mbi.marginal_loss.MarginalLossFn
- class mbi.marginal_loss.MarginalLossFn(cliques: list[tuple[str, ...]], loss_fn: Callable[[CliqueVector], Array | ndarray | bool | number | float | int], lipschitz: float | None = None)[source]
Bases:
objectA Loss function over the concatenated vector of marginals.
- cliques
A list of cliques (tuples of attribute names) that define the scope of the marginals used in the loss function.
- Type:
list[tuple[str, …]]
- loss_fn
A callable that takes a CliqueVector (representing the marginals) and returns a numeric loss value.
- Type:
collections.abc.Callable[[mbi.clique_vector.CliqueVector], jax.Array | numpy.ndarray | numpy.bool | numpy.number | float | int]
- lipschitz
An optional float representing the Lipschitz constant of the gradient of the loss function. This is used for optimization algorithms.
- Type:
float | None
Method generated by attrs for class MarginalLossFn.
Methods
__init__Method generated by attrs for class MarginalLossFn.
Attributes
- cliques: list[tuple[str, ...]]
- loss_fn: Callable[[CliqueVector], Array | ndarray | bool | number | float | int]
- lipschitz: float | None
- __call__(marginals: CliqueVector) Array | ndarray | bool | number | float | int[source]
Call self as a function.