DistributedActors¶
- class maze.train.parallelization.distributed_actors.distributed_actors.DistributedActors(env_factory: Callable[[], StructuredEnv | StructuredEnvSpacesMixin | LogStatsEnv], policy: TorchPolicy, n_rollout_steps: int, n_actors: int, batch_size: int)¶
The base class for all distributed actors.
Distributed actors run rollouts independently. Rollouts are recorded and made available in batches to be used during training. When a new policy version is made available, it is distributed to all actors.
- Parameters:
env_factory – Factory function for envs to run rollouts on.
policy – Structured policy to sample actions from.
n_rollout_steps – Number of rollouts steps to record in one rollout.
n_actors – Number of distributed actors to run simultaneously.
batch_size – Size of the batch the rollouts are collected in.
- abstract broadcast_updated_policy(state_dict: Dict) None¶
Broadcast the newest version of the policy to the actors.
- Parameters:
state_dict – State of the new policy version to broadcast.
- abstract collect_outputs(learner_device: str) Tuple[StructuredSpacesRecord, float, float, float]¶
Collect self.batch_size actor outputs from the queue and return them batched where the first dim is time and the second is the batch size.
- Parameters:
learner_device – the device of the learner
- Returns:
A tuple of (1) batched version of ActorOutputs, (2) queue size before de-queueing, (3) queue size after dequeueing, and (4) the time it took to dequeue the outputs
- get_epoch_stats_aggregator() LogStatsAggregator¶
Return the collected epoch stats aggregator
- get_stats_value(event: Callable, level: LogStatsLevel, name: str | None = None) int | float | numpy.ndarray | dict¶
Obtain a single value from the epoch statistics dict.
- Parameters:
event – The event interface method of the value in question.
name – The output_name of the statistics in case it has been specified in
maze.core.log_stats.event_decorators.define_epoch_stats()level – Must be set to LogStatsLevel.EPOCH, step or episode statistics are not propagated.