TorchPolicy¶
-
class
maze.core.agent.torch_policy.
TorchPolicy
(networks: Mapping[Union[str, int], torch.nn.Module], distribution_mapper: maze.distributions.distribution_mapper.DistributionMapper, device: str, substeps_with_separate_agent_nets: Optional[List[Union[int, str]]] = None)¶ Encapsulates multiple torch policies along with a distribution mapper for training and rollouts in structured environments.
- Parameters
networks – Mapping of policy networks to encapsulate
distribution_mapper – Distribution mapper associated with the policy mapping.
device – Device the policy should be located on (cpu or cuda)
-
compute_action
(observation: Dict[str, numpy.ndarray], maze_state: Optional[Any] = None, env: Optional[maze.core.env.base_env.BaseEnv] = None, actor_id: maze.core.env.structured_env.ActorID = None, deterministic: bool = False) → Dict[str, Union[int, numpy.ndarray]]¶ (overrides
Policy
)implementation of
Policy
-
compute_policy_output
(record: maze.core.trajectory_recording.records.structured_spaces_record.StructuredSpacesRecord, temperature: float = 1.0) → maze.core.agent.torch_policy_output.PolicyOutput¶ Compute the full Policy output for all policy networks over a full (flat) environment step.
- Parameters
record – The StructuredSpacesRecord holding the observation and actor ids.
temperature – Optional, the temperature to use for initializing the probability distribution.
- Returns
The full Policy output for the record given.
-
compute_substep_policy_output
(observation: Dict[str, numpy.ndarray], actor_id: maze.core.env.structured_env.ActorID = None, temperature: float = 1.0) → maze.core.agent.torch_policy_output.PolicySubStepOutput¶ Compute the full output of a specified policy.
- Parameters
observation – The observation to use as input.
actor_id – Optional, the actor id specifying the network to use.
temperature – Optional, the temperature to use for initializing the probability distribution.
- Returns
The computed PolicySubStepOutput.
-
compute_top_action_candidates
(observation: Dict[str, numpy.ndarray], num_candidates: Optional[int], maze_state: Optional[Any], env: Optional[maze.core.env.base_env.BaseEnv], actor_id: maze.core.env.structured_env.ActorID = None) → Tuple[Sequence[Dict[str, Union[int, numpy.ndarray]]], Sequence[float]]¶ (overrides
Policy
)implementation of
Policy
-
eval
() → None¶ (overrides
TorchModel
)implementation of
TorchModel
-
load_state_dict
(state_dict: Dict) → None¶ (overrides
TorchModel
)implementation of
TorchModel
-
needs_state
() → bool¶ (overrides
Policy
)This policy does not require the state() object to compute the action.
-
network_for
(actor_id: Optional[maze.core.env.structured_env.ActorID]) → torch.nn.Module¶ Helper function for returning a network for the given policy ID (using either just the sub-step ID or the full Actor ID as key, depending on the separated agent networks mode.
- Parameters
actor_id – Actor ID to get a network for
- Returns
Network corresponding to the given policy ID.
-
parameters
() → List[torch.Tensor]¶ (overrides
TorchModel
)implementation of
TorchModel
-
state_dict
() → Dict¶ (overrides
TorchModel
)implementation of
TorchModel
-
to
(device: str) → None¶ (overrides
TorchModel
)implementation of
TorchModel
-
train
() → None¶ (overrides
TorchModel
)implementation of
TorchModel