DistributionMapper

class maze.distributions.distribution_mapper.DistributionMapper(action_space: gym.spaces.Dict, distribution_mapper_config: Union[List[Union[None, Mapping[str, Any], Any]], Mapping[Union[str, Type], Union[None, Mapping[str, Any], Any]]])

Provides a mapping of spaces and action heads to the respective probability distributions to be used.

This ensures full flexibility for specifying different distributions to the same gym action space type. (e.g. One gym.spaces.Box space could be modeled with a Beta another one with a DiagonalGaussian distribution.) It allows to add and register arbitrary custom distributions.

Parameters
  • action_space – The dictionary action space.

  • distribution_mapper_config – A Distribution mapper configuration (for details see the docs).

action_head_distribution(action_head: str, logits: torch.Tensor, temperature: float)maze.distributions.torch_dist.TorchProbabilityDistribution

Creates a probability distribution for a given action head.

Parameters
  • action_head – The name of the action head (action dictionary key).

  • logits – the logits to parameterize the distribution from

  • temperature – Controls the sampling behaviour * 1.0 corresponds to unmodified sampling * smaller than 1.0 concentrates the action distribution towards deterministic sampling

Returns

(ProbabilityDistribution) the appropriate instance of a ProbabilityDistribution

logits_dict_to_distribution(logits_dict: Dict[str, torch.Tensor], temperature: float)maze.distributions.dict.DictProbabilityDistribution

Creates a dictionary probability distribution for a given logits dictionary.

Parameters
  • logits_dict – A logits dictionary [action_head: action_logits] to parameterize the distribution from.

  • temperature – Controls the sampling behaviour. * 1.0 corresponds to unmodified sampling * smaller than 1.0 concentrates the action distribution towards deterministic sampling

Returns

(DictProbabilityDistribution) the respective instance of a DictProbabilityDistribution.

required_logits_shape(action_head: str) → Sequence[int]

Returns the required logits shape (network output shape) for a given action head.

Parameters

action_head – The name of the action head (action dictionary key).

Returns

The required logits shape.