InMemoryDataset

class maze.core.trajectory_recording.datasets.in_memory_dataset.InMemoryDataset(*args: Any, **kwargs: Any)

Base class of trajectory data set for imitation learning that keeps all loaded data in memory.

Provides the main functionality for parsing and appending records.

Parameters
  • input_data – The optional input data to fill the dataset with. This can be either a single file, a single directory, a list of files or a list of directories.

  • conversion_env_factory – Function for creating an environment for state and action conversion. For Maze envs, the environment configuration (i.e. space interfaces, wrappers etc.) determines the format of the actions and observations that will be derived from the recorded MazeActions and MazeStates (e.g. multi-step observations/actions etc.).

  • n_workers – Number of worker processes to load data in.

  • trajectory_processor – The processor object for processing and converting individual trajectories.

  • deserialize_in_main_thread – Specify whether to deserialize the trajectories in the main thread (True) or in in the workers. In case only one trajectory file is given and this file holds many trajectories which all have to be converted with the conversion env setting this value to true makes sense as the expensive operation is the conversion. However if many files are given where no conversion is necessary the expensive operation is the deserialization, and thus should happen in the worker threads. Only relevant if n_workers > 1.

append(trajectory: maze.core.trajectory_recording.records.trajectory_record.TrajectoryRecord)None

Append a new trajectory to the dataset.

Parameters

trajectory – Trajectory to append.

static deserialize_trajectory(trajectory_file: Union[str, pathlib.Path]) → Generator[maze.core.trajectory_recording.records.trajectory_record.TrajectoryRecord, None, None]

Deserialize all trajectories located in the given file path.

Will attempt to load the given trajectory file. Supports pickled TrajectoryRecords, or lists or dictionaries containing TrajectoryRecords as values.

Returns a generator that will yield the individual trajectory records, no matter in which form (i.e., individual, list, or dict) they were loaded.

Parameters

trajectory_file – File to load trajectory data from.

Returns

Generator yielding the individual trajectory records.

static list_trajectory_files(data_dir: Union[str, pathlib.Path]) → List[pathlib.Path]

List pickle files (“pkl” suffix, used for trajectory data storage by default) in the given directory.

Parameters

data_dir – Where to look for the trajectory records (= pickle files).

Returns

A list of available pkl files in the given directory.

load_data(input_data: Union[str, pathlib.Path, List[Union[str, pathlib.Path]]])None

Load the trajectory data from the given file or directory and append it to the dataset.

Should provide the main logic of how the data load is done to be efficient for the data at hand (e.g. splitting it up into multiple parallel workers). Otherwise, this class already provides multiple helper methods useful for loading (e.g. for deserializing different structured of trajectories or converting maze states to raw observations).

Parameters

input_data – Input data to load the trajectories from. This can be either a single file, a single directory, a list of files or a list of directories.

random_split(lengths: Sequence[int], generator: torch.Generator = torch.default_generator) → List[torch.utils.data.Subset]

Randomly split the dataset into non-overlapping new datasets of given lengths.

The split is based on episodes – samples from the same episode will end up in the same subset. Based on the available episode lengths, this might result in subsets of slightly different lengths than specified.

Optionally fix the generator for reproducible results, e.g.:

self.random_split([3, 7], generator=torch.Generator().manual_seed(42))

Parameters
  • lengths – lengths of splits to be produced (best effort, the result might differ based on available episode lengths

  • generator – Generator used for the random permutation.

Returns

A list of the data subsets, each with size roughly (!) corresponding to what was specified by lengths.