mbi.estimation.lbfgs

mbi.estimation.lbfgs(domain: ~mbi.domain.Domain, loss_fn: ~mbi.marginal_loss.MarginalLossFn | list[~mbi.marginal_loss.LinearMeasurement], *, known_total: float | None = None, potentials: ~mbi.clique_vector.CliqueVector | None = None, marginal_oracle: ~mbi.marginal_oracles.MarginalOracle = <PjitFunction of <function message_passing_stable>>, iters: int = 1000, callback_fn: ~collections.abc.Callable[[~mbi.clique_vector.CliqueVector], None] = <function <lambda>>, mesh: ~jax._src.mesh.Mesh | None = None)[source]

Gradient-based optimization on the potentials (theta) via L-BFGS.

This optimizer works by calculating the gradients with respect to the potentials by back-propagting through the marginal inference oracle.

This is a standard approach for fitting the parameters of a graphical model without noise (i.e., when you know the exact marginals). In this case, the loss function with respect to theta is convex, and therefore this approach enjoys convergence guarantees. With generic marginal loss functions that arise for instance ith noisy marginals, the loss function is typically convex with respect to mu, but not with respect to theta. Therefore, this optimizer is not guaranteed to converge to the global optimum in all cases. In practice, it tends to work well in these settings despite non-convexities. This approach appeared in the paper [“Learning Graphical Model Parameters with Approximate Marginal Inference”](https://arxiv.org/abs/1301.3193).

Parameters:
  • domain – The domain over which the model should be defined.

  • loss_fn – A MarginalLossFn or a list of Linear Measurements.

  • known_total – The known or estimated number of records in the data. If loss_fn is provided as a list of LinearMeasurements, this argument is optional. Otherwise, it is required.

  • potentials – The initial potentials. Must be defined over a set of cliques that supports the cliques in the loss_fn.

  • marginal_oracle – The function to use to compute marginals from potentials.

  • iters – The maximum number of optimization iterations.

  • callback_fn

  • mesh – Determines how the marginal oracle and loss calculation will be sharded across devices.