TorchActorCritic

class maze.core.agent.torch_actor_critic.TorchActorCritic(policy: maze.core.agent.torch_policy.TorchPolicy, critic: Union[maze.core.agent.torch_state_critic.TorchStateCritic, maze.core.agent.torch_state_action_critic.TorchStateActionCritic], device: str)

Encapsulates a structured torch policy and critic for training actor-critic algorithms in structured environments.

Parameters
  • policy – A structured torch policy for training in structured environments.

  • critic – A structured torch critic for training in structured environments.

  • device – Device the model (networks) should be located on (cpu or cuda)

compute_actor_critic_output(record: maze.core.trajectory_recording.records.structured_spaces_record.StructuredSpacesRecord, temperature: float = 1.0) → Tuple[maze.core.agent.torch_policy_output.PolicyOutput, maze.core.agent.state_critic_input_output.StateCriticOutput]

One method to compute the policy and critic output in one go, managing the sub-steps, individual critic types shared embeddings of networks.

Parameters
  • record – The StructuredSpacesRecord holding the observation and actor ids.

  • temperature – (Optional) The temperature used for initializing the probability distribution of the action heads.

Returns

A tuple of the policy and critic output.

property device

implementation of TorchModel

eval()None

(overrides TorchModel)

implementation of TorchModel

load_state_dict(state_dict: Dict)None

(overrides TorchModel)

implementation of TorchModel

parameters() → List[torch.Tensor]

(overrides TorchModel)

implementation of TorchModel

state_dict() → Dict

(overrides TorchModel)

implementation of TorchModel

to(device: str)

(overrides TorchModel)

implementation of TorchModel

train()None

(overrides TorchModel)

implementation of TorchModel