StateCriticStepInput

class maze.core.agent.state_critic_input_output.StateCriticStepInput(tensor_dict: Dict[str, torch.Tensor], actor_id: ActorID)

State Critic input for a single substep of the env, holding the tensor_dict and the actor_ids corresponding to where the embedding logits where retrieved if applicable, otherwise just the corresponding actor.

The tensor dict here holds the observations of the corresponding env-sub-step as well as logits coming from the actor if a shared embedding is used.

actor_id: ActorID

The actor id of the corresponding actor.

classmethod build(policy_step_output: PolicySubStepOutput, observation: Dict[str, torch.Tensor]) StateCriticStepInput

Build the critic input for an individual step, by combining the policy step output and the given observation.

Parameters:
  • policy_step_output – The output of the corresponding policy ot check for shared embedding outputs.

  • observation – The observation as the default input to the critic.

Returns:

The Critic input for this specific step.

tensor_dict: Dict[str, torch.Tensor]

The tensor dict to use as an input for the critic.