SAC

class maze.train.trainers.sac.sac_trainer.SAC(algorithm_config: maze.train.trainers.sac.sac_algorithm_config.SACAlgorithmConfig, learner_model: maze.core.agent.torch_actor_critic.TorchActorCritic, distributed_actors: maze.train.parallelization.distributed_actors.base_distributed_workers_with_buffer.BaseDistributedWorkersWithBuffer, model_selection: Optional[maze.train.trainers.common.model_selection.best_model_selection.BestModelSelection], evaluator: Optional[maze.train.trainers.common.evaluators.rollout_evaluator.RolloutEvaluator])

Multi step soft actor critic.

Parameters
  • algorithm_config – Algorithm options.

  • learner_model – Structured torch actor critic to train.

  • distributed_actors – Distributed actors for collection of training rollouts.

  • model_selection – Optional model selection class, receives model evaluation results.

  • evaluator – The evaluator to use.

evaluate()None

Perform evaluation on eval env.

load_state(file_path: Union[str, BinaryIO])None

(overrides Trainer)

implementation of Trainer

load_state_dict(state_dict: Dict)None

Set the model and optimizer state.

Parameters

state_dict – The state dict.

train(n_epochs: Optional[int] = None)None

(overrides Trainer)

Train function that wraps normal train function in order to close all processes properly

param n_epochs

number of epochs to train.