StateActionCritic¶
- class maze.core.agent.state_action_critic.StateActionCritic¶
Structured state action critic class designed to work with structured environments. (see
StructuredEnv).It encapsulates state critic and queries them for values according to the provided policy ID.
- abstract predict_q_values(observations: Dict[str | int, Dict[str, torch.Tensor]], actions: Dict[str | int, Dict[str, torch.Tensor]], gather_output: bool) Dict[str | int, List[torch.Tensor | Dict[str, torch.Tensor]]]¶
Predict the Q value based on the observations and actions.
- Parameters:
observations – The observation for the current step.
actions – The action performed at the current step.
gather_output – Specify whether to gather the output in the discrete setting.
- Returns:
A list of tensors holding the predicted q value for each critic.