Skip to content

distreqx¤

distreqx (pronounced "dist-rex") is a JAX-based library providing implementations of distributions, bijectors, and tools for statistical and probabilistic machine learning with all benefits of jax (native GPU/TPU acceleration, differentiability, vectorization, distributing workloads, XLA compilation, etc.).

The origin of this package is a reimplementation of distrax, (which is a subset of TensorFlow Probability (TFP), with some new features and emphasis on jax compatibility) using equinox. As a result, much of the original code/comments/documentation/tests are directly taken or adapted from distrax (original distrax copyright available at end of README.)

Current features include:

  • Probability distributions
  • Bijectors

Installation¤

git clone https://github.com/lockwo/distreqx.git
cd distreqx
pip install -e .

Requires Python 3.9+, JAX 0.4.11+, and Equinox 0.11.0+.

Documentation¤

Available at https://lockwo.github.io/distreqx/.

Quick example¤

import jax
from jax import numpy as jnp
from distreqx import distributions

key = jax.random.PRNGKey(1234)
mu = jnp.array([-1., 0., 1.])
sigma = jnp.array([0.1, 0.2, 0.3])

dist = distributions.MultivariateNormalDiag(mu, sigma)

samples = dist.sample(key)

print(dist.log_prob(samples))

Differences with Distrax¤

  • No official support/interoperability with TFP
  • The concept of a batch dimension is dropped. If you want to operate on a batch, use vmap (note, this can be used in construction as well, e.g. vmaping the construction of a ScalarAffine)
  • Broader pytree enablement
  • Strict abstract/final design pattern