TorchPolicy¶
- class maze.core.agent.torch_policy.TorchPolicy(networks: Mapping[str | int, torch.nn.Module], distribution_mapper: DistributionMapper, device: str, substeps_with_separate_agent_nets: List[str | int] | None = None)¶
Encapsulates multiple torch policies along with a distribution mapper for training and rollouts in structured environments.
- Parameters:
networks – Mapping of policy networks to encapsulate
distribution_mapper – Distribution mapper associated with the policy mapping.
device – Device the policy should be located on (cpu or cuda)
- compute_action(observation: Dict[str, numpy.ndarray], maze_state: Any | None = None, env: BaseEnv | None = None, actor_id: ActorID | None = None, deterministic: bool = False) Dict[str, int | numpy.ndarray]¶
(overrides
Policy)implementation of
Policy
- compute_policy_output(record: StructuredSpacesRecord, temperature: float = 1.0) PolicyOutput¶
Compute the full Policy output for all policy networks over a full (flat) environment step.
- Parameters:
record – The StructuredSpacesRecord holding the observation and actor ids.
temperature – Optional, the temperature to use for initializing the probability distribution.
- Returns:
The full Policy output for the record given.
- compute_substep_policy_output(observation: Dict[str, numpy.ndarray], actor_id: ActorID | None = None, temperature: float = 1.0) PolicySubStepOutput¶
Compute the full output of a specified policy.
- Parameters:
observation – The observation to use as input.
actor_id – Optional, the actor id specifying the network to use.
temperature – Optional, the temperature to use for initializing the probability distribution.
- Returns:
The computed PolicySubStepOutput.
- compute_top_action_candidates(observation: Dict[str, numpy.ndarray], num_candidates: int | None, maze_state: Any | None, env: BaseEnv | None, actor_id: ActorID | None = None) Tuple[Sequence[Dict[str, int | numpy.ndarray]], Sequence[float]]¶
(overrides
Policy)implementation of
Policy
- eval() None¶
(overrides
TorchModel)implementation of
TorchModel
- load_state_dict(state_dict: Dict) None¶
(overrides
TorchModel)implementation of
TorchModel
- needs_state() bool¶
(overrides
Policy)This policy does not require the state() object to compute the action.
- network_for(actor_id: ActorID | None) torch.nn.Module¶
Helper function for returning a network for the given policy ID (using either just the sub-step ID or the full Actor ID as key, depending on the separated agent networks mode.
- Parameters:
actor_id – Actor ID to get a network for
- Returns:
Network corresponding to the given policy ID.
- parameters() List[torch.Tensor]¶
(overrides
TorchModel)implementation of
TorchModel
- state_dict() Dict¶
(overrides
TorchModel)implementation of
TorchModel
- to(device: str) None¶
(overrides
TorchModel)implementation of
TorchModel
- train() None¶
(overrides
TorchModel)implementation of
TorchModel