ActorCritic¶
-
class
maze.train.trainers.common.actor_critic.actor_critic_trainer.
ActorCritic
(algorithm_config: Union[maze.train.trainers.a2c.a2c_algorithm_config.A2CAlgorithmConfig, maze.train.trainers.ppo.ppo_algorithm_config.PPOAlgorithmConfig, maze.train.trainers.impala.impala_algorithm_config.ImpalaAlgorithmConfig], rollout_generator: Union[maze.core.rollout.rollout_generator.RolloutGenerator, maze.train.parallelization.distributed_actors.distributed_actors.DistributedActors], evaluator: Optional[maze.train.trainers.common.evaluators.rollout_evaluator.RolloutEvaluator], model: maze.core.agent.torch_actor_critic.TorchActorCritic, model_selection: Optional[maze.train.trainers.common.model_selection.best_model_selection.BestModelSelection])¶ Base class for actor critic trainers. Suitable for multi-step and multi-agent training.
- Parameters
algorithm_config – Algorithm parameters.
rollout_generator – The rollout generator to use. This object encapsulates the env.
evaluator – The evaluator to use.
model – Structured torch actor critic model.
model_selection – Optional model selection class, receives model evaluation results.
-
load_state_dict
(state_dict: Dict) → None¶ Set the model and optimizer state. :param state_dict: The state dict.
-
train
(n_epochs: Optional[int] = None) → None¶ (overrides
Trainer
)- Main train method of the actor critic trainer. This is used in order to do algorithm specific operations
around this method in the main train method which is called by the runner. (e.g. this is used when it comes to multiprocessing)
- param n_epochs
Number of epochs to train.