Skip to main content

GradientAggregator

Trait GradientAggregator 

Source
pub trait GradientAggregator: Send {
    // Required method
    fn aggregate(
        &mut self,
        local_gradients: &[Tensor],
    ) -> Result<Vec<Tensor>, ModelError>;
}
Expand description

Strategy for combining gradients across distributed workers.

Required Methods§

Source

fn aggregate( &mut self, local_gradients: &[Tensor], ) -> Result<Vec<Tensor>, ModelError>

Aggregate local gradients, returning the combined result.

Implementors§