MultiCategoricalProbabilityDistribution

class maze.distributions.multi_categorical.MultiCategoricalProbabilityDistribution(logits: torch.Tensor, action_space: gym.spaces.MultiDiscrete, temperature: float)

Multi-categorical probability distribution.

The respective functions either return aggregated properties across the sub-distributions using a reduce_fun such as mean or sum.

Parameters

logits – The concatenated action selection logits for all sub spaces.

deterministic_sample()torch.Tensor

(overrides ProbabilityDistribution)

implementation of TorchProbabilityDistribution interface

entropy(reduce_fun: callable = torch.mean)torch.Tensor

(overrides ProbabilityDistribution)

implementation of TorchProbabilityDistribution interface

kl(other: maze.distributions.multi_categorical.MultiCategoricalProbabilityDistribution, reduce_fun: callable = torch.mean)torch.Tensor

(overrides ProbabilityDistribution)

implementation of TorchProbabilityDistribution interface

log_prob(actions: torch.Tensor)torch.Tensor

(overrides ProbabilityDistribution)

implementation of TorchProbabilityDistribution interface

neg_log_prob(actions: torch.Tensor)torch.Tensor

(overrides ProbabilityDistribution)

implementation of TorchProbabilityDistribution interface

classmethod required_logits_shape(action_space: gym.spaces.MultiDiscrete) → Sequence[int]

implementation of TorchProbabilityDistribution interface

sample()torch.Tensor

(overrides ProbabilityDistribution)

implementation of TorchProbabilityDistribution interface