pub fn distributed_train_step<F, G>(
compute_gradients_fn: F,
apply_gradients_fn: G,
aggregator: &mut dyn GradientAggregator,
) -> Result<f32, ModelError>Expand description
Performs a single distributed training step: forward, backward, aggregate, update.
This is the main entry point for one iteration of distributed training.
It calls compute_gradients_fn to run the local forward and backward pass,
then aggregates gradients across workers via the provided aggregator,
and finally applies the aggregated gradients through apply_gradients_fn.
§Arguments
compute_gradients_fn– closure that returns(loss, local_gradients).apply_gradients_fn– closure that receives the aggregated gradients and updates model parameters (e.g. via an optimizer step).aggregator– the gradient aggregation strategy (e.g.AllReduceAggregatororLocalAggregatorfor single-worker training).
§Returns
The scalar loss value produced by compute_gradients_fn.
§Example
ⓘ
let loss = distributed_train_step(
|| { /* forward + backward */ Ok((loss_val, grads)) },
|agg_grads| { optimizer.apply(agg_grads); Ok(()) },
&mut aggregator,
)?;