ESTrainer¶
-
class
maze.train.trainers.es.es_trainer.
ESTrainer
(algorithm_config: maze.train.trainers.es.es_algorithm_config.ESAlgorithmConfig, torch_policy: maze.core.agent.torch_policy.TorchPolicy, shared_noise: maze.train.trainers.es.es_shared_noise_table.SharedNoiseTable, normalization_stats: Optional[Dict[str, Tuple[numpy.ndarray, numpy.ndarray]]])¶ Trainer class for OpenAI Evolution Strategies.
- Parameters
algorithm_config – Algorithm parameters.
torch_policy – Multi-step policy encapsulating the policy networks
shared_noise – The noise table, with the same content for every worker and the master.
normalization_stats – Normalization statistics as calculated by the NormalizeObservationWrapper.
-
load_state_dict
(state_dict: Dict) → None¶ Set the model and optimizer state. :param state_dict: The state dict.
-
train
(distributed_rollouts: maze.train.trainers.es.distributed.es_distributed_rollouts.ESDistributedRollouts, n_epochs: Optional[int] = None, model_selection: Optional[maze.train.trainers.common.model_selection.model_selection_base.ModelSelectionBase] = None) → None¶ (overrides
Trainer
)Run the ES training loop. :param distributed_rollouts: The distribution interface for experience collection. :param n_epochs: Number of epochs to train. :param model_selection: Optional model selection class, receives model evaluation results.