1use crate::metrics::get_global_metrics_collector;
7use crate::profiling::get_global_profiler;
8use crate::{TorshDistributedError, TorshResult};
9use log::info;
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, VecDeque};
12use std::sync::{Arc, Mutex, RwLock};
13use std::time::{Duration, SystemTime, UNIX_EPOCH};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
17pub enum LogLevel {
18 Trace,
19 Debug,
20 Info,
21 Warn,
22 Error,
23 Critical,
24}
25
26impl std::fmt::Display for LogLevel {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 match self {
29 LogLevel::Trace => write!(f, "TRACE"),
30 LogLevel::Debug => write!(f, "DEBUG"),
31 LogLevel::Info => write!(f, "INFO"),
32 LogLevel::Warn => write!(f, "WARN"),
33 LogLevel::Error => write!(f, "ERROR"),
34 LogLevel::Critical => write!(f, "CRITICAL"),
35 }
36 }
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct DebugEvent {
42 pub event_id: u64,
44 pub timestamp: SystemTime,
46 pub level: LogLevel,
48 pub source: String,
50 pub rank: u32,
52 pub message: String,
54 pub context: HashMap<String, String>,
56 pub call_stack: Vec<String>,
58 pub duration: Option<Duration>,
60}
61
62impl DebugEvent {
63 pub fn new(level: LogLevel, source: String, rank: u32, message: String) -> Self {
65 Self {
66 event_id: 0, timestamp: SystemTime::now(),
68 level,
69 source,
70 rank,
71 message,
72 context: HashMap::new(),
73 call_stack: Vec::new(),
74 duration: None,
75 }
76 }
77
78 pub fn with_context(mut self, key: String, value: String) -> Self {
80 self.context.insert(key, value);
81 self
82 }
83
84 pub fn with_call_stack(mut self, stack: Vec<String>) -> Self {
86 self.call_stack = stack;
87 self
88 }
89
90 pub fn with_duration(mut self, duration: Duration) -> Self {
92 self.duration = Some(duration);
93 self
94 }
95
96 pub fn format(&self) -> String {
98 let timestamp_str = self
99 .timestamp
100 .duration_since(UNIX_EPOCH)
101 .map(|d| format!("{:.3}", d.as_secs_f64()))
102 .unwrap_or_else(|_| "unknown".to_string());
103
104 let duration_str = self
105 .duration
106 .map(|d| format!(" [{}ms]", d.as_millis()))
107 .unwrap_or_default();
108
109 format!(
110 "[{}] {} [{}:{}] {}{}\n",
111 timestamp_str, self.level, self.source, self.rank, self.message, duration_str
112 )
113 }
114}
115
116#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct SystemStateSnapshot {
119 pub timestamp: SystemTime,
121 pub process_group: ProcessGroupState,
123 pub communication: CommunicationState,
125 pub resources: ResourceState,
127 pub active_operations: Vec<ActiveOperation>,
129 pub recent_errors: Vec<DebugEvent>,
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct ProcessGroupState {
136 pub rank: u32,
138 pub world_size: u32,
140 pub backend: String,
142 pub health_status: String,
144 pub active_processes: u32,
146 pub failed_processes: Vec<u32>,
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct CommunicationState {
153 pub pending_operations: u32,
155 pub failed_operations: u32,
157 pub avg_latency_ms: f64,
159 pub bandwidth_mbps: f64,
161 pub queue_length: u32,
163 pub last_success: Option<SystemTime>,
165}
166
167#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct ResourceState {
170 pub cpu_usage_pct: f64,
172 pub memory_usage_pct: f64,
174 pub gpu_usage_pct: Option<f64>,
176 pub network_io_bps: u64,
178 pub disk_io_bps: u64,
180 pub memory_pressure: String,
182}
183
184#[derive(Debug, Clone, Serialize, Deserialize)]
186pub struct ActiveOperation {
187 pub operation_type: String,
189 pub start_time: SystemTime,
191 pub expected_duration: Option<Duration>,
193 pub progress_pct: f64,
195 pub ranks: Vec<u32>,
197 pub status: String,
199}
200
201#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct DiagnosticResult {
204 pub check_name: String,
206 pub passed: bool,
208 pub severity: LogLevel,
210 pub description: String,
212 pub remediation: Vec<String>,
214 pub data: HashMap<String, String>,
216}
217
218#[derive(Debug, Clone, Serialize, Deserialize)]
220pub struct DebugConfig {
221 pub enabled: bool,
223 pub min_log_level: LogLevel,
225 pub max_events: usize,
227 pub capture_call_stacks: bool,
229 pub real_time_monitoring: bool,
231 pub snapshot_interval_secs: u64,
233 pub auto_diagnosis_interval_secs: u64,
235}
236
237impl Default for DebugConfig {
238 fn default() -> Self {
239 Self {
240 enabled: true,
241 min_log_level: LogLevel::Info,
242 max_events: 1000,
243 capture_call_stacks: false, real_time_monitoring: true,
245 snapshot_interval_secs: 30,
246 auto_diagnosis_interval_secs: 60,
247 }
248 }
249}
250
251pub struct DistributedDebugger {
253 config: RwLock<DebugConfig>,
255 event_counter: Mutex<u64>,
257 events: Mutex<VecDeque<DebugEvent>>,
259 snapshots: Mutex<VecDeque<SystemStateSnapshot>>,
261 active_operations: Mutex<HashMap<String, ActiveOperation>>,
263 diagnostic_history: Mutex<Vec<DiagnosticResult>>,
265 stats: Mutex<DebuggerStats>,
267}
268
269#[derive(Debug, Default, Serialize, Deserialize)]
271struct DebuggerStats {
272 events_captured: u64,
273 snapshots_taken: u64,
274 diagnostics_run: u64,
275 errors_detected: u64,
276}
277
278impl DistributedDebugger {
279 pub fn new() -> Self {
281 Self::with_config(DebugConfig::default())
282 }
283
284 pub fn with_config(config: DebugConfig) -> Self {
286 Self {
287 config: RwLock::new(config),
288 event_counter: Mutex::new(0),
289 events: Mutex::new(VecDeque::new()),
290 snapshots: Mutex::new(VecDeque::new()),
291 active_operations: Mutex::new(HashMap::new()),
292 diagnostic_history: Mutex::new(Vec::new()),
293 stats: Mutex::new(DebuggerStats::default()),
294 }
295 }
296
297 pub fn log_event(&self, mut event: DebugEvent) -> TorshResult<()> {
299 let config = self
300 .config
301 .read()
302 .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?;
303
304 if !config.enabled || event.level < config.min_log_level {
305 return Ok(());
306 }
307
308 {
310 let mut counter = self
311 .event_counter
312 .lock()
313 .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?;
314 *counter += 1;
315 event.event_id = *counter;
316 }
317
318 if config.capture_call_stacks {
320 event.call_stack = vec!["main".to_string(), "debug_function".to_string()];
322 }
323
324 {
326 let mut events = self
327 .events
328 .lock()
329 .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?;
330 events.push_back(event.clone());
331
332 if events.len() > config.max_events {
334 events.pop_front();
335 }
336 }
337
338 {
340 let mut stats = self
341 .stats
342 .lock()
343 .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?;
344 stats.events_captured += 1;
345 if event.level >= LogLevel::Error {
346 stats.errors_detected += 1;
347 }
348 }
349
350 if event.level >= LogLevel::Critical {
352 info!("CRITICAL: {}", event.format());
353 }
354
355 Ok(())
356 }
357
358 pub fn take_snapshot(&self) -> TorshResult<SystemStateSnapshot> {
360 let snapshot = SystemStateSnapshot {
361 timestamp: SystemTime::now(),
362 process_group: self.capture_process_group_state()?,
363 communication: self.capture_communication_state()?,
364 resources: self.capture_resource_state()?,
365 active_operations: self.get_active_operations(),
366 recent_errors: self.get_recent_errors(10)?,
367 };
368
369 {
371 let mut snapshots = self
372 .snapshots
373 .lock()
374 .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?;
375 snapshots.push_back(snapshot.clone());
376
377 if snapshots.len() > 20 {
379 snapshots.pop_front();
380 }
381 }
382
383 {
385 let mut stats = self
386 .stats
387 .lock()
388 .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?;
389 stats.snapshots_taken += 1;
390 }
391
392 Ok(snapshot)
393 }
394
395 fn capture_process_group_state(&self) -> TorshResult<ProcessGroupState> {
397 Ok(ProcessGroupState {
399 rank: 0, world_size: 1, backend: "Mock".to_string(), health_status: "Healthy".to_string(),
403 active_processes: 1,
404 failed_processes: Vec::new(),
405 })
406 }
407
408 fn capture_communication_state(&self) -> TorshResult<CommunicationState> {
410 let metrics_collector = get_global_metrics_collector();
411
412 if let Ok(comm_history) = metrics_collector.get_communication_history() {
413 if let Some(latest) = comm_history.last() {
414 return Ok(CommunicationState {
415 pending_operations: 0, failed_operations: latest.value.failed_operations as u32,
417 avg_latency_ms: latest.value.avg_latency_ms,
418 bandwidth_mbps: latest.value.avg_bandwidth_mbps,
419 queue_length: 0, last_success: Some(latest.timestamp),
421 });
422 }
423 }
424
425 Ok(CommunicationState {
426 pending_operations: 0,
427 failed_operations: 0,
428 avg_latency_ms: 0.0,
429 bandwidth_mbps: 0.0,
430 queue_length: 0,
431 last_success: None,
432 })
433 }
434
435 fn capture_resource_state(&self) -> TorshResult<ResourceState> {
437 let metrics_collector = get_global_metrics_collector();
438
439 if let Ok(system_history) = metrics_collector.get_system_history() {
440 if let Some(latest) = system_history.last() {
441 return Ok(ResourceState {
442 cpu_usage_pct: latest.value.cpu_usage_pct,
443 memory_usage_pct: latest.value.memory_usage_pct,
444 gpu_usage_pct: latest.value.gpu_usage_pct,
445 network_io_bps: latest.value.network_bytes_rx + latest.value.network_bytes_tx,
446 disk_io_bps: latest.value.disk_bytes_read + latest.value.disk_bytes_write,
447 memory_pressure: if latest.value.memory_usage_pct > 90.0 {
448 "High"
449 } else {
450 "Normal"
451 }
452 .to_string(),
453 });
454 }
455 }
456
457 Ok(ResourceState {
458 cpu_usage_pct: 0.0,
459 memory_usage_pct: 0.0,
460 gpu_usage_pct: None,
461 network_io_bps: 0,
462 disk_io_bps: 0,
463 memory_pressure: "Unknown".to_string(),
464 })
465 }
466
467 fn get_active_operations(&self) -> Vec<ActiveOperation> {
469 self.active_operations
470 .lock()
471 .map(|ops| ops.values().cloned().collect())
472 .unwrap_or_default()
473 }
474
475 pub fn start_operation(&self, operation_type: String, ranks: Vec<u32>) -> TorshResult<String> {
477 let operation_id = format!(
478 "{}_{}",
479 operation_type,
480 SystemTime::now()
481 .duration_since(UNIX_EPOCH)
482 .unwrap_or_default()
483 .as_nanos()
484 );
485
486 let operation = ActiveOperation {
487 operation_type: operation_type.clone(),
488 start_time: SystemTime::now(),
489 expected_duration: None,
490 progress_pct: 0.0,
491 ranks,
492 status: "Running".to_string(),
493 };
494
495 {
496 let mut active_ops = self
497 .active_operations
498 .lock()
499 .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?;
500 active_ops.insert(operation_id.clone(), operation);
501 }
502
503 self.log_event(
504 DebugEvent::new(
505 LogLevel::Debug,
506 "DistributedDebugger".to_string(),
507 0, format!("Started operation: {}", operation_type),
509 )
510 .with_context("operation_id".to_string(), operation_id.clone()),
511 )?;
512
513 Ok(operation_id)
514 }
515
516 pub fn update_operation_progress(
518 &self,
519 operation_id: &str,
520 progress_pct: f64,
521 ) -> TorshResult<()> {
522 let mut active_ops = self
523 .active_operations
524 .lock()
525 .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?;
526
527 if let Some(operation) = active_ops.get_mut(operation_id) {
528 operation.progress_pct = progress_pct;
529 }
530
531 Ok(())
532 }
533
534 pub fn complete_operation(&self, operation_id: &str, success: bool) -> TorshResult<()> {
536 let mut active_ops = self
537 .active_operations
538 .lock()
539 .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?;
540
541 if let Some(operation) = active_ops.remove(operation_id) {
542 let duration = SystemTime::now()
543 .duration_since(operation.start_time)
544 .unwrap_or_default();
545
546 self.log_event(
547 DebugEvent::new(
548 if success {
549 LogLevel::Debug
550 } else {
551 LogLevel::Error
552 },
553 "DistributedDebugger".to_string(),
554 0, format!(
556 "Completed operation: {} ({})",
557 operation.operation_type,
558 if success { "SUCCESS" } else { "FAILED" }
559 ),
560 )
561 .with_context("operation_id".to_string(), operation_id.to_string())
562 .with_duration(duration),
563 )?;
564 }
565
566 Ok(())
567 }
568
569 fn get_recent_errors(&self, count: usize) -> TorshResult<Vec<DebugEvent>> {
571 let events = self
572 .events
573 .lock()
574 .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?;
575
576 Ok(events
577 .iter()
578 .filter(|e| e.level >= LogLevel::Error)
579 .rev()
580 .take(count)
581 .cloned()
582 .collect())
583 }
584
585 pub fn run_diagnostics(&self) -> TorshResult<Vec<DiagnosticResult>> {
587 let mut results = vec![
588 self.check_communication_health()?,
590 ];
591
592 results.push(self.check_resource_utilization()?);
594
595 results.push(self.check_bottlenecks()?);
597
598 results.push(self.check_error_rate()?);
600
601 results.push(self.check_process_group_health()?);
603
604 {
606 let mut diagnostic_history = self
607 .diagnostic_history
608 .lock()
609 .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?;
610 diagnostic_history.extend(results.clone());
611
612 let current_len = diagnostic_history.len();
614 if current_len > 100 {
615 diagnostic_history.drain(0..current_len - 100);
616 }
617 }
618
619 {
621 let mut stats = self
622 .stats
623 .lock()
624 .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?;
625 stats.diagnostics_run += 1;
626 }
627
628 Ok(results)
629 }
630
631 fn check_communication_health(&self) -> TorshResult<DiagnosticResult> {
633 let profiler = get_global_profiler();
634 let all_stats = profiler.get_all_operation_stats()?;
635
636 let total_failed = 0u64;
637 let mut total_operations = 0u64;
638 let mut max_latency: f64 = 0.0;
639
640 for stats in all_stats.values() {
641 total_operations += stats.count;
642 max_latency = max_latency.max(stats.max_latency.as_secs_f64() * 1000.0);
643 }
646
647 let failure_rate = if total_operations > 0 {
648 total_failed as f64 / total_operations as f64
649 } else {
650 0.0
651 };
652 let passed = failure_rate < 0.01 && max_latency < 1000.0; Ok(DiagnosticResult {
655 check_name: "Communication Health".to_string(),
656 passed,
657 severity: if !passed {
658 LogLevel::Error
659 } else {
660 LogLevel::Info
661 },
662 description: if passed {
663 "Communication system is healthy".to_string()
664 } else {
665 format!(
666 "Communication issues detected: {:.2}% failure rate, {:.1}ms max latency",
667 failure_rate * 100.0,
668 max_latency
669 )
670 },
671 remediation: if !passed {
672 vec![
673 "Check network connectivity".to_string(),
674 "Verify NCCL/MPI configuration".to_string(),
675 "Monitor bandwidth utilization".to_string(),
676 ]
677 } else {
678 vec![]
679 },
680 data: {
681 let mut data = HashMap::new();
682 data.insert("failure_rate".to_string(), failure_rate.to_string());
683 data.insert("max_latency_ms".to_string(), max_latency.to_string());
684 data.insert("total_operations".to_string(), total_operations.to_string());
685 data
686 },
687 })
688 }
689
690 fn check_resource_utilization(&self) -> TorshResult<DiagnosticResult> {
692 let state = self.capture_resource_state()?;
693
694 let high_cpu = state.cpu_usage_pct > 95.0;
695 let high_memory = state.memory_usage_pct > 90.0;
696 let high_gpu = state.gpu_usage_pct.is_some_and(|gpu| gpu > 98.0);
697
698 let passed = !high_cpu && !high_memory && !high_gpu;
699
700 let mut issues = Vec::new();
701 if high_cpu {
702 issues.push(format!("High CPU usage: {:.1}%", state.cpu_usage_pct));
703 }
704 if high_memory {
705 issues.push(format!("High memory usage: {:.1}%", state.memory_usage_pct));
706 }
707 if high_gpu {
708 issues.push(format!(
709 "High GPU usage: {:.1}%",
710 state.gpu_usage_pct.unwrap_or(0.0)
711 ));
712 }
713
714 Ok(DiagnosticResult {
715 check_name: "Resource Utilization".to_string(),
716 passed,
717 severity: if !passed {
718 LogLevel::Warn
719 } else {
720 LogLevel::Info
721 },
722 description: if passed {
723 "Resource utilization is normal".to_string()
724 } else {
725 format!("Resource pressure detected: {}", issues.join(", "))
726 },
727 remediation: if !passed {
728 vec![
729 "Scale to more resources if available".to_string(),
730 "Optimize memory usage with gradient checkpointing".to_string(),
731 "Consider model sharding or parallelism".to_string(),
732 ]
733 } else {
734 vec![]
735 },
736 data: {
737 let mut data = HashMap::new();
738 data.insert("cpu_usage_pct".to_string(), state.cpu_usage_pct.to_string());
739 data.insert(
740 "memory_usage_pct".to_string(),
741 state.memory_usage_pct.to_string(),
742 );
743 if let Some(gpu_usage) = state.gpu_usage_pct {
744 data.insert("gpu_usage_pct".to_string(), gpu_usage.to_string());
745 }
746 data
747 },
748 })
749 }
750
751 fn check_bottlenecks(&self) -> TorshResult<DiagnosticResult> {
753 crate::bottleneck_detection::with_global_bottleneck_detector(|detector| {
754 let recent_bottlenecks = detector
755 .get_bottleneck_history()
756 .iter()
757 .filter(|b| b.detected_at > SystemTime::now() - Duration::from_secs(300)) .collect::<Vec<_>>();
759
760 let critical_bottlenecks = recent_bottlenecks
761 .iter()
762 .filter(|b| {
763 matches!(
764 b.severity,
765 crate::bottleneck_detection::BottleneckSeverity::Critical
766 | crate::bottleneck_detection::BottleneckSeverity::High
767 )
768 })
769 .count();
770
771 let passed = critical_bottlenecks == 0;
772
773 Ok(DiagnosticResult {
774 check_name: "Bottleneck Detection".to_string(),
775 passed,
776 severity: if critical_bottlenecks > 0 {
777 LogLevel::Error
778 } else {
779 LogLevel::Info
780 },
781 description: if passed {
782 "No critical bottlenecks detected".to_string()
783 } else {
784 format!(
785 "{} critical bottlenecks detected in the last 5 minutes",
786 critical_bottlenecks
787 )
788 },
789 remediation: if !passed {
790 vec![
791 "Review bottleneck analysis for specific recommendations".to_string(),
792 "Consider load balancing adjustments".to_string(),
793 "Optimize communication patterns".to_string(),
794 ]
795 } else {
796 vec![]
797 },
798 data: {
799 let mut data = HashMap::new();
800 data.insert(
801 "recent_bottlenecks".to_string(),
802 recent_bottlenecks.len().to_string(),
803 );
804 data.insert(
805 "critical_bottlenecks".to_string(),
806 critical_bottlenecks.to_string(),
807 );
808 data
809 },
810 })
811 })
812 }
813
814 fn check_error_rate(&self) -> TorshResult<DiagnosticResult> {
816 let events = self
817 .events
818 .lock()
819 .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?;
820
821 let recent_events = events
822 .iter()
823 .filter(|e| e.timestamp > SystemTime::now() - Duration::from_secs(300)) .collect::<Vec<_>>();
825
826 let error_events = recent_events
827 .iter()
828 .filter(|e| e.level >= LogLevel::Error)
829 .count();
830
831 let error_rate = if !recent_events.is_empty() {
832 error_events as f64 / recent_events.len() as f64
833 } else {
834 0.0
835 };
836
837 let passed = error_rate < 0.05; Ok(DiagnosticResult {
840 check_name: "Error Rate".to_string(),
841 passed,
842 severity: if !passed {
843 LogLevel::Error
844 } else {
845 LogLevel::Info
846 },
847 description: if passed {
848 "Error rate is within normal limits".to_string()
849 } else {
850 format!(
851 "High error rate detected: {:.1}% ({} errors in {} events)",
852 error_rate * 100.0,
853 error_events,
854 recent_events.len()
855 )
856 },
857 remediation: if !passed {
858 vec![
859 "Review recent error messages for patterns".to_string(),
860 "Check system logs for underlying issues".to_string(),
861 "Verify configuration and environment setup".to_string(),
862 ]
863 } else {
864 vec![]
865 },
866 data: {
867 let mut data = HashMap::new();
868 data.insert("error_rate".to_string(), error_rate.to_string());
869 data.insert("error_count".to_string(), error_events.to_string());
870 data.insert("total_events".to_string(), recent_events.len().to_string());
871 data
872 },
873 })
874 }
875
876 fn check_process_group_health(&self) -> TorshResult<DiagnosticResult> {
878 let state = self.capture_process_group_state()?;
879
880 let passed = state.failed_processes.is_empty() && state.health_status == "Healthy";
881
882 Ok(DiagnosticResult {
883 check_name: "Process Group Health".to_string(),
884 passed,
885 severity: if !passed {
886 LogLevel::Critical
887 } else {
888 LogLevel::Info
889 },
890 description: if passed {
891 format!(
892 "Process group is healthy ({}/{} processes active)",
893 state.active_processes, state.world_size
894 )
895 } else {
896 format!(
897 "Process group issues: {} failed processes, status: {}",
898 state.failed_processes.len(),
899 state.health_status
900 )
901 },
902 remediation: if !passed {
903 vec![
904 "Restart failed processes if possible".to_string(),
905 "Check network connectivity between nodes".to_string(),
906 "Verify resource availability on all nodes".to_string(),
907 ]
908 } else {
909 vec![]
910 },
911 data: {
912 let mut data = HashMap::new();
913 data.insert("world_size".to_string(), state.world_size.to_string());
914 data.insert(
915 "active_processes".to_string(),
916 state.active_processes.to_string(),
917 );
918 data.insert(
919 "failed_processes".to_string(),
920 state.failed_processes.len().to_string(),
921 );
922 data.insert("health_status".to_string(), state.health_status);
923 data
924 },
925 })
926 }
927
928 pub fn generate_debug_report(&self) -> TorshResult<String> {
930 let mut report = String::new();
931 report.push_str("=== Distributed Training Debug Report ===\n\n");
932
933 if let Ok(snapshot) = self.take_snapshot() {
935 report.push_str("=== Current System State ===\n");
936 report.push_str(&format!("Timestamp: {:?}\n", snapshot.timestamp));
937 report.push_str(&format!(
938 "Process Group: Rank {}/{}, Backend: {}, Status: {}\n",
939 snapshot.process_group.rank,
940 snapshot.process_group.world_size,
941 snapshot.process_group.backend,
942 snapshot.process_group.health_status
943 ));
944 report.push_str(&format!(
945 "Resources: CPU {:.1}%, Memory {:.1}%",
946 snapshot.resources.cpu_usage_pct, snapshot.resources.memory_usage_pct
947 ));
948 if let Some(gpu) = snapshot.resources.gpu_usage_pct {
949 report.push_str(&format!(", GPU {:.1}%", gpu));
950 }
951 report.push('\n');
952 report.push_str(&format!(
953 "Communication: {:.1}ms avg latency, {:.1} MB/s bandwidth\n",
954 snapshot.communication.avg_latency_ms, snapshot.communication.bandwidth_mbps
955 ));
956 report.push_str(&format!(
957 "Active Operations: {}\n\n",
958 snapshot.active_operations.len()
959 ));
960 }
961
962 if let Ok(diagnostics) = self.run_diagnostics() {
964 report.push_str("=== Diagnostic Results ===\n");
965 for diagnostic in &diagnostics {
966 let status = if diagnostic.passed { "PASS" } else { "FAIL" };
967 report.push_str(&format!(
968 "[{}] {}: {}\n",
969 status, diagnostic.check_name, diagnostic.description
970 ));
971
972 if !diagnostic.remediation.is_empty() {
973 report.push_str(" Recommended Actions:\n");
974 for action in &diagnostic.remediation {
975 report.push_str(&format!(" - {}\n", action));
976 }
977 }
978 }
979 report.push('\n');
980 }
981
982 if let Ok(errors) = self.get_recent_errors(5) {
984 if !errors.is_empty() {
985 report.push_str("=== Recent Errors ===\n");
986 for error in &errors {
987 report.push_str(&error.format());
988 }
989 report.push('\n');
990 }
991 }
992
993 if let Ok(stats) = self.stats.lock() {
995 report.push_str("=== Debugger Statistics ===\n");
996 report.push_str(&format!("Events Captured: {}\n", stats.events_captured));
997 report.push_str(&format!("Snapshots Taken: {}\n", stats.snapshots_taken));
998 report.push_str(&format!("Diagnostics Run: {}\n", stats.diagnostics_run));
999 report.push_str(&format!("Errors Detected: {}\n", stats.errors_detected));
1000 }
1001
1002 Ok(report)
1003 }
1004
1005 pub fn export_debug_data(&self) -> TorshResult<String> {
1007 #[derive(Serialize)]
1008 struct DebugExport {
1009 config: DebugConfig,
1010 events: Vec<DebugEvent>,
1011 snapshots: Vec<SystemStateSnapshot>,
1012 diagnostic_history: Vec<DiagnosticResult>,
1013 statistics: Option<DebuggerStats>,
1014 }
1015
1016 let config = self
1017 .config
1018 .read()
1019 .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?
1020 .clone();
1021 let events = self
1022 .events
1023 .lock()
1024 .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?
1025 .iter()
1026 .cloned()
1027 .collect();
1028 let snapshots = self
1029 .snapshots
1030 .lock()
1031 .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?
1032 .iter()
1033 .cloned()
1034 .collect();
1035 let diagnostic_history = self
1036 .diagnostic_history
1037 .lock()
1038 .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?
1039 .clone();
1040 let statistics = self.stats.lock().ok().map(|s| DebuggerStats {
1041 events_captured: s.events_captured,
1042 snapshots_taken: s.snapshots_taken,
1043 diagnostics_run: s.diagnostics_run,
1044 errors_detected: s.errors_detected,
1045 });
1046
1047 let export = DebugExport {
1048 config,
1049 events,
1050 snapshots,
1051 diagnostic_history,
1052 statistics,
1053 };
1054
1055 serde_json::to_string_pretty(&export).map_err(|e| {
1056 TorshDistributedError::backend_error(
1057 "debugging",
1058 format!("JSON serialization failed: {}", e),
1059 )
1060 })
1061 }
1062
1063 pub fn clear(&self) -> TorshResult<()> {
1065 {
1066 let mut events = self
1067 .events
1068 .lock()
1069 .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?;
1070 events.clear();
1071 }
1072
1073 {
1074 let mut snapshots = self
1075 .snapshots
1076 .lock()
1077 .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?;
1078 snapshots.clear();
1079 }
1080
1081 {
1082 let mut active_ops = self
1083 .active_operations
1084 .lock()
1085 .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?;
1086 active_ops.clear();
1087 }
1088
1089 {
1090 let mut diagnostic_history = self
1091 .diagnostic_history
1092 .lock()
1093 .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?;
1094 diagnostic_history.clear();
1095 }
1096
1097 {
1098 let mut stats = self
1099 .stats
1100 .lock()
1101 .map_err(|_| TorshDistributedError::backend_error("debugging", "Lock poisoned"))?;
1102 *stats = DebuggerStats::default();
1103 }
1104
1105 Ok(())
1106 }
1107}
1108
1109impl Default for DistributedDebugger {
1110 fn default() -> Self {
1111 Self::new()
1112 }
1113}
1114
1115static GLOBAL_DEBUGGER: std::sync::OnceLock<Arc<DistributedDebugger>> = std::sync::OnceLock::new();
1117
1118pub fn get_global_debugger() -> &'static Arc<DistributedDebugger> {
1120 GLOBAL_DEBUGGER.get_or_init(|| Arc::new(DistributedDebugger::new()))
1121}
1122
1123pub fn init_global_debugger(config: DebugConfig) -> TorshResult<()> {
1125 let debugger = Arc::new(DistributedDebugger::with_config(config));
1126 GLOBAL_DEBUGGER.set(debugger).map_err(|_| {
1127 TorshDistributedError::backend_error("debugging", "Global debugger already initialized")
1128 })?;
1129 Ok(())
1130}
1131
1132#[macro_export]
1134macro_rules! debug_log {
1135 ($level:expr, $source:expr, $rank:expr, $msg:expr) => {
1136 let debugger = $crate::debugging::get_global_debugger();
1137 let event = $crate::debugging::DebugEvent::new($level, $source.to_string(), $rank, $msg.to_string());
1138 let _ = debugger.log_event(event);
1139 };
1140 ($level:expr, $source:expr, $rank:expr, $msg:expr, $($key:expr => $value:expr),+) => {
1141 let debugger = $crate::debugging::get_global_debugger();
1142 let mut event = $crate::debugging::DebugEvent::new($level, $source.to_string(), $rank, $msg.to_string());
1143 $(
1144 event = event.with_context($key.to_string(), $value.to_string());
1145 )+
1146 let _ = debugger.log_event(event);
1147 };
1148}
1149
1150#[macro_export]
1151macro_rules! debug_trace_operation {
1152 ($op_type:expr, $ranks:expr, $code:block) => {{
1153 let debugger = $crate::debugging::get_global_debugger();
1154 let op_id = debugger.start_operation($op_type.to_string(), $ranks).unwrap_or_default();
1155 let result = $code;
1156 let _ = debugger.complete_operation(&op_id, true); result
1158 }};
1159}
1160
1161#[cfg(test)]
1162mod tests {
1163 use super::*;
1164
1165 #[test]
1166 fn test_debug_event_creation() {
1167 let event = DebugEvent::new(
1168 LogLevel::Info,
1169 "test_module".to_string(),
1170 0,
1171 "Test message".to_string(),
1172 )
1173 .with_context("key".to_string(), "value".to_string())
1174 .with_duration(Duration::from_millis(100));
1175
1176 assert_eq!(event.level, LogLevel::Info);
1177 assert_eq!(event.source, "test_module");
1178 assert_eq!(event.message, "Test message");
1179 assert_eq!(event.context.get("key"), Some(&"value".to_string()));
1180 assert_eq!(event.duration, Some(Duration::from_millis(100)));
1181 }
1182
1183 #[test]
1184 fn test_debugger_creation() {
1185 let debugger = DistributedDebugger::new();
1186 assert!(debugger.get_active_operations().is_empty());
1187 }
1188
1189 #[test]
1190 fn test_event_logging() {
1191 let debugger = DistributedDebugger::new();
1192 let event = DebugEvent::new(
1193 LogLevel::Info,
1194 "test".to_string(),
1195 0,
1196 "Test event".to_string(),
1197 );
1198
1199 debugger.log_event(event).unwrap();
1200
1201 let events = debugger.events.lock().expect("lock should not be poisoned");
1202 assert_eq!(events.len(), 1);
1203 assert_eq!(events[0].message, "Test event");
1204 }
1205
1206 #[test]
1207 fn test_operation_tracking() {
1208 let debugger = DistributedDebugger::new();
1209
1210 let op_id = debugger
1211 .start_operation("test_op".to_string(), vec![0, 1])
1212 .unwrap();
1213 assert!(debugger.get_active_operations().len() == 1);
1214
1215 debugger.update_operation_progress(&op_id, 50.0).unwrap();
1216 debugger.complete_operation(&op_id, true).unwrap();
1217
1218 assert!(debugger.get_active_operations().is_empty());
1219 }
1220
1221 #[test]
1222 fn test_snapshot_taking() {
1223 let debugger = DistributedDebugger::new();
1224 let snapshot = debugger.take_snapshot().unwrap();
1225
1226 assert_eq!(snapshot.process_group.backend, "Mock");
1227 assert!(snapshot.recent_errors.is_empty());
1228 }
1229
1230 #[test]
1231 fn test_diagnostics() {
1232 let debugger = DistributedDebugger::new();
1233 let results = debugger.run_diagnostics().unwrap();
1234
1235 assert!(!results.is_empty());
1236 assert!(results
1237 .iter()
1238 .any(|r| r.check_name == "Communication Health"));
1239 assert!(results
1240 .iter()
1241 .any(|r| r.check_name == "Resource Utilization"));
1242 }
1243
1244 #[test]
1245 fn test_debug_report_generation() {
1246 let debugger = DistributedDebugger::new();
1247 let report = debugger.generate_debug_report().unwrap();
1248
1249 assert!(report.contains("Distributed Training Debug Report"));
1250 assert!(report.contains("Current System State"));
1251 assert!(report.contains("Diagnostic Results"));
1252 }
1253
1254 #[test]
1255 fn test_json_export() {
1256 let debugger = DistributedDebugger::new();
1257 let event = DebugEvent::new(
1258 LogLevel::Info,
1259 "test".to_string(),
1260 0,
1261 "Export test".to_string(),
1262 );
1263 debugger.log_event(event).unwrap();
1264
1265 let json = debugger.export_debug_data().unwrap();
1266 assert!(json.contains("Export test"));
1267 assert!(json.contains("events"));
1268 assert!(json.contains("config"));
1269 }
1270
1271 #[test]
1272 fn test_log_level_ordering() {
1273 assert!(LogLevel::Critical > LogLevel::Error);
1274 assert!(LogLevel::Error > LogLevel::Warn);
1275 assert!(LogLevel::Warn > LogLevel::Info);
1276 assert!(LogLevel::Info > LogLevel::Debug);
1277 assert!(LogLevel::Debug > LogLevel::Trace);
1278 }
1279}