MultiCategoricalProbabilityDistribution¶
-
class
maze.distributions.multi_categorical.
MultiCategoricalProbabilityDistribution
(logits: torch.Tensor, action_space: gym.spaces.MultiDiscrete, temperature: float)¶ Multi-categorical probability distribution.
The respective functions either return aggregated properties across the sub-distributions using a reduce_fun such as mean or sum.
- Parameters
logits – The concatenated action selection logits for all sub spaces.
-
deterministic_sample
() → torch.Tensor¶ (overrides
ProbabilityDistribution
)implementation of
TorchProbabilityDistribution
interface
-
entropy
(reduce_fun: callable = torch.mean) → torch.Tensor¶ (overrides
ProbabilityDistribution
)implementation of
TorchProbabilityDistribution
interface
-
kl
(other: maze.distributions.multi_categorical.MultiCategoricalProbabilityDistribution, reduce_fun: callable = torch.mean) → torch.Tensor¶ (overrides
ProbabilityDistribution
)implementation of
TorchProbabilityDistribution
interface
-
log_prob
(actions: torch.Tensor) → torch.Tensor¶ (overrides
ProbabilityDistribution
)implementation of
TorchProbabilityDistribution
interface
-
neg_log_prob
(actions: torch.Tensor) → torch.Tensor¶ (overrides
ProbabilityDistribution
)implementation of
TorchProbabilityDistribution
interface
-
classmethod
required_logits_shape
(action_space: gym.spaces.MultiDiscrete) → Sequence[int]¶ implementation of
TorchProbabilityDistribution
interface
-
sample
() → torch.Tensor¶ (overrides
ProbabilityDistribution
)implementation of
TorchProbabilityDistribution
interface