mbi.estimation.mirror_descent
- mbi.estimation.mirror_descent(domain: ~mbi.domain.Domain, loss_fn: ~mbi.marginal_loss.MarginalLossFn | list[~mbi.marginal_loss.LinearMeasurement], *, known_total: float | None = None, potentials: ~mbi.clique_vector.CliqueVector | None = None, marginal_oracle: ~mbi.marginal_oracles.MarginalOracle = <PjitFunction of <function message_passing_fast>>, iters: int = 1000, stepsize: float | None = None, callback_fn: ~collections.abc.Callable[[~mbi.clique_vector.CliqueVector], None] = <function <lambda>>, mesh: ~jax._src.mesh.Mesh | None = None) MarkovRandomField[source]
Optimization using the Mirror Descent algorithm.
This is a first-order proximal optimization algorithm for solving a (possibly nonsmooth) convex optimization problem over the marginal polytope. This is an implementation of Algorithm 1 from the paper [“Graphical-model based estimation and inference for differential privacy”] (https://arxiv.org/pdf/1901.09136). If stepsize is not provided, this algorithm uses a line search to automatically choose appropriate step sizes that satisfy the Armijo condition.
- Parameters:
domain – The domain over which the model should be defined.
loss_fn – A MarginalLossFn or a list of Linear Measurements.
known_total – The known or estimated number of records in the data.
potentials – The initial potentials. Must be defind over a set of cliques that supports the cliques in the loss_fn.
marginal_oracle – The function to use to compute marginals from potentials.
iters – The maximum number of optimization iterations.
stepsize – The step size for the optimization. If not provided, this algorithm will use a line search to automatically choose appropriate step sizes.
callback_fn – A function to call at each iteration with the iteration number.
mesh – Determines how the marginal oracle and loss calculation will be sharded across devices.
- Returns:
A MarkovRandomField object with the estimated potentials and marginals.