SAC

class maze.train.trainers.sac.sac_trainer.SAC(algorithm_config: SACAlgorithmConfig, learner_model: TorchActorCritic, distributed_actors: BaseDistributedWorkersWithBuffer, model_selection: BestModelSelection | None, evaluator: RolloutEvaluator | None)

Multi step soft actor critic.

Parameters:
  • algorithm_config – Algorithm options.

  • learner_model – Structured torch actor critic to train.

  • distributed_actors – Distributed actors for collection of training rollouts.

  • model_selection – Optional model selection class, receives model evaluation results.

  • evaluator – The evaluator to use.

evaluate() None

Perform evaluation on eval env.

load_state(file_path: str | BinaryIO) None

(overrides Trainer)

implementation of Trainer

load_state_dict(state_dict: Dict) None

Set the model and optimizer state.

Parameters:

state_dict – The state dict.

state_dict()

(overrides Trainer)

implementation of Trainer

train(n_epochs: int | None = None) None

(overrides Trainer)

Train function that wraps normal train function in order to close all processes properly

param n_epochs:

number of epochs to train.