mbi.approximate_oracles.StatefulMarginalOracle

class mbi.approximate_oracles.StatefulMarginalOracle(*args, **kwargs)[source]

Bases: Protocol

Defines the callable signature for stateful marginal oracle functions.

A stateful marginal oracle computes (approximate) marginals from log-space potentials while also managing an internal state, often for optimization in iterative algorithms (e.g., preserving messages in message passing).

Methods

__init__

__call__(potentials: CliqueVector, total: float = 1.0, state: Any = None, mesh: Mesh | None = None) tuple[CliqueVector, Any][source]

Computes marginals from log-space potentials and manages state.

Parameters:
  • potentials – A CliqueVector representing the log-space potentials of a graphical model.

  • total – The normalization factor, typically the total number of records or a probability sum. Defaults to 1.0.

  • state – An optional argument to pass state between calls. The oracle may use this state and return an updated version.

  • mesh – Specifies how the computation will be sharded across devices.

Returns:

  • CliqueVector: The computed marginals.

  • Any: The updated state.

Return type:

A tuple containing