Skip to content

Sparse Ising Model on MNISTยค

import equinox as eqx
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
from isax import (
    BlockGraph,
    Edge,
    IsingModel,
    IsingSampler,
    Node,
    sample_chain,
    SamplingArgs,
)
from torchvision import datasets
train_dataset = datasets.MNIST(root="/tmp/mnist", train=True, download=True)
test_dataset = datasets.MNIST(root="/tmp/mnist", train=False, download=True)

train_images = train_dataset.data.numpy()
train_labels = train_dataset.targets.numpy()
test_images = test_dataset.data.numpy()
test_labels = test_dataset.targets.numpy()

# Binarize images
train_images = (train_images > 127).astype(np.int32)
test_images = (test_images > 127).astype(np.int32)

# Flatten to vectors
train_images = train_images.reshape(train_images.shape[0], -1)
test_images = test_images.reshape(test_images.shape[0], -1)

target_classes = [0, 1]
train_mask = np.isin(train_labels, target_classes)
test_mask = np.isin(test_labels, target_classes)

train_images = jnp.array(train_images[train_mask])
train_labels = jnp.array(train_labels[train_mask])
test_images = jnp.array(test_images[test_mask])
test_labels = jnp.array(test_labels[test_mask])

print(f"Train: {train_images.shape}, Test: {test_images.shape}")
Train: (12665, 784), Test: (2115, 784)
fig, axes = plt.subplots(1, 5, figsize=(10, 2))
for i, ax in enumerate(axes):
    ax.imshow(train_images[i].reshape(28, 28), cmap="gray")
    ax.set_title(f"Label: {train_labels[i]}")
    ax.axis("off")
plt.tight_layout()
plt.show()

img

n_visible = 28 * 28
n_hidden = 200
n_total = n_visible + n_hidden

k_connections = 50  # each hidden unit connects to 50 random visible units

key = jax.random.key(42)

visible_nodes = [Node() for _ in range(n_visible)]
hidden_nodes = [Node() for _ in range(n_hidden)]

edges = []
key, subkey = jax.random.split(key)
for h_idx, h in enumerate(hidden_nodes):
    visible_indices = np.random.choice(n_visible, size=(k_connections,), replace=False)
    for v_idx in visible_indices:
        edges.append(Edge(visible_nodes[int(v_idx)], h))

n_edges = len(edges)
print(f"Sparse: {n_edges} (vs {n_visible * n_hidden} for RBM)")
print(f"Sparsity: {n_edges / (n_visible * n_hidden) * 100:.1f}%")

blocks = [visible_nodes, hidden_nodes]
graph = BlockGraph(blocks, edges)
params = graph.get_sampling_params()
edge_indices, edge_mask = graph.get_edge_structure()
Sparse: 10000 (vs 156800 for RBM)
Sparsity: 6.4%
key, subkey = jax.random.split(key)
init_weights = jax.random.normal(subkey, (n_edges,)) * 0.01

data_mean = jnp.mean(train_images, axis=0)
eps = 1e-6
visible_biases = jnp.log((data_mean + eps) / (1 - data_mean + eps)) * 0.5
hidden_biases = jnp.zeros(n_hidden)
init_biases = jnp.concatenate([visible_biases, hidden_biases])

model = IsingModel(weights=init_weights, biases=init_biases)
@jax.jit
def data_to_spins(data):
    return 2 * data - 1


@jax.jit
def spins_to_data(spins):
    return (spins + 1) // 2


sampler = IsingSampler()

# Positive phase: only sample hidden (block 1), visible clamped
pos_sampling_args = SamplingArgs(
    gibbs_steps=200,
    blocks_to_sample=[1],
    data=params,
)

