Posts by Tags

AI

JAX

Improved batch norm with hijax?

6 minute read

Published:

Stateful operations in JAX, such as batchnorm, can sometimes be annoying to work with. As a DSL that relies on pure functions, it requires the usual functional approach of passing around a different state object to all the functions, changing $f :: x \to y$ to $f :: (x, state) \to (y, state)$. While this may not sound like much work, it can be quite tedious to manage with large and complex machine learning network workflows. Part of this complexity comes from the fact that we often want to use vmap to batch, rather than having a batch dimension for everything. Although different libraries in JAX implement stateful operations in different ways, they have a common mechanism for batchnorm, all of them keep track of a vmap named axis, then do collective operation like pmean to get the stats across the vmapped batch. Here, we won’t be looking to simplify that dimension of batchnorm, but the state management.

LLMs

embeddings

Can Quantization Reduce Separability?

1 minute read

Published:

I was recently experimenting with the linear separability of embeddings, and showed a graph which indicated the embeddings were substantially more linearly separable than the quantized embeddings. One person commented that this didn’t make sense, and the quantized vectors should be as separable as the non-quantized vectors. That didn’t seem right to me, but on the spot I couldn’t think of a trivial counter example. In this blog, I present two.

hijax

Improved batch norm with hijax?

6 minute read

Published:

Stateful operations in JAX, such as batchnorm, can sometimes be annoying to work with. As a DSL that relies on pure functions, it requires the usual functional approach of passing around a different state object to all the functions, changing $f :: x \to y$ to $f :: (x, state) \to (y, state)$. While this may not sound like much work, it can be quite tedious to manage with large and complex machine learning network workflows. Part of this complexity comes from the fact that we often want to use vmap to batch, rather than having a batch dimension for everything. Although different libraries in JAX implement stateful operations in different ways, they have a common mechanism for batchnorm, all of them keep track of a vmap named axis, then do collective operation like pmean to get the stats across the vmapped batch. Here, we won’t be looking to simplify that dimension of batchnorm, but the state management.

machine learning

Improved batch norm with hijax?

6 minute read

Published:

Stateful operations in JAX, such as batchnorm, can sometimes be annoying to work with. As a DSL that relies on pure functions, it requires the usual functional approach of passing around a different state object to all the functions, changing $f :: x \to y$ to $f :: (x, state) \to (y, state)$. While this may not sound like much work, it can be quite tedious to manage with large and complex machine learning network workflows. Part of this complexity comes from the fact that we often want to use vmap to batch, rather than having a batch dimension for everything. Although different libraries in JAX implement stateful operations in different ways, they have a common mechanism for batchnorm, all of them keep track of a vmap named axis, then do collective operation like pmean to get the stats across the vmapped batch. Here, we won’t be looking to simplify that dimension of batchnorm, but the state management.

Can Quantization Reduce Separability?

1 minute read

Published:

I was recently experimenting with the linear separability of embeddings, and showed a graph which indicated the embeddings were substantially more linearly separable than the quantized embeddings. One person commented that this didn’t make sense, and the quantized vectors should be as separable as the non-quantized vectors. That didn’t seem right to me, but on the spot I couldn’t think of a trivial counter example. In this blog, I present two.

programming

python

Improved batch norm with hijax?

6 minute read

Published:

Stateful operations in JAX, such as batchnorm, can sometimes be annoying to work with. As a DSL that relies on pure functions, it requires the usual functional approach of passing around a different state object to all the functions, changing $f :: x \to y$ to $f :: (x, state) \to (y, state)$. While this may not sound like much work, it can be quite tedious to manage with large and complex machine learning network workflows. Part of this complexity comes from the fact that we often want to use vmap to batch, rather than having a batch dimension for everything. Although different libraries in JAX implement stateful operations in different ways, they have a common mechanism for batchnorm, all of them keep track of a vmap named axis, then do collective operation like pmean to get the stats across the vmapped batch. Here, we won’t be looking to simplify that dimension of batchnorm, but the state management.

quantization

Can Quantization Reduce Separability?

1 minute read

Published:

I was recently experimenting with the linear separability of embeddings, and showed a graph which indicated the embeddings were substantially more linearly separable than the quantized embeddings. One person commented that this didn’t make sense, and the quantized vectors should be as separable as the non-quantized vectors. That didn’t seem right to me, but on the spot I couldn’t think of a trivial counter example. In this blog, I present two.

society