Flips Per Secondยค
In this example we will be looking as Flips Per Second (FPS) as a metric for evaluating the speed of a Gibbs sampling program.
import time
import dwave_networkx
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import networkx as nx
from isax import (
BlockGraph,
Edge,
IsingModel,
IsingSampler,
Node,
sample_chain,
SamplingArgs,
)
def create_dwave_pegasus_graph(pegasus_size, key):
graph = dwave_networkx.pegasus_graph(pegasus_size)
coord_to_node = {coord: Node() for coord in graph.nodes}
nx.relabel_nodes(graph, coord_to_node, copy=False)
nodes = list(graph.nodes)
edges = [Edge(u, v) for u, v in graph.edges()]
# todo: dwave ships colorings already?
coloring = nx.coloring.greedy_color(graph, strategy="DSATUR")
n_colors = max(coloring.values()) + 1
blocks = [[] for _ in range(n_colors)]
for node in graph.nodes:
blocks[coloring[node]].append(node)
block_graph = BlockGraph(blocks, edges)
key1, key2 = jax.random.split(key)
biases = jax.random.uniform(key1, (len(nodes),), minval=-0.1, maxval=0.1)
weights = jax.random.uniform(key2, (len(edges),), minval=-0.1, maxval=0.1)
model = IsingModel(weights=weights, biases=biases)
print(
f"Created graph with {len(nodes)} nodes, {len(edges)} edges, {n_colors} blocks"
)
return model, block_graph, nodes, blocks
def time_sampling(
model, block_graph, nodes, blocks, chain_len, batch_size, n_reps, device
):
key = jax.random.key(42)
(adjs, masks, edge_infos), eqx_graph = block_graph.get_sampling_params()
samplers = [IsingSampler() for _ in range(len(blocks))]
sampling_args = SamplingArgs(
gibbs_steps=chain_len,
blocks_to_sample=list(range(len(blocks))),
adjs=adjs,
masks=masks,
edge_info=edge_infos,
eqx_graph=eqx_graph,
)
def sample_batch(key):
keys = jax.random.split(key, batch_size)
def sample_single(single_key):
k_init, k_run = jax.random.split(single_key)
init_state = []
for block in blocks:
block_state = (
jax.random.bernoulli(k_init, 0.5, (len(block),)).astype(jnp.int32)
* 2
- 1
)
init_state.append(block_state)
samples = sample_chain(init_state, samplers, model, sampling_args, k_run)
return samples
return jax.vmap(sample_single)(keys)
jit_sample_batch = jax.jit(sample_batch, device=device)
keys = jax.random.split(key, n_reps)
start_time = time.time()
_ = jax.block_until_ready(jit_sample_batch(keys[0]))
time_with_compile = time.time() - start_time
start_time = time.time()
for i in range(n_reps):
_ = jax.block_until_ready(jit_sample_batch(keys[i]))
trials_time = time.time() - start_time
time_without_compile = trials_time / n_reps
thruput = chain_len * batch_size * len(nodes)
flips_per_ns = thruput / (time_without_compile * 1e9)
print(f"chain_len: {chain_len}, batch_size: {batch_size}")
print(
f"Time with compile: {time_with_compile:.4f}s, "
f"time without compile: {time_without_compile:.4f}s, "
f"flips per ns: {flips_per_ns:.4f}, thruput: {thruput}"
)
return flips_per_ns
pegasus_size = 14
chain_len = 1000
n_reps = 2
batch_sizes = [1, 4, 16, 64, 128, 256, 1024]
try:
device_gpu = jax.devices("cuda")[0]
has_gpu = True
except Exception as _:
has_gpu = False
device_cpu = jax.devices("cpu")[0]
# Create the graph once and reuse it
key = jax.random.key(42)
model, block_graph, nodes, blocks = create_dwave_pegasus_graph(pegasus_size, key)
flips_per_ns_cpu = []
for batch_size in batch_sizes:
val = time_sampling(
model, block_graph, nodes, blocks, chain_len, batch_size, n_reps, device_cpu
)
flips_per_ns_cpu.append(val)
if has_gpu:
flips_per_ns_gpu = []
for batch_size in batch_sizes:
val = time_sampling(
model, block_graph, nodes, blocks, chain_len, batch_size, n_reps, device_gpu
)
flips_per_ns_gpu.append(val)
Created Pegasus graph with 4264 nodes, 30404 edges, 4 color blocks
chain_len: 1000, batch_size: 1
Time with compile: 0.5638s, time without compile: 0.2302s, flips per ns: 0.0185, thruput: 4264000
chain_len: 1000, batch_size: 4
Time with compile: 0.9512s, time without compile: 0.5929s, flips per ns: 0.0288, thruput: 17056000
chain_len: 1000, batch_size: 16
Time with compile: 1.6824s, time without compile: 1.3140s, flips per ns: 0.0519, thruput: 68224000
chain_len: 1000, batch_size: 64
Time with compile: 2.9053s, time without compile: 2.5344s, flips per ns: 0.1077, thruput: 272896000
chain_len: 1000, batch_size: 128
Time with compile: 4.1339s, time without compile: 3.7625s, flips per ns: 0.1451, thruput: 545792000
chain_len: 1000, batch_size: 256
Time with compile: 5.6201s, time without compile: 5.2584s, flips per ns: 0.2076, thruput: 1091584000
chain_len: 1000, batch_size: 1024
Time with compile: 15.3236s, time without compile: 15.0730s, flips per ns: 0.2897, thruput: 4366336000
chain_len: 1000, batch_size: 1
Time with compile: 1.2708s, time without compile: 0.0285s, flips per ns: 0.1498, thruput: 4264000
chain_len: 1000, batch_size: 4
Time with compile: 1.2622s, time without compile: 0.0339s, flips per ns: 0.5034, thruput: 17056000
chain_len: 1000, batch_size: 16
Time with compile: 1.4971s, time without compile: 0.0455s, flips per ns: 1.5007, thruput: 68224000
chain_len: 1000, batch_size: 64
Time with compile: 1.2715s, time without compile: 0.0641s, flips per ns: 4.2579, thruput: 272896000
chain_len: 1000, batch_size: 128
Time with compile: 1.2462s, time without compile: 0.0933s, flips per ns: 5.8480, thruput: 545792000
chain_len: 1000, batch_size: 256
Time with compile: 1.3404s, time without compile: 0.1517s, flips per ns: 7.1960, thruput: 1091584000
chain_len: 1000, batch_size: 1024
Time with compile: 1.7359s, time without compile: 0.4824s, flips per ns: 9.0515, thruput: 4366336000
plt.figure(figsize=(6, 6))
if has_gpu:
plt.plot(batch_sizes, flips_per_ns_gpu, label="isax GPU", marker="o")
plt.plot(batch_sizes, flips_per_ns_cpu, label="isax CPU", marker="s")
plt.legend()
plt.xlabel("Batch size")
plt.xscale("log")
plt.yscale("log")
plt.ylabel("Flips per ns")
plt.show()
