Skip to main content

torsh_distributed/
profiling.rs

1//! Communication profiling and performance monitoring for distributed training
2//!
3//! This module provides comprehensive profiling capabilities for distributed communication
4//! operations, including timing measurements, bandwidth analysis, and performance statistics.
5
6use crate::{TorshDistributedError, TorshResult};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::{Arc, Mutex, RwLock};
10use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
11
12/// Type of communication operation being profiled
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
14pub enum CommunicationOpType {
15    AllReduce,
16    AllGather,
17    ReduceScatter,
18    Broadcast,
19    Reduce,
20    Scatter,
21    Gather,
22    Send,
23    Recv,
24    Barrier,
25    AllToAll,
26    Custom(u32),
27}
28
29impl std::fmt::Display for CommunicationOpType {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        match self {
32            CommunicationOpType::AllReduce => write!(f, "AllReduce"),
33            CommunicationOpType::AllGather => write!(f, "AllGather"),
34            CommunicationOpType::ReduceScatter => write!(f, "ReduceScatter"),
35            CommunicationOpType::Broadcast => write!(f, "Broadcast"),
36            CommunicationOpType::Reduce => write!(f, "Reduce"),
37            CommunicationOpType::Scatter => write!(f, "Scatter"),
38            CommunicationOpType::Gather => write!(f, "Gather"),
39            CommunicationOpType::Send => write!(f, "Send"),
40            CommunicationOpType::Recv => write!(f, "Recv"),
41            CommunicationOpType::Barrier => write!(f, "Barrier"),
42            CommunicationOpType::AllToAll => write!(f, "AllToAll"),
43            CommunicationOpType::Custom(id) => write!(f, "Custom({})", id),
44        }
45    }
46}
47
48/// Individual communication event record
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct CommunicationEvent {
51    /// Unique event identifier
52    pub event_id: u64,
53    /// Type of communication operation
54    pub op_type: CommunicationOpType,
55    /// Rank of the process that initiated the operation
56    pub rank: u32,
57    /// World size at the time of operation
58    pub world_size: u32,
59    /// Size of data transferred in bytes
60    pub data_size_bytes: usize,
61    /// Start timestamp
62    pub start_time: SystemTime,
63    /// Duration of the operation
64    pub duration: Duration,
65    /// Bandwidth achieved (bytes per second)
66    pub bandwidth_bps: f64,
67    /// Additional metadata
68    pub metadata: HashMap<String, String>,
69}
70
71impl CommunicationEvent {
72    /// Create a new communication event
73    pub fn new(
74        event_id: u64,
75        op_type: CommunicationOpType,
76        rank: u32,
77        world_size: u32,
78        data_size_bytes: usize,
79        start_time: SystemTime,
80        duration: Duration,
81    ) -> Self {
82        let bandwidth_bps = if duration.as_secs_f64() > 0.0 {
83            data_size_bytes as f64 / duration.as_secs_f64()
84        } else {
85            0.0
86        };
87
88        Self {
89            event_id,
90            op_type,
91            rank,
92            world_size,
93            data_size_bytes,
94            start_time,
95            duration,
96            bandwidth_bps,
97            metadata: HashMap::new(),
98        }
99    }
100
101    /// Add metadata to the event
102    pub fn with_metadata(mut self, key: String, value: String) -> Self {
103        self.metadata.insert(key, value);
104        self
105    }
106
107    /// Get latency in milliseconds
108    pub fn latency_ms(&self) -> f64 {
109        self.duration.as_secs_f64() * 1000.0
110    }
111
112    /// Get bandwidth in MB/s
113    pub fn bandwidth_mbps(&self) -> f64 {
114        self.bandwidth_bps / (1024.0 * 1024.0)
115    }
116}
117
118/// Statistics for a specific communication operation type
119#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct OperationStats {
121    /// Total number of operations
122    pub count: u64,
123    /// Total data transferred (bytes)
124    pub total_bytes: u64,
125    /// Total time spent (duration)
126    pub total_duration: Duration,
127    /// Minimum latency observed
128    pub min_latency: Duration,
129    /// Maximum latency observed
130    pub max_latency: Duration,
131    /// Average latency
132    pub avg_latency: Duration,
133    /// Average bandwidth (bytes per second)
134    pub avg_bandwidth_bps: f64,
135    /// 95th percentile latency
136    pub p95_latency: Duration,
137    /// 99th percentile latency
138    pub p99_latency: Duration,
139}
140
141impl Default for OperationStats {
142    fn default() -> Self {
143        Self {
144            count: 0,
145            total_bytes: 0,
146            total_duration: Duration::ZERO,
147            min_latency: Duration::MAX,
148            max_latency: Duration::ZERO,
149            avg_latency: Duration::ZERO,
150            avg_bandwidth_bps: 0.0,
151            p95_latency: Duration::ZERO,
152            p99_latency: Duration::ZERO,
153        }
154    }
155}
156
157impl OperationStats {
158    /// Add a new event to the statistics
159    pub fn add_event(&mut self, event: &CommunicationEvent) {
160        self.count += 1;
161        self.total_bytes += event.data_size_bytes as u64;
162        self.total_duration += event.duration;
163
164        if event.duration < self.min_latency {
165            self.min_latency = event.duration;
166        }
167        if event.duration > self.max_latency {
168            self.max_latency = event.duration;
169        }
170
171        // Recalculate averages
172        self.avg_latency = self.total_duration / self.count as u32;
173        if self.total_duration.as_secs_f64() > 0.0 {
174            self.avg_bandwidth_bps = self.total_bytes as f64 / self.total_duration.as_secs_f64();
175        }
176    }
177
178    /// Calculate percentiles from a list of durations
179    pub fn calculate_percentiles(&mut self, durations: &mut [Duration]) {
180        if durations.is_empty() {
181            return;
182        }
183
184        durations.sort();
185        let len = durations.len();
186
187        let p95_idx = (len as f64 * 0.95).ceil() as usize - 1;
188        let p99_idx = (len as f64 * 0.99).ceil() as usize - 1;
189
190        self.p95_latency = durations[p95_idx.min(len - 1)];
191        self.p99_latency = durations[p99_idx.min(len - 1)];
192    }
193}
194
195/// Profiling configuration
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct ProfilingConfig {
198    /// Whether profiling is enabled
199    pub enabled: bool,
200    /// Maximum number of events to keep in memory
201    pub max_events: usize,
202    /// Whether to track detailed per-operation statistics
203    pub track_per_operation_stats: bool,
204    /// Whether to track per-rank statistics
205    pub track_per_rank_stats: bool,
206    /// Sampling rate (0.0 to 1.0, 1.0 means profile all operations)
207    pub sampling_rate: f64,
208    /// Minimum operation duration to record (microseconds)
209    pub min_duration_us: u64,
210}
211
212impl Default for ProfilingConfig {
213    fn default() -> Self {
214        Self {
215            enabled: true,
216            max_events: 10000,
217            track_per_operation_stats: true,
218            track_per_rank_stats: true,
219            sampling_rate: 1.0,
220            min_duration_us: 0,
221        }
222    }
223}
224
225/// Thread-safe communication profiler
226pub struct CommunicationProfiler {
227    /// Configuration
228    config: RwLock<ProfilingConfig>,
229    /// Event counter for unique IDs
230    event_counter: Mutex<u64>,
231    /// Circular buffer of recent events
232    events: Mutex<Vec<CommunicationEvent>>,
233    /// Statistics per operation type
234    operation_stats: RwLock<HashMap<CommunicationOpType, OperationStats>>,
235    /// Statistics per rank
236    rank_stats: RwLock<HashMap<u32, HashMap<CommunicationOpType, OperationStats>>>,
237    /// Global start time for relative timestamps
238    start_time: SystemTime,
239}
240
241impl CommunicationProfiler {
242    /// Create a new profiler with default configuration
243    pub fn new() -> Self {
244        Self::with_config(ProfilingConfig::default())
245    }
246
247    /// Create a new profiler with custom configuration
248    pub fn with_config(config: ProfilingConfig) -> Self {
249        Self {
250            config: RwLock::new(config),
251            event_counter: Mutex::new(0),
252            events: Mutex::new(Vec::new()),
253            operation_stats: RwLock::new(HashMap::new()),
254            rank_stats: RwLock::new(HashMap::new()),
255            start_time: SystemTime::now(),
256        }
257    }
258
259    /// Start timing a communication operation
260    pub fn start_timing(&self) -> ProfilingTimer {
261        ProfilingTimer::new()
262    }
263
264    /// Record a communication event
265    pub fn record_event(
266        &self,
267        op_type: CommunicationOpType,
268        rank: u32,
269        world_size: u32,
270        data_size_bytes: usize,
271        timer: ProfilingTimer,
272    ) -> TorshResult<()> {
273        let config = self
274            .config
275            .read()
276            .map_err(|_| TorshDistributedError::backend_error("profiling", "Lock poisoned"))?;
277
278        if !config.enabled {
279            return Ok(());
280        }
281
282        let duration = timer.elapsed();
283
284        // Skip if duration is below threshold
285        if duration.as_micros() < config.min_duration_us as u128 {
286            return Ok(());
287        }
288
289        // Apply sampling
290        if config.sampling_rate < 1.0 {
291            use std::collections::hash_map::DefaultHasher;
292            use std::hash::{Hash, Hasher};
293
294            let mut hasher = DefaultHasher::new();
295            (
296                rank,
297                SystemTime::now()
298                    .duration_since(UNIX_EPOCH)
299                    .unwrap_or_default()
300                    .as_nanos(),
301            )
302                .hash(&mut hasher);
303            let hash_val = hasher.finish();
304            let sample_threshold = (u64::MAX as f64 * config.sampling_rate) as u64;
305
306            if hash_val > sample_threshold {
307                return Ok(());
308            }
309        }
310
311        // Generate unique event ID
312        let event_id = {
313            let mut counter = self
314                .event_counter
315                .lock()
316                .map_err(|_| TorshDistributedError::backend_error("profiling", "Lock poisoned"))?;
317            *counter += 1;
318            *counter
319        };
320
321        // Create event
322        let event = CommunicationEvent::new(
323            event_id,
324            op_type,
325            rank,
326            world_size,
327            data_size_bytes,
328            timer.start_time,
329            duration,
330        );
331
332        // Store event
333        {
334            let mut events = self
335                .events
336                .lock()
337                .map_err(|_| TorshDistributedError::backend_error("profiling", "Lock poisoned"))?;
338            events.push(event.clone());
339
340            // Maintain circular buffer
341            if events.len() > config.max_events {
342                events.remove(0);
343            }
344        }
345
346        // Update statistics
347        if config.track_per_operation_stats {
348            let mut stats = self
349                .operation_stats
350                .write()
351                .map_err(|_| TorshDistributedError::backend_error("profiling", "Lock poisoned"))?;
352            stats.entry(op_type).or_default().add_event(&event);
353        }
354
355        if config.track_per_rank_stats {
356            let mut rank_stats = self
357                .rank_stats
358                .write()
359                .map_err(|_| TorshDistributedError::backend_error("profiling", "Lock poisoned"))?;
360            rank_stats
361                .entry(rank)
362                .or_default()
363                .entry(op_type)
364                .or_default()
365                .add_event(&event);
366        }
367
368        Ok(())
369    }
370
371    /// Get statistics for a specific operation type
372    pub fn get_operation_stats(
373        &self,
374        op_type: CommunicationOpType,
375    ) -> TorshResult<Option<OperationStats>> {
376        let stats = self
377            .operation_stats
378            .read()
379            .map_err(|_| TorshDistributedError::backend_error("profiling", "Lock poisoned"))?;
380        Ok(stats.get(&op_type).cloned())
381    }
382
383    /// Get all operation statistics
384    pub fn get_all_operation_stats(
385        &self,
386    ) -> TorshResult<HashMap<CommunicationOpType, OperationStats>> {
387        let stats = self
388            .operation_stats
389            .read()
390            .map_err(|_| TorshDistributedError::backend_error("profiling", "Lock poisoned"))?;
391        Ok(stats.clone())
392    }
393
394    /// Get statistics for a specific rank
395    pub fn get_rank_stats(
396        &self,
397        rank: u32,
398    ) -> TorshResult<Option<HashMap<CommunicationOpType, OperationStats>>> {
399        let rank_stats = self
400            .rank_stats
401            .read()
402            .map_err(|_| TorshDistributedError::backend_error("profiling", "Lock poisoned"))?;
403        Ok(rank_stats.get(&rank).cloned())
404    }
405
406    /// Get recent events (last N events)
407    pub fn get_recent_events(&self, count: usize) -> TorshResult<Vec<CommunicationEvent>> {
408        let events = self
409            .events
410            .lock()
411            .map_err(|_| TorshDistributedError::backend_error("profiling", "Lock poisoned"))?;
412        let start_idx = events.len().saturating_sub(count);
413        Ok(events[start_idx..].to_vec())
414    }
415
416    /// Get all events
417    pub fn get_all_events(&self) -> TorshResult<Vec<CommunicationEvent>> {
418        let events = self
419            .events
420            .lock()
421            .map_err(|_| TorshDistributedError::backend_error("profiling", "Lock poisoned"))?;
422        Ok(events.clone())
423    }
424
425    /// Get the count of failed operations across all ranks and operation types
426    pub fn get_failed_operations_count(&self) -> u64 {
427        let events = match self.events.lock() {
428            Ok(events) => events,
429            Err(_) => return 0, // Return 0 if lock is poisoned
430        };
431
432        // Count events that indicate failures (placeholder implementation)
433        // In a real implementation, you would track operation success/failure explicitly
434        events
435            .iter()
436            .filter(|event| {
437                // Consider events with very high latency as potential failures
438                // This is a heuristic approach for demonstration
439                event.duration.as_millis() > 10000 || event.metadata.contains_key("error")
440            })
441            .count() as u64
442    }
443
444    /// Clear all profiling data
445    pub fn clear(&self) -> TorshResult<()> {
446        {
447            let mut events = self
448                .events
449                .lock()
450                .map_err(|_| TorshDistributedError::backend_error("profiling", "Lock poisoned"))?;
451            events.clear();
452        }
453
454        {
455            let mut stats = self
456                .operation_stats
457                .write()
458                .map_err(|_| TorshDistributedError::backend_error("profiling", "Lock poisoned"))?;
459            stats.clear();
460        }
461
462        {
463            let mut rank_stats = self
464                .rank_stats
465                .write()
466                .map_err(|_| TorshDistributedError::backend_error("profiling", "Lock poisoned"))?;
467            rank_stats.clear();
468        }
469
470        {
471            let mut counter = self
472                .event_counter
473                .lock()
474                .map_err(|_| TorshDistributedError::backend_error("profiling", "Lock poisoned"))?;
475            *counter = 0;
476        }
477
478        Ok(())
479    }
480
481    /// Update configuration
482    pub fn update_config(&self, config: ProfilingConfig) -> TorshResult<()> {
483        let mut current_config = self
484            .config
485            .write()
486            .map_err(|_| TorshDistributedError::backend_error("profiling", "Lock poisoned"))?;
487        *current_config = config;
488        Ok(())
489    }
490
491    /// Export profiling data to JSON
492    pub fn export_json(&self) -> TorshResult<String> {
493        #[derive(Serialize)]
494        struct ExportData {
495            config: ProfilingConfig,
496            events: Vec<CommunicationEvent>,
497            operation_stats: HashMap<CommunicationOpType, OperationStats>,
498            rank_stats: HashMap<u32, HashMap<CommunicationOpType, OperationStats>>,
499        }
500
501        let config = self
502            .config
503            .read()
504            .map_err(|_| TorshDistributedError::backend_error("profiling", "Lock poisoned"))?
505            .clone();
506        let events = self.get_all_events()?;
507        let operation_stats = self.get_all_operation_stats()?;
508        let rank_stats = self
509            .rank_stats
510            .read()
511            .map_err(|_| TorshDistributedError::backend_error("profiling", "Lock poisoned"))?
512            .clone();
513
514        let export_data = ExportData {
515            config,
516            events,
517            operation_stats,
518            rank_stats,
519        };
520
521        serde_json::to_string_pretty(&export_data).map_err(|e| {
522            TorshDistributedError::backend_error(
523                "profiling",
524                format!("JSON serialization failed: {}", e),
525            )
526        })
527    }
528
529    /// Generate a summary report
530    pub fn generate_summary(&self) -> TorshResult<String> {
531        let mut report = String::new();
532        report.push_str("=== Communication Profiling Summary ===\n\n");
533
534        let events = self.get_all_events()?;
535        let operation_stats = self.get_all_operation_stats()?;
536
537        report.push_str(&format!("Total Events: {}\n", events.len()));
538        report.push_str(&format!(
539            "Profiling Duration: {:.2}s\n\n",
540            SystemTime::now()
541                .duration_since(self.start_time)
542                .unwrap_or_default()
543                .as_secs_f64()
544        ));
545
546        report.push_str("=== Per-Operation Statistics ===\n");
547        for (op_type, stats) in operation_stats.iter() {
548            report.push_str(&format!("\n{} Operations:\n", op_type));
549            report.push_str(&format!("  Count: {}\n", stats.count));
550            report.push_str(&format!(
551                "  Total Data: {:.2} MB\n",
552                stats.total_bytes as f64 / (1024.0 * 1024.0)
553            ));
554            report.push_str(&format!(
555                "  Avg Latency: {:.2} ms\n",
556                stats.avg_latency.as_secs_f64() * 1000.0
557            ));
558            report.push_str(&format!(
559                "  Min Latency: {:.2} ms\n",
560                stats.min_latency.as_secs_f64() * 1000.0
561            ));
562            report.push_str(&format!(
563                "  Max Latency: {:.2} ms\n",
564                stats.max_latency.as_secs_f64() * 1000.0
565            ));
566            report.push_str(&format!(
567                "  Avg Bandwidth: {:.2} MB/s\n",
568                stats.avg_bandwidth_bps / (1024.0 * 1024.0)
569            ));
570        }
571
572        Ok(report)
573    }
574}
575
576impl Default for CommunicationProfiler {
577    fn default() -> Self {
578        Self::new()
579    }
580}
581
582/// Timer for measuring communication operation duration
583pub struct ProfilingTimer {
584    start_time: SystemTime,
585    start_instant: Instant,
586}
587
588impl Default for ProfilingTimer {
589    fn default() -> Self {
590        Self::new()
591    }
592}
593
594impl ProfilingTimer {
595    /// Create a new timer and start timing
596    pub fn new() -> Self {
597        Self {
598            start_time: SystemTime::now(),
599            start_instant: Instant::now(),
600        }
601    }
602
603    /// Get elapsed duration
604    pub fn elapsed(&self) -> Duration {
605        self.start_instant.elapsed()
606    }
607
608    /// Get start time
609    pub fn start_time(&self) -> SystemTime {
610        self.start_time
611    }
612}
613
614/// Global profiler instance
615static GLOBAL_PROFILER: std::sync::OnceLock<Arc<CommunicationProfiler>> =
616    std::sync::OnceLock::new();
617
618/// Get the global profiler instance
619pub fn get_global_profiler() -> &'static Arc<CommunicationProfiler> {
620    GLOBAL_PROFILER.get_or_init(|| Arc::new(CommunicationProfiler::new()))
621}
622
623/// Initialize the global profiler with custom configuration
624pub fn init_global_profiler(config: ProfilingConfig) -> TorshResult<()> {
625    let profiler = Arc::new(CommunicationProfiler::with_config(config));
626    GLOBAL_PROFILER.set(profiler).map_err(|_| {
627        TorshDistributedError::backend_error("profiling", "Global profiler already initialized")
628    })?;
629    Ok(())
630}
631
632/// Convenience macro for profiling communication operations
633#[macro_export]
634macro_rules! profile_communication {
635    ($op_type:expr, $rank:expr, $world_size:expr, $data_size:expr, $code:block) => {{
636        let profiler = $crate::profiling::get_global_profiler();
637        let timer = profiler.start_timing();
638        let result = $code;
639        let _ = profiler.record_event($op_type, $rank, $world_size, $data_size, timer);
640        result
641    }};
642}
643
644#[cfg(test)]
645mod tests {
646    use super::*;
647
648    #[test]
649    fn test_profiler_creation() {
650        let profiler = CommunicationProfiler::new();
651        let stats = profiler.get_all_operation_stats().unwrap();
652        assert!(stats.is_empty());
653    }
654
655    #[test]
656    fn test_event_recording() {
657        let profiler = CommunicationProfiler::new();
658        let timer = profiler.start_timing();
659        std::thread::sleep(Duration::from_millis(10));
660
661        profiler
662            .record_event(CommunicationOpType::AllReduce, 0, 4, 1024, timer)
663            .unwrap();
664
665        let events = profiler.get_all_events().unwrap();
666        assert_eq!(events.len(), 1);
667        assert_eq!(events[0].op_type, CommunicationOpType::AllReduce);
668        assert_eq!(events[0].data_size_bytes, 1024);
669    }
670
671    #[test]
672    fn test_operation_stats() {
673        let profiler = CommunicationProfiler::new();
674
675        // Record multiple events
676        for i in 0..5 {
677            let timer = profiler.start_timing();
678            std::thread::sleep(Duration::from_millis(1));
679            profiler
680                .record_event(CommunicationOpType::AllReduce, 0, 4, 1024 * (i + 1), timer)
681                .unwrap();
682        }
683
684        let stats = profiler
685            .get_operation_stats(CommunicationOpType::AllReduce)
686            .unwrap();
687        assert!(stats.is_some());
688        let stats = stats.unwrap();
689        assert_eq!(stats.count, 5);
690        assert_eq!(stats.total_bytes, 1024 + 2048 + 3072 + 4096 + 5120);
691    }
692
693    #[test]
694    fn test_profiler_macro() {
695        let result = profile_communication!(CommunicationOpType::Broadcast, 0, 4, 2048, {
696            std::thread::sleep(Duration::from_millis(5));
697            42
698        });
699
700        assert_eq!(result, 42);
701
702        let profiler = get_global_profiler();
703        let events = profiler.get_all_events().unwrap();
704        assert!(!events.is_empty());
705    }
706
707    #[test]
708    fn test_json_export() {
709        let profiler = CommunicationProfiler::new();
710        let timer = profiler.start_timing();
711        std::thread::sleep(Duration::from_millis(1));
712
713        profiler
714            .record_event(CommunicationOpType::AllGather, 0, 4, 512, timer)
715            .unwrap();
716
717        let json = profiler.export_json().unwrap();
718        assert!(json.contains("AllGather"));
719        assert!(json.contains("512"));
720    }
721
722    #[test]
723    fn test_summary_generation() {
724        let profiler = CommunicationProfiler::new();
725        let timer = profiler.start_timing();
726        std::thread::sleep(Duration::from_millis(1));
727
728        profiler
729            .record_event(CommunicationOpType::Reduce, 0, 4, 256, timer)
730            .unwrap();
731
732        let summary = profiler.generate_summary().unwrap();
733        assert!(summary.contains("Communication Profiling Summary"));
734        assert!(summary.contains("Reduce Operations"));
735    }
736}