Adding Step-Conditional Action Masking

In this part of the tutorial we will learn how to substantially increase the sample efficiency of our agents by adding sub-step conditional action masking to the structured environment.

The complete code for this part of the tutorial can be found here

# relevant files
- cutting_2d
    - main.py
    - env
        - struct_env_masked.py

In particular, we will add two different masks:

  • Inventory_mask: allows to only select cutting pieces from inventory slots actually holding a piece that would allow to fulfill the customer order.

  • Rotation_mask: allows to only specify valid cutting rotations (e.g., the ordered piece fits into the cutting piece from inventory). Note that providing this mask is only possible once the cutting piece has been selected in the first sub-step - hence the name step-conditional masking.

The figure below provides a sketch of the two masks.

../../_images/cutting2d_masking.png

Only the first two inventory pieces are able to fit the customer order. The four rightmost inventory slots do not hold a piece at all and are also masked. When rotating the piece by 90° for cutting the customer order would not fit into the selected inventory piece which is why we can simply mask this option.

Masked Structured Environment

One way to incorporate the two masks in our structured environment is to simply inherit from the initial version and extend it by the following changes:

  • Add the two masks to the observation spaces (e.g., inventory_mask and cutting_mask)

  • Compute the actual mask for the two sub-steps in the respective functions (e.g., _obs_selection_step and _obs_cutting_step).

env/struct_env_masked.py
from copy import deepcopy
from typing import Dict, List, Tuple

import gym
import numpy as np
from tutorial_maze_env.part06_struct_env.env.maze_env import maze_env_factory
from tutorial_maze_env.part06_struct_env.env.struct_env import StructuredCutting2DEnvironment
from maze.core.env.maze_env import MazeEnv


class MaskedStructuredCutting2DEnvironment(StructuredCutting2DEnvironment):
    """Structured environment version of the cutting 2D environment.
    The environment alternates between the two sub-steps:

    - Select cutting piece
    - Select cutting configuration (cutting order and cutting orientation)

    :param maze_env: The "flat" cutting 2D environment to wrap.
    """

    def __init__(self, maze_env: MazeEnv):
        super().__init__(maze_env)

        # add masks to observation spaces
        max_inventory = self.observation_conversion.max_pieces_in_inventory
        self._observation_spaces_dict[0].spaces["inventory_mask"] = \
            gym.spaces.Box(low=np.float32(0), high=np.float32(1), shape=(max_inventory,), dtype=np.float32)

        self._observation_spaces_dict[1].spaces["cutting_mask"] = \
            gym.spaces.Box(low=np.float32(0), high=np.float32(1), shape=(2,), dtype=np.float32)

    @staticmethod
    def _obs_selection_step(flat_obs: Dict[str, np.array]) -> Dict[str, np.array]:
        """Formats initial observation / observation available for the first sub-step."""
        observation = deepcopy(flat_obs)

        # prepare inventory mask
        sorted_order = np.sort(observation["ordered_piece"].flatten())
        sorted_inventory = np.sort(observation["inventory"], axis=1)

        observation["inventory_mask"] = np.all(observation["inventory"] > 0, axis=1).astype(np.float32)
        for i in np.nonzero(observation["inventory_mask"])[0]:
            # exclude pieces which do not fit
            observation["inventory_mask"][i] = np.all(sorted_order <= sorted_inventory[i])

        return observation

    @staticmethod
    def _obs_cutting_step(flat_obs: Dict[str, np.array], selected_piece_idx: int) -> Dict[str, np.array]:
        """Formats observation available for the second sub-step."""

        selected_piece = flat_obs["inventory"][selected_piece_idx]
        ordered_piece = flat_obs["ordered_piece"]

        # prepare cutting action mask
        cutting_mask = np.zeros((2,), dtype=np.float32)

        selected_piece = selected_piece.squeeze()
        if np.all(flat_obs["ordered_piece"] <= selected_piece):
            cutting_mask[0] = 1.0

        if np.all(flat_obs["ordered_piece"][::-1] <= selected_piece):
            cutting_mask[1] = 1.0

        return {"selected_piece": selected_piece,
                "ordered_piece": ordered_piece,
                "cutting_mask": cutting_mask}


def struct_env_factory(max_pieces_in_inventory: int, raw_piece_size: Tuple[int, int],
                       static_demand: List[Tuple[int, int]]) -> StructuredCutting2DEnvironment:
    """Convenience factory function that compiles a trainable structured environment.
    (for argument details see: Cutting2DEnvironment)
    """

    # init maze environment including observation and action interfaces
    env = maze_env_factory(max_pieces_in_inventory=max_pieces_in_inventory,
                           raw_piece_size=raw_piece_size,
                           static_demand=static_demand)

    # convert flat to structured environment
    return MaskedStructuredCutting2DEnvironment(env)

Test Script

When re-running the main script of the previous section with the masked version of the structured environment we now get the following output:

action_space 1:      Dict(piece_idx:Discrete(200))
observation_space 1: Dict(inventory:Box(200, 2), inventory_size:Box(1,), ordered_piece:Box(2,), inventory_mask:Box(200,))
observation 1:       dict_keys(['inventory', 'inventory_size', 'ordered_piece', 'inventory_mask'])
action_space 2:      Dict(cut_order:Discrete(2), cut_rotation:Discrete(2))
observation_space 2: Dict(ordered_piece:Box(2,), selected_piece:Box(2,), cutting_mask:Box(2,))
observation 2:       dict_keys(['selected_piece', 'ordered_piece', 'cutting_mask'])

As expected, both masks are contained in the respective observations and spaces. In the next section we will utilize these masks to enhance the sample efficiency ouf our trainers.