Skip to main content

train_epoch_distributed

Function train_epoch_distributed 

Source
pub fn train_epoch_distributed<F, O: GraphOptimizer>(
    graph: &mut Graph,
    optimizer: &mut O,
    aggregator: &mut dyn GradientAggregator,
    trainable_nodes: &[NodeId],
    num_batches: usize,
    train_batch_fn: &mut F,
) -> Result<EpochMetrics, ModelError>
where F: FnMut(&mut Graph, usize) -> Result<f32, ModelError>,
Expand description

Train one epoch with distributed gradient synchronization.

After each batch’s backward pass, gradients are collected from the trainable parameter nodes, aggregated across all ranks using the provided GradientAggregator (e.g. AllReduceAggregator or LocalAggregator for single-rank), written back, and then the optimizer is stepped. This is the data-parallel training pattern.

The caller supplies a closure train_batch_fn that, given the graph and a batch index, must:

  1. Set up the forward pass for the batch (feed inputs, compute prediction).
  2. Compute the loss and call graph.backward(loss_node).
  3. Return Ok(loss_scalar).

The function returns the mean loss across all batches.