mbi.MarginalLossFn

class mbi.MarginalLossFn(cliques: list[tuple[str, ...]], loss_fn: Callable[[CliqueVector], Array | ndarray | bool | number | float | int], lipschitz: float | None = None)[source]

Bases: object

A Loss function over the concatenated vector of marginals.

cliques

A list of cliques (tuples of attribute names) that define the scope of the marginals used in the loss function.

Type:

list[tuple[str, …]]

loss_fn

A callable that takes a CliqueVector (representing the marginals) and returns a numeric loss value.

Type:

collections.abc.Callable[[mbi.clique_vector.CliqueVector], jax.Array | numpy.ndarray | numpy.bool | numpy.number | float | int]

lipschitz

An optional float representing the Lipschitz constant of the gradient of the loss function. This is used for optimization algorithms.

Type:

float | None

Method generated by attrs for class MarginalLossFn.

Methods

__init__

Method generated by attrs for class MarginalLossFn.

Attributes

cliques

loss_fn

lipschitz

cliques: list[tuple[str, ...]]
loss_fn: Callable[[CliqueVector], Array | ndarray | bool | number | float | int]
lipschitz: float | None
__call__(marginals: CliqueVector) Array | ndarray | bool | number | float | int[source]

Call self as a function.