BaseDistributedWorkersWithBuffer

class maze.train.parallelization.distributed_actors.base_distributed_workers_with_buffer.BaseDistributedWorkersWithBuffer(env_factory: Callable[], Union[maze.core.env.structured_env.StructuredEnv, maze.core.env.structured_env_spaces_mixin.StructuredEnvSpacesMixin, maze.core.log_stats.log_stats_env.LogStatsEnv]], worker_policy: maze.core.agent.torch_policy.TorchPolicy, n_rollout_steps: int, n_workers: int, batch_size: int, rollouts_per_iteration: int, split_rollouts_into_transitions: bool, env_instance_seeds: List[int], replay_buffer: maze.train.trainers.common.replay_buffer.replay_buffer.BaseReplayBuffer)

The base class for all distributed workers with buffer.

Distributed workers run rollouts independently. Rollouts are collected by calling the collect_rollouts method and are then added to the buffer.

Parameters
  • env_factory – Factory function for envs to run rollouts on

  • worker_policy – Structured policy to sample actions from

  • n_rollout_steps – Number of rollouts steps to record in one rollout

  • n_workers – Number of distributed workers to run simultaneously

  • batch_size – Size of the batch the rollouts are collected in

  • rollouts_per_iteration – The number of rollouts to collect each time the collect_rollouts method is called.

  • split_rollouts_into_transitions – Specify whether all computed rollouts should be split into transitions before processing them

  • env_instance_seeds – A list of seeds for each workers envs.

  • replay_buffer – The replay buffer to use.

abstract broadcast_updated_policy(state_dict: Dict)None

Broadcast the newest version of the policy to the workers.

Parameters

state_dict – State of the new policy version to broadcast.

abstract collect_rollouts() → Tuple[float, float, float]

Collect worker outputs from the queue and add it to the buffer.

Returns

A tuple of (1) queue size before de-queueing, (2) queue size after dequeueing, and (3) the time it took to dequeue the outputs

get_epoch_stats_aggregator()maze.core.log_stats.log_stats.LogStatsAggregator

Return the collected epoch stats aggregator

get_stats_value(event: Callable, level: maze.core.log_stats.log_stats.LogStatsLevel, name: Optional[str] = None) → Union[int, float, numpy.ndarray, dict]

Obtain a single value from the epoch statistics dict.

Parameters
  • event – The event interface method of the value in question.

  • name – The output_name of the statistics in case it has been specified in maze.core.log_stats.event_decorators.define_epoch_stats()

  • level – Must be set to LogStatsLevel.EPOCH, step or episode statistics are not propagated.

sample_batch(learner_device: str)maze.core.trajectory_recording.records.structured_spaces_record.StructuredSpacesRecord

Sample a batch from the buffer and return it as a batched structured spaces record.

Parameters

learner_device – The device of the learner (cpu or cuda).

Returns

An batched structured spaces record object holding the batched rollouts.

abstract start()None

Start all distributed workers

abstract stop()None

Stop all distributed workers