mbi.marginal_oracles

Functions for computing marginals from log-space potentials.

The functions in this library should all produce numerically identical outputs on well-behaved inputs, but may have different stability characteristics on poorly-behaved inputs, and different rutnime/memory performance characteristics.

We recommend using message_passing_stable with accelerated estimation algorithms like Interior Gradient, but using message_passing_fast with mirror descent.

Functions

brute_force_marginals

Compute marginals from (log-space) potentials by materializing the full joint distribution.

bulk_variable_elimination

Compute the marginals of the graphical model with the given potentials.

calculate_many_marginals

Calculate marginals for all projections using belief propagation.

einsum_marginals

Compute marginals from (log-space) potentials by using einsum.

kron_query

logspace_sum_product_fast

Numerically stable algorithm for computing sum product in log space.

logspace_sum_product_stable_v1

More stable implementation of logspace_sum_product.

message_passing_fast

Compute marginals from (log-space) potentials using the message passing algorithm.

message_passing_shafer_shenoy

Compute marginals from (log-space) potentials using the Shafer-Shenoy algorithm.

message_passing_stable

Compute marginals from (log-space) potentials using the message passing algorithm.

sum_product

Compute the sum-of-products of a list of Factors using einsum.

variable_elimination

Compute an out-of-model/unsupported marginal from the potentials.

Classes

MarginalOracle

Defines the callable signature for stateless marginal oracle functions.

class mbi.marginal_oracles.MarginalOracle(*args, **kwargs)[source]

Bases: Protocol

Defines the callable signature for stateless marginal oracle functions.

A marginal oracle consumes log-space potentials (CliqueVector) of a graphical model and returns its marginals (CliqueVector). The returned marginals will be defined over the same domain and set of cliques as the potentials.

Different marginal oracles should usually produce identical results, but they may have different time/space complexities and numerical stabilities. Examples of conforming functions from mbi.marginal_oracles:

  • message_passing_stable: Computes marginals using message passing, operating in log-space for numerical stability.

  • message_passing_fast: A faster and more memory efficient message passing algorithm that uses einsum, but it is not as stable as message_passing_stable.

  • brute_force_marginals: Computes marginals by materializing the full joint distribution.

  • einsum_marginals: Computes marginals using einsum, generally not recommended for large models.

mbi.marginal_oracles.sum_product(factors: list[~mbi.factor.Factor], dom: ~mbi.domain.Domain, einsum_fn: ~collections.abc.Callable = <function einsum>) Factor[source]

Compute the sum-of-products of a list of Factors using einsum.

Parameters:
  • factors – A list of Factors.

  • dom – The target domain of the output factor.

Returns:

sum_{S - D} prod_i F_i, where

  • F_i = factors[i]

  • D = dom

  • S = union of domains of F_i

mbi.marginal_oracles.logspace_sum_product_fast(log_factors: list[~mbi.factor.Factor], dom: ~mbi.domain.Domain, einsum_fn: ~collections.abc.Callable = <function einsum>) Factor[source]

Numerically stable algorithm for computing sum product in log space.

This seems to be the most stable algorithm for doing this computation that doesn’t require materializing sum(log_factors). Materializing sum(log_factors) will in general give better numerical stability, but it comes at the cost of increased memory usage. This can be potentially mitigated by using scan_einsum with an appropriately chosen sequential kwarg from mbi.einsum.

https://github.com/jax-ml/jax/issues/24915

https://stackoverflow.com/questions/23630277/numerically-stable-way-to-multiply-log-probability-matrices-in-numpy

Parameters:
  • log_factors – a list of log-space factors.

  • dom – The desired domain of the output factor.

Returns:

log sum_{S - D} prod_i exp(F_i), where

  • F_i = log_factors[i],

  • D is the input domain,

  • S is the union of the domains of F_i

mbi.marginal_oracles.logspace_sum_product_stable_v1(log_factors: list[~mbi.factor.Factor], dom: ~mbi.domain.Domain, einsum_fn: ~collections.abc.Callable = <function einsum>) Factor[source]

More stable implementation of logspace_sum_product.

This ipmlementation may (or may not) materialize a Factor over the domain of all elements of log_factors. Without JIT, it will materialize this “super-factor”. Under JIT, there may be some instances where the compiler can figure out that it does not need to materialize this intermediate to compute the final output.

mbi.marginal_oracles.brute_force_marginals(potentials: CliqueVector, total: float = 1, mesh: Mesh | None = None) CliqueVector[source]

Compute marginals from (log-space) potentials by materializing the full joint distribution.

mbi.marginal_oracles.einsum_marginals(potentials: CliqueVector, total: float = 1, mesh: Mesh | None = None, einsum_fn: Callable = <function einsum>) CliqueVector[source]

Compute marginals from (log-space) potentials by using einsum.

This is a “brute-force” approach and is not recommended in practice.

mbi.marginal_oracles.message_passing_stable(potentials: CliqueVector, total: float = 1, mesh: Mesh | None = None, jtree: Graph | None = None) CliqueVector[source]

Compute marginals from (log-space) potentials using the message passing algorithm.

