Improved batch norm with hijax?

6 minute read

Published:

Stateful operations in JAX, such as batchnorm, can sometimes be annoying to work with. As a DSL that relies on pure functions, it requires the usual functional approach of passing around a different state object to all the functions, changing $f :: x \to y$ to $f :: (x, state) \to (y, state)$. While this may not sound like much work, it can be quite tedious to manage with large and complex machine learning network workflows. Part of this complexity comes from the fact that we often want to use vmap to batch, rather than having a batch dimension for everything. Although different libraries in JAX implement stateful operations in different ways, they have a common mechanism for batchnorm, all of them keep track of a vmap named axis, then do collective operation like pmean to get the stats across the vmapped batch. Here, we won’t be looking to simplify that dimension of batchnorm, but the state management.

State management is not always pretty, and leads to unergonomic interfaces such as flax:

logits, updates = state.apply_fn(
      {'params': params, 'batch_stats': state.batch_stats},
      x=batch['image'], train=True, mutable=['batch_stats'])

or somewhat complicated implementations that require careful operations for edge cases.

But what if there was a way to simplify this? Thanks to the experimental hijax features, I think there will be. A further explanation of what hijax is will be in a future blog, but for now check out the JAX devlabs keynote for more. Basically, hijax is a new interface to working with intermediate levels of operations, between pytrees and primitives. It’s experimental and changing, so the interface is subject to change (this code is with v0.9.0 JAX), but the idea has a lot of potential to simplify these workloads.

Similar to how array refs allow for stateful operations (see my previous blog), hijax has features that enable functionality in that direction. Specifically, the idea of Quasi Dynamic Data (QDD) which allows for stateful operations from the user perspective. Under the hood, when a trace requires lowering, hijax primitives call their to_lojax methods. For Box, box_get_p.to_lojax simply returns the stored value’s leaves, while box_set_p.to_lojax updates the internal _val attribute. This means the mutable semantics the user programs get compiled down to direct value access. This allows a nice simplification of batch norm implementations of state. For example,

import jax
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class
from jax._src.hijax import Box

@register_pytree_node_class
class BatchNorm:

    def __init__(self, num_features, momentum=0.1, eps=1e-5, axis_name=None):
        self.running_mean = Box(jnp.zeros(num_features))
        self.running_var = Box(jnp.ones(num_features))
        self.gamma = jnp.ones(num_features)
        self.beta = jnp.zeros(num_features)
        self.momentum = momentum
        self.eps = eps
        self.num_features = num_features
        self.axis_name = axis_name

    def __call__(self, x, training=True):
        if training:
            if self.axis_name is not None:
                mean = jax.lax.pmean(x, axis_name=self.axis_name)
                mean_of_squares = jax.lax.pmean(x ** 2, axis_name=self.axis_name)
                var = mean_of_squares - mean ** 2
            else:
                mean = x.mean(axis=0)
                var = x.var(axis=0)

            rm = self.running_mean.get()
            rv = self.running_var.get()
            new_rm = jax.lax.stop_gradient(rm * (1 - self.momentum) + mean * self.momentum)
            new_rv = jax.lax.stop_gradient(rv * (1 - self.momentum) + var * self.momentum)
            self.running_mean.set(new_rm)
            self.running_var.set(new_rv)
        else:
            mean = self.running_mean.get()
            var = self.running_var.get()

        x_norm = (x - mean) / jnp.sqrt(var + self.eps)
        return self.gamma * x_norm + self.beta

To really get this working with vmap, you need to patch BatchTrace with a cur_qdd method:

from jax._src.interpreters import batching

def _batch_trace_cur_qdd(self, x):
    return x.cur_qdd()

if not hasattr(batching.BatchTrace, 'cur_qdd'):
    batching.BatchTrace.cur_qdd = _batch_trace_cur_qdd

This is a bit of a hack since BatchTrace doesn’t natively support hijax’s mutable types yet. Mutable HiTypes like Box officially don’t work with vmap or scan1, but this hack seems to work in my tests so far.

