Skip to content

Transformed¤

distreqx.distributions.transformed.Transformed (AbstractTransformed, AbstractSTDDistribution) ¤

Distribution of a random variable transformed by a bijective function.

Let X be a continuous random variable and Y = f(X) be a random variable transformed by a differentiable bijection f (a "bijector"). Given the distribution of X (the "base distribution") and the bijector f, this class implements the distribution of Y (also known as the pushforward of the base distribution through f).

The probability density of Y can be computed by:

log p(y) = log p(x) - log|det J(f)(x)|

where p(x) is the probability density of X (the "base density") and J(f)(x) is the Jacobian matrix of f, both evaluated at x = f^{-1}(y).

Sampling from a Transformed distribution involves two steps: sampling from the base distribution x ~ p(x) and then evaluating y = f(x). For example:

  dist = distrax.Normal(loc=0., scale=1.)
  bij = distrax.ScalarAffine(shift=jnp.asarray([3., 3., 3.]))
  transformed_dist = distrax.Transformed(distribution=dist, bijector=bij)
  samples = transformed_dist.sample(jax.random.PRNGKey(0))
  print(samples)  # [2.7941577, 2.7941577, 2.7941577]

This assumes that the forward function of the bijector is traceable; that is, it is a pure function that does not contain run-time branching. Functions that do not strictly meet this requirement can still be used, but we cannot guarantee that the shapes, dtype, and KL computations involving the transformed distribution can be correctly obtained.

__init__(self, distribution: AbstractDistribution, bijector: AbstractBijector) ¤

Initializes a Transformed distribution.

Arguments: - distribution: the base distribution. - bijector: a differentiable bijective transformation. Can be a bijector or a callable to be wrapped by Lambda bijector.

mean(self) -> Array ¤

Calculates the mean.

mode(self) -> Array ¤

Calculates the mode.

kl_divergence(self, other_dist, **kwargs) -> Array ¤

Obtains the KL divergence between two Transformed distributions.

This computes the KL divergence between two Transformed distributions with the same bijector. If the two Transformed distributions do not have the same bijector, an error is raised. To determine if the bijectors are equal, this method proceeds as follows:

  • If both bijectors are the same instance of a distreqx bijector, then they are declared equal.

  • If not the same instance, we check if they are equal according to their same_as predicate.

  • Otherwise, the string representation of the Jaxpr of the forward method of each bijector is compared. If both string representations are equal, the bijectors are declared equal.

  • Otherwise, the bijectors cannot be guaranteed to be equal and an error is raised.

Arguments:

  • other_dist: A Transformed distribution.
  • input_hint: keyword argument, an example sample from the base distribution, used to trace the forward method. If not specified, it is computed using a zero array of the shape and dtype of a sample from the base distribution.

Returns:

  • KL(dist1 || dist2).

Raises:

  • NotImplementedError: If bijectors are not known to be equal.
  • ValueError: If the base distributions do not have the same event_shape.