class maze.core.agent.serialized_torch_policy.SerializedTorchPolicy(model: Union[omegaconf.DictConfig, Dict], state_dict_file: str, spaces_dict_file: str, device: str, deterministic: bool)

Structured policy used for rollouts of trained models.

Will build the models based on the model composer and spaces config and set the state of individual policies according to the state dict dump.

Policies are set to eval mode by default.

  • model – Model composer configuration

  • state_dict_file – Path to dumped state dictionaries of the trained policies

  • spaces_dict_file – Path to dumped spaces configuration (action and observation spaces of the env the policy was trained on, used for model initialization)

  • deterministic – If True actions are computed deterministically; else sample from the probability distribution.

compute_action(observation: Dict[str, numpy.ndarray], maze_state: Optional[Any] = None, env: Optional[maze.core.env.base_env.BaseEnv] = None, actor_id: maze.core.env.structured_env.ActorID = None, deterministic: bool = False) → Dict[str, Union[int, numpy.ndarray]]

(overrides TorchPolicy)


seed(seed: int)

(overrides TorchPolicy)

Set torch manual seed