ActionMaskingBlock

class maze.perception.blocks.general.action_masking.ActionMaskingBlock(*args: Any, **kwargs: Any)

An action masking block.

The block takes two keys as input where the first key contains the logits tensor and the second key contains the binary mask tensor. Masking is performed by adding the smallest possible float32 number to the logits where the corresponding mask value is False (0.0).

Parameters
  • in_keys – Keys identifying the input tensors.

  • out_keys – Keys identifying the output tensors.

  • in_shapes – List of input shapes.

forward(block_input: Dict[str, torch.Tensor]) → Dict[str, torch.Tensor]

(overrides PerceptionBlock)

implementation of PerceptionBlock interface