MultiCategoricalProbabilityDistribution¶
- class maze.distributions.multi_categorical.MultiCategoricalProbabilityDistribution(logits: torch.Tensor, action_space: gymnasium.spaces.MultiDiscrete, temperature: float)¶
Multi-categorical probability distribution.
The respective functions either return aggregated properties across the sub-distributions using a reduce_fun such as mean or sum.
- Parameters:
logits – The concatenated action selection logits for all sub spaces.
- deterministic_sample() torch.Tensor¶
(overrides
ProbabilityDistribution)implementation of
TorchProbabilityDistributioninterface
- entropy(reduce_fun: callable = torch.mean) torch.Tensor¶
(overrides
ProbabilityDistribution)implementation of
TorchProbabilityDistributioninterface
- kl(other: MultiCategoricalProbabilityDistribution, reduce_fun: callable = torch.mean) torch.Tensor¶
(overrides
ProbabilityDistribution)implementation of
TorchProbabilityDistributioninterface
- log_prob(actions: torch.Tensor) torch.Tensor¶
(overrides
ProbabilityDistribution)implementation of
TorchProbabilityDistributioninterface
- neg_log_prob(actions: torch.Tensor) torch.Tensor¶
(overrides
ProbabilityDistribution)implementation of
TorchProbabilityDistributioninterface
- classmethod required_logits_shape(action_space: gymnasium.spaces.MultiDiscrete) Sequence[int]¶
implementation of
TorchProbabilityDistributioninterface
- sample() torch.Tensor¶
(overrides
ProbabilityDistribution)implementation of
TorchProbabilityDistributioninterface