get_log_rhos

class maze.train.trainers.impala.impala_vtrace.get_log_rhos(target_action_log_probs: List[Dict[str, torch.Tensor]], behaviour_action_log_probs: List[Dict[str, torch.Tensor]])

With the selected log_probs for multi-discrete actions of behavior and target policies we compute the log_rhos for calculating the vtrace.

Parameters
  • target_action_log_probs – A list (w.r.t. the substeps of the env) of dicts (w.r.t. the actions) of tensors of shape [T, B] corresponding to the sampling log probability of the chosen action w.r.t. the target policy.

  • behaviour_action_log_probs – A list (w.r.t. the substeps of the env) of dicts (w.r.t. the actions) of tensors of shape [T, B] corresponding to the sampling log probability of the chosen action w.r.t. the behaviour policy.

Returns

a list (w.r.t. the substeps of the env) of tensors, where each tensor is of the shape [T,B]