BCLoss¶
-
class
maze.train.trainers.imitation.bc_loss.
BCLoss
(action_spaces_dict: Dict[Union[int, str], gym.spaces.Dict], entropy_coef: float, loss_discrete: torch.nn.Module = torch.nn.CrossEntropyLoss, loss_box: torch.nn.Module = torch.nn.MSELoss, loss_multi_binary: torch.nn.Module = torch.nn.BCEWithLogitsLoss)¶ Loss function for behavioral cloning.
-
action_spaces_dict
: Dict[Union[int, str], gym.spaces.Dict]¶ Action space we are training on (used to determine appropriate loss functions)
-
calculate_loss
(policy: maze.core.agent.torch_policy.TorchPolicy, observations: List[Dict[str, numpy.ndarray]], actions: List[Dict[str, torch.Tensor]], action_logits: Optional[List[Dict[str, torch.Tensor]]], actor_ids: List[maze.core.env.structured_env.ActorID], events: maze.train.trainers.imitation.imitation_events.ImitationEvents) → torch.Tensor¶ Calculate and return the training loss for one step (= multiple sub-steps in structured scenarios).
- Parameters
policy – Structured policy to evaluate.
observations – List with observations w.r.t. actor_ids.
actions – List with actions w.r.t. actor_ids.
action_logits – The optional action logits of the policy.
actor_ids – List of actor ids.
events – Events of current episode.
- Returns
Total loss
-