mbi.estimation
Algorithms for estimating graphical models from marginal-based loss functions.
This module provides a flexible set of optimization algorithms, each sharing the the same API. The supported algorithms are: 1. Mirror Descent [our recommended algorithm] 2. L-BFGS (using back-belief propagation) 3. Regularized Dual Averaging 4. Interior Gradient 5. Universal accelerated mirror descent
Each algorithm can be given an initial set of potentials, or can automatically intialize the potentials to zero for you. Any CliqueVector of potentials that support the cliques of the marginal-based loss function can be used here.
Functions
Optimization using the Regularized Dual Averaging (RDA) algorithm. |
|
Optimization using the Interior Point Gradient Descent algorithm. |
|
Gradient-based optimization on the potentials (theta) via L-BFGS. |
|
Estimates the total count from measurements with identity queries. |
|
Optimization using the Mirror Descent algorithm. |
|
Compute the MLE Graphical Model from the marginals. |
|
Optimization using the Universal Accelerated MD algorithm. |
Classes
Defines the callable signature for marginal-based estimators. |
- class mbi.estimation.Estimator(*args, **kwargs)[source]
Bases:
ProtocolDefines the callable signature for marginal-based estimators.
An estimator estimates a discrete distribution, or more generally a `Projectable’ object from a loss function defined over it’s low-dimensional marginals.
Examples of conforming functions from mbi.estimation: - mirror_descent - lbfgs - dual_averaging - interior_gradient - universal_accelerated_method - … and more from other modules
- mbi.estimation.minimum_variance_unbiased_total(measurements: list[LinearMeasurement]) float[source]
Estimates the total count from measurements with identity queries.
- mbi.estimation.mirror_descent(domain: ~mbi.domain.Domain, loss_fn: ~mbi.marginal_loss.MarginalLossFn | list[~mbi.marginal_loss.LinearMeasurement], *, known_total: float | None = None, potentials: ~mbi.clique_vector.CliqueVector | None = None, marginal_oracle: ~mbi.marginal_oracles.MarginalOracle | ~mbi.approximate_oracles.StatefulMarginalOracle = <PjitFunction of <function message_passing_fast>>, iters: int = 1000, stateful: bool = False, stepsize: float | None = None, callback_fn: ~collections.abc.Callable[[~mbi.clique_vector.CliqueVector], None] = <function <lambda>>, mesh: ~jax._src.mesh.Mesh | None = None)[source]
Optimization using the Mirror Descent algorithm.
This is a first-order proximal optimization algorithm for solving a (possibly nonsmooth) convex optimization problem over the marginal polytope. This is an implementation of Algorithm 1 from the paper [“Graphical-model based estimation and inference for differential privacy”] (https://arxiv.org/pdf/1901.09136). If stepsize is not provided, this algorithm uses a line search to automatically choose appropriate step sizes that satisfy the Armijo condition.
- Parameters:
domain – The domain over which the model should be defined.
loss_fn – A MarginalLossFn or a list of Linear Measurements.
known_total – The known or estimated number of records in the data.
potentials – The initial potentials. Must be defind over a set of cliques that supports the cliques in the loss_fn.
marginal_oracle – The function to use to compute marginals from potentials.
iters – The maximum number of optimization iterations.
stepsize – The step size for the optimization. If not provided, this algorithm will use a line search to automatically choose appropriate step sizes.
callback_fn – A function to call at each iteration with the iteration number.
mesh – Determines how the marginal oracle and loss calculation will be sharded across devices.
- Returns:
A MarkovRandomField object with the estimated potentials and marginals.
- mbi.estimation.lbfgs(domain: ~mbi.domain.Domain, loss_fn: ~mbi.marginal_loss.MarginalLossFn | list[~mbi.marginal_loss.LinearMeasurement], *, known_total: float | None = None, potentials: ~mbi.clique_vector.CliqueVector | None = None, marginal_oracle: ~mbi.marginal_oracles.MarginalOracle = <PjitFunction of <function message_passing_stable>>, iters: int = 1000, callback_fn: ~collections.abc.Callable[[~mbi.clique_vector.CliqueVector], None] = <function <lambda>>, mesh: ~jax._src.mesh.Mesh | None = None)[source]
Gradient-based optimization on the potentials (theta) via L-BFGS.
This optimizer works by calculating the gradients with respect to the potentials by back-propagting through the marginal inference oracle.
This is a standard approach for fitting the parameters of a graphical model without noise (i.e., when you know the exact marginals). In this case, the loss function with respect to theta is convex, and therefore this approach enjoys convergence guarantees. With generic marginal loss functions that arise for instance ith noisy marginals, the loss function is typically convex with respect to mu, but not with respect to theta. Therefore, this optimizer is not guaranteed to converge to the global optimum in all cases. In practice, it tends to work well in these settings despite non-convexities. This approach appeared in the paper [“Learning Graphical Model Parameters with Approximate Marginal Inference”](https://arxiv.org/abs/1301.3193).
- Parameters:
domain – The domain over which the model should be defined.
loss_fn – A MarginalLossFn or a list of Linear Measurements.
known_total – The known or estimated number of records in the data. If loss_fn is provided as a list of LinearMeasurements, this argument is optional. Otherwise, it is required.
potentials – The initial potentials. Must be defined over a set of cliques that supports the cliques in the loss_fn.
marginal_oracle – The function to use to compute marginals from potentials.
iters – The maximum number of optimization iterations.
callback_fn –
…
mesh – Determines how the marginal oracle and loss calculation will be sharded across devices.
- mbi.estimation.mle_from_marginals(marginals: CliqueVector, known_total: float, iters: int = 250, marginal_oracle: MarginalOracle = <PjitFunction of <function message_passing_stable>>, callback_fn=<function <lambda>>, mesh: Mesh | None = None) MarkovRandomField[source]
Compute the MLE Graphical Model from the marginals.
- Parameters:
marginals – The marginal probabilities.
known_total – The known or estimated number of records in the data.
- Returns:
A MarkovRandomField object with the final potentials and marginals.
- mbi.estimation.dual_averaging(domain: ~mbi.domain.Domain, loss_fn: ~mbi.marginal_loss.MarginalLossFn | list[~mbi.marginal_loss.LinearMeasurement], *, known_total: float | None = None, potentials: ~mbi.clique_vector.CliqueVector | None = None, marginal_oracle: ~mbi.marginal_oracles.MarginalOracle = <PjitFunction of <function message_passing_stable>>, iters: int = 1000, callback_fn: ~collections.abc.Callable[[~mbi.clique_vector.CliqueVector], None] = <function <lambda>>, mesh: ~jax._src.mesh.Mesh | None = None) MarkovRandomField[source]
Optimization using the Regularized Dual Averaging (RDA) algorithm.
RDA is an accelerated proximal algorithm for solving a smooth convex optimization problem over the marginal polytope. This algorithm requires knowledge of the Lipschitz constant of the gradient of the loss function.
- Parameters:
domain – The domain over which the model should be defined.
loss_fn – A MarginalLossFn or a list of Linear Measurements.
lipschitz – The Lipschitz constant of the gradient of the loss function.
known_total – The known or estimated number of records in the data.
potentials – The initial potentials. Must be defind over a set of cliques that supports the cliques in the loss_fn.
marginal_oracle – The function to use to compute marginals from potentials.
iters – The maximum number of optimization iterations.
callback_fn – A function to call with intermediate solution at each iteration.
mesh – Determines how the marginal oracle and loss calculation will be sharded across devices.
- Returns:
A MarkovRandomField object with the final potentials and marginals.
- mbi.estimation.interior_gradient(domain: ~mbi.domain.Domain, loss_fn: ~mbi.marginal_loss.MarginalLossFn | list[~mbi.marginal_loss.LinearMeasurement], *, known_total: float | None = None, potentials: ~mbi.clique_vector.CliqueVector | None = None, marginal_oracle: ~mbi.marginal_oracles.MarginalOracle = <PjitFunction of <function message_passing_stable>>, iters: int = 1000, callback_fn: ~collections.abc.Callable[[~mbi.clique_vector.CliqueVector], None] = <function <lambda>>, mesh: ~jax._src.mesh.Mesh | None = None)[source]
Optimization using the Interior Point Gradient Descent algorithm.
Interior Gradient is an accelerated proximal algorithm for solving a smooth convex optimization problem over the marginal polytope. This algorithm requires knowledge of the Lipschitz constant of the gradient of the loss function. This algorithm is based on the paper titled [“Interior Gradient and Proximal Methods for Convex and Conic Optimization”](https://epubs.siam.org/doi/abs/10.1137/S1052623403427823?journalCode=sjope8).
- Parameters:
domain – The domain over which the model should be defined.
loss_fn – A MarginalLossFn or a list of Linear Measurements.
lipschitz – The Lipschitz constant of the gradient of the loss function.
known_total – The known or estimated number of records in the data.
potentials – The initial potentials. Must be defind over a set of cliques that supports the cliques in the loss_fn.
marginal_oracle – The function to use to compute marginals from potentials.
iters – The maximum number of optimization iterations.
callback_fn – A function to call at each iteration with the iteration number.
mesh – Determines how the marginal oracle and loss calculation will be sharded across devices.
- Returns:
A MarkovRandomField object with the optimized potentials and marginals.
- mbi.estimation.universal_accelerated_method(domain: ~mbi.domain.Domain, loss_fn: ~mbi.marginal_loss.MarginalLossFn | list[~mbi.marginal_loss.LinearMeasurement], *, known_total: float | None = None, potentials: ~mbi.clique_vector.CliqueVector | None = None, marginal_oracle: ~mbi.marginal_oracles.MarginalOracle = <PjitFunction of <function message_passing_stable>>, iters: int = 1000, callback_fn: ~collections.abc.Callable[[~mbi.clique_vector.CliqueVector], None] = <function <lambda>>, mesh: ~jax._src.mesh.Mesh | None = None)[source]
Optimization using the Universal Accelerated MD algorithm.