1use crate::{TorshDistributedError, TorshResult};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::{Arc, Mutex, RwLock};
10use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
14pub enum CommunicationOpType {
15 AllReduce,
16 AllGather,
17 ReduceScatter,
18 Broadcast,
19 Reduce,
20 Scatter,
21 Gather,
22 Send,
23 Recv,
24 Barrier,
25 AllToAll,
26 Custom(u32),
27}
28
29impl std::fmt::Display for CommunicationOpType {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 match self {
32 CommunicationOpType::AllReduce => write!(f, "AllReduce"),
33 CommunicationOpType::AllGather => write!(f, "AllGather"),
34 CommunicationOpType::ReduceScatter => write!(f, "ReduceScatter"),
35 CommunicationOpType::Broadcast => write!(f, "Broadcast"),
36 CommunicationOpType::Reduce => write!(f, "Reduce"),
37 CommunicationOpType::Scatter => write!(f, "Scatter"),
38 CommunicationOpType::Gather => write!(f, "Gather"),
39 CommunicationOpType::Send => write!(f, "Send"),
40 CommunicationOpType::Recv => write!(f, "Recv"),
41 CommunicationOpType::Barrier => write!(f, "Barrier"),
42 CommunicationOpType::AllToAll => write!(f, "AllToAll"),
43 CommunicationOpType::Custom(id) => write!(f, "Custom({})", id),
44 }
45 }
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct CommunicationEvent {
51 pub event_id: u64,
53 pub op_type: CommunicationOpType,
55 pub rank: u32,
57 pub world_size: u32,
59 pub data_size_bytes: usize,
61 pub start_time: SystemTime,
63 pub duration: Duration,
65 pub bandwidth_bps: f64,
67 pub metadata: HashMap<String, String>,
69}
70
71impl CommunicationEvent {
72 pub fn new(
74 event_id: u64,
75 op_type: CommunicationOpType,
76 rank: u32,
77 world_size: u32,
78 data_size_bytes: usize,
79 start_time: SystemTime,
80 duration: Duration,
81 ) -> Self {
82 let bandwidth_bps = if duration.as_secs_f64() > 0.0 {
83 data_size_bytes as f64 / duration.as_secs_f64()
84 } else {
85 0.0
86 };
87
88 Self {
89 event_id,
90 op_type,
91 rank,
92 world_size,
93 data_size_bytes,
94 start_time,
95 duration,
96 bandwidth_bps,
97 metadata: HashMap::new(),
98 }
99 }
100
101 pub fn with_metadata(mut self, key: String, value: String) -> Self {
103 self.metadata.insert(key, value);
104 self
105 }
106
107 pub fn latency_ms(&self) -> f64 {
109 self.duration.as_secs_f64() * 1000.0
110 }
111
112 pub fn bandwidth_mbps(&self) -> f64 {
114 self.bandwidth_bps / (1024.0 * 1024.0)
115 }
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct OperationStats {
121 pub count: u64,
123 pub total_bytes: u64,
125 pub total_duration: Duration,
127 pub min_latency: Duration,
129 pub max_latency: Duration,
131 pub avg_latency: Duration,
133 pub avg_bandwidth_bps: f64,
135 pub p95_latency: Duration,
137 pub p99_latency: Duration,
139}
140
141impl Default for OperationStats {
142 fn default() -> Self {
143 Self {
144 count: 0,
145 total_bytes: 0,
146 total_duration: Duration::ZERO,
147 min_latency: Duration::MAX,
148 max_latency: Duration::ZERO,
149 avg_latency: Duration::ZERO,
150 avg_bandwidth_bps: 0.0,
151 p95_latency: Duration::ZERO,
152 p99_latency: Duration::ZERO,
153 }
154 }
155}
156
157impl OperationStats {
158 pub fn add_event(&mut self, event: &CommunicationEvent) {
160 self.count += 1;
161 self.total_bytes += event.data_size_bytes as u64;
162 self.total_duration += event.duration;
163
164 if event.duration < self.min_latency {
165 self.min_latency = event.duration;
166 }
167 if event.duration > self.max_latency {
168 self.max_latency = event.duration;
169 }
170
171 self.avg_latency = self.total_duration / self.count as u32;
173 if self.total_duration.as_secs_f64() > 0.0 {
174 self.avg_bandwidth_bps = self.total_bytes as f64 / self.total_duration.as_secs_f64();
175 }
176 }
177
178 pub fn calculate_percentiles(&mut self, durations: &mut [Duration]) {
180 if durations.is_empty() {
181 return;
182 }
183
184 durations.sort();
185 let len = durations.len();
186
187 let p95_idx = (len as f64 * 0.95).ceil() as usize - 1;
188 let p99_idx = (len as f64 * 0.99).ceil() as usize - 1;
189
190 self.p95_latency = durations[p95_idx.min(len - 1)];
191 self.p99_latency = durations[p99_idx.min(len - 1)];
192 }
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct ProfilingConfig {
198 pub enabled: bool,
200 pub max_events: usize,
202 pub track_per_operation_stats: bool,
204 pub track_per_rank_stats: bool,
206 pub sampling_rate: f64,
208 pub min_duration_us: u64,
210}
211
212impl Default for ProfilingConfig {
213 fn default() -> Self {
214 Self {
215 enabled: true,
216 max_events: 10000,
217 track_per_operation_stats: true,
218 track_per_rank_stats: true,
219 sampling_rate: 1.0,
220 min_duration_us: 0,
221 }
222 }
223}
224
225pub struct CommunicationProfiler {
227 config: RwLock<ProfilingConfig>,
229 event_counter: Mutex<u64>,
231 events: Mutex<Vec<CommunicationEvent>>,
233 operation_stats: RwLock<HashMap<CommunicationOpType, OperationStats>>,
235 rank_stats: RwLock<HashMap<u32, HashMap<CommunicationOpType, OperationStats>>>,
237 start_time: SystemTime,
239}
240
241impl CommunicationProfiler {
242 pub fn new() -> Self {
244 Self::with_config(ProfilingConfig::default())
245 }
246
247 pub fn with_config(config: ProfilingConfig) -> Self {
249 Self {
250 config: RwLock::new(config),
251 event_counter: Mutex::new(0),
252 events: Mutex::new(Vec::new()),
253 operation_stats: RwLock::new(HashMap::new()),
254 rank_stats: RwLock::new(HashMap::new()),
255 start_time: SystemTime::now(),
256 }
257 }
258
259 pub fn start_timing(&self) -> ProfilingTimer {
261 ProfilingTimer::new()
262 }
263
264 pub fn record_event(
266 &self,
267 op_type: CommunicationOpType,
268 rank: u32,
269 world_size: u32,
270 data_size_bytes: usize,
271 timer: ProfilingTimer,
272 ) -> TorshResult<()> {
273 let config = self
274 .config
275 .read()
276 .map_err(|_| TorshDistributedError::backend_error("profiling", "Lock poisoned"))?;
277
278 if !config.enabled {
279 return Ok(());
280 }
281
282 let duration = timer.elapsed();
283
284 if duration.as_micros() < config.min_duration_us as u128 {
286 return Ok(());
287 }
288
289 if config.sampling_rate < 1.0 {
291 use std::collections::hash_map::DefaultHasher;
292 use std::hash::{Hash, Hasher};
293
294 let mut hasher = DefaultHasher::new();
295 (
296 rank,
297 SystemTime::now()
298 .duration_since(UNIX_EPOCH)
299 .unwrap_or_default()
300 .as_nanos(),
301 )
302 .hash(&mut hasher);
303 let hash_val = hasher.finish();
304 let sample_threshold = (u64::MAX as f64 * config.sampling_rate) as u64;
305
306 if hash_val > sample_threshold {
307 return Ok(());
308 }
309 }
310
311 let event_id = {
313 let mut counter = self
314 .event_counter
315 .lock()
316 .map_err(|_| TorshDistributedError::backend_error("profiling", "Lock poisoned"))?;
317 *counter += 1;
318 *counter
319 };
320
321 let event = CommunicationEvent::new(
323 event_id,
324 op_type,
325 rank,
326 world_size,
327 data_size_bytes,
328 timer.start_time,
329 duration,
330 );
331
332 {
334 let mut events = self
335 .events
336 .lock()
337 .map_err(|_| TorshDistributedError::backend_error("profiling", "Lock poisoned"))?;
338 events.push(event.clone());
339
340 if events.len() > config.max_events {
342 events.remove(0);
343 }
344 }
345
346 if config.track_per_operation_stats {
348 let mut stats = self
349 .operation_stats
350 .write()
351 .map_err(|_| TorshDistributedError::backend_error("profiling", "Lock poisoned"))?;
352 stats.entry(op_type).or_default().add_event(&event);
353 }
354
355 if config.track_per_rank_stats {
356 let mut rank_stats = self
357 .rank_stats
358 .write()
359 .map_err(|_| TorshDistributedError::backend_error("profiling", "Lock poisoned"))?;
360 rank_stats
361 .entry(rank)
362 .or_default()
363 .entry(op_type)
364 .or_default()
365 .add_event(&event);
366 }
367
368 Ok(())
369 }
370
371 pub fn get_operation_stats(
373 &self,
374 op_type: CommunicationOpType,
375 ) -> TorshResult<Option<OperationStats>> {
376 let stats = self
377 .operation_stats
378 .read()
379 .map_err(|_| TorshDistributedError::backend_error("profiling", "Lock poisoned"))?;
380 Ok(stats.get(&op_type).cloned())
381 }
382
383 pub fn get_all_operation_stats(
385 &self,
386 ) -> TorshResult<HashMap<CommunicationOpType, OperationStats>> {
387 let stats = self
388 .operation_stats
389 .read()
390 .map_err(|_| TorshDistributedError::backend_error("profiling", "Lock poisoned"))?;
391 Ok(stats.clone())
392 }
393
394 pub fn get_rank_stats(
396 &self,
397 rank: u32,
398 ) -> TorshResult<Option<HashMap<CommunicationOpType, OperationStats>>> {
399 let rank_stats = self
400 .rank_stats
401 .read()
402 .map_err(|_| TorshDistributedError::backend_error("profiling", "Lock poisoned"))?;
403 Ok(rank_stats.get(&rank).cloned())
404 }
405
406 pub fn get_recent_events(&self, count: usize) -> TorshResult<Vec<CommunicationEvent>> {
408 let events = self
409 .events
410 .lock()
411 .map_err(|_| TorshDistributedError::backend_error("profiling", "Lock poisoned"))?;
412 let start_idx = events.len().saturating_sub(count);
413 Ok(events[start_idx..].to_vec())
414 }
415
416 pub fn get_all_events(&self) -> TorshResult<Vec<CommunicationEvent>> {
418 let events = self
419 .events
420 .lock()
421 .map_err(|_| TorshDistributedError::backend_error("profiling", "Lock poisoned"))?;
422 Ok(events.clone())
423 }
424
425 pub fn get_failed_operations_count(&self) -> u64 {
427 let events = match self.events.lock() {
428 Ok(events) => events,
429 Err(_) => return 0, };
431
432 events
435 .iter()
436 .filter(|event| {
437 event.duration.as_millis() > 10000 || event.metadata.contains_key("error")
440 })
441 .count() as u64
442 }
443
444 pub fn clear(&self) -> TorshResult<()> {
446 {
447 let mut events = self
448 .events
449 .lock()
450 .map_err(|_| TorshDistributedError::backend_error("profiling", "Lock poisoned"))?;
451 events.clear();
452 }
453
454 {
455 let mut stats = self
456 .operation_stats
457 .write()
458 .map_err(|_| TorshDistributedError::backend_error("profiling", "Lock poisoned"))?;
459 stats.clear();
460 }
461
462 {
463 let mut rank_stats = self
464 .rank_stats
465 .write()
466 .map_err(|_| TorshDistributedError::backend_error("profiling", "Lock poisoned"))?;
467 rank_stats.clear();
468 }
469
470 {
471 let mut counter = self
472 .event_counter
473 .lock()
474 .map_err(|_| TorshDistributedError::backend_error("profiling", "Lock poisoned"))?;
475 *counter = 0;
476 }
477
478 Ok(())
479 }
480
481 pub fn update_config(&self, config: ProfilingConfig) -> TorshResult<()> {
483 let mut current_config = self
484 .config
485 .write()
486 .map_err(|_| TorshDistributedError::backend_error("profiling", "Lock poisoned"))?;
487 *current_config = config;
488 Ok(())
489 }
490
491 pub fn export_json(&self) -> TorshResult<String> {
493 #[derive(Serialize)]
494 struct ExportData {
495 config: ProfilingConfig,
496 events: Vec<CommunicationEvent>,
497 operation_stats: HashMap<CommunicationOpType, OperationStats>,
498 rank_stats: HashMap<u32, HashMap<CommunicationOpType, OperationStats>>,
499 }
500
501 let config = self
502 .config
503 .read()
504 .map_err(|_| TorshDistributedError::backend_error("profiling", "Lock poisoned"))?
505 .clone();
506 let events = self.get_all_events()?;
507 let operation_stats = self.get_all_operation_stats()?;
508 let rank_stats = self
509 .rank_stats
510 .read()
511 .map_err(|_| TorshDistributedError::backend_error("profiling", "Lock poisoned"))?
512 .clone();
513
514 let export_data = ExportData {
515 config,
516 events,
517 operation_stats,
518 rank_stats,
519 };
520
521 serde_json::to_string_pretty(&export_data).map_err(|e| {
522 TorshDistributedError::backend_error(
523 "profiling",
524 format!("JSON serialization failed: {}", e),
525 )
526 })
527 }
528
529 pub fn generate_summary(&self) -> TorshResult<String> {
531 let mut report = String::new();
532 report.push_str("=== Communication Profiling Summary ===\n\n");
533
534 let events = self.get_all_events()?;
535 let operation_stats = self.get_all_operation_stats()?;
536
537 report.push_str(&format!("Total Events: {}\n", events.len()));
538 report.push_str(&format!(
539 "Profiling Duration: {:.2}s\n\n",
540 SystemTime::now()
541 .duration_since(self.start_time)
542 .unwrap_or_default()
543 .as_secs_f64()
544 ));
545
546 report.push_str("=== Per-Operation Statistics ===\n");
547 for (op_type, stats) in operation_stats.iter() {
548 report.push_str(&format!("\n{} Operations:\n", op_type));
549 report.push_str(&format!(" Count: {}\n", stats.count));
550 report.push_str(&format!(
551 " Total Data: {:.2} MB\n",
552 stats.total_bytes as f64 / (1024.0 * 1024.0)
553 ));
554 report.push_str(&format!(
555 " Avg Latency: {:.2} ms\n",
556 stats.avg_latency.as_secs_f64() * 1000.0
557 ));
558 report.push_str(&format!(
559 " Min Latency: {:.2} ms\n",
560 stats.min_latency.as_secs_f64() * 1000.0
561 ));
562 report.push_str(&format!(
563 " Max Latency: {:.2} ms\n",
564 stats.max_latency.as_secs_f64() * 1000.0
565 ));
566 report.push_str(&format!(
567 " Avg Bandwidth: {:.2} MB/s\n",
568 stats.avg_bandwidth_bps / (1024.0 * 1024.0)
569 ));
570 }
571
572 Ok(report)
573 }
574}
575
576impl Default for CommunicationProfiler {
577 fn default() -> Self {
578 Self::new()
579 }
580}
581
582pub struct ProfilingTimer {
584 start_time: SystemTime,
585 start_instant: Instant,
586}
587
588impl Default for ProfilingTimer {
589 fn default() -> Self {
590 Self::new()
591 }
592}
593
594impl ProfilingTimer {
595 pub fn new() -> Self {
597 Self {
598 start_time: SystemTime::now(),
599 start_instant: Instant::now(),
600 }
601 }
602
603 pub fn elapsed(&self) -> Duration {
605 self.start_instant.elapsed()
606 }
607
608 pub fn start_time(&self) -> SystemTime {
610 self.start_time
611 }
612}
613
614static GLOBAL_PROFILER: std::sync::OnceLock<Arc<CommunicationProfiler>> =
616 std::sync::OnceLock::new();
617
618pub fn get_global_profiler() -> &'static Arc<CommunicationProfiler> {
620 GLOBAL_PROFILER.get_or_init(|| Arc::new(CommunicationProfiler::new()))
621}
622
623pub fn init_global_profiler(config: ProfilingConfig) -> TorshResult<()> {
625 let profiler = Arc::new(CommunicationProfiler::with_config(config));
626 GLOBAL_PROFILER.set(profiler).map_err(|_| {
627 TorshDistributedError::backend_error("profiling", "Global profiler already initialized")
628 })?;
629 Ok(())
630}
631
632#[macro_export]
634macro_rules! profile_communication {
635 ($op_type:expr, $rank:expr, $world_size:expr, $data_size:expr, $code:block) => {{
636 let profiler = $crate::profiling::get_global_profiler();
637 let timer = profiler.start_timing();
638 let result = $code;
639 let _ = profiler.record_event($op_type, $rank, $world_size, $data_size, timer);
640 result
641 }};
642}
643
644#[cfg(test)]
645mod tests {
646 use super::*;
647
648 #[test]
649 fn test_profiler_creation() {
650 let profiler = CommunicationProfiler::new();
651 let stats = profiler.get_all_operation_stats().unwrap();
652 assert!(stats.is_empty());
653 }
654
655 #[test]
656 fn test_event_recording() {
657 let profiler = CommunicationProfiler::new();
658 let timer = profiler.start_timing();
659 std::thread::sleep(Duration::from_millis(10));
660
661 profiler
662 .record_event(CommunicationOpType::AllReduce, 0, 4, 1024, timer)
663 .unwrap();
664
665 let events = profiler.get_all_events().unwrap();
666 assert_eq!(events.len(), 1);
667 assert_eq!(events[0].op_type, CommunicationOpType::AllReduce);
668 assert_eq!(events[0].data_size_bytes, 1024);
669 }
670
671 #[test]
672 fn test_operation_stats() {
673 let profiler = CommunicationProfiler::new();
674
675 for i in 0..5 {
677 let timer = profiler.start_timing();
678 std::thread::sleep(Duration::from_millis(1));
679 profiler
680 .record_event(CommunicationOpType::AllReduce, 0, 4, 1024 * (i + 1), timer)
681 .unwrap();
682 }
683
684 let stats = profiler
685 .get_operation_stats(CommunicationOpType::AllReduce)
686 .unwrap();
687 assert!(stats.is_some());
688 let stats = stats.unwrap();
689 assert_eq!(stats.count, 5);
690 assert_eq!(stats.total_bytes, 1024 + 2048 + 3072 + 4096 + 5120);
691 }
692
693 #[test]
694 fn test_profiler_macro() {
695 let result = profile_communication!(CommunicationOpType::Broadcast, 0, 4, 2048, {
696 std::thread::sleep(Duration::from_millis(5));
697 42
698 });
699
700 assert_eq!(result, 42);
701
702 let profiler = get_global_profiler();
703 let events = profiler.get_all_events().unwrap();
704 assert!(!events.is_empty());
705 }
706
707 #[test]
708 fn test_json_export() {
709 let profiler = CommunicationProfiler::new();
710 let timer = profiler.start_timing();
711 std::thread::sleep(Duration::from_millis(1));
712
713 profiler
714 .record_event(CommunicationOpType::AllGather, 0, 4, 512, timer)
715 .unwrap();
716
717 let json = profiler.export_json().unwrap();
718 assert!(json.contains("AllGather"));
719 assert!(json.contains("512"));
720 }
721
722 #[test]
723 fn test_summary_generation() {
724 let profiler = CommunicationProfiler::new();
725 let timer = profiler.start_timing();
726 std::thread::sleep(Duration::from_millis(1));
727
728 profiler
729 .record_event(CommunicationOpType::Reduce, 0, 4, 256, timer)
730 .unwrap();
731
732 let summary = profiler.generate_summary().unwrap();
733 assert!(summary.contains("Communication Profiling Summary"));
734 assert!(summary.contains("Reduce Operations"));
735 }
736}