With a little sprinkling of vmap rules, like we did with refs, we can even get this to replace the somewhat complicated substate/vmap setup of equinox. Basically, we just have to tell JAX that we have a vmap rule for this Box, in which we just want to broadcast/propagate the vmaping to the elements of the Box. Certainly there would be other ways to approach this, but for the simple Counter example this is good enough.

from jax._src.interpreters import batching
from jax._src.interpreters.batching import not_mapped
from jax._src.hijax import (
    new_box_p,
    box_get_p,
    box_set_p,
)

def _new_box_batching(axis_data, batched_args, batch_dims, *, treedef):
    box = new_box_p.bind(treedef=treedef)
    return box, not_mapped


def _box_get_batching(axis_data, batched_args, batch_dims, *, avals):
    (box,), (box_bdim,) = batched_args, batch_dims

    if box_bdim is not not_mapped:
        raise ValueError("Box cannot be batched")

    results = box_get_p.bind(box, avals=avals)
    out_bdims = (not_mapped,) * len(results)
    return results, out_bdims


def _box_set_batching(axis_data, batched_args, batch_dims, *, treedef):
    box, *vals = batched_args
    box_bdim, *val_bdims = batch_dims

    if box_bdim is not not_mapped:
        raise ValueError("Box cannot be batched")

    box_set_p.bind(box, *vals, treedef=treedef)
    return [], []

batching.fancy_primitive_batchers[new_box_p] = _new_box_batching
batching.fancy_primitive_batchers[box_get_p] = _box_get_batching
batching.fancy_primitive_batchers[box_set_p] = _box_set_batching

def _batch_trace_cur_qdd(self, x):
    return x.cur_qdd()

batching.BatchTrace.cur_qdd = _batch_trace_cur_qdd

@register_pytree_node_class
@dataclass(init=False)
class Counter:
    index: Box
    _hash: int
    _counter: ClassVar[int] = 0

    def __init__(self, index):
        self.index = Box(index)
        self._hash = Counter._counter
        Counter._counter += 1

    def __hash__(self):
        return self._hash

    def __call__(self, x):
        value = self.index.get()
        new_x = x + value
        self.index.set(value + 1)
        return new_x

    def tree_flatten(self):
        children = (self.index,)
        aux_data = self._hash
        return children, aux_data

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        (index_box,) = children
        obj = object.__new__(cls)
        obj.index = index_box
        obj._hash = aux_data
        return obj


counter = Counter(jnp.array(0))
x = jnp.array(2.3)
f = jax.jit(counter)

print(counter.index.get())

_ = f(x)
print(counter.index.get())

_ = f(x)
print(counter.index.get())


class Model(eqx.Module):
    linear: eqx.nn.Linear
    counter: Counter
    v_counter: Counter

    def __init__(self, key):
        # Not-stateful layer
        self.linear = eqx.nn.Linear(2, 2, key=key)
        # Stateful layer.
        self.counter = Counter(jnp.array(0))
        # Vmap'd stateful layer. (Whose initial state will include a batch dimension.)
        self.v_counter = jax.vmap(Counter, out_axes=None)(jnp.array([0, 0]))

    def __call__(self, x):
        assert x.shape == (2,)
        x = self.linear(x)
        x = self.counter(x)
        x = jax.vmap(self.v_counter)(x)
        return x


key = jr.key(0)
model = Model(key)
x = jnp.array([5.0, -1.0])
_ = model(x)
print(model.counter.index.get())
print(model.v_counter.index.get())

As you can see, this replicates the effect of the vmap/substate of equinox in a much simpler and elegant manner (of course, the custom vmap rules are annoying, but this is mostly because it is experimental and requires only a single implementation for this Box type). This is just the beginnings of the cool stuff you can do with hijax, and I look forward to seeing more mainstream JAX package adoption, for example see the Flax hijax talk.

Changelog

  1. February 2, 2026: Published initial version.

Footnotes