mbi.marginal_loss.calculate_l2_lipschitz
- mbi.marginal_loss.calculate_l2_lipschitz(domain: Domain, cliques: list[tuple[str, ...]], loss_fn: Callable[[CliqueVector], Array | ndarray | bool | number | float | int]) float[source]
Estimate the Lipschitz constant of L(x) = || f(x) - y ||_2^2 where f is a linear function.
The Lipschitz constant can usually be obtained via the largest eigenvalue of the Hessian, which for linear functions represented in matrix form is A^T A. This function computes the same value without materializing this n x n matrix by using power iteration and leveraging jax.jvp.
- Parameters:
domain – The domain over which the loss_fn is defined.
loss_fn – The loss function, assumed to be of the form || f(x) - y ||_2^2 where f is linear.
- Returns:
An estimate of the Lipschitz constant of the grad(L).