TorchModel

class maze.core.agent.torch_model.TorchModel(device: str)

Base class for any torch model.

Parameters

device – Device the networks should be located on (cpu or cuda)

property device

Returns the device the networks are located on.

abstract eval()None

Set all networks to eval mode.

abstract load_state_dict(state_dict: Dict)None

Set state dict of all encapsulated networks. :param state_dict: The torch state dictionary.

property num_params

Returns overall number of network parameters.

abstract parameters() → List[torch.Tensor]

Returns all parameters of all networks.

abstract state_dict() → Dict

Return state dict composed of state dicts of all encapsulated networks.

abstract to(device: str)None

Move all networks to the specified device. :param device: The target device.

abstract train()None

Set all networks to training mode.