Sampling¤
Models¤
isax.sample.IsingModel
¤
Ising model with arbitrary edge interactions and local fields.
energy(state: Int[Array, num_nodes], edge_indices: Int[Array, 'num_edges max_k'], edge_mask: Int[Array, 'num_edges max_k']) -> Float[Array, '']
¤
Compute Ising energy.
Arguments:
state: Node spins, shape(num_nodes,).edge_indices: Node indices for each edge, shape(num_edges, max_k).edge_mask: Valid positions inedge_indices.
Returns:
Scalar energy.
to_sample_params(graph: EqxGraph, edge_info: list[jaxtyping.Int[Array, 'nodes max_edges']]) -> list[tuple[jaxtyping.Float[Array, num_edges], jaxtyping.Float[Array, num_nodes]]]
¤
Extract weights and biases for each block.
Arguments:
graph: Graph structure.edge_info: Edge indices per block.
Returns:
List of (edge_weights, node_biases) tuples.
Samplers¤
isax.sample.AbstractSampler
¤
Base class for block-wise sampling algorithms.
isax.sample.IsingSampler
¤
Gibbs sampler for Ising model at fixed temperature (\(\beta=1\)).
sample(current_state: Int[Array, num_nodes], neighbor_states: Int[Array, 'num_nodes max_edges max_k-1'], neighbor_mask: Int[Array, 'num_nodes max_edges max_k-1'], model_params: tuple[jaxtyping.Float[Array, num_edges], jaxtyping.Float[Array, num_nodes]], runtime_params: None, sampler_state: None, key: Key[Array, '']) -> tuple[jaxtyping.Int[Array, num_nodes], None]
¤
Parallel Gibbs sampling for all nodes in block.
For each node \(i\), the effective field is computed as:
The probability of spin \(s_i = +1\) is then:
where \(\sigma\) is the sigmoid function.
Arguments:
current_state: Current spin values for nodes in this block.neighbor_states: Spin values of neighbors from all edges.neighbor_mask: Boolean mask for valid positions inneighbor_states.model_params: Tuple of(edge_weights, node_biases)for this block.runtime_params: None.sampler_state: None.key: JAX random key for stochastic sampling.
Returns:
Tuple of (new_block_state, None) where new_block_state contains the sampled
spin configuration.
initialize_state() -> None
¤
Return None for stateless sampler.
isax.sample.AnnealedIsingSampler
¤
Simulated annealing sampler with time-varying temperature.
__init__(beta_fn: typing.Callable[[jaxtyping.Int[Array, '']], jaxtyping.Float[Array, '']])
¤
Arguments:
beta_fn: Function mapping timestep \(t\) to inverse temperature \(\beta(t)\). Can also curry beta arrays for more complex schedules.
sample(current_state: Int[Array, num_nodes], neighbor_states: Int[Array, 'num_nodes max_edges max_k-1'], neighbor_mask: Int[Array, 'num_nodes max_edges max_k-1'], model_params: tuple[jaxtyping.Float[Array, num_edges], jaxtyping.Float[Array, num_nodes]], runtime_params: None, sampler_state: Int[Array, ''], key: Key[Array, '']) -> tuple[jaxtyping.Int[Array, num_nodes], jaxtyping.Int[Array, '']]
¤
Sample with temperature based on current timestep.
Arguments:
current_state: Current spin values for nodes in this block.neighbor_states: Spin values of neighbors from all edges.neighbor_mask: Boolean mask for valid positions inneighbor_states.model_params: Tuple of(edge_weights, node_biases)for this block.runtime_params: Unused, provided for API compatibility.sampler_state: Current timestep (integer scalar).key: JAX random key for stochastic sampling.
Returns:
Tuple of (new_block_state, new_sampler_state) where new_sampler_state is
incremented by 1.
initialize_state() -> Int[Array, '']
¤
Initialize time step to 0.
Sampling Utilities¤
isax.sample.SamplingArgs
¤
Configuration and data for block-wise Gibbs sampling.
__init__(gibbs_steps: int, blocks_to_sample: list[int], data: tuple[tuple, isax.block.EqxGraph]) -> None
¤
Arguments:
gibbs_steps: Number of Gibbs sampling iterations to perform.blocks_to_sample: List of block indices to update (e.g.,[0, 2]updates blocks 0 and 2).data: Output fromBlockGraph.get_sampling_params()containing adjacency information and graph structure.
isax.sample.sample_chain(block_states: list[jaxtyping.Int[Array, block_size]], samplers: list[isax.sample.AbstractSampler], model: IsingModel, sampling_args: SamplingArgs, key: Key[Array, '']) -> list[jaxtyping.Int[Array, 'gibbs_steps block_size']]
¤
Run Gibbs sampling chain for specified number of steps.
Arguments:
block_states: Initial spin states for each block.samplers: List of sampler instances, one per block.model: Energy model providing weights and biases.sampling_args: Configuration with number of steps and blocks to sample.key: JAX random key.
Returns:
List of arrays containing the history of block states over all Gibbs steps.