from_importance_weights

class maze.train.trainers.impala.impala_vtrace.from_importance_weights(log_rhos: torch.Tensor, discounts: torch.Tensor, rewards: torch.Tensor, values: torch.Tensor, bootstrap_value: torch.Tensor, clip_rho_threshold: Optional[float], clip_pg_rho_threshold: Optional[float])

V-trace from log importance weights.

Calculates V-trace actor critic targets as described in

“IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures” by Espeholt, Soyer, Munos et al.

In the notation used throughout documentation and comments, T refers to the time dimension ranging from 0 to T-1. B refers to the batch size. This code also supports the case where all tensors have the same number of additional dimensions, e.g., rewards is [T, B, C], values is [T, B, C], bootstrap_value is [B, C].

Parameters
  • log_rhos – A float32 tensor of shape [T, B] representing the log importance sampling weights, i.e. log(target_policy(a) / behaviour_policy(a)). V-trace performs operations on rhos in log-space for numerical stability.

  • discounts – A float32 tensor of shape [T, B] with discounts encountered when following the behaviour policy.

  • rewards – A float32 tensor of shape [T, B] containing rewards generated by following the behaviour policy.

  • values – A float32 tensor of shape [T, B] with the value function estimates wrt. the target policy.

  • bootstrap_value – A float32 of shape [B] with the value function estimate at time T.

  • clip_rho_threshold – A scalar float32 tensor with the clipping threshold for importance weights (rho) when calculating the baseline targets (vs). rho^bar in the paper. If None, no clipping is applied.

  • clip_pg_rho_threshold – A scalar float32 tensor with the clipping threshold on rho_s in rho_s delta log pi(a|x) (r + gamma v_{s+1} - V(x_sfrom_importance_weights)). If None, no clipping is applied.

Returns

A VTraceReturns namedtuple (vs, pg_advantages) where: vs: A float32 tensor of shape [T, B]. Can be used as target to train a baseline (V(x_t) - vs_t)^2. pg_advantages: A float32 tensor of shape [T, B]. Can be used as the advantage in the calculation of policy gradients.