TorchSharedStateActionCritic¶
- class maze.core.agent.torch_state_action_critic.TorchSharedStateActionCritic(networks: Mapping[str | int, torch.nn.Module], num_policies: int, device: str, only_discrete_spaces: Dict[str | int, bool], action_spaces_dict: Dict[str | int, gymnasium.spaces.Dict])¶
One critic is shared across all sub-steps or actors (default to use for standard gym-style environments). Can be instantiated via the
SharedStateActionCriticComposer.- property num_critics: int¶
(overrides
TorchStateActionCritic)- implementation of
- predict_next_q_values(next_observations: Dict[str | int, Dict[str, torch.Tensor]], next_actions: Dict[str | int, Dict[str, torch.Tensor]], next_actions_logits: Dict[str | int, Dict[str, torch.Tensor]], next_actions_log_probs: Dict[str | int, Dict[str, torch.Tensor]], alpha: Dict[str | int, torch.Tensor]) → Dict[str | int, torch.Tensor | Dict[str, torch.Tensor]]¶
(overrides
TorchStateActionCritic)- implementation of
- predict_q_values(observations: Dict[str | int, Dict[str, torch.Tensor]], actions: Dict[str | int, Dict[str, torch.Tensor]], gather_output: bool) → Dict[str | int, List[torch.Tensor | Dict[str, torch.Tensor]]]¶
(overrides
TorchStateActionCritic)- implementation of