# Negative phase: sample both blocks freely
neg_sampling_args = SamplingArgs(
    gibbs_steps=400,
    blocks_to_sample=[0, 1],
    data=params,
)
def cd_loss(
    weights,
    biases,
    data_batch,
    key,
):
    model = IsingModel(weights=weights, biases=biases)
    batch_size = data_batch.shape[0]

    visible_spins = data_to_spins(data_batch)

    key, subkey = jax.random.split(key)
    hidden_init = jax.random.choice(subkey, jnp.array([-1, 1]), (batch_size, n_hidden))

    def sample_positive(visible, hidden, k):
        states = sample_chain(
            [visible, hidden],
            [sampler, sampler],
            model,
            pos_sampling_args,
            jax.random.fold_in(k, 0),
        )
        return states[0][-1], states[1][-1]

    key, subkey = jax.random.split(key)
    keys = jax.random.split(subkey, batch_size)
    v_pos, h_pos = jax.vmap(sample_positive)(visible_spins, hidden_init, keys)

    def sample_negative(visible, hidden, k):
        states = sample_chain(
            [visible, hidden],
            [sampler, sampler],
            model,
            neg_sampling_args,
            jax.random.fold_in(k, 1),
        )
        return states[0][-1], states[1][-1]

    key, subkey = jax.random.split(key)
    keys = jax.random.split(subkey, batch_size)
    v_neg, h_neg = jax.vmap(sample_negative)(v_pos, h_pos, keys)

    def compute_energy(v, h):
        state = jnp.concatenate([v, h])
        return model.energy(state, edge_indices, edge_mask)

    e_pos = jax.vmap(compute_energy)(v_pos, h_pos)
    e_neg = jax.vmap(compute_energy)(v_neg, h_neg)

    return jnp.mean(e_pos) - jnp.mean(e_neg)


@eqx.filter_jit
def cd_step(weights, biases, opt_state, optimizer, data_batch, key):
    loss, (grad_w, grad_b) = jax.value_and_grad(cd_loss, argnums=(0, 1))(
        weights, biases, data_batch, key
    )

    updates_w, opt_state_w = optimizer.update(grad_w, opt_state[0], weights)
    updates_b, opt_state_b = optimizer.update(grad_b, opt_state[1], biases)

    new_weights = optax.apply_updates(weights, updates_w)
    new_biases = optax.apply_updates(biases, updates_b)

    return new_weights, new_biases, (opt_state_w, opt_state_b), loss
batch_size = 64
n_epochs = 12
learning_rate = 0.005

optimizer = optax.adam(learning_rate=learning_rate)
opt_state = (optimizer.init(model.weights), optimizer.init(model.biases))

weights = model.weights
biases = model.biases
losses = []

n_batches = len(train_images) // batch_size

for epoch in range(n_epochs):
    key, subkey = jax.random.split(key)
    perm = jax.random.permutation(subkey, len(train_images))
    train_images_shuffled = train_images[perm]

    epoch_loss = 0.0
    for batch_idx in range(n_batches):
        batch = train_images_shuffled[
            batch_idx * batch_size : (batch_idx + 1) * batch_size
        ]

        key, subkey = jax.random.split(key)
        weights, biases, opt_state, loss = cd_step(
            weights, biases, opt_state, optimizer, batch, subkey
        )
        epoch_loss += loss

    avg_loss = epoch_loss / n_batches
    losses.append(float(avg_loss))
    print(f"Epoch {epoch + 1}/{n_epochs}, Loss: {avg_loss:.4f}")
Epoch 1/12, Loss: -12.9094
Epoch 2/12, Loss: -13.1847
Epoch 3/12, Loss: -9.7431
Epoch 4/12, Loss: -7.0727
Epoch 5/12, Loss: -5.5887
Epoch 6/12, Loss: -4.1621
Epoch 7/12, Loss: -2.7466
Epoch 8/12, Loss: -2.2553
Epoch 9/12, Loss: -1.7346
Epoch 10/12, Loss: -1.2224
Epoch 11/12, Loss: -1.0403
Epoch 12/12, Loss: -0.3680
gen_sampling_args = SamplingArgs(
    gibbs_steps=400,
    blocks_to_sample=[0, 1],
    data=params,
)

trained_model = IsingModel(weights=weights, biases=biases)
sample_fn = eqx.filter_jit(sample_chain)


@eqx.filter_jit
@eqx.filter_vmap
def generate_sample(key):
    k1, k2, k_run = jax.random.split(key, 3)
    init_visible = jax.random.choice(k1, jnp.array([-1, 1]), (n_visible,))
    init_hidden = jax.random.choice(k2, jnp.array([-1, 1]), (n_hidden,))

    samples = sample_fn(
        [init_visible, init_hidden],
        [sampler, sampler],
        trained_model,
        gen_sampling_args,
        k_run,
    )
    return samples[0][-1]
n_gen = 16
key, subkey = jax.random.split(key)
keys = jax.random.split(subkey, n_gen)
generated_spins = generate_sample(keys)
generated = spins_to_data(generated_spins)

n_cols, n_rows = 4, 4
fig, axes = plt.subplots(n_rows, n_cols, figsize=(8, 8))
for i, ax in enumerate(axes.flat):
    ax.imshow(generated[i].reshape(28, 28), cmap="gray")
    ax.axis("off")
