mbi.marginal_oracles.logspace_sum_product_fast
- mbi.marginal_oracles.logspace_sum_product_fast(log_factors: list[~mbi.factor.Factor], dom: ~mbi.domain.Domain, einsum_fn=<function einsum>) Factor[source]
Numerically stable algorithm for computing sum product in log space.
This seems to be the most stable algorithm for doing this computation that doesn’t require materializing sum(log_factors). Materializing sum(log_factors) will in general give better numerical stability, but it comes at the cost of increased memory usage. This can be potentially mitigated by using scan_einsum with an appropriately chosen sequential kwarg from mbi.einsum.
https://github.com/jax-ml/jax/issues/24915
- Parameters:
log_factors – a list of log-space factors.
dom – The desired domain of the output factor.
- Returns:
log sum_{S - D} prod_i exp(F_i), where
F_i = log_factors[i],
D is the input domain,
S is the union of the domains of F_i