mbi.estimation.dual_averaging_step

mbi.estimation.dual_averaging_step(state: DualAveragingState, loss_fn: MarginalLossFn, marginal_oracle: MarginalOracle, total: Array | float, lipschitz: float, gamma: float, t: int) DualAveragingState[source]

Performs a single dual averaging step.

Parameters:
  • state – Current algorithm state.

  • loss_fn – The marginal loss function.

  • marginal_oracle – A marginal oracle with signature (potentials, total) -> marginals.

  • total – The known or estimated total number of records.

  • lipschitz – Lipschitz constant of the gradient, divided by total.

  • gamma – Variance-related parameter (typically 0 for deterministic).

  • t – Current iteration number (1-indexed).

Returns:

Updated DualAveragingState.