class maze.train.trainers.sac.sac_runners.SACRunner(state_dict_dump_file: str, dump_interval: Optional[int], spaces_config_dump_file: str, normalization_samples: int, eval_concurrency: int)

Common superclass for SAC runners, implementing the main training controls.

abstract create_distributed_eval_env(env_factory: Callable[], maze.core.env.maze_env.MazeEnv], eval_concurrency: int, logging_prefix: str)maze.train.parallelization.vector_env.structured_vector_env.StructuredVectorEnv

The individual runners implement the setup of the distributed eval env

  • env_factory – Factory function for envs to run rollouts on.

  • eval_concurrency – The concurrency of the evaluation env.

  • logging_prefix – The logging prefix to use for the envs.


A vector env.

abstract create_distributed_rollout_workers(env_factory: Callable[], maze.core.env.maze_env.MazeEnv], worker_policy: maze.core.agent.torch_policy.TorchPolicy, n_rollout_steps: int, n_workers: int, batch_size: int, replay_buffer_size: int, initial_buffer_size: int, initial_sampling_policy: maze.core.agent.policy.Policy, rollouts_per_iteration: int, split_rollouts_into_transitions: bool, env_instance_seeds: List[int], replay_buffer_seed: int)maze.train.parallelization.distributed_actors.base_distributed_workers_with_buffer.BaseDistributedWorkersWithBuffer

The individual runners implement the setup of the distributed training rollout actors.

  • env_factory – Factory function for envs to run rollouts on.

  • worker_policy – Structured policy to sample actions from.

  • n_rollout_steps – Number of rollouts steps to record in one rollout.

  • n_workers – Number of distributed workers to run simultaneously.

  • batch_size – Size of the batch the rollouts are collected in.

  • replay_buffer_size – The total size of the replay buffer.

  • initial_buffer_size – The initial size of the replay buffer filled by sampling from the.

  • initial_sampling_policy – Initial sampling policy to fill the buffer with :param initial_buffer_size initial samples.

  • rollouts_per_iteration – The number of rollouts to collect each time the collect_rollouts method is called.

  • split_rollouts_into_transitions – Specify whether to split rollouts into individual transitions.

  • env_instance_seeds – The seed for each of the workers env.

  • replay_buffer_seed – The seed for the replay buffer.


A BaseDistributedWorkersWithBuffer object.

eval_concurrency: int

Number of concurrent evaluation envs

setup(cfg: omegaconf.DictConfig)None

(overrides TrainingRunner)

See setup().