StateActionCritic

class maze.core.agent.state_action_critic.StateActionCritic

Structured state action critic class designed to work with structured environments. (see StructuredEnv).

It encapsulates state critic and queries them for values according to the provided policy ID.

abstract predict_q_values(observations: Dict[Union[str, int], Dict[str, torch.Tensor]], actions: Dict[Union[str, int], Dict[str, torch.Tensor]], gather_output: bool) → Dict[Union[str, int], List[Union[torch.Tensor, Dict[str, torch.Tensor]]]]

Predict the Q value based on the observations and actions.

Parameters
  • observations – The observation for the current step.

  • actions – The action performed at the current step.

  • gather_output – Specify whether to gather the output in the discrete setting.

Returns

A list of tensors holding the predicted q value for each critic.