Source code for mbi.estimation

"""Algorithms for estimating graphical models from marginal-based loss functions.

This module provides a flexible set of optimization algorithms, each sharing the
the same API.  The supported algorithms are:
1. Mirror Descent [our recommended algorithm]
2. L-BFGS (using back-belief propagation)
3. Regularized Dual Averaging
4. Interior Gradient
5. Universal accelerated mirror descent

Each algorithm can be given an initial set of potentials, or can automatically
intialize the potentials to zero for you.  Any CliqueVector of potentials that
support the cliques of the marginal-based loss function can be used here.
"""

from __future__ import annotations

import functools
from collections.abc import Callable
from typing import Any, NamedTuple, Protocol

import jax
import jax.numpy as jnp
import numpy as np
import optax

from . import marginal_loss, marginal_oracles
from .approximate_oracles import StatefulMarginalOracle
from .clique_vector import CliqueVector
from .domain import Domain
from .factor import Factor, Projectable
from .marginal_loss import LinearMeasurement
from .markov_random_field import MarkovRandomField


[docs] class Estimator(Protocol): """ Defines the callable signature for marginal-based estimators. An estimator estimates a discrete distribution, or more generally a `Projectable' object from a loss function defined over it's low-dimensional marginals. Examples of conforming functions from `mbi.estimation`: - `mirror_descent` - `lbfgs` - `dual_averaging` - `interior_gradient` - `universal_accelerated_method` - ... and more from other modules """
[docs] def __call__( self, domain: Domain, loss_fn: marginal_loss.MarginalLossFn | list[LinearMeasurement], *, known_total: float | None = None, **kwargs: Any, ) -> Projectable: """ Estimate a Projectable from noisy marginal measurements. Args: domain: The Domain object specifying the attributes and their cardinalities over which the model is defined. loss_fn: Either a MarginalLossFn object or a list of LinearMeasurement objects. This defines the objective function to be minimized. known_total: An optional float for the known or estimated total number of records. If not specified, the estimator will attempt to learn this automatically. **kwargs: Additional optional keyword arguments specific to the estimation algorithm. Returns: A Projectable object that is maximally consistent with the noisy measurements taken in some sense. """ ...
[docs] def minimum_variance_unbiased_total(measurements: list[LinearMeasurement]) -> float: """Estimates the total count from measurements with identity queries.""" # find the minimum variance estimate of the total given the measurements estimates, variances = [], [] for M in measurements: y = M.noisy_measurement try: # TODO: generalize to support any linear measurement that supports total query if M.query == Factor.datavector: # query = Identity estimates.append(y.sum()) variances.append(M.stddev**2 * y.size) except Exception: continue estimates, variances = np.array(estimates), np.array(variances) if len(estimates) == 0: return 1 variance = 1.0 / np.sum(1.0 / variances) estimate = variance * np.sum(estimates / variances) return max(1, estimate)
def _initialize(domain, loss_fn, known_total, potentials): """Initializes loss function, total records, and potentials for estimation algorithms.""" if isinstance(loss_fn, list): if known_total is None: known_total = minimum_variance_unbiased_total(loss_fn) loss_fn = marginal_loss.from_linear_measurements(loss_fn, domain=domain) elif known_total is None: raise ValueError("Must set known_total if giving a custom MarginalLossFn") if potentials is None: potentials = CliqueVector.zeros(domain, loss_fn.cliques) if not all(potentials.supports(cl) for cl in loss_fn.cliques): potentials = potentials.expand(loss_fn.cliques) return loss_fn, known_total, potentials def _get_stateful_oracle( marginal_oracle: marginal_oracles.MarginalOracle | StatefulMarginalOracle, stateful: bool, ) -> StatefulMarginalOracle: if stateful: return marginal_oracle def wrapper(theta, total, state): return marginal_oracle(theta, total), state return wrapper
[docs] def mirror_descent( domain: Domain, loss_fn: marginal_loss.MarginalLossFn | list[LinearMeasurement], *, known_total: float | None = None, potentials: CliqueVector | None = None, marginal_oracle: ( marginal_oracles.MarginalOracle | StatefulMarginalOracle ) = marginal_oracles.message_passing_fast, iters: int = 1000, stateful: bool = False, stepsize: float | None = None, callback_fn: Callable[[CliqueVector], None] = lambda _: None, mesh: jax.sharding.Mesh | None = None, ): """Optimization using the Mirror Descent algorithm. This is a first-order proximal optimization algorithm for solving a (possibly nonsmooth) convex optimization problem over the marginal polytope. This is an implementation of Algorithm 1 from the paper ["Graphical-model based estimation and inference for differential privacy"] (https://arxiv.org/pdf/1901.09136). If stepsize is not provided, this algorithm uses a line search to automatically choose appropriate step sizes that satisfy the Armijo condition. Args: domain: The domain over which the model should be defined. loss_fn: A MarginalLossFn or a list of Linear Measurements. known_total: The known or estimated number of records in the data. potentials: The initial potentials. Must be defind over a set of cliques that supports the cliques in the loss_fn. marginal_oracle: The function to use to compute marginals from potentials. iters: The maximum number of optimization iterations. stepsize: The step size for the optimization. If not provided, this algorithm will use a line search to automatically choose appropriate step sizes. callback_fn: A function to call at each iteration with the iteration number. mesh: Determines how the marginal oracle and loss calculation will be sharded across devices. Returns: A MarkovRandomField object with the estimated potentials and marginals. """ if stepsize is None and stateful: raise ValueError( "Stepsize should be manually tuned when using a stateful oracle." ) loss_fn, known_total, potentials = _initialize( domain, loss_fn, known_total, potentials ) marginal_oracle = functools.partial(marginal_oracle, mesh=mesh) marginal_oracle = _get_stateful_oracle(marginal_oracle, stateful) @jax.jit def update(theta, alpha, state=None): mu, state = marginal_oracle(theta, known_total, state) loss, dL = jax.value_and_grad(loss_fn)(mu) theta2 = theta - alpha * dL if stepsize is not None: return theta2, loss, alpha, mu, state mu2, _ = marginal_oracle(theta2, known_total, state) loss2 = loss_fn(mu2) sufficient_decrease = loss - loss2 >= 0.5 * alpha * dL.dot(mu - mu2) alpha = jax.lax.select(sufficient_decrease, 1.01 * alpha, 0.5 * alpha) theta = jax.lax.cond(sufficient_decrease, lambda: theta2, lambda: theta) loss = jax.lax.select(sufficient_decrease, loss2, loss) return theta, loss, alpha, mu, state # Theory suggests the initial learning rate should be inversely # proportional to L. We also divide by scaling factor to account for # the fact that gradients are scaled up by a factor of known_total. # See Eq 75. of https://www.cs.uic.edu/~zhangx/teaching/bregman.pdf. L = loss_fn.lipschitz or 1.0 alpha = 2.0 / (L * known_total) if stepsize is None else stepsize mu, state = marginal_oracle(potentials, known_total, state=None) for t in range(iters): potentials, loss, alpha, mu, state = update(potentials, alpha, state) callback_fn(mu) marginals, _ = marginal_oracle(potentials, known_total, state) return MarkovRandomField( potentials=potentials, marginals=marginals, total=known_total )
def _optimize(loss_and_grad_fn, params, iters=250, callback_fn=lambda _: None): """Runs an optimization loop using Optax L-BFGS.""" if len(jax.tree.leaves(params)) == 0: # Nothing to optimize callback_fn(params) return params def loss_fn(theta): return loss_and_grad_fn(theta)[0] @jax.jit def update(params, opt_state): loss, grad = loss_and_grad_fn(params) updates, opt_state = optimizer.update( grad, opt_state, params, value=loss, grad=grad, value_fn=loss_fn ) return optax.apply_updates(params, updates), opt_state, loss optimizer = optax.lbfgs( memory_size=1, linesearch=optax.scale_by_zoom_linesearch(128, max_learning_rate=1), ) state = optimizer.init(params) prev_loss = float("inf") for t in range(iters): params, state, loss = update(params, state) callback_fn(params) # if loss == prev_loss: break prev_loss = loss return params
[docs] def lbfgs( domain: Domain, loss_fn: marginal_loss.MarginalLossFn | list[LinearMeasurement], *, known_total: float | None = None, potentials: CliqueVector | None = None, marginal_oracle: marginal_oracles.MarginalOracle = marginal_oracles.message_passing_stable, iters: int = 1000, callback_fn: Callable[[CliqueVector], None] = lambda _: None, mesh: jax.sharding.Mesh | None = None, ): """Gradient-based optimization on the potentials (theta) via L-BFGS. This optimizer works by calculating the gradients with respect to the potentials by back-propagting through the marginal inference oracle. This is a standard approach for fitting the parameters of a graphical model without noise (i.e., when you know the exact marginals). In this case, the loss function with respect to theta is convex, and therefore this approach enjoys convergence guarantees. With generic marginal loss functions that arise for instance ith noisy marginals, the loss function is typically convex with respect to mu, but not with respect to theta. Therefore, this optimizer is not guaranteed to converge to the global optimum in all cases. In practice, it tends to work well in these settings despite non-convexities. This approach appeared in the paper ["Learning Graphical Model Parameters with Approximate Marginal Inference"](https://arxiv.org/abs/1301.3193). Args: domain: The domain over which the model should be defined. loss_fn: A MarginalLossFn or a list of Linear Measurements. known_total: The known or estimated number of records in the data. If loss_fn is provided as a list of LinearMeasurements, this argument is optional. Otherwise, it is required. potentials: The initial potentials. Must be defined over a set of cliques that supports the cliques in the loss_fn. marginal_oracle: The function to use to compute marginals from potentials. iters: The maximum number of optimization iterations. callback_fn: ... mesh: Determines how the marginal oracle and loss calculation will be sharded across devices. """ loss_fn, known_total, potentials = _initialize( domain, loss_fn, known_total, potentials ) marginal_oracle = functools.partial(marginal_oracle, mesh=mesh) def theta_loss(theta): return loss_fn(marginal_oracle(theta, known_total)) theta_loss_and_grad = jax.value_and_grad(theta_loss) def theta_callback_fn(theta): callback_fn(marginal_oracle(theta, known_total)) potentials = _optimize( theta_loss_and_grad, potentials, iters=iters, callback_fn=theta_callback_fn ) return MarkovRandomField( potentials=potentials, marginals=marginal_oracle(potentials, known_total), total=known_total, )
[docs] def mle_from_marginals( marginals: CliqueVector, known_total: float, iters: int = 250, marginal_oracle: marginal_oracles.MarginalOracle = marginal_oracles.message_passing_stable, callback_fn=lambda *_: None, mesh: jax.sharding.Mesh | None = None, ) -> MarkovRandomField: """Compute the MLE Graphical Model from the marginals. Args: marginals: The marginal probabilities. known_total: The known or estimated number of records in the data. Returns: A MarkovRandomField object with the final potentials and marginals. """ def loss_and_grad_fn(theta): mu = marginal_oracle(theta, known_total, mesh) return -marginals.dot(mu.log()), mu - marginals potentials = CliqueVector.zeros(marginals.domain, marginals.cliques) potentials = _optimize(loss_and_grad_fn, potentials, iters=iters) return MarkovRandomField( potentials=potentials, marginals=marginal_oracle(potentials, known_total), total=known_total, )
[docs] def dual_averaging( domain: Domain, loss_fn: marginal_loss.MarginalLossFn | list[LinearMeasurement], *, known_total: float | None = None, potentials: CliqueVector | None = None, marginal_oracle: marginal_oracles.MarginalOracle = marginal_oracles.message_passing_stable, iters: int = 1000, callback_fn: Callable[[CliqueVector], None] = lambda _: None, mesh: jax.sharding.Mesh | None = None, ) -> MarkovRandomField: """Optimization using the Regularized Dual Averaging (RDA) algorithm. RDA is an accelerated proximal algorithm for solving a smooth convex optimization problem over the marginal polytope. This algorithm requires knowledge of the Lipschitz constant of the gradient of the loss function. Args: domain: The domain over which the model should be defined. loss_fn: A MarginalLossFn or a list of Linear Measurements. lipschitz: The Lipschitz constant of the gradient of the loss function. known_total: The known or estimated number of records in the data. potentials: The initial potentials. Must be defind over a set of cliques that supports the cliques in the loss_fn. marginal_oracle: The function to use to compute marginals from potentials. iters: The maximum number of optimization iterations. callback_fn: A function to call with intermediate solution at each iteration. mesh: Determines how the marginal oracle and loss calculation will be sharded across devices. Returns: A MarkovRandomField object with the final potentials and marginals. """ loss_fn, known_total, potentials = _initialize( domain, loss_fn, known_total, potentials ) if loss_fn.lipschitz is None: raise ValueError( "Dual Averaging requires a loss function with Lipschitz gradients." ) D = np.sqrt(domain.size() * np.log(domain.size())) # upper bound on entropy Q = 0 # upper bound on variance of stochastic gradients gamma = Q / D L = loss_fn.lipschitz / known_total @jax.jit def update(w, v, gbar, c, beta, t): u = (1 - c) * w + c * v g = jax.grad(loss_fn)(u) / known_total gbar = (1 - c) * gbar + c * g theta = -t * (t + 1) / (4 * L + beta) * gbar v = marginal_oracle(theta, known_total, mesh) w = (1 - c) * w + c * v return w, v, gbar w = v = marginal_oracle(potentials, known_total, mesh) gbar = CliqueVector.zeros(domain, loss_fn.cliques) for t in range(1, iters + 1): c = 2.0 / (t + 1) beta = gamma * (t + 1) ** 1.5 / 2 w, v, gbar = update(w, v, gbar, c, beta, t) callback_fn(w) return mle_from_marginals(w, known_total)
[docs] def interior_gradient( domain: Domain, loss_fn: marginal_loss.MarginalLossFn | list[LinearMeasurement], *, known_total: float | None = None, potentials: CliqueVector | None = None, marginal_oracle: marginal_oracles.MarginalOracle = marginal_oracles.message_passing_stable, iters: int = 1000, callback_fn: Callable[[CliqueVector], None] = lambda _: None, mesh: jax.sharding.Mesh | None = None, ): """Optimization using the Interior Point Gradient Descent algorithm. Interior Gradient is an accelerated proximal algorithm for solving a smooth convex optimization problem over the marginal polytope. This algorithm requires knowledge of the Lipschitz constant of the gradient of the loss function. This algorithm is based on the paper titled ["Interior Gradient and Proximal Methods for Convex and Conic Optimization"](https://epubs.siam.org/doi/abs/10.1137/S1052623403427823?journalCode=sjope8). Args: domain: The domain over which the model should be defined. loss_fn: A MarginalLossFn or a list of Linear Measurements. lipschitz: The Lipschitz constant of the gradient of the loss function. known_total: The known or estimated number of records in the data. potentials: The initial potentials. Must be defind over a set of cliques that supports the cliques in the loss_fn. marginal_oracle: The function to use to compute marginals from potentials. iters: The maximum number of optimization iterations. callback_fn: A function to call at each iteration with the iteration number. mesh: Determines how the marginal oracle and loss calculation will be sharded across devices. Returns: A MarkovRandomField object with the optimized potentials and marginals. """ loss_fn, known_total, potentials = _initialize( domain, loss_fn, known_total, potentials ) if loss_fn.lipschitz is None: raise ValueError( "Interior Gradient requires a loss function with Lipschitz gradients." ) # Algorithm parameters c = 1 sigma = 1 l = sigma / loss_fn.lipschitz @jax.jit def update(theta, c, x, y, z): a = (((c * l) ** 2 + 4 * c * l) ** 0.5 - l * c) / 2 y = (1 - a) * x + a * z c = c * (1 - a) g = jax.grad(loss_fn)(y) theta = theta - a / c / known_total * g z = marginal_oracle(theta, known_total, mesh) x = (1 - a) * x + a * z return theta, c, x, y, z # If we remove jit from marginal oracle, then we'll need to wrap this in # a jitted "init" function. x = y = z = marginal_oracle(potentials, known_total, mesh) theta = potentials for t in range(1, iters + 1): theta, c, x, y, z = update(theta, c, x, y, z) callback_fn(x) return mle_from_marginals(x, known_total)
class _AcceleratedStepSearchState(NamedTuple): """State of the step search. Attributes: x: parameters defining the optimization algorithm (see Roulet and d'Aspremont Algorithm 2). z: same as x, see ref. u: dual variable corresponding to z. prev_stepsize: reciprocal of the estimate of the Lipshitz-continuity parameter of the gradient of the objective at the previous iteration of the algorithm. stepsize: reciprocal of the estimate of the Lipshitz-continuity parameter of the gradient of the objective at the current iteration of the algorithm. prev_theta: numerical value decreasing along iterates at the previous iteration of the algorithm, see ref. accept: whether the step is accepted or not. iter_search: iteration count of the search. References: Nesterov, [Universal Gradient Methods for Convex Optimization Problems](https://optimization-online.org/wp-content/uploads/2013/04/3833.pdf) Roulet and d'Aspremont, [Sharpness, Restart and Acceleration](https://arxiv.org/pdf/1702.03828) """ x: CliqueVector z: CliqueVector u: CliqueVector prev_stepsize: jnp.ndarray | float stepsize: jnp.ndarray | float prev_theta: jnp.ndarray | float accept: jnp.ndarray | bool iter_search: jnp.ndarray | int def _universal_accelerated_method_step_init( fun: Callable[[CliqueVector], jnp.ndarray], dual_init_params, dual_proj: Callable[..., Any], max_iter_search: int = 30, target_acc: float = 0.0, stepsize: float = 1.0, norm: int = 2, linesearch=True, ) -> tuple[ _AcceleratedStepSearchState, Callable[[_AcceleratedStepSearchState], bool], Callable[[_AcceleratedStepSearchState], _AcceleratedStepSearchState], ]: """Accelerated first order method adapted to any smoothness. Minimizes fun(x) over a constraint set M. The algorithm requires an oracle "dual_proj(g)" that computes argmin_y <g, y> + h(y) s.t. y in M where h is a distance generating function. This method is inspired from ref 1 and the algorithm is described in essentially described in Algorithm 2 of ref 2. One difference is that we keep track of the dual variable returned by the dual_proj to avoid mapping back and forth between the primal and dual spaces. This function provides the initial state and the continuation and body functions for the step the method (which searches for a valid stepsize each time). Args: fun: objective to minimize. dual_init_params: initial parameters in dual space. dual_proj: projection onto some constraint set according to a bregman divergence. max_iter_search: maximal number of iterations to run the search. target_acc: target accuracy of the method. If `fun` is non-smooth, this needs to be set > 0. Convergence beyond that target accuracy is not guaranteed. If the function is smooth, set `target_acc=0`. stepsize: initial estimate of the stepsize. norm: type of norm measuring the smoothness of `fun`. linesearch: if true, uses linesearch to determine acceptance of step, otherwise use constant stepsize given by `stepsize`. Returns: (init_carry, cond_fun, body_fun) where init_carry: initial state of the step search. cond_fun: continuation criterion when searching for next step. body_fun: step when searching step. References: 1 Nesterov, [Universal Gradient Methods for Convex Optimization Problems](https://optimization-online.org/wp-content/uploads/2013/04/3833.pdf) 2 Roulet and d'Aspremont, [Sharpness, Restart and Acceleration](https://arxiv.org/pdf/1702.03828) """ def cond_fun(carry: _AcceleratedStepSearchState) -> bool | jnp.ndarray: """Continuation criterion when searching for next step.""" return jnp.logical_not( jnp.logical_or(carry.accept, carry.iter_search >= max_iter_search), ) def body_fun( carry: _AcceleratedStepSearchState, ) -> _AcceleratedStepSearchState: """Step when searching step.""" # Computes new theta prev_theta, prev_smooth_estim = carry.prev_theta, 1 / carry.prev_stepsize smooth_estim, stepsize = 1 / carry.stepsize, carry.stepsize aux = 1 + 4 * smooth_estim / (prev_theta**2 * prev_smooth_estim) new_theta = 2 / (1 + jnp.sqrt(aux)) # We hardcode the first iteration to be prev_theta=-1 theta = jnp.where(carry.prev_theta < 0.0, 1.0, new_theta) # Computes sequences of params y = (1 - theta) * carry.x + theta * carry.z value_y, grad_y = jax.value_and_grad(fun)(y) u = carry.u - stepsize / theta * grad_y z = dual_proj(u) x = (1 - theta) * carry.x + theta * z # Check condition if linesearch: new_value = fun(x) if norm == 1: sq_norm_diff = optax.tree.norm( optax.tree.sub(x, y), ord=1, squared=True ) elif norm == 2: sq_norm_diff = optax.tree.norm( optax.tree_utils.tree_sub(x, y), ord=2, squared=True ) else: raise ValueError(f"norm={norm} not supported") taylor_approx = ( value_y + grad_y.dot(x - y) + 0.5 * smooth_estim * sq_norm_diff ) accept = new_value <= (taylor_approx + 0.5 * target_acc * theta) new_stepsize = 1.1 * stepsize else: accept = True new_stepsize = stepsize candidate = _AcceleratedStepSearchState( x=x, z=z, u=u, prev_stepsize=stepsize, stepsize=new_stepsize, prev_theta=theta, accept=accept, iter_search=jnp.asarray(0), ) base = carry._replace( stepsize=0.5 * carry.stepsize, iter_search=carry.iter_search + 1 ) return jax.tree.map(lambda x, y: jnp.where(accept, x, y), candidate, base) x = z = dual_proj(dual_init_params) u = dual_init_params init_carry = _AcceleratedStepSearchState( x=x, z=z, u=u, prev_stepsize=stepsize, stepsize=stepsize, prev_theta=jnp.asarray(-1.0), accept=jnp.asarray(False), iter_search=jnp.asarray(0), ) return init_carry, cond_fun, body_fun
[docs] def universal_accelerated_method( domain: Domain, loss_fn: marginal_loss.MarginalLossFn | list[LinearMeasurement], *, known_total: float | None = None, potentials: CliqueVector | None = None, marginal_oracle: marginal_oracles.MarginalOracle = marginal_oracles.message_passing_stable, iters: int = 1000, callback_fn: Callable[[CliqueVector], None] = lambda _: None, mesh: jax.sharding.Mesh | None = None, ): """Optimization using the Universal Accelerated MD algorithm.""" loss_fn, known_total, potentials = _initialize( domain, loss_fn, known_total, potentials ) marginal_oracle = functools.partial(marginal_oracle, mesh=mesh) carry, cond_fun, body_fun = _universal_accelerated_method_step_init( fun=loss_fn, dual_init_params=potentials, dual_proj=lambda x: marginal_oracle(x, known_total), max_iter_search=30, target_acc=0.0, stepsize=1.0 / known_total, norm=2, linesearch=True, ) for _ in range(iters): # jax.lax.while_loop traces the body function, so no need to jit it. carry = jax.lax.while_loop(cond_fun, body_fun, carry) carry = carry._replace(accept=jnp.asarray(False)) callback_fn(carry.x) sol = carry.x return mle_from_marginals(sol, known_total)