TorchSharedStateActionCritic¶
-
class
maze.core.agent.torch_state_action_critic.
TorchSharedStateActionCritic
(networks: Mapping[Union[str, int], torch.nn.Module], num_policies: int, device: str, only_discrete_spaces: Dict[Union[str, int], bool], action_spaces_dict: Dict[Union[str, int], gym.spaces.Dict])¶ One critic is shared across all sub-steps or actors (default to use for standard gym-style environments). Can be instantiated via the
SharedStateActionCriticComposer
.-
property
num_critics
¶ (overrides
TorchStateActionCritic
)- implementation of
-
predict_next_q_values
(next_observations: Dict[Union[str, int], Dict[str, torch.Tensor]], next_actions: Dict[Union[str, int], Dict[str, torch.Tensor]], next_actions_logits: Dict[Union[str, int], Dict[str, torch.Tensor]], next_actions_log_probs: Dict[Union[str, int], Dict[str, torch.Tensor]], alpha: Dict[Union[str, int], torch.Tensor]) → Dict[Union[str, int], Union[torch.Tensor, Dict[str, torch.Tensor]]]¶ (overrides
TorchStateActionCritic
)- implementation of
-
predict_q_values
(observations: Dict[Union[str, int], Dict[str, torch.Tensor]], actions: Dict[Union[str, int], Dict[str, torch.Tensor]], gather_output: bool) → Dict[Union[str, int], List[Union[torch.Tensor, Dict[str, torch.Tensor]]]]¶ (overrides
TorchStateActionCritic
)- implementation of
-
property