StateCriticStepInput¶
- class maze.core.agent.state_critic_input_output.StateCriticStepInput(tensor_dict: Dict[str, torch.Tensor], actor_id: ActorID)¶
State Critic input for a single substep of the env, holding the tensor_dict and the actor_ids corresponding to where the embedding logits where retrieved if applicable, otherwise just the corresponding actor.
The tensor dict here holds the observations of the corresponding env-sub-step as well as logits coming from the actor if a shared embedding is used.
- classmethod build(policy_step_output: PolicySubStepOutput, observation: Dict[str, torch.Tensor]) StateCriticStepInput¶
Build the critic input for an individual step, by combining the policy step output and the given observation.
- Parameters:
policy_step_output – The output of the corresponding policy ot check for shared embedding outputs.
observation – The observation as the default input to the critic.
- Returns:
The Critic input for this specific step.
- tensor_dict: Dict[str, torch.Tensor]¶
The tensor dict to use as an input for the critic.