Normalizing Flowsยค
In this tutorial, adapted from: https://gebob19.github.io/normalizing-flows/, we will implement a simple RealNVP normalizing flow. Normalizing flows are a class of generative models which are advantageous due to their explicit representation of densities and likelihoods, but come at a cost of requiring computable jacobian determinants and invertible layers. For an introduction to normalizing flows, see https://arxiv.org/abs/1912.02762.
import equinox as eqx
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
from sklearn import datasets
from sklearn.preprocessing import StandardScaler
from distreqx import bijectors, distributions
Let's define our simple dataset.
n_samples = 20000
noisy_moons = datasets.make_moons(n_samples=n_samples, noise=0.05)
X, y = noisy_moons
X = StandardScaler().fit_transform(X)
xlim, ylim = [-3, 3], [-3, 3]
plt.scatter(X[:, 0], X[:, 1], s=10, color="red")
plt.xlim(xlim)
plt.ylim(ylim)
Now we can program our custom bijector.
class RNVP(
bijectors.AbstractFwdLogDetJacBijector, bijectors.AbstractInvLogDetJacBijector
):
_is_constant_jacobian: bool
_is_constant_log_det: bool
sig_net: eqx.Module
mu_net: eqx.Module
d: int
k: int
flip: bool
def __init__(self, d, k, flip, key, hidden=32):
self._is_constant_jacobian = False
self._is_constant_log_det = False
self.flip = flip
keys = jax.random.split(key, 4)
self.d = d
self.k = k
self.sig_net = eqx.nn.Sequential(
[
eqx.nn.Linear(k, hidden, key=keys[0]),
eqx.nn.Lambda(jax.nn.swish),
eqx.nn.Linear(hidden, d - k, key=keys[1]),
]
)
self.mu_net = eqx.nn.Sequential(
[
eqx.nn.Linear(k, hidden, key=keys[2]),
eqx.nn.Lambda(jax.nn.swish),
eqx.nn.Linear(hidden, d - k, key=keys[3]),
]
)
def forward_and_log_det(self, x):
x1, x2 = x[: self.k], x[self.k :]
if self.flip:
x1, x2 = x2, x1
sig = self.sig_net(x1)
z1, z2 = x1, x2 * jnp.exp(sig) + self.mu_net(x1)
if self.flip:
z1, z2 = z2, z1
z_hat = jnp.concatenate([z1, z2])
log_det = jnp.sum(sig)
return z_hat, log_det
def inverse(self, y):
z1, z2 = y[: self.k], y[self.k :]
if self.flip:
z1, z2 = z2, z1
x1 = z1
x2 = (z2 - self.mu_net(z1)) * jnp.exp(-self.sig_net(z1))
if self.flip:
x1, x2 = x2, x1
x_hat = jnp.concatenate([x1, x2])
return x_hat
def forward(self, x):
y, _ = self.forward_and_log_det(x)
return y
def inverse_and_log_det(self, y):
raise NotImplementedError(
f"Bijector {self.name} does not implement `inverse_and_log_det`."
)
def same_as(self, other) -> bool:
return type(other) is RNVP
Since we want to stack these together, we can use a chain bijector to accomplish this.
n = 3
key = jax.random.key(0)
keys = jax.random.split(key, n)
bijector_chain = bijectors.Chain([RNVP(2, 1, i % 2, keys[i], 600) for i in range(n)])
Flows map p(x) -> p(z) via a function F (samples are generated via F^-1(z)). In general, p(z) is chosen to have some tractable form for sampling and calculating log probabilities. A common choice is Gaussian, which we go with here.
base_distribution = distributions.MultivariateNormalDiag(jnp.zeros(2))
base_distribution_sample = eqx.filter_vmap(base_distribution.sample)
base_distribution_log_prob = eqx.filter_vmap(base_distribution.log_prob)
Here we plot the initial, untrained, samples.
num_samples = 2000
base_samples = base_distribution_sample(jax.random.split(key, num_samples))
transformed_samples = eqx.filter_vmap(bijector_chain.inverse)(base_samples)
plt.scatter(
transformed_samples[:, 0],
transformed_samples[:, 1],
s=10,
color="blue",
label="Untrained F^-1",
)
plt.scatter(base_samples[:, 0], base_samples[:, 1], s=10, color="red", label="Base")
plt.legend()
plt.xlim(xlim)
plt.ylim(ylim)
plt.title("Initial Samples")
plt.show()
learning_rate = 1e-3
num_iters = 1000
batch_size = 128
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(eqx.filter(bijector_chain, eqx.is_inexact_array))
def log_prob(params: bijectors.AbstractBijector, data):
f_inv, log_det = params.forward_and_log_det(data)
log_prob = base_distribution.log_prob(f_inv) + log_det
return log_prob
def loss(params, batch):
return -jnp.mean(eqx.filter_vmap(log_prob, in_axes=(None, 0))(params, batch))
@eqx.filter_jit
def update(model, batch, opt_state, optimizer):
val, grads = eqx.filter_value_and_grad(loss)(model, batch)
update, opt_state = optimizer.update(grads, opt_state)
model = eqx.apply_updates(model, update)
return model, opt_state, val
losses = []
for i in range(num_iters):
if i % 500 == 0:
print(i)
batch_indices = jax.random.choice(
key, jnp.arange(X.shape[0]), (batch_size,), replace=False
)
batch = X[batch_indices]
bijector_chain, opt_state, loss_val = update(
bijector_chain, batch, opt_state, optimizer
)
losses.append(loss_val)
plt.plot(losses)
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.title("Training Loss Over Time")
plt.show()
After training we can plot both F(x) (to see where the true data ends up in our sampled space) and F^-1(z) to generate new samples.
trained_params = bijector_chain
transformed_samples_trained = eqx.filter_vmap(bijector_chain.inverse)(base_samples)
plt.scatter(
transformed_samples[:, 0], transformed_samples[:, 1], s=10, label="Initial F^-1"
)
plt.scatter(X[:, 0], X[:, 1], s=5, label="Data")
plt.scatter(
transformed_samples_trained[:, 0],
transformed_samples_trained[:, 1],
s=10,
label="Trained F^-1",
)
plt.xlim(xlim)
plt.ylim(ylim)
plt.legend()
plt.show()
data_to_noise = eqx.filter_vmap(bijector_chain.forward)(X)
plt.scatter(data_to_noise[:, 0], data_to_noise[:, 1], s=10, label="Z = F(X)")
plt.xlim(xlim)
plt.ylim(ylim)
plt.legend()
plt.show()