"""Provides the Dataset class for representing and manipulating tabular data.
This module defines the `Dataset` class, which serves as a wrapper around a
numpy array, associating it with a `Domain` object. It allows for
structured representation of data, facilitating operations like projection onto
subsets of attributes and conversion into a data vector format suitable for
various statistical and machine learning tasks.
"""
from __future__ import annotations
import csv
import functools
import json
from collections.abc import Sequence
from typing import Any
import attr
import jax
import jax.numpy as jnp
import math
import numpy as np
from numpy.typing import ArrayLike, NDArray
from .domain import Domain
from .factor import Factor
import warnings
def _validate_column(data: np.ndarray, size: int):
if data.ndim != 1:
raise ValueError(f"Expected column data to be 1D, found shape {data.shape}")
if not np.issubdtype(data.dtype, np.integer):
raise ValueError(f"Expected integer data, got {data.dtype}")
if not np.all((data >= 0) & (data < size)):
raise ValueError(f"Expected data in range [0, {size})")
def _validate_data(data: dict[str, np.ndarray], domain: Domain):
if set(data.keys()) != set(domain.attrs):
raise ValueError("Keys in data dictionary must match domain attributes")
n = None
for col in data:
_validate_column(data[col], domain[col])
if n is None:
n = data[col].shape[0]
if n != data[col].shape[0]:
raise ValueError("Expected data to have same size for each record.")
def _validate_mapping(map_array: np.ndarray, attr: str):
if map_array.ndim != 1:
raise ValueError(f"Mapping for {attr} must be 1D array")
if not np.issubdtype(map_array.dtype, np.integer):
raise ValueError(f"Mapping for {attr} must be integers")
if np.any(map_array < 0):
raise ValueError(f"Mapping for {attr} must be non-negative")
[docs]
class Dataset:
def __init__(
self,
data: ArrayLike | dict[str, ArrayLike],
domain: Domain,
weights: np.ndarray | None = None,
):
"""create a Dataset object
:param data: a numpy array (n x d) or a dictionary of 1d arrays (length n), keyed by attribute.
:param domain: a domain object
:param weight: weight for each row
"""
if isinstance(data, np.ndarray):
if data.shape[1] != len(domain.attrs):
raise ValueError("Shape of data does not match shape of domain")
n = data.shape[0]
data = {attr: data[:, i] for i, attr in enumerate(domain.attrs)}
elif isinstance(data, dict):
if len(data) > 0:
n = list(data.values())[0].shape[0]
else:
n = None
elif hasattr(data, "values"): # Pandas DataFrame
warnings.warn(
"Pandas dataframe inputs are deprecated, please pass in a dictionary of numpy arrays instead."
)
n = data.shape[0]
data = {attr: data[attr].values for attr in domain.attrs}
else:
raise ValueError(f"Unrecognized data type {type(data)}")
_validate_data(data, domain)
if n == None:
if weights is None:
raise ValueError(
"Weights must be provided if data is empty (cannot infer N)"
)
n = weights.size
if weights is None:
weights = np.ones(n)
assert n == weights.size
self.domain = domain
self._data = data
self.weights = weights
self._n = n
[docs]
def to_dict(self) -> dict[str, np.ndarray]:
return self._data
@property
def df(self):
import pandas
return pandas.DataFrame(self._data)
[docs]
@staticmethod
def synthetic(domain: Domain, N: int) -> Dataset:
"""Generate synthetic data conforming to the given domain
:param domain: The domain object
:param N: the number of individuals
"""
arr = [np.random.randint(low=0, high=n, size=N) for n in domain.shape]
values = np.array(arr).T
return Dataset(values, domain)
[docs]
@staticmethod
def load(path: str, domain: str | Domain) -> Dataset:
"""Load data into a dataset object
:param path: path to csv file
:param domain: path to json file encoding the domain information
"""
if isinstance(domain, str):
with open(domain, "r", encoding="utf-8") as f:
config = json.load(f)
domain_obj = Domain(config.keys(), config.values())
else:
domain_obj = domain
with open(path, "r", encoding="utf-8") as f:
reader = csv.reader(f)
header = next(reader)
header_map = {name: i for i, name in enumerate(header)}
if not set(domain_obj.attrs) <= set(header):
raise ValueError("data must contain domain attributes")
indices = [header_map[attr] for attr in domain_obj.attrs]
data = []
for row in reader:
# Convert to int, handling potential float strings like '1.0'
try:
mapped_row = [int(float(row[i])) for i in indices]
except ValueError:
# Fallback or error if data is not numeric
# Assuming domain implies discrete/integer data
mapped_row = [int(row[i]) for i in indices]
data.append(mapped_row)
return Dataset(np.array(data), domain_obj)
[docs]
def project(self, cols: int | str | Sequence[str] | Sequence[int]) -> Factor:
"""project dataset onto a subset of columns"""
if isinstance(cols, (str, int)):
cols = [cols]
domain = self.domain.project(cols)
data = {col: self._data[col] for col in domain.attrs}
data = Dataset(data, domain, self.weights)
return Factor(data.domain, jnp.asarray(data.datavector(flatten=False)))
[docs]
def supports(self, cols: str | Sequence[str]) -> bool:
return self.domain.supports(cols)
[docs]
def drop(self, cols: Sequence[str]) -> Factor:
"""Returns a new Dataset with the specified columns removed."""
proj = [c for c in self.domain if c not in cols]
return self.project(proj)
@property
def records(self) -> int:
"""Returns the number of records (rows) in the dataset."""
return self._n
[docs]
def datavector(self, flatten: bool = True) -> NDArray:
"""return the database in vector-of-counts form"""
dims = self.domain.shape
if len(dims) == 0:
result = self.weights.sum()
return np.array([result]) if flatten else result
multi_index = tuple(self._data[a] for a in self.domain.attrs)
linear_indices = np.ravel_multi_index(multi_index, dims, order="C")
counts = np.bincount(
linear_indices, minlength=math.prod(dims), weights=self.weights
)
return counts if flatten else counts.reshape(dims)
[docs]
def compress(self, mapping: dict[str, np.ndarray]) -> Dataset:
"""
Compresses the dataset by mapping domain elements to a smaller domain.
Args:
mapping: A dictionary where keys are attribute names and values are 1D arrays.
mapping[attr][i] gives the new value for original value i.
Returns:
A new Dataset with transformed values and updated domain.
"""
new_data = dict(self._data)
new_domain_config = self.domain.config.copy()
for attr, map_array in mapping.items():
if attr not in self.domain:
continue
_validate_mapping(map_array, attr)
if map_array.shape[0] != self.domain[attr]:
raise ValueError(
f"Mapping size {map_array.shape[0]} does not match domain size {self.domain[attr]} for attribute {attr}"
)
new_col = map_array[self._data[attr]]
new_data[attr] = new_col.astype(np.min_scalar_type(np.max(map_array)))
new_size = int(np.max(map_array) + 1)
new_domain_config[attr] = new_size
new_domain = Domain(new_domain_config.keys(), new_domain_config.values())
return Dataset(new_data, new_domain, self.weights)
[docs]
def decompress(self, mapping: dict[str, np.ndarray]) -> Dataset:
"""
Decompresses the dataset by reversing the mapping.
Since the mapping is surjective, the reverse mapping is one-to-many.
We sample uniformly from the possible original values.
Args:
mapping: The same mapping dictionary used for compression.
Returns:
A new Dataset with restored domain size and sampled values.
"""
new_data = dict(self._data)
new_domain_config = self.domain.config.copy()
for attr, map_array in mapping.items():
if attr not in self.domain:
continue
_validate_mapping(map_array, attr)
permutation = np.argsort(map_array)
sorted_map = map_array[permutation]
compressed_domain_size = int(np.max(map_array) + 1)
counts = np.bincount(sorted_map, minlength=compressed_domain_size)
starts = np.zeros(compressed_domain_size + 1, dtype=int)
starts[1:] = np.cumsum(counts)
starts = starts[:-1]
current_col = self._data[attr]
col_counts = counts[current_col]
if np.any(col_counts == 0):
raise ValueError(
f"Data contains values for {attr} that have no preimage in the mapping."
)
random_offsets = np.floor(
np.random.rand(len(current_col)) * col_counts
).astype(int)
lookup_indices = starts[current_col] + random_offsets
new_col = permutation[lookup_indices]
new_data[attr] = new_col.astype(np.min_scalar_type(len(map_array) - 1))
new_domain_config[attr] = len(map_array)
new_domain = Domain(new_domain_config.keys(), new_domain_config.values())
return Dataset(new_data, new_domain, self.weights)
@functools.partial(
jax.tree_util.register_dataclass,
meta_fields=["domain"],
data_fields=["data", "weights"],
)
@attr.dataclass(frozen=True)
class JaxDataset:
"""Represents a discrete dataset backed by JAX Arrays.
Attributes:
data (dict[str, jax.Array]): A dictionary of 1D JAX arrays where keys are attributes
and values are columns of data.
domain (Domain): A `Domain` object describing the attributes and their
possible discrete values.
weights (jax.Array | None): An optional 1D JAX array representing the
weight for each record in the dataset. If None, all records are
assumed to have a weight of 1.
"""
data: dict[str, jax.Array]
domain: Domain
weights: jax.Array | None = None
@staticmethod
def synthetic(domain: Domain, records: int) -> JaxDataset:
"""Generate synthetic data conforming to the given domain
:param domain: The domain object
:param records: the number of individuals
"""
data = {}
for attr, n in zip(domain.attrs, domain.shape):
data[attr] = jnp.array(np.random.randint(low=0, high=n, size=records))
return JaxDataset(data, domain)
def project(self, cols: str | Sequence[str]) -> Factor:
"""project dataset onto a subset of columns"""
if isinstance(cols, (str, int)):
cols = [cols]
domain = self.domain.project(cols)
dims = domain.shape
if not dims:
w = self.weights if self.weights is not None else jnp.ones(self.records)
result = w.sum()
return Factor(domain, jnp.array([result]))
length = math.prod(dims)
dtype = np.min_scalar_type(length-1)
multi_index = [self.data[a] for a in domain.attrs]
multi_index[0] = multi_index[0].astype(dtype)
linear_indices = jnp.ravel_multi_index(
tuple(multi_index), dims, mode="wrap", order="C"
)
counts = jnp.bincount(linear_indices, weights=self.weights, minlength=length)
return Factor(domain, counts.reshape(dims))
def supports(self, cols: str | Sequence[str]) -> bool:
return self.domain.supports(cols)
@property
def records(self) -> int:
"""Returns the number of records (rows) in the dataset."""
if not self.data:
raise ValueError("Dataset is empty (no columns).")
return list(self.data.values())[0].shape[0]
def apply_sharding(self, mesh: jax.sharding.Mesh) -> JaxDataset:
pspec = jax.sharding.PartitionSpec(mesh.axis_names)
sharding = jax.sharding.NamedSharding(mesh, pspec)
new_data = {}
for k, v in self.data.items():
new_data[k] = jax.lax.with_sharding_constraint(v, sharding)
weights = (
self.weights
if self.weights is None
else jax.lax.with_sharding_constraint(self.weights, sharding)
)
return JaxDataset(new_data, self.domain, weights)