Skip to content

Sampling¤

Models¤

isax.sample.IsingModel ¤

Ising model with arbitrary edge interactions and local fields.

energy(state: Int[Array, num_nodes], edge_indices: Int[Array, 'num_edges max_k'], edge_mask: Int[Array, 'num_edges max_k']) -> Float[Array, ''] ¤

Compute Ising energy.

\[H = -\sum_{e \in E} J_e \prod_{i \in e} s_i - \sum_i h_i s_i\]

Arguments:

  • state: Node spins, shape (num_nodes,).
  • edge_indices: Node indices for each edge, shape (num_edges, max_k).
  • edge_mask: Valid positions in edge_indices.

Returns:

Scalar energy.

to_sample_params(graph: EqxGraph, edge_info: list[jaxtyping.Int[Array, 'nodes max_edges']]) -> list[tuple[jaxtyping.Float[Array, num_edges], jaxtyping.Float[Array, num_nodes]]] ¤

Extract weights and biases for each block.

Arguments:

  • graph: Graph structure.
  • edge_info: Edge indices per block.

Returns:

List of (edge_weights, node_biases) tuples.

Samplers¤

isax.sample.AbstractSampler ¤

Base class for block-wise sampling algorithms.

isax.sample.IsingSampler ¤

Gibbs sampler for Ising model at fixed temperature (\(\beta=1\)).

sample(current_state: Int[Array, num_nodes], neighbor_states: Int[Array, 'num_nodes max_edges max_k-1'], neighbor_mask: Int[Array, 'num_nodes max_edges max_k-1'], model_params: tuple[jaxtyping.Float[Array, num_edges], jaxtyping.Float[Array, num_nodes]], runtime_params: None, sampler_state: None, key: Key[Array, '']) -> tuple[jaxtyping.Int[Array, num_nodes], None] ¤

Parallel Gibbs sampling for all nodes in block.

For each node \(i\), the effective field is computed as:

\[h_i^{\text{eff}} = h_i + \sum_{e \ni i} J_e \prod_{j \in e \setminus i} s_j\]

The probability of spin \(s_i = +1\) is then:

\[P(s_i = +1) = \sigma(2h_i^{\text{eff}}) = \frac{1}{1 + e^{-2h_i^{\text{eff}}}}\]

where \(\sigma\) is the sigmoid function.

Arguments:

  • current_state: Current spin values for nodes in this block.
  • neighbor_states: Spin values of neighbors from all edges.
  • neighbor_mask: Boolean mask for valid positions in neighbor_states.
  • model_params: Tuple of (edge_weights, node_biases) for this block.
  • runtime_params: None.
  • sampler_state: None.
  • key: JAX random key for stochastic sampling.

Returns:

Tuple of (new_block_state, None) where new_block_state contains the sampled spin configuration.

initialize_state() -> None ¤

Return None for stateless sampler.

isax.sample.AnnealedIsingSampler ¤

Simulated annealing sampler with time-varying temperature.

__init__(beta_fn: typing.Callable[[jaxtyping.Int[Array, '']], jaxtyping.Float[Array, '']]) ¤

Arguments:

  • beta_fn: Function mapping timestep \(t\) to inverse temperature \(\beta(t)\). Can also curry beta arrays for more complex schedules.
sample(current_state: Int[Array, num_nodes], neighbor_states: Int[Array, 'num_nodes max_edges max_k-1'], neighbor_mask: Int[Array, 'num_nodes max_edges max_k-1'], model_params: tuple[jaxtyping.Float[Array, num_edges], jaxtyping.Float[Array, num_nodes]], runtime_params: None, sampler_state: Int[Array, ''], key: Key[Array, '']) -> tuple[jaxtyping.Int[Array, num_nodes], jaxtyping.Int[Array, '']] ¤

Sample with temperature based on current timestep.

Arguments:

  • current_state: Current spin values for nodes in this block.
  • neighbor_states: Spin values of neighbors from all edges.
  • neighbor_mask: Boolean mask for valid positions in neighbor_states.
  • model_params: Tuple of (edge_weights, node_biases) for this block.
  • runtime_params: Unused, provided for API compatibility.
  • sampler_state: Current timestep (integer scalar).
  • key: JAX random key for stochastic sampling.

Returns:

Tuple of (new_block_state, new_sampler_state) where new_sampler_state is incremented by 1.

initialize_state() -> Int[Array, ''] ¤

Initialize time step to 0.

Sampling Utilities¤

isax.sample.SamplingArgs ¤

Configuration and data for block-wise Gibbs sampling.

__init__(gibbs_steps: int, blocks_to_sample: list[int], data: tuple[tuple, isax.block.EqxGraph]) -> None ¤

Arguments:

  • gibbs_steps: Number of Gibbs sampling iterations to perform.
  • blocks_to_sample: List of block indices to update (e.g., [0, 2] updates blocks 0 and 2).
  • data: Output from BlockGraph.get_sampling_params() containing adjacency information and graph structure.

isax.sample.sample_chain(block_states: list[jaxtyping.Int[Array, block_size]], samplers: list[isax.sample.AbstractSampler], model: IsingModel, sampling_args: SamplingArgs, key: Key[Array, '']) -> list[jaxtyping.Int[Array, 'gibbs_steps block_size']] ¤

Run Gibbs sampling chain for specified number of steps.

Arguments:

  • block_states: Initial spin states for each block.
  • samplers: List of sampler instances, one per block.
  • model: Energy model providing weights and biases.
  • sampling_args: Configuration with number of steps and blocks to sample.
  • key: JAX random key.

Returns:

List of arrays containing the history of block states over all Gibbs steps.