TorchStepStateCritic

class maze.core.agent.torch_state_critic.TorchStepStateCritic(networks: Mapping[Union[str, int], torch.nn.Module], obs_spaces_dict: Dict[Union[str, int], gym.spaces.Dict], device: str)

Each sub-step or actor gets its individual critic. Can be instantiated via the StepStateCriticComposer.

compute_structured_return(gamma: float, gae_lambda: float, rewards: List[torch.Tensor], values: List[torch.Tensor], dones: torch.Tensor) → List[torch.Tensor]

(overrides TorchStateCritic)

Compute returns for each sub-step separately

property num_critics

(overrides TorchStateCritic)

implementation of TorchStateCritic

predict_values(critic_input: maze.core.agent.state_critic_input_output.StateCriticInput)maze.core.agent.state_critic_input_output.StateCriticOutput

(overrides StateCritic)

implementation of TorchStateCritic