log_probs_from_logits_and_actions_and_spaces

class maze.train.trainers.impala.impala_vtrace.log_probs_from_logits_and_actions_and_spaces(policy_logits: List[Dict[str, torch.Tensor]], actions: List[Dict[str, torch.Tensor]], distribution_mapper: maze.distributions.distribution_mapper.DistributionMapper)

Computes action log-probs from policy logits, actions and acton_spaces.

In the notation used throughout documentation and comments, T refers to the time dimension ranging from 0 to T-1. B refers to the batch size and NUM_ACTIONS refers to the number of actions.

Parameters
  • policy_logits – A list (w.r.t. the substeps of the env) of dicts (w.r.t. the actions) of tensors of un-normalized log-probabilities (shape list[dict[str,[T, B, NUM_ACTIONS]]])

  • actions – An list (w.r.t. the substeps of the env) of dicts (w.r.t. the actions) of tensors (list[dict[str,[T, B]]])

  • distribution_mapper – A distribution mapper providing a mapping of action heads to distributions.

Returns

A list (w.r.t. the substeps of the env) of dicts (w.r.t. the actions) of tensors of shape [T, B] corresponding to the sampling log probability of the chosen action w.r.t. the policy. And a list (w.r.t. the substeps of the env) of DictProbability distributions corresponding to the step-action- distributions.