plt.suptitle("Generated Samples", fontsize=14)
plt.tight_layout()
plt.show()

img

half_pixels = 14 * 28  # Top 14 rows
top_visible_nodes = visible_nodes[:half_pixels]
bottom_visible_nodes = visible_nodes[half_pixels:]

clamped_blocks = [top_visible_nodes, bottom_visible_nodes, hidden_nodes]
clamped_graph = BlockGraph(clamped_blocks, edges)
clamped_params = clamped_graph.get_sampling_params()

# Only sample bottom visible (block 1) and hidden (block 2), clamp top (block 0)
clamped_sampling_args = SamplingArgs(
    gibbs_steps=500,
    blocks_to_sample=[1, 2],
    data=clamped_params,
)
n_test = 8
test_batch = test_images[:n_test]


@jax.jit
@jax.vmap
def complete_image(visible_data, key):
    k1, k2, k_run = jax.random.split(key, 3)

    visible_spins = data_to_spins(visible_data).astype(jnp.int32)
    top_spins = visible_spins[:half_pixels]

    bottom_init = jax.random.choice(k1, jnp.array([-1, 1]), (n_visible - half_pixels,))
    hidden_init = jax.random.choice(k2, jnp.array([-1, 1]), (n_hidden,))

    states = sample_fn(
        [top_spins, bottom_init, hidden_init],
        [sampler, sampler, sampler],
        trained_model,
        clamped_sampling_args,
        k_run,
    )

    full_spins = jnp.concatenate([states[0][-1], states[1][-1]])
    return spins_to_data(full_spins)


key, subkey = jax.random.split(key)
keys = jax.random.split(subkey, n_test)
completed = complete_image(test_batch, keys)

fig, axes = plt.subplots(3, n_test, figsize=(16, 6))
for i in range(n_test):
    axes[0, i].imshow(test_batch[i].reshape(28, 28), cmap="gray")
    axes[0, i].axis("off")
    if i == 0:
        axes[0, i].set_ylabel("Original", fontsize=12)

    masked = test_batch[i].reshape(28, 28).copy()
    masked = jnp.array(masked).at[14:, :].set(0.5)
    axes[1, i].imshow(masked, cmap="gray", vmin=0, vmax=1)
    axes[1, i].axhline(14, color="red", linewidth=2)
    axes[1, i].axis("off")
    if i == 0:
        axes[1, i].set_ylabel("Input (top half)", fontsize=12)

    axes[2, i].imshow(completed[i].reshape(28, 28), cmap="gray")
    axes[2, i].axhline(14, color="red", linewidth=2)
    axes[2, i].axis("off")
    if i == 0:
        axes[2, i].set_ylabel("Completed", fontsize=12)

plt.tight_layout()
plt.show()

img

@jax.jit
@jax.vmap
def compute_data_energy(visible_data, key):
    visible_spins = data_to_spins(visible_data)
    hidden = jax.random.choice(key, jnp.array([-1, 1]), (n_hidden,))

    states = sample_fn(
        [visible_spins, hidden],
        [sampler, sampler],
        trained_model,
        pos_sampling_args,
        key,
    )

    state = jnp.concatenate([states[0][-1], states[1][-1]])
    return trained_model.energy(state, edge_indices, edge_mask)


n = 1000
key, subkey = jax.random.split(key)
data_keys = jax.random.split(subkey, n)
data_energies = compute_data_energy(train_images[:n], data_keys)

key, subkey = jax.random.split(key)
noise_keys = jax.random.split(subkey, n)
random_data = jax.random.bernoulli(subkey, 0.5, (n, n_visible)).astype(jnp.int32)
noise_energies = compute_data_energy(random_data, noise_keys)

print(f"Data energy: {jnp.mean(data_energies):.2f} +/- {jnp.std(data_energies):.2f}")
print(
    f"Random energy: {jnp.mean(noise_energies):.2f} +/- {jnp.std(noise_energies):.2f}"
)

plt.figure(figsize=(8, 4))
plt.hist(data_energies, bins=30, alpha=0.7, label="Data", density=True)
plt.hist(noise_energies, bins=30, alpha=0.7, label="Random", density=True)
plt.xlabel("Energy")
plt.ylabel("Density")
plt.legend()
plt.show()
Data energy: -3175.77 +/- 51.00
Random energy: -332.40 +/- 123.59

img