Source code for mbi.clique_vector

"""Defines the CliqueVector class for managing collections of factors over cliques.

This module introduces the `CliqueVector`, a data structure designed to hold
and manipulate sets of `Factor` objects, each associated with a specific clique
(a subset of attributes) within a domain. It facilitates operations common in
graphical models, such as projecting onto sub-cliques, expanding to larger cliques,
and performing arithmetic operations on these collections.
"""

from __future__ import annotations

import functools
import operator

import attr
import chex
import jax
import jax.numpy as jnp

from .clique_utils import Clique, reverse_clique_mapping
from .domain import Domain
from .factor import Factor, Projectable


[docs] @functools.partial( jax.tree_util.register_dataclass, meta_fields=["domain", "cliques"], data_fields=["arrays"], ) @attr.dataclass class CliqueVector: """Manages a collection of factors, each associated with a clique. This class provides a structure for holding and operating on multiple `Factor` objects, where each factor corresponds to a clique (a subset of attributes) within a larger domain. It's particularly useful in the context of graphical models for representing sets of potentials or marginals. Attributes: domain (Domain): The overall domain that encompasses all cliques. cliques (list[Clique]): A list of cliques (tuples of attribute names) for which factors are stored. arrays (dict[Clique, Factor]): A dictionary mapping each clique in `cliques` to its corresponding `Factor` object. """ domain: Domain cliques: list[Clique] arrays: dict[Clique, Factor] def __attrs_post_init__(self): if set(self.cliques) != set(self.arrays): raise ValueError("Cliques must be equal to keys of array.") if len(self.cliques) != len(set(self.cliques)): raise ValueError("Cliques must be unique.")
[docs] @classmethod def zeros(cls, domain: Domain, cliques: list[Clique]) -> CliqueVector: """Creates a CliqueVector initialized with zero factors for each clique.""" arrays = {cl: Factor.zeros(domain.project(cl)) for cl in cliques} return cls(domain, cliques, arrays)
[docs] @classmethod def ones(cls, domain: Domain, cliques: list[Clique]) -> CliqueVector: """Creates a CliqueVector initialized with one factors for each clique.""" arrays = {cl: Factor.ones(domain.project(cl)) for cl in cliques} return cls(domain, cliques, arrays)
[docs] @classmethod def random(cls, domain: Domain, cliques: list[Clique]) -> CliqueVector: """Creates a CliqueVector initialized with random factors for each clique.""" arrays = {cl: Factor.random(domain.project(cl)) for cl in cliques} return cls(domain, cliques, arrays)
[docs] @classmethod def abstract(cls, domain: Domain, cliques: list[Clique]) -> CliqueVector: arrays = {cl: Factor.abstract(domain.project(cl)) for cl in cliques} return cls(domain, cliques, arrays)
[docs] @classmethod def from_projectable(cls, data: Projectable, cliques: list[Clique]): """Creates a CliqueVector by projecting a data source onto the specified cliques.""" arrays = {cl: data.project(cl) for cl in cliques} return cls(data.domain, cliques, arrays)
@functools.cached_property def active_domain(self): """Returns the merged domain encompassing all attributes across all cliques.""" domains = [self.domain.project(cl) for cl in self.cliques] return functools.reduce(lambda a, b: a.merge(b), domains, Domain([], [])) # @functools.lru_cache(maxsize=None)
[docs] def parent(self, clique: Clique) -> Clique | None: """Finds a clique in this vector that is a superset of the given clique.""" for result in self.cliques: if set(clique) <= set(result): return result
[docs] def supports(self, clique: Clique) -> bool: """Checks if the given clique is supported (is a subset of any clique in the vector).""" return self.parent(clique) is not None
[docs] def project(self, clique: Clique, log: bool = False) -> Factor: clique = tuple(clique) if clique in self.arrays: return self[clique] if self.supports(clique): return self[self.parent(clique)].project(clique, log=log) raise ValueError(f"Cannot project onto unsupported clique {clique}.")
[docs] def expand(self, cliques: list[Clique]) -> CliqueVector: """Re-expresses this CliqueVector over an expanded set of cliques. If the original CliqueVector represents the potentials of a Graphical Model, the given cliques support the cliques in the original CliqueVector, then the distribution represented by the new CliqueVector will be identical. Args: cliques: The new cliques the clique vector will be defined over. Returns: An expanded CliqueVector defined over the given set of cliques. """ mapping = reverse_clique_mapping(cliques, self.cliques, domain=self.domain) arrays = {} for cl in cliques: dom = self.domain.project(cl) if len(mapping[cl]) == 0: arrays[cl] = Factor.zeros(dom) else: arrays[cl] = sum(self[cl2] for cl2 in mapping[cl]).expand(dom) return CliqueVector(self.domain, cliques, arrays)
[docs] def contract(self, cliques: list[Clique], log: bool = False) -> CliqueVector: """Computes a new CliqueVector by projecting this one onto a smaller set of cliques.""" arrays = {cl: self.project(cl, log=log) for cl in cliques} return CliqueVector(self.domain, cliques, arrays)
[docs] def normalize(self, total: float = 1, log: bool = True): """Normalizes each factor within the CliqueVector.""" return jax.tree.map( lambda f: f.normalize(total, log), self, is_leaf=Factor.__instancecheck__ )
def __mul__(self, const: chex.Numeric) -> CliqueVector: """Multiplies each factor in the vector by a constant.""" return jax.tree.map(lambda f: f * const, self) def __rmul__(self, const: chex.Numeric) -> CliqueVector: """Right-multiplies each factor in the vector by a constant.""" return self.__mul__(const) def __truediv__(self, const: chex.Numeric) -> CliqueVector: """Divides each factor in the vector by a constant.""" return self.__mul__(1 / const) def __add__(self, other: chex.Numeric | CliqueVector) -> CliqueVector: """Adds another CliqueVector or a constant to this vector elementwise.""" if isinstance(other, CliqueVector): return jax.tree.map(jnp.add, self, other) return jax.tree.map(lambda f: f + other, self) def __sub__(self, other: chex.Numeric | CliqueVector) -> CliqueVector: """Subtracts another CliqueVector or a constant from this vector elementwise.""" return self + -1 * other
[docs] def exp(self) -> CliqueVector: """Applies elementwise exponentiation (jnp.exp) to each factor.""" return jax.tree.map(jnp.exp, self)
[docs] def log(self) -> CliqueVector: """Applies elementwise logarithm (jnp.log) to each factor.""" return jax.tree.map(jnp.log, self)
[docs] def dot(self, other: CliqueVector) -> chex.Numeric: """Computes the dot product between this CliqueVector and another.""" dots = jax.tree.map(Factor.dot, self, other, is_leaf=Factor.__instancecheck__) return jax.tree.reduce(operator.add, dots, 0)
[docs] def size(self): """Calculates the total number of parameters across all factors in the vector.""" return sum(self.domain.size(cl) for cl in self.cliques)
def __getitem__(self, clique: Clique) -> Factor: """Retrieves the factor associated with the given clique.""" return self.arrays[clique] def __setitem__(self, clique: Clique, value: Factor): """Sets the factor for a given clique, replacing the existing one if present.""" if clique in self.cliques: self.arrays[clique] = value else: raise ValueError(f"Clique {clique} not in CliqueVector.")
[docs] def apply_sharding(self, mesh: jax.sharding.Mesh | None) -> CliqueVector: """Apply sharding constraint to each factor in the CliqueVector. The sharding strategy is automatically determined based on the provided mesh and the factor domains. Args: mesh: The mesh over which the factors should be sharded. Returns: A new CliqueVector identical to self with sharding constraints applied to the underlying factors. """ arrays = {cl: self.arrays[cl].apply_sharding(mesh) for cl in self.cliques} return CliqueVector(self.domain, self.cliques, arrays)