1use std::collections::HashMap;
7use std::sync::{Arc, Mutex};
8use std::time::{Duration, Instant};
9
10use super::config::RankMapping;
11
12pub struct Performance3DMonitor {
14 rank_mapping: RankMapping,
16 stats: Arc<Mutex<Performance3DStats>>,
18 timing_history: Arc<Mutex<TimingHistory>>,
20 memory_tracker: Arc<Mutex<MemoryTracker>>,
22 communication_metrics: Arc<Mutex<CommunicationMetrics>>,
24}
25
26impl Performance3DMonitor {
27 pub fn new(rank_mapping: &RankMapping) -> Self {
29 Self {
30 rank_mapping: rank_mapping.clone(),
31 stats: Arc::new(Mutex::new(Performance3DStats::new())),
32 timing_history: Arc::new(Mutex::new(TimingHistory::new())),
33 memory_tracker: Arc::new(Mutex::new(MemoryTracker::new())),
34 communication_metrics: Arc::new(Mutex::new(CommunicationMetrics::new())),
35 }
36 }
37
38 pub async fn record_forward_pass(&self, duration: Duration, num_tokens: usize) {
40 let mut stats = self.stats.lock().expect("lock should not be poisoned");
41 stats.forward_passes += 1;
42 stats.total_forward_time += duration;
43 stats.total_tokens_processed += num_tokens as u64;
44
45 if !stats.total_forward_time.is_zero() {
47 stats.tokens_per_second =
48 stats.total_tokens_processed as f64 / stats.total_forward_time.as_secs_f64();
49 }
50
51 let mut history = self
53 .timing_history
54 .lock()
55 .expect("lock should not be poisoned");
56 history.record_forward_pass(duration, num_tokens);
57
58 stats.computation_time += duration;
60 }
61
62 pub async fn record_backward_pass(&self, duration: Duration, num_tokens: usize) {
64 let mut stats = self.stats.lock().expect("lock should not be poisoned");
65 stats.backward_passes += 1;
66 stats.total_backward_time += duration;
67
68 let mut history = self
70 .timing_history
71 .lock()
72 .expect("lock should not be poisoned");
73 history.record_backward_pass(duration, num_tokens);
74
75 stats.computation_time += duration;
77 }
78
79 pub async fn record_communication(
81 &self,
82 comm_type: CommunicationType,
83 duration: Duration,
84 bytes: usize,
85 ) {
86 let mut stats = self.stats.lock().expect("lock should not be poisoned");
87 stats.communication_time += duration;
88
89 let mut comm_metrics = self
90 .communication_metrics
91 .lock()
92 .expect("lock should not be poisoned");
93 comm_metrics.record_communication(comm_type, duration, bytes);
94 }
95
96 pub fn record_memory_usage(&self, usage_mb: f64) {
98 let mut stats = self.stats.lock().expect("lock should not be poisoned");
99 stats.memory_usage_mb = usage_mb;
100
101 let mut memory_tracker = self
102 .memory_tracker
103 .lock()
104 .expect("lock should not be poisoned");
105 memory_tracker.record_usage(usage_mb);
106 }
107
108 pub fn get_stats(&self) -> Performance3DStats {
110 self.stats
111 .lock()
112 .expect("lock should not be poisoned")
113 .clone()
114 }
115
116 pub fn get_performance_analysis(&self) -> PerformanceAnalysis {
118 let stats = self.stats.lock().expect("lock should not be poisoned");
119 let timing_history = self
120 .timing_history
121 .lock()
122 .expect("lock should not be poisoned");
123 let memory_tracker = self
124 .memory_tracker
125 .lock()
126 .expect("lock should not be poisoned");
127 let comm_metrics = self
128 .communication_metrics
129 .lock()
130 .expect("lock should not be poisoned");
131
132 PerformanceAnalysis {
133 overall_throughput: stats.tokens_per_second,
134 forward_pass_avg_ms: timing_history.avg_forward_time_ms(),
135 backward_pass_avg_ms: timing_history.avg_backward_time_ms(),
136 communication_overhead_percent: self.calculate_communication_overhead(&stats),
137 memory_efficiency: memory_tracker.efficiency(),
138 pipeline_utilization: self.calculate_pipeline_utilization(&timing_history),
139 tensor_parallel_efficiency: self.calculate_tp_efficiency(&comm_metrics),
140 data_parallel_efficiency: self.calculate_dp_efficiency(&comm_metrics),
141 bottlenecks: self.identify_bottlenecks(&stats, &timing_history, &comm_metrics),
142 }
143 }
144
145 fn calculate_communication_overhead(&self, stats: &Performance3DStats) -> f32 {
147 let total_time = stats.computation_time + stats.communication_time;
148 if total_time.is_zero() {
149 0.0
150 } else {
151 (stats.communication_time.as_secs_f32() / total_time.as_secs_f32()) * 100.0
152 }
153 }
154
155 fn calculate_pipeline_utilization(&self, timing_history: &TimingHistory) -> f32 {
157 let ideal_time = timing_history.total_forward_time + timing_history.total_backward_time;
159 if ideal_time.is_zero() {
160 0.0
161 } else {
162 let actual_time = timing_history.wall_clock_time;
163 (ideal_time.as_secs_f32() / actual_time.as_secs_f32()).min(1.0) * 100.0
164 }
165 }
166
167 fn calculate_tp_efficiency(&self, comm_metrics: &CommunicationMetrics) -> f32 {
169 if self.rank_mapping.config.tp_size <= 1 {
171 100.0
172 } else {
173 let tp_comm_time = comm_metrics.get_communication_time(CommunicationType::AllReduceTP);
174 let total_comm_time = comm_metrics.total_communication_time();
175
176 if total_comm_time.is_zero() {
177 100.0
178 } else {
179 let ideal_ratio = 1.0 / self.rank_mapping.config.tp_size as f32;
180 let actual_ratio = tp_comm_time.as_secs_f32() / total_comm_time.as_secs_f32();
181 ((ideal_ratio / actual_ratio.max(ideal_ratio)) * 100.0).min(100.0)
182 }
183 }
184 }
185
186 fn calculate_dp_efficiency(&self, comm_metrics: &CommunicationMetrics) -> f32 {
188 if self.rank_mapping.config.dp_size <= 1 {
189 100.0
190 } else {
191 let dp_comm_time = comm_metrics.get_communication_time(CommunicationType::AllReduceDP);
193 let computation_time = self
194 .stats
195 .lock()
196 .expect("lock should not be poisoned")
197 .computation_time;
198
199 if computation_time.is_zero() {
200 100.0
201 } else {
202 let comm_ratio = dp_comm_time.as_secs_f32() / computation_time.as_secs_f32();
203 ((1.0 / (1.0 + comm_ratio)) * 100.0).min(100.0)
204 }
205 }
206 }
207
208 fn identify_bottlenecks(
210 &self,
211 stats: &Performance3DStats,
212 timing_history: &TimingHistory,
213 comm_metrics: &CommunicationMetrics,
214 ) -> Vec<PerformanceBottleneck> {
215 let mut bottlenecks = Vec::new();
216
217 let comm_overhead = self.calculate_communication_overhead(stats);
219 if comm_overhead > 30.0 {
220 bottlenecks.push(PerformanceBottleneck {
221 category: "Communication".to_string(),
222 description: format!("High communication overhead: {:.1}%", comm_overhead),
223 severity: BottleneckSeverity::High,
224 suggested_fix:
225 "Consider increasing micro-batch size or optimizing communication patterns"
226 .to_string(),
227 });
228 }
229
230 if stats.memory_usage_mb
232 > 0.9 * (self.rank_mapping.config.max_memory_per_device as f64) * 1024.0
233 {
234 bottlenecks.push(PerformanceBottleneck {
235 category: "Memory".to_string(),
236 description: "Memory usage near capacity".to_string(),
237 severity: BottleneckSeverity::Critical,
238 suggested_fix: "Enable gradient checkpointing or reduce model size".to_string(),
239 });
240 }
241
242 let pipeline_util = self.calculate_pipeline_utilization(timing_history);
244 if pipeline_util < 70.0 {
245 bottlenecks.push(PerformanceBottleneck {
246 category: "Pipeline".to_string(),
247 description: format!("Low pipeline utilization: {:.1}%", pipeline_util),
248 severity: BottleneckSeverity::Medium,
249 suggested_fix: "Adjust micro-batch size or pipeline schedule".to_string(),
250 });
251 }
252
253 let tp_efficiency = self.calculate_tp_efficiency(comm_metrics);
255 if tp_efficiency < 80.0 && self.rank_mapping.config.tp_size > 1 {
256 bottlenecks.push(PerformanceBottleneck {
257 category: "TensorParallel".to_string(),
258 description: format!("Low tensor parallel efficiency: {:.1}%", tp_efficiency),
259 severity: BottleneckSeverity::Medium,
260 suggested_fix: "Optimize tensor parallel communication or reduce TP size"
261 .to_string(),
262 });
263 }
264
265 bottlenecks
266 }
267
268 pub fn generate_report(&self) -> String {
270 let analysis = self.get_performance_analysis();
271 let stats = self.get_stats();
272
273 format!(
274 "🚀 3D Parallelism Performance Report\n\
275 ===================================\n\
276 \n\
277 📊 Overall Performance:\n\
278 • Throughput: {:.1} tokens/second\n\
279 • Forward Pass: {:.2}ms avg\n\
280 • Backward Pass: {:.2}ms avg\n\
281 • Communication Overhead: {:.1}%\n\
282 \n\
283 💾 Memory Metrics:\n\
284 • Current Usage: {:.1} MB\n\
285 • Memory Efficiency: {:.1}%\n\
286 \n\
287 🔄 Parallelism Efficiency:\n\
288 • Pipeline Utilization: {:.1}%\n\
289 • Tensor Parallel Efficiency: {:.1}%\n\
290 • Data Parallel Efficiency: {:.1}%\n\
291 \n\
292 ⚠️ Bottlenecks Identified:\n\
293 {}\n\
294 \n\
295 📈 Statistics:\n\
296 • Forward Passes: {}\n\
297 • Backward Passes: {}\n\
298 • Total Tokens Processed: {}\n\
299 • Total Computation Time: {:.2}s\n\
300 • Total Communication Time: {:.2}s\n",
301 analysis.overall_throughput,
302 analysis.forward_pass_avg_ms,
303 analysis.backward_pass_avg_ms,
304 analysis.communication_overhead_percent,
305 stats.memory_usage_mb,
306 analysis.memory_efficiency,
307 analysis.pipeline_utilization,
308 analysis.tensor_parallel_efficiency,
309 analysis.data_parallel_efficiency,
310 self.format_bottlenecks(&analysis.bottlenecks),
311 stats.forward_passes,
312 stats.backward_passes,
313 stats.total_tokens_processed,
314 stats.computation_time.as_secs_f64(),
315 stats.communication_time.as_secs_f64()
316 )
317 }
318
319 fn format_bottlenecks(&self, bottlenecks: &[PerformanceBottleneck]) -> String {
321 if bottlenecks.is_empty() {
322 "No significant bottlenecks detected".to_string()
323 } else {
324 bottlenecks
325 .iter()
326 .map(|b| {
327 format!(
328 "• {}: {} ({})",
329 b.category,
330 b.description,
331 b.severity.as_str()
332 )
333 })
334 .collect::<Vec<_>>()
335 .join("\n")
336 }
337 }
338
339 pub fn reset_stats(&self) {
341 let mut stats = self.stats.lock().expect("lock should not be poisoned");
342 *stats = Performance3DStats::new();
343
344 let mut history = self
345 .timing_history
346 .lock()
347 .expect("lock should not be poisoned");
348 *history = TimingHistory::new();
349
350 let mut memory_tracker = self
351 .memory_tracker
352 .lock()
353 .expect("lock should not be poisoned");
354 *memory_tracker = MemoryTracker::new();
355
356 let mut comm_metrics = self
357 .communication_metrics
358 .lock()
359 .expect("lock should not be poisoned");
360 *comm_metrics = CommunicationMetrics::new();
361 }
362}
363
364#[derive(Debug, Clone)]
366pub struct Performance3DStats {
367 pub forward_passes: u64,
368 pub backward_passes: u64,
369 pub total_forward_time: Duration,
370 pub total_backward_time: Duration,
371 pub total_tokens_processed: u64,
372 pub tokens_per_second: f64,
373 pub communication_time: Duration,
374 pub computation_time: Duration,
375 pub memory_usage_mb: f64,
376}
377
378impl Default for Performance3DStats {
379 fn default() -> Self {
380 Self::new()
381 }
382}
383
384impl Performance3DStats {
385 pub fn new() -> Self {
386 Self {
387 forward_passes: 0,
388 backward_passes: 0,
389 total_forward_time: Duration::ZERO,
390 total_backward_time: Duration::ZERO,
391 total_tokens_processed: 0,
392 tokens_per_second: 0.0,
393 communication_time: Duration::ZERO,
394 computation_time: Duration::ZERO,
395 memory_usage_mb: 0.0,
396 }
397 }
398}
399
400#[derive(Debug, Clone)]
402pub struct PerformanceAnalysis {
403 pub overall_throughput: f64,
404 pub forward_pass_avg_ms: f32,
405 pub backward_pass_avg_ms: f32,
406 pub communication_overhead_percent: f32,
407 pub memory_efficiency: f32,
408 pub pipeline_utilization: f32,
409 pub tensor_parallel_efficiency: f32,
410 pub data_parallel_efficiency: f32,
411 pub bottlenecks: Vec<PerformanceBottleneck>,
412}
413
414#[derive(Debug, Clone)]
416pub struct PerformanceBottleneck {
417 pub category: String,
418 pub description: String,
419 pub severity: BottleneckSeverity,
420 pub suggested_fix: String,
421}
422
423#[derive(Debug, Clone, PartialEq)]
425pub enum BottleneckSeverity {
426 Low,
427 Medium,
428 High,
429 Critical,
430}
431
432impl BottleneckSeverity {
433 pub fn as_str(&self) -> &'static str {
434 match self {
435 Self::Low => "Low",
436 Self::Medium => "Medium",
437 Self::High => "High",
438 Self::Critical => "Critical",
439 }
440 }
441}
442
443#[derive(Debug, Clone)]
445struct TimingHistory {
446 forward_times: Vec<Duration>,
447 backward_times: Vec<Duration>,
448 total_forward_time: Duration,
449 total_backward_time: Duration,
450 wall_clock_time: Duration,
451 start_time: Option<Instant>,
452}
453
454impl TimingHistory {
455 fn new() -> Self {
456 Self {
457 forward_times: Vec::new(),
458 backward_times: Vec::new(),
459 total_forward_time: Duration::ZERO,
460 total_backward_time: Duration::ZERO,
461 wall_clock_time: Duration::ZERO,
462 start_time: Some(Instant::now()),
463 }
464 }
465
466 fn record_forward_pass(&mut self, duration: Duration, _num_tokens: usize) {
467 self.forward_times.push(duration);
468 self.total_forward_time += duration;
469 self.update_wall_clock_time();
470
471 if self.forward_times.len() > 1000 {
473 self.forward_times.remove(0);
474 }
475 }
476
477 fn record_backward_pass(&mut self, duration: Duration, _num_tokens: usize) {
478 self.backward_times.push(duration);
479 self.total_backward_time += duration;
480 self.update_wall_clock_time();
481
482 if self.backward_times.len() > 1000 {
484 self.backward_times.remove(0);
485 }
486 }
487
488 fn update_wall_clock_time(&mut self) {
489 if let Some(start) = self.start_time {
490 self.wall_clock_time = start.elapsed();
491 }
492 }
493
494 fn avg_forward_time_ms(&self) -> f32 {
495 if self.forward_times.is_empty() {
496 0.0
497 } else {
498 let total: Duration = self.forward_times.iter().sum();
499 total.as_secs_f32() * 1000.0 / self.forward_times.len() as f32
500 }
501 }
502
503 fn avg_backward_time_ms(&self) -> f32 {
504 if self.backward_times.is_empty() {
505 0.0
506 } else {
507 let total: Duration = self.backward_times.iter().sum();
508 total.as_secs_f32() * 1000.0 / self.backward_times.len() as f32
509 }
510 }
511}
512
513#[derive(Debug, Clone)]
515struct MemoryTracker {
516 usage_history: Vec<f64>,
517 peak_usage: f64,
518 average_usage: f64,
519}
520
521impl MemoryTracker {
522 fn new() -> Self {
523 Self {
524 usage_history: Vec::new(),
525 peak_usage: 0.0,
526 average_usage: 0.0,
527 }
528 }
529
530 fn record_usage(&mut self, usage_mb: f64) {
531 self.usage_history.push(usage_mb);
532 self.peak_usage = self.peak_usage.max(usage_mb);
533
534 if !self.usage_history.is_empty() {
536 self.average_usage =
537 self.usage_history.iter().sum::<f64>() / self.usage_history.len() as f64;
538 }
539
540 if self.usage_history.len() > 1000 {
542 self.usage_history.remove(0);
543 }
544 }
545
546 fn efficiency(&self) -> f32 {
547 if self.peak_usage == 0.0 {
548 100.0
549 } else {
550 (self.average_usage / self.peak_usage * 100.0) as f32
551 }
552 }
553}
554
555#[derive(Debug, Clone)]
557struct CommunicationMetrics {
558 communication_times: HashMap<CommunicationType, Vec<Duration>>,
559 bytes_transferred: HashMap<CommunicationType, Vec<usize>>,
560}
561
562impl CommunicationMetrics {
563 fn new() -> Self {
564 Self {
565 communication_times: HashMap::new(),
566 bytes_transferred: HashMap::new(),
567 }
568 }
569
570 fn record_communication(
571 &mut self,
572 comm_type: CommunicationType,
573 duration: Duration,
574 bytes: usize,
575 ) {
576 self.communication_times
577 .entry(comm_type)
578 .or_default()
579 .push(duration);
580
581 self.bytes_transferred
582 .entry(comm_type)
583 .or_default()
584 .push(bytes);
585 }
586
587 fn get_communication_time(&self, comm_type: CommunicationType) -> Duration {
588 self.communication_times
589 .get(&comm_type)
590 .map(|times| times.iter().sum())
591 .unwrap_or(Duration::ZERO)
592 }
593
594 fn total_communication_time(&self) -> Duration {
595 self.communication_times
596 .values()
597 .flat_map(|times| times.iter())
598 .sum()
599 }
600}
601
602#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
604pub enum CommunicationType {
605 AllReduceDP,
606 AllReduceTP,
607 AllGatherTP,
608 ReduceScatterTP,
609 Send,
610 Recv,
611}
612
613#[derive(Debug, Clone)]
615pub struct Memory3DStats {
616 pub model_memory: usize,
617 pub activation_memory: usize,
618 pub gradient_memory: usize,
619 pub optimizer_memory: usize,
620 pub total_memory: usize,
621 pub peak_memory: usize,
622 pub memory_efficiency: f32,
623}
624
625impl Default for Memory3DStats {
626 fn default() -> Self {
627 Self::new()
628 }
629}
630
631impl Memory3DStats {
632 pub fn new() -> Self {
633 Self {
634 model_memory: 0,
635 activation_memory: 0,
636 gradient_memory: 0,
637 optimizer_memory: 0,
638 total_memory: 0,
639 peak_memory: 0,
640 memory_efficiency: 0.0,
641 }
642 }
643}