get_log_rhos¶
-
class
maze.train.trainers.impala.impala_vtrace.
get_log_rhos
(target_action_log_probs: List[Dict[str, torch.Tensor]], behaviour_action_log_probs: List[Dict[str, torch.Tensor]])¶ With the selected log_probs for multi-discrete actions of behavior and target policies we compute the log_rhos for calculating the vtrace.
- Parameters
target_action_log_probs – A list (w.r.t. the substeps of the env) of dicts (w.r.t. the actions) of tensors of shape [T, B] corresponding to the sampling log probability of the chosen action w.r.t. the target policy.
behaviour_action_log_probs – A list (w.r.t. the substeps of the env) of dicts (w.r.t. the actions) of tensors of shape [T, B] corresponding to the sampling log probability of the chosen action w.r.t. the behaviour policy.
- Returns
a list (w.r.t. the substeps of the env) of tensors, where each tensor is of the shape [T,B]