Math¤
distreqx.utils.math.multiply_no_nan: None
¤
distreqx.utils.math.power_no_nan: None
¤
distreqx.utils.math.mul_exp(x: Array, logp: Array) -> Array
¤
Returns x * exp(logp) with zero output if exp(logp) == 0.
Arguments:
x: An array.logp: An array representing logarithms.
Returns:
- The result of
x * exp(logp).
distreqx.utils.math.normalize(*, probs: Optional[Array] = None, logits: Optional[Array] = None) -> Array
¤
Normalizes logits via log_softmax or probabilities to ensure they sum to one.
Arguments:
probs: Probability values.logits: Logit values.
Returns:
- Normalized probabilities or logits.
distreqx.utils.math.sum_last(x: Array, ndims: int) -> Array
¤
Sums the last ndims axes of array x.
Arguments:
x: An array.ndims: The number of last dimensions to sum.
Returns:
- The sum of the last
ndimsdimensions ofx.
distreqx.utils.math.log_expbig_minus_expsmall(big: Array, small: Array) -> Array
¤
Stable implementation of log(exp(big) - exp(small)).
Arguments:
big: First input.small: Second input. It must besmall <= big.
Returns:
- The resulting
log(exp(big) - exp(small)).
distreqx.utils.math.log_beta(a: Array, b: Array) -> Array
¤
Obtains the log of the beta function log B(a, b).
Arguments:
a: First input. It must be positive.b: Second input. It must be positive.
Returns:
- The value
log B(a, b) = log Gamma(a) + log Gamma(b) - log Gamma(a + b), whereGammais the Gamma function, obtained through stable computation oflog Gamma.
distreqx.utils.math.log_beta_multivariate(a: Array) -> Array
¤
Obtains the log of the multivariate beta function log B(a).
Arguments:
a: An array of lengthKcontaining positive values.
Returns:
- The value
log B(a) = sum_{k=1}^{K} log Gamma(a_k) - log Gamma(sum_{k=1}^{K} a_k), whereGammais the Gamma function, obtained through stable computation oflog Gamma.