mbi.estimation.dual_averaging
- mbi.estimation.dual_averaging(domain: ~mbi.domain.Domain, loss_fn: ~mbi.marginal_loss.MarginalLossFn | list[~mbi.marginal_loss.LinearMeasurement], *, known_total: float | None = None, potentials: ~mbi.clique_vector.CliqueVector | None = None, marginal_oracle: ~mbi.marginal_oracles.MarginalOracle = <PjitFunction of <function message_passing_stable>>, iters: int = 1000, callback_fn: ~collections.abc.Callable[[~mbi.clique_vector.CliqueVector], None] = <function <lambda>>, mesh: ~jax._src.mesh.Mesh | None = None) MarkovRandomField[source]
Optimization using the Regularized Dual Averaging (RDA) algorithm.
RDA is an accelerated proximal algorithm for solving a smooth convex optimization problem over the marginal polytope. This algorithm requires knowledge of the Lipschitz constant of the gradient of the loss function.
- Parameters:
domain – The domain over which the model should be defined.
loss_fn – A MarginalLossFn or a list of Linear Measurements.
lipschitz – The Lipschitz constant of the gradient of the loss function.
known_total – The known or estimated number of records in the data.
potentials – The initial potentials. Must be defind over a set of cliques that supports the cliques in the loss_fn.
marginal_oracle – The function to use to compute marginals from potentials.
iters – The maximum number of optimization iterations.
callback_fn – A function to call with intermediate solution at each iteration.
mesh – Determines how the marginal oracle and loss calculation will be sharded across devices.
- Returns:
A MarkovRandomField object with the final potentials and marginals.