mbi.marginal_loss.calculate_l2_lipschitz

mbi.marginal_loss.calculate_l2_lipschitz(domain: Domain, cliques: list[tuple[str, ...]], loss_fn: Callable[[CliqueVector], Array | ndarray | bool | number | float | int]) float[source]

Estimate the Lipschitz constant of L(x) = || f(x) - y ||_2^2 where f is a linear function.

The Lipschitz constant can usually be obtained via the largest eigenvalue of the Hessian, which for linear functions represented in matrix form is A^T A. This function computes the same value without materializing this n x n matrix by using power iteration and leveraging jax.jvp.

Parameters:
  • domain – The domain over which the loss_fn is defined.

  • loss_fn – The loss function, assumed to be of the form || f(x) - y ||_2^2 where f is linear.

Returns:

An estimate of the Lipschitz constant of the grad(L).