Skip to main content

distributed_train_step

Function distributed_train_step 

Source
pub fn distributed_train_step<F, G>(
    compute_gradients_fn: F,
    apply_gradients_fn: G,
    aggregator: &mut dyn GradientAggregator,
) -> Result<f32, ModelError>
where F: FnOnce() -> Result<(f32, Vec<Tensor>), ModelError>, G: FnOnce(&[Tensor]) -> Result<(), ModelError>,
Expand description

Performs a single distributed training step: forward, backward, aggregate, update.

This is the main entry point for one iteration of distributed training. It calls compute_gradients_fn to run the local forward and backward pass, then aggregates gradients across workers via the provided aggregator, and finally applies the aggregated gradients through apply_gradients_fn.

§Arguments

  • compute_gradients_fn – closure that returns (loss, local_gradients).
  • apply_gradients_fn – closure that receives the aggregated gradients and updates model parameters (e.g. via an optimizer step).
  • aggregator – the gradient aggregation strategy (e.g. AllReduceAggregator or LocalAggregator for single-worker training).

§Returns

The scalar loss value produced by compute_gradients_fn.

§Example

let loss = distributed_train_step(
    || { /* forward + backward */ Ok((loss_val, grads)) },
    |agg_grads| { optimizer.apply(agg_grads); Ok(()) },
    &mut aggregator,
)?;