Jumpax¤
Jumpax is a JAX-based library providing numerical jump process solvers.
Heavily adapted/inspired by JumpProcesses.jl and diffrax.
Installation¤
git clone https://github.com/lockwo/jumpax.git
cd jumpax
pip install -e .
Requires Python >= 3.10
Quick example¤
Simulate a simple birth-death process:
import jax.numpy as jnp
import jumpax as jpx
# Reactant stoichiometry: birth needs 0, death needs 1
reactants = jnp.array([[0], [1]])
# Net stoichiometry: birth adds 1, death removes 1
net_stoich = jnp.array([[1], [-1]])
# Rate constants
rates = jnp.array([10.0, 0.1])
jumps = jpx.MassActionJump(reactants, net_stoich, rates=rates)
solver = jpx.SSA()
save = jpx.Save(states=True)
u0 = jnp.array([50.0]) # initial population
key = jax.random.key(0)
sol = jpx.solve(jumps, solver, save, u0, t0=0.0, t1=1.0, key=key)
mask = jnp.isfinite(sol.ts)
ts, us = sol.ts[mask], sol.us[mask]
print(f"Final population: {us[-1]}")
Citation¤
If you found this library useful in academic research, please cite:
(Also consider starring the project on GitHub.)
See also: other libraries in the JAX ecosystem¤
Awesome JAX: a longer list of other JAX projects.