trustformers_debug/utilities/
performance.rs1use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::time::{Duration, Instant};
6
7#[derive(Debug)]
9pub struct PerformanceMonitor {
10 start_time: Instant,
11 checkpoints: HashMap<String, Instant>,
12 durations: HashMap<String, Duration>,
13}
14
15impl PerformanceMonitor {
16 pub fn new() -> Self {
17 Self {
18 start_time: Instant::now(),
19 checkpoints: HashMap::new(),
20 durations: HashMap::new(),
21 }
22 }
23
24 pub fn checkpoint(&mut self, name: &str) {
25 self.checkpoints.insert(name.to_string(), Instant::now());
26 }
27
28 pub fn end_checkpoint(&mut self, name: &str) -> Option<Duration> {
29 if let Some(start) = self.checkpoints.remove(name) {
30 let duration = start.elapsed();
31 self.durations.insert(name.to_string(), duration);
32 Some(duration)
33 } else {
34 None
35 }
36 }
37
38 pub fn total_elapsed(&self) -> Duration {
39 self.start_time.elapsed()
40 }
41
42 pub fn get_durations(&self) -> &HashMap<String, Duration> {
43 &self.durations
44 }
45
46 pub fn performance_report(&self) -> String {
47 let mut report = format!(
48 "Performance Report - Total: {:.2}ms\n",
49 self.total_elapsed().as_millis()
50 );
51
52 for (name, duration) in &self.durations {
53 report.push_str(&format!(" {}: {:.2}ms\n", name, duration.as_millis()));
54 }
55
56 report
57 }
58
59 pub fn get_detailed_metrics(&self) -> PerformanceMetrics {
61 let total_duration = self.total_elapsed();
62 let checkpoint_count = self.durations.len();
63
64 let avg_checkpoint_duration = if checkpoint_count > 0 {
65 self.durations.values().map(|d| d.as_millis() as f64).sum::<f64>()
66 / checkpoint_count as f64
67 } else {
68 0.0
69 };
70
71 let slowest_checkpoint = self
72 .durations
73 .iter()
74 .max_by_key(|(_, duration)| *duration)
75 .map(|(name, duration)| (name.clone(), *duration));
76
77 let fastest_checkpoint = self
78 .durations
79 .iter()
80 .min_by_key(|(_, duration)| *duration)
81 .map(|(name, duration)| (name.clone(), *duration));
82
83 PerformanceMetrics {
84 total_duration,
85 checkpoint_count,
86 avg_checkpoint_duration,
87 slowest_checkpoint,
88 fastest_checkpoint,
89 durations: self.durations.clone(),
90 }
91 }
92
93 pub fn analyze_bottlenecks(&self, threshold_percentile: f64) -> BottleneckAnalysis {
95 let mut duration_values: Vec<u128> =
96 self.durations.values().map(|d| d.as_millis()).collect();
97 duration_values.sort();
98
99 let threshold_index = ((duration_values.len() as f64 * threshold_percentile) as usize)
100 .min(duration_values.len().saturating_sub(1));
101 let threshold = duration_values.get(threshold_index).copied().unwrap_or(0);
102
103 let bottlenecks: Vec<PerformanceBottleneck> = self
104 .durations
105 .iter()
106 .filter(|(_, duration)| duration.as_millis() >= threshold)
107 .map(|(name, duration)| PerformanceBottleneck {
108 checkpoint_name: name.clone(),
109 duration: *duration,
110 severity: Self::classify_bottleneck_severity(
111 duration.as_millis(),
112 &duration_values,
113 ),
114 recommendation: Self::generate_bottleneck_recommendation(name, *duration),
115 })
116 .collect();
117
118 let total_bottleneck_time: Duration = bottlenecks.iter().map(|b| b.duration).sum();
119
120 BottleneckAnalysis {
121 threshold_ms: threshold,
122 bottlenecks,
123 total_bottleneck_time,
124 bottleneck_percentage: if self.total_elapsed().as_millis() > 0 {
125 (total_bottleneck_time.as_millis() as f64 / self.total_elapsed().as_millis() as f64)
126 * 100.0
127 } else {
128 0.0
129 },
130 }
131 }
132
133 fn classify_bottleneck_severity(
134 duration_ms: u128,
135 all_durations: &[u128],
136 ) -> BottleneckSeverity {
137 if all_durations.is_empty() {
138 return BottleneckSeverity::Low;
139 }
140
141 let max_duration = all_durations.iter().max().copied().unwrap_or(0);
142 let avg_duration = all_durations.iter().sum::<u128>() / all_durations.len() as u128;
143
144 if duration_ms >= max_duration {
145 BottleneckSeverity::Critical
146 } else if duration_ms > avg_duration * 3 {
147 BottleneckSeverity::High
148 } else if duration_ms > avg_duration * 2 {
149 BottleneckSeverity::Medium
150 } else {
151 BottleneckSeverity::Low
152 }
153 }
154
155 fn generate_bottleneck_recommendation(checkpoint_name: &str, duration: Duration) -> String {
156 let duration_ms = duration.as_millis();
157
158 match checkpoint_name {
159 name if name.contains("forward") => {
160 if duration_ms > 1000 {
161 "Consider model pruning or quantization to reduce forward pass time".to_string()
162 } else {
163 "Monitor forward pass efficiency".to_string()
164 }
165 },
166 name if name.contains("backward") => {
167 if duration_ms > 2000 {
168 "Consider gradient accumulation or mixed precision training".to_string()
169 } else {
170 "Monitor backward pass efficiency".to_string()
171 }
172 },
173 name if name.contains("data") => {
174 "Consider data loading optimization or caching".to_string()
175 },
176 name if name.contains("io") => {
177 "Consider I/O optimization or async processing".to_string()
178 },
179 _ => {
180 format!(
181 "Optimize '{}' operation - duration: {}ms",
182 checkpoint_name, duration_ms
183 )
184 },
185 }
186 }
187}
188
189impl Default for PerformanceMonitor {
190 fn default() -> Self {
191 Self::new()
192 }
193}
194
195#[derive(Debug, Serialize, Deserialize)]
197pub struct PerformanceMetrics {
198 pub total_duration: Duration,
199 pub checkpoint_count: usize,
200 pub avg_checkpoint_duration: f64,
201 pub slowest_checkpoint: Option<(String, Duration)>,
202 pub fastest_checkpoint: Option<(String, Duration)>,
203 pub durations: HashMap<String, Duration>,
204}
205
206#[derive(Debug, Serialize, Deserialize)]
208pub struct BottleneckAnalysis {
209 pub threshold_ms: u128,
210 pub bottlenecks: Vec<PerformanceBottleneck>,
211 pub total_bottleneck_time: Duration,
212 pub bottleneck_percentage: f64,
213}
214
215#[derive(Debug, Serialize, Deserialize)]
217pub struct PerformanceBottleneck {
218 pub checkpoint_name: String,
219 pub duration: Duration,
220 pub severity: BottleneckSeverity,
221 pub recommendation: String,
222}
223
224#[derive(Debug, Serialize, Deserialize)]
226pub enum BottleneckSeverity {
227 Low,
228 Medium,
229 High,
230 Critical,
231}
232
233#[derive(Debug)]
235pub struct SystemMemoryProfiler {
236 baseline_memory: usize,
237 peak_memory: usize,
238 checkpoints: HashMap<String, usize>,
239}
240
241impl Default for SystemMemoryProfiler {
242 fn default() -> Self {
243 Self::new()
244 }
245}
246
247impl SystemMemoryProfiler {
248 pub fn new() -> Self {
249 Self {
250 baseline_memory: Self::get_current_memory_usage(),
251 peak_memory: 0,
252 checkpoints: HashMap::new(),
253 }
254 }
255
256 pub fn checkpoint(&mut self, name: &str) {
257 let current_memory = Self::get_current_memory_usage();
258 self.checkpoints.insert(name.to_string(), current_memory);
259
260 if current_memory > self.peak_memory {
261 self.peak_memory = current_memory;
262 }
263 }
264
265 pub fn memory_report(&self) -> MemoryReport {
266 let current_memory = Self::get_current_memory_usage();
267 let memory_growth = current_memory.saturating_sub(self.baseline_memory);
268
269 let mut memory_deltas = HashMap::new();
270 let mut prev_memory = self.baseline_memory;
271
272 for (name, memory) in &self.checkpoints {
273 let delta = memory.saturating_sub(prev_memory) as i64;
274 memory_deltas.insert(name.clone(), delta);
275 prev_memory = *memory;
276 }
277
278 MemoryReport {
279 baseline_memory: self.baseline_memory,
280 current_memory,
281 peak_memory: self.peak_memory,
282 memory_growth,
283 checkpoints: self.checkpoints.clone(),
284 memory_deltas,
285 }
286 }
287
288 fn get_current_memory_usage() -> usize {
289 0
292 }
293}
294
295#[derive(Debug, Serialize, Deserialize)]
297pub struct MemoryReport {
298 pub baseline_memory: usize,
299 pub current_memory: usize,
300 pub peak_memory: usize,
301 pub memory_growth: usize,
302 pub checkpoints: HashMap<String, usize>,
303 pub memory_deltas: HashMap<String, i64>,
304}
305
306#[derive(Debug)]
308pub struct SystemProfiler {
309 performance_monitor: PerformanceMonitor,
310 memory_profiler: SystemMemoryProfiler,
311}
312
313impl Default for SystemProfiler {
314 fn default() -> Self {
315 Self::new()
316 }
317}
318
319impl SystemProfiler {
320 pub fn new() -> Self {
321 Self {
322 performance_monitor: PerformanceMonitor::new(),
323 memory_profiler: SystemMemoryProfiler::new(),
324 }
325 }
326
327 pub fn checkpoint(&mut self, name: &str) {
328 self.performance_monitor.checkpoint(name);
329 self.memory_profiler.checkpoint(name);
330 }
331
332 pub fn end_checkpoint(&mut self, name: &str) -> Option<Duration> {
333 self.performance_monitor.end_checkpoint(name)
334 }
335
336 pub fn generate_system_report(&self) -> SystemReport {
337 let performance_metrics = self.performance_monitor.get_detailed_metrics();
338 let memory_report = self.memory_profiler.memory_report();
339 let bottleneck_analysis = self.performance_monitor.analyze_bottlenecks(0.8);
340
341 SystemReport {
342 performance_metrics,
343 memory_report,
344 bottleneck_analysis,
345 timestamp: chrono::Utc::now(),
346 }
347 }
348}
349
350#[derive(Debug, Serialize, Deserialize)]
352pub struct SystemReport {
353 pub performance_metrics: PerformanceMetrics,
354 pub memory_report: MemoryReport,
355 pub bottleneck_analysis: BottleneckAnalysis,
356 pub timestamp: chrono::DateTime<chrono::Utc>,
357}