mbi.approximate_oracles.StatefulMarginalOracle
- class mbi.approximate_oracles.StatefulMarginalOracle(*args, **kwargs)[source]
Bases:
ProtocolDefines the callable signature for stateful marginal oracle functions.
A stateful marginal oracle computes (approximate) marginals from log-space potentials while also managing an internal state, often for optimization in iterative algorithms (e.g., preserving messages in message passing).
Methods
__init__- __call__(potentials: CliqueVector, total: float = 1.0, state: Any = None, mesh: Mesh | None = None) tuple[CliqueVector, Any][source]
Computes marginals from log-space potentials and manages state.
- Parameters:
potentials – A CliqueVector representing the log-space potentials of a graphical model.
total – The normalization factor, typically the total number of records or a probability sum. Defaults to 1.0.
state – An optional argument to pass state between calls. The oracle may use this state and return an updated version.
mesh – Specifies how the computation will be sharded across devices.
- Returns:
CliqueVector: The computed marginals.
Any: The updated state.
- Return type:
A tuple containing