Skip to main content

trustformers_debug/utilities/
performance.rs

1//! Performance monitoring and profiling utilities
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::time::{Duration, Instant};
6
7/// Performance monitoring utilities
8#[derive(Debug)]
9pub struct PerformanceMonitor {
10    start_time: Instant,
11    checkpoints: HashMap<String, Instant>,
12    durations: HashMap<String, Duration>,
13}
14
15impl PerformanceMonitor {
16    pub fn new() -> Self {
17        Self {
18            start_time: Instant::now(),
19            checkpoints: HashMap::new(),
20            durations: HashMap::new(),
21        }
22    }
23
24    pub fn checkpoint(&mut self, name: &str) {
25        self.checkpoints.insert(name.to_string(), Instant::now());
26    }
27
28    pub fn end_checkpoint(&mut self, name: &str) -> Option<Duration> {
29        if let Some(start) = self.checkpoints.remove(name) {
30            let duration = start.elapsed();
31            self.durations.insert(name.to_string(), duration);
32            Some(duration)
33        } else {
34            None
35        }
36    }
37
38    pub fn total_elapsed(&self) -> Duration {
39        self.start_time.elapsed()
40    }
41
42    pub fn get_durations(&self) -> &HashMap<String, Duration> {
43        &self.durations
44    }
45
46    pub fn performance_report(&self) -> String {
47        let mut report = format!(
48            "Performance Report - Total: {:.2}ms\n",
49            self.total_elapsed().as_millis()
50        );
51
52        for (name, duration) in &self.durations {
53            report.push_str(&format!("  {}: {:.2}ms\n", name, duration.as_millis()));
54        }
55
56        report
57    }
58
59    /// Get detailed performance metrics
60    pub fn get_detailed_metrics(&self) -> PerformanceMetrics {
61        let total_duration = self.total_elapsed();
62        let checkpoint_count = self.durations.len();
63
64        let avg_checkpoint_duration = if checkpoint_count > 0 {
65            self.durations.values().map(|d| d.as_millis() as f64).sum::<f64>()
66                / checkpoint_count as f64
67        } else {
68            0.0
69        };
70
71        let slowest_checkpoint = self
72            .durations
73            .iter()
74            .max_by_key(|(_, duration)| *duration)
75            .map(|(name, duration)| (name.clone(), *duration));
76
77        let fastest_checkpoint = self
78            .durations
79            .iter()
80            .min_by_key(|(_, duration)| *duration)
81            .map(|(name, duration)| (name.clone(), *duration));
82
83        PerformanceMetrics {
84            total_duration,
85            checkpoint_count,
86            avg_checkpoint_duration,
87            slowest_checkpoint,
88            fastest_checkpoint,
89            durations: self.durations.clone(),
90        }
91    }
92
93    /// Analyze performance bottlenecks
94    pub fn analyze_bottlenecks(&self, threshold_percentile: f64) -> BottleneckAnalysis {
95        let mut duration_values: Vec<u128> =
96            self.durations.values().map(|d| d.as_millis()).collect();
97        duration_values.sort();
98
99        let threshold_index = ((duration_values.len() as f64 * threshold_percentile) as usize)
100            .min(duration_values.len().saturating_sub(1));
101        let threshold = duration_values.get(threshold_index).copied().unwrap_or(0);
102
103        let bottlenecks: Vec<PerformanceBottleneck> = self
104            .durations
105            .iter()
106            .filter(|(_, duration)| duration.as_millis() >= threshold)
107            .map(|(name, duration)| PerformanceBottleneck {
108                checkpoint_name: name.clone(),
109                duration: *duration,
110                severity: Self::classify_bottleneck_severity(
111                    duration.as_millis(),
112                    &duration_values,
113                ),
114                recommendation: Self::generate_bottleneck_recommendation(name, *duration),
115            })
116            .collect();
117
118        let total_bottleneck_time: Duration = bottlenecks.iter().map(|b| b.duration).sum();
119
120        BottleneckAnalysis {
121            threshold_ms: threshold,
122            bottlenecks,
123            total_bottleneck_time,
124            bottleneck_percentage: if self.total_elapsed().as_millis() > 0 {
125                (total_bottleneck_time.as_millis() as f64 / self.total_elapsed().as_millis() as f64)
126                    * 100.0
127            } else {
128                0.0
129            },
130        }
131    }
132
133    fn classify_bottleneck_severity(
134        duration_ms: u128,
135        all_durations: &[u128],
136    ) -> BottleneckSeverity {
137        if all_durations.is_empty() {
138            return BottleneckSeverity::Low;
139        }
140
141        let max_duration = all_durations.iter().max().copied().unwrap_or(0);
142        let avg_duration = all_durations.iter().sum::<u128>() / all_durations.len() as u128;
143
144        if duration_ms >= max_duration {
145            BottleneckSeverity::Critical
146        } else if duration_ms > avg_duration * 3 {
147            BottleneckSeverity::High
148        } else if duration_ms > avg_duration * 2 {
149            BottleneckSeverity::Medium
150        } else {
151            BottleneckSeverity::Low
152        }
153    }
154
155    fn generate_bottleneck_recommendation(checkpoint_name: &str, duration: Duration) -> String {
156        let duration_ms = duration.as_millis();
157
158        match checkpoint_name {
159            name if name.contains("forward") => {
160                if duration_ms > 1000 {
161                    "Consider model pruning or quantization to reduce forward pass time".to_string()
162                } else {
163                    "Monitor forward pass efficiency".to_string()
164                }
165            },
166            name if name.contains("backward") => {
167                if duration_ms > 2000 {
168                    "Consider gradient accumulation or mixed precision training".to_string()
169                } else {
170                    "Monitor backward pass efficiency".to_string()
171                }
172            },
173            name if name.contains("data") => {
174                "Consider data loading optimization or caching".to_string()
175            },
176            name if name.contains("io") => {
177                "Consider I/O optimization or async processing".to_string()
178            },
179            _ => {
180                format!(
181                    "Optimize '{}' operation - duration: {}ms",
182                    checkpoint_name, duration_ms
183                )
184            },
185        }
186    }
187}
188
189impl Default for PerformanceMonitor {
190    fn default() -> Self {
191        Self::new()
192    }
193}
194
195/// Detailed performance metrics
196#[derive(Debug, Serialize, Deserialize)]
197pub struct PerformanceMetrics {
198    pub total_duration: Duration,
199    pub checkpoint_count: usize,
200    pub avg_checkpoint_duration: f64,
201    pub slowest_checkpoint: Option<(String, Duration)>,
202    pub fastest_checkpoint: Option<(String, Duration)>,
203    pub durations: HashMap<String, Duration>,
204}
205
206/// Performance bottleneck analysis
207#[derive(Debug, Serialize, Deserialize)]
208pub struct BottleneckAnalysis {
209    pub threshold_ms: u128,
210    pub bottlenecks: Vec<PerformanceBottleneck>,
211    pub total_bottleneck_time: Duration,
212    pub bottleneck_percentage: f64,
213}
214
215/// Individual performance bottleneck
216#[derive(Debug, Serialize, Deserialize)]
217pub struct PerformanceBottleneck {
218    pub checkpoint_name: String,
219    pub duration: Duration,
220    pub severity: BottleneckSeverity,
221    pub recommendation: String,
222}
223
224/// Bottleneck severity levels
225#[derive(Debug, Serialize, Deserialize)]
226pub enum BottleneckSeverity {
227    Low,
228    Medium,
229    High,
230    Critical,
231}
232
233/// Memory performance monitoring
234#[derive(Debug)]
235pub struct SystemMemoryProfiler {
236    baseline_memory: usize,
237    peak_memory: usize,
238    checkpoints: HashMap<String, usize>,
239}
240
241impl Default for SystemMemoryProfiler {
242    fn default() -> Self {
243        Self::new()
244    }
245}
246
247impl SystemMemoryProfiler {
248    pub fn new() -> Self {
249        Self {
250            baseline_memory: Self::get_current_memory_usage(),
251            peak_memory: 0,
252            checkpoints: HashMap::new(),
253        }
254    }
255
256    pub fn checkpoint(&mut self, name: &str) {
257        let current_memory = Self::get_current_memory_usage();
258        self.checkpoints.insert(name.to_string(), current_memory);
259
260        if current_memory > self.peak_memory {
261            self.peak_memory = current_memory;
262        }
263    }
264
265    pub fn memory_report(&self) -> MemoryReport {
266        let current_memory = Self::get_current_memory_usage();
267        let memory_growth = current_memory.saturating_sub(self.baseline_memory);
268
269        let mut memory_deltas = HashMap::new();
270        let mut prev_memory = self.baseline_memory;
271
272        for (name, memory) in &self.checkpoints {
273            let delta = memory.saturating_sub(prev_memory) as i64;
274            memory_deltas.insert(name.clone(), delta);
275            prev_memory = *memory;
276        }
277
278        MemoryReport {
279            baseline_memory: self.baseline_memory,
280            current_memory,
281            peak_memory: self.peak_memory,
282            memory_growth,
283            checkpoints: self.checkpoints.clone(),
284            memory_deltas,
285        }
286    }
287
288    fn get_current_memory_usage() -> usize {
289        // Simplified memory usage - in practice this would use platform-specific APIs
290        // This is a placeholder implementation
291        0
292    }
293}
294
295/// Memory profiling report
296#[derive(Debug, Serialize, Deserialize)]
297pub struct MemoryReport {
298    pub baseline_memory: usize,
299    pub current_memory: usize,
300    pub peak_memory: usize,
301    pub memory_growth: usize,
302    pub checkpoints: HashMap<String, usize>,
303    pub memory_deltas: HashMap<String, i64>,
304}
305
306/// Combined performance and memory profiler
307#[derive(Debug)]
308pub struct SystemProfiler {
309    performance_monitor: PerformanceMonitor,
310    memory_profiler: SystemMemoryProfiler,
311}
312
313impl Default for SystemProfiler {
314    fn default() -> Self {
315        Self::new()
316    }
317}
318
319impl SystemProfiler {
320    pub fn new() -> Self {
321        Self {
322            performance_monitor: PerformanceMonitor::new(),
323            memory_profiler: SystemMemoryProfiler::new(),
324        }
325    }
326
327    pub fn checkpoint(&mut self, name: &str) {
328        self.performance_monitor.checkpoint(name);
329        self.memory_profiler.checkpoint(name);
330    }
331
332    pub fn end_checkpoint(&mut self, name: &str) -> Option<Duration> {
333        self.performance_monitor.end_checkpoint(name)
334    }
335
336    pub fn generate_system_report(&self) -> SystemReport {
337        let performance_metrics = self.performance_monitor.get_detailed_metrics();
338        let memory_report = self.memory_profiler.memory_report();
339        let bottleneck_analysis = self.performance_monitor.analyze_bottlenecks(0.8);
340
341        SystemReport {
342            performance_metrics,
343            memory_report,
344            bottleneck_analysis,
345            timestamp: chrono::Utc::now(),
346        }
347    }
348}
349
350/// Comprehensive system profiling report
351#[derive(Debug, Serialize, Deserialize)]
352pub struct SystemReport {
353    pub performance_metrics: PerformanceMetrics,
354    pub memory_report: MemoryReport,
355    pub bottleneck_analysis: BottleneckAnalysis,
356    pub timestamp: chrono::DateTime<chrono::Utc>,
357}