SACRunner

class maze.train.trainers.sac.sac_runners.SACRunner(state_dict_dump_file: str, dump_interval: Optional[int], spaces_config_dump_file: str, normalization_samples: int, eval_concurrency: int, initial_demonstration_trajectories: omegaconf.DictConfig)

Common superclass for SAC runners, implementing the main training controls.

abstract create_distributed_eval_env(env_factory: Callable[], maze.core.env.maze_env.MazeEnv], eval_concurrency: int, logging_prefix: str)maze.train.parallelization.vector_env.structured_vector_env.StructuredVectorEnv

The individual runners implement the setup of the distributed eval env

Parameters
  • env_factory – Factory function for envs to run rollouts on.

  • eval_concurrency – The concurrency of the evaluation env.

  • logging_prefix – The logging prefix to use for the envs.

Returns

A vector env.

abstract create_distributed_rollout_workers(env_factory: Callable[], maze.core.env.maze_env.MazeEnv], worker_policy: maze.core.agent.torch_policy.TorchPolicy, n_rollout_steps: int, n_workers: int, batch_size: int, rollouts_per_iteration: int, split_rollouts_into_transitions: bool, env_instance_seeds: List[int], replay_buffer: maze.train.trainers.common.replay_buffer.replay_buffer.BaseReplayBuffer)maze.train.parallelization.distributed_actors.base_distributed_workers_with_buffer.BaseDistributedWorkersWithBuffer

The individual runners implement the setup of the distributed training rollout actors.

Parameters
  • env_factory – Factory function for envs to run rollouts on.

  • worker_policy – Structured policy to sample actions from.

  • n_rollout_steps – Number of rollouts steps to record in one rollout.

  • n_workers – Number of distributed workers to run simultaneously.

  • batch_size – Size of the batch the rollouts are collected in.

  • rollouts_per_iteration – The number of rollouts to collect each time the collect_rollouts method is called.

  • split_rollouts_into_transitions – Specify whether to split rollouts into individual transitions.

  • env_instance_seeds – The seed for each of the workers env.

  • replay_buffer – The replay buffer to use.

Returns

A BaseDistributedWorkersWithBuffer object.

eval_concurrency: int

Number of concurrent evaluation envs

static init_replay_buffer(replay_buffer: maze.train.trainers.common.replay_buffer.replay_buffer.BaseReplayBuffer, initial_sampling_policy: Union[omegaconf.DictConfig, maze.core.agent.policy.Policy], initial_buffer_size: int, replay_buffer_seed: int, split_rollouts_into_transitions: bool, n_rollout_steps: int, env_factory: Callable[], maze.core.env.maze_env.MazeEnv])None

Fill the buffer with initial_buffer_size rollouts by rolling out the initial_sampling_policy.

Parameters
  • replay_buffer – The replay buffer to use.

  • initial_sampling_policy – The initial sampling policy used to fill the buffer to the initial fill state.

  • initial_buffer_size – The initial size of the replay buffer filled by sampling from the initial sampling policy.

  • replay_buffer_seed – A seed for initializing and sampling from the replay buffer.

  • split_rollouts_into_transitions – Specify whether to split rollouts into individual transitions.

  • n_rollout_steps – Number of rollouts steps to record in one rollout.

  • env_factory – Factory function for envs to run rollouts on.

initial_demonstration_trajectories: omegaconf.DictConfig

Optionally a trajectory, list of trajectories, a dir or list of directories can be given to fill the replay buffer with. If this is not given (is None) the initial replay buffer is filled with the (algorithm) specified initial_sampling_policy

load_replay_buffer(replay_buffer: maze.train.trainers.common.replay_buffer.replay_buffer.BaseReplayBuffer, cfg: omegaconf.DictConfig)None

Load the given trajectories as a dataset and fill the buffer with these trajectories.

Parameters
  • replay_buffer – The replay buffer to fill.

  • cfg – The dict config of the experiment.

setup(cfg: omegaconf.DictConfig)None

(overrides TrainingRunner)

See setup().