Trainer¶
-
class
maze.train.trainers.common.trainer.
Trainer
(algorithm_config: AlgorithmConfigType, model: Optional[maze.core.agent.torch_model.TorchModel] = None)¶ Interface for trainers. :param algorithm_config: Algorithm configuration including all parameter expected in .train(). :param model: Model to train.
-
abstract
load_state
(file_path: Union[str, BinaryIO]) → None¶ Load state from file. This is required for resuming training or model fine tuning with different parameters.
- Parameters
file_path – Path from where to load the state.
-
abstract
state_dict
() → Dict¶ Returns the state dict composed of all encapsulated trainer components.
- Returns
The trainer’s state dict.
-
abstract
train
(n_epochs: Optional[int] = None, **kwargs) → None¶ Train for n epochs. kwargs describe additional configuration necessary at training time, as for e.g. ESTrainer or BCTrainer. Some necessary parameters are set at initialization time and can be read from the trainer’s algorithm config. Hence all parameters available at initialization time are optional to set at train time. :param n_epochs: Number of epochs to train. :param kwargs: Additional, trainer-specific parameters.
-
abstract