pub fn train_epoch_distributed<F, O: GraphOptimizer>(
graph: &mut Graph,
optimizer: &mut O,
aggregator: &mut dyn GradientAggregator,
trainable_nodes: &[NodeId],
num_batches: usize,
train_batch_fn: &mut F,
) -> Result<EpochMetrics, ModelError>Expand description
Train one epoch with distributed gradient synchronization.
After each batch’s backward pass, gradients are collected from the
trainable parameter nodes, aggregated across all ranks using the
provided GradientAggregator (e.g. AllReduceAggregator or
LocalAggregator for single-rank), written back, and then the
optimizer is stepped. This is the data-parallel training pattern.
The caller supplies a closure train_batch_fn that, given the graph
and a batch index, must:
- Set up the forward pass for the batch (feed inputs, compute prediction).
- Compute the loss and call
graph.backward(loss_node). - Return
Ok(loss_scalar).
The function returns the mean loss across all batches.