mbi.marginal_oracles.message_passing_fast
- mbi.marginal_oracles.message_passing_fast(potentials: CliqueVector, total: float = 1, mesh: Mesh | None = None, einsum_fn=<function einsum>, jtree: Graph | None = None, logspace_sum_product_fn=<function logspace_sum_product_fast>) CliqueVector[source]
Compute marginals from (log-space) potentials using the message passing algorithm.
This implementation leverages the “einsum” primitive to compute clique marginals without materializing marginals over the super cliques first (nodes in the junction tree). It can be much faster and more memory efficient than message_passing_stable, but there are some cases where this implementation is not as stable.
See the stackoverflow thread for the key difficulty here. https://stackoverflow.com/questions/23630277/numerically-stable-way-to-multiply-log-probability-matrices-in-numpy
- Parameters:
potentials – The (log-space) potentials of a graphical model.
total – The normalization factor.
mesh – The mesh over which the computation should be sharded.
einsum_fn – A function with the same API and semantics as jnp.einsum.
jtree – An optional junction tree that defines the message passing order.
- Returns:
The marginals of the graphical model, defined over the same set of cliques as the input potentials. Each marginal is non-negative and sums to “total”.