Skip to content

Abstract Distributions¤

distreqx.distributions.AbstractDistribution ¤

Base class for all distreqx distributions.

log_prob(value: PyTree[Array]) -> PyTree[Array] ¤

Calculates the log probability of an event.

Arguments:

  • value: An event.

Returns:

  • The log probability log P(value).
prob(value: PyTree[Array]) -> PyTree[Array] ¤

Calculates the probability of an event.

Arguments:

  • value: An event.

Returns:

  • The probability P(value).
cdf(value: PyTree[Array]) -> PyTree[Array] ¤

Evaluates the cumulative distribution function at value.

Arguments:

  • value: An event.

Returns:

  • The CDF evaluated at value, i.e. P[X <= value].
survival_function(value: PyTree[Array]) -> PyTree[Array] ¤

Evaluates the survival function at value.

Note that by default we use a numerically not necessarily stable definition of the survival function in terms of the CDF. More stable definitions should be implemented in subclasses for distributions for which they exist.

Arguments:

  • value: An event.

Returns:

  • The survival function evaluated at value, i.e. P[X > value]
log_survival_function(value: PyTree[Array]) -> PyTree[Array] ¤

Evaluates the log of the survival function at value.

Note that by default we use a numerically not necessarily stable definition of the log of the survival function in terms of the CDF. More stable definitions should be implemented in subclasses for distributions for which they exist.

Arguments:

  • value: An event.

Returns:

  • The log of the survival function evaluated at value, i.e. log P[X > value]
kl_divergence(other_dist, **kwargs) -> PyTree[Array] ¤

Calculates the KL divergence to another distribution.

Arguments:

  • other_dist: A compatible distreqx Distribution.
  • kwargs: Additional kwargs.

Returns:

  • The KL divergence KL(self || other_dist).
cross_entropy(other_dist, **kwargs) -> Array ¤

Calculates the cross entropy to another distribution.

Arguments:

  • other_dist: A compatible distreqx Distribution.
  • kwargs: Additional kwargs.

Returns:

  • The cross entropy H(self || other_dist).

distreqx.distributions.AbstractSampleLogProbDistribution(distreqx.distributions.AbstractDistribution) ¤

Abstract distribution + concrete sample_and_log_prob.

distreqx.distributions.AbstractProbDistribution(distreqx.distributions.AbstractDistribution) ¤

Abstract distribution + concrete prob.

prob(value: PyTree[Array]) -> PyTree[Array] ¤

Calculates the probability of an event.

Arguments:

  • value: An event.

Returns:

  • The probability P(value).

distreqx.distributions.AbstractCDFDistribution(distreqx.distributions.AbstractDistribution) ¤

Abstract distribution + concrete cdf.

cdf(value: PyTree[Array]) -> PyTree[Array] ¤

Evaluates the cumulative distribution function at value.

Arguments:

  • value: An event.

Returns:

  • The CDF evaluated at value, i.e. P[X <= value].

distreqx.distributions.AbstractSTDDistribution(distreqx.distributions.AbstractDistribution) ¤

Abstract distribution + concrete stddev.

stddev() -> PyTree[Array] ¤

Calculate the standard deviation.

distreqx.distributions.AbstractSurvivalDistribution(distreqx.distributions.AbstractDistribution) ¤

Abstract distribution + concrete survival_function and log_survival_function.

survival_function(value: PyTree[Array]) -> PyTree[Array] ¤

Evaluates the survival function at value.

Note that by default we use a numerically not necessarily stable definition of the survival function in terms of the CDF. More stable definitions should be implemented in subclasses for distributions for which they exist.

Arguments:

  • value: An event.

Returns:

  • The survival function evaluated at value, i.e. P[X > value]
log_survival_function(value: PyTree[Array]) -> PyTree[Array] ¤

Evaluates the log of the survival function at value.

Note that by default we use a numerically not necessarily stable definition of the log of the survival function in terms of the CDF. More stable definitions should be implemented in subclasses for distributions for which they exist.

Arguments:

  • value: An event.

Returns:

  • The log of the survival function evaluated at value, i.e. log P[X > value]

distreqx.distributions.AbstractTransformed(distreqx.distributions.AbstractSurvivalDistribution, distreqx.distributions.AbstractProbDistribution) ¤

Abstract base class for transformed distributions.

See distreqx.distributions.Transformed for full documentation.

dtype property ¤

See Distribution.dtype.

event_shape property ¤

See Distribution.event_shape.

log_prob(value: Array) -> Array ¤

See Distribution.log_prob.

sample(key: Key[Array, '']) -> Array ¤

Return a sample.

sample_and_log_prob(key: Key[Array, '']) -> tuple ¤

Return a sample and log prob.

This function is more efficient than calling sample and log_prob separately, because it uses only the forward methods of the bijector. It also works for bijectors that don't implement inverse methods.

Arguments:

  • key: PRNG key.

Returns:

  • A tuple of a sample and its log probs.
entropy(input_hint: Array | None = None) -> Array ¤

Calculates the Shannon entropy (in Nats).

Only works for bijectors with constant Jacobian determinant.

Arguments:

  • input_hint: an example sample from the base distribution, used to compute the constant forward log-determinant. If not specified, it is computed using a zero array of the shape and dtype of a sample from the base distribution.

Returns:

  • The entropy of the distribution.

Raises:

  • NotImplementedError: if bijector's Jacobian determinant is not known to be constant.

distreqx.distributions.AbstractMultivariateNormalFromBijector(distreqx.distributions.AbstractTransformed) ¤

AbstractMultivariateNormalFromBijector()

covariance() -> Array ¤

Calculates the covariance matrix.

Returns: - The covariance matrix, of shape k x k.

variance() -> Array ¤

Calculates the variance of all one-dimensional marginals.

stddev() -> Array ¤

Calculates the standard deviation (the square root of the variance).

kl_divergence(other_dist, **kwargs) -> Array ¤

Calculates the KL divergence to another distribution.

Arguments:

  • other_dist: A compatible disteqx distribution.
  • kwargs: Additional kwargs.

Returns:

The KL divergence KL(self || other_dist).