TrainingRunner

class maze.train.trainers.common.training_runner.TrainingRunner(state_dict_dump_file: str, dump_interval: Optional[int], spaces_config_dump_file: str, normalization_samples: int)

Base class for training runner implementations.

property cfg

Returns Hydra config. :return: Hydra config.

dump_interval: Optional[int]

If provided the state dict will be dumped ever ‘dump_interval’ epochs.

property model_composer

Returns model composer. :return: Model composer.

normalization_samples: int

Number of samples (=steps) to collect normalization statistics at the beginning of the training.

run(n_epochs: Optional[int] = None, **train_kwargs)None

Runs training. While this method is designed to be overriden by individual subclasses, it provides some functionality that is useful in general:

  • Building the env factory for env + wrappers

  • Estimating normalization statistics from the env

  • If successfully estimated, wrapping the env factory so that envs are already built with the statistics

  • Building the model composer from model config and env spaces config

  • Serializing the env spaces configuration (so that the model composer can be re-loaded for future rollout)

  • Initializing logging setup

Parameters
  • n_epochs – Number of epochs to train.

  • train_kwargs – Additional arguments for trainer.train().

setup(cfg: omegaconf.DictConfig)None

Sets up prerequisites to training. Includes wrapping the environment for observation normalization, instantiating the model composer etc. :param cfg: DictConfig defining components to initialize.

spaces_config_dump_file: str

Where to save the env spaces configuration (output directory handled by hydra).

state_dict_dump_file: str

Where to save the best model (output directory handled by hydra).