BaseDistributedWorkersWithBuffer

class maze.train.parallelization.distributed_actors.base_distributed_workers_with_buffer.BaseDistributedWorkersWithBuffer(env_factory: Callable[], Union[maze.core.env.structured_env.StructuredEnv, maze.core.env.structured_env_spaces_mixin.StructuredEnvSpacesMixin, maze.core.log_stats.log_stats_env.LogStatsEnv]], worker_policy: maze.core.agent.torch_policy.TorchPolicy, n_rollout_steps: int, n_workers: int, batch_size: int, rollouts_per_iteration: int, initial_sampling_policy: Union[omegaconf.DictConfig, maze.core.agent.policy.Policy], replay_buffer_size: int, initial_buffer_size: int, split_rollouts_into_transitions: bool, env_instance_seeds: List[int], replay_buffer_seed: int)

The base class for all distributed workers with buffer.

Distributed workers run rollouts independently. Rollouts are collected by calling the collect_rollouts method and are then added to the buffer.

Parameters
  • env_factory – Factory function for envs to run rollouts on

  • worker_policy – Structured policy to sample actions from

  • n_rollout_steps – Number of rollouts steps to record in one rollout

  • n_workers – Number of distributed workers to run simultaneously

  • batch_size – Size of the batch the rollouts are collected in

  • rollouts_per_iteration – The number of rollouts to collect each time the collect_rollouts method is called.

  • initial_sampling_policy – The policy used to sample trajectories from to fill the buffer initially (before training starts).

  • replay_buffer_size – The max size of the replay buffer.

  • initial_buffer_size – The initial size of the replay buffer filled by sampling from the

  • initial_sampling_policy – Initial sampling policy to fill the buffer with initial_buffer_size initial samples.

  • split_rollouts_into_transitions – Specify whether all computed rollouts should be split into transitions before processing them

  • env_instance_seeds – A list of seeds for each workers envs.

  • replay_buffer_seed – A seed for initializing and sampling from the replay buffer.

abstract broadcast_updated_policy(state_dict: Dict)None

Broadcast the newest version of the policy to the workers.

Parameters

state_dict – State of the new policy version to broadcast.

abstract collect_rollouts() → Tuple[float, float, float]

Collect worker outputs from the queue and add it to the buffer.

Returns

A tuple of (1) queue size before de-queueing, (2) queue size after dequeueing, and (3) the time it took to dequeue the outputs

get_epoch_stats_aggregator()maze.core.log_stats.log_stats.LogStatsAggregator

Return the collected epoch stats aggregator

get_stats_value(event: Callable, level: maze.core.log_stats.log_stats.LogStatsLevel, name: Optional[str] = None) → Union[int, float, numpy.ndarray, dict]

Obtain a single value from the epoch statistics dict.

Parameters
  • event – The event interface method of the value in question.

  • name – The output_name of the statistics in case it has been specified in maze.core.log_stats.event_decorators.define_epoch_stats()

  • level – Must be set to LogStatsLevel.EPOCH, step or episode statistics are not propagated.

init_replay_buffer(initial_sampling_policy: Union[omegaconf.DictConfig, maze.core.agent.policy.Policy], initial_buffer_size: int, replay_buffer_seed: int)None
Fill the buffer with initial_buffer_size rollouts by rolling out the

initial_sampling_policy.

Parameters
  • initial_sampling_policy – The initial sampling policy used to fill the buffer to the initial fill state.

  • initial_buffer_size – The initial size of the replay buffer filled by sampling from the initial sampling policy.

  • replay_buffer_seed – A seed for initializing and sampling from the replay buffer.

sample_batch(learner_device: str)maze.core.trajectory_recording.records.structured_spaces_record.StructuredSpacesRecord

Sample a batch from the buffer and return it as a batched structured spaces record.

Parameters

learner_device – The device of the learner (cpu or cuda).

Returns

An batched structured spaces record object holding the batched rollouts.

abstract start()None

Start all distributed workers

abstract stop()None

Stop all distributed workers