Skip to main content

trustformers_debug/distributed_debugger/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use super::types::*;
6/// Convenience functions for distributed debugging
7/// Create a distributed debugger for a master node
8pub fn create_master_debugger(rank: u32, hostname: String) -> DistributedDebugger {
9    let node_id = NodeId::new(rank, hostname);
10    DistributedDebugger::new(DistributedDebugConfig::default(), node_id)
11}
12/// Create a distributed debugger for a worker node
13pub fn create_worker_debugger(rank: u32, hostname: String) -> DistributedDebugger {
14    let node_id = NodeId::new(rank, hostname);
15    let mut config = DistributedDebugConfig::default();
16    config.enable_auto_recovery = false;
17    DistributedDebugger::new(config, node_id)
18}
19/// Macro for monitoring gradient synchronization
20#[macro_export]
21macro_rules! monitor_gradient_sync {
22    ($debugger:expr, $sync_round:expr, $nodes:expr, $sync_time:expr) => {{
23        let sync_event = GradientSyncEvent {
24            timestamp: std::time::SystemTime::now(),
25            sync_round: $sync_round,
26            participating_nodes: $nodes,
27            total_sync_time: $sync_time,
28            gradient_sizes: HashMap::new(),
29            compression_ratio: 1.0,
30            sync_algorithm: SyncAlgorithm::AllReduce,
31        };
32        $debugger.monitor_gradient_sync(sync_event).await
33    }};
34}
35#[cfg(test)]
36mod tests {
37    use super::*;
38    use std::collections::HashMap;
39    use std::net::{IpAddr, Ipv4Addr, SocketAddr};
40    use std::time::Duration;
41
42    #[tokio::test]
43    async fn test_distributed_debugger_creation() {
44        let node_id = NodeId::new(0, "test-node".to_string());
45        let debugger = DistributedDebugger::new(DistributedDebugConfig::default(), node_id);
46        assert_eq!(debugger.node_id.rank, 0);
47    }
48    #[tokio::test]
49    async fn test_node_info_creation() {
50        let node_id = NodeId::new(1, "worker-1".to_string());
51        let debugger = DistributedDebugger::new(DistributedDebugConfig::default(), node_id);
52        let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
53        let node_info = debugger.create_node_info(addr).await.expect("async operation failed");
54        assert_eq!(node_info.node_id.rank, 1);
55        assert_eq!(node_info.status, NodeStatus::Healthy);
56        assert_eq!(node_info.address, addr);
57    }
58    #[tokio::test]
59    async fn test_gradient_sync_monitoring() {
60        let node_id = NodeId::new(0, "test-node".to_string());
61        let debugger = DistributedDebugger::new(DistributedDebugConfig::default(), node_id.clone());
62        let sync_event = GradientSyncEvent {
63            timestamp: std::time::SystemTime::now(),
64            sync_round: 1,
65            participating_nodes: vec![node_id],
66            total_sync_time: Duration::from_millis(100),
67            gradient_sizes: HashMap::new(),
68            compression_ratio: 0.8,
69            sync_algorithm: SyncAlgorithm::AllReduce,
70        };
71        let result = debugger.monitor_gradient_sync(sync_event).await;
72        assert!(result.is_ok());
73    }
74    #[tokio::test]
75    async fn test_cluster_analysis() {
76        let node_id = NodeId::new(0, "test-node".to_string());
77        let debugger = DistributedDebugger::new(DistributedDebugConfig::default(), node_id);
78        let _report = debugger.analyze_cluster_performance().await.expect("async operation failed");
79        // Successfully generated cluster analysis report
80    }
81    #[tokio::test]
82    async fn test_fault_detection() {
83        let node_id = NodeId::new(0, "test-node".to_string());
84        let debugger = DistributedDebugger::new(DistributedDebugConfig::default(), node_id);
85        let _faults = debugger.detect_faults().await.expect("async operation failed");
86        // Successfully detected faults
87    }
88    #[tokio::test]
89    async fn test_distributed_debug_report() {
90        let node_id = NodeId::new(0, "test-node".to_string());
91        let debugger = DistributedDebugger::new(DistributedDebugConfig::default(), node_id);
92        let report = debugger
93            .generate_distributed_debug_report()
94            .await
95            .expect("async operation failed");
96        assert!(!report.recommendations.is_empty());
97    }
98    #[test]
99    fn test_convenience_functions() {
100        let master = create_master_debugger(0, "master".to_string());
101        let worker = create_worker_debugger(1, "worker-1".to_string());
102        assert_eq!(master.node_id.rank, 0);
103        assert_eq!(worker.node_id.rank, 1);
104        assert!(!worker.config.enable_auto_recovery);
105    }
106}