Trainer

class maze.train.trainers.common.trainer.Trainer(algorithm_config: AlgorithmConfigType, model: Optional[maze.core.agent.torch_model.TorchModel] = None)

Interface for trainers. :param algorithm_config: Algorithm configuration including all parameter expected in .train(). :param model: Model to train.

abstract load_state(file_path: Union[str, BinaryIO])None

Load state from file. This is required for resuming training or model fine tuning with different parameters.

Parameters

file_path – Path from where to load the state.

abstract state_dict() → Dict

Returns the state dict composed of all encapsulated trainer components.

Returns

The trainer’s state dict.

abstract train(n_epochs: Optional[int] = None, **kwargs)None

Train for n epochs. kwargs describe additional configuration necessary at training time, as for e.g. ESTrainer or BCTrainer. Some necessary parameters are set at initialization time and can be read from the trainer’s algorithm config. Hence all parameters available at initialization time are optional to set at train time. :param n_epochs: Number of epochs to train. :param kwargs: Additional, trainer-specific parameters.