TorchModelBlock

class maze.perception.blocks.general.torch_model_block.TorchModelBlock(*args: Any, **kwargs: Any)

A block transforming a common nn.Module to a shape-normalized Maze perception block.

Parameters
  • in_keys – Keys identifying the input tensors.

  • out_keys – Keys identifying the output tensors.

  • in_shapes – List of input shapes.

  • in_num_dims – Required number of dimensions for corresponding input.

  • out_num_dims – Required number of dimensions for corresponding output.

  • net – An nn.Module PyTorch net (the forward method of which must accept a Tensor input dict as parameter and must return a Tensor output dict)

normalized_forward(block_input: Dict[str, torch.Tensor]) → Dict[str, torch.Tensor]

(overrides ShapeNormalizationBlock)

implementation of ShapeNormalizationBlock interface