TorchModelBlock¶
- class maze.perception.blocks.general.torch_model_block.TorchModelBlock(*args: Any, **kwargs: Any)¶
A block transforming a common nn.Module to a shape-normalized Maze perception block.
- Parameters:
in_keys – Keys identifying the input tensors.
out_keys – Keys identifying the output tensors.
in_shapes – List of input shapes.
in_num_dims – Required number of dimensions for corresponding input.
out_num_dims – Required number of dimensions for corresponding output.
net – An nn.Module PyTorch net (the forward method of which must accept a Tensor input dict as parameter and must return a Tensor output dict)
- normalized_forward(block_input: Dict[str, torch.Tensor]) Dict[str, torch.Tensor]¶
(overrides
ShapeNormalizationBlock)implementation of
ShapeNormalizationBlockinterface