Skip to content

Block Bijector¤

distreqx.bijectors.block.Block (AbstractBijector) ¤

A wrapper that promotes a bijector to a block bijector.

A block bijector applies a bijector to a k-dimensional array of events, but considers that array of events to be a single event. In practical terms, this means that the log det Jacobian will be summed over its last k dimensions.

For example, consider a scalar bijector (such as Tanh) that operates on scalar events. We may want to apply this bijector identically to a 4D array of shape [N, H, W, C] representing a sequence of N images. Doing so naively with a vmap will produce a log det Jacobian of shape [N, H, W, C], because the scalar bijector will assume scalar events and so all 4 dimensions will be considered as batch. To promote the scalar bijector to a "block scalar" that operates on the 3D arrays can be done by Block(bijector, ndims=3). Then, applying the block bijector will produce a log det Jacobian of shape [N] as desired.

In general, suppose bijector operates on n-dimensional events. Then, Block(bijector, k) will promote bijector to a block bijector that operates on (k + n)-dimensional events, summing the log det Jacobian over its last k dimensions.

__init__(self, bijector: AbstractBijector, ndims: int) ¤

Initializes a Block.

Arguments:

  • bijector: the bijector to be promoted to a block bijector. It can be a distreqx bijector or a callable to be wrapped by Lambda.
  • ndims: number of dimensions to promote to event dimensions.