BaseDistributedWorkersWithBuffer¶
-
class
maze.train.parallelization.distributed_actors.base_distributed_workers_with_buffer.
BaseDistributedWorkersWithBuffer
(env_factory: Callable[], Union[maze.core.env.structured_env.StructuredEnv, maze.core.env.structured_env_spaces_mixin.StructuredEnvSpacesMixin, maze.core.log_stats.log_stats_env.LogStatsEnv]], worker_policy: maze.core.agent.torch_policy.TorchPolicy, n_rollout_steps: int, n_workers: int, batch_size: int, rollouts_per_iteration: int, split_rollouts_into_transitions: bool, env_instance_seeds: List[int], replay_buffer: maze.train.trainers.common.replay_buffer.replay_buffer.BaseReplayBuffer)¶ The base class for all distributed workers with buffer.
Distributed workers run rollouts independently. Rollouts are collected by calling the collect_rollouts method and are then added to the buffer.
- Parameters
env_factory – Factory function for envs to run rollouts on
worker_policy – Structured policy to sample actions from
n_rollout_steps – Number of rollouts steps to record in one rollout
n_workers – Number of distributed workers to run simultaneously
batch_size – Size of the batch the rollouts are collected in
rollouts_per_iteration – The number of rollouts to collect each time the collect_rollouts method is called.
split_rollouts_into_transitions – Specify whether all computed rollouts should be split into transitions before processing them
env_instance_seeds – A list of seeds for each workers envs.
replay_buffer – The replay buffer to use.
-
abstract
broadcast_updated_policy
(state_dict: Dict) → None¶ Broadcast the newest version of the policy to the workers.
- Parameters
state_dict – State of the new policy version to broadcast.
-
abstract
collect_rollouts
() → Tuple[float, float, float]¶ Collect worker outputs from the queue and add it to the buffer.
- Returns
A tuple of (1) queue size before de-queueing, (2) queue size after dequeueing, and (3) the time it took to dequeue the outputs
-
get_epoch_stats_aggregator
() → maze.core.log_stats.log_stats.LogStatsAggregator¶ Return the collected epoch stats aggregator
-
get_stats_value
(event: Callable, level: maze.core.log_stats.log_stats.LogStatsLevel, name: Optional[str] = None) → Union[int, float, numpy.ndarray, dict]¶ Obtain a single value from the epoch statistics dict.
- Parameters
event – The event interface method of the value in question.
name – The output_name of the statistics in case it has been specified in
maze.core.log_stats.event_decorators.define_epoch_stats()
level – Must be set to LogStatsLevel.EPOCH, step or episode statistics are not propagated.
-
sample_batch
(learner_device: str) → maze.core.trajectory_recording.records.structured_spaces_record.StructuredSpacesRecord¶ Sample a batch from the buffer and return it as a batched structured spaces record.
- Parameters
learner_device – The device of the learner (cpu or cuda).
- Returns
An batched structured spaces record object holding the batched rollouts.