Skip to main content

torsh_distributed/
ddp.rs

1//! Distributed Data Parallel (DDP) implementation
2
3// Framework infrastructure - components designed for future use
4#![allow(dead_code)]
5#![allow(clippy::await_holding_lock)]
6use crate::backend::ReduceOp;
7use crate::collectives::all_reduce;
8use crate::{process_group::ProcessGroup, TorshResult};
9use log::info;
10use std::collections::{HashMap, HashSet};
11use std::sync::{Arc, Mutex};
12use std::time::{Duration, Instant};
13use tokio::sync::{mpsc, Semaphore};
14use tokio::task::JoinHandle;
15use torsh_core::{error::Result, DeviceType};
16use torsh_nn::{Module, Parameter};
17use torsh_tensor::Tensor;
18
19/// Configuration for gradient bucketing
20#[derive(Debug, Clone)]
21pub struct BucketConfig {
22    /// Maximum size of each bucket in MB
23    pub max_bucket_size_mb: f32,
24    /// Whether to enable gradient bucketing
25    pub enabled: bool,
26    /// Minimum bucket size to avoid tiny buckets
27    pub min_bucket_size_mb: f32,
28}
29
30impl Default for BucketConfig {
31    fn default() -> Self {
32        Self {
33            max_bucket_size_mb: 25.0,
34            enabled: true,
35            min_bucket_size_mb: 1.0,
36        }
37    }
38}
39
40/// A bucket of gradients for efficient communication
41#[derive(Debug)]
42struct GradientBucket {
43    /// Parameters in this bucket
44    parameters: Vec<String>,
45    /// Total size in bytes
46    total_size: usize,
47    /// Whether this bucket is ready for synchronization
48    _ready: bool,
49}
50
51impl GradientBucket {
52    fn new() -> Self {
53        Self {
54            parameters: Vec::new(),
55            total_size: 0,
56            _ready: false,
57        }
58    }
59
60    fn add_parameter(&mut self, name: String, size: usize) {
61        self.parameters.push(name);
62        self.total_size += size;
63    }
64
65    fn size_mb(&self) -> f32 {
66        self.total_size as f32 / (1024.0 * 1024.0)
67    }
68}
69
70/// Statistics about gradient synchronization
71#[derive(Debug, Clone)]
72pub struct GradientSyncStats {
73    /// Total number of parameters that require gradients
74    pub total_parameters: usize,
75    /// Number of parameters that currently have gradients
76    pub parameters_with_grad: usize,
77    /// Total size of gradients in MB
78    pub total_gradient_size_mb: f32,
79    /// Number of gradient buckets
80    pub num_buckets: usize,
81    /// World size (number of processes)
82    pub world_size: u32,
83}
84
85/// Information about a gradient bucket
86#[derive(Debug, Clone)]
87pub struct BucketInfo {
88    /// Bucket index
89    pub index: usize,
90    /// Size of bucket in MB
91    pub size_mb: f32,
92    /// Number of parameters in bucket
93    pub num_parameters: usize,
94    /// Names of parameters in bucket
95    pub parameter_names: Vec<String>,
96}
97
98/// Message sent to gradient synchronization worker
99#[derive(Debug)]
100struct GradientMessage {
101    /// Parameter name
102    param_name: String,
103    /// Gradient tensor to synchronize
104    gradient: Tensor,
105    /// Bucket index this parameter belongs to
106    bucket_index: usize,
107}
108
109/// State for tracking unused parameters
110#[derive(Debug, Default)]
111struct UnusedParameterTracker {
112    /// All parameters that require gradients
113    all_parameters: HashSet<String>,
114    /// Parameters that have been used in the current iteration
115    used_parameters: HashSet<String>,
116    /// Whether unused parameter detection is enabled
117    enabled: bool,
118    /// Iteration counter
119    iteration: u64,
120}
121
122/// Overlap computation configuration
123#[derive(Debug, Clone)]
124pub struct OverlapConfig {
125    /// Whether to enable computation/communication overlap
126    pub enabled: bool,
127    /// Maximum number of pending gradient synchronizations
128    pub max_pending_syncs: usize,
129    /// Timeout for gradient synchronization (in seconds)
130    pub sync_timeout_secs: u64,
131    /// Whether to enable unused parameter detection
132    pub detect_unused_parameters: bool,
133}
134
135impl Default for OverlapConfig {
136    fn default() -> Self {
137        Self {
138            enabled: true,
139            max_pending_syncs: 4,
140            sync_timeout_secs: 30,
141            detect_unused_parameters: true,
142        }
143    }
144}
145
146/// Distributed Data Parallel wrapper for models
147pub struct DistributedDataParallel<M: Module> {
148    module: M,
149    process_group: Arc<ProcessGroup>,
150    _device_ids: Vec<usize>,
151    _output_device: Option<usize>,
152    _broadcast_buffers: bool,
153    _bucket_cap_mb: f32,
154    bucket_config: BucketConfig,
155    gradient_buckets: Vec<GradientBucket>,
156    /// Mapping from parameter name to bucket index
157    param_to_bucket: HashMap<String, usize>,
158    /// Overlap computation configuration
159    overlap_config: OverlapConfig,
160    /// Channel for sending gradients to background worker
161    gradient_sender: Option<mpsc::UnboundedSender<GradientMessage>>,
162    /// Handle to background gradient synchronization task
163    sync_task_handle: Option<JoinHandle<()>>,
164    /// Semaphore to limit concurrent gradient synchronizations
165    sync_semaphore: Arc<Semaphore>,
166    /// Unused parameter tracker
167    unused_param_tracker: Arc<Mutex<UnusedParameterTracker>>,
168    /// Bucket readiness tracking (bucket_index -> ready gradients count)
169    bucket_ready_count: Arc<Mutex<HashMap<usize, usize>>>,
170}
171
172impl<M: Module> DistributedDataParallel<M> {
173    /// Create a new DDP wrapper
174    pub fn new(
175        module: M,
176        process_group: Arc<ProcessGroup>,
177        device_ids: Vec<usize>,
178        output_device: Option<usize>,
179        broadcast_buffers: bool,
180        bucket_cap_mb: f32,
181    ) -> TorshResult<Self> {
182        let bucket_config = BucketConfig {
183            max_bucket_size_mb: bucket_cap_mb,
184            enabled: true,
185            min_bucket_size_mb: 1.0,
186        };
187
188        Self::new_with_configs(
189            module,
190            process_group,
191            device_ids,
192            output_device,
193            broadcast_buffers,
194            bucket_config,
195            OverlapConfig::default(),
196        )
197    }
198
199    /// Create a new DDP wrapper with custom bucket configuration
200    pub fn new_with_bucket_config(
201        module: M,
202        process_group: Arc<ProcessGroup>,
203        device_ids: Vec<usize>,
204        output_device: Option<usize>,
205        broadcast_buffers: bool,
206        bucket_config: BucketConfig,
207    ) -> TorshResult<Self> {
208        Self::new_with_configs(
209            module,
210            process_group,
211            device_ids,
212            output_device,
213            broadcast_buffers,
214            bucket_config,
215            OverlapConfig::default(),
216        )
217    }
218
219    /// Create a new DDP wrapper with custom configurations
220    pub fn new_with_configs(
221        module: M,
222        process_group: Arc<ProcessGroup>,
223        device_ids: Vec<usize>,
224        output_device: Option<usize>,
225        broadcast_buffers: bool,
226        bucket_config: BucketConfig,
227        overlap_config: OverlapConfig,
228    ) -> TorshResult<Self> {
229        let sync_semaphore = Arc::new(Semaphore::new(overlap_config.max_pending_syncs));
230        let unused_param_tracker = Arc::new(Mutex::new(UnusedParameterTracker::default()));
231        let bucket_ready_count = Arc::new(Mutex::new(HashMap::new()));
232
233        let mut ddp = Self {
234            module,
235            process_group,
236            _device_ids: device_ids,
237            _output_device: output_device,
238            _broadcast_buffers: broadcast_buffers,
239            _bucket_cap_mb: bucket_config.max_bucket_size_mb,
240            bucket_config,
241            gradient_buckets: Vec::new(),
242            param_to_bucket: HashMap::new(),
243            overlap_config,
244            gradient_sender: None,
245            sync_task_handle: None,
246            sync_semaphore,
247            unused_param_tracker,
248            bucket_ready_count,
249        };
250
251        // Initialize gradient buckets
252        ddp.initialize_buckets()?;
253
254        // Initialize unused parameter tracking if enabled
255        if ddp.overlap_config.detect_unused_parameters {
256            ddp.initialize_unused_parameter_tracking()?;
257        }
258
259        // Start background gradient synchronization worker if overlap is enabled
260        if ddp.overlap_config.enabled {
261            ddp.start_gradient_sync_worker()?;
262        }
263
264        Ok(ddp)
265    }
266
267    /// Initialize gradient buckets based on parameter sizes
268    fn initialize_buckets(&mut self) -> TorshResult<()> {
269        if !self.bucket_config.enabled {
270            return Ok(());
271        }
272
273        let parameters = self.module.named_parameters();
274        let mut current_bucket = GradientBucket::new();
275        let mut bucket_index = 0;
276
277        // Sort parameters by size (largest first for better packing)
278        let mut param_sizes: Vec<(String, usize)> = parameters
279            .iter()
280            .map(|(name, param)| {
281                let tensor = param.tensor();
282                let tensor_guard = tensor.read();
283                let size = tensor_guard.numel() * std::mem::size_of::<f32>(); // Assume f32 for now
284                (name.clone(), size)
285            })
286            .collect();
287
288        param_sizes.sort_by(|a, b| b.1.cmp(&a.1)); // Sort descending by size
289
290        for (param_name, param_size) in param_sizes {
291            // Check if adding this parameter would exceed bucket capacity
292            let new_size_mb = (current_bucket.total_size + param_size) as f32 / (1024.0 * 1024.0);
293
294            if new_size_mb > self.bucket_config.max_bucket_size_mb
295                && !current_bucket.parameters.is_empty()
296            {
297                // Finalize current bucket and start a new one
298                self.gradient_buckets.push(current_bucket);
299                current_bucket = GradientBucket::new();
300                bucket_index += 1;
301            }
302
303            // Add parameter to current bucket
304            current_bucket.add_parameter(param_name.clone(), param_size);
305            self.param_to_bucket.insert(param_name, bucket_index);
306        }
307
308        // Add the last bucket if it has parameters
309        if !current_bucket.parameters.is_empty() {
310            self.gradient_buckets.push(current_bucket);
311        }
312
313        info!(
314            "📦 Initialized {} gradient buckets",
315            self.gradient_buckets.len()
316        );
317        for (i, bucket) in self.gradient_buckets.iter().enumerate() {
318            info!(
319                "  Bucket {}: {:.2} MB, {} parameters",
320                i,
321                bucket.size_mb(),
322                bucket.parameters.len()
323            );
324        }
325
326        Ok(())
327    }
328
329    /// Initialize unused parameter tracking
330    fn initialize_unused_parameter_tracking(&mut self) -> TorshResult<()> {
331        let parameters = self.module.named_parameters();
332        let mut tracker = self
333            .unused_param_tracker
334            .lock()
335            .expect("lock should not be poisoned");
336
337        tracker.enabled = true;
338        tracker.all_parameters.clear();
339        tracker.used_parameters.clear();
340
341        // Add all parameters that require gradients
342        for (name, param) in parameters {
343            let tensor = param.tensor();
344            let tensor_guard = tensor.read();
345            if tensor_guard.requires_grad() {
346                tracker.all_parameters.insert(name);
347            }
348        }
349
350        info!(
351            "🔍 Initialized unused parameter detection for {} parameters",
352            tracker.all_parameters.len()
353        );
354
355        Ok(())
356    }
357
358    /// Start the background gradient synchronization worker
359    fn start_gradient_sync_worker(&mut self) -> TorshResult<()> {
360        let (sender, mut receiver) = mpsc::unbounded_channel::<GradientMessage>();
361
362        let process_group = Arc::clone(&self.process_group);
363        let sync_semaphore = Arc::clone(&self.sync_semaphore);
364        let _bucket_ready_count = Arc::clone(&self.bucket_ready_count);
365        let _gradient_buckets_len = self.gradient_buckets.len();
366        let timeout_duration = Duration::from_secs(self.overlap_config.sync_timeout_secs);
367
368        // Clone bucket info for the worker
369        let bucket_param_counts: HashMap<usize, usize> = self
370            .gradient_buckets
371            .iter()
372            .enumerate()
373            .map(|(i, bucket)| (i, bucket.parameters.len()))
374            .collect();
375
376        let handle = tokio::spawn(async move {
377            let mut pending_gradients: HashMap<usize, Vec<(String, Tensor)>> = HashMap::new();
378
379            while let Some(grad_msg) = receiver.recv().await {
380                // Acquire semaphore permit to limit concurrent operations
381                let _permit =
382                    match tokio::time::timeout(timeout_duration, sync_semaphore.acquire()).await {
383                        Ok(Ok(permit)) => permit,
384                        Ok(Err(_)) => {
385                            info!("  Gradient sync semaphore closed, stopping worker");
386                            break;
387                        }
388                        Err(_) => {
389                            info!(
390                                "  Gradient sync timeout, dropping gradient for {}",
391                                grad_msg.param_name
392                            );
393                            continue;
394                        }
395                    };
396
397                let bucket_index = grad_msg.bucket_index;
398
399                // Add gradient to pending list for this bucket
400                pending_gradients
401                    .entry(bucket_index)
402                    .or_insert_with(Vec::new)
403                    .push((grad_msg.param_name, grad_msg.gradient));
404
405                // Check if bucket is ready (all parameters have gradients)
406                let expected_count = bucket_param_counts.get(&bucket_index).copied().unwrap_or(0);
407                let current_count = pending_gradients
408                    .get(&bucket_index)
409                    .map(|v| v.len())
410                    .unwrap_or(0);
411
412                if current_count >= expected_count && expected_count > 0 {
413                    // Bucket is ready - process all gradients in this bucket
414                    if let Some(bucket_gradients) = pending_gradients.remove(&bucket_index) {
415                        // Process bucket asynchronously with improved synchronization
416                        let pg = Arc::clone(&process_group);
417                        tokio::spawn(async move {
418                            match Self::sync_bucket_gradients(bucket_gradients, &pg).await {
419                                Ok(synchronized_gradients) => {
420                                    info!(
421                                        " Successfully synchronized bucket {} with {} gradients",
422                                        bucket_index,
423                                        synchronized_gradients.len()
424                                    );
425                                    // Note: In this async worker context, we cannot directly set gradients back to parameters
426                                    // This would need to be handled by the main thread or through a callback mechanism
427                                }
428                                Err(e) => {
429                                    info!("  Failed to sync bucket {}: {}", bucket_index, e);
430                                }
431                            }
432                        });
433                    }
434                }
435            }
436
437            info!(" Gradient synchronization worker stopped");
438        });
439
440        self.gradient_sender = Some(sender);
441        self.sync_task_handle = Some(handle);
442
443        info!(" Started background gradient synchronization worker");
444        Ok(())
445    }
446
447    /// Synchronize a bucket of gradients with efficient flattening
448    async fn sync_bucket_gradients(
449        gradients: Vec<(String, Tensor)>,
450        process_group: &ProcessGroup,
451    ) -> TorshResult<Vec<(String, Tensor)>> {
452        let start_time = Instant::now();
453
454        if gradients.is_empty() {
455            return Ok(Vec::new());
456        }
457
458        // Improved implementation with efficient bucket flattening and synchronization
459        if gradients.len() > 1 {
460            // Multiple gradients - use efficient bucket flattening
461            Self::sync_bucket_gradients_flattened(gradients, process_group).await
462        } else {
463            // Single gradient - direct synchronization
464            Self::sync_single_gradient(gradients, process_group).await
465        }
466        .inspect(|_result| {
467            let elapsed = start_time.elapsed();
468            if elapsed > Duration::from_millis(100) {
469                info!(
470                    "⏱️  Bucket sync took {:.2}ms",
471                    elapsed.as_secs_f32() * 1000.0
472                );
473            }
474        })
475    }
476
477    /// Synchronize gradients using flattening for efficiency
478    async fn sync_bucket_gradients_flattened(
479        gradients: Vec<(String, Tensor)>,
480        process_group: &ProcessGroup,
481    ) -> TorshResult<Vec<(String, Tensor)>> {
482        // Step 1: Flatten all gradients into a single tensor for efficient communication
483        let mut gradient_shapes = Vec::new();
484        let mut gradient_sizes = Vec::new();
485        let mut flattened_data = Vec::new();
486
487        for (param_name, grad) in &gradients {
488            let shape = grad.shape();
489            let numel = grad.numel();
490            gradient_shapes.push((param_name.clone(), shape.dims().to_vec()));
491            gradient_sizes.push(numel);
492
493            // Flatten the gradient and extract its data
494            let flattened_grad = grad.flatten()?;
495            let grad_data = flattened_grad.data()?;
496            flattened_data.extend_from_slice(&grad_data);
497        }
498
499        // Step 2: Create a single flattened tensor containing all gradients
500        let total_size = flattened_data.len();
501        let mut flattened_tensor =
502            Tensor::from_data(flattened_data, vec![total_size], gradients[0].1.device())?;
503
504        // Step 3: Perform a single all-reduce operation on the flattened tensor
505        all_reduce(&mut flattened_tensor, ReduceOp::Sum, process_group).await?;
506
507        // Average by world size
508        let world_size = process_group.world_size() as f32;
509        flattened_tensor = flattened_tensor.div_scalar(world_size)?;
510
511        // Step 4: Unflatten and distribute back to individual gradients
512        let flattened_data = flattened_tensor.data()?;
513        let mut result_gradients = Vec::new();
514        let mut current_offset = 0;
515
516        for ((param_name, original_shape), size) in
517            gradient_shapes.iter().zip(gradient_sizes.iter())
518        {
519            // Extract data for this gradient
520            let grad_data = &flattened_data[current_offset..current_offset + size];
521
522            // Reconstruct the gradient tensor with original shape
523            let reconstructed_grad = Tensor::from_data(
524                grad_data.to_vec(),
525                original_shape.clone(),
526                gradients[0].1.device(),
527            )?;
528
529            result_gradients.push((param_name.clone(), reconstructed_grad));
530            current_offset += size;
531        }
532
533        Ok(result_gradients)
534    }
535
536    /// Synchronize a single gradient directly
537    async fn sync_single_gradient(
538        mut gradients: Vec<(String, Tensor)>,
539        process_group: &ProcessGroup,
540    ) -> TorshResult<Vec<(String, Tensor)>> {
541        if let Some((param_name, mut grad)) = gradients.pop() {
542            all_reduce(&mut grad, ReduceOp::Sum, process_group).await?;
543
544            // Average by world size
545            let world_size = process_group.world_size() as f32;
546            grad = grad.div_scalar(world_size)?;
547
548            Ok(vec![(param_name, grad)])
549        } else {
550            Ok(Vec::new())
551        }
552    }
553
554    /// Synchronize gradients across all processes
555    pub async fn sync_gradients(&mut self) -> TorshResult<()> {
556        if self.bucket_config.enabled && !self.gradient_buckets.is_empty() {
557            // Use bucketed gradient synchronization for better performance
558            self.sync_gradients_bucketed().await
559        } else {
560            // Fall back to naive synchronization
561            self.sync_gradients_naive().await
562        }
563    }
564
565    /// Synchronize gradients using naive approach (one parameter at a time)
566    async fn sync_gradients_naive(&mut self) -> TorshResult<()> {
567        #[allow(clippy::await_holding_lock)]
568        let parameters = self.module.parameters();
569
570        for (_name, param) in parameters {
571            let tensor = param.tensor();
572            let tensor_guard = tensor.read();
573
574            // Check if this parameter requires gradients and has a gradient
575            if tensor_guard.requires_grad() {
576                if let Some(mut grad) = tensor_guard.grad() {
577                    // Perform all-reduce on the gradient
578                    all_reduce(&mut grad, ReduceOp::Sum, &self.process_group).await?;
579
580                    // Average by world size (divide by number of processes)
581                    let world_size = self.process_group.world_size() as f32;
582                    grad = grad.div_scalar(world_size)?;
583
584                    // Set the synchronized gradient back to the parameter
585                    tensor_guard.set_grad(Some(grad));
586                }
587            }
588        }
589
590        Ok(())
591    }
592
593    /// Synchronize gradients using bucketing for better communication efficiency
594    async fn sync_gradients_bucketed(&mut self) -> TorshResult<()> {
595        let parameters = self.module.named_parameters();
596
597        // Process each bucket
598        for bucket in &self.gradient_buckets {
599            let mut bucket_gradients = Vec::new();
600            let mut bucket_params = Vec::new();
601
602            // Collect gradients for this bucket
603            for param_name in &bucket.parameters {
604                if let Some(param) = parameters.get(param_name) {
605                    let tensor = param.tensor();
606                    let tensor_guard = tensor.read();
607
608                    if tensor_guard.requires_grad() {
609                        if let Some(grad) = tensor_guard.grad() {
610                            bucket_gradients.push(grad);
611                            bucket_params.push(param_name.clone());
612                        }
613                    }
614                }
615            }
616
617            if !bucket_gradients.is_empty() {
618                // Sophisticated implementation with efficient bucket flattening:
619                // 1. Flatten all gradients in the bucket into a single tensor
620                // 2. Perform a single all-reduce operation on the flattened tensor
621                // 3. Unflatten and distribute back to individual gradients
622
623                // Prepare gradients with parameter names for synchronization
624                let gradients_with_names: Vec<(String, Tensor)> = bucket_gradients
625                    .into_iter()
626                    .zip(bucket_params.iter())
627                    .map(|(grad, param_name)| (param_name.clone(), grad))
628                    .collect();
629
630                // Synchronize the entire bucket efficiently
631                match Self::sync_bucket_gradients_flattened(
632                    gradients_with_names,
633                    &self.process_group,
634                )
635                .await
636                {
637                    Ok(synchronized_gradients) => {
638                        // Set the synchronized gradients back to their parameters
639                        for (param_name, synchronized_grad) in synchronized_gradients {
640                            if let Some(param) = parameters.get(&param_name) {
641                                let tensor = param.tensor();
642                                let tensor_guard = tensor.read();
643                                tensor_guard.set_grad(Some(synchronized_grad));
644                            }
645                        }
646                    }
647                    Err(e) => {
648                        info!("  Failed to sync bucket gradients efficiently, falling back to individual sync: {}", e);
649
650                        #[allow(clippy::await_holding_lock)]
651                        // Fallback to individual gradient synchronization
652                        for param_name in &bucket_params {
653                            if let Some(param) = parameters.get(param_name) {
654                                let tensor = param.tensor();
655                                let tensor_guard = tensor.read();
656
657                                if let Some(mut grad) = tensor_guard.grad() {
658                                    all_reduce(&mut grad, ReduceOp::Sum, &self.process_group)
659                                        .await?;
660                                    let world_size = self.process_group.world_size() as f32;
661                                    grad = grad.div_scalar(world_size)?;
662                                    tensor_guard.set_grad(Some(grad));
663                                }
664                            }
665                        }
666                    }
667                }
668            }
669        }
670
671        Ok(())
672    }
673
674    /// Register gradient synchronization hooks
675    /// This should be called during the backward pass to automatically sync gradients
676    pub fn register_gradient_hooks(&self) -> TorshResult<()> {
677        // In a complete implementation, this would register hooks on each parameter
678        // that automatically call all_reduce when gradients are computed
679
680        // For now, we'll use a simpler approach where sync_gradients is called manually
681        // after backward() but before optimizer.step()
682        Ok(())
683    }
684
685    /// Register a gradient for asynchronous synchronization (overlap mode)
686    /// This should be called when a gradient becomes available during backward pass
687    pub fn register_gradient_async(&self, param_name: &str, gradient: Tensor) -> TorshResult<()> {
688        if !self.overlap_config.enabled {
689            return Ok(()); // Overlap not enabled, skip
690        }
691
692        // Mark parameter as used for unused parameter detection
693        if self.overlap_config.detect_unused_parameters {
694            let mut tracker = self
695                .unused_param_tracker
696                .lock()
697                .expect("lock should not be poisoned");
698            if tracker.enabled {
699                tracker.used_parameters.insert(param_name.to_string());
700            }
701        }
702
703        // Get bucket index for this parameter
704        let bucket_index = self.param_to_bucket.get(param_name).copied().unwrap_or(0);
705
706        // Send gradient to background worker
707        if let Some(sender) = &self.gradient_sender {
708            let message = GradientMessage {
709                param_name: param_name.to_string(),
710                gradient,
711                bucket_index,
712            };
713
714            if let Err(e) = sender.send(message) {
715                info!("  Failed to send gradient for {}: {}", param_name, e);
716            }
717        }
718
719        Ok(())
720    }
721
722    /// Check for unused parameters and issue warnings
723    pub fn check_unused_parameters(&self) -> TorshResult<Vec<String>> {
724        if !self.overlap_config.detect_unused_parameters {
725            return Ok(Vec::new());
726        }
727
728        let tracker = self
729            .unused_param_tracker
730            .lock()
731            .expect("lock should not be poisoned");
732        if !tracker.enabled {
733            return Ok(Vec::new());
734        }
735
736        let unused: Vec<String> = tracker
737            .all_parameters
738            .difference(&tracker.used_parameters)
739            .cloned()
740            .collect();
741
742        if !unused.is_empty() {
743            info!(
744                "  Found {} unused parameters in iteration {}:",
745                unused.len(),
746                tracker.iteration
747            );
748            for param in &unused {
749                info!("    - {}", param);
750            }
751        }
752
753        Ok(unused)
754    }
755
756    /// Start a new iteration (reset unused parameter tracking)
757    pub fn start_iteration(&self) -> TorshResult<()> {
758        if self.overlap_config.detect_unused_parameters {
759            let mut tracker = self
760                .unused_param_tracker
761                .lock()
762                .expect("lock should not be poisoned");
763            if tracker.enabled {
764                tracker.used_parameters.clear();
765                tracker.iteration += 1;
766            }
767        }
768        Ok(())
769    }
770
771    /// Get overlap computation statistics
772    pub fn get_overlap_stats(&self) -> HashMap<String, serde_json::Value> {
773        let mut stats = HashMap::new();
774
775        stats.insert(
776            "overlap_enabled".to_string(),
777            serde_json::Value::Bool(self.overlap_config.enabled),
778        );
779        stats.insert(
780            "max_pending_syncs".to_string(),
781            serde_json::Value::Number(serde_json::Number::from(
782                self.overlap_config.max_pending_syncs,
783            )),
784        );
785        stats.insert(
786            "sync_timeout_secs".to_string(),
787            serde_json::Value::Number(serde_json::Number::from(
788                self.overlap_config.sync_timeout_secs,
789            )),
790        );
791        stats.insert(
792            "unused_param_detection".to_string(),
793            serde_json::Value::Bool(self.overlap_config.detect_unused_parameters),
794        );
795
796        if let Ok(tracker) = self.unused_param_tracker.lock() {
797            stats.insert(
798                "total_params".to_string(),
799                serde_json::Value::Number(serde_json::Number::from(tracker.all_parameters.len())),
800            );
801            stats.insert(
802                "used_params".to_string(),
803                serde_json::Value::Number(serde_json::Number::from(tracker.used_parameters.len())),
804            );
805            stats.insert(
806                "current_iteration".to_string(),
807                serde_json::Value::Number(serde_json::Number::from(tracker.iteration)),
808            );
809        }
810
811        // Semaphore availability
812        let available_permits = self.sync_semaphore.available_permits();
813        stats.insert(
814            "available_sync_permits".to_string(),
815            serde_json::Value::Number(serde_json::Number::from(available_permits)),
816        );
817
818        stats
819    }
820
821    /// Check if any parameters have gradients
822    pub fn has_gradients(&self) -> bool {
823        let parameters = self.module.parameters();
824
825        for (_name, param) in parameters {
826            let tensor = param.tensor();
827            let tensor_guard = tensor.read();
828
829            if tensor_guard.requires_grad() && tensor_guard.has_grad() {
830                return true;
831            }
832        }
833
834        false
835    }
836
837    /// Zero all gradients
838    pub fn zero_grad(&mut self) -> TorshResult<()> {
839        let parameters = self.module.parameters();
840
841        for (_name, param) in parameters {
842            let tensor = param.tensor();
843            let tensor_guard = tensor.read();
844
845            if tensor_guard.requires_grad() {
846                tensor_guard.set_grad(None);
847            }
848        }
849
850        Ok(())
851    }
852
853    /// Get gradient synchronization statistics
854    pub fn get_sync_stats(&self) -> GradientSyncStats {
855        let parameters = self.module.named_parameters();
856        let mut total_parameters = 0;
857        let mut parameters_with_grad = 0;
858        let mut total_gradient_size = 0;
859
860        for (_name, param) in parameters {
861            let tensor = param.tensor();
862            let tensor_guard = tensor.read();
863
864            if tensor_guard.requires_grad() {
865                total_parameters += 1;
866
867                if tensor_guard.has_grad() {
868                    parameters_with_grad += 1;
869                    total_gradient_size += tensor_guard.numel() * std::mem::size_of::<f32>();
870                }
871            }
872        }
873
874        GradientSyncStats {
875            total_parameters,
876            parameters_with_grad,
877            total_gradient_size_mb: total_gradient_size as f32 / (1024.0 * 1024.0),
878            num_buckets: self.gradient_buckets.len(),
879            world_size: self.process_group.world_size(),
880        }
881    }
882
883    /// Enable/disable gradient bucketing at runtime
884    pub fn set_bucketing_enabled(&mut self, enabled: bool) -> TorshResult<()> {
885        self.bucket_config.enabled = enabled;
886
887        if enabled && self.gradient_buckets.is_empty() {
888            // Re-initialize buckets if they were disabled
889            self.initialize_buckets()?;
890        }
891
892        Ok(())
893    }
894
895    /// Get bucket information for debugging
896    pub fn get_bucket_info(&self) -> Vec<BucketInfo> {
897        self.gradient_buckets
898            .iter()
899            .enumerate()
900            .map(|(i, bucket)| BucketInfo {
901                index: i,
902                size_mb: bucket.size_mb(),
903                num_parameters: bucket.parameters.len(),
904                parameter_names: bucket.parameters.clone(),
905            })
906            .collect()
907    }
908
909    /// Perform a gradient consistency check across all processes
910    /// This is useful for debugging distributed training issues
911    pub async fn check_gradient_consistency(&self) -> TorshResult<bool> {
912        // In a complete implementation, this would:
913        // 1. Compute checksums of gradients on each process
914        // 2. Use all_gather to collect checksums from all processes
915        // 3. Compare checksums to detect inconsistencies
916        // 4. Report which parameters have mismatched gradients
917
918        // For now, just return true as a placeholder
919        Ok(true)
920    }
921}
922
923impl<M: Module> Module for DistributedDataParallel<M> {
924    fn forward(&self, input: &Tensor) -> Result<Tensor> {
925        // Forward through underlying module
926        self.module.forward(input)
927    }
928
929    fn parameters(&self) -> HashMap<String, Parameter> {
930        self.module.parameters()
931    }
932
933    fn named_parameters(&self) -> HashMap<String, Parameter> {
934        self.module.named_parameters()
935    }
936
937    fn training(&self) -> bool {
938        self.module.training()
939    }
940
941    fn train(&mut self) {
942        self.module.train()
943    }
944
945    fn eval(&mut self) {
946        self.module.eval()
947    }
948
949    fn to_device(&mut self, device: DeviceType) -> Result<()> {
950        self.module.to_device(device)
951    }
952}