TorchSharedStateCritic

class maze.core.agent.torch_state_critic.TorchSharedStateCritic(networks: Mapping[str | int, torch.nn.Module], obs_spaces_dict: Dict[str | int, gymnasium.spaces.Dict], device: str, stack_observations: bool)

One critic is shared across all sub-steps or actors (default to use for standard gym-style environments).

In multi-step and multi-agent scenarios, observations from different sub-steps are merged into one. Observation keys common across multiple sub-steps are expected to have the same value and are present only once in the resulting dictionary.

Can be instantiated via the SharedStateCriticComposer.

compute_structured_return(gamma: float, gae_lambda: float, rewards: List[torch.Tensor], values: List[torch.Tensor], dones: torch.Tensor) List[torch.Tensor]

(overrides TorchStateCritic)

Compute return based on shared reward (summing the reward across all sub-steps)

property num_critics: int

(overrides TorchStateCritic)

There is a single shared critic network.

predict_value(observation: Dict[str, numpy.ndarray], critic_id: int | str) torch.Tensor

(overrides StateCritic)

Predictions depend on previous sub-steps, thus this method is not supported in the delta state critic.

predict_values(critic_input: StateCriticInput) StateCriticOutput

(overrides StateCritic)

implementation of TorchStateCritic