TorchStateCritic

class maze.core.agent.torch_state_critic.TorchStateCritic(networks: Mapping[Union[str, int], torch.nn.Module], obs_spaces_dict: Dict[Union[str, int], gym.spaces.Dict], device: str)

Encapsulates multiple torch state critics for training in structured environments.

Parameters
  • networks – Mapping of value functions (critic) to encapsulate.

  • obs_spaces_dict – The observation spaces dict of the environment.

  • device – Device the policy should be located on (cpu or cuda).

compute_return(gamma: float, gae_lambda: float, rewards: torch.Tensor, values: torch.Tensor, dones: torch.Tensor, deltas: torch.Tensor = None)torch.Tensor

Compute bootstrapped return from rewards and estimated values.

Parameters
  • gamma – Discounting factor

  • gae_lambda – Bias vs variance trade of factor for Generalized Advantage Estimator (GAE)

  • rewards – Step rewards with shape (n_steps, n_workers)

  • values – Predicted values with shape (n_steps, n_workers)

  • dones – Step dones with shape (n_steps, n_workers)

  • deltas – Predicted value deltas to previous sub-step with shape (n_steps, n_workers)

Returns

Per time step returns.

abstract compute_structured_return(gamma: float, gae_lambda: float, rewards: List[torch.Tensor], values: List[torch.Tensor], dones: torch.Tensor) → List[torch.Tensor]

Compute bootstrapped return for the whole structured step (i.e., all sub-steps).

Parameters
  • gamma – Discounting factor

  • gae_lambda – Bias vs variance trade of factor for Generalized Advantage Estimator (GAE)

  • rewards – List of sub-step rewards, each with shape (n_steps, n_workers)

  • values – List of sub-step detached values, each with shape (n_steps, n_workers)

  • dones – Step dones with shape (n_steps, n_workers)

Returns

List of per-time sub-step returns

property device

implementation of TorchModel

eval()None

(overrides TorchModel)

implementation of TorchModel

load_state_dict(state_dict: Dict)None

(overrides TorchModel)

implementation of TorchModel

abstract property num_critics

Returns the number of critic networks. :return: Number of critic networks.

parameters() → List[torch.Tensor]

(overrides TorchModel)

implementation of TorchModel

predict_value(observation: Dict[str, numpy.ndarray], critic_id: Union[int, str]) → Dict[str, torch.Tensor]

(overrides StateCritic)

implementation of StateCritic

state_dict() → Dict

(overrides TorchModel)

implementation of TorchModel

to(device: str)None

(overrides TorchModel)

implementation of TorchModel

train()None

(overrides TorchModel)

implementation of TorchModel