TorchModelBlock¶
-
class
maze.perception.blocks.general.torch_model_block.
TorchModelBlock
(*args: Any, **kwargs: Any)¶ A block transforming a common nn.Module to a shape-normalized Maze perception block.
- Parameters
in_keys – Keys identifying the input tensors.
out_keys – Keys identifying the output tensors.
in_shapes – List of input shapes.
in_num_dims – Required number of dimensions for corresponding input.
out_num_dims – Required number of dimensions for corresponding output.
net – An nn.Module PyTorch net (the forward method of which must accept a Tensor input dict as parameter and must return a Tensor output dict)
-
normalized_forward
(block_input: Dict[str, torch.Tensor]) → Dict[str, torch.Tensor]¶ (overrides
ShapeNormalizationBlock
)implementation of
ShapeNormalizationBlock
interface