ActorCritic

class maze.train.trainers.common.actor_critic.actor_critic_trainer.ActorCritic(algorithm_config: Union[maze.train.trainers.a2c.a2c_algorithm_config.A2CAlgorithmConfig, maze.train.trainers.ppo.ppo_algorithm_config.PPOAlgorithmConfig, maze.train.trainers.impala.impala_algorithm_config.ImpalaAlgorithmConfig], rollout_generator: Union[maze.core.rollout.rollout_generator.RolloutGenerator, maze.train.parallelization.distributed_actors.distributed_actors.DistributedActors], evaluator: Optional[maze.train.trainers.common.evaluators.rollout_evaluator.RolloutEvaluator], model: maze.core.agent.torch_actor_critic.TorchActorCritic, model_selection: Optional[maze.train.trainers.common.model_selection.best_model_selection.BestModelSelection])

Base class for actor critic trainers. Suitable for multi-step and multi-agent training.

Parameters
  • algorithm_config – Algorithm parameters.

  • rollout_generator – The rollout generator to use. This object encapsulates the env.

  • evaluator – The evaluator to use.

  • model – Structured torch actor critic model.

  • model_selection – Optional model selection class, receives model evaluation results.

evaluate()None

Perform evaluation on eval env.

load_state(file_path: Union[str, BinaryIO])None

(overrides Trainer)

implementation of Trainer

load_state_dict(state_dict: Dict)None

Set the model and optimizer state. :param state_dict: The state dict.

train(n_epochs: Optional[int] = None)None

(overrides Trainer)

Main train method of the actor critic trainer. This is used in order to do algorithm specific operations

around this method in the main train method which is called by the runner. (e.g. this is used when it comes to multiprocessing)

param n_epochs

Number of epochs to train.