TorchStateCritic

class maze.core.agent.torch_state_critic.TorchStateCritic(networks: Mapping[str | int, torch.nn.Module], obs_spaces_dict: Dict[str | int, gymnasium.spaces.Dict], device: str)

Encapsulates multiple torch state critics for training in structured environments.

Parameters:
  • networks – Mapping of value functions (critic) to encapsulate.

  • obs_spaces_dict – The observation spaces dict of the environment.

  • device – Device the policy should be located on (cpu or cuda).

compute_return(gamma: float, gae_lambda: float, rewards: torch.Tensor, values: torch.Tensor, dones: torch.Tensor, deltas: torch.Tensor | None = None) torch.Tensor

Compute bootstrapped return from rewards and estimated values.

Parameters:
  • gamma – Discounting factor

  • gae_lambda – Bias vs variance trade of factor for Generalized Advantage Estimator (GAE)

  • rewards – Step rewards with shape (n_steps, n_workers)

  • values – Predicted values with shape (n_steps, n_workers)

  • dones – Step dones with shape (n_steps, n_workers)

  • deltas – Predicted value deltas to previous sub-step with shape (n_steps, n_workers)

Returns:

Per time step returns.

abstract compute_structured_return(gamma: float, gae_lambda: float, rewards: List[torch.Tensor], values: List[torch.Tensor], dones: torch.Tensor) List[torch.Tensor]

Compute bootstrapped return for the whole structured step (i.e., all sub-steps).

Parameters:
  • gamma – Discounting factor

  • gae_lambda – Bias vs variance trade of factor for Generalized Advantage Estimator (GAE)

  • rewards – List of sub-step rewards, each with shape (n_steps, n_workers)

  • values – List of sub-step detached values, each with shape (n_steps, n_workers)

  • dones – Step dones with shape (n_steps, n_workers)

Returns:

List of per-time sub-step returns

property device: str

implementation of TorchModel

eval() None

(overrides TorchModel)

implementation of TorchModel

load_state_dict(state_dict: Dict) None

(overrides TorchModel)

implementation of TorchModel

abstract property num_critics: int

Returns the number of critic networks. :return: Number of critic networks.

parameters() List[torch.Tensor]

(overrides TorchModel)

implementation of TorchModel

predict_value(observation: Dict[str, numpy.ndarray], critic_id: int | str) Dict[str, torch.Tensor]

(overrides StateCritic)

implementation of StateCritic

state_dict() Dict

(overrides TorchModel)

implementation of TorchModel

to(device: str) None

(overrides TorchModel)

implementation of TorchModel

train() None

(overrides TorchModel)

implementation of TorchModel