Source code for mbi.markov_random_field

"""Defines the MarkovRandomField class representing learned graphical models.

This module provides the `MarkovRandomField` class, which encapsulates the
results of learning a graphical model. It stores the learned potentials,
the resulting marginal distributions, and the associated total count (e.g.,
number of records). It also offers methods for querying marginals and
generating synthetic data.
"""

from collections.abc import Sequence
import chex
import math
import numpy as np

from . import junction_tree, marginal_oracles
from .clique_vector import CliqueVector
from .dataset import Dataset
from .factor import Factor


[docs] @chex.dataclass(frozen=True, kw_only=False) class MarkovRandomField: """Represents a learned graphical model. This class encapsulates the components of a Markov Random Field that has been learned from data. It stores the learned potentials, the resulting marginal distributions over specified cliques, and the total count (e.g., number of records or equivalent sample size) associated with the model. Attributes: potentials (CliqueVector): A `CliqueVector` containing the learned potential functions for the cliques in the model. marginals (CliqueVector): A `CliqueVector` containing the marginal distributions for a set of cliques, derived from the potentials. total (chex.Numeric): The total count or effective sample size represented by the model. This is often used for scaling or interpreting the marginals. """ potentials: CliqueVector marginals: CliqueVector total: chex.Numeric = 1
[docs] def project(self, attrs: str | Sequence[str]) -> Factor: if isinstance(attrs, str): attrs = (attrs,) attrs = tuple(attrs) if self.marginals.supports(attrs): return self.marginals.project(attrs) return marginal_oracles.variable_elimination(self.potentials, attrs, self.total)
[docs] def supports(self, attrs: str | Sequence[str]) -> bool: return self.marginals.domain.supports(attrs)
[docs] def synthetic_data(self, rows: int | None = None, method: str = "round") -> Dataset: """Generates synthetic data based on the learned model's marginals. Args: rows: The number of rows to generate. If not provided, uses the model total, which is usually estimated automatically. method: Specification for strategy to use to generate records. - "round" for randomized rounding - "sample" for i.i.d sampling Returns: A synthetic dataset whose marginals should closely match those of the model. """ total = max(1, int(rows or self.total)) domain = self.domain cliques = [set(cl) for cl in self.cliques] jtree, elimination_order = junction_tree.make_junction_tree(domain, cliques) potentials = self.potentials.expand(list(jtree.nodes)) marginals = marginal_oracles.message_passing_stable(potentials) def synthetic_col(counts, total): """Generates a synthetic column by sampling or rounding based on counts and total.""" dtype = np.min_scalar_type(counts.size) options = np.arange(counts.size, dtype=dtype) if total == 0: return np.array([], dtype=int) if method == "sample": probas = counts / counts.sum() return np.random.choice(options, total, True, probas) counts *= total / counts.sum() frac, integ = np.modf(counts) integ = integ.astype(int) extra = total - integ.sum() if extra > 0: idx = np.random.choice(options, extra, False, frac / frac.sum()) integ[idx] += 1 vals = np.repeat(options, integ) np.random.shuffle(vals) return vals data = {} order = elimination_order[::-1] col = order[0] marg = marginals.project((col,)).datavector(flatten=False) data[col] = synthetic_col(marg, total) used = {col} for col in order[1:]: relevant = [cl for cl in cliques if col in cl] relevant = used.intersection(set().union(*relevant)) proj = tuple(relevant) used.add(col) if len(proj) >= 1: current_proj_data = np.stack(tuple(data[col] for col in proj), -1) marg = np.asarray( marginals.project(proj + (col,)).datavector(flatten=False) ) marg_parents = marg.sum(axis=-1, keepdims=True) cond_probs = np.divide( marg, marg_parents, out=np.zeros_like(marg), where=marg_parents != 0 ) cond_cdfs = cond_probs.cumsum(axis=-1) uniques, inverse, counts = np.unique( current_proj_data, axis=0, return_inverse=True, return_counts=True, ) if method == "sample": u = np.random.rand(total) else: perm = np.argsort(inverse, kind="stable") inverse_sorted = inverse[perm] group_starts = np.zeros(len(counts), dtype=int) np.cumsum(counts[:-1], out=group_starts[1:]) sorted_indices = np.arange(total) ranks_sorted = sorted_indices - group_starts[inverse_sorted] ranks = np.empty(total, dtype=int) ranks[perm] = ranks_sorted noise = np.random.rand(total) u = (ranks + noise) / counts[inverse] indices = tuple(uniques.T) unique_cdfs = cond_cdfs[indices] choices = np.empty(total, dtype=np.min_scalar_type(self.domain[col])) domain_size = self.domain[col] if method == "sample": perm = np.argsort(inverse, kind="stable") u_sorted = u[perm] start = 0 for i, count in enumerate(counts): end = start + count cdf = unique_cdfs[i] indices_chunk = np.searchsorted( cdf, u_sorted[start:end], side="right" ) if len(indices_chunk) > 0: np.minimum(indices_chunk, domain_size - 1, out=indices_chunk) choices[perm[start:end]] = indices_chunk start = end data[col] = choices else: marg = marginals.project((col,)).datavector(flatten=False) data[col] = synthetic_col(marg, total) return Dataset(data, domain)
@property def domain(self): """Returns the Domain object associated with this graphical model.""" return self.potentials.domain @property def cliques(self): """Returns the list of cliques the model's potentials are defined over.""" return self.potentials.cliques