TorchStepStateActionCritic

class maze.core.agent.torch_state_action_critic.TorchStepStateActionCritic(networks: Mapping[Union[str, int], torch.nn.Module], num_policies: int, device: str, only_discrete_spaces: Dict[Union[str, int], bool], action_spaces_dict: Dict[Union[str, int], gym.spaces.Dict])

Each sub-step or actor gets its individual critic. Can be instantiated via the StepStateActionCriticComposer.

property num_critics

(overrides TorchStateActionCritic)

implementation of

TorchStateActionCritic

predict_next_q_values(next_observations: Dict[Union[str, int], Dict[str, torch.Tensor]], next_actions: Dict[Union[str, int], Dict[str, torch.Tensor]], next_actions_logits: Dict[Union[str, int], Dict[str, torch.Tensor]], next_actions_log_probs: Dict[Union[str, int], Dict[str, torch.Tensor]], alpha: Dict[Union[str, int], torch.Tensor]) → Dict[Union[str, int], Union[torch.Tensor, Dict[str, torch.Tensor]]]

(overrides TorchStateActionCritic)

implementation of

TorchStateActionCritic

predict_q_values(observations: Dict[Union[str, int], Dict[str, torch.Tensor]], actions: Dict[Union[str, int], Dict[str, torch.Tensor]], gather_output: bool) → Dict[Union[str, int], List[Union[torch.Tensor, Dict[str, torch.Tensor]]]]

(overrides TorchStateActionCritic)

implementation of

TorchStateActionCritic