GNNBlock

class maze.perception.blocks.feed_forward.graph_nn.GNNBlock(*args: Any, **kwargs: Any)

A customizable graph neural network (GNN) block.

Parameters
  • in_keys – One key identifying the input tensors.

  • out_keys – One key identifying the output tensors.

  • in_shapes – List of input shapes.

  • edges – List of graph edges required for message passing (aggregation).

  • aggregate – The aggregation function to use (max, mean, sum).

  • hidden_units – List containing the number of hidden units for hidden layers.

  • non_lin – The non-linearity to apply after each layer.

  • with_layer_norm – If True layer normalization is applied.

  • node2node_aggr – If True node to node message passing is applied.

  • edge2node_aggr – If True edge to node message passing is applied.

  • node2edge_aggr – If True node to edge message passing is applied.

  • edge2edge_aggr – If True edge to edge message passing is applied.

  • with_node_embedding – If True the node embedding is computed.

  • with_edge_embedding – If True the edge embedding is computed.

normalized_forward(block_input: Dict[str, torch.Tensor]) → Dict[str, torch.Tensor]

(overrides ShapeNormalizationBlock)

implementation of ShapeNormalizationBlock interface