from_logits

class maze.train.trainers.impala.impala_vtrace.from_logits(behaviour_policy_logits: List[Dict[str, torch.Tensor]], target_policy_logits: List[Dict[str, torch.Tensor]], actions: List[Dict[str, torch.Tensor]], distribution_mapper: maze.distributions.distribution_mapper.DistributionMapper, discounts: torch.Tensor, rewards: torch.Tensor, values: List[torch.Tensor], bootstrap_value: List[torch.Tensor], clip_rho_threshold: Optional[float], clip_pg_rho_threshold: Optional[float], device: Optional[str])

V-trace for softmax policies.

Calculates V-trace actor critic targets for softmax polices as described in

“IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures” by Espeholt, Soyer, Munos et al.

Target policy refers to the policy we are interested in improving and behaviour policy refers to the policy that generated the given rewards and actions.

In the notation used throughout documentation and comments, T refers to the time dimension ranging from 0 to T-1. B refers to the batch size and ACTION_SPACE refers to the list of numbers each representing a number of actions.

Parameters
  • behaviour_policy_logits – A list (w.r.t. the substeps of the env) of dict (w.r.t. the actions) of tensors of un-normalized log-probabilities (shape list[dict[str,[T, B, NUM_ACTIONS]]])

  • target_policy_logits – A list (w.r.t. the substeps of the env) of dict (w.r.t. the actions) of tensors of un-normalized log-probabilities (shape list[dict[str,[T, B, NUM_ACTIONS]]])

  • actions – An list (w.r.t. the substeps of the env) of dicts (w.r.t. the actions) with actions sampled from the behavior policy. (list[dict[str,[T, B]]])

  • distribution_mapper – A distribution mapper providing a mapping of action heads to distributions.

  • discounts – A float32 tensor of shape [T, B] with the discount encountered when following the behavior policy.

  • rewards – A float32 tensor of shape [T, B] with the rewards generated by following the behavior policy.

  • values – A list (w.r.t. the substeps of the env) of float32 tensors of shape [T, B] with the value function estimates wrt. the target policy.

  • bootstrap_value – A list (w.r.t. the substeps of the env) of float32 tensors of shape [B] with the value function estimate at time T.

  • clip_rho_threshold – A scalar float32 tensor with the clipping threshold for importance weights (rho) when calculating the baseline targets (vs). rho^bar in the paper.

  • clip_pg_rho_threshold – A scalar float32 tensor with the clipping threshold on rho_s in: rho_s delta log pi(a|x) (r + gamma v_{s+1} - V(x_s)).

  • device – the device the results should be sent to before returning it

Returns

A VTraceFromLogitsReturns namedtuple with the following fields: vs: A list (w.r.t. the substeps of the env) of float32 tensors of shape [T, B]. Can be used as target to train a baseline (V(x_t) - vs_t)^2. pg_advantages: A list (w.r.t. the substeps of the env) of float32 tensors of shape [T, B]. Can be used as an estimate of the advantage in the calculation of policy gradients. log_rhos: A list (w.r.t. the substeps of the env) of float32 tensors of shape [T, B] containing the log importance sampling weights (log rhos). behaviour_action_log_probs: A list (w.r.t. the substeps of the env) of float32 tensors of shape [T, B] containing the behaviour policy action log probabilities (log mu(a_t)). target_action_log_probs: A list (w.r.t. the substeps of the env) of float32 tensors of shape [T, B] containing target policy action probabilities (log pi(a_t)). target_step_action_dists: A list (w.r.t. the substeps of the env) of the action probability distributions w.r.t. to the target policy