This implementation operates completely in logspace, until the last step where it exponentiates the log-beliefs to get marginals. It is very stable numerically, but in general could materialize factors defined over “super-cliques”, which are the nodes in the junction tree implied by the cliques in potentials. Thus, it may require more memory than “message_passing_fast” below.

Parameters:
  • potentials – The (log-space) potentials of a graphical model.

  • total – The normalization factor.

  • mesh – The mesh over which the computation should be sharded.

  • jtree – An optional junction tree that defines the message passing order.

Returns:

The marginals of the graphical model, defined over the same set of cliques as the input potentials. Each marginal is non-negative and sums to “total”.

mbi.marginal_oracles.message_passing_shafer_shenoy(potentials: CliqueVector, total: float = 1, mesh: Mesh | None = None, jtree: Graph | None = None) CliqueVector[source]

Compute marginals from (log-space) potentials using the Shafer-Shenoy algorithm.

This implementation operates completely in logspace, and is more stable than message_passing_stable when potentials contain -inf values. It avoids subtraction of log-probabilities (division in probability space) which can lead to NaNs when dealing with zero probabilities.

Parameters:
  • potentials – The (log-space) potentials of a graphical model.

  • total – The normalization factor.

  • mesh – The mesh over which the computation should be sharded.

  • jtree – An optional junction tree that defines the message passing order.

Returns:

The marginals of the graphical model, defined over the same set of cliques as the input potentials. Each marginal is non-negative and sums to “total”.

mbi.marginal_oracles.message_passing_fast(potentials: CliqueVector, total: float = 1, mesh: Mesh | None = None, einsum_fn: Callable = <function einsum>, jtree: Graph | None = None, logspace_sum_product_fn=<function logspace_sum_product_fast>) CliqueVector[source]

Compute marginals from (log-space) potentials using the message passing algorithm.

This implementation leverages the “einsum” primitive to compute clique marginals without materializing marginals over the super cliques first (nodes in the junction tree). It can be much faster and more memory efficient than message_passing_stable, but there are some cases where this implementation is not as stable.

See the stackoverflow thread for the key difficulty here. https://stackoverflow.com/questions/23630277/numerically-stable-way-to-multiply-log-probability-matrices-in-numpy

Parameters:
  • potentials – The (log-space) potentials of a graphical model.

  • total – The normalization factor.

  • mesh – The mesh over which the computation should be sharded.

  • einsum_fn – A function with the same API and semantics as jnp.einsum.

  • jtree – An optional junction tree that defines the message passing order.

Returns:

The marginals of the graphical model, defined over the same set of cliques as the input potentials. Each marginal is non-negative and sums to “total”.

mbi.marginal_oracles.variable_elimination(potentials: CliqueVector, clique: tuple[str, ...], total: float = 1, mesh: Mesh | None = None, evidence: dict[str, int] | None = None) Factor[source]

Compute an out-of-model/unsupported marginal from the potentials.

Parameters:
  • potentials – The (log-space) potentials of a Graphical Model.

  • clique – The subset of attributes whose marginal you want.

  • total – The normalization factor.

  • mesh – The mesh over which the computation should be sharded.

  • evidence – A dictionary mapping attribute names to observed values.

Returns:

The marginal defined over the domain of the input clique, where each entry is non-negative and sums to the input total.

mbi.marginal_oracles.bulk_variable_elimination(potentials: CliqueVector, marginal_queries: list[tuple[str, ...]], total: float = 1.0, mesh: Mesh | None = None) CliqueVector[source]

Compute the marginals of the graphical model with the given potentials.

Unlike other marginal oracles, which only compute marginals for cliques in the potentials vector, this function can compute arbitrary marginals from an arbitrary model. Both runtime and compilation time can be expensive when there are a large number of marginal queries. This function compiles and runs variable_elimination for one query at a time, using parallelism and asyncronous computation do do the compilation in the background, while running variable_eliminatoin sequentially one query at a time.

Parameters:
  • potentials – The (log-space) potentials of a Graphical Model.

  • marginal_queries – A list of cliques to obtain marginals for.

  • total – The normalization factor.

  • mesh – The mesh over which the computation should be sharded.

Returns:

A CliqueVector with the marginals computed over the specified cliques.

mbi.marginal_oracles.calculate_many_marginals(potentials: ~mbi.clique_vector.CliqueVector, marginal_queries: list[tuple[str, ...]], total: float = 1.0, belief_propagation_oracle: ~mbi.marginal_oracles.MarginalOracle = <PjitFunction of <function message_passing_stable>>, mesh: ~jax._src.mesh.Mesh | None = None) CliqueVector[source]

Calculate marginals for all projections using belief propagation.

Implements Algorithm from section 10.3 in Koller and Friedman. This method may be faster than calling variable_elimination many times. Note: this implementation is experimental, and further work may be needed to optimize it. Contributions are welcome.

Parameters:
  • potentials – Potentials of a graphical model.

  • marginal_queries – a list of cliques whose marginals are desired.

Returns:

A CliqueVector, where each defined over the list of input marginal_queries.

mbi.marginal_oracles.kron_query(potentials: CliqueVector, query_factors: dict[str, Array], total: float = 1, mesh: Mesh | None = None, suffix: str = '_answer') Factor[source]