mbi.marginal_loss
Defines loss functions based on linear measurements of marginals.
This module provides structures and functions for defining and calculating loss based on potentially noisy linear measurements of marginal distributions. Key components include the LinearMeasurement class to represent individual measurements and the MarginalLossFn class to define loss functions over CliqueVector objects, enabling the evaluation of model fit against observed or noisy data. Utilities for clique manipulation and feasibility checks are also included.
Functions
Estimate the Lipschitz constant of L(x) = || f(x) - y ||_2^2 where f is a linear function. |
|
Construct a MarginalLossFn from a list of LinearMeasurements. |
|
Calculates the average L1 distance between overlapping marginals in mu (consistency). |
Classes
A class for representing a private linear measurement of a marginal. |
|
A Loss function over the concatenated vector of marginals. |
- class mbi.marginal_loss.LinearMeasurement(noisy_measurement: ~jax.Array, clique, stddev: float = 1.0, query: ~collections.abc.Callable[[~mbi.factor.Factor], ~jax.Array] = <function Factor.datavector>)[source]
Bases:
objectA class for representing a private linear measurement of a marginal.
- noisy_measurement
The noisy measurement of the marginal.
- Type:
jax.Array
- clique
The clique (a tuple of attribute names) defining the marginal.
- Type:
tuple[str, …]
- stddev
The standard deviation of the noise added to the measurement.
- Type:
float
- query
A linear function that, when applied to a Factor, extracts a
- Type:
collections.abc.Callable[[mbi.factor.Factor], jax.Array]
- a vector with the same shape and interpretation as `noisy_measurement`.
Method generated by attrs for class LinearMeasurement.
- noisy_measurement: Array
- clique: tuple[str, ...]
- stddev: float
- class mbi.marginal_loss.MarginalLossFn(cliques: list[tuple[str, ...]], loss_fn: Callable[[CliqueVector], Array | ndarray | bool | number | float | int], lipschitz: float | None = None)[source]
Bases:
objectA Loss function over the concatenated vector of marginals.
- cliques
A list of cliques (tuples of attribute names) that define the scope of the marginals used in the loss function.
- Type:
list[tuple[str, …]]
- loss_fn
A callable that takes a CliqueVector (representing the marginals) and returns a numeric loss value.
- Type:
collections.abc.Callable[[mbi.clique_vector.CliqueVector], jax.Array | numpy.ndarray | numpy.bool | numpy.number | float | int]
- lipschitz
An optional float representing the Lipschitz constant of the gradient of the loss function. This is used for optimization algorithms.
- Type:
float | None
Method generated by attrs for class MarginalLossFn.
- cliques: list[tuple[str, ...]]
- loss_fn: Callable[[CliqueVector], Array | ndarray | bool | number | float | int]
- lipschitz: float | None
- mbi.marginal_loss.calculate_l2_lipschitz(domain: Domain, cliques: list[tuple[str, ...]], loss_fn: Callable[[CliqueVector], Array | ndarray | bool | number | float | int]) float[source]
Estimate the Lipschitz constant of L(x) = || f(x) - y ||_2^2 where f is a linear function.
The Lipschitz constant can usually be obtained via the largest eigenvalue of the Hessian, which for linear functions represented in matrix form is A^T A. This function computes the same value without materializing this n x n matrix by using power iteration and leveraging jax.jvp.
- Parameters:
domain – The domain over which the loss_fn is defined.
loss_fn – The loss function, assumed to be of the form || f(x) - y ||_2^2 where f is linear.
- Returns:
An estimate of the Lipschitz constant of the grad(L).
- mbi.marginal_loss.from_linear_measurements(measurements: list[LinearMeasurement], norm: str = 'l2', normalize: bool = False, domain: Domain | None = None) MarginalLossFn[source]
Construct a MarginalLossFn from a list of LinearMeasurements.
- Parameters:
measurements – A list of LinearMeasurements.
norm – Either “l1” or “l2”.
normalize – Flag determining if the loss function should be normalized by the length of linear measurements and estimated total.
domain – The domain over which the measurements were made, necessary for calcualting the Lipschitz parameter.
- Returns:
The MarginalLossFn L(mu) = sum_{c} || Q_c mu_c - y_c || (possibly squared or normalized).
- mbi.marginal_loss.primal_feasibility(mu: CliqueVector) Array | ndarray | bool | number | float | int[source]
Calculates the average L1 distance between overlapping marginals in mu (consistency).