mbi.marginal_oracles.logspace_sum_product_fast

mbi.marginal_oracles.logspace_sum_product_fast(log_factors: list[~mbi.factor.Factor], dom: ~mbi.domain.Domain, einsum_fn=<function einsum>) Factor[source]

Numerically stable algorithm for computing sum product in log space.

This seems to be the most stable algorithm for doing this computation that doesn’t require materializing sum(log_factors). Materializing sum(log_factors) will in general give better numerical stability, but it comes at the cost of increased memory usage. This can be potentially mitigated by using scan_einsum with an appropriately chosen sequential kwarg from mbi.einsum.

https://github.com/jax-ml/jax/issues/24915

https://stackoverflow.com/questions/23630277/numerically-stable-way-to-multiply-log-probability-matrices-in-numpy

Parameters:
  • log_factors – a list of log-space factors.

  • dom – The desired domain of the output factor.

Returns:

log sum_{S - D} prod_i exp(F_i), where

  • F_i = log_factors[i],

  • D is the input domain,

  • S is the union of the domains of F_i