1use std::collections::HashMap;
9use std::time::Duration;
10
11#[derive(Debug, Clone)]
20pub struct Zero3PerformanceStats {
21 pub forward_passes: u64,
23 pub backward_passes: u64,
25 pub optimizer_steps: u64,
27 pub total_forward_time: Duration,
29 pub total_backward_time: Duration,
31 pub total_optimizer_time: Duration,
33 pub parameter_transfer_time: Duration,
35 pub gradient_sync_time: Duration,
37 pub layer_timings: HashMap<String, LayerTimingStats>,
39 pub throughput_metrics: ThroughputMetrics,
41 pub memory_transfer_metrics: MemoryTransferMetrics,
43 pub communication_stats: CommunicationStats,
45 pub optimization_efficiency: OptimizationEfficiency,
47}
48
49impl Zero3PerformanceStats {
50 pub fn new() -> Self {
52 Self {
53 forward_passes: 0,
54 backward_passes: 0,
55 optimizer_steps: 0,
56 total_forward_time: Duration::ZERO,
57 total_backward_time: Duration::ZERO,
58 total_optimizer_time: Duration::ZERO,
59 parameter_transfer_time: Duration::ZERO,
60 gradient_sync_time: Duration::ZERO,
61 layer_timings: HashMap::new(),
62 throughput_metrics: ThroughputMetrics::new(),
63 memory_transfer_metrics: MemoryTransferMetrics::new(),
64 communication_stats: CommunicationStats::new(),
65 optimization_efficiency: OptimizationEfficiency::new(),
66 }
67 }
68
69 pub fn record_forward_pass(&mut self, duration: Duration, num_tokens: usize) {
71 self.forward_passes += 1;
72 self.total_forward_time += duration;
73 self.throughput_metrics
74 .record_forward_pass(duration, num_tokens);
75 self.optimization_efficiency.record_forward_pass(duration);
76 }
77
78 pub fn record_backward_pass(&mut self, duration: Duration, num_tokens: usize) {
80 self.backward_passes += 1;
81 self.total_backward_time += duration;
82 self.throughput_metrics
83 .record_backward_pass(duration, num_tokens);
84 self.optimization_efficiency.record_backward_pass(duration);
85 }
86
87 pub fn record_optimizer_step(&mut self, duration: Duration, num_params: usize) {
89 self.optimizer_steps += 1;
90 self.total_optimizer_time += duration;
91 self.optimization_efficiency
92 .record_optimizer_step(duration, num_params);
93 }
94
95 pub fn record_layer_execution(&mut self, layer_name: String, duration: Duration) {
97 let layer_stats = self.layer_timings.entry(layer_name.clone()).or_default();
98 layer_stats.record_forward_execution(duration);
99 }
100
101 pub fn record_layer_backward(&mut self, layer_name: String, duration: Duration) {
103 let layer_stats = self.layer_timings.entry(layer_name).or_default();
104 layer_stats.record_backward_execution(duration);
105 }
106
107 pub fn record_parameter_transfer(
109 &mut self,
110 duration: Duration,
111 bytes_transferred: usize,
112 direction: TransferDirection,
113 ) {
114 self.parameter_transfer_time += duration;
115 self.memory_transfer_metrics
116 .record_transfer(duration, bytes_transferred, direction);
117 }
118
119 pub fn record_gradient_sync(
121 &mut self,
122 duration: Duration,
123 num_gradients: usize,
124 world_size: usize,
125 ) {
126 self.gradient_sync_time += duration;
127 self.communication_stats
128 .record_gradient_sync(duration, num_gradients, world_size);
129 }
130
131 pub fn record_communication(
133 &mut self,
134 operation: CommunicationOperation,
135 duration: Duration,
136 bytes: usize,
137 ) {
138 self.communication_stats
139 .record_operation(operation, duration, bytes);
140 }
141
142 pub fn average_forward_time(&self) -> Duration {
144 if self.forward_passes > 0 {
145 self.total_forward_time / self.forward_passes as u32
146 } else {
147 Duration::ZERO
148 }
149 }
150
151 pub fn average_backward_time(&self) -> Duration {
153 if self.backward_passes > 0 {
154 self.total_backward_time / self.backward_passes as u32
155 } else {
156 Duration::ZERO
157 }
158 }
159
160 pub fn average_optimizer_time(&self) -> Duration {
162 if self.optimizer_steps > 0 {
163 self.total_optimizer_time / self.optimizer_steps as u32
164 } else {
165 Duration::ZERO
166 }
167 }
168
169 pub fn get_tokens_per_second(&self) -> f64 {
171 self.throughput_metrics.get_tokens_per_second()
172 }
173
174 pub fn get_memory_bandwidth_gbps(&self) -> f64 {
176 self.memory_transfer_metrics.get_bandwidth_gbps()
177 }
178
179 pub fn get_communication_efficiency(&self) -> f64 {
181 self.communication_stats.get_efficiency()
182 }
183
184 pub fn get_training_efficiency(&self) -> f64 {
186 self.optimization_efficiency.get_overall_efficiency()
187 }
188
189 pub fn get_performance_summary(&self) -> PerformanceSummary {
191 PerformanceSummary {
192 total_operations: self.forward_passes + self.backward_passes + self.optimizer_steps,
193 average_forward_time: self.average_forward_time(),
194 average_backward_time: self.average_backward_time(),
195 average_optimizer_time: self.average_optimizer_time(),
196 tokens_per_second: self.get_tokens_per_second(),
197 memory_bandwidth_gbps: self.get_memory_bandwidth_gbps(),
198 communication_efficiency: self.get_communication_efficiency(),
199 training_efficiency: self.get_training_efficiency(),
200 memory_transfer_efficiency: self.memory_transfer_metrics.get_efficiency(),
201 layer_performance: self.get_layer_performance_summary(),
202 }
203 }
204
205 fn get_layer_performance_summary(&self) -> HashMap<String, LayerPerformanceSummary> {
207 self.layer_timings
208 .iter()
209 .map(|(name, stats)| {
210 (
211 name.clone(),
212 LayerPerformanceSummary {
213 total_executions: stats.forward_executions + stats.backward_executions,
214 average_forward_time: stats.average_forward_time(),
215 average_backward_time: stats.average_backward_time(),
216 total_time: stats.total_forward_time + stats.total_backward_time,
217 },
218 )
219 })
220 .collect()
221 }
222
223 pub fn reset(&mut self) {
225 *self = Self::new();
226 }
227
228 pub fn merge(&mut self, other: &Zero3PerformanceStats) {
230 self.forward_passes += other.forward_passes;
231 self.backward_passes += other.backward_passes;
232 self.optimizer_steps += other.optimizer_steps;
233 self.total_forward_time += other.total_forward_time;
234 self.total_backward_time += other.total_backward_time;
235 self.total_optimizer_time += other.total_optimizer_time;
236 self.parameter_transfer_time += other.parameter_transfer_time;
237 self.gradient_sync_time += other.gradient_sync_time;
238
239 for (layer_name, other_stats) in &other.layer_timings {
241 let layer_stats = self.layer_timings.entry(layer_name.clone()).or_default();
242 layer_stats.merge(other_stats);
243 }
244
245 self.throughput_metrics.merge(&other.throughput_metrics);
246 self.memory_transfer_metrics
247 .merge(&other.memory_transfer_metrics);
248 self.communication_stats.merge(&other.communication_stats);
249 self.optimization_efficiency
250 .merge(&other.optimization_efficiency);
251 }
252}
253
254impl Default for Zero3PerformanceStats {
255 fn default() -> Self {
256 Self::new()
257 }
258}
259
260#[derive(Debug, Clone)]
262pub struct LayerTimingStats {
263 pub forward_executions: u64,
265 pub backward_executions: u64,
267 pub total_forward_time: Duration,
269 pub total_backward_time: Duration,
271 pub min_forward_time: Duration,
273 pub max_forward_time: Duration,
275 pub min_backward_time: Duration,
277 pub max_backward_time: Duration,
279}
280
281impl LayerTimingStats {
282 pub fn new() -> Self {
283 Self {
284 forward_executions: 0,
285 backward_executions: 0,
286 total_forward_time: Duration::ZERO,
287 total_backward_time: Duration::ZERO,
288 min_forward_time: Duration::MAX,
289 max_forward_time: Duration::ZERO,
290 min_backward_time: Duration::MAX,
291 max_backward_time: Duration::ZERO,
292 }
293 }
294
295 pub fn record_forward_execution(&mut self, duration: Duration) {
296 self.forward_executions += 1;
297 self.total_forward_time += duration;
298 self.min_forward_time = self.min_forward_time.min(duration);
299 self.max_forward_time = self.max_forward_time.max(duration);
300 }
301
302 pub fn record_backward_execution(&mut self, duration: Duration) {
303 self.backward_executions += 1;
304 self.total_backward_time += duration;
305 self.min_backward_time = self.min_backward_time.min(duration);
306 self.max_backward_time = self.max_backward_time.max(duration);
307 }
308
309 pub fn average_forward_time(&self) -> Duration {
310 if self.forward_executions > 0 {
311 self.total_forward_time / self.forward_executions as u32
312 } else {
313 Duration::ZERO
314 }
315 }
316
317 pub fn average_backward_time(&self) -> Duration {
318 if self.backward_executions > 0 {
319 self.total_backward_time / self.backward_executions as u32
320 } else {
321 Duration::ZERO
322 }
323 }
324
325 pub fn merge(&mut self, other: &LayerTimingStats) {
326 self.forward_executions += other.forward_executions;
327 self.backward_executions += other.backward_executions;
328 self.total_forward_time += other.total_forward_time;
329 self.total_backward_time += other.total_backward_time;
330 self.min_forward_time = self.min_forward_time.min(other.min_forward_time);
331 self.max_forward_time = self.max_forward_time.max(other.max_forward_time);
332 self.min_backward_time = self.min_backward_time.min(other.min_backward_time);
333 self.max_backward_time = self.max_backward_time.max(other.max_backward_time);
334 }
335}
336
337impl Default for LayerTimingStats {
338 fn default() -> Self {
339 Self::new()
340 }
341}
342
343#[derive(Debug, Clone)]
345pub struct ThroughputMetrics {
346 pub total_forward_tokens: usize,
348 pub total_backward_tokens: usize,
350 pub total_forward_time: Duration,
352 pub total_backward_time: Duration,
354 pub peak_tokens_per_second: f64,
356 pub rolling_average_tps: f64,
358 pub rolling_samples: u32,
360}
361
362impl ThroughputMetrics {
363 pub fn new() -> Self {
364 Self {
365 total_forward_tokens: 0,
366 total_backward_tokens: 0,
367 total_forward_time: Duration::ZERO,
368 total_backward_time: Duration::ZERO,
369 peak_tokens_per_second: 0.0,
370 rolling_average_tps: 0.0,
371 rolling_samples: 0,
372 }
373 }
374
375 pub fn record_forward_pass(&mut self, duration: Duration, num_tokens: usize) {
376 self.total_forward_tokens += num_tokens;
377 self.total_forward_time += duration;
378 self.update_rolling_average(duration, num_tokens);
379 }
380
381 pub fn record_backward_pass(&mut self, duration: Duration, num_tokens: usize) {
382 self.total_backward_tokens += num_tokens;
383 self.total_backward_time += duration;
384 self.update_rolling_average(duration, num_tokens);
385 }
386
387 fn update_rolling_average(&mut self, duration: Duration, num_tokens: usize) {
388 if !duration.is_zero() {
389 let current_tps = num_tokens as f64 / duration.as_secs_f64();
390 self.peak_tokens_per_second = self.peak_tokens_per_second.max(current_tps);
391
392 let alpha = 0.1; if self.rolling_samples == 0 {
395 self.rolling_average_tps = current_tps;
396 } else {
397 self.rolling_average_tps =
398 alpha * current_tps + (1.0 - alpha) * self.rolling_average_tps;
399 }
400 self.rolling_samples += 1;
401 }
402 }
403
404 pub fn get_tokens_per_second(&self) -> f64 {
405 let total_time = self.total_forward_time + self.total_backward_time;
406 let total_tokens = self.total_forward_tokens + self.total_backward_tokens;
407
408 if !total_time.is_zero() && total_tokens > 0 {
409 total_tokens as f64 / total_time.as_secs_f64()
410 } else {
411 0.0
412 }
413 }
414
415 pub fn get_forward_tps(&self) -> f64 {
416 if !self.total_forward_time.is_zero() && self.total_forward_tokens > 0 {
417 self.total_forward_tokens as f64 / self.total_forward_time.as_secs_f64()
418 } else {
419 0.0
420 }
421 }
422
423 pub fn get_backward_tps(&self) -> f64 {
424 if !self.total_backward_time.is_zero() && self.total_backward_tokens > 0 {
425 self.total_backward_tokens as f64 / self.total_backward_time.as_secs_f64()
426 } else {
427 0.0
428 }
429 }
430
431 pub fn merge(&mut self, other: &ThroughputMetrics) {
432 self.total_forward_tokens += other.total_forward_tokens;
433 self.total_backward_tokens += other.total_backward_tokens;
434 self.total_forward_time += other.total_forward_time;
435 self.total_backward_time += other.total_backward_time;
436 self.peak_tokens_per_second = self
437 .peak_tokens_per_second
438 .max(other.peak_tokens_per_second);
439
440 let total_samples = self.rolling_samples + other.rolling_samples;
442 if total_samples > 0 {
443 let self_weight = self.rolling_samples as f64 / total_samples as f64;
444 let other_weight = other.rolling_samples as f64 / total_samples as f64;
445 self.rolling_average_tps =
446 self_weight * self.rolling_average_tps + other_weight * other.rolling_average_tps;
447 self.rolling_samples = total_samples;
448 }
449 }
450}
451
452impl Default for ThroughputMetrics {
453 fn default() -> Self {
454 Self::new()
455 }
456}
457
458#[derive(Debug, Clone)]
460pub struct MemoryTransferMetrics {
461 pub cpu_to_gpu_bytes: usize,
463 pub gpu_to_cpu_bytes: usize,
465 pub cpu_to_gpu_time: Duration,
467 pub gpu_to_cpu_time: Duration,
469 pub cpu_to_gpu_transfers: u64,
471 pub gpu_to_cpu_transfers: u64,
473 pub peak_bandwidth: f64,
475 pub transfer_efficiency: f64,
477}
478
479impl MemoryTransferMetrics {
480 pub fn new() -> Self {
481 Self {
482 cpu_to_gpu_bytes: 0,
483 gpu_to_cpu_bytes: 0,
484 cpu_to_gpu_time: Duration::ZERO,
485 gpu_to_cpu_time: Duration::ZERO,
486 cpu_to_gpu_transfers: 0,
487 gpu_to_cpu_transfers: 0,
488 peak_bandwidth: 0.0,
489 transfer_efficiency: 1.0,
490 }
491 }
492
493 pub fn record_transfer(
494 &mut self,
495 duration: Duration,
496 bytes: usize,
497 direction: TransferDirection,
498 ) {
499 if !duration.is_zero() {
500 let bandwidth = bytes as f64 / duration.as_secs_f64();
501 self.peak_bandwidth = self.peak_bandwidth.max(bandwidth);
502 }
503
504 match direction {
505 TransferDirection::CpuToGpu => {
506 self.cpu_to_gpu_bytes += bytes;
507 self.cpu_to_gpu_time += duration;
508 self.cpu_to_gpu_transfers += 1;
509 }
510 TransferDirection::GpuToCpu => {
511 self.gpu_to_cpu_bytes += bytes;
512 self.gpu_to_cpu_time += duration;
513 self.gpu_to_cpu_transfers += 1;
514 }
515 }
516
517 self.update_efficiency();
518 }
519
520 fn update_efficiency(&mut self) {
521 let theoretical_bandwidth = 1_000_000_000.0; let actual_bandwidth = self.get_bandwidth_bps();
525
526 if theoretical_bandwidth > 0.0 {
527 self.transfer_efficiency = (actual_bandwidth / theoretical_bandwidth).min(1.0);
528 }
529 }
530
531 pub fn get_bandwidth_gbps(&self) -> f64 {
532 self.get_bandwidth_bps() / (1024.0 * 1024.0 * 1024.0)
533 }
534
535 pub fn get_bandwidth_bps(&self) -> f64 {
536 let total_bytes = self.cpu_to_gpu_bytes + self.gpu_to_cpu_bytes;
537 let total_time = self.cpu_to_gpu_time + self.gpu_to_cpu_time;
538
539 if !total_time.is_zero() && total_bytes > 0 {
540 total_bytes as f64 / total_time.as_secs_f64()
541 } else {
542 0.0
543 }
544 }
545
546 pub fn get_cpu_to_gpu_bandwidth(&self) -> f64 {
547 if !self.cpu_to_gpu_time.is_zero() && self.cpu_to_gpu_bytes > 0 {
548 self.cpu_to_gpu_bytes as f64 / self.cpu_to_gpu_time.as_secs_f64()
549 } else {
550 0.0
551 }
552 }
553
554 pub fn get_gpu_to_cpu_bandwidth(&self) -> f64 {
555 if !self.gpu_to_cpu_time.is_zero() && self.gpu_to_cpu_bytes > 0 {
556 self.gpu_to_cpu_bytes as f64 / self.gpu_to_cpu_time.as_secs_f64()
557 } else {
558 0.0
559 }
560 }
561
562 pub fn get_efficiency(&self) -> f64 {
563 self.transfer_efficiency
564 }
565
566 pub fn merge(&mut self, other: &MemoryTransferMetrics) {
567 self.cpu_to_gpu_bytes += other.cpu_to_gpu_bytes;
568 self.gpu_to_cpu_bytes += other.gpu_to_cpu_bytes;
569 self.cpu_to_gpu_time += other.cpu_to_gpu_time;
570 self.gpu_to_cpu_time += other.gpu_to_cpu_time;
571 self.cpu_to_gpu_transfers += other.cpu_to_gpu_transfers;
572 self.gpu_to_cpu_transfers += other.gpu_to_cpu_transfers;
573 self.peak_bandwidth = self.peak_bandwidth.max(other.peak_bandwidth);
574 self.update_efficiency();
575 }
576}
577
578impl Default for MemoryTransferMetrics {
579 fn default() -> Self {
580 Self::new()
581 }
582}
583
584#[derive(Debug, Clone, Copy, PartialEq, Eq)]
586pub enum TransferDirection {
587 CpuToGpu,
588 GpuToCpu,
589}
590
591#[derive(Debug, Clone)]
593pub struct CommunicationStats {
594 pub allreduce_operations: u64,
596 pub allreduce_time: Duration,
598 pub allreduce_bytes: usize,
600 pub broadcast_operations: u64,
602 pub broadcast_time: Duration,
604 pub broadcast_bytes: usize,
606 pub p2p_operations: u64,
608 pub p2p_time: Duration,
610 pub p2p_bytes: usize,
612 pub communication_efficiency: f64,
614}
615
616impl CommunicationStats {
617 pub fn new() -> Self {
618 Self {
619 allreduce_operations: 0,
620 allreduce_time: Duration::ZERO,
621 allreduce_bytes: 0,
622 broadcast_operations: 0,
623 broadcast_time: Duration::ZERO,
624 broadcast_bytes: 0,
625 p2p_operations: 0,
626 p2p_time: Duration::ZERO,
627 p2p_bytes: 0,
628 communication_efficiency: 1.0,
629 }
630 }
631
632 pub fn record_gradient_sync(
633 &mut self,
634 duration: Duration,
635 num_gradients: usize,
636 world_size: usize,
637 ) {
638 self.allreduce_operations += 1;
640 self.allreduce_time += duration;
641 let estimated_bytes = num_gradients * 4 * world_size; self.allreduce_bytes += estimated_bytes;
644 self.update_efficiency();
645 }
646
647 pub fn record_operation(
648 &mut self,
649 operation: CommunicationOperation,
650 duration: Duration,
651 bytes: usize,
652 ) {
653 match operation {
654 CommunicationOperation::AllReduce => {
655 self.allreduce_operations += 1;
656 self.allreduce_time += duration;
657 self.allreduce_bytes += bytes;
658 }
659 CommunicationOperation::Broadcast => {
660 self.broadcast_operations += 1;
661 self.broadcast_time += duration;
662 self.broadcast_bytes += bytes;
663 }
664 CommunicationOperation::PointToPoint => {
665 self.p2p_operations += 1;
666 self.p2p_time += duration;
667 self.p2p_bytes += bytes;
668 }
669 }
670 self.update_efficiency();
671 }
672
673 fn update_efficiency(&mut self) {
674 let total_time = self.allreduce_time + self.broadcast_time + self.p2p_time;
676 let total_bytes = self.allreduce_bytes + self.broadcast_bytes + self.p2p_bytes;
677
678 if !total_time.is_zero() && total_bytes > 0 {
679 let achieved_bandwidth = total_bytes as f64 / total_time.as_secs_f64();
680 let theoretical_bandwidth = 10_000_000_000.0; self.communication_efficiency = (achieved_bandwidth / theoretical_bandwidth).min(1.0);
682 }
683 }
684
685 pub fn get_efficiency(&self) -> f64 {
686 self.communication_efficiency
687 }
688
689 pub fn get_allreduce_bandwidth(&self) -> f64 {
690 if !self.allreduce_time.is_zero() && self.allreduce_bytes > 0 {
691 self.allreduce_bytes as f64 / self.allreduce_time.as_secs_f64()
692 } else {
693 0.0
694 }
695 }
696
697 pub fn get_broadcast_bandwidth(&self) -> f64 {
698 if !self.broadcast_time.is_zero() && self.broadcast_bytes > 0 {
699 self.broadcast_bytes as f64 / self.broadcast_time.as_secs_f64()
700 } else {
701 0.0
702 }
703 }
704
705 pub fn merge(&mut self, other: &CommunicationStats) {
706 self.allreduce_operations += other.allreduce_operations;
707 self.allreduce_time += other.allreduce_time;
708 self.allreduce_bytes += other.allreduce_bytes;
709 self.broadcast_operations += other.broadcast_operations;
710 self.broadcast_time += other.broadcast_time;
711 self.broadcast_bytes += other.broadcast_bytes;
712 self.p2p_operations += other.p2p_operations;
713 self.p2p_time += other.p2p_time;
714 self.p2p_bytes += other.p2p_bytes;
715 self.update_efficiency();
716 }
717}
718
719impl Default for CommunicationStats {
720 fn default() -> Self {
721 Self::new()
722 }
723}
724
725#[derive(Debug, Clone, Copy, PartialEq, Eq)]
727pub enum CommunicationOperation {
728 AllReduce,
729 Broadcast,
730 PointToPoint,
731}
732
733#[derive(Debug, Clone)]
735pub struct OptimizationEfficiency {
736 pub compute_time: Duration,
738 pub communication_time: Duration,
740 pub memory_efficiency: f64,
742 pub parameter_update_efficiency: f64,
744 pub overall_efficiency: f64,
746 pub measurements: u32,
748}
749
750impl OptimizationEfficiency {
751 pub fn new() -> Self {
752 Self {
753 compute_time: Duration::ZERO,
754 communication_time: Duration::ZERO,
755 memory_efficiency: 1.0,
756 parameter_update_efficiency: 1.0,
757 overall_efficiency: 1.0,
758 measurements: 0,
759 }
760 }
761
762 pub fn record_forward_pass(&mut self, duration: Duration) {
763 self.compute_time += duration;
764 self.update_efficiency();
765 }
766
767 pub fn record_backward_pass(&mut self, duration: Duration) {
768 self.compute_time += duration;
769 self.update_efficiency();
770 }
771
772 pub fn record_optimizer_step(&mut self, duration: Duration, _num_params: usize) {
773 self.compute_time += duration;
774 self.update_efficiency();
775 }
776
777 pub fn record_communication(&mut self, duration: Duration) {
778 self.communication_time += duration;
779 self.update_efficiency();
780 }
781
782 fn update_efficiency(&mut self) {
783 self.measurements += 1;
784
785 let total_time = self.compute_time + self.communication_time;
787 let compute_ratio = if !total_time.is_zero() {
788 self.compute_time.as_secs_f64() / total_time.as_secs_f64()
789 } else {
790 1.0
791 };
792
793 self.overall_efficiency = 0.5 * compute_ratio
795 + 0.3 * self.memory_efficiency
796 + 0.2 * self.parameter_update_efficiency;
797 self.overall_efficiency = self.overall_efficiency.clamp(0.0, 1.0);
798 }
799
800 pub fn update_memory_efficiency(&mut self, efficiency: f64) {
801 self.memory_efficiency = efficiency.clamp(0.0, 1.0);
802 self.update_efficiency();
803 }
804
805 pub fn update_parameter_efficiency(&mut self, efficiency: f64) {
806 self.parameter_update_efficiency = efficiency.clamp(0.0, 1.0);
807 self.update_efficiency();
808 }
809
810 pub fn get_compute_ratio(&self) -> f64 {
811 let total_time = self.compute_time + self.communication_time;
812 if !total_time.is_zero() {
813 self.compute_time.as_secs_f64() / total_time.as_secs_f64()
814 } else {
815 1.0
816 }
817 }
818
819 pub fn get_communication_ratio(&self) -> f64 {
820 let total_time = self.compute_time + self.communication_time;
821 if !total_time.is_zero() {
822 self.communication_time.as_secs_f64() / total_time.as_secs_f64()
823 } else {
824 0.0
825 }
826 }
827
828 pub fn get_overall_efficiency(&self) -> f64 {
829 self.overall_efficiency
830 }
831
832 pub fn merge(&mut self, other: &OptimizationEfficiency) {
833 self.compute_time += other.compute_time;
834 self.communication_time += other.communication_time;
835 self.measurements += other.measurements;
836
837 let total_measurements = self.measurements as f64;
839 if total_measurements > 0.0 {
840 let self_weight = (self.measurements - other.measurements) as f64 / total_measurements;
841 let other_weight = other.measurements as f64 / total_measurements;
842
843 self.memory_efficiency =
844 self_weight * self.memory_efficiency + other_weight * other.memory_efficiency;
845 self.parameter_update_efficiency = self_weight * self.parameter_update_efficiency
846 + other_weight * other.parameter_update_efficiency;
847 }
848
849 self.update_efficiency();
850 }
851}
852
853impl Default for OptimizationEfficiency {
854 fn default() -> Self {
855 Self::new()
856 }
857}
858
859#[derive(Debug, Clone)]
861pub struct PerformanceSummary {
862 pub total_operations: u64,
864 pub average_forward_time: Duration,
866 pub average_backward_time: Duration,
868 pub average_optimizer_time: Duration,
870 pub tokens_per_second: f64,
872 pub memory_bandwidth_gbps: f64,
874 pub communication_efficiency: f64,
876 pub training_efficiency: f64,
878 pub memory_transfer_efficiency: f64,
880 pub layer_performance: HashMap<String, LayerPerformanceSummary>,
882}
883
884#[derive(Debug, Clone)]
886pub struct LayerPerformanceSummary {
887 pub total_executions: u64,
889 pub average_forward_time: Duration,
891 pub average_backward_time: Duration,
893 pub total_time: Duration,
895}
896
897pub use super::memory_management::Zero3MemoryStats;
899
900#[cfg(test)]
901mod tests {
902 use super::*;
903
904 #[test]
905 fn test_performance_stats_creation() {
906 let stats = Zero3PerformanceStats::new();
907 assert_eq!(stats.forward_passes, 0);
908 assert_eq!(stats.backward_passes, 0);
909 assert_eq!(stats.optimizer_steps, 0);
910 assert_eq!(stats.total_forward_time, Duration::ZERO);
911 }
912
913 #[test]
914 fn test_record_forward_pass() {
915 let mut stats = Zero3PerformanceStats::new();
916 stats.record_forward_pass(Duration::from_millis(100), 1000);
917
918 assert_eq!(stats.forward_passes, 1);
919 assert_eq!(stats.total_forward_time, Duration::from_millis(100));
920 assert_eq!(stats.average_forward_time(), Duration::from_millis(100));
921 }
922
923 #[test]
924 fn test_layer_timing_stats() {
925 let mut layer_stats = LayerTimingStats::new();
926
927 layer_stats.record_forward_execution(Duration::from_millis(50));
928 layer_stats.record_backward_execution(Duration::from_millis(75));
929
930 assert_eq!(layer_stats.forward_executions, 1);
931 assert_eq!(layer_stats.backward_executions, 1);
932 assert_eq!(
933 layer_stats.average_forward_time(),
934 Duration::from_millis(50)
935 );
936 assert_eq!(
937 layer_stats.average_backward_time(),
938 Duration::from_millis(75)
939 );
940 }
941
942 #[test]
943 fn test_throughput_metrics() {
944 let mut metrics = ThroughputMetrics::new();
945
946 metrics.record_forward_pass(Duration::from_secs(1), 1000);
947 assert_eq!(metrics.get_tokens_per_second(), 1000.0);
948
949 metrics.record_backward_pass(Duration::from_secs(1), 1000);
950 assert_eq!(metrics.get_tokens_per_second(), 1000.0); }
952
953 #[test]
954 fn test_memory_transfer_metrics() {
955 let mut metrics = MemoryTransferMetrics::new();
956
957 metrics.record_transfer(Duration::from_secs(1), 1000, TransferDirection::CpuToGpu);
958 assert_eq!(metrics.cpu_to_gpu_bytes, 1000);
959 assert_eq!(metrics.cpu_to_gpu_transfers, 1);
960 assert_eq!(metrics.get_cpu_to_gpu_bandwidth(), 1000.0);
961 }
962
963 #[test]
964 fn test_communication_stats() {
965 let mut stats = CommunicationStats::new();
966
967 stats.record_operation(
968 CommunicationOperation::AllReduce,
969 Duration::from_millis(100),
970 1000,
971 );
972 assert_eq!(stats.allreduce_operations, 1);
973 assert_eq!(stats.allreduce_bytes, 1000);
974 assert_eq!(stats.get_allreduce_bandwidth(), 10000.0); }
976
977 #[test]
978 fn test_optimization_efficiency() {
979 let mut efficiency = OptimizationEfficiency::new();
980
981 efficiency.record_forward_pass(Duration::from_millis(800));
982 efficiency.record_communication(Duration::from_millis(200));
983
984 assert_eq!(efficiency.get_compute_ratio(), 0.8);
985 assert_eq!(efficiency.get_communication_ratio(), 0.2);
986 }
987
988 #[test]
989 fn test_stats_merging() {
990 let mut stats1 = Zero3PerformanceStats::new();
991 stats1.record_forward_pass(Duration::from_millis(100), 1000);
992
993 let mut stats2 = Zero3PerformanceStats::new();
994 stats2.record_forward_pass(Duration::from_millis(200), 2000);
995
996 stats1.merge(&stats2);
997 assert_eq!(stats1.forward_passes, 2);
998 assert_eq!(stats1.total_forward_time, Duration::from_millis(300));
999 }
1000
1001 #[test]
1002 fn test_performance_summary() {
1003 let mut stats = Zero3PerformanceStats::new();
1004 stats.record_forward_pass(Duration::from_millis(100), 1000);
1005 stats.record_backward_pass(Duration::from_millis(150), 1000);
1006 stats.record_optimizer_step(Duration::from_millis(50), 100);
1007
1008 let summary = stats.get_performance_summary();
1009 assert_eq!(summary.total_operations, 3);
1010 assert!(summary.tokens_per_second > 0.0);
1011 }
1012
1013 #[test]
1014 fn test_transfer_direction() {
1015 assert_eq!(TransferDirection::CpuToGpu, TransferDirection::CpuToGpu);
1016 assert_ne!(TransferDirection::CpuToGpu, TransferDirection::GpuToCpu);
1017 }
1018
1019 #[test]
1020 fn test_communication_operation() {
1021 assert_eq!(
1022 CommunicationOperation::AllReduce,
1023 CommunicationOperation::AllReduce
1024 );
1025 assert_ne!(
1026 CommunicationOperation::AllReduce,
1027 CommunicationOperation::Broadcast
1028 );
1029 }
1030}