Math¤
distreqx.utils.math.multiply_no_nan: None
¤
distreqx.utils.math.power_no_nan: None
¤
distreqx.utils.math.mul_exp(x: Array, logp: Array) -> Array
¤
Returns x * exp(logp)
with zero output if exp(logp) == 0
.
Arguments:
x
: An array.logp
: An array representing logarithms.
Returns:
- The result of
x * exp(logp)
.
distreqx.utils.math.normalize(*, probs: Optional[Array] = None, logits: Optional[Array] = None) -> Array
¤
Normalizes logits via log_softmax or probabilities to ensure they sum to one.
Arguments:
probs
: Probability values.logits
: Logit values.
Returns:
- Normalized probabilities or logits.
distreqx.utils.math.sum_last(x: Array, ndims: int) -> Array
¤
Sums the last ndims
axes of array x
.
Arguments:
x
: An array.ndims
: The number of last dimensions to sum.
Returns:
- The sum of the last
ndims
dimensions ofx
.
distreqx.utils.math.log_expbig_minus_expsmall(big: Array, small: Array) -> Array
¤
Stable implementation of log(exp(big) - exp(small))
.
Arguments:
big
: First input.small
: Second input. It must besmall <= big
.
Returns:
- The resulting
log(exp(big) - exp(small))
.
distreqx.utils.math.log_beta(a: Array, b: Array) -> Array
¤
Obtains the log of the beta function log B(a, b)
.
Arguments:
a
: First input. It must be positive.b
: Second input. It must be positive.
Returns:
- The value
log B(a, b) = log Gamma(a) + log Gamma(b) - log Gamma(a + b)
, whereGamma
is the Gamma function, obtained through stable computation oflog Gamma
.
distreqx.utils.math.log_beta_multivariate(a: Array) -> Array
¤
Obtains the log of the multivariate beta function log B(a)
.
Arguments:
a
: An array of lengthK
containing positive values.
Returns:
- The value
log B(a) = sum_{k=1}^{K} log Gamma(a_k) - log Gamma(sum_{k=1}^{K} a_k)
, whereGamma
is the Gamma function, obtained through stable computation oflog Gamma
.