TorchDeltaStateCritic¶
- class maze.core.agent.torch_state_critic.TorchDeltaStateCritic(networks: Mapping[str | int, torch.nn.Module], obs_spaces_dict: Dict[str | int, gymnasium.spaces.Dict], device: str)¶
First sub step gets a regular critic, subsequent sub-steps predict a delta w.r.t. to the previous critic. Can be instantiated via the
DeltaStateCriticComposer.- compute_structured_return(gamma: float, gae_lambda: float, rewards: List[torch.Tensor], values: List[torch.Tensor], dones: torch.Tensor) List[torch.Tensor]¶
(overrides
TorchStateCritic)Compute return based on shared reward (summing the reward across all sub-steps)
- property num_critics: int¶
(overrides
TorchStateCritic)implementation of
TorchStateCritic
- predict_value(observation: Dict[str, numpy.ndarray], critic_id: int | str) torch.Tensor¶
(overrides
StateCritic)Predictions depend on previous sub-steps, thus this method is not supported in the delta state critic.
- predict_values(critic_input: StateCriticInput) StateCriticOutput¶
(overrides
StateCritic)implementation of
StateCritic