TorchStateActionCritic¶
- class maze.core.agent.torch_state_action_critic.TorchStateActionCritic(networks: Mapping[str | int, torch.nn.Module], num_policies: int, device: str, only_discrete_spaces: Dict[str | int, bool], action_spaces_dict: Dict[str | int, gymnasium.spaces.Dict])¶
Encapsulates multiple torch state action critics for training in structured environments.
- Parameters:
networks – Mapping of value functions (critic) to encapsulate.
num_policies – The number of corresponding policies.
device – Device the policy should be located on (cpu or cuda)
only_discrete_spaces – A dict specifying if the action spaces w.r.t. the step only hold discrete action spaces.
- compute_state_action_value_step(observation: Dict[str, torch.Tensor], action: Dict[str, torch.Tensor], critic_id: str | int | tuple) List[torch.Tensor]¶
Predict the value with specified step_key, step_observation and action.
- Parameters:
observation – The observation for the current step.
action – The action performed at the current step.
critic_id – The current step key of the multi-step env.
- Returns:
A list of tensors holding the predicted q value for each critic.
- compute_state_action_values_step(observation: Dict[str, torch.Tensor], critic_id: str | int | tuple) List[Dict[str, torch.Tensor]]¶
Predict the value with specified step_key, step_observation and action for discrete actions only.
- Parameters:
observation – The observation for the current step.
critic_id – The current step key of the multi-step env.
- Returns:
A list of dicts holding the predicted q value for each action w.r.t. to the critic.
- property device: str¶
implementation of
TorchModel
- eval() None¶
(overrides
TorchModel)implementation of
TorchModel
- load_state_dict(state_dict: Dict) None¶
(overrides
TorchModel)implementation of
TorchModel
- abstract property num_critics: int¶
Returns the number of critic networks. :return: Number of critic networks.
- parameters() List[torch.Tensor]¶
(overrides
TorchModel)implementation of
TorchModel
- per_critic_parameters() List[List[torch.Tensor]]¶
Retrieve all trainable critic parameters (to be assigned to optimizers). :return: List of lists holding all parameters for the base critic corresponding to number of critic per step.
- abstract predict_next_q_values(next_observations: Dict[str | int, Dict[str, torch.Tensor]], next_actions: Dict[str | int, Dict[str, torch.Tensor]], next_actions_logits: Dict[str | int, Dict[str, torch.Tensor]], next_actions_log_probs: Dict[str | int, Dict[str, torch.Tensor]], alpha: Dict[str | int, torch.Tensor]) Dict[str | int, torch.Tensor | Dict[str, torch.Tensor]]¶
Predict the target q value for the next step. \(V (st) := E_{at∼π}[Q(st, at) − α log(π(at |st))]\).
- Parameters:
next_observations – The next observations.
next_actions – The next actions sampled from the policy.
next_actions_logits – The logits of the next actions (only relevantt for the discrete case).
next_actions_log_probs – The log probabilities of the actions.
alpha – The alpha or entropy coefficient for each step.
- Returns:
A dict w.r.t. the step holding tensors representing the predicted next q value
- abstract predict_q_values(observations: Dict[str | int, Dict[str, torch.Tensor]], actions: Dict[str | int, Dict[str, torch.Tensor]], gather_output: bool) Dict[str | int, List[torch.Tensor | Dict[str, torch.Tensor]]]¶
(overrides
StateActionCritic)implementation of
StateActionCritic
- state_dict() Dict¶
(overrides
TorchModel)implementation of
TorchModel
- to(device: str) None¶
(overrides
TorchModel)implementation of
TorchModel
- train() None¶
(overrides
TorchModel)implementation of
TorchModel