mbi.marginal_loss

Defines loss functions based on linear measurements of marginals.

This module provides structures and functions for defining and calculating loss based on potentially noisy linear measurements of marginal distributions. Key components include the LinearMeasurement class to represent individual measurements and the MarginalLossFn class to define loss functions over CliqueVector objects, enabling the evaluation of model fit against observed or noisy data. Utilities for clique manipulation and feasibility checks are also included.

Functions

calculate_l2_lipschitz

Estimate the Lipschitz constant of L(x) = || f(x) - y ||_2^2 where f is a linear function.

from_linear_measurements

Construct a MarginalLossFn from a list of LinearMeasurements.

primal_feasibility

Calculates the average L1 distance between overlapping marginals in mu (consistency).

Classes

LinearMeasurement

A class for representing a private linear measurement of a marginal.

MarginalLossFn

A Loss function over the concatenated vector of marginals.

class mbi.marginal_loss.LinearMeasurement(noisy_measurement: ~jax.Array, clique, stddev: float = 1.0, query: ~collections.abc.Callable[[~mbi.factor.Factor], ~jax.Array] = <function Factor.datavector>)[source]

Bases: object

A class for representing a private linear measurement of a marginal.

noisy_measurement

The noisy measurement of the marginal.

Type:

jax.Array

clique

The clique (a tuple of attribute names) defining the marginal.

Type:

tuple[str, …]

stddev

The standard deviation of the noise added to the measurement.

Type:

float

query

A linear function that, when applied to a Factor, extracts a

Type:

collections.abc.Callable[[mbi.factor.Factor], jax.Array]

a vector with the same shape and interpretation as `noisy_measurement`.

Method generated by attrs for class LinearMeasurement.

noisy_measurement: Array
clique: tuple[str, ...]
stddev: float
query: Callable[[Factor], Array]
class mbi.marginal_loss.MarginalLossFn(cliques: list[tuple[str, ...]], loss_fn: Callable[[CliqueVector], Array | ndarray | bool | number | float | int], lipschitz: float | None = None)[source]

Bases: object

A Loss function over the concatenated vector of marginals.

cliques

A list of cliques (tuples of attribute names) that define the scope of the marginals used in the loss function.

Type:

list[tuple[str, …]]

loss_fn

A callable that takes a CliqueVector (representing the marginals) and returns a numeric loss value.

Type:

collections.abc.Callable[[mbi.clique_vector.CliqueVector], jax.Array | numpy.ndarray | numpy.bool | numpy.number | float | int]

lipschitz

An optional float representing the Lipschitz constant of the gradient of the loss function. This is used for optimization algorithms.

Type:

float | None

Method generated by attrs for class MarginalLossFn.

cliques: list[tuple[str, ...]]
loss_fn: Callable[[CliqueVector], Array | ndarray | bool | number | float | int]
lipschitz: float | None
mbi.marginal_loss.calculate_l2_lipschitz(domain: Domain, cliques: list[tuple[str, ...]], loss_fn: Callable[[CliqueVector], Array | ndarray | bool | number | float | int]) float[source]

Estimate the Lipschitz constant of L(x) = || f(x) - y ||_2^2 where f is a linear function.

The Lipschitz constant can usually be obtained via the largest eigenvalue of the Hessian, which for linear functions represented in matrix form is A^T A. This function computes the same value without materializing this n x n matrix by using power iteration and leveraging jax.jvp.

Parameters:
  • domain – The domain over which the loss_fn is defined.

  • loss_fn – The loss function, assumed to be of the form || f(x) - y ||_2^2 where f is linear.

Returns:

An estimate of the Lipschitz constant of the grad(L).

mbi.marginal_loss.from_linear_measurements(measurements: list[LinearMeasurement], norm: str = 'l2', normalize: bool = False, domain: Domain | None = None) MarginalLossFn[source]

Construct a MarginalLossFn from a list of LinearMeasurements.

Parameters:
  • measurements – A list of LinearMeasurements.

  • norm – Either “l1” or “l2”.

  • normalize – Flag determining if the loss function should be normalized by the length of linear measurements and estimated total.

  • domain – The domain over which the measurements were made, necessary for calcualting the Lipschitz parameter.

Returns:

The MarginalLossFn L(mu) = sum_{c} || Q_c mu_c - y_c || (possibly squared or normalized).

mbi.marginal_loss.primal_feasibility(mu: CliqueVector) Array | ndarray | bool | number | float | int[source]

Calculates the average L1 distance between overlapping marginals in mu (consistency).