Skip to content

Math¤

distreqx.utils.math.multiply_no_nan: None ¤

distreqx.utils.math.power_no_nan: None ¤

distreqx.utils.math.mul_exp(x: Array, logp: Array) -> Array ¤

Returns x * exp(logp) with zero output if exp(logp) == 0.

Arguments:

  • x: An array.
  • logp: An array representing logarithms.

Returns:

  • The result of x * exp(logp).

distreqx.utils.math.normalize(*, probs: Optional[Array] = None, logits: Optional[Array] = None) -> Array ¤

Normalizes logits via log_softmax or probabilities to ensure they sum to one.

Arguments:

  • probs: Probability values.
  • logits: Logit values.

Returns:

  • Normalized probabilities or logits.

distreqx.utils.math.sum_last(x: Array, ndims: int) -> Array ¤

Sums the last ndims axes of array x.

Arguments:

  • x: An array.
  • ndims: The number of last dimensions to sum.

Returns:

  • The sum of the last ndims dimensions of x.

distreqx.utils.math.log_expbig_minus_expsmall(big: Array, small: Array) -> Array ¤

Stable implementation of log(exp(big) - exp(small)).

Arguments:

  • big: First input.
  • small: Second input. It must be small <= big.

Returns:

  • The resulting log(exp(big) - exp(small)).

distreqx.utils.math.log_beta(a: Array, b: Array) -> Array ¤

Obtains the log of the beta function log B(a, b).

Arguments:

  • a: First input. It must be positive.
  • b: Second input. It must be positive.

Returns:

  • The value log B(a, b) = log Gamma(a) + log Gamma(b) - log Gamma(a + b), where Gamma is the Gamma function, obtained through stable computation of log Gamma.

distreqx.utils.math.log_beta_multivariate(a: Array) -> Array ¤

Obtains the log of the multivariate beta function log B(a).

Arguments:

  • a: An array of length K containing positive values.

Returns:

  • The value log B(a) = sum_{k=1}^{K} log Gamma(a_k) - log Gamma(sum_{k=1}^{K} a_k), where Gamma is the Gamma function, obtained through stable computation of log Gamma.