Transformed¤
distreqx.distributions.transformed.Transformed (AbstractTransformed, AbstractSTDDistribution)
¤
Distribution of a random variable transformed by a bijective function.
Let X
be a continuous random variable and Y = f(X)
be a random variable
transformed by a differentiable bijection f
(a "bijector"). Given the
distribution of X
(the "base distribution") and the bijector f
, this class
implements the distribution of Y
(also known as the pushforward of the base
distribution through f
).
The probability density of Y
can be computed by:
log p(y) = log p(x) - log|det J(f)(x)|
where p(x)
is the probability density of X
(the "base density") and
J(f)(x)
is the Jacobian matrix of f
, both evaluated at x = f^{-1}(y)
.
Sampling from a Transformed distribution involves two steps: sampling from the
base distribution x ~ p(x)
and then evaluating y = f(x)
. For example:
dist = distrax.Normal(loc=0., scale=1.)
bij = distrax.ScalarAffine(shift=jnp.asarray([3., 3., 3.]))
transformed_dist = distrax.Transformed(distribution=dist, bijector=bij)
samples = transformed_dist.sample(jax.random.PRNGKey(0))
print(samples) # [2.7941577, 2.7941577, 2.7941577]
This assumes that the forward
function of the bijector is traceable; that is,
it is a pure function that does not contain run-time branching. Functions that
do not strictly meet this requirement can still be used, but we cannot guarantee
that the shapes, dtype, and KL computations involving the transformed distribution
can be correctly obtained.
__init__(self, distribution: AbstractDistribution, bijector: AbstractBijector)
¤
Initializes a Transformed distribution.
Arguments:
- distribution
: the base distribution.
- bijector
: a differentiable bijective transformation. Can be a bijector or
a callable to be wrapped by Lambda
bijector.
mean(self) -> Array
¤
Calculates the mean.
mode(self) -> Array
¤
Calculates the mode.
kl_divergence(self, other_dist, **kwargs) -> Array
¤
Obtains the KL divergence between two Transformed distributions.
This computes the KL divergence between two Transformed distributions with the same bijector. If the two Transformed distributions do not have the same bijector, an error is raised. To determine if the bijectors are equal, this method proceeds as follows:
-
If both bijectors are the same instance of a distreqx bijector, then they are declared equal.
-
If not the same instance, we check if they are equal according to their
same_as
predicate. -
Otherwise, the string representation of the Jaxpr of the
forward
method of each bijector is compared. If both string representations are equal, the bijectors are declared equal. -
Otherwise, the bijectors cannot be guaranteed to be equal and an error is raised.
Arguments:
other_dist
: A Transformed distribution.input_hint
: keyword argument, an example sample from the base distribution, used to trace theforward
method. If not specified, it is computed using a zero array of the shape and dtype of a sample from the base distribution.
Returns:
KL(dist1 || dist2)
.
Raises:
NotImplementedError
: If bijectors are not known to be equal.ValueError
: If the base distributions do not have the sameevent_shape
.