mbi.estimation.interior_gradient

mbi.estimation.interior_gradient(domain: ~mbi.domain.Domain, loss_fn: ~mbi.marginal_loss.MarginalLossFn | list[~mbi.marginal_loss.LinearMeasurement], *, known_total: float | None = None, potentials: ~mbi.clique_vector.CliqueVector | None = None, marginal_oracle: ~mbi.marginal_oracles.MarginalOracle = <PjitFunction of <function message_passing_stable>>, iters: int = 1000, callback_fn: ~collections.abc.Callable[[~mbi.clique_vector.CliqueVector], None] = <function <lambda>>, mesh: ~jax._src.mesh.Mesh | None = None)[source]

Optimization using the Interior Point Gradient Descent algorithm.

Interior Gradient is an accelerated proximal algorithm for solving a smooth convex optimization problem over the marginal polytope. This algorithm requires knowledge of the Lipschitz constant of the gradient of the loss function. This algorithm is based on the paper titled [“Interior Gradient and Proximal Methods for Convex and Conic Optimization”](https://epubs.siam.org/doi/abs/10.1137/S1052623403427823?journalCode=sjope8).

Parameters:
  • domain – The domain over which the model should be defined.

  • loss_fn – A MarginalLossFn or a list of Linear Measurements.

  • lipschitz – The Lipschitz constant of the gradient of the loss function.

  • known_total – The known or estimated number of records in the data.

  • potentials – The initial potentials. Must be defind over a set of cliques that supports the cliques in the loss_fn.

  • marginal_oracle – The function to use to compute marginals from potentials.

  • iters – The maximum number of optimization iterations.

  • callback_fn – A function to call at each iteration with the iteration number.

  • mesh – Determines how the marginal oracle and loss calculation will be sharded across devices.

Returns:

A MarkovRandomField object with the optimized potentials and marginals.