StateCriticStepInput

class maze.core.agent.state_critic_input_output.StateCriticStepInput(tensor_dict: Dict[str, torch.Tensor], actor_id: maze.core.env.structured_env.ActorID)

State Critic input for a single substep of the env, holding the tensor_dict and the actor_ids corresponding to where the embedding logits where retrieved if applicable, otherwise just the corresponding actor.

The tensor dict here holds the observations of the corresponding env-sub-step as well as logits coming from the actor if a shared embedding is used.

actor_id: maze.core.env.structured_env.ActorID

The actor id of the corresponding actor.

classmethod build(policy_step_output: maze.core.agent.torch_policy_output.PolicySubStepOutput, observation: Dict[str, torch.Tensor])maze.core.agent.state_critic_input_output.StateCriticStepInput

Build the critic input for an individual step, by combining the policy step output and the given observation.

Parameters
  • policy_step_output – The output of the corresponding policy ot check for shared embedding outputs.

  • observation – The observation as the default input to the critic.

Returns

The Critic input for this specific step.

tensor_dict: Dict[str, torch.Tensor]

The tensor dict to use as an input for the critic.