Skip to main content

torsh_distributed/zero_3_cpu_offload/
stats.rs

1//! Performance and Memory Statistics for ZeRO-3 CPU Offloading
2//!
3//! This module provides comprehensive statistics collection and analysis for ZeRO-3
4//! (Zero Redundancy Optimizer Stage 3) with CPU offloading. It tracks performance
5//! metrics, memory usage patterns, throughput statistics, and provides detailed
6//! insights for optimization and monitoring.
7
8use std::collections::HashMap;
9use std::time::Duration;
10
11/// Comprehensive performance statistics for ZeRO-3 operations
12///
13/// Tracks all aspects of ZeRO-3 performance including:
14/// - Forward and backward pass timing
15/// - Parameter transfer and optimization statistics
16/// - Memory management performance
17/// - Distributed synchronization metrics
18/// - Throughput and efficiency measurements
19#[derive(Debug, Clone)]
20pub struct Zero3PerformanceStats {
21    /// Number of forward passes completed
22    pub forward_passes: u64,
23    /// Number of backward passes completed
24    pub backward_passes: u64,
25    /// Number of optimizer steps completed
26    pub optimizer_steps: u64,
27    /// Total time spent in forward passes
28    pub total_forward_time: Duration,
29    /// Total time spent in backward passes
30    pub total_backward_time: Duration,
31    /// Total time spent in optimizer steps
32    pub total_optimizer_time: Duration,
33    /// Time spent transferring parameters between CPU/GPU
34    pub parameter_transfer_time: Duration,
35    /// Time spent synchronizing gradients across ranks
36    pub gradient_sync_time: Duration,
37    /// Per-layer execution timings
38    pub layer_timings: HashMap<String, LayerTimingStats>,
39    /// Throughput metrics
40    pub throughput_metrics: ThroughputMetrics,
41    /// Memory transfer performance
42    pub memory_transfer_metrics: MemoryTransferMetrics,
43    /// Distributed communication statistics
44    pub communication_stats: CommunicationStats,
45    /// Optimization efficiency metrics
46    pub optimization_efficiency: OptimizationEfficiency,
47}
48
49impl Zero3PerformanceStats {
50    /// Create new performance statistics
51    pub fn new() -> Self {
52        Self {
53            forward_passes: 0,
54            backward_passes: 0,
55            optimizer_steps: 0,
56            total_forward_time: Duration::ZERO,
57            total_backward_time: Duration::ZERO,
58            total_optimizer_time: Duration::ZERO,
59            parameter_transfer_time: Duration::ZERO,
60            gradient_sync_time: Duration::ZERO,
61            layer_timings: HashMap::new(),
62            throughput_metrics: ThroughputMetrics::new(),
63            memory_transfer_metrics: MemoryTransferMetrics::new(),
64            communication_stats: CommunicationStats::new(),
65            optimization_efficiency: OptimizationEfficiency::new(),
66        }
67    }
68
69    /// Record a completed forward pass
70    pub fn record_forward_pass(&mut self, duration: Duration, num_tokens: usize) {
71        self.forward_passes += 1;
72        self.total_forward_time += duration;
73        self.throughput_metrics
74            .record_forward_pass(duration, num_tokens);
75        self.optimization_efficiency.record_forward_pass(duration);
76    }
77
78    /// Record a completed backward pass
79    pub fn record_backward_pass(&mut self, duration: Duration, num_tokens: usize) {
80        self.backward_passes += 1;
81        self.total_backward_time += duration;
82        self.throughput_metrics
83            .record_backward_pass(duration, num_tokens);
84        self.optimization_efficiency.record_backward_pass(duration);
85    }
86
87    /// Record a completed optimizer step
88    pub fn record_optimizer_step(&mut self, duration: Duration, num_params: usize) {
89        self.optimizer_steps += 1;
90        self.total_optimizer_time += duration;
91        self.optimization_efficiency
92            .record_optimizer_step(duration, num_params);
93    }
94
95    /// Record layer execution timing
96    pub fn record_layer_execution(&mut self, layer_name: String, duration: Duration) {
97        let layer_stats = self.layer_timings.entry(layer_name.clone()).or_default();
98        layer_stats.record_forward_execution(duration);
99    }
100
101    /// Record layer backward pass timing
102    pub fn record_layer_backward(&mut self, layer_name: String, duration: Duration) {
103        let layer_stats = self.layer_timings.entry(layer_name).or_default();
104        layer_stats.record_backward_execution(duration);
105    }
106
107    /// Record parameter transfer operation
108    pub fn record_parameter_transfer(
109        &mut self,
110        duration: Duration,
111        bytes_transferred: usize,
112        direction: TransferDirection,
113    ) {
114        self.parameter_transfer_time += duration;
115        self.memory_transfer_metrics
116            .record_transfer(duration, bytes_transferred, direction);
117    }
118
119    /// Record gradient synchronization
120    pub fn record_gradient_sync(
121        &mut self,
122        duration: Duration,
123        num_gradients: usize,
124        world_size: usize,
125    ) {
126        self.gradient_sync_time += duration;
127        self.communication_stats
128            .record_gradient_sync(duration, num_gradients, world_size);
129    }
130
131    /// Record distributed communication operation
132    pub fn record_communication(
133        &mut self,
134        operation: CommunicationOperation,
135        duration: Duration,
136        bytes: usize,
137    ) {
138        self.communication_stats
139            .record_operation(operation, duration, bytes);
140    }
141
142    /// Get average forward pass time
143    pub fn average_forward_time(&self) -> Duration {
144        if self.forward_passes > 0 {
145            self.total_forward_time / self.forward_passes as u32
146        } else {
147            Duration::ZERO
148        }
149    }
150
151    /// Get average backward pass time
152    pub fn average_backward_time(&self) -> Duration {
153        if self.backward_passes > 0 {
154            self.total_backward_time / self.backward_passes as u32
155        } else {
156            Duration::ZERO
157        }
158    }
159
160    /// Get average optimizer step time
161    pub fn average_optimizer_time(&self) -> Duration {
162        if self.optimizer_steps > 0 {
163            self.total_optimizer_time / self.optimizer_steps as u32
164        } else {
165            Duration::ZERO
166        }
167    }
168
169    /// Get tokens per second throughput
170    pub fn get_tokens_per_second(&self) -> f64 {
171        self.throughput_metrics.get_tokens_per_second()
172    }
173
174    /// Get memory transfer bandwidth in GB/s
175    pub fn get_memory_bandwidth_gbps(&self) -> f64 {
176        self.memory_transfer_metrics.get_bandwidth_gbps()
177    }
178
179    /// Get communication efficiency metrics
180    pub fn get_communication_efficiency(&self) -> f64 {
181        self.communication_stats.get_efficiency()
182    }
183
184    /// Get overall training efficiency score (0.0 to 1.0)
185    pub fn get_training_efficiency(&self) -> f64 {
186        self.optimization_efficiency.get_overall_efficiency()
187    }
188
189    /// Get detailed performance summary
190    pub fn get_performance_summary(&self) -> PerformanceSummary {
191        PerformanceSummary {
192            total_operations: self.forward_passes + self.backward_passes + self.optimizer_steps,
193            average_forward_time: self.average_forward_time(),
194            average_backward_time: self.average_backward_time(),
195            average_optimizer_time: self.average_optimizer_time(),
196            tokens_per_second: self.get_tokens_per_second(),
197            memory_bandwidth_gbps: self.get_memory_bandwidth_gbps(),
198            communication_efficiency: self.get_communication_efficiency(),
199            training_efficiency: self.get_training_efficiency(),
200            memory_transfer_efficiency: self.memory_transfer_metrics.get_efficiency(),
201            layer_performance: self.get_layer_performance_summary(),
202        }
203    }
204
205    /// Get layer performance summary
206    fn get_layer_performance_summary(&self) -> HashMap<String, LayerPerformanceSummary> {
207        self.layer_timings
208            .iter()
209            .map(|(name, stats)| {
210                (
211                    name.clone(),
212                    LayerPerformanceSummary {
213                        total_executions: stats.forward_executions + stats.backward_executions,
214                        average_forward_time: stats.average_forward_time(),
215                        average_backward_time: stats.average_backward_time(),
216                        total_time: stats.total_forward_time + stats.total_backward_time,
217                    },
218                )
219            })
220            .collect()
221    }
222
223    /// Reset all statistics
224    pub fn reset(&mut self) {
225        *self = Self::new();
226    }
227
228    /// Merge statistics from another instance (useful for distributed aggregation)
229    pub fn merge(&mut self, other: &Zero3PerformanceStats) {
230        self.forward_passes += other.forward_passes;
231        self.backward_passes += other.backward_passes;
232        self.optimizer_steps += other.optimizer_steps;
233        self.total_forward_time += other.total_forward_time;
234        self.total_backward_time += other.total_backward_time;
235        self.total_optimizer_time += other.total_optimizer_time;
236        self.parameter_transfer_time += other.parameter_transfer_time;
237        self.gradient_sync_time += other.gradient_sync_time;
238
239        // Merge layer timings
240        for (layer_name, other_stats) in &other.layer_timings {
241            let layer_stats = self.layer_timings.entry(layer_name.clone()).or_default();
242            layer_stats.merge(other_stats);
243        }
244
245        self.throughput_metrics.merge(&other.throughput_metrics);
246        self.memory_transfer_metrics
247            .merge(&other.memory_transfer_metrics);
248        self.communication_stats.merge(&other.communication_stats);
249        self.optimization_efficiency
250            .merge(&other.optimization_efficiency);
251    }
252}
253
254impl Default for Zero3PerformanceStats {
255    fn default() -> Self {
256        Self::new()
257    }
258}
259
260/// Per-layer timing statistics
261#[derive(Debug, Clone)]
262pub struct LayerTimingStats {
263    /// Number of forward executions
264    pub forward_executions: u64,
265    /// Number of backward executions
266    pub backward_executions: u64,
267    /// Total time spent in forward passes
268    pub total_forward_time: Duration,
269    /// Total time spent in backward passes
270    pub total_backward_time: Duration,
271    /// Minimum forward execution time
272    pub min_forward_time: Duration,
273    /// Maximum forward execution time
274    pub max_forward_time: Duration,
275    /// Minimum backward execution time
276    pub min_backward_time: Duration,
277    /// Maximum backward execution time
278    pub max_backward_time: Duration,
279}
280
281impl LayerTimingStats {
282    pub fn new() -> Self {
283        Self {
284            forward_executions: 0,
285            backward_executions: 0,
286            total_forward_time: Duration::ZERO,
287            total_backward_time: Duration::ZERO,
288            min_forward_time: Duration::MAX,
289            max_forward_time: Duration::ZERO,
290            min_backward_time: Duration::MAX,
291            max_backward_time: Duration::ZERO,
292        }
293    }
294
295    pub fn record_forward_execution(&mut self, duration: Duration) {
296        self.forward_executions += 1;
297        self.total_forward_time += duration;
298        self.min_forward_time = self.min_forward_time.min(duration);
299        self.max_forward_time = self.max_forward_time.max(duration);
300    }
301
302    pub fn record_backward_execution(&mut self, duration: Duration) {
303        self.backward_executions += 1;
304        self.total_backward_time += duration;
305        self.min_backward_time = self.min_backward_time.min(duration);
306        self.max_backward_time = self.max_backward_time.max(duration);
307    }
308
309    pub fn average_forward_time(&self) -> Duration {
310        if self.forward_executions > 0 {
311            self.total_forward_time / self.forward_executions as u32
312        } else {
313            Duration::ZERO
314        }
315    }
316
317    pub fn average_backward_time(&self) -> Duration {
318        if self.backward_executions > 0 {
319            self.total_backward_time / self.backward_executions as u32
320        } else {
321            Duration::ZERO
322        }
323    }
324
325    pub fn merge(&mut self, other: &LayerTimingStats) {
326        self.forward_executions += other.forward_executions;
327        self.backward_executions += other.backward_executions;
328        self.total_forward_time += other.total_forward_time;
329        self.total_backward_time += other.total_backward_time;
330        self.min_forward_time = self.min_forward_time.min(other.min_forward_time);
331        self.max_forward_time = self.max_forward_time.max(other.max_forward_time);
332        self.min_backward_time = self.min_backward_time.min(other.min_backward_time);
333        self.max_backward_time = self.max_backward_time.max(other.max_backward_time);
334    }
335}
336
337impl Default for LayerTimingStats {
338    fn default() -> Self {
339        Self::new()
340    }
341}
342
343/// Throughput metrics for training operations
344#[derive(Debug, Clone)]
345pub struct ThroughputMetrics {
346    /// Total tokens processed in forward passes
347    pub total_forward_tokens: usize,
348    /// Total tokens processed in backward passes
349    pub total_backward_tokens: usize,
350    /// Total time spent in forward passes
351    pub total_forward_time: Duration,
352    /// Total time spent in backward passes
353    pub total_backward_time: Duration,
354    /// Peak tokens per second observed
355    pub peak_tokens_per_second: f64,
356    /// Average tokens per second (rolling window)
357    pub rolling_average_tps: f64,
358    /// Number of samples in rolling average
359    pub rolling_samples: u32,
360}
361
362impl ThroughputMetrics {
363    pub fn new() -> Self {
364        Self {
365            total_forward_tokens: 0,
366            total_backward_tokens: 0,
367            total_forward_time: Duration::ZERO,
368            total_backward_time: Duration::ZERO,
369            peak_tokens_per_second: 0.0,
370            rolling_average_tps: 0.0,
371            rolling_samples: 0,
372        }
373    }
374
375    pub fn record_forward_pass(&mut self, duration: Duration, num_tokens: usize) {
376        self.total_forward_tokens += num_tokens;
377        self.total_forward_time += duration;
378        self.update_rolling_average(duration, num_tokens);
379    }
380
381    pub fn record_backward_pass(&mut self, duration: Duration, num_tokens: usize) {
382        self.total_backward_tokens += num_tokens;
383        self.total_backward_time += duration;
384        self.update_rolling_average(duration, num_tokens);
385    }
386
387    fn update_rolling_average(&mut self, duration: Duration, num_tokens: usize) {
388        if !duration.is_zero() {
389            let current_tps = num_tokens as f64 / duration.as_secs_f64();
390            self.peak_tokens_per_second = self.peak_tokens_per_second.max(current_tps);
391
392            // Update rolling average with exponential decay
393            let alpha = 0.1; // Smoothing factor
394            if self.rolling_samples == 0 {
395                self.rolling_average_tps = current_tps;
396            } else {
397                self.rolling_average_tps =
398                    alpha * current_tps + (1.0 - alpha) * self.rolling_average_tps;
399            }
400            self.rolling_samples += 1;
401        }
402    }
403
404    pub fn get_tokens_per_second(&self) -> f64 {
405        let total_time = self.total_forward_time + self.total_backward_time;
406        let total_tokens = self.total_forward_tokens + self.total_backward_tokens;
407
408        if !total_time.is_zero() && total_tokens > 0 {
409            total_tokens as f64 / total_time.as_secs_f64()
410        } else {
411            0.0
412        }
413    }
414
415    pub fn get_forward_tps(&self) -> f64 {
416        if !self.total_forward_time.is_zero() && self.total_forward_tokens > 0 {
417            self.total_forward_tokens as f64 / self.total_forward_time.as_secs_f64()
418        } else {
419            0.0
420        }
421    }
422
423    pub fn get_backward_tps(&self) -> f64 {
424        if !self.total_backward_time.is_zero() && self.total_backward_tokens > 0 {
425            self.total_backward_tokens as f64 / self.total_backward_time.as_secs_f64()
426        } else {
427            0.0
428        }
429    }
430
431    pub fn merge(&mut self, other: &ThroughputMetrics) {
432        self.total_forward_tokens += other.total_forward_tokens;
433        self.total_backward_tokens += other.total_backward_tokens;
434        self.total_forward_time += other.total_forward_time;
435        self.total_backward_time += other.total_backward_time;
436        self.peak_tokens_per_second = self
437            .peak_tokens_per_second
438            .max(other.peak_tokens_per_second);
439
440        // Merge rolling averages (weighted by sample count)
441        let total_samples = self.rolling_samples + other.rolling_samples;
442        if total_samples > 0 {
443            let self_weight = self.rolling_samples as f64 / total_samples as f64;
444            let other_weight = other.rolling_samples as f64 / total_samples as f64;
445            self.rolling_average_tps =
446                self_weight * self.rolling_average_tps + other_weight * other.rolling_average_tps;
447            self.rolling_samples = total_samples;
448        }
449    }
450}
451
452impl Default for ThroughputMetrics {
453    fn default() -> Self {
454        Self::new()
455    }
456}
457
458/// Memory transfer performance metrics
459#[derive(Debug, Clone)]
460pub struct MemoryTransferMetrics {
461    /// Total bytes transferred CPU to GPU
462    pub cpu_to_gpu_bytes: usize,
463    /// Total bytes transferred GPU to CPU
464    pub gpu_to_cpu_bytes: usize,
465    /// Time spent in CPU to GPU transfers
466    pub cpu_to_gpu_time: Duration,
467    /// Time spent in GPU to CPU transfers
468    pub gpu_to_cpu_time: Duration,
469    /// Number of CPU to GPU transfers
470    pub cpu_to_gpu_transfers: u64,
471    /// Number of GPU to CPU transfers
472    pub gpu_to_cpu_transfers: u64,
473    /// Peak transfer bandwidth observed (bytes/sec)
474    pub peak_bandwidth: f64,
475    /// Transfer efficiency (actual vs theoretical bandwidth)
476    pub transfer_efficiency: f64,
477}
478
479impl MemoryTransferMetrics {
480    pub fn new() -> Self {
481        Self {
482            cpu_to_gpu_bytes: 0,
483            gpu_to_cpu_bytes: 0,
484            cpu_to_gpu_time: Duration::ZERO,
485            gpu_to_cpu_time: Duration::ZERO,
486            cpu_to_gpu_transfers: 0,
487            gpu_to_cpu_transfers: 0,
488            peak_bandwidth: 0.0,
489            transfer_efficiency: 1.0,
490        }
491    }
492
493    pub fn record_transfer(
494        &mut self,
495        duration: Duration,
496        bytes: usize,
497        direction: TransferDirection,
498    ) {
499        if !duration.is_zero() {
500            let bandwidth = bytes as f64 / duration.as_secs_f64();
501            self.peak_bandwidth = self.peak_bandwidth.max(bandwidth);
502        }
503
504        match direction {
505            TransferDirection::CpuToGpu => {
506                self.cpu_to_gpu_bytes += bytes;
507                self.cpu_to_gpu_time += duration;
508                self.cpu_to_gpu_transfers += 1;
509            }
510            TransferDirection::GpuToCpu => {
511                self.gpu_to_cpu_bytes += bytes;
512                self.gpu_to_cpu_time += duration;
513                self.gpu_to_cpu_transfers += 1;
514            }
515        }
516
517        self.update_efficiency();
518    }
519
520    fn update_efficiency(&mut self) {
521        // Estimate efficiency based on achieved vs theoretical bandwidth
522        // This is a simplified calculation; real implementation would use hardware specs
523        let theoretical_bandwidth = 1_000_000_000.0; // 1 GB/s theoretical
524        let actual_bandwidth = self.get_bandwidth_bps();
525
526        if theoretical_bandwidth > 0.0 {
527            self.transfer_efficiency = (actual_bandwidth / theoretical_bandwidth).min(1.0);
528        }
529    }
530
531    pub fn get_bandwidth_gbps(&self) -> f64 {
532        self.get_bandwidth_bps() / (1024.0 * 1024.0 * 1024.0)
533    }
534
535    pub fn get_bandwidth_bps(&self) -> f64 {
536        let total_bytes = self.cpu_to_gpu_bytes + self.gpu_to_cpu_bytes;
537        let total_time = self.cpu_to_gpu_time + self.gpu_to_cpu_time;
538
539        if !total_time.is_zero() && total_bytes > 0 {
540            total_bytes as f64 / total_time.as_secs_f64()
541        } else {
542            0.0
543        }
544    }
545
546    pub fn get_cpu_to_gpu_bandwidth(&self) -> f64 {
547        if !self.cpu_to_gpu_time.is_zero() && self.cpu_to_gpu_bytes > 0 {
548            self.cpu_to_gpu_bytes as f64 / self.cpu_to_gpu_time.as_secs_f64()
549        } else {
550            0.0
551        }
552    }
553
554    pub fn get_gpu_to_cpu_bandwidth(&self) -> f64 {
555        if !self.gpu_to_cpu_time.is_zero() && self.gpu_to_cpu_bytes > 0 {
556            self.gpu_to_cpu_bytes as f64 / self.gpu_to_cpu_time.as_secs_f64()
557        } else {
558            0.0
559        }
560    }
561
562    pub fn get_efficiency(&self) -> f64 {
563        self.transfer_efficiency
564    }
565
566    pub fn merge(&mut self, other: &MemoryTransferMetrics) {
567        self.cpu_to_gpu_bytes += other.cpu_to_gpu_bytes;
568        self.gpu_to_cpu_bytes += other.gpu_to_cpu_bytes;
569        self.cpu_to_gpu_time += other.cpu_to_gpu_time;
570        self.gpu_to_cpu_time += other.gpu_to_cpu_time;
571        self.cpu_to_gpu_transfers += other.cpu_to_gpu_transfers;
572        self.gpu_to_cpu_transfers += other.gpu_to_cpu_transfers;
573        self.peak_bandwidth = self.peak_bandwidth.max(other.peak_bandwidth);
574        self.update_efficiency();
575    }
576}
577
578impl Default for MemoryTransferMetrics {
579    fn default() -> Self {
580        Self::new()
581    }
582}
583
584/// Direction of memory transfer
585#[derive(Debug, Clone, Copy, PartialEq, Eq)]
586pub enum TransferDirection {
587    CpuToGpu,
588    GpuToCpu,
589}
590
591/// Distributed communication statistics
592#[derive(Debug, Clone)]
593pub struct CommunicationStats {
594    /// Number of all-reduce operations
595    pub allreduce_operations: u64,
596    /// Total time spent in all-reduce
597    pub allreduce_time: Duration,
598    /// Total bytes all-reduced
599    pub allreduce_bytes: usize,
600    /// Number of broadcast operations
601    pub broadcast_operations: u64,
602    /// Total time spent in broadcast
603    pub broadcast_time: Duration,
604    /// Total bytes broadcast
605    pub broadcast_bytes: usize,
606    /// Number of point-to-point communications
607    pub p2p_operations: u64,
608    /// Total time spent in point-to-point communication
609    pub p2p_time: Duration,
610    /// Total bytes in point-to-point communication
611    pub p2p_bytes: usize,
612    /// Communication efficiency (achieved vs theoretical)
613    pub communication_efficiency: f64,
614}
615
616impl CommunicationStats {
617    pub fn new() -> Self {
618        Self {
619            allreduce_operations: 0,
620            allreduce_time: Duration::ZERO,
621            allreduce_bytes: 0,
622            broadcast_operations: 0,
623            broadcast_time: Duration::ZERO,
624            broadcast_bytes: 0,
625            p2p_operations: 0,
626            p2p_time: Duration::ZERO,
627            p2p_bytes: 0,
628            communication_efficiency: 1.0,
629        }
630    }
631
632    pub fn record_gradient_sync(
633        &mut self,
634        duration: Duration,
635        num_gradients: usize,
636        world_size: usize,
637    ) {
638        // Gradient sync typically uses all-reduce
639        self.allreduce_operations += 1;
640        self.allreduce_time += duration;
641        // Estimate bytes (simplified calculation)
642        let estimated_bytes = num_gradients * 4 * world_size; // Assuming f32 gradients
643        self.allreduce_bytes += estimated_bytes;
644        self.update_efficiency();
645    }
646
647    pub fn record_operation(
648        &mut self,
649        operation: CommunicationOperation,
650        duration: Duration,
651        bytes: usize,
652    ) {
653        match operation {
654            CommunicationOperation::AllReduce => {
655                self.allreduce_operations += 1;
656                self.allreduce_time += duration;
657                self.allreduce_bytes += bytes;
658            }
659            CommunicationOperation::Broadcast => {
660                self.broadcast_operations += 1;
661                self.broadcast_time += duration;
662                self.broadcast_bytes += bytes;
663            }
664            CommunicationOperation::PointToPoint => {
665                self.p2p_operations += 1;
666                self.p2p_time += duration;
667                self.p2p_bytes += bytes;
668            }
669        }
670        self.update_efficiency();
671    }
672
673    fn update_efficiency(&mut self) {
674        // Simplified efficiency calculation
675        let total_time = self.allreduce_time + self.broadcast_time + self.p2p_time;
676        let total_bytes = self.allreduce_bytes + self.broadcast_bytes + self.p2p_bytes;
677
678        if !total_time.is_zero() && total_bytes > 0 {
679            let achieved_bandwidth = total_bytes as f64 / total_time.as_secs_f64();
680            let theoretical_bandwidth = 10_000_000_000.0; // 10 GB/s theoretical network
681            self.communication_efficiency = (achieved_bandwidth / theoretical_bandwidth).min(1.0);
682        }
683    }
684
685    pub fn get_efficiency(&self) -> f64 {
686        self.communication_efficiency
687    }
688
689    pub fn get_allreduce_bandwidth(&self) -> f64 {
690        if !self.allreduce_time.is_zero() && self.allreduce_bytes > 0 {
691            self.allreduce_bytes as f64 / self.allreduce_time.as_secs_f64()
692        } else {
693            0.0
694        }
695    }
696
697    pub fn get_broadcast_bandwidth(&self) -> f64 {
698        if !self.broadcast_time.is_zero() && self.broadcast_bytes > 0 {
699            self.broadcast_bytes as f64 / self.broadcast_time.as_secs_f64()
700        } else {
701            0.0
702        }
703    }
704
705    pub fn merge(&mut self, other: &CommunicationStats) {
706        self.allreduce_operations += other.allreduce_operations;
707        self.allreduce_time += other.allreduce_time;
708        self.allreduce_bytes += other.allreduce_bytes;
709        self.broadcast_operations += other.broadcast_operations;
710        self.broadcast_time += other.broadcast_time;
711        self.broadcast_bytes += other.broadcast_bytes;
712        self.p2p_operations += other.p2p_operations;
713        self.p2p_time += other.p2p_time;
714        self.p2p_bytes += other.p2p_bytes;
715        self.update_efficiency();
716    }
717}
718
719impl Default for CommunicationStats {
720    fn default() -> Self {
721        Self::new()
722    }
723}
724
725/// Types of distributed communication operations
726#[derive(Debug, Clone, Copy, PartialEq, Eq)]
727pub enum CommunicationOperation {
728    AllReduce,
729    Broadcast,
730    PointToPoint,
731}
732
733/// Optimization efficiency metrics
734#[derive(Debug, Clone)]
735pub struct OptimizationEfficiency {
736    /// Time spent in computation vs communication
737    pub compute_time: Duration,
738    /// Time spent in communication
739    pub communication_time: Duration,
740    /// Memory utilization efficiency (0.0 to 1.0)
741    pub memory_efficiency: f64,
742    /// Parameter update efficiency
743    pub parameter_update_efficiency: f64,
744    /// Overall training efficiency score
745    pub overall_efficiency: f64,
746    /// Number of efficiency measurements
747    pub measurements: u32,
748}
749
750impl OptimizationEfficiency {
751    pub fn new() -> Self {
752        Self {
753            compute_time: Duration::ZERO,
754            communication_time: Duration::ZERO,
755            memory_efficiency: 1.0,
756            parameter_update_efficiency: 1.0,
757            overall_efficiency: 1.0,
758            measurements: 0,
759        }
760    }
761
762    pub fn record_forward_pass(&mut self, duration: Duration) {
763        self.compute_time += duration;
764        self.update_efficiency();
765    }
766
767    pub fn record_backward_pass(&mut self, duration: Duration) {
768        self.compute_time += duration;
769        self.update_efficiency();
770    }
771
772    pub fn record_optimizer_step(&mut self, duration: Duration, _num_params: usize) {
773        self.compute_time += duration;
774        self.update_efficiency();
775    }
776
777    pub fn record_communication(&mut self, duration: Duration) {
778        self.communication_time += duration;
779        self.update_efficiency();
780    }
781
782    fn update_efficiency(&mut self) {
783        self.measurements += 1;
784
785        // Calculate compute vs communication ratio
786        let total_time = self.compute_time + self.communication_time;
787        let compute_ratio = if !total_time.is_zero() {
788            self.compute_time.as_secs_f64() / total_time.as_secs_f64()
789        } else {
790            1.0
791        };
792
793        // Overall efficiency is weighted combination of different factors
794        self.overall_efficiency = 0.5 * compute_ratio
795            + 0.3 * self.memory_efficiency
796            + 0.2 * self.parameter_update_efficiency;
797        self.overall_efficiency = self.overall_efficiency.clamp(0.0, 1.0);
798    }
799
800    pub fn update_memory_efficiency(&mut self, efficiency: f64) {
801        self.memory_efficiency = efficiency.clamp(0.0, 1.0);
802        self.update_efficiency();
803    }
804
805    pub fn update_parameter_efficiency(&mut self, efficiency: f64) {
806        self.parameter_update_efficiency = efficiency.clamp(0.0, 1.0);
807        self.update_efficiency();
808    }
809
810    pub fn get_compute_ratio(&self) -> f64 {
811        let total_time = self.compute_time + self.communication_time;
812        if !total_time.is_zero() {
813            self.compute_time.as_secs_f64() / total_time.as_secs_f64()
814        } else {
815            1.0
816        }
817    }
818
819    pub fn get_communication_ratio(&self) -> f64 {
820        let total_time = self.compute_time + self.communication_time;
821        if !total_time.is_zero() {
822            self.communication_time.as_secs_f64() / total_time.as_secs_f64()
823        } else {
824            0.0
825        }
826    }
827
828    pub fn get_overall_efficiency(&self) -> f64 {
829        self.overall_efficiency
830    }
831
832    pub fn merge(&mut self, other: &OptimizationEfficiency) {
833        self.compute_time += other.compute_time;
834        self.communication_time += other.communication_time;
835        self.measurements += other.measurements;
836
837        // Merge efficiency metrics (weighted average)
838        let total_measurements = self.measurements as f64;
839        if total_measurements > 0.0 {
840            let self_weight = (self.measurements - other.measurements) as f64 / total_measurements;
841            let other_weight = other.measurements as f64 / total_measurements;
842
843            self.memory_efficiency =
844                self_weight * self.memory_efficiency + other_weight * other.memory_efficiency;
845            self.parameter_update_efficiency = self_weight * self.parameter_update_efficiency
846                + other_weight * other.parameter_update_efficiency;
847        }
848
849        self.update_efficiency();
850    }
851}
852
853impl Default for OptimizationEfficiency {
854    fn default() -> Self {
855        Self::new()
856    }
857}
858
859/// High-level performance summary
860#[derive(Debug, Clone)]
861pub struct PerformanceSummary {
862    /// Total number of operations (forward + backward + optimizer)
863    pub total_operations: u64,
864    /// Average forward pass time
865    pub average_forward_time: Duration,
866    /// Average backward pass time
867    pub average_backward_time: Duration,
868    /// Average optimizer step time
869    pub average_optimizer_time: Duration,
870    /// Tokens processed per second
871    pub tokens_per_second: f64,
872    /// Memory bandwidth in GB/s
873    pub memory_bandwidth_gbps: f64,
874    /// Communication efficiency (0.0 to 1.0)
875    pub communication_efficiency: f64,
876    /// Overall training efficiency (0.0 to 1.0)
877    pub training_efficiency: f64,
878    /// Memory transfer efficiency (0.0 to 1.0)
879    pub memory_transfer_efficiency: f64,
880    /// Per-layer performance summary
881    pub layer_performance: HashMap<String, LayerPerformanceSummary>,
882}
883
884/// Per-layer performance summary
885#[derive(Debug, Clone)]
886pub struct LayerPerformanceSummary {
887    /// Total number of executions (forward + backward)
888    pub total_executions: u64,
889    /// Average forward execution time
890    pub average_forward_time: Duration,
891    /// Average backward execution time
892    pub average_backward_time: Duration,
893    /// Total time spent in this layer
894    pub total_time: Duration,
895}
896
897/// Memory statistics for ZeRO-3 (re-exported from memory_management module)
898pub use super::memory_management::Zero3MemoryStats;
899
900#[cfg(test)]
901mod tests {
902    use super::*;
903
904    #[test]
905    fn test_performance_stats_creation() {
906        let stats = Zero3PerformanceStats::new();
907        assert_eq!(stats.forward_passes, 0);
908        assert_eq!(stats.backward_passes, 0);
909        assert_eq!(stats.optimizer_steps, 0);
910        assert_eq!(stats.total_forward_time, Duration::ZERO);
911    }
912
913    #[test]
914    fn test_record_forward_pass() {
915        let mut stats = Zero3PerformanceStats::new();
916        stats.record_forward_pass(Duration::from_millis(100), 1000);
917
918        assert_eq!(stats.forward_passes, 1);
919        assert_eq!(stats.total_forward_time, Duration::from_millis(100));
920        assert_eq!(stats.average_forward_time(), Duration::from_millis(100));
921    }
922
923    #[test]
924    fn test_layer_timing_stats() {
925        let mut layer_stats = LayerTimingStats::new();
926
927        layer_stats.record_forward_execution(Duration::from_millis(50));
928        layer_stats.record_backward_execution(Duration::from_millis(75));
929
930        assert_eq!(layer_stats.forward_executions, 1);
931        assert_eq!(layer_stats.backward_executions, 1);
932        assert_eq!(
933            layer_stats.average_forward_time(),
934            Duration::from_millis(50)
935        );
936        assert_eq!(
937            layer_stats.average_backward_time(),
938            Duration::from_millis(75)
939        );
940    }
941
942    #[test]
943    fn test_throughput_metrics() {
944        let mut metrics = ThroughputMetrics::new();
945
946        metrics.record_forward_pass(Duration::from_secs(1), 1000);
947        assert_eq!(metrics.get_tokens_per_second(), 1000.0);
948
949        metrics.record_backward_pass(Duration::from_secs(1), 1000);
950        assert_eq!(metrics.get_tokens_per_second(), 1000.0); // 2000 tokens in 2 seconds
951    }
952
953    #[test]
954    fn test_memory_transfer_metrics() {
955        let mut metrics = MemoryTransferMetrics::new();
956
957        metrics.record_transfer(Duration::from_secs(1), 1000, TransferDirection::CpuToGpu);
958        assert_eq!(metrics.cpu_to_gpu_bytes, 1000);
959        assert_eq!(metrics.cpu_to_gpu_transfers, 1);
960        assert_eq!(metrics.get_cpu_to_gpu_bandwidth(), 1000.0);
961    }
962
963    #[test]
964    fn test_communication_stats() {
965        let mut stats = CommunicationStats::new();
966
967        stats.record_operation(
968            CommunicationOperation::AllReduce,
969            Duration::from_millis(100),
970            1000,
971        );
972        assert_eq!(stats.allreduce_operations, 1);
973        assert_eq!(stats.allreduce_bytes, 1000);
974        assert_eq!(stats.get_allreduce_bandwidth(), 10000.0); // 1000 bytes / 0.1 seconds = 10000 bytes/sec
975    }
976
977    #[test]
978    fn test_optimization_efficiency() {
979        let mut efficiency = OptimizationEfficiency::new();
980
981        efficiency.record_forward_pass(Duration::from_millis(800));
982        efficiency.record_communication(Duration::from_millis(200));
983
984        assert_eq!(efficiency.get_compute_ratio(), 0.8);
985        assert_eq!(efficiency.get_communication_ratio(), 0.2);
986    }
987
988    #[test]
989    fn test_stats_merging() {
990        let mut stats1 = Zero3PerformanceStats::new();
991        stats1.record_forward_pass(Duration::from_millis(100), 1000);
992
993        let mut stats2 = Zero3PerformanceStats::new();
994        stats2.record_forward_pass(Duration::from_millis(200), 2000);
995
996        stats1.merge(&stats2);
997        assert_eq!(stats1.forward_passes, 2);
998        assert_eq!(stats1.total_forward_time, Duration::from_millis(300));
999    }
1000
1001    #[test]
1002    fn test_performance_summary() {
1003        let mut stats = Zero3PerformanceStats::new();
1004        stats.record_forward_pass(Duration::from_millis(100), 1000);
1005        stats.record_backward_pass(Duration::from_millis(150), 1000);
1006        stats.record_optimizer_step(Duration::from_millis(50), 100);
1007
1008        let summary = stats.get_performance_summary();
1009        assert_eq!(summary.total_operations, 3);
1010        assert!(summary.tokens_per_second > 0.0);
1011    }
1012
1013    #[test]
1014    fn test_transfer_direction() {
1015        assert_eq!(TransferDirection::CpuToGpu, TransferDirection::CpuToGpu);
1016        assert_ne!(TransferDirection::CpuToGpu, TransferDirection::GpuToCpu);
1017    }
1018
1019    #[test]
1020    fn test_communication_operation() {
1021        assert_eq!(
1022            CommunicationOperation::AllReduce,
1023            CommunicationOperation::AllReduce
1024        );
1025        assert_ne!(
1026            CommunicationOperation::AllReduce,
1027            CommunicationOperation::Broadcast
1028        );
1029    }
1030}