Skip to main content

torsh_distributed/
debugging.rs

1//! Debugging utilities for distributed training systems
2//!
3//! This module provides comprehensive debugging tools including operation tracing,
4//! state inspection, diagnostic tools, and automated troubleshooting capabilities.
5
6use crate::metrics::get_global_metrics_collector;
7use crate::profiling::get_global_profiler;
8use crate::{TorshDistributedError, TorshResult};
9use log::info;
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, VecDeque};
12use std::sync::{Arc, Mutex, RwLock};
13use std::time::{Duration, SystemTime, UNIX_EPOCH};
14
15/// Logging levels for debugging
16#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
17pub enum LogLevel {
18    Trace,
19    Debug,
20    Info,
21    Warn,
22    Error,
23    Critical,
24}
25
26impl std::fmt::Display for LogLevel {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        match self {
29            LogLevel::Trace => write!(f, "TRACE"),
30            LogLevel::Debug => write!(f, "DEBUG"),
31            LogLevel::Info => write!(f, "INFO"),
32            LogLevel::Warn => write!(f, "WARN"),
33            LogLevel::Error => write!(f, "ERROR"),
34            LogLevel::Critical => write!(f, "CRITICAL"),
35        }
36    }
37}
38
39/// Debug event for tracking system operations
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct DebugEvent {
42    /// Unique event identifier
43    pub event_id: u64,
44    /// Event timestamp
45    pub timestamp: SystemTime,
46    /// Log level
47    pub level: LogLevel,
48    /// Source module/component
49    pub source: String,
50    /// Rank that generated the event
51    pub rank: u32,
52    /// Event message
53    pub message: String,
54    /// Additional context data
55    pub context: HashMap<String, String>,
56    /// Call stack trace (if available)
57    pub call_stack: Vec<String>,
58    /// Duration (for operation events)
59    pub duration: Option<Duration>,
60}
61
62impl DebugEvent {
63    /// Create a new debug event
64    pub fn new(level: LogLevel, source: String, rank: u32, message: String) -> Self {
65        Self {
66            event_id: 0, // Will be set by the debugger
67            timestamp: SystemTime::now(),
68            level,
69            source,
70            rank,
71            message,
72            context: HashMap::new(),
73            call_stack: Vec::new(),
74            duration: None,
75        }
76    }
77
78    /// Add context information
79    pub fn with_context(mut self, key: String, value: String) -> Self {
80        self.context.insert(key, value);
81        self
82    }
83
84    /// Add call stack
85    pub fn with_call_stack(mut self, stack: Vec<String>) -> Self {
86        self.call_stack = stack;
87        self
88    }
89
90    /// Set duration
91    pub fn with_duration(mut self, duration: Duration) -> Self {
92        self.duration = Some(duration);
93        self
94    }
95
96    /// Format as a human-readable string
97    pub fn format(&self) -> String {
98        let timestamp_str = self
99            .timestamp
100            .duration_since(UNIX_EPOCH)
101            .map(|d| format!("{:.3}", d.as_secs_f64()))
102            .unwrap_or_else(|_| "unknown".to_string());
103
104        let duration_str = self
105            .duration
106            .map(|d| format!(" [{}ms]", d.as_millis()))
107            .unwrap_or_default();
108
109        format!(
110            "[{}] {} [{}:{}] {}{}\n",
111            timestamp_str, self.level, self.source, self.rank, self.message, duration_str
112        )
113    }
114}
115
116/// System state snapshot for debugging
117#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct SystemStateSnapshot {
119    /// Timestamp when snapshot was taken
120    pub timestamp: SystemTime,
121    /// Process group information
122    pub process_group: ProcessGroupState,
123    /// Communication state
124    pub communication: CommunicationState,
125    /// Resource utilization
126    pub resources: ResourceState,
127    /// Active operations
128    pub active_operations: Vec<ActiveOperation>,
129    /// Recent errors
130    pub recent_errors: Vec<DebugEvent>,
131}
132
133/// Process group state information
134#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct ProcessGroupState {
136    /// Current rank
137    pub rank: u32,
138    /// World size
139    pub world_size: u32,
140    /// Backend type
141    pub backend: String,
142    /// Process group health status
143    pub health_status: String,
144    /// Active process count
145    pub active_processes: u32,
146    /// Failed processes
147    pub failed_processes: Vec<u32>,
148}
149
150/// Communication state information
151#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct CommunicationState {
153    /// Pending operations count
154    pub pending_operations: u32,
155    /// Failed operations count
156    pub failed_operations: u32,
157    /// Average latency (ms)
158    pub avg_latency_ms: f64,
159    /// Current bandwidth utilization (MB/s)
160    pub bandwidth_mbps: f64,
161    /// Communication queue length
162    pub queue_length: u32,
163    /// Last successful communication timestamp
164    pub last_success: Option<SystemTime>,
165}
166
167/// Resource state information
168#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct ResourceState {
170    /// CPU usage percentage
171    pub cpu_usage_pct: f64,
172    /// Memory usage percentage
173    pub memory_usage_pct: f64,
174    /// GPU usage percentage (if available)
175    pub gpu_usage_pct: Option<f64>,
176    /// Network I/O (bytes/sec)
177    pub network_io_bps: u64,
178    /// Disk I/O (bytes/sec)
179    pub disk_io_bps: u64,
180    /// Memory pressure indicator
181    pub memory_pressure: String,
182}
183
184/// Active operation information
185#[derive(Debug, Clone, Serialize, Deserialize)]
186pub struct ActiveOperation {
187    /// Operation type
188    pub operation_type: String,
189    /// Start time
190    pub start_time: SystemTime,
191    /// Expected duration
192    pub expected_duration: Option<Duration>,
193    /// Progress percentage (0-100)
194    pub progress_pct: f64,
195    /// Rank(s) involved
196    pub ranks: Vec<u32>,
197    /// Operation status
198    pub status: String,
199}
200
201/// Diagnostic check result
202#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct DiagnosticResult {
204    /// Check name
205    pub check_name: String,
206    /// Whether the check passed
207    pub passed: bool,
208    /// Severity if check failed
209    pub severity: LogLevel,
210    /// Description of the issue (if any)
211    pub description: String,
212    /// Suggested remediation
213    pub remediation: Vec<String>,
214    /// Supporting data
215    pub data: HashMap<String, String>,
216}
217
218/// Configuration for debugging utilities
219#[derive(Debug, Clone, Serialize, Deserialize)]
220pub struct DebugConfig {
221    /// Whether debugging is enabled
222    pub enabled: bool,
223    /// Minimum log level to capture
224    pub min_log_level: LogLevel,
225    /// Maximum number of events to keep in memory
226    pub max_events: usize,
227    /// Whether to capture call stacks
228    pub capture_call_stacks: bool,
229    /// Whether to enable real-time monitoring
230    pub real_time_monitoring: bool,
231    /// Snapshot interval (seconds)
232    pub snapshot_interval_secs: u64,
233    /// Auto-diagnosis interval (seconds)
234    pub auto_diagnosis_interval_secs: u64,
235}
236
237impl Default for DebugConfig {
238    fn default() -> Self {
239        Self {
240            enabled: true,
241            min_log_level: LogLevel::Info,
242            max_events: 1000,
243            capture_call_stacks: false, // Expensive operation
244            real_time_monitoring: true,
245            snapshot_interval_secs: 30,
246            auto_diagnosis_interval_secs: 60,
247        }
248    }
249}
250
251/// Comprehensive debugging system for distributed training
252pub struct DistributedDebugger {
253    /// Configuration
254    config: RwLock<DebugConfig>,
255    /// Event counter for unique IDs
256    event_counter: Mutex<u64>,
257    /// Circular buffer of debug events
258    events: Mutex<VecDeque<DebugEvent>>,
259    /// System state snapshots
260    snapshots: Mutex<VecDeque<SystemStateSnapshot>>,
261    /// Active operation tracking
262    active_operations: Mutex<HashMap<String, ActiveOperation>>,
263    /// Diagnostic results history
264    diagnostic_history: Mutex<Vec<DiagnosticResult>>,
265    /// Performance statistics
266    stats: Mutex<DebuggerStats>,
267}
268
269/// Statistics for the debugger itself
270#[derive(Debug, Default, Serialize, Deserialize)]
271struct DebuggerStats {
272    events_captured: u64,
273    snapshots_taken: u64,
274    diagnostics_run: u64,
275    errors_detected: u64,
276}
277
278impl DistributedDebugger {
279    /// Create a new distributed debugger
280    pub fn new() -> Self {
281        Self::with_config(DebugConfig::default())
282    }
283
284    /// Create a new distributed debugger with custom configuration
285    pub fn with_config(config: DebugConfig) -> Self {
286        Self {
287            config: RwLock::new(config),
288            event_counter: Mutex::new(0),
289            events: Mutex::new(VecDeque::new()),
290            snapshots: Mutex::new(VecDeque::new()),
291            active_operations: Mutex::new(HashMap::new()),
292            diagnostic_history: Mutex::new(Vec::new()),
293            stats: Mutex::new(DebuggerStats::default()),
294        }
295    }
296
297    /// Log a debug event
298    pub fn log_event(&self, mut event: DebugEvent) -> TorshResult<()> {
299        let config = self
300            .config
301            .read()
302            .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?;
303
304        if !config.enabled || event.level < config.min_log_level {
305            return Ok(());
306        }
307
308        // Assign unique event ID
309        {
310            let mut counter = self
311                .event_counter
312                .lock()
313                .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?;
314            *counter += 1;
315            event.event_id = *counter;
316        }
317
318        // Capture call stack if enabled
319        if config.capture_call_stacks {
320            // In a real implementation, you would capture the actual call stack
321            event.call_stack = vec!["main".to_string(), "debug_function".to_string()];
322        }
323
324        // Store event
325        {
326            let mut events = self
327                .events
328                .lock()
329                .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?;
330            events.push_back(event.clone());
331
332            // Maintain circular buffer
333            if events.len() > config.max_events {
334                events.pop_front();
335            }
336        }
337
338        // Update statistics
339        {
340            let mut stats = self
341                .stats
342                .lock()
343                .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?;
344            stats.events_captured += 1;
345            if event.level >= LogLevel::Error {
346                stats.errors_detected += 1;
347            }
348        }
349
350        // Print to console if critical
351        if event.level >= LogLevel::Critical {
352            info!("CRITICAL: {}", event.format());
353        }
354
355        Ok(())
356    }
357
358    /// Take a system state snapshot
359    pub fn take_snapshot(&self) -> TorshResult<SystemStateSnapshot> {
360        let snapshot = SystemStateSnapshot {
361            timestamp: SystemTime::now(),
362            process_group: self.capture_process_group_state()?,
363            communication: self.capture_communication_state()?,
364            resources: self.capture_resource_state()?,
365            active_operations: self.get_active_operations(),
366            recent_errors: self.get_recent_errors(10)?,
367        };
368
369        // Store snapshot
370        {
371            let mut snapshots = self
372                .snapshots
373                .lock()
374                .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?;
375            snapshots.push_back(snapshot.clone());
376
377            // Keep only last 20 snapshots
378            if snapshots.len() > 20 {
379                snapshots.pop_front();
380            }
381        }
382
383        // Update statistics
384        {
385            let mut stats = self
386                .stats
387                .lock()
388                .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?;
389            stats.snapshots_taken += 1;
390        }
391
392        Ok(snapshot)
393    }
394
395    /// Capture process group state
396    fn capture_process_group_state(&self) -> TorshResult<ProcessGroupState> {
397        // In a real implementation, this would query the actual process group
398        Ok(ProcessGroupState {
399            rank: 0,                     // Would get from actual process group
400            world_size: 1,               // Would get from actual process group
401            backend: "Mock".to_string(), // Would get from actual process group
402            health_status: "Healthy".to_string(),
403            active_processes: 1,
404            failed_processes: Vec::new(),
405        })
406    }
407
408    /// Capture communication state
409    fn capture_communication_state(&self) -> TorshResult<CommunicationState> {
410        let metrics_collector = get_global_metrics_collector();
411
412        if let Ok(comm_history) = metrics_collector.get_communication_history() {
413            if let Some(latest) = comm_history.last() {
414                return Ok(CommunicationState {
415                    pending_operations: 0, // Would track from actual communication system
416                    failed_operations: latest.value.failed_operations as u32,
417                    avg_latency_ms: latest.value.avg_latency_ms,
418                    bandwidth_mbps: latest.value.avg_bandwidth_mbps,
419                    queue_length: 0, // Would get from actual queue
420                    last_success: Some(latest.timestamp),
421                });
422            }
423        }
424
425        Ok(CommunicationState {
426            pending_operations: 0,
427            failed_operations: 0,
428            avg_latency_ms: 0.0,
429            bandwidth_mbps: 0.0,
430            queue_length: 0,
431            last_success: None,
432        })
433    }
434
435    /// Capture resource state
436    fn capture_resource_state(&self) -> TorshResult<ResourceState> {
437        let metrics_collector = get_global_metrics_collector();
438
439        if let Ok(system_history) = metrics_collector.get_system_history() {
440            if let Some(latest) = system_history.last() {
441                return Ok(ResourceState {
442                    cpu_usage_pct: latest.value.cpu_usage_pct,
443                    memory_usage_pct: latest.value.memory_usage_pct,
444                    gpu_usage_pct: latest.value.gpu_usage_pct,
445                    network_io_bps: latest.value.network_bytes_rx + latest.value.network_bytes_tx,
446                    disk_io_bps: latest.value.disk_bytes_read + latest.value.disk_bytes_write,
447                    memory_pressure: if latest.value.memory_usage_pct > 90.0 {
448                        "High"
449                    } else {
450                        "Normal"
451                    }
452                    .to_string(),
453                });
454            }
455        }
456
457        Ok(ResourceState {
458            cpu_usage_pct: 0.0,
459            memory_usage_pct: 0.0,
460            gpu_usage_pct: None,
461            network_io_bps: 0,
462            disk_io_bps: 0,
463            memory_pressure: "Unknown".to_string(),
464        })
465    }
466
467    /// Get active operations
468    fn get_active_operations(&self) -> Vec<ActiveOperation> {
469        self.active_operations
470            .lock()
471            .map(|ops| ops.values().cloned().collect())
472            .unwrap_or_default()
473    }
474
475    /// Track an active operation
476    pub fn start_operation(&self, operation_type: String, ranks: Vec<u32>) -> TorshResult<String> {
477        let operation_id = format!(
478            "{}_{}",
479            operation_type,
480            SystemTime::now()
481                .duration_since(UNIX_EPOCH)
482                .unwrap_or_default()
483                .as_nanos()
484        );
485
486        let operation = ActiveOperation {
487            operation_type: operation_type.clone(),
488            start_time: SystemTime::now(),
489            expected_duration: None,
490            progress_pct: 0.0,
491            ranks,
492            status: "Running".to_string(),
493        };
494
495        {
496            let mut active_ops = self
497                .active_operations
498                .lock()
499                .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?;
500            active_ops.insert(operation_id.clone(), operation);
501        }
502
503        self.log_event(
504            DebugEvent::new(
505                LogLevel::Debug,
506                "DistributedDebugger".to_string(),
507                0, // Would get actual rank
508                format!("Started operation: {}", operation_type),
509            )
510            .with_context("operation_id".to_string(), operation_id.clone()),
511        )?;
512
513        Ok(operation_id)
514    }
515
516    /// Update operation progress
517    pub fn update_operation_progress(
518        &self,
519        operation_id: &str,
520        progress_pct: f64,
521    ) -> TorshResult<()> {
522        let mut active_ops = self
523            .active_operations
524            .lock()
525            .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?;
526
527        if let Some(operation) = active_ops.get_mut(operation_id) {
528            operation.progress_pct = progress_pct;
529        }
530
531        Ok(())
532    }
533
534    /// Complete an operation
535    pub fn complete_operation(&self, operation_id: &str, success: bool) -> TorshResult<()> {
536        let mut active_ops = self
537            .active_operations
538            .lock()
539            .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?;
540
541        if let Some(operation) = active_ops.remove(operation_id) {
542            let duration = SystemTime::now()
543                .duration_since(operation.start_time)
544                .unwrap_or_default();
545
546            self.log_event(
547                DebugEvent::new(
548                    if success {
549                        LogLevel::Debug
550                    } else {
551                        LogLevel::Error
552                    },
553                    "DistributedDebugger".to_string(),
554                    0, // Would get actual rank
555                    format!(
556                        "Completed operation: {} ({})",
557                        operation.operation_type,
558                        if success { "SUCCESS" } else { "FAILED" }
559                    ),
560                )
561                .with_context("operation_id".to_string(), operation_id.to_string())
562                .with_duration(duration),
563            )?;
564        }
565
566        Ok(())
567    }
568
569    /// Get recent error events
570    fn get_recent_errors(&self, count: usize) -> TorshResult<Vec<DebugEvent>> {
571        let events = self
572            .events
573            .lock()
574            .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?;
575
576        Ok(events
577            .iter()
578            .filter(|e| e.level >= LogLevel::Error)
579            .rev()
580            .take(count)
581            .cloned()
582            .collect())
583    }
584
585    /// Run comprehensive system diagnostics
586    pub fn run_diagnostics(&self) -> TorshResult<Vec<DiagnosticResult>> {
587        let mut results = vec![
588            // Communication health check
589            self.check_communication_health()?,
590        ];
591
592        // Resource utilization check
593        results.push(self.check_resource_utilization()?);
594
595        // Bottleneck detection check
596        results.push(self.check_bottlenecks()?);
597
598        // Error rate check
599        results.push(self.check_error_rate()?);
600
601        // Process group health check
602        results.push(self.check_process_group_health()?);
603
604        // Store results
605        {
606            let mut diagnostic_history = self
607                .diagnostic_history
608                .lock()
609                .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?;
610            diagnostic_history.extend(results.clone());
611
612            // Keep only last 100 diagnostic results
613            let current_len = diagnostic_history.len();
614            if current_len > 100 {
615                diagnostic_history.drain(0..current_len - 100);
616            }
617        }
618
619        // Update statistics
620        {
621            let mut stats = self
622                .stats
623                .lock()
624                .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?;
625            stats.diagnostics_run += 1;
626        }
627
628        Ok(results)
629    }
630
631    /// Check communication health
632    fn check_communication_health(&self) -> TorshResult<DiagnosticResult> {
633        let profiler = get_global_profiler();
634        let all_stats = profiler.get_all_operation_stats()?;
635
636        let total_failed = 0u64;
637        let mut total_operations = 0u64;
638        let mut max_latency: f64 = 0.0;
639
640        for stats in all_stats.values() {
641            total_operations += stats.count;
642            max_latency = max_latency.max(stats.max_latency.as_secs_f64() * 1000.0);
643            // Note: We don't have a direct failed count in the current profiler,
644            // so this is a placeholder
645        }
646
647        let failure_rate = if total_operations > 0 {
648            total_failed as f64 / total_operations as f64
649        } else {
650            0.0
651        };
652        let passed = failure_rate < 0.01 && max_latency < 1000.0; // Less than 1% failures and < 1s max latency
653
654        Ok(DiagnosticResult {
655            check_name: "Communication Health".to_string(),
656            passed,
657            severity: if !passed {
658                LogLevel::Error
659            } else {
660                LogLevel::Info
661            },
662            description: if passed {
663                "Communication system is healthy".to_string()
664            } else {
665                format!(
666                    "Communication issues detected: {:.2}% failure rate, {:.1}ms max latency",
667                    failure_rate * 100.0,
668                    max_latency
669                )
670            },
671            remediation: if !passed {
672                vec![
673                    "Check network connectivity".to_string(),
674                    "Verify NCCL/MPI configuration".to_string(),
675                    "Monitor bandwidth utilization".to_string(),
676                ]
677            } else {
678                vec![]
679            },
680            data: {
681                let mut data = HashMap::new();
682                data.insert("failure_rate".to_string(), failure_rate.to_string());
683                data.insert("max_latency_ms".to_string(), max_latency.to_string());
684                data.insert("total_operations".to_string(), total_operations.to_string());
685                data
686            },
687        })
688    }
689
690    /// Check resource utilization
691    fn check_resource_utilization(&self) -> TorshResult<DiagnosticResult> {
692        let state = self.capture_resource_state()?;
693
694        let high_cpu = state.cpu_usage_pct > 95.0;
695        let high_memory = state.memory_usage_pct > 90.0;
696        let high_gpu = state.gpu_usage_pct.is_some_and(|gpu| gpu > 98.0);
697
698        let passed = !high_cpu && !high_memory && !high_gpu;
699
700        let mut issues = Vec::new();
701        if high_cpu {
702            issues.push(format!("High CPU usage: {:.1}%", state.cpu_usage_pct));
703        }
704        if high_memory {
705            issues.push(format!("High memory usage: {:.1}%", state.memory_usage_pct));
706        }
707        if high_gpu {
708            issues.push(format!(
709                "High GPU usage: {:.1}%",
710                state.gpu_usage_pct.unwrap_or(0.0)
711            ));
712        }
713
714        Ok(DiagnosticResult {
715            check_name: "Resource Utilization".to_string(),
716            passed,
717            severity: if !passed {
718                LogLevel::Warn
719            } else {
720                LogLevel::Info
721            },
722            description: if passed {
723                "Resource utilization is normal".to_string()
724            } else {
725                format!("Resource pressure detected: {}", issues.join(", "))
726            },
727            remediation: if !passed {
728                vec![
729                    "Scale to more resources if available".to_string(),
730                    "Optimize memory usage with gradient checkpointing".to_string(),
731                    "Consider model sharding or parallelism".to_string(),
732                ]
733            } else {
734                vec![]
735            },
736            data: {
737                let mut data = HashMap::new();
738                data.insert("cpu_usage_pct".to_string(), state.cpu_usage_pct.to_string());
739                data.insert(
740                    "memory_usage_pct".to_string(),
741                    state.memory_usage_pct.to_string(),
742                );
743                if let Some(gpu_usage) = state.gpu_usage_pct {
744                    data.insert("gpu_usage_pct".to_string(), gpu_usage.to_string());
745                }
746                data
747            },
748        })
749    }
750
751    /// Check for bottlenecks
752    fn check_bottlenecks(&self) -> TorshResult<DiagnosticResult> {
753        crate::bottleneck_detection::with_global_bottleneck_detector(|detector| {
754            let recent_bottlenecks = detector
755                .get_bottleneck_history()
756                .iter()
757                .filter(|b| b.detected_at > SystemTime::now() - Duration::from_secs(300)) // Last 5 minutes
758                .collect::<Vec<_>>();
759
760            let critical_bottlenecks = recent_bottlenecks
761                .iter()
762                .filter(|b| {
763                    matches!(
764                        b.severity,
765                        crate::bottleneck_detection::BottleneckSeverity::Critical
766                            | crate::bottleneck_detection::BottleneckSeverity::High
767                    )
768                })
769                .count();
770
771            let passed = critical_bottlenecks == 0;
772
773            Ok(DiagnosticResult {
774                check_name: "Bottleneck Detection".to_string(),
775                passed,
776                severity: if critical_bottlenecks > 0 {
777                    LogLevel::Error
778                } else {
779                    LogLevel::Info
780                },
781                description: if passed {
782                    "No critical bottlenecks detected".to_string()
783                } else {
784                    format!(
785                        "{} critical bottlenecks detected in the last 5 minutes",
786                        critical_bottlenecks
787                    )
788                },
789                remediation: if !passed {
790                    vec![
791                        "Review bottleneck analysis for specific recommendations".to_string(),
792                        "Consider load balancing adjustments".to_string(),
793                        "Optimize communication patterns".to_string(),
794                    ]
795                } else {
796                    vec![]
797                },
798                data: {
799                    let mut data = HashMap::new();
800                    data.insert(
801                        "recent_bottlenecks".to_string(),
802                        recent_bottlenecks.len().to_string(),
803                    );
804                    data.insert(
805                        "critical_bottlenecks".to_string(),
806                        critical_bottlenecks.to_string(),
807                    );
808                    data
809                },
810            })
811        })
812    }
813
814    /// Check error rate
815    fn check_error_rate(&self) -> TorshResult<DiagnosticResult> {
816        let events = self
817            .events
818            .lock()
819            .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?;
820
821        let recent_events = events
822            .iter()
823            .filter(|e| e.timestamp > SystemTime::now() - Duration::from_secs(300)) // Last 5 minutes
824            .collect::<Vec<_>>();
825
826        let error_events = recent_events
827            .iter()
828            .filter(|e| e.level >= LogLevel::Error)
829            .count();
830
831        let error_rate = if !recent_events.is_empty() {
832            error_events as f64 / recent_events.len() as f64
833        } else {
834            0.0
835        };
836
837        let passed = error_rate < 0.05; // Less than 5% error rate
838
839        Ok(DiagnosticResult {
840            check_name: "Error Rate".to_string(),
841            passed,
842            severity: if !passed {
843                LogLevel::Error
844            } else {
845                LogLevel::Info
846            },
847            description: if passed {
848                "Error rate is within normal limits".to_string()
849            } else {
850                format!(
851                    "High error rate detected: {:.1}% ({} errors in {} events)",
852                    error_rate * 100.0,
853                    error_events,
854                    recent_events.len()
855                )
856            },
857            remediation: if !passed {
858                vec![
859                    "Review recent error messages for patterns".to_string(),
860                    "Check system logs for underlying issues".to_string(),
861                    "Verify configuration and environment setup".to_string(),
862                ]
863            } else {
864                vec![]
865            },
866            data: {
867                let mut data = HashMap::new();
868                data.insert("error_rate".to_string(), error_rate.to_string());
869                data.insert("error_count".to_string(), error_events.to_string());
870                data.insert("total_events".to_string(), recent_events.len().to_string());
871                data
872            },
873        })
874    }
875
876    /// Check process group health
877    fn check_process_group_health(&self) -> TorshResult<DiagnosticResult> {
878        let state = self.capture_process_group_state()?;
879
880        let passed = state.failed_processes.is_empty() && state.health_status == "Healthy";
881
882        Ok(DiagnosticResult {
883            check_name: "Process Group Health".to_string(),
884            passed,
885            severity: if !passed {
886                LogLevel::Critical
887            } else {
888                LogLevel::Info
889            },
890            description: if passed {
891                format!(
892                    "Process group is healthy ({}/{} processes active)",
893                    state.active_processes, state.world_size
894                )
895            } else {
896                format!(
897                    "Process group issues: {} failed processes, status: {}",
898                    state.failed_processes.len(),
899                    state.health_status
900                )
901            },
902            remediation: if !passed {
903                vec![
904                    "Restart failed processes if possible".to_string(),
905                    "Check network connectivity between nodes".to_string(),
906                    "Verify resource availability on all nodes".to_string(),
907                ]
908            } else {
909                vec![]
910            },
911            data: {
912                let mut data = HashMap::new();
913                data.insert("world_size".to_string(), state.world_size.to_string());
914                data.insert(
915                    "active_processes".to_string(),
916                    state.active_processes.to_string(),
917                );
918                data.insert(
919                    "failed_processes".to_string(),
920                    state.failed_processes.len().to_string(),
921                );
922                data.insert("health_status".to_string(), state.health_status);
923                data
924            },
925        })
926    }
927
928    /// Generate comprehensive debug report
929    pub fn generate_debug_report(&self) -> TorshResult<String> {
930        let mut report = String::new();
931        report.push_str("=== Distributed Training Debug Report ===\n\n");
932
933        // System state
934        if let Ok(snapshot) = self.take_snapshot() {
935            report.push_str("=== Current System State ===\n");
936            report.push_str(&format!("Timestamp: {:?}\n", snapshot.timestamp));
937            report.push_str(&format!(
938                "Process Group: Rank {}/{}, Backend: {}, Status: {}\n",
939                snapshot.process_group.rank,
940                snapshot.process_group.world_size,
941                snapshot.process_group.backend,
942                snapshot.process_group.health_status
943            ));
944            report.push_str(&format!(
945                "Resources: CPU {:.1}%, Memory {:.1}%",
946                snapshot.resources.cpu_usage_pct, snapshot.resources.memory_usage_pct
947            ));
948            if let Some(gpu) = snapshot.resources.gpu_usage_pct {
949                report.push_str(&format!(", GPU {:.1}%", gpu));
950            }
951            report.push('\n');
952            report.push_str(&format!(
953                "Communication: {:.1}ms avg latency, {:.1} MB/s bandwidth\n",
954                snapshot.communication.avg_latency_ms, snapshot.communication.bandwidth_mbps
955            ));
956            report.push_str(&format!(
957                "Active Operations: {}\n\n",
958                snapshot.active_operations.len()
959            ));
960        }
961
962        // Diagnostic results
963        if let Ok(diagnostics) = self.run_diagnostics() {
964            report.push_str("=== Diagnostic Results ===\n");
965            for diagnostic in &diagnostics {
966                let status = if diagnostic.passed { "PASS" } else { "FAIL" };
967                report.push_str(&format!(
968                    "[{}] {}: {}\n",
969                    status, diagnostic.check_name, diagnostic.description
970                ));
971
972                if !diagnostic.remediation.is_empty() {
973                    report.push_str("  Recommended Actions:\n");
974                    for action in &diagnostic.remediation {
975                        report.push_str(&format!("  - {}\n", action));
976                    }
977                }
978            }
979            report.push('\n');
980        }
981
982        // Recent errors
983        if let Ok(errors) = self.get_recent_errors(5) {
984            if !errors.is_empty() {
985                report.push_str("=== Recent Errors ===\n");
986                for error in &errors {
987                    report.push_str(&error.format());
988                }
989                report.push('\n');
990            }
991        }
992
993        // Statistics
994        if let Ok(stats) = self.stats.lock() {
995            report.push_str("=== Debugger Statistics ===\n");
996            report.push_str(&format!("Events Captured: {}\n", stats.events_captured));
997            report.push_str(&format!("Snapshots Taken: {}\n", stats.snapshots_taken));
998            report.push_str(&format!("Diagnostics Run: {}\n", stats.diagnostics_run));
999            report.push_str(&format!("Errors Detected: {}\n", stats.errors_detected));
1000        }
1001
1002        Ok(report)
1003    }
1004
1005    /// Export debug data to JSON
1006    pub fn export_debug_data(&self) -> TorshResult<String> {
1007        #[derive(Serialize)]
1008        struct DebugExport {
1009            config: DebugConfig,
1010            events: Vec<DebugEvent>,
1011            snapshots: Vec<SystemStateSnapshot>,
1012            diagnostic_history: Vec<DiagnosticResult>,
1013            statistics: Option<DebuggerStats>,
1014        }
1015
1016        let config = self
1017            .config
1018            .read()
1019            .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?
1020            .clone();
1021        let events = self
1022            .events
1023            .lock()
1024            .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?
1025            .iter()
1026            .cloned()
1027            .collect();
1028        let snapshots = self
1029            .snapshots
1030            .lock()
1031            .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?
1032            .iter()
1033            .cloned()
1034            .collect();
1035        let diagnostic_history = self
1036            .diagnostic_history
1037            .lock()
1038            .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?
1039            .clone();
1040        let statistics = self.stats.lock().ok().map(|s| DebuggerStats {
1041            events_captured: s.events_captured,
1042            snapshots_taken: s.snapshots_taken,
1043            diagnostics_run: s.diagnostics_run,
1044            errors_detected: s.errors_detected,
1045        });
1046
1047        let export = DebugExport {
1048            config,
1049            events,
1050            snapshots,
1051            diagnostic_history,
1052            statistics,
1053        };
1054
1055        serde_json::to_string_pretty(&export).map_err(|e| {
1056            TorshDistributedError::backend_error(
1057                "debugging",
1058                format!("JSON serialization failed: {}", e),
1059            )
1060        })
1061    }
1062
1063    /// Clear all debug data
1064    pub fn clear(&self) -> TorshResult<()> {
1065        {
1066            let mut events = self
1067                .events
1068                .lock()
1069                .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?;
1070            events.clear();
1071        }
1072
1073        {
1074            let mut snapshots = self
1075                .snapshots
1076                .lock()
1077                .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?;
1078            snapshots.clear();
1079        }
1080
1081        {
1082            let mut active_ops = self
1083                .active_operations
1084                .lock()
1085                .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?;
1086            active_ops.clear();
1087        }
1088
1089        {
1090            let mut diagnostic_history = self
1091                .diagnostic_history
1092                .lock()
1093                .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?;
1094            diagnostic_history.clear();
1095        }
1096
1097        {
1098            let mut stats = self
1099                .stats
1100                .lock()
1101                .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?;
1102            *stats = DebuggerStats::default();
1103        }
1104
1105        Ok(())
1106    }
1107}
1108
1109impl Default for DistributedDebugger {
1110    fn default() -> Self {
1111        Self::new()
1112    }
1113}
1114
1115/// Global debugger instance
1116static GLOBAL_DEBUGGER: std::sync::OnceLock<Arc<DistributedDebugger>> = std::sync::OnceLock::new();
1117
1118/// Get the global debugger instance
1119pub fn get_global_debugger() -> &'static Arc<DistributedDebugger> {
1120    GLOBAL_DEBUGGER.get_or_init(|| Arc::new(DistributedDebugger::new()))
1121}
1122
1123/// Initialize the global debugger with custom configuration
1124pub fn init_global_debugger(config: DebugConfig) -> TorshResult<()> {
1125    let debugger = Arc::new(DistributedDebugger::with_config(config));
1126    GLOBAL_DEBUGGER.set(debugger).map_err(|_| {
1127        TorshDistributedError::backend_error("debugging", "Global debugger already initialized")
1128    })?;
1129    Ok(())
1130}
1131
1132/// Convenience macros for debugging
1133#[macro_export]
1134macro_rules! debug_log {
1135    ($level:expr, $source:expr, $rank:expr, $msg:expr) => {
1136        let debugger = $crate::debugging::get_global_debugger();
1137        let event = $crate::debugging::DebugEvent::new($level, $source.to_string(), $rank, $msg.to_string());
1138        let _ = debugger.log_event(event);
1139    };
1140    ($level:expr, $source:expr, $rank:expr, $msg:expr, $($key:expr => $value:expr),+) => {
1141        let debugger = $crate::debugging::get_global_debugger();
1142        let mut event = $crate::debugging::DebugEvent::new($level, $source.to_string(), $rank, $msg.to_string());
1143        $(
1144            event = event.with_context($key.to_string(), $value.to_string());
1145        )+
1146        let _ = debugger.log_event(event);
1147    };
1148}
1149
1150#[macro_export]
1151macro_rules! debug_trace_operation {
1152    ($op_type:expr, $ranks:expr, $code:block) => {{
1153        let debugger = $crate::debugging::get_global_debugger();
1154        let op_id = debugger.start_operation($op_type.to_string(), $ranks).unwrap_or_default();
1155        let result = $code;
1156        let _ = debugger.complete_operation(&op_id, true); // Assume success, real impl would check result
1157        result
1158    }};
1159}
1160
1161#[cfg(test)]
1162mod tests {
1163    use super::*;
1164
1165    #[test]
1166    fn test_debug_event_creation() {
1167        let event = DebugEvent::new(
1168            LogLevel::Info,
1169            "test_module".to_string(),
1170            0,
1171            "Test message".to_string(),
1172        )
1173        .with_context("key".to_string(), "value".to_string())
1174        .with_duration(Duration::from_millis(100));
1175
1176        assert_eq!(event.level, LogLevel::Info);
1177        assert_eq!(event.source, "test_module");
1178        assert_eq!(event.message, "Test message");
1179        assert_eq!(event.context.get("key"), Some(&"value".to_string()));
1180        assert_eq!(event.duration, Some(Duration::from_millis(100)));
1181    }
1182
1183    #[test]
1184    fn test_debugger_creation() {
1185        let debugger = DistributedDebugger::new();
1186        assert!(debugger.get_active_operations().is_empty());
1187    }
1188
1189    #[test]
1190    fn test_event_logging() {
1191        let debugger = DistributedDebugger::new();
1192        let event = DebugEvent::new(
1193            LogLevel::Info,
1194            "test".to_string(),
1195            0,
1196            "Test event".to_string(),
1197        );
1198
1199        debugger.log_event(event).unwrap();
1200
1201        let events = debugger.events.lock().expect("lock should not be poisoned");
1202        assert_eq!(events.len(), 1);
1203        assert_eq!(events[0].message, "Test event");
1204    }
1205
1206    #[test]
1207    fn test_operation_tracking() {
1208        let debugger = DistributedDebugger::new();
1209
1210        let op_id = debugger
1211            .start_operation("test_op".to_string(), vec![0, 1])
1212            .unwrap();
1213        assert!(debugger.get_active_operations().len() == 1);
1214
1215        debugger.update_operation_progress(&op_id, 50.0).unwrap();
1216        debugger.complete_operation(&op_id, true).unwrap();
1217
1218        assert!(debugger.get_active_operations().is_empty());
1219    }
1220
1221    #[test]
1222    fn test_snapshot_taking() {
1223        let debugger = DistributedDebugger::new();
1224        let snapshot = debugger.take_snapshot().unwrap();
1225
1226        assert_eq!(snapshot.process_group.backend, "Mock");
1227        assert!(snapshot.recent_errors.is_empty());
1228    }
1229
1230    #[test]
1231    fn test_diagnostics() {
1232        let debugger = DistributedDebugger::new();
1233        let results = debugger.run_diagnostics().unwrap();
1234
1235        assert!(!results.is_empty());
1236        assert!(results
1237            .iter()
1238            .any(|r| r.check_name == "Communication Health"));
1239        assert!(results
1240            .iter()
1241            .any(|r| r.check_name == "Resource Utilization"));
1242    }
1243
1244    #[test]
1245    fn test_debug_report_generation() {
1246        let debugger = DistributedDebugger::new();
1247        let report = debugger.generate_debug_report().unwrap();
1248
1249        assert!(report.contains("Distributed Training Debug Report"));
1250        assert!(report.contains("Current System State"));
1251        assert!(report.contains("Diagnostic Results"));
1252    }
1253
1254    #[test]
1255    fn test_json_export() {
1256        let debugger = DistributedDebugger::new();
1257        let event = DebugEvent::new(
1258            LogLevel::Info,
1259            "test".to_string(),
1260            0,
1261            "Export test".to_string(),
1262        );
1263        debugger.log_event(event).unwrap();
1264
1265        let json = debugger.export_debug_data().unwrap();
1266        assert!(json.contains("Export test"));
1267        assert!(json.contains("events"));
1268        assert!(json.contains("config"));
1269    }
1270
1271    #[test]
1272    fn test_log_level_ordering() {
1273        assert!(LogLevel::Critical > LogLevel::Error);
1274        assert!(LogLevel::Error > LogLevel::Warn);
1275        assert!(LogLevel::Warn > LogLevel::Info);
1276        assert!(LogLevel::Info > LogLevel::Debug);
1277        assert!(LogLevel::Debug > LogLevel::Trace);
1278    }
1279}