TorchStateCritic¶
-
class
maze.core.agent.torch_state_critic.
TorchStateCritic
(networks: Mapping[Union[str, int], torch.nn.Module], obs_spaces_dict: Dict[Union[str, int], gym.spaces.Dict], device: str)¶ Encapsulates multiple torch state critics for training in structured environments.
- Parameters
networks – Mapping of value functions (critic) to encapsulate.
obs_spaces_dict – The observation spaces dict of the environment.
device – Device the policy should be located on (cpu or cuda).
-
compute_return
(gamma: float, gae_lambda: float, rewards: torch.Tensor, values: torch.Tensor, dones: torch.Tensor, deltas: torch.Tensor = None) → torch.Tensor¶ Compute bootstrapped return from rewards and estimated values.
- Parameters
gamma – Discounting factor
gae_lambda – Bias vs variance trade of factor for Generalized Advantage Estimator (GAE)
rewards – Step rewards with shape (n_steps, n_workers)
values – Predicted values with shape (n_steps, n_workers)
dones – Step dones with shape (n_steps, n_workers)
deltas – Predicted value deltas to previous sub-step with shape (n_steps, n_workers)
- Returns
Per time step returns.
-
abstract
compute_structured_return
(gamma: float, gae_lambda: float, rewards: List[torch.Tensor], values: List[torch.Tensor], dones: torch.Tensor) → List[torch.Tensor]¶ Compute bootstrapped return for the whole structured step (i.e., all sub-steps).
- Parameters
gamma – Discounting factor
gae_lambda – Bias vs variance trade of factor for Generalized Advantage Estimator (GAE)
rewards – List of sub-step rewards, each with shape (n_steps, n_workers)
values – List of sub-step detached values, each with shape (n_steps, n_workers)
dones – Step dones with shape (n_steps, n_workers)
- Returns
List of per-time sub-step returns
-
property
device
¶ implementation of
TorchModel
-
eval
() → None¶ (overrides
TorchModel
)implementation of
TorchModel
-
load_state_dict
(state_dict: Dict) → None¶ (overrides
TorchModel
)implementation of
TorchModel
-
abstract property
num_critics
¶ Returns the number of critic networks. :return: Number of critic networks.
-
parameters
() → List[torch.Tensor]¶ (overrides
TorchModel
)implementation of
TorchModel
-
predict_value
(observation: Dict[str, numpy.ndarray], critic_id: Union[int, str]) → Dict[str, torch.Tensor]¶ (overrides
StateCritic
)implementation of
StateCritic
-
state_dict
() → Dict¶ (overrides
TorchModel
)implementation of
TorchModel
-
to
(device: str) → None¶ (overrides
TorchModel
)implementation of
TorchModel
-
train
() → None¶ (overrides
TorchModel
)implementation of
TorchModel