Skip to main content

tenflowers_neural/distributed/
data_parallel.rs

1//! DataParallel and DistributedDataParallel model wrappers with gradient sync
2
3use parking_lot::RwLock;
4use std::sync::Arc;
5use tenflowers_core::{Device, Result, Tensor, TensorError};
6
7use super::types::{
8    BackendConfig, CollectiveOp, CollectiveResult, CommunicationBackend, CommunicationGroup,
9    CommunicationRuntime, ReductionOp,
10};
11use crate::Model;
12
13/// Data Parallel model wrapper for single-node multi-GPU training
14pub struct DataParallel {
15    /// Base model to replicate across devices
16    pub(crate) base_model: Arc<RwLock<Box<dyn Model<f32>>>>,
17    /// Device assignments for each replica
18    pub(crate) device_replicas: Vec<Device>,
19    /// Communication runtime for gradient synchronization
20    pub(crate) comm_runtime: Arc<RwLock<CommunicationRuntime>>,
21    /// Whether model is in training mode
22    pub(crate) is_training: bool,
23    /// Synchronization mode
24    pub(crate) sync_mode: SynchronizationMode,
25}
26
27/// Distributed Data Parallel model wrapper for multi-node training
28pub struct DistributedDataParallel {
29    /// Base model wrapped for distributed training
30    pub(crate) base_model: Arc<RwLock<Box<dyn Model<f32>>>>,
31    /// Communication group for this DDP instance
32    pub(crate) process_group: Arc<CommunicationGroup>,
33    /// Communication runtime
34    pub(crate) comm_runtime: Arc<RwLock<CommunicationRuntime>>,
35    /// Local device for this process
36    pub(crate) device: Device,
37    /// Whether to broadcast parameters from rank 0 on initialization
38    pub(crate) broadcast_buffers: bool,
39    /// Whether model is in training mode
40    pub(crate) is_training: bool,
41    /// Gradient bucket size for efficient communication
42    pub(crate) bucket_size: usize,
43    /// DDP-specific configuration
44    pub(crate) ddp_config: DDPConfig,
45}
46
47/// Configuration for Distributed Data Parallel training
48#[derive(Debug, Clone)]
49pub struct DDPConfig {
50    /// Find unused parameters to skip in gradient sync
51    pub find_unused_parameters: bool,
52    /// Gradient as bucket view for memory efficiency
53    pub gradient_as_bucket_view: bool,
54    /// Static computation graph optimization
55    pub static_graph: bool,
56    /// Delay all-reduce until backward is complete
57    pub delay_all_reduce: bool,
58}
59
60/// Synchronization modes for DataParallel
61#[derive(Debug, Clone, Copy)]
62pub enum SynchronizationMode {
63    /// Synchronous - wait for all replicas
64    Synchronous,
65    /// Asynchronous - don't wait for all replicas
66    Asynchronous,
67    /// Bounded staleness - allow limited staleness
68    BoundedStaleness { max_staleness: u32 },
69}
70
71impl Default for DDPConfig {
72    fn default() -> Self {
73        Self {
74            find_unused_parameters: false,
75            gradient_as_bucket_view: false,
76            static_graph: false,
77            delay_all_reduce: true,
78        }
79    }
80}
81
82impl DataParallel {
83    /// Create new DataParallel model wrapper
84    pub fn new(
85        model: Box<dyn Model<f32>>,
86        devices: Vec<Device>,
87        comm_runtime: Arc<RwLock<CommunicationRuntime>>,
88    ) -> Result<Self> {
89        if devices.is_empty() {
90            return Err(TensorError::invalid_argument_op(
91                "DataParallel::new",
92                "No devices provided",
93            ));
94        }
95
96        #[allow(clippy::arc_with_non_send_sync)]
97        let base_model = Arc::new(RwLock::new(model));
98
99        Self::replicate_parameters(&base_model, &devices)?;
100
101        Ok(Self {
102            base_model,
103            device_replicas: devices,
104            comm_runtime,
105            is_training: true,
106            sync_mode: SynchronizationMode::Synchronous,
107        })
108    }
109
110    /// Replicate model parameters across devices
111    fn replicate_parameters(
112        model: &Arc<RwLock<Box<dyn Model<f32>>>>,
113        devices: &[Device],
114    ) -> Result<()> {
115        let model_read = model.read();
116        let parameters = model_read.parameters();
117
118        for param in parameters {
119            for device in devices {
120                if *param.device() != *device {
121                    param.to(device.clone())?;
122                }
123            }
124        }
125
126        Ok(())
127    }
128
129    /// Perform forward pass with data parallelism
130    pub fn forward_parallel(&self, inputs: &[Tensor<f32>]) -> Result<Vec<Tensor<f32>>> {
131        if inputs.len() != self.device_replicas.len() {
132            return Err(TensorError::invalid_argument_op(
133                "forward_parallel",
134                &format!(
135                    "Expected {} inputs for {} devices",
136                    self.device_replicas.len(),
137                    inputs.len()
138                ),
139            ));
140        }
141
142        let model = self.base_model.read();
143        let mut outputs = Vec::with_capacity(inputs.len());
144
145        for (input, device) in inputs.iter().zip(&self.device_replicas) {
146            let input_on_device = if *input.device() != *device {
147                input.to(device.clone())?
148            } else {
149                input.clone()
150            };
151
152            let output = model.forward(&input_on_device)?;
153            outputs.push(output);
154        }
155
156        Ok(outputs)
157    }
158
159    /// Synchronize gradients across all device replicas
160    pub fn sync_gradients(&mut self) -> Result<()> {
161        if !self.is_training {
162            return Ok(());
163        }
164
165        let mut model = self.base_model.write();
166        let mut parameters = model.parameters_mut();
167
168        for param in parameters.iter_mut() {
169            if let Some(grad) = param.grad() {
170                let comm_runtime = self.comm_runtime.read();
171
172                let op = CollectiveOp::AllReduce {
173                    reduction_op: ReductionOp::Average,
174                };
175
176                if let Ok(CollectiveResult::Tensor(synced_grad)) =
177                    comm_runtime.collective_op_f32(op, grad, None)
178                {
179                    param.set_grad(Some(synced_grad));
180                }
181            }
182        }
183
184        Ok(())
185    }
186
187    /// Set synchronization mode
188    pub fn set_sync_mode(&mut self, mode: SynchronizationMode) {
189        self.sync_mode = mode;
190    }
191
192    /// Get device replicas
193    pub fn devices(&self) -> &[Device] {
194        &self.device_replicas
195    }
196}
197
198impl Model<f32> for DataParallel {
199    fn forward(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
200        let inputs: Vec<Tensor<f32>> = self
201            .device_replicas
202            .iter()
203            .map(|device| input.to(device.clone()))
204            .collect::<Result<Vec<_>>>()?;
205
206        let outputs = self.forward_parallel(&inputs)?;
207
208        let primary_device = &self.device_replicas[0];
209        let gathered_outputs: Vec<Tensor<f32>> = outputs
210            .into_iter()
211            .map(|output| output.to(primary_device.clone()))
212            .collect::<Result<Vec<_>>>()?;
213
214        let mut result = gathered_outputs[0].clone();
215        for output in &gathered_outputs[1..] {
216            result = result.add(output)?;
217        }
218
219        let num_devices = gathered_outputs.len() as f32;
220        let divisor = Tensor::from_scalar(num_devices);
221        result.div(&divisor)
222    }
223
224    fn parameters(&self) -> Vec<&Tensor<f32>> {
225        vec![]
226    }
227
228    fn parameters_mut(&mut self) -> Vec<&mut Tensor<f32>> {
229        vec![]
230    }
231
232    fn set_training(&mut self, training: bool) {
233        self.is_training = training;
234        self.base_model.write().set_training(training);
235    }
236
237    fn zero_grad(&mut self) {
238        self.base_model.write().zero_grad();
239    }
240
241    fn as_any(&self) -> &dyn std::any::Any {
242        self
243    }
244}
245
246impl DistributedDataParallel {
247    /// Create new DistributedDataParallel model wrapper
248    pub fn new(
249        model: Box<dyn Model<f32>>,
250        device: Device,
251        process_group: Arc<CommunicationGroup>,
252        comm_runtime: Arc<RwLock<CommunicationRuntime>>,
253        config: DDPConfig,
254    ) -> Result<Self> {
255        #[allow(clippy::arc_with_non_send_sync)]
256        let base_model = Arc::new(RwLock::new(model));
257
258        let mut ddp = Self {
259            base_model,
260            process_group,
261            comm_runtime,
262            device,
263            broadcast_buffers: true,
264            is_training: true,
265            bucket_size: 25 * 1024 * 1024, // 25MB default bucket size
266            ddp_config: config,
267        };
268
269        if ddp.broadcast_buffers {
270            ddp.broadcast_parameters()?;
271        }
272
273        Ok(ddp)
274    }
275
276    /// Broadcast model parameters from rank 0 to all other ranks
277    fn broadcast_parameters(&mut self) -> Result<()> {
278        let model = self.base_model.read();
279        let parameters = model.parameters();
280
281        let comm_runtime = self.comm_runtime.read();
282
283        for param in parameters {
284            let op = CollectiveOp::Broadcast { root_rank: 0 };
285
286            if let Ok(CollectiveResult::Tensor(_synced_param)) =
287                comm_runtime.collective_op_f32(op, param, Some(&self.process_group.group_id))
288            {
289                // Update parameter with broadcasted value
290                // Note: This is a simplified implementation
291            }
292        }
293
294        Ok(())
295    }
296
297    /// Perform gradient synchronization using all-reduce
298    pub fn sync_gradients(&mut self) -> Result<()> {
299        if !self.is_training {
300            return Ok(());
301        }
302
303        let mut model = self.base_model.write();
304        let mut parameters = model.parameters_mut();
305
306        let mut gradient_buckets = self.create_gradient_buckets(&mut parameters)?;
307
308        let comm_runtime = self.comm_runtime.read();
309
310        for bucket in &gradient_buckets {
311            for grad_tensor in bucket {
312                let op = CollectiveOp::AllReduce {
313                    reduction_op: ReductionOp::Average,
314                };
315
316                if let Ok(CollectiveResult::Tensor(_synced_grad)) = comm_runtime.collective_op_f32(
317                    op,
318                    grad_tensor,
319                    Some(&self.process_group.group_id),
320                ) {
321                    // In a complete implementation, we would update the parameter gradients
322                }
323            }
324        }
325
326        Ok(())
327    }
328
329    /// Create gradient buckets for efficient communication
330    fn create_gradient_buckets<'a>(
331        &self,
332        parameters: &'a mut [&'a mut Tensor<f32>],
333    ) -> Result<Vec<Vec<&'a Tensor<f32>>>> {
334        let mut buckets = Vec::new();
335        let mut current_bucket = Vec::new();
336        let mut current_bucket_size = 0;
337
338        for param in parameters {
339            if let Some(grad) = param.grad() {
340                let grad_size = grad.shape().size() * std::mem::size_of::<f32>();
341
342                if current_bucket_size + grad_size > self.bucket_size && !current_bucket.is_empty()
343                {
344                    buckets.push(std::mem::take(&mut current_bucket));
345                    current_bucket_size = 0;
346                }
347
348                current_bucket.push(grad);
349                current_bucket_size += grad_size;
350            }
351        }
352
353        if !current_bucket.is_empty() {
354            buckets.push(current_bucket);
355        }
356
357        Ok(buckets)
358    }
359
360    /// Get process group information
361    pub fn process_group(&self) -> &CommunicationGroup {
362        &self.process_group
363    }
364
365    /// Get local rank within process group
366    pub fn local_rank(&self) -> usize {
367        self.process_group.rank
368    }
369
370    /// Get world size (total number of processes)
371    pub fn world_size(&self) -> usize {
372        self.process_group.world_size
373    }
374
375    /// Set bucket size for gradient communication
376    pub fn set_bucket_size(&mut self, size: usize) {
377        self.bucket_size = size;
378    }
379}
380
381impl Model<f32> for DistributedDataParallel {
382    fn forward(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
383        let input_on_device = if *input.device() != self.device {
384            input.to(self.device.clone())?
385        } else {
386            input.clone()
387        };
388
389        self.base_model.read().forward(&input_on_device)
390    }
391
392    fn parameters(&self) -> Vec<&Tensor<f32>> {
393        vec![]
394    }
395
396    fn parameters_mut(&mut self) -> Vec<&mut Tensor<f32>> {
397        vec![]
398    }
399
400    fn set_training(&mut self, training: bool) {
401        self.is_training = training;
402        self.base_model.write().set_training(training);
403    }
404
405    fn zero_grad(&mut self) {
406        self.base_model.write().zero_grad();
407    }
408
409    fn as_any(&self) -> &dyn std::any::Any {
410        self
411    }
412}
413
414/// Utility functions for distributed models
415pub mod utils {
416    use super::super::types::CommunicationBackend;
417    use super::*;
418
419    /// Initialize process group for distributed training
420    pub fn init_process_group(
421        backend: CommunicationBackend,
422        rank: usize,
423        world_size: usize,
424    ) -> Result<(Arc<RwLock<CommunicationRuntime>>, Arc<CommunicationGroup>)> {
425        let mut comm_runtime = CommunicationRuntime::new();
426
427        match backend {
428            CommunicationBackend::Thread => {
429                comm_runtime.register_backend(
430                    CommunicationBackend::Thread,
431                    Box::new(crate::backends::thread::ThreadBackend::new()),
432                );
433            }
434            #[cfg(feature = "nccl")]
435            CommunicationBackend::Nccl => {
436                comm_runtime.register_backend(
437                    CommunicationBackend::Nccl,
438                    Box::new(crate::backends::nccl::NcclBackend::new()),
439                );
440            }
441            _ => {
442                return Err(TensorError::unsupported_operation_simple(format!(
443                    "Backend {backend:?} not supported"
444                )));
445            }
446        }
447
448        let config = BackendConfig::default();
449        comm_runtime.initialize(&config)?;
450
451        let devices = super::super::auto_detect_available_devices();
452        let process_group = Arc::new(CommunicationGroup {
453            group_id: "ddp_main".to_string(),
454            rank,
455            world_size,
456            devices,
457            backend,
458        });
459
460        let runtime = Arc::new(RwLock::new(comm_runtime));
461        runtime.write().create_group((*process_group).clone())?;
462
463        Ok((runtime, process_group))
464    }
465
466    /// Create DataParallel model wrapper with automatic device detection
467    pub fn create_data_parallel(model: Box<dyn Model<f32>>) -> Result<DataParallel> {
468        let devices = super::super::auto_detect_available_devices();
469        let comm_runtime = super::super::utils::init_distributed(0, devices.len(), None)?;
470        let runtime = Arc::new(RwLock::new(comm_runtime));
471
472        DataParallel::new(model, devices, runtime)
473    }
474
475    /// Create DistributedDataParallel model wrapper
476    pub fn create_distributed_data_parallel(
477        model: Box<dyn Model<f32>>,
478        device: Device,
479        backend: CommunicationBackend,
480        rank: usize,
481        world_size: usize,
482    ) -> Result<DistributedDataParallel> {
483        let (comm_runtime, process_group) = init_process_group(backend, rank, world_size)?;
484        let config = DDPConfig::default();
485
486        DistributedDataParallel::new(model, device, process_group, comm_runtime, config)
487    }
488}