Plain Python Training Example (low-level)

This tutorial demonstrates how to train an A2C agent with Maze in plain Python without utilizing RunContext. In the process it introduces and explains some of Maze’ most important components and concepts.

This is complementary to the article on high-level training in plain Python, which guides through the same setup (but with RunContext support).

Environment Setup

We will first prepare our environment for use with Maze. In order to use Maze’s parallelization capabilities, it is necessary to define a factory function that returns a MazeEnv of your environment. This is easily done for Gym environments:

def cartpole_env_factory():
    """ Env factory for the cartpole MazeEnv """
    # Registered gym environments can be instantiated first and then provided to GymMazeEnv:
    cartpole_env = gym.make("CartPole-v0")
    maze_env = GymMazeEnv(env=cartpole_env)

    # Another possibility is to supply the gym env string to GymMazeEnv directly:
    maze_env = GymMazeEnv(env="CartPole-v0")

    return maze_env

If you have your own environment (that is not a gym.Env) you must transform it into a MazeEnv yourself, as is shown here, and have your factory return that. If it is a custom gym env it can be instantiated with our wrapper as shown above.

We instantiate one environment. This will be used for convenient access to observation and action spaces later.

env = cartpole_env_factory()
observation_space = env.observation_space
action_space = env.action_space

Model Setup

Now that the environment setup is done, let us develop the policy and value networks that will be used. We will pay special attention to emphasize the format required by Maze. When creating your own models, it is important to know two things:

  1. Maze works with dictionaries throughout, which means that arguments for the constructor and the input and return values of the forward method are dicts with user-defined keys.

  2. Policy networks and value network constructors have required arguments: for policy nets, these are obs_shapes and action_logit_dicts, for value nets, this is obs_shapes.

The required format is explained in more detail here. With this in mind, let us create a simple linear mapping network with the required constraints:

