from __future__ import annotations
import functools
from collections.abc import Callable, Sequence
from typing import Literal, Protocol
import attr
import chex
import jax
import jax.numpy as jnp
import numpy as np
from .domain import Domain
jax.config.update("jax_enable_x64", True)
[docs]
@functools.partial(
jax.tree_util.register_dataclass, meta_fields=["domain"], data_fields=["values"]
)
@attr.dataclass(frozen=True)
class Factor:
"""Represents a factor defined over a discrete domain.
A factor can be thought of as a potential function or an unnormalized
probability distribution over a set of discrete variables defined by a
`Domain` object. It maps each configuration of the domain to a value.
Attributes:
domain (Domain): The discrete domain over which the factor is defined.
values (jax.Array): A JAX array containing the factor's values. The shape
of this array matches the shape specified by the `domain`.
Supported Operations:
- Creation: `zeros`, `ones`, `random` for creating factors.
- Reshaping: `transpose`, `expand` for modifying the domain/shape.
- Aggregation: `sum`, `logsumexp`, `project` for marginalizing attributes.
- Element-wise: `exp`, `log`, `normalize` for value transformations.
- Binary Ops: `+`, `-`, `*`, `/`, `dot` for combining factors.
Example Usage:
>>> from mbi import Domain # Needed for doctest context
>>> domain = Domain.fromdict({'X': 2, 'Y': 3})
>>> factor = Factor.ones(domain)
>>> print(factor.domain)
Domain(X: 2, Y: 3)
"""
domain: Domain
values: jax.Array
def __post_init__(self):
if self.values.shape != self.domain.shape:
raise ValueError("values must be same shape as domain.")
# Constructors
[docs]
@classmethod
def zeros(cls, domain: Domain) -> Factor:
"""Creates a Factor object with all values initialized to zero."""
return cls(domain, jnp.zeros(domain.shape))
[docs]
@classmethod
def ones(cls, domain: Domain) -> Factor:
"""Creates a Factor object with all values initialized to one."""
return cls(domain, jnp.ones(domain.shape))
[docs]
@classmethod
def random(cls, domain: Domain) -> Factor:
"""Creates a Factor object with random values (uniform 0-1)."""
return cls(domain, jnp.asarray(np.random.rand(*domain.shape)))
[docs]
@classmethod
def abstract(cls, domain: Domain) -> Factor:
return cls(domain, jax.ShapeDtypeStruct(domain.shape, jnp.float64))
# Reshaping operations
[docs]
def transpose(self, attrs: Sequence[str]) -> Factor:
"""Rearranges the factor's axes according to the new attribute order."""
if set(attrs) != set(self.domain.attrs):
raise ValueError("attrs must be same as domain attributes")
newdom = self.domain.project(attrs)
ax = newdom.axes(self.domain.attrs)
values = jnp.moveaxis(self.values, range(len(ax)), ax)
return Factor(newdom, values)
[docs]
def expand(self, domain):
"""Expands the factor's domain to include new attributes."""
if not domain.contains(self.domain):
raise ValueError("Expanded domain must contain domain.")
dims = len(domain) - len(self.domain)
values = self.values.reshape(self.domain.shape + tuple([1] * dims))
ax = domain.axes(self.domain.attrs)
values = jnp.moveaxis(values, range(len(ax)), ax)
values = jnp.broadcast_to(values, domain.shape)
return Factor(domain, values)
# Functions that aggregate along some subset of axes
def _aggregate(self, fn: Callable, attrs: Sequence[str] | None = None) -> Factor:
"""Helper for aggregating values along specified attribute axes."""
attrs = self.domain.attrs if attrs is None else attrs
axes = self.domain.axes(attrs)
values = fn(self.values, axis=axes)
newdom = self.domain.marginalize(attrs)
return Factor(newdom, values)
[docs]
def max(self, attrs: Sequence[str] | None = None) -> Factor:
"""Computes the maximum value along specified attribute axes."""
return self._aggregate(jnp.max, attrs)
[docs]
def sum(self, attrs: Sequence[str] | None = None) -> Factor:
"""Computes the sum along specified attribute axes."""
return self._aggregate(jnp.sum, attrs)
[docs]
def logsumexp(self, attrs: Sequence[str] | None = None) -> Factor:
"""Computes the log-sum-exp along specified attribute axes."""
return self._aggregate(jax.scipy.special.logsumexp, attrs)
[docs]
def project(self, attrs: str | Sequence[str], log: bool = False) -> "Factor":
"""Computes the marginal distribution by summing/logsumexp'ing out other attributes."""
if isinstance(attrs, str):
attrs = (attrs,)
marginalized = self.domain.marginalize(attrs).attrs
result = self.logsumexp(marginalized) if log else self.sum(marginalized)
return result.transpose(attrs)
[docs]
def slice(self, evidence: dict[str, int | np.ndarray | jax.Array]) -> "Factor":
"""Slices the factor by fixing specific attribute values.
If at least one attribute has numpy-valued evidence, the returned factor will
have a new leading dimension called '_mbi_evidence' corresponding to the
number of evidence points.
Args:
evidence: A dictionary mapping attribute names to the values they should be fixed to.
Returns:
A new Factor with the specified attributes fixed and removed from the domain.
"""
slices = [slice(None)] * len(self.domain)
relevant = [e for e in evidence if e in self.domain.attrs]
adv_indices = []
has_vector = False
ev_size = None
for attr in relevant:
axis = self.domain.axes((attr,))[0]
val = evidence[attr]
slices[axis] = val
adv_indices.append(axis)
is_arr = hasattr(val, "ndim") and val.ndim > 0
if is_arr:
has_vector = True
if ev_size is None:
ev_size = val.shape[0]
elif ev_size != val.shape[0]:
raise ValueError("All vector evidence must have same size.")
values = self.values[tuple(slices)]
domain = self.domain.marginalize(relevant)
if has_vector:
adv_indices.sort()
# If advanced indices are contiguous, numpy puts the new dimension at the start of the block
is_contiguous = (adv_indices[-1] - adv_indices[0] + 1) == len(adv_indices)
target_axis = adv_indices[0] if is_contiguous else 0
# We want the evidence dimension to be at axis 0
if target_axis != 0:
values = jnp.moveaxis(values, target_axis, 0)
new_labels = None
if self.domain.labels is not None:
new_labels = (tuple(range(ev_size)),)
new = Domain(["_mbi_evidence"], [ev_size], labels=new_labels)
domain = new.merge(domain)
return Factor(domain, values)
[docs]
def supports(self, attrs: str | Sequence[str]) -> bool:
return self.domain.supports(attrs)
# Functions that operate element-wise
[docs]
def exp(self, out=None) -> Factor:
"""Applies element-wise exponentiation (jnp.exp) to the factor's values."""
return Factor(self.domain, jnp.exp(self.values))
[docs]
def log(self, out=None) -> Factor:
"""Applies element-wise logarithm (jnp.log) to the factor's values."""
return Factor(self.domain, jnp.log(self.values))
[docs]
def normalize(self, total: float = 1.0, log: bool = False) -> Factor:
"""Normalizes the factor so its values sum to `total` (or log-normalize)."""
if log:
return self + jnp.log(total) - self.logsumexp()
return self * total / self.sum()
[docs]
def copy(self) -> Factor:
"""Returns a copy of the factor (potentially shallow due to JAX)."""
return self
def __float__(self):
if len(self.domain) > 0:
raise ValueError("Domain must be empty to convert to float.")
return float(self.values)
# Binary operations between two factors
def _binaryop(self, fn: Callable, other: Factor | chex.Numeric) -> Factor:
"""Helper for applying binary operations between this factor and another factor or scalar."""
if not isinstance(other, Factor) and jnp.ndim(other) == 0:
other = Factor(Domain([], []), jnp.asarray(other))
newdom = self.domain.merge(other.domain)
factor1 = self.expand(newdom)
factor2 = other.expand(newdom)
return Factor(newdom, fn(factor1.values, factor2.values))
def __sub__(self, other: Factor | chex.Numeric) -> Factor:
return self._binaryop(jnp.subtract, other)
def __truediv__(self, other: Factor | chex.Numeric) -> Factor:
return self._binaryop(jnp.divide, other)
def __mul__(self, other: Factor | chex.Numeric) -> Factor:
"""Multiply two factors together.
Example Usage:
>>> f1 = Factor.ones(Domain(['a','b'], [2,3]))
>>> f2 = Factor.ones(Domain(['b','c'], [3,4]))
>>> f3 = f1 * f2
>>> print(f3.domain)
Domain(a: 2, b: 3, c: 4)
Args:
other: the other factor to multiply
Returns:
the product of the two factors
"""
return self._binaryop(jnp.multiply, other)
def __add__(self, other: Factor | chex.Numeric) -> Factor:
return self._binaryop(jnp.add, other)
def __radd__(self, other: chex.Numeric) -> Factor:
return self + other
def __rsub__(self, other: chex.Numeric) -> Factor:
return self + (-1 * other)
def __rmul__(self, other: chex.Numeric) -> Factor:
return self * other
[docs]
def dot(self, other: Factor) -> chex.Numeric:
if self.domain != other.domain:
raise ValueError(f"Domains do not match {self.domain} != {other.domain}")
return jnp.sum(
self.values * other.values,
where=(self.values != 0) & (other.values != 0),
)
[docs]
def datavector(self, flatten: bool = True) -> jax.Array:
"""Returns the factor's values as a flattened vector or original array."""
return self.values.flatten() if flatten else self.values
[docs]
def pad(
self, mesh: jax.sharding.Mesh | None, pad_value: Literal[0, "-inf"]
) -> Factor:
if mesh is None:
return self
pad_amounts = [0] * len(self.domain)
for i, ax in enumerate(self.domain):
if ax in mesh.axis_names:
size = self.domain[ax]
num_shards = mesh.axis_sizes[mesh.axis_names.index(ax)]
pad_amounts[i] = -size % num_shards
values = jnp.pad(
self.values,
pad_width=tuple((0, w) for w in pad_amounts),
constant_values=0.0 if pad_value == 0 else -jnp.inf,
)
# We keep the domain as-is here, even though values is now larger.
# We have a couple of options
# 1. Explicitly unpad when we are done.
# 2. Never unpad, just expand the domain, and make sure new elements are impossible.
# 3. Allow values to be an array where each dim is >= the domain implies, and truncate
# when necessary.
return Factor(self.domain, values)
[docs]
def apply_sharding(self, mesh: jax.sharding.Mesh | None) -> Factor:
"""Apply sharding constraint to the factor values.
The sharding strategy is automatically determined based on the provided
mesh, and the factor domain.
Args:
mesh: The mesh over which the factor should be sharded.
Returns:
A new factor identical to self with sharding constraints applied to the values.
"""
if mesh is None:
return self
pspec = [None] * len(self.domain)
for i, ax in enumerate(self.domain):
if ax in mesh.axis_names:
pspec[i] = ax
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(*pspec))
return Factor(
domain=self.domain,
values=jax.lax.with_sharding_constraint(self.values, sharding),
)
[docs]
class Projectable(Protocol):
"""A projectable is an object that can be projected onto a subset of attributes to compute a marginal.
Example projectables:
* Dataset
* Factor
* CliqueVector
* MarkovRandomField
"""
@property
def domain(self) -> Domain:
"""Returns the domain over which this projectable is defined."""
[docs]
def project(self, attrs: str | Sequence[str]) -> Factor:
"""Projection onto a subset of attributes."""
[docs]
def supports(self, attrs: str | Sequence[str]) -> bool:
"""Returns true if the given attributes can be projected onto."""