Jump Problems¤
Jump problems define the rates and effects of discrete jumps in the system.
Abstract Base Classes¤
jumpax.AbstractJumpProblem
¤
Abstract class for general jump problems.
init(u: Shaped[Array, '?*u'], args: PyTree[Any], key: Key[Array, '']) -> PyTree[Any]
¤
Initialize the state of the affect.
Arguments:
u: the current state of the systemargs: any static arguments as passed tojumpax.solvekey: a random key to use
Returns:
The initialized state.
rate(t: Real[ArrayLike, ''], u: Shaped[Array, '?*u'], args: PyTree[Any]) -> ~_Rate
¤
Computes the rates of each possible reaction.
Arguments:
t: current timeu: current stateargs: static arguments
Returns:
The rates for each possible reaction.
affect(t: Real[ArrayLike, ''], u: Shaped[Array, '?*u'], args: PyTree[Any], jump_state: PyTree[Any]) -> tuple
¤
Compute the result of applying some affect.
Arguments:
t: the time of the affectu: the current stateargs: any static argumentsjump_state: the state of the affect currently
Results:
The new state and the new affect state.
leap_delta(t: Real[ArrayLike, ''], u: Shaped[Array, '?*u'], args: PyTree[Any])
¤
Compute the per-reaction net state change.
Arguments:
t: current timeu: current stateargs: static arguments
Returns:
An array of shape (R, S) where R is the number of reactions and S is
the number of species, representing the net state change for each reaction.
jumpax.AbstractAffect
¤
Abstract base class for all affects.
init(u: Shaped[Array, '?*u'], args: PyTree[Any], key: Key[Array, '']) -> ~_JumpState
¤
Initialize the state of the affect.
Arguments:
u: the current state of the systemargs: any static arguments as passed tojumpax.solvekey: a random key to use
Returns:
The initialized state.
__call__(t: Real[ArrayLike, ''], u: Shaped[Array, '?*u'], args: PyTree[Any], jump_state: ~_JumpState) -> tuple
¤
Compute the result of applying the affect.
Arguments:
t: the time of the affectu: the current stateargs: any static argumentsjump_state: the state of the affect currently
Results:
The new state and the new affect state.
Jump Problem Types¤
jumpax.ConstantRateJump(jumpax.AbstractJumpProblem)
¤
A jump process with a rate \(\lambda(u)\) that depends only on the current state.
The rate function is evaluated only at jump times, making this suitable for processes where the rate does not depend on continuously evolving dynamics between jumps.
__init__(rate_fn: Callable, affect_fn: jumpax.AbstractAffect)
¤
Initialize self. See help(type(self)) for accurate signature.
jumpax.MassActionJump(jumpax.AbstractJumpProblem)
¤
Array-based mass-action reaction system.
The propensity for reaction \(j\) is given by:
where \(\kappa_j\) is the rate constant, \(u_i\) is the population of species \(i\), and \(r_{ji}\) is the stoichiometric coefficient of species \(i\) in reaction \(j\). The binomial coefficient \(\binom{n}{k} = 0\) when \(k > n\).
__init__(reactants: Int[Array, 'R S'], net_stoich: Int[Array, 'R S'], *, rates: Float[Array, 'R'])
¤
Arguments:
reactants: int array \((R, S)\) with nonnegative stoichiometry for each reaction \(j\) and species \(i\).net_stoich: int array \((R, S)\) with net state change per reaction.rates: float array \((R,)\) of rate constants \(\kappa_j\).
jumpax.VariableRateJump(jumpax.AbstractJumpProblem)
¤
A jump process with a rate \(\lambda(t, u)\) that depends on time and state.
Unlike jumpax.ConstantRateJump, the rate is evaluated continuously during
integration, making this suitable for processes coupled to ODEs or other
continuously evolving dynamics.
__init__(rate_fn: Callable, affect_fn: jumpax.AbstractAffect, leap_delta_fn: Callable | None = None)
¤
Initialize self. See help(type(self)) for accurate signature.
jumpax.StatelessAffect(jumpax.AbstractAffect)
¤
A convenience wrapper for the common use case of stateless affect functions.