BCValidationEvaluator¶
- class maze.train.trainers.imitation.bc_validation_evaluator.BCValidationEvaluator(loss: BCLoss, model_selection: ModelSelectionBase | None, data_loader: torch.utils.data.DataLoader, log_substep_events: bool, logging_prefix: str | None = 'eval')¶
Evaluates a given policy on validation data.
Expects that the first two items returned in the dataset tuple are the observation_dict and action_dict.
- Parameters:
data_loader – The data used for evaluation.
loss – Loss function to be used.
model_selection – Model selection interface that will be notified of the recorded rewards.
log_substep_events – Whether to log the individual substep events or not.
- evaluate(policy: TorchPolicy) None¶
(overrides
Evaluator)Evaluate given policy (results are stored in stat logs) and dump the model if the reward improved.
- param policy:
Policy to evaluate.