Source code for mbi.domain

"""Defines the Domain class for representing attribute domains.

This module provides the `Domain` class, which encapsulates a set of named
attributes and their corresponding discrete sizes (shapes). It facilitates
representing the structure of datasets and graphical models and supports
various operations like projection, marginalization, and merging of domains.
"""

import functools
from collections.abc import Iterator, Sequence
from typing import Any

import attr


[docs] @attr.dataclass(frozen=True) class Domain: """Represents a discrete domain defined by attributes and their sizes. This class encapsulates a set of named attributes and their corresponding discrete sizes (shapes). It provides methods for common domain operations. Attributes: attributes (tuple[str, ...]): A tuple containing the names of the attributes in the domain. shape (tuple[int, ...]): A tuple containing the integer sizes (number of discrete values) for each corresponding attribute in the `attributes` tuple. labels (tuple[tuple[Any, ...], ...] | None): An optional tuple of tuples containing semantic information (labels) for each attribute's values. Must be the same length as attributes, and each inner tuple must have length corresponding to the attribute's size. Supported Operations: - Projection (`project`): Creates a new domain with a subset of attributes. - Marginalization (`marginalize`): Creates a new domain excluding specified attributes. - Intersection (`intersect`): Creates a new domain containing only common attributes. - Merging (`merge`): Combines two domains into a larger one. - Size Calculation (`size`): Computes the total number of configurations in the domain or a subset. Example Usage (using fromdict): >>> domain = Domain.fromdict({'a': 2, 'b': 3}) >>> print(domain) Domain(a: 2, b: 3) """ attributes: tuple[str, ...] = attr.field(converter=tuple) shape: tuple[int, ...] = attr.field(converter=lambda sh: tuple(int(n) for n in sh)) labels: tuple[tuple[Any, ...], ...] | None = attr.field( default=None, converter=lambda l: tuple(tuple(x) for x in l) if l is not None else None, ) def __attrs_post_init__(self): if len(self.attributes) != len(self.shape): raise ValueError("Dimensions must be equal.") if len(self.attributes) != len(set(self.attributes)): raise ValueError("Attributes must be unique.") if self.labels is not None: if len(self.labels) != len(self.attributes): raise ValueError("Labels must be same length as attributes.") for i, l in enumerate(self.labels): if len(l) != self.shape[i]: raise ValueError( f"Labels for {self.attributes[i]} must have length {self.shape[i]}." ) @functools.cached_property def config(self) -> dict[str, int]: """Returns a dictionary of { attr : size } values.""" return dict(zip(self.attributes, self.shape)) @functools.cached_property def labels_config(self) -> dict[str, tuple[Any, ...]] | None: """Returns a dictionary of { attr : labels } values.""" if self.labels is None: return None return dict(zip(self.attributes, self.labels))
[docs] @staticmethod def fromdict(config: dict[str, int]) -> "Domain": """Construct a Domain object from a dictionary of { attr : size } values. Example Usage: >>> print(Domain.fromdict({'a': 10, 'b': 20})) Domain(a: 10, b: 20) Args: config: a dictionary of { attr : size } values Returns: the Domain object """ return Domain(config.keys(), config.values())
[docs] def project(self, attributes: str | Sequence[str]) -> "Domain": """Project the domain onto a subset of attributes. Args: attributes: the attributes to project onto Returns: the projected Domain object """ if isinstance(attributes, str): attributes = [attributes] if not set(attributes) <= set(self.attributes): raise ValueError(f"Cannot project {self} onto {attributes}.") shape = tuple(self.config[a] for a in attributes) labels = None if self.labels is not None: labels = tuple(self.labels_config[a] for a in attributes) return Domain(attributes, shape, labels=labels)
[docs] def marginalize(self, attrs: Sequence[str]) -> "Domain": """Marginalize out some attributes from the domain (opposite of project). Example Usage: >>> D1 = Domain(['a','b'], [10,20]) >>> print(D1.marginalize(['a'])) Domain(b: 20) Args: attrs: the attributes to marginalize out. Returns: the marginalized Domain object """ proj = [a for a in self.attributes if a not in attrs] return self.project(proj)
[docs] def contains(self, other: "Domain") -> bool: """Checks if this domain contains all attributes present in another domain.""" return set(other.attributes) <= set(self.attributes)
[docs] def canonical(self, attrs): """Returns attributes common to the domain and input, maintaining the domain's order.""" return tuple(a for a in self.attributes if a in attrs)
[docs] def invert(self, attrs): """Returns attributes present in the domain but not in the provided list.""" return [a for a in self.attributes if a not in attrs]
[docs] def intersect(self, other: "Domain") -> "Domain": """Intersect this Domain object with another. Example Usage: >>> D1 = Domain(['a','b'], [10,20]) >>> D2 = Domain(['b','c'], [20,30]) >>> print(D1.intersect(D2)) Domain(b: 20) Args: other: another Domain object Returns: the intersection of the two domains """ return self.project([a for a in self.attributes if a in other.attributes])
[docs] def axes(self, attrs: Sequence[str]) -> tuple[int, ...]: """Return the axes tuple for the given attributes. Args: attrs: the attributes Returns: a tuple with the corresponding axes """ return tuple(self.attributes.index(a) for a in attrs)
[docs] def merge(self, other: "Domain") -> "Domain": """Merge this Domain object with another. :param other: another Domain object :return: a new domain object covering the full domain Example: >>> D1 = Domain(['a','b'], [10,20]) >>> D2 = Domain(['b','c'], [20,30]) >>> print(D1.merge(D2)) Domain(a: 10, b: 20, c: 30) Args: other: another Domain object Returns: a new domain object covering the combined domain. """ extra = other.marginalize(self.attributes) new_labels = None if self.labels is not None and other.labels is not None: new_labels = self.labels + extra.labels return Domain( self.attributes + extra.attributes, self.shape + extra.shape, labels=new_labels, )
[docs] def size(self, attributes: Sequence[str] | None = None) -> int: """Return the total size of the domain. Example: >>> D1 = Domain(['a','b'], [10,20]) >>> D1.size() 200 >>> D1.size(['a']) 10 Args: attributes: A subset of attributes whose total size should be returned. Returns: the total size of the domain """ if attributes is None: return functools.reduce(lambda x, y: x * y, self.shape, 1) return self.project(attributes).size()
@property def attrs(self): """Alias for the `attributes` tuple.""" return self.attributes
[docs] def supports(self, attrs: str | Sequence[str]) -> bool: if isinstance(attrs, str): attrs = [attrs] return set(attrs) <= set(self.attributes)
def __contains__(self, name: str) -> bool: """Check if the given attribute is in the domain.""" return name in self.attributes def __getitem__(self, a: str) -> int: return self.config[a] def __iter__(self) -> Iterator[str]: return self.attributes.__iter__() def __len__(self) -> int: return len(self.attributes) def __str__(self) -> str: inner = ", ".join(["%s: %d" % x for x in zip(self.attributes, self.shape)]) return "Domain(%s)" % inner