TorchStateCritic¶
- class maze.core.agent.torch_state_critic.TorchStateCritic(networks: Mapping[str | int, torch.nn.Module], obs_spaces_dict: Dict[str | int, gymnasium.spaces.Dict], device: str)¶
Encapsulates multiple torch state critics for training in structured environments.
- Parameters:
networks – Mapping of value functions (critic) to encapsulate.
obs_spaces_dict – The observation spaces dict of the environment.
device – Device the policy should be located on (cpu or cuda).
- compute_return(gamma: float, gae_lambda: float, rewards: torch.Tensor, values: torch.Tensor, dones: torch.Tensor, deltas: torch.Tensor | None = None) torch.Tensor¶
Compute bootstrapped return from rewards and estimated values.
- Parameters:
gamma – Discounting factor
gae_lambda – Bias vs variance trade of factor for Generalized Advantage Estimator (GAE)
rewards – Step rewards with shape (n_steps, n_workers)
values – Predicted values with shape (n_steps, n_workers)
dones – Step dones with shape (n_steps, n_workers)
deltas – Predicted value deltas to previous sub-step with shape (n_steps, n_workers)
- Returns:
Per time step returns.
- abstract compute_structured_return(gamma: float, gae_lambda: float, rewards: List[torch.Tensor], values: List[torch.Tensor], dones: torch.Tensor) List[torch.Tensor]¶
Compute bootstrapped return for the whole structured step (i.e., all sub-steps).
- Parameters:
gamma – Discounting factor
gae_lambda – Bias vs variance trade of factor for Generalized Advantage Estimator (GAE)
rewards – List of sub-step rewards, each with shape (n_steps, n_workers)
values – List of sub-step detached values, each with shape (n_steps, n_workers)
dones – Step dones with shape (n_steps, n_workers)
- Returns:
List of per-time sub-step returns
- property device: str¶
implementation of
TorchModel
- eval() None¶
(overrides
TorchModel)implementation of
TorchModel
- load_state_dict(state_dict: Dict) None¶
(overrides
TorchModel)implementation of
TorchModel
- abstract property num_critics: int¶
Returns the number of critic networks. :return: Number of critic networks.
- parameters() List[torch.Tensor]¶
(overrides
TorchModel)implementation of
TorchModel
- predict_value(observation: Dict[str, numpy.ndarray], critic_id: int | str) Dict[str, torch.Tensor]¶
(overrides
StateCritic)implementation of
StateCritic
- state_dict() Dict¶
(overrides
TorchModel)implementation of
TorchModel
- to(device: str) None¶
(overrides
TorchModel)implementation of
TorchModel
- train() None¶
(overrides
TorchModel)implementation of
TorchModel