TorchSharedStateActionCritic

class maze.core.agent.torch_state_action_critic.TorchSharedStateActionCritic(networks: Mapping[str | int, torch.nn.Module], num_policies: int, device: str, only_discrete_spaces: Dict[str | int, bool], action_spaces_dict: Dict[str | int, gymnasium.spaces.Dict])

One critic is shared across all sub-steps or actors (default to use for standard gym-style environments). Can be instantiated via the SharedStateActionCriticComposer.

property num_critics: int

(overrides TorchStateActionCritic)

implementation of

TorchStateActionCritic

predict_next_q_values(next_observations: Dict[str | int, Dict[str, torch.Tensor]], next_actions: Dict[str | int, Dict[str, torch.Tensor]], next_actions_logits: Dict[str | int, Dict[str, torch.Tensor]], next_actions_log_probs: Dict[str | int, Dict[str, torch.Tensor]], alpha: Dict[str | int, torch.Tensor]) Dict[str | int, torch.Tensor | Dict[str, torch.Tensor]]

(overrides TorchStateActionCritic)

implementation of

TorchStateActionCritic

predict_q_values(observations: Dict[str | int, Dict[str, torch.Tensor]], actions: Dict[str | int, Dict[str, torch.Tensor]], gather_output: bool) Dict[str | int, List[torch.Tensor | Dict[str, torch.Tensor]]]

(overrides TorchStateActionCritic)

implementation of

TorchStateActionCritic