PolicySubStepOutput

class maze.core.agent.torch_policy_output.PolicySubStepOutput(action_logits: Dict[str, torch.Tensor], prob_dist: DictProbabilityDistribution, embedding_logits: Dict[str, torch.Tensor] | None, actor_id: ActorID)

Dataclass for holding the output of the policy’s compute full output method

action_logits: Dict[str, torch.Tensor]

A logits dictionary (action_head maps to action_logits) to parameterize the distribution from.

actor_id: ActorID

The actor id of the output

embedding_logits: Dict[str, torch.Tensor] | None

The Embedding output if applicable, used as the input for the critic network.

property entropy: torch.Tensor

The entropy of the probability distribution.

prob_dist: DictProbabilityDistribution

The respective instance of a DictProbabilityDistribution.