mbi.MarginalOracle

class mbi.MarginalOracle(*args, **kwargs)[source]

Bases: Protocol

Defines the callable signature for stateless marginal oracle functions.

A marginal oracle consumes log-space potentials (CliqueVector) of a graphical model and returns its marginals (CliqueVector). The returned marginals will be defined over the same domain and set of cliques as the potentials.

Different marginal oracles should usually produce identical results, but they may have different time/space complexities and numerical stabilities. Examples of conforming functions from mbi.marginal_oracles:

  • message_passing_stable: Computes marginals using message passing, operating in log-space for numerical stability.

  • message_passing_fast: A faster and more memory efficient message passing algorithm that uses einsum, but it is not as stable as message_passing_stable.

  • brute_force_marginals: Computes marginals by materializing the full joint distribution.

  • einsum_marginals: Computes marginals using einsum, generally not recommended for large models.

Methods

__init__

__call__(potentials: CliqueVector, total: float = 1.0, mesh: Mesh | None = None) CliqueVector[source]

Computes marginals from log-space potentials.

Parameters:
  • potentials – A CliqueVector representing the log-space potentials of a graphical model.

  • total – The normalization factor, typically the total number of records or a probability sum. Defaults to 1.0.

  • mesh – An optional mesh which determines how the computation will be sharded across multiple machines.

Returns:

A CliqueVector of the computed marginals.