Block Bijector¤
distreqx.bijectors.block.Block (AbstractBijector)
¤
A wrapper that promotes a bijector to a block bijector.
A block bijector applies a bijector to a k-dimensional array of events, but considers that array of events to be a single event. In practical terms, this means that the log det Jacobian will be summed over its last k dimensions.
For example, consider a scalar bijector (such as Tanh
) that operates on
scalar events. We may want to apply this bijector identically to a 4D array of
shape [N, H, W, C] representing a sequence of N images. Doing so naively with
a vmap
will produce a log det Jacobian of shape [N, H, W, C], because the
scalar bijector will assume scalar events and so all 4 dimensions will be
considered as batch. To promote the scalar bijector to a "block scalar" that
operates on the 3D arrays can be done by Block(bijector, ndims=3)
. Then,
applying the block bijector will produce a log det Jacobian of shape [N]
as desired.
In general, suppose bijector
operates on n-dimensional events. Then,
Block(bijector, k)
will promote bijector
to a block bijector that
operates on (k + n)-dimensional events, summing the log det Jacobian over its
last k dimensions.
__init__(self, bijector: AbstractBijector, ndims: int)
¤
Initializes a Block.
Arguments:
bijector
: the bijector to be promoted to a block bijector. It can be a distreqx bijector or a callable to be wrapped byLambda
.ndims
: number of dimensions to promote to event dimensions.