mbi.estimation.mirror_descent_step
- mbi.estimation.mirror_descent_step(state: MirrorDescentState, loss_fn: MarginalLossFn, marginal_oracle: MarginalOracle, total: Array | float, linesearch: bool = True) MirrorDescentState[source]
Performs a single mirror descent step.
- Parameters:
state – Current algorithm state.
loss_fn – The marginal loss function.
marginal_oracle – A marginal oracle with signature
(potentials, total) -> marginals.total – The known or estimated total number of records.
linesearch – If True (default), uses Armijo line search to adapt the step size. If False, uses a fixed step size.
- Returns:
Updated
MirrorDescentState.