Abstract Distributions¤
distreqx.distributions._distribution.AbstractDistribution
¤
Base class for all distreqx distributions.
log_prob(self, value: PyTree[Array]) -> PyTree[Array]
abstractmethod
¤
Calculates the log probability of an event.
Arguments:
value
: An event.
Returns:
- The log probability log P(value).
prob(self, value: PyTree[Array]) -> PyTree[Array]
abstractmethod
¤
Calculates the probability of an event.
Arguments:
value
: An event.
Returns:
- The probability P(value).
cdf(self, value: PyTree[Array]) -> PyTree[Array]
abstractmethod
¤
Evaluates the cumulative distribution function at value
.
Arguments:
value
: An event.
Returns:
- The CDF evaluated at value, i.e. P[X <= value].
survival_function(self, value: PyTree[Array]) -> PyTree[Array]
abstractmethod
¤
Evaluates the survival function at value
.
Note that by default we use a numerically not necessarily stable definition of the survival function in terms of the CDF. More stable definitions should be implemented in subclasses for distributions for which they exist.
Arguments:
value
: An event.
Returns:
- The survival function evaluated at
value
, i.e. P[X > value]
log_survival_function(self, value: PyTree[Array]) -> PyTree[Array]
abstractmethod
¤
Evaluates the log of the survival function at value
.
Note that by default we use a numerically not necessarily stable definition of the log of the survival function in terms of the CDF. More stable definitions should be implemented in subclasses for distributions for which they exist.
Arguments:
value
: An event.
Returns:
- The log of the survival function evaluated at
value
, i.e. log P[X > value]
kl_divergence(self, other_dist, **kwargs) -> PyTree[Array]
abstractmethod
¤
Calculates the KL divergence to another distribution.
Arguments:
other_dist
: A compatible distreqx Distribution.kwargs
: Additional kwargs.
Returns:
- The KL divergence
KL(self || other_dist)
.
cross_entropy(self, other_dist, **kwargs) -> Array
¤
Calculates the cross entropy to another distribution.
Arguments:
other_dist
: A compatible distreqx Distribution.kwargs
: Additional kwargs.
Returns:
- The cross entropy
H(self || other_dist)
.
distreqx.distributions._distribution.AbstractSampleLogProbDistribution (AbstractDistribution)
¤
Abstract distribution + concrete sample_and_log_prob
.
distreqx.distributions._distribution.AbstractProbDistribution (AbstractDistribution)
¤
Abstract distribution + concrete prob
.
prob(self, value: PyTree[Array]) -> PyTree[Array]
¤
distreqx.distributions._distribution.AbstractCDFDistribution (AbstractDistribution)
¤
Abstract distribution + concrete cdf
.
cdf(self, value: PyTree[Array]) -> PyTree[Array]
¤
distreqx.distributions._distribution.AbstractSTDDistribution (AbstractDistribution)
¤
Abstract distribution + concrete stddev
.
stddev(self) -> PyTree[Array]
¤
Calculate the standard deviation.
distreqx.distributions._distribution.AbstractSurivialDistribution (AbstractDistribution)
¤
distreqx.distributions.transformed.AbstractTransformed (AbstractSurivialDistribution, AbstractProbDistribution)
¤
Distribution of a random variable transformed by a bijective function.
Let X
be a continuous random variable and Y = f(X)
be a random variable
transformed by a differentiable bijection f
(a "bijector"). Given the
distribution of X
(the "base distribution") and the bijector f
, this class
implements the distribution of Y
(also known as the pushforward of the base
distribution through f
).
The probability density of Y
can be computed by:
log p(y) = log p(x) - log|det J(f)(x)|
where p(x)
is the probability density of X
(the "base density") and
J(f)(x)
is the Jacobian matrix of f
, both evaluated at x = f^{-1}(y)
.
Sampling from a Transformed distribution involves two steps: sampling from the
base distribution x ~ p(x)
and then evaluating y = f(x)
. For example:
dist = distrax.Normal(loc=0., scale=1.)
bij = distrax.ScalarAffine(shift=jnp.asarray([3., 3., 3.]))
transformed_dist = distrax.Transformed(distribution=dist, bijector=bij)
samples = transformed_dist.sample(jax.random.PRNGKey(0))
print(samples) # [2.7941577, 2.7941577, 2.7941577]
This assumes that the forward
function of the bijector is traceable; that is,
it is a pure function that does not contain run-time branching. Functions that
do not strictly meet this requirement can still be used, but we cannot guarantee
that the shapes, dtype, and KL computations involving the transformed distribution
can be correctly obtained.
distribution
property
readonly
¤
The base distribution.
bijector
property
readonly
¤
The bijector representing the transformation.
dtype: dtype
property
readonly
¤
See Distribution.dtype
.
event_shape: tuple
property
readonly
¤
See Distribution.event_shape
.
log_prob(self, value: Array) -> Array
¤
See Distribution.log_prob
.
sample(self, key: PRNGKeyArray) -> Array
¤
Return a sample.
sample_and_log_prob(self, key: PRNGKeyArray) -> tuple
¤
Return a sample and log prob.
This function is more efficient than calling sample
and log_prob
separately, because it uses only the forward methods of the bijector. It
also works for bijectors that don't implement inverse methods.
Arguments:
key
: PRNG key.
Returns:
- A tuple of a sample and its log probs.
entropy(self, input_hint: Optional[Array] = None) -> Array
¤
Calculates the Shannon entropy (in Nats).
Only works for bijectors with constant Jacobian determinant.
Arguments:
input_hint
: an example sample from the base distribution, used to compute the constant forward log-determinant. If not specified, it is computed using a zero array of the shape and dtype of a sample from the base distribution.
Returns:
- The entropy of the distribution.
Raises:
NotImplementedError
: if bijector's Jacobian determinant is not known to be constant.
distreqx.distributions.mvn_from_bijector.AbstractMultivariateNormalFromBijector (AbstractTransformed)
¤
AbstractMultivariateNormalFromBijector()
scale: AbstractLinearBijector
property
readonly
¤
The scale bijector.
loc: Array
property
readonly
¤
The loc
parameter of the distribution.
covariance(self) -> Array
¤
Calculates the covariance matrix.
Returns:
- The covariance matrix, of shape k x k
.
variance(self) -> Array
¤
Calculates the variance of all one-dimensional marginals.
stddev(self) -> Array
¤
Calculates the standard deviation (the square root of the variance).
kl_divergence(self, other_dist, **kwargs) -> Array
¤
Calculates the KL divergence to another distribution.
Arguments:
other_dist
: A compatible disteqx distribution.kwargs
: Additional kwargs.
Returns:
The KL divergence KL(self || other_dist)
.