mbi.estimation.mirror_descent_step

mbi.estimation.mirror_descent_step(state: MirrorDescentState, loss_fn: MarginalLossFn, marginal_oracle: MarginalOracle, total: Array | float, linesearch: bool = True) MirrorDescentState[source]

Performs a single mirror descent step.

Parameters:
  • state – Current algorithm state.

  • loss_fn – The marginal loss function.

  • marginal_oracle – A marginal oracle with signature (potentials, total) -> marginals.

  • total – The known or estimated total number of records.

  • linesearch – If True (default), uses Armijo line search to adapt the step size. If False, uses a fixed step size.

Returns:

Updated MirrorDescentState.