class CartpolePolicyNet(nn.Module):
    """ Simple linear policy net for demonstration purposes. """
    def __init__(self, obs_shapes: Sequence[int], action_logit_shapes: Sequence[int]):
        super().__init__() = nn.Sequential(
            nn.Linear(in_features=obs_shapes[0], out_features=action_logit_shapes[0])

    def forward(self, x_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        # Since x_dict has to be a dictionary in Maze, we extract the input for the network.
        x = x_dict['observation']

        # Do the forward pass.
        logits =

        # Since the return value has to be a dict again, put the
        # forward pass result into a dict with the  correct key.
        logits_dict = {'action': logits}

        return logits_dict

# Instantiate our custom policy net.
policy_net = CartpolePolicyNet(


class CartpoleValueNet(nn.Module):
    """ Simple linear value net for demonstration purposes. """
    def __init__(self, obs_shapes: Sequence[int]):
        self.value_net = nn.Sequential(nn.Linear(in_features=obs_shapes[0], out_features=1))

    def forward(self, x_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """ Forward method. """
        # The same as for the policy can be said about the value
        # net: Inputs and outputs have to be dicts.
        x = x_dict['observation']

        value = self.value_net(x)

        value_dict = {'value': value}
        return value_dict

Policy Setup

For a policy, we need a parametrization for the policy (provided by the policy network) and a probability distribution we can sample from. We will subsequently define and instantiate each of these.

Policy Network

Instantiate a policy with the correct shapes of observation and action spaces.

policy_net = WrappedCartpolePolicyNet(

We can use one of Mazes capabilities, shape normalization (see ShapeNormalizationBlock), with these models by wrapping them with the TorchModelBlock.

maze_wrapped_policy_net = TorchModelBlock(
    in_keys='observation', out_keys='action',
    in_shapes=observation_space.spaces['observation'].shape, in_num_dims=[2],
    out_num_dims=2, net=policy_net)

Since Maze offers the capability of supporting multiple actors, we need to map each policy_net to its corresponding actor ID. As we have only one policy, this is a trivial mapping:

policy_networks = {0: maze_wrapped_policy_net}

Policy Distribution

Initializing the proper probability distribution for the policy is rather easy with Maze. Simply provide the DistributionMapper with the action space and you automatically get the proper distribution to use.

distribution_mapper = DistributionMapper(action_space=action_space, distribution_mapper_config={})

Optionally, you can specify a different distribution with the distribution_mapper_config argument. Using a Categorical distribution for a discrete action space would be done with

distribution_mapper = DistributionMapper(
        "action_space": gym.spaces.Discrete,
        "distribution": "maze.distributions.categorical.CategoricalProbabilityDistribution"}])

Since the standard distribution taken by Maze for a discrete action space is a Categorical distribution anyway (as can be seen here), both definitions of the distribution_mapper have the same result. For more information about the DistributionMapper, see Action Spaces and Distributions.

Instantiating the Policy

We have both necessary ingredients to define a policy: a parametrization, given by the policy network, and a distribution. With these, we can instantiate a policy. This is done with the TorchPolicy class:

torch_policy = TorchPolicy(networks=policy_networks,

Critic Setup

The setup of a critic (or value function) is similar to the setup of a policy, the main difference being that we do not need a probability distribution.

Value Network

value_net = WrappedCartpoleValueNet(obs_shapes=observation_space.spaces['observation'].shape)

maze_wrapped_value_net = TorchModelBlock(
    in_keys='observation', out_keys='value',
    in_shapes=observation_space.spaces['observation'].shape, in_num_dims=[2],
    out_num_dims=2, net=value_net)

value_networks = {0: maze_wrapped_value_net}

Instantiating the Critic

This step is analogous to the instantiation of the policy above. In Maze, critics can have different forms (see Value Functions (Critics)). Here, we use a simple shared critic. Shared means that the same critic will be used for all sub-steps (in a multi-step setting) and all actors. Since we only have one actor in this example and are in a one-step setting, the TorchSharedStateCritic reduces to a vanilla StateCritic (aka a state-dependent value function).

torch_critic = TorchSharedStateCritic(networks=value_networks, num_policies=1, device='cpu')

Initializing the ActorCritic Model.

In Maze, policies and critics are encapsulated by an ActorCritic model. Details about this can be found in Actor-Critics. We will use A2C to train the cartpole env. The correct ActorCritic model to use for A2C is the TorchActorCritic:

actor_critic_model = TorchActorCritic(policy=torch_policy, critic=torch_critic, device='cpu')

Trainer Setup

The last steps will be the instantiations of the algorithm and corresponding trainer. We use A2C for this example. The algorithm_config for A2C can be found here. The hyperparameters will be supplied to Maze with an algorithm-dependent AlgorithmConfig object. The one for A2C is A2CAlgorithmConfig. We will use the default parameters, which can also be found here.

algorithm_config = A2CAlgorithmConfig(

In order to use the distributed trainers, we create a vector environment (i.e., multiple environment instances encapsulated to be stepped simultaneously) using the environment factory function:

train_envs = SequentialVectorEnv(
    [cartpole_env_factory for _ in range(2)], logging_prefix="train")
eval_envs = SequentialVectorEnv(
    [cartpole_env_factory for _ in range(2)], logging_prefix="eval")

(In this case, we create sequential vector environments, i.e. all environment instances are located in the main process and stepped sequentially. When we are ready to scale the training, we might want to use e.g. sub-process distributed vector environments.)

For this example, we want to save the parameters of the best model in terms of mean achieved reward. This is done with the BestModelSelection class, an instance of which will be provided to the trainer.

model_selection = BestModelSelection(dump_file="", model=actor_critic_model)

We can now instantiate an A2C trainer:

a2c_trainer = A2C(

Train the Agent

Before starting the training, we will enable logging by calling

log_dir = '.'
setup_logging(job_config=None, log_dir=log_dir)

Now, we can train the agent.


To get an out-of sample estimate of our performance, evaluate on the evaluation envs:

a2c_trainer.evaluate(deterministic=False, repeats=1)

Full Python Code

Here is the code without documentation for easier copy-pasting:

""" Rollout of a policy in plain Python. """

from typing import Dict, Sequence

import gym
import torch
import torch.nn as nn

from maze.core.agent.torch_actor_critic import TorchActorCritic
from maze.core.agent.torch_policy import TorchPolicy
from maze.core.agent.torch_state_critic import TorchSharedStateCritic
from maze.core.rollout.rollout_generator import RolloutGenerator
from maze.core.wrappers.maze_gym_env_wrapper import GymMazeEnv
from maze.distributions.distribution_mapper import DistributionMapper
from maze.perception.blocks.general.torch_model_block import TorchModelBlock
from maze.train.parallelization.vector_env.sequential_vector_env import SequentialVectorEnv
from maze.train.trainers.a2c.a2c_algorithm_config import A2CAlgorithmConfig
from maze.train.trainers.a2c.a2c_trainer import A2C
from maze.train.trainers.common.evaluators.rollout_evaluator import RolloutEvaluator
from maze.train.trainers.common.model_selection.best_model_selection import BestModelSelection
from maze.utils.log_stats_utils import setup_logging

# Environment Setup
# =================

# Environment Factory
# -------------------
# Define environment factory
def cartpole_env_factory():
    """ Env factory for the cartpole MazeEnv """
    # Registered gym environments can be instantiated first and then provided to GymMazeEnv:
    cartpole_env = gym.make("CartPole-v0")
    maze_env = GymMazeEnv(env=cartpole_env)

    # Another possibility is to supply the gym env string to GymMazeEnv directly:
    maze_env = GymMazeEnv(env="CartPole-v0")

    return maze_env

# Model Setup
# ===========
# Policy Network
# --------------
class CartpolePolicyNet(nn.Module):
    """ Simple linear policy net for demonstration purposes. """

    def __init__(self, obs_shapes: Dict[str, Sequence[int]], action_logit_shapes: Dict[str, Sequence[int]]):
        super().__init__() = nn.Sequential(

    def forward(self, x_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        # Since x_dict has to be a dictionary in Maze, we extract the input for the network.
        x = x_dict['observation']

        # Do the forward pass.
        logits =

        # Since the return value has to be a dict again, put the forward pass result into a dict with the
        # correct key.
        logits_dict = {'action': logits}

        return logits_dict

# Value Network
# -------------
class CartpoleValueNet(nn.Module):
    """ Simple linear value net for demonstration purposes. """

    def __init__(self, obs_shapes: Dict[str, Sequence[int]]):
        self.value_net = nn.Sequential(nn.Linear(in_features=obs_shapes['observation'][0], out_features=1))

    def forward(self, x_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """ Forward method. """
        # The same as for the policy can be said about the value net. Inputs and outputs have to be dicts.
        x = x_dict['observation']

        value = self.value_net(x)

        value_dict = {'value': value}
        return value_dict

def train(n_epochs):
    # Instantiate one environment. This will be used for convenient access to observation
    # and action spaces.
    env = cartpole_env_factory()
    observation_space = env.observation_space
    action_space = env.action_space

    # Policy Setup
    # ------------

    # Policy Network
    # ^^^^^^^^^^^^^^
    # Instantiate policy with the correct shapes of observation and action spaces.
    policy_net = CartpolePolicyNet(
        obs_shapes={'observation': observation_space.spaces['observation'].shape},
        action_logit_shapes={'action': (action_space.spaces['action'].n,)})

    maze_wrapped_policy_net = TorchModelBlock(
        in_keys='observation', out_keys='action',
        in_shapes=observation_space.spaces['observation'].shape, in_num_dims=[2],
        out_num_dims=2, net=policy_net)

    policy_networks = {0: maze_wrapped_policy_net}

    # Policy Distribution
    # ^^^^^^^^^^^^^^^^^^^
    distribution_mapper = DistributionMapper(

    # Optionally, you can specify a different distribution with the distribution_mapper_config argument. Using a
    # Categorical distribution for a discrete action space would be done via
    distribution_mapper = DistributionMapper(
            "action_space": gym.spaces.Discrete,
            "distribution": "maze.distributions.categorical.CategoricalProbabilityDistribution"}])

    # Instantiating the Policy
    # ^^^^^^^^^^^^^^^^^^^^^^^^
    torch_policy = TorchPolicy(networks=policy_networks, distribution_mapper=distribution_mapper, device='cpu')

    # Value Function Setup
    # --------------------

    # Value Network
    # ^^^^^^^^^^^^^
    value_net = CartpoleValueNet(obs_shapes={'observation': observation_space.spaces['observation'].shape})

    maze_wrapped_value_net = TorchModelBlock(
        in_keys='observation', out_keys='value',
        in_shapes=observation_space.spaces['observation'].shape, in_num_dims=[2],
        out_num_dims=2, net=value_net)

    value_networks = {0: maze_wrapped_value_net}

    # Instantiate the Value Function
    # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    torch_critic = TorchSharedStateCritic(networks=value_networks, obs_spaces_dict=env.observation_spaces_dict,
                                          device='cpu', stack_observations=False)

    # Initializing the ActorCritic Model.
    # -----------------------------------
    actor_critic_model = TorchActorCritic(policy=torch_policy, critic=torch_critic, device='cpu')

    # Instantiating the Trainer
    # =========================

    algorithm_config = A2CAlgorithmConfig(

    # Distributed Environments
    # ------------------------
    # In order to use the distributed trainers, the previously created env factory is supplied to one of Maze's
    # distribution classes:
    train_envs = SequentialVectorEnv([cartpole_env_factory for _ in range(2)], logging_prefix="train")
    eval_envs = SequentialVectorEnv([cartpole_env_factory for _ in range(2)], logging_prefix="eval")

    # Initialize best model selection.
    model_selection = BestModelSelection(dump_file="", model=actor_critic_model)

    a2c_trainer = A2C(rollout_generator=RolloutGenerator(train_envs),

    # Train the Agent
    # ===============
    # Before starting the training, we will enable logging by calling
    log_dir = '.'
    setup_logging(job_config=None, log_dir=log_dir)

    # Now, we can train the agent.

    return 0

if __name__ == '__main__':