trustformers_debug/distributed_debugger/
functions.rs1use super::types::*;
6pub fn create_master_debugger(rank: u32, hostname: String) -> DistributedDebugger {
9 let node_id = NodeId::new(rank, hostname);
10 DistributedDebugger::new(DistributedDebugConfig::default(), node_id)
11}
12pub fn create_worker_debugger(rank: u32, hostname: String) -> DistributedDebugger {
14 let node_id = NodeId::new(rank, hostname);
15 let mut config = DistributedDebugConfig::default();
16 config.enable_auto_recovery = false;
17 DistributedDebugger::new(config, node_id)
18}
19#[macro_export]
21macro_rules! monitor_gradient_sync {
22 ($debugger:expr, $sync_round:expr, $nodes:expr, $sync_time:expr) => {{
23 let sync_event = GradientSyncEvent {
24 timestamp: std::time::SystemTime::now(),
25 sync_round: $sync_round,
26 participating_nodes: $nodes,
27 total_sync_time: $sync_time,
28 gradient_sizes: HashMap::new(),
29 compression_ratio: 1.0,
30 sync_algorithm: SyncAlgorithm::AllReduce,
31 };
32 $debugger.monitor_gradient_sync(sync_event).await
33 }};
34}
35#[cfg(test)]
36mod tests {
37 use super::*;
38 use std::collections::HashMap;
39 use std::net::{IpAddr, Ipv4Addr, SocketAddr};
40 use std::time::Duration;
41
42 #[tokio::test]
43 async fn test_distributed_debugger_creation() {
44 let node_id = NodeId::new(0, "test-node".to_string());
45 let debugger = DistributedDebugger::new(DistributedDebugConfig::default(), node_id);
46 assert_eq!(debugger.node_id.rank, 0);
47 }
48 #[tokio::test]
49 async fn test_node_info_creation() {
50 let node_id = NodeId::new(1, "worker-1".to_string());
51 let debugger = DistributedDebugger::new(DistributedDebugConfig::default(), node_id);
52 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
53 let node_info = debugger.create_node_info(addr).await.expect("async operation failed");
54 assert_eq!(node_info.node_id.rank, 1);
55 assert_eq!(node_info.status, NodeStatus::Healthy);
56 assert_eq!(node_info.address, addr);
57 }
58 #[tokio::test]
59 async fn test_gradient_sync_monitoring() {
60 let node_id = NodeId::new(0, "test-node".to_string());
61 let debugger = DistributedDebugger::new(DistributedDebugConfig::default(), node_id.clone());
62 let sync_event = GradientSyncEvent {
63 timestamp: std::time::SystemTime::now(),
64 sync_round: 1,
65 participating_nodes: vec![node_id],
66 total_sync_time: Duration::from_millis(100),
67 gradient_sizes: HashMap::new(),
68 compression_ratio: 0.8,
69 sync_algorithm: SyncAlgorithm::AllReduce,
70 };
71 let result = debugger.monitor_gradient_sync(sync_event).await;
72 assert!(result.is_ok());
73 }
74 #[tokio::test]
75 async fn test_cluster_analysis() {
76 let node_id = NodeId::new(0, "test-node".to_string());
77 let debugger = DistributedDebugger::new(DistributedDebugConfig::default(), node_id);
78 let _report = debugger.analyze_cluster_performance().await.expect("async operation failed");
79 }
81 #[tokio::test]
82 async fn test_fault_detection() {
83 let node_id = NodeId::new(0, "test-node".to_string());
84 let debugger = DistributedDebugger::new(DistributedDebugConfig::default(), node_id);
85 let _faults = debugger.detect_faults().await.expect("async operation failed");
86 }
88 #[tokio::test]
89 async fn test_distributed_debug_report() {
90 let node_id = NodeId::new(0, "test-node".to_string());
91 let debugger = DistributedDebugger::new(DistributedDebugConfig::default(), node_id);
92 let report = debugger
93 .generate_distributed_debug_report()
94 .await
95 .expect("async operation failed");
96 assert!(!report.recommendations.is_empty());
97 }
98 #[test]
99 fn test_convenience_functions() {
100 let master = create_master_debugger(0, "master".to_string());
101 let worker = create_worker_debugger(1, "worker-1".to_string());
102 assert_eq!(master.node_id.rank, 0);
103 assert_eq!(worker.node_id.rank, 1);
104 assert!(!worker.config.enable_auto_recovery);
105 }
106}