MultiHeadAttentionBlock

class maze.perception.blocks.feed_forward.multi_head_attention.MultiHeadAttentionBlock(*args: Any, **kwargs: Any)

Implementation of a torch MultiHeadAttention block.

This Block wraps the torch.nn.MultiheadAttention. This block can then be used for 1d data as well as sequential data.

Parameters
  • in_keys – Keys identifying the input tensors. First key is the query, second is the key and the third input is the value. Additionally there is the optional attention mask which can be passed as an input to the block. - query: \((N, L, E)\) where L is the target sequence length, N is the batch size, E is the embedding dimension. - key: \((N, S, E)\), where S is the source sequence length, N is the batch size, E is the embedding dimension. - value: \((N, S, E)\) where S is the source sequence length, N is the batch size, E is the embedding dimension. - (Optional) attn_mask: 2D mask \((L, S)\) where L is the target sequence length, S is the source sequence length. 3D mask \((N, L, S)\) where N is the batch size, L is the target sequence length, S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked positions. If a Bool- or int- or float-Tensor is provided, positions with False/~1 is not allowed to attend while True/1 values will be unchanged. - (Optional) key_padding_mask: \((N, S)\) where N is the batch size, S is the source sequence length. If a Bool- or int- or float-Tensor is provided, the positions with the value of False/~1 will be ignored while the position with the value of True/1 will be unchanged.

  • out_keys – Keys identifying the output tensors. First key is self-attention output, second optional key is attention map. - attn_output: \((N, L, E)\) where L is the target sequence length, N is the batch size, E is the embedding dimension. - attn_output_weights: \((N, L, S)\) where N is the batch size, L is the target sequence length, S is the source sequence length.

  • in_shapes – List of input shapes.

  • num_heads – Parallel attention heads.

  • dropout – A dropout layer on attn_output_weights.

  • bias – Add bias as module parameter.

  • add_bias_kv – Add bias to the key and value sequences at dim=0.

  • add_zero_attn – Add a new batch of zeros to the key and value sequences at dim=1.

  • kdim – Total number of features in key. Default: None.

  • vdim – Total number of features in value. Default: None.

  • use_key_padding_mask – Specify wether a key padding mask is being used.

Note: If kdim and vdim are None, they will be set to embed_dim such that query, key, and value have the same number

of features.

normalized_forward(block_input: Dict[str, torch.Tensor]) → Dict[str, torch.Tensor]

(overrides ShapeNormalizationBlock)

implementation of PerceptionBlock interface