Skip to main content

torsh_optim/distributed/
core.rs

1//! Core distributed optimizer functionality
2//!
3//! This module provides the main distributed optimizer wrapper and core types
4//! for synchronizing gradients across multiple processes in distributed training.
5
6use crate::{Optimizer, OptimizerResult, OptimizerState};
7use parking_lot::RwLock;
8use std::collections::HashMap;
9use std::sync::Arc;
10use torsh_core::error::{Result, TorshError};
11use torsh_tensor::Tensor;
12
13/// Communication backend for distributed training
14#[derive(Debug, Clone)]
15pub enum DistributedBackend {
16    /// NCCL backend for NVIDIA GPUs
17    NCCL,
18    /// MPI backend for general distributed computing
19    MPI,
20    /// Gloo backend for CPU and GPU
21    Gloo,
22    /// Custom backend
23    Custom(String),
24}
25
26/// Gradient synchronization strategy
27#[derive(Debug, Clone)]
28pub enum SyncStrategy {
29    /// AllReduce: Sum gradients across all processes
30    AllReduce,
31    /// AllGather: Gather all gradients and average
32    AllGather,
33    /// ReduceScatter: Distribute gradient reduction across processes
34    ReduceScatter,
35}
36
37/// Configuration for distributed optimizer
38#[derive(Debug, Clone)]
39pub struct DistributedConfig {
40    /// Communication backend
41    pub backend: DistributedBackend,
42    /// Gradient synchronization strategy
43    pub sync_strategy: SyncStrategy,
44    /// World size (total number of processes)
45    pub world_size: usize,
46    /// Current process rank
47    pub rank: usize,
48    /// Process group for communication
49    pub process_group: Option<String>,
50    /// Whether to enable gradient compression
51    pub gradient_compression: bool,
52    /// Bucket size for gradient bucketing (in MB)
53    pub bucket_size_mb: f32,
54    /// Whether to overlap communication with computation
55    pub overlap_communication: bool,
56}
57
58impl Default for DistributedConfig {
59    fn default() -> Self {
60        Self {
61            backend: DistributedBackend::Gloo,
62            sync_strategy: SyncStrategy::AllReduce,
63            world_size: 1,
64            rank: 0,
65            process_group: None,
66            gradient_compression: false,
67            bucket_size_mb: 25.0,
68            overlap_communication: true,
69        }
70    }
71}
72
73/// Distributed optimizer wrapper
74///
75/// This wrapper can be applied to any optimizer to enable distributed training.
76/// It handles gradient synchronization across multiple processes before applying
77/// the underlying optimizer's update rule.
78pub struct DistributedOptimizer<O: Optimizer> {
79    optimizer: O,
80    config: DistributedConfig,
81    gradient_buckets: Vec<GradientBucket>,
82    #[allow(dead_code)]
83    communication_handle: Option<CommunicationHandle>,
84}
85
86/// Gradient bucket for efficient communication
87#[derive(Debug)]
88pub struct GradientBucket {
89    pub tensors: Vec<Arc<RwLock<Tensor>>>,
90    #[allow(dead_code)]
91    pub flattened_grad: Option<Tensor>,
92    pub size_bytes: usize,
93}
94
95/// Handle for asynchronous communication
96#[derive(Debug)]
97pub struct CommunicationHandle {
98    #[allow(dead_code)]
99    pub operation_id: u64,
100}
101
102/// Communication statistics for monitoring
103#[derive(Debug, Default, Clone)]
104pub struct CommunicationStats {
105    pub total_communications: u64,
106    pub total_bytes_transferred: u64,
107    pub average_communication_time_ms: f32,
108    pub gradient_compression_ratio: f32,
109}
110
111impl<O: Optimizer> DistributedOptimizer<O> {
112    /// Create a new distributed optimizer
113    pub fn new(optimizer: O, config: DistributedConfig) -> OptimizerResult<Self> {
114        let gradient_buckets = Vec::new();
115
116        Ok(Self {
117            optimizer,
118            config,
119            gradient_buckets,
120            communication_handle: None,
121        })
122    }
123
124    /// Get the underlying optimizer
125    pub fn inner(&self) -> &O {
126        &self.optimizer
127    }
128
129    /// Get the underlying optimizer mutably
130    pub fn inner_mut(&mut self) -> &mut O {
131        &mut self.optimizer
132    }
133
134    /// Get the configuration
135    pub fn config(&self) -> &DistributedConfig {
136        &self.config
137    }
138
139    /// Update the configuration
140    pub fn set_config(&mut self, config: DistributedConfig) {
141        self.config = config;
142    }
143
144    /// Get communication statistics
145    pub fn get_communication_stats(&self) -> CommunicationStats {
146        // In a real implementation, this would return actual statistics
147        CommunicationStats::default()
148    }
149
150    /// Synchronize gradients across all processes
151    pub fn synchronize_gradients(&mut self) -> OptimizerResult<()> {
152        // In a real implementation, this would perform actual gradient synchronization
153        // For now, this is a placeholder that would integrate with communication backends
154
155        match self.config.sync_strategy {
156            SyncStrategy::AllReduce => self.all_reduce_gradients(),
157            SyncStrategy::AllGather => self.all_gather_gradients(),
158            SyncStrategy::ReduceScatter => self.reduce_scatter_gradients(),
159        }
160    }
161
162    /// Perform all-reduce on gradients
163    fn all_reduce_gradients(&mut self) -> OptimizerResult<()> {
164        // Placeholder for all-reduce implementation
165        // In reality, this would call into the communication backend
166        Ok(())
167    }
168
169    /// Perform all-gather on gradients
170    fn all_gather_gradients(&mut self) -> OptimizerResult<()> {
171        // Placeholder for all-gather implementation
172        Ok(())
173    }
174
175    /// Perform reduce-scatter on gradients
176    fn reduce_scatter_gradients(&mut self) -> OptimizerResult<()> {
177        // Placeholder for reduce-scatter implementation
178        Ok(())
179    }
180
181    /// Create gradient buckets for efficient communication
182    pub fn create_gradient_buckets(
183        &mut self,
184        parameters: &[Arc<RwLock<Tensor>>],
185    ) -> OptimizerResult<()> {
186        let bucket_size_bytes = (self.config.bucket_size_mb * 1024.0 * 1024.0) as usize;
187        let mut current_bucket = GradientBucket {
188            tensors: Vec::new(),
189            flattened_grad: None,
190            size_bytes: 0,
191        };
192
193        for param in parameters {
194            let param_guard = param.read();
195            let param_size = param_guard.shape().numel() * 4; // Assuming f32
196
197            if current_bucket.size_bytes + param_size > bucket_size_bytes
198                && !current_bucket.tensors.is_empty()
199            {
200                // Start a new bucket
201                self.gradient_buckets.push(current_bucket);
202                current_bucket = GradientBucket {
203                    tensors: Vec::new(),
204                    flattened_grad: None,
205                    size_bytes: 0,
206                };
207            }
208
209            current_bucket.tensors.push(param.clone());
210            current_bucket.size_bytes += param_size;
211        }
212
213        // Add the last bucket if it has any tensors
214        if !current_bucket.tensors.is_empty() {
215            self.gradient_buckets.push(current_bucket);
216        }
217
218        Ok(())
219    }
220
221    /// Flatten gradients within a bucket for efficient communication
222    fn flatten_bucket_gradients(&self, bucket: &GradientBucket) -> OptimizerResult<Tensor> {
223        // This would flatten all gradients in the bucket into a single tensor
224        // For now, return a placeholder
225        let total_elements: usize = bucket
226            .tensors
227            .iter()
228            .map(|t| {
229                let guard = t.read();
230                guard.shape().numel()
231            })
232            .sum();
233
234        // Return a zero tensor as placeholder
235        let flattened = Tensor::zeros(&[total_elements], torsh_core::device::DeviceType::Cpu)?;
236        Ok(flattened)
237    }
238
239    /// Unflatten gradients after communication
240    fn unflatten_bucket_gradients(
241        &self,
242        bucket: &GradientBucket,
243        flattened: &Tensor,
244    ) -> OptimizerResult<()> {
245        // This would unflatten the communicated gradients back to individual tensors
246        // For now, this is a placeholder
247        Ok(())
248    }
249}
250
251impl<O: Optimizer> Optimizer for DistributedOptimizer<O> {
252    fn step(&mut self) -> OptimizerResult<()> {
253        // Synchronize gradients before optimization step
254        self.synchronize_gradients()?;
255
256        // Perform the optimization step
257        self.optimizer.step()
258    }
259
260    fn zero_grad(&mut self) {
261        self.optimizer.zero_grad();
262    }
263
264    fn get_lr(&self) -> Vec<f32> {
265        self.optimizer.get_lr()
266    }
267
268    fn set_lr(&mut self, lr: f32) {
269        self.optimizer.set_lr(lr);
270    }
271
272    fn add_param_group(&mut self, params: Vec<Arc<RwLock<Tensor>>>, options: HashMap<String, f32>) {
273        self.optimizer.add_param_group(params, options);
274    }
275
276    fn state_dict(&self) -> OptimizerResult<OptimizerState> {
277        self.optimizer.state_dict()
278    }
279
280    fn load_state_dict(&mut self, state: OptimizerState) -> OptimizerResult<()> {
281        self.optimizer.load_state_dict(state)
282    }
283}
284
285/// Extension trait for adding distributed functionality to any optimizer
286pub trait OptimizerExt: Optimizer + Sized {
287    /// Wrap this optimizer with distributed functionality
288    fn distributed(self, config: DistributedConfig) -> OptimizerResult<DistributedOptimizer<Self>> {
289        DistributedOptimizer::new(self, config)
290    }
291}
292
293impl<O: Optimizer> OptimizerExt for O {}