mbi.Factor

class mbi.Factor(domain: Domain, values: Array)[source]

Bases: object

Represents a factor defined over a discrete domain.

A factor can be thought of as a potential function or an unnormalized probability distribution over a set of discrete variables defined by a Domain object. It maps each configuration of the domain to a value.

domain

The discrete domain over which the factor is defined.

Type:

Domain

values

A JAX array containing the factor’s values. The shape of this array matches the shape specified by the domain.

Type:

jax.Array

Supported Operations:
  • Creation: zeros, ones, random for creating factors.

  • Reshaping: transpose, expand for modifying the domain/shape.

  • Aggregation: sum, logsumexp, project for marginalizing attributes.

  • Element-wise: exp, log, normalize for value transformations.

  • Binary Ops: +, -, *, /, dot for combining factors.

Example Usage:
>>> from mbi import Domain  # Needed for doctest context
>>> domain = Domain.fromdict({'X': 2, 'Y': 3})
>>> factor = Factor.ones(domain)
>>> print(factor.domain)
Domain(X: 2, Y: 3)

Method generated by attrs for class Factor.

Methods

__init__

Method generated by attrs for class Factor.

abstract

apply_sharding

Apply sharding constraint to the factor values.

copy

Returns a copy of the factor (potentially shallow due to JAX).

datavector

Returns the factor's values as a flattened vector or original array.

dot

exp

Applies element-wise exponentiation (jnp.exp) to the factor's values.

expand

Expands the factor's domain to include new attributes.

log

Applies element-wise logarithm (jnp.log) to the factor's values.

logsumexp

Computes the log-sum-exp along specified attribute axes.

max

Computes the maximum value along specified attribute axes.

normalize

Normalizes the factor so its values sum to total (or log-normalize).

ones

Creates a Factor object with all values initialized to one.

pad

project

Computes the marginal distribution by summing/logsumexp'ing out other attributes.

random

Creates a Factor object with random values (uniform 0-1).

slice

Slices the factor by fixing specific attribute values.

sum

Computes the sum along specified attribute axes.

supports

transpose

Rearranges the factor's axes according to the new attribute order.

zeros

Creates a Factor object with all values initialized to zero.

Attributes

domain

values

domain: Domain
values: Array
classmethod zeros(domain: Domain) Factor[source]

Creates a Factor object with all values initialized to zero.

classmethod ones(domain: Domain) Factor[source]

Creates a Factor object with all values initialized to one.

classmethod random(domain: Domain) Factor[source]

Creates a Factor object with random values (uniform 0-1).

classmethod abstract(domain: Domain) Factor[source]
transpose(attrs: Sequence[str]) Factor[source]

Rearranges the factor’s axes according to the new attribute order.

expand(domain)[source]

Expands the factor’s domain to include new attributes.

max(attrs: Sequence[str] | None = None) Factor[source]

Computes the maximum value along specified attribute axes.

sum(attrs: Sequence[str] | None = None) Factor[source]

Computes the sum along specified attribute axes.

logsumexp(attrs: Sequence[str] | None = None) Factor[source]

Computes the log-sum-exp along specified attribute axes.

project(attrs: str | Sequence[str], log: bool = False) Factor[source]

Computes the marginal distribution by summing/logsumexp’ing out other attributes.

slice(evidence: dict[str, int | ndarray | Array]) Factor[source]

Slices the factor by fixing specific attribute values.

If at least one attribute has numpy-valued evidence, the returned factor will have a new leading dimension called ‘_mbi_evidence’ corresponding to the number of evidence points.

Parameters:

evidence – A dictionary mapping attribute names to the values they should be fixed to.

Returns:

A new Factor with the specified attributes fixed and removed from the domain.

supports(attrs: str | Sequence[str]) bool[source]
exp(out=None) Factor[source]

Applies element-wise exponentiation (jnp.exp) to the factor’s values.

log(out=None) Factor[source]

Applies element-wise logarithm (jnp.log) to the factor’s values.

normalize(total: float = 1.0, log: bool = False) Factor[source]

Normalizes the factor so its values sum to total (or log-normalize).

copy() Factor[source]

Returns a copy of the factor (potentially shallow due to JAX).

dot(other: Factor) Array | ndarray | bool | number | float | int[source]
datavector(flatten: bool = True) Array[source]

Returns the factor’s values as a flattened vector or original array.

pad(mesh: Mesh | None, pad_value: Literal[0, '-inf']) Factor[source]
apply_sharding(mesh: Mesh | None) Factor[source]

Apply sharding constraint to the factor values.

The sharding strategy is automatically determined based on the provided mesh, and the factor domain.

Parameters:

mesh – The mesh over which the factor should be sharded.

Returns:

A new factor identical to self with sharding constraints applied to the values.