Skip to content

Abstract Distributions¤

distreqx.distributions._distribution.AbstractDistribution ¤

Base class for all distreqx distributions.

log_prob(self, value: PyTree[Array]) -> PyTree[Array] abstractmethod ¤

Calculates the log probability of an event.

Arguments:

  • value: An event.

Returns:

  • The log probability log P(value).
prob(self, value: PyTree[Array]) -> PyTree[Array] abstractmethod ¤

Calculates the probability of an event.

Arguments:

  • value: An event.

Returns:

  • The probability P(value).
cdf(self, value: PyTree[Array]) -> PyTree[Array] abstractmethod ¤

Evaluates the cumulative distribution function at value.

Arguments:

  • value: An event.

Returns:

  • The CDF evaluated at value, i.e. P[X <= value].
survival_function(self, value: PyTree[Array]) -> PyTree[Array] abstractmethod ¤

Evaluates the survival function at value.

Note that by default we use a numerically not necessarily stable definition of the survival function in terms of the CDF. More stable definitions should be implemented in subclasses for distributions for which they exist.

Arguments:

  • value: An event.

Returns:

  • The survival function evaluated at value, i.e. P[X > value]
log_survival_function(self, value: PyTree[Array]) -> PyTree[Array] abstractmethod ¤

Evaluates the log of the survival function at value.

Note that by default we use a numerically not necessarily stable definition of the log of the survival function in terms of the CDF. More stable definitions should be implemented in subclasses for distributions for which they exist.

Arguments:

  • value: An event.

Returns:

  • The log of the survival function evaluated at value, i.e. log P[X > value]
kl_divergence(self, other_dist, **kwargs) -> PyTree[Array] abstractmethod ¤

Calculates the KL divergence to another distribution.

Arguments:

  • other_dist: A compatible distreqx Distribution.
  • kwargs: Additional kwargs.

Returns:

  • The KL divergence KL(self || other_dist).
cross_entropy(self, other_dist, **kwargs) -> Array ¤

Calculates the cross entropy to another distribution.

Arguments:

  • other_dist: A compatible distreqx Distribution.
  • kwargs: Additional kwargs.

Returns:

  • The cross entropy H(self || other_dist).

distreqx.distributions._distribution.AbstractSampleLogProbDistribution (AbstractDistribution) ¤

Abstract distribution + concrete sample_and_log_prob.

distreqx.distributions._distribution.AbstractProbDistribution (AbstractDistribution) ¤

Abstract distribution + concrete prob.

prob(self, value: PyTree[Array]) -> PyTree[Array] ¤

distreqx.distributions._distribution.AbstractCDFDistribution (AbstractDistribution) ¤

Abstract distribution + concrete cdf.

cdf(self, value: PyTree[Array]) -> PyTree[Array] ¤

distreqx.distributions._distribution.AbstractSTDDistribution (AbstractDistribution) ¤

Abstract distribution + concrete stddev.

stddev(self) -> PyTree[Array] ¤

Calculate the standard deviation.

distreqx.distributions._distribution.AbstractSurivialDistribution (AbstractDistribution) ¤

Abstract distribution + concrete survival_function and log_survival_function.

survival_function(self, value: PyTree[Array]) -> PyTree[Array] ¤
log_survival_function(self, value: PyTree[Array]) -> PyTree[Array] ¤

distreqx.distributions.transformed.AbstractTransformed (AbstractSurivialDistribution, AbstractProbDistribution) ¤

Distribution of a random variable transformed by a bijective function.

Let X be a continuous random variable and Y = f(X) be a random variable transformed by a differentiable bijection f (a "bijector"). Given the distribution of X (the "base distribution") and the bijector f, this class implements the distribution of Y (also known as the pushforward of the base distribution through f).

The probability density of Y can be computed by:

log p(y) = log p(x) - log|det J(f)(x)|

where p(x) is the probability density of X (the "base density") and J(f)(x) is the Jacobian matrix of f, both evaluated at x = f^{-1}(y).

Sampling from a Transformed distribution involves two steps: sampling from the base distribution x ~ p(x) and then evaluating y = f(x). For example:

  dist = distrax.Normal(loc=0., scale=1.)
  bij = distrax.ScalarAffine(shift=jnp.asarray([3., 3., 3.]))
  transformed_dist = distrax.Transformed(distribution=dist, bijector=bij)
  samples = transformed_dist.sample(jax.random.PRNGKey(0))
  print(samples)  # [2.7941577, 2.7941577, 2.7941577]

This assumes that the forward function of the bijector is traceable; that is, it is a pure function that does not contain run-time branching. Functions that do not strictly meet this requirement can still be used, but we cannot guarantee that the shapes, dtype, and KL computations involving the transformed distribution can be correctly obtained.

distribution property readonly ¤

The base distribution.

bijector property readonly ¤

The bijector representing the transformation.

dtype: dtype property readonly ¤

See Distribution.dtype.

event_shape: tuple property readonly ¤

See Distribution.event_shape.

log_prob(self, value: Array) -> Array ¤

See Distribution.log_prob.

sample(self, key: PRNGKeyArray) -> Array ¤

Return a sample.

sample_and_log_prob(self, key: PRNGKeyArray) -> tuple ¤

Return a sample and log prob.

This function is more efficient than calling sample and log_prob separately, because it uses only the forward methods of the bijector. It also works for bijectors that don't implement inverse methods.

Arguments:

  • key: PRNG key.

Returns:

  • A tuple of a sample and its log probs.
entropy(self, input_hint: Optional[Array] = None) -> Array ¤

Calculates the Shannon entropy (in Nats).

Only works for bijectors with constant Jacobian determinant.

Arguments:

  • input_hint: an example sample from the base distribution, used to compute the constant forward log-determinant. If not specified, it is computed using a zero array of the shape and dtype of a sample from the base distribution.

Returns:

  • The entropy of the distribution.

Raises:

  • NotImplementedError: if bijector's Jacobian determinant is not known to be constant.

distreqx.distributions.mvn_from_bijector.AbstractMultivariateNormalFromBijector (AbstractTransformed) ¤

AbstractMultivariateNormalFromBijector()

scale: AbstractLinearBijector property readonly ¤

The scale bijector.

loc: Array property readonly ¤

The loc parameter of the distribution.

covariance(self) -> Array ¤

Calculates the covariance matrix.

Returns: - The covariance matrix, of shape k x k.

variance(self) -> Array ¤

Calculates the variance of all one-dimensional marginals.

stddev(self) -> Array ¤

Calculates the standard deviation (the square root of the variance).

kl_divergence(self, other_dist, **kwargs) -> Array ¤

Calculates the KL divergence to another distribution.

Arguments:

  • other_dist: A compatible disteqx distribution.
  • kwargs: Additional kwargs.

Returns:

The KL divergence KL(self || other_dist).