TorchProbabilityDistribution¶
- class maze.distributions.torch_dist.TorchProbabilityDistribution(dist: T, action_space: gymnasium.spaces.Space)¶
Base class for wrapping Torch probability distributions.
- Parameters:
dist – The torch probability distribution.
action_space – The gym action space.
- entropy() torch.Tensor¶
(overrides
ProbabilityDistribution)implementation of
ProbabilityDistributioninterface
- kl(other: TorchProbabilityDistribution) torch.Tensor¶
(overrides
ProbabilityDistribution)implementation of
ProbabilityDistributioninterface
- log_prob(actions: torch.Tensor) torch.Tensor¶
(overrides
ProbabilityDistribution)implementation of
ProbabilityDistributioninterface
- abstract classmethod required_logits_shape(action_space: gymnasium.spaces.Space) Sequence[int]¶
Returns the required shape for the corresponding neural network logits output.
- Parameters:
action_space – The respective action space to compute logits for.
- Returns:
The required logits shape.
- sample() torch.Tensor¶
(overrides
ProbabilityDistribution)implementation of
ProbabilityDistributioninterface