TorchPolicy

class maze.core.agent.torch_policy.TorchPolicy(networks: Mapping[str | int, torch.nn.Module], distribution_mapper: DistributionMapper, device: str, substeps_with_separate_agent_nets: List[str | int] | None = None)

Encapsulates multiple torch policies along with a distribution mapper for training and rollouts in structured environments.

Parameters:
  • networks – Mapping of policy networks to encapsulate

  • distribution_mapper – Distribution mapper associated with the policy mapping.

  • device – Device the policy should be located on (cpu or cuda)

compute_action(observation: Dict[str, numpy.ndarray], maze_state: Any | None = None, env: BaseEnv | None = None, actor_id: ActorID | None = None, deterministic: bool = False) Dict[str, int | numpy.ndarray]

(overrides Policy)

implementation of Policy

compute_policy_output(record: StructuredSpacesRecord, temperature: float = 1.0) PolicyOutput

Compute the full Policy output for all policy networks over a full (flat) environment step.

Parameters:
  • record – The StructuredSpacesRecord holding the observation and actor ids.

  • temperature – Optional, the temperature to use for initializing the probability distribution.

Returns:

The full Policy output for the record given.

compute_substep_policy_output(observation: Dict[str, numpy.ndarray], actor_id: ActorID | None = None, temperature: float = 1.0) PolicySubStepOutput

Compute the full output of a specified policy.

Parameters:
  • observation – The observation to use as input.

  • actor_id – Optional, the actor id specifying the network to use.

  • temperature – Optional, the temperature to use for initializing the probability distribution.

Returns:

The computed PolicySubStepOutput.

compute_top_action_candidates(observation: Dict[str, numpy.ndarray], num_candidates: int | None, maze_state: Any | None, env: BaseEnv | None, actor_id: ActorID | None = None) Tuple[Sequence[Dict[str, int | numpy.ndarray]], Sequence[float]]

(overrides Policy)

implementation of Policy

eval() None

(overrides TorchModel)

implementation of TorchModel

load_state_dict(state_dict: Dict) None

(overrides TorchModel)

implementation of TorchModel

needs_state() bool

(overrides Policy)

This policy does not require the state() object to compute the action.

network_for(actor_id: ActorID | None) torch.nn.Module

Helper function for returning a network for the given policy ID (using either just the sub-step ID or the full Actor ID as key, depending on the separated agent networks mode.

Parameters:

actor_id – Actor ID to get a network for

Returns:

Network corresponding to the given policy ID.

parameters() List[torch.Tensor]

(overrides TorchModel)

implementation of TorchModel

seed(seed: int) None

(overrides Policy)

This is done globally

state_dict() Dict

(overrides TorchModel)

implementation of TorchModel

to(device: str) None

(overrides TorchModel)

implementation of TorchModel

train() None

(overrides TorchModel)

implementation of TorchModel