mbi.estimation.mirror_descent

mbi.estimation.mirror_descent(domain: ~mbi.domain.Domain, loss_fn: ~mbi.marginal_loss.MarginalLossFn | list[~mbi.marginal_loss.LinearMeasurement], *, known_total: float | None = None, potentials: ~mbi.clique_vector.CliqueVector | None = None, marginal_oracle: ~mbi.marginal_oracles.MarginalOracle = <PjitFunction of <function message_passing_fast>>, iters: int = 1000, stepsize: float | None = None, callback_fn: ~collections.abc.Callable[[~mbi.clique_vector.CliqueVector], None] = <function <lambda>>, mesh: ~jax._src.mesh.Mesh | None = None) MarkovRandomField[source]

Optimization using the Mirror Descent algorithm.

This is a first-order proximal optimization algorithm for solving a (possibly nonsmooth) convex optimization problem over the marginal polytope. This is an implementation of Algorithm 1 from the paper [“Graphical-model based estimation and inference for differential privacy”] (https://arxiv.org/pdf/1901.09136). If stepsize is not provided, this algorithm uses a line search to automatically choose appropriate step sizes that satisfy the Armijo condition.

Parameters:
  • domain – The domain over which the model should be defined.

  • loss_fn – A MarginalLossFn or a list of Linear Measurements.

  • known_total – The known or estimated number of records in the data.

  • potentials – The initial potentials. Must be defind over a set of cliques that supports the cliques in the loss_fn.

  • marginal_oracle – The function to use to compute marginals from potentials.

  • iters – The maximum number of optimization iterations.

  • stepsize – The step size for the optimization. If not provided, this algorithm will use a line search to automatically choose appropriate step sizes.

  • callback_fn – A function to call at each iteration with the iteration number.

  • mesh – Determines how the marginal oracle and loss calculation will be sharded across devices.

Returns:

A MarkovRandomField object with the estimated potentials and marginals.