class maze.core.trajectory_recording.datasets.in_memory_dataset.InMemoryDataset(*args: Any, **kwargs: Any)

Base class of trajectory data set for imitation learning that keeps all loaded data in memory.

Provides the main functionality for parsing and appending records.

  • dir_or_file – Directory or file containing the trajectory data. If present, these data will be loaded on init.

  • conversion_env_factory – Function for creating an environment for state and action conversion. For Maze envs, the environment configuration (i.e. space interfaces, wrappers etc.) determines the format of the actions and observations that will be derived from the recorded MazeActions and MazeStates (e.g. multi-step observations/actions etc.).

  • n_workers – Number of worker processes to load data in.

append(trajectory: maze.core.trajectory_recording.records.trajectory_record.TrajectoryRecord)None

Append a new trajectory to the dataset.


trajectory – Trajectory to append.

static convert_trajectory(trajectory: maze.core.trajectory_recording.records.trajectory_record.TrajectoryRecord, conversion_env: Optional[maze.core.env.maze_env.MazeEnv]) → List[maze.core.trajectory_recording.records.structured_spaces_record.StructuredSpacesRecord]

Convert an episode trajectory record into an array of observations and actions using the given env.

  • trajectory – Episode record to load

  • conversion_env – Env to use for conversion of MazeStates and MazeActions into observations and actions. Required only if state records are being loaded (i.e. conversion to raw actions and observations is needed).


Loaded observations and actions. I.e., a tuple (observation_list, action_list). Each of the lists contains observation/action dictionaries, with keys corresponding to IDs of structured sub-steps. (I.e., the dictionary will have just one entry for non-structured scenarios.)

static deserialize_trajectories(dir_or_file: Union[str, pathlib.Path]) → Generator[maze.core.trajectory_recording.records.trajectory_record.TrajectoryRecord, None, None]

Deserialize all trajectories located in a particular directory or file.

If a file path is passed in, will attempt to load it. Supports pickled TrajectoryRecords, or lists or dictionaries containing TrajectoryRecords as values.

If a directory is passed in, locates all pickle files (with pkl suffix) in this directory, then attempts to load each of them (again supporting also lists and dictionaries of trajectory records.

Returns a generator that will yield the individual trajectory records, no matter in which form (i.e., individual, list, or dict) they were loaded.


dir_or_file – Directory of file to load trajectory data from


Generator yielding the individual trajectory records.

static list_trajectory_files(data_dir: Union[str, pathlib.Path]) → List[pathlib.Path]

List pickle files (“pkl” suffix, used for trajectory data storage by default) in the given directory.


data_dir – Where to look for the trajectory records (= pickle files).


A list of available pkl files in the given directory.

load_data(dir_or_file: Union[str, pathlib.Path])None

Load the trajectory data from the given file or directory and append it to the dataset.

Should provide the main logic of how the data load is done to be efficient for the data at hand (e.g. splitting it up into multiple parallel workers). Otherwise, this class already provides multiple helper methods useful for loading (e.g. for deserializing different structured of trajectories or converting maze states to raw observations).


dir_or_file – Directory or file to load the trajectory data from.

random_split(lengths: Sequence[int], generator: torch.Generator = torch.default_generator) → List[]

Randomly split the dataset into non-overlapping new datasets of given lengths.

The split is based on episodes – samples from the same episode will end up in the same subset. Based on the available episode lengths, this might result in subsets of slightly different lengths than specified.

Optionally fix the generator for reproducible results, e.g.:

self.random_split([3, 7], generator=torch.Generator().manual_seed(42))

  • lengths – lengths of splits to be produced (best effort, the result might differ based on available episode lengths

  • generator – Generator used for the random permutation.


A list of the data subsets, each with size roughly (!) corresponding to what was specified by lengths.