ESTrainer¶
- class maze.train.trainers.es.es_trainer.ESTrainer(algorithm_config: ESAlgorithmConfig, torch_policy: TorchPolicy, shared_noise: SharedNoiseTable, normalization_stats: Dict[str, Tuple[numpy.ndarray, numpy.ndarray]] | None)¶
Trainer class for OpenAI Evolution Strategies.
- Parameters:
algorithm_config – Algorithm parameters.
torch_policy – Multi-step policy encapsulating the policy networks
shared_noise – The noise table, with the same content for every worker and the master.
normalization_stats – Normalization statistics as calculated by the NormalizeObservationWrapper.
- load_state_dict(state_dict: Dict) None¶
Set the model and optimizer state. :param state_dict: The state dict.
- train(distributed_rollouts: ESDistributedRollouts, n_epochs: int | None = None, model_selection: ModelSelectionBase | None = None) None¶
(overrides
Trainer)Run the ES training loop. :param distributed_rollouts: The distribution interface for experience collection. :param n_epochs: Number of epochs to train. :param model_selection: Optional model selection class, receives model evaluation results.