BCTrainer¶
-
class
maze.train.trainers.imitation.bc_trainer.
BCTrainer
(algorithm_config: maze.train.trainers.imitation.bc_algorithm_config.BCAlgorithmConfig, data_loader: torch.utils.data.DataLoader, policy: maze.core.agent.torch_policy.TorchPolicy, optimizer: torch.optim.Optimizer, loss: maze.train.trainers.imitation.bc_loss.BCLoss)¶ Trainer for behavioral cloning learning.
Runs training on top of provided trajectory data and rolls out the policy using the provided evaluator.
In structured (multi-step) envs, all policies are trained simultaneously based on the substep actions and observation present in the trajectory data.
-
data_loader
: torch.utils.data.DataLoader¶ Data loader for loading trajectory data.
-
imitation_events
: maze.train.trainers.imitation.imitation_events.ImitationEvents = <abc.ImitationEventsProxy object>¶ Imitation-specific training events
-
load_state_dict
(state_dict: Dict) → None¶ Set the model and optimizer state. :param state_dict: The state dict.
-
loss
: maze.train.trainers.imitation.bc_loss.BCLoss¶ Class providing the training loss function.
-
optimizer
: torch.optim.Optimizer¶ Optimizer to use
-
policy
: maze.core.agent.torch_policy.TorchPolicy¶ Structured policy to train.
-
train
(evaluator: maze.train.trainers.common.evaluators.evaluator.Evaluator, n_epochs: Optional[int] = None, eval_every_k_iterations: Optional[int] = None) → None¶ (overrides
Trainer
)Run training. :param evaluator: Evaluator to use for evaluation rollouts :param n_epochs: How many epochs to train for :param eval_every_k_iterations: Number of iterations after which to run evaluation (in addition to evaluations at the end of each epoch, which are run automatically). If set to None, evaluations will run on epoch end only.
-
train_stats
: maze.core.log_stats.log_stats.LogStatsAggregator = <maze.core.log_stats.log_stats.LogStatsAggregator object>¶ Training statistics
-