Skip to main content

trustformers_core/monitoring/
profiler.rs

1// Performance profiling utilities
2use anyhow::Result;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::time::{Duration, Instant};
6
7/// Model profiler for tracking performance metrics
8#[derive(Debug, Clone)]
9pub struct ModelProfiler {
10    config: ProfilerConfig,
11    active_sessions: HashMap<String, ProfilingSession>,
12}
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ProfilerConfig {
16    pub enabled: bool,
17    pub track_layer_times: bool,
18    pub track_memory_usage: bool,
19    pub track_compute_utilization: bool,
20    pub sample_interval_ms: u64,
21    pub max_samples: usize,
22}
23
24impl Default for ProfilerConfig {
25    fn default() -> Self {
26        Self {
27            enabled: true,
28            track_layer_times: true,
29            track_memory_usage: true,
30            track_compute_utilization: false, // Expensive
31            sample_interval_ms: 10,
32            max_samples: 10000,
33        }
34    }
35}
36
37#[derive(Debug, Clone)]
38struct ProfilingSession {
39    id: String,
40    start_time: Instant,
41    layer_timings: HashMap<String, Vec<Duration>>,
42    operation_timings: HashMap<String, Vec<Duration>>,
43    memory_samples: Vec<MemorySample>,
44    compute_samples: Vec<ComputeSample>,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct MemorySample {
49    pub timestamp: Duration,
50    pub cpu_usage_mb: f64,
51    pub gpu_usage_mb: f64,
52    pub peak_usage_mb: f64,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct ComputeSample {
57    pub timestamp: Duration,
58    pub cpu_utilization: f64,
59    pub gpu_utilization: f64,
60    pub memory_bandwidth: f64,
61    pub flops: f64,
62}
63
64/// Complete profiling report for a session
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct ProfilingReport {
67    pub session_id: String,
68    pub total_duration: Duration,
69    pub layer_performance: LayerPerformanceReport,
70    pub operation_performance: OperationPerformanceReport,
71    pub memory_profile: MemoryProfile,
72    pub compute_profile: ComputeProfile,
73    pub bottleneck_analysis: BottleneckAnalysis,
74    pub optimization_suggestions: Vec<OptimizationSuggestion>,
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct LayerPerformanceReport {
79    pub layer_timings: HashMap<String, LayerTiming>,
80    pub total_layer_time: Duration,
81    pub slowest_layers: Vec<(String, Duration)>,
82    pub layer_efficiency: HashMap<String, f64>,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct LayerTiming {
87    pub layer_name: String,
88    pub average_time: Duration,
89    pub min_time: Duration,
90    pub max_time: Duration,
91    pub std_deviation: Duration,
92    pub call_count: usize,
93    pub total_time: Duration,
94    pub percentage_of_total: f64,
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct OperationPerformanceReport {
99    pub operation_timings: HashMap<String, OperationTiming>,
100    pub total_operation_time: Duration,
101    pub slowest_operations: Vec<(String, Duration)>,
102    pub operation_efficiency: HashMap<String, f64>,
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct OperationTiming {
107    pub operation_name: String,
108    pub average_time: Duration,
109    pub min_time: Duration,
110    pub max_time: Duration,
111    pub std_deviation: Duration,
112    pub call_count: usize,
113    pub total_time: Duration,
114    pub percentage_of_total: f64,
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct MemoryProfile {
119    pub peak_memory_usage: f64,
120    pub average_memory_usage: f64,
121    pub memory_efficiency: f64,
122    pub memory_fragmentation: f64,
123    pub memory_timeline: Vec<MemorySample>,
124    pub allocation_patterns: Vec<AllocationPattern>,
125}
126
127#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct AllocationPattern {
129    pub pattern_type: String,
130    pub frequency: usize,
131    pub average_size_mb: f64,
132    pub total_size_mb: f64,
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct ComputeProfile {
137    pub average_cpu_utilization: f64,
138    pub average_gpu_utilization: f64,
139    pub peak_flops: f64,
140    pub average_flops: f64,
141    pub compute_efficiency: f64,
142    pub utilization_timeline: Vec<ComputeSample>,
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct BottleneckAnalysis {
147    pub primary_bottleneck: BottleneckType,
148    pub bottleneck_severity: f64,
149    pub affected_operations: Vec<String>,
150    pub bottleneck_timeline: Vec<BottleneckEvent>,
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
154pub enum BottleneckType {
155    Memory,
156    Compute,
157    IO,
158    Network,
159    Synchronization,
160    None,
161}
162
163#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct BottleneckEvent {
165    pub timestamp: Duration,
166    pub bottleneck_type: BottleneckType,
167    pub severity: f64,
168    pub duration: Duration,
169    pub affected_operation: String,
170}
171
172#[derive(Debug, Clone, Serialize, Deserialize)]
173pub struct OptimizationSuggestion {
174    pub suggestion_type: OptimizationType,
175    pub priority: OptimizationPriority,
176    pub description: String,
177    pub expected_improvement: f64,
178    pub implementation_complexity: ComplexityLevel,
179}
180
181#[derive(Debug, Clone, Serialize, Deserialize)]
182pub enum OptimizationType {
183    MemoryOptimization,
184    ComputeOptimization,
185    ArchitecturalChange,
186    AlgorithmicImprovement,
187    HardwareUtilization,
188    DataLayout,
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub enum OptimizationPriority {
193    Critical,
194    High,
195    Medium,
196    Low,
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize)]
200pub enum ComplexityLevel {
201    Low,
202    Medium,
203    High,
204    VeryHigh,
205}
206
207impl Default for ProfilingReport {
208    fn default() -> Self {
209        Self {
210            session_id: String::new(),
211            total_duration: Duration::from_secs(0),
212            layer_performance: LayerPerformanceReport::default(),
213            operation_performance: OperationPerformanceReport::default(),
214            memory_profile: MemoryProfile::default(),
215            compute_profile: ComputeProfile::default(),
216            bottleneck_analysis: BottleneckAnalysis::default(),
217            optimization_suggestions: Vec::new(),
218        }
219    }
220}
221
222impl Default for LayerPerformanceReport {
223    fn default() -> Self {
224        Self {
225            layer_timings: HashMap::new(),
226            total_layer_time: Duration::from_secs(0),
227            slowest_layers: Vec::new(),
228            layer_efficiency: HashMap::new(),
229        }
230    }
231}
232
233impl Default for OperationPerformanceReport {
234    fn default() -> Self {
235        Self {
236            operation_timings: HashMap::new(),
237            total_operation_time: Duration::from_secs(0),
238            slowest_operations: Vec::new(),
239            operation_efficiency: HashMap::new(),
240        }
241    }
242}
243
244impl Default for MemoryProfile {
245    fn default() -> Self {
246        Self {
247            peak_memory_usage: 0.0,
248            average_memory_usage: 0.0,
249            memory_efficiency: 0.0,
250            memory_fragmentation: 0.0,
251            memory_timeline: Vec::new(),
252            allocation_patterns: Vec::new(),
253        }
254    }
255}
256
257impl Default for ComputeProfile {
258    fn default() -> Self {
259        Self {
260            average_cpu_utilization: 0.0,
261            average_gpu_utilization: 0.0,
262            peak_flops: 0.0,
263            average_flops: 0.0,
264            compute_efficiency: 0.0,
265            utilization_timeline: Vec::new(),
266        }
267    }
268}
269
270impl Default for BottleneckAnalysis {
271    fn default() -> Self {
272        Self {
273            primary_bottleneck: BottleneckType::None,
274            bottleneck_severity: 0.0,
275            affected_operations: Vec::new(),
276            bottleneck_timeline: Vec::new(),
277        }
278    }
279}
280
281impl Default for ModelProfiler {
282    fn default() -> Self {
283        Self::new()
284    }
285}
286
287impl ModelProfiler {
288    pub fn new() -> Self {
289        Self {
290            config: ProfilerConfig::default(),
291            active_sessions: HashMap::new(),
292        }
293    }
294
295    pub fn with_config(config: ProfilerConfig) -> Self {
296        Self {
297            config,
298            active_sessions: HashMap::new(),
299        }
300    }
301
302    /// Start profiling a session
303    pub fn start_profiling(&mut self, session_id: &str) -> Result<()> {
304        if !self.config.enabled {
305            return Ok(());
306        }
307
308        let session = ProfilingSession {
309            id: session_id.to_string(),
310            start_time: Instant::now(),
311            layer_timings: HashMap::new(),
312            operation_timings: HashMap::new(),
313            memory_samples: Vec::new(),
314            compute_samples: Vec::new(),
315        };
316
317        self.active_sessions.insert(session_id.to_string(), session);
318        Ok(())
319    }
320
321    /// Profile a layer execution
322    pub fn profile_layer<T, F>(
323        &mut self,
324        session_id: &str,
325        layer_name: &str,
326        operation: F,
327    ) -> Result<T>
328    where
329        F: FnOnce() -> Result<T>,
330    {
331        if !self.config.enabled {
332            return operation();
333        }
334
335        let start_time = Instant::now();
336        let result = operation()?;
337        let duration = start_time.elapsed();
338
339        if let Some(session) = self.active_sessions.get_mut(session_id) {
340            session.layer_timings.entry(layer_name.to_string()).or_default().push(duration);
341        }
342
343        Ok(result)
344    }
345
346    /// Profile an operation execution
347    pub fn profile_operation<T, F>(
348        &mut self,
349        session_id: &str,
350        operation_name: &str,
351        operation: F,
352    ) -> Result<T>
353    where
354        F: FnOnce() -> Result<T>,
355    {
356        if !self.config.enabled {
357            return operation();
358        }
359
360        let start_time = Instant::now();
361        let result = operation()?;
362        let duration = start_time.elapsed();
363
364        if let Some(session) = self.active_sessions.get_mut(session_id) {
365            session
366                .operation_timings
367                .entry(operation_name.to_string())
368                .or_default()
369                .push(duration);
370        }
371
372        Ok(result)
373    }
374
375    /// Take a memory sample
376    pub fn sample_memory(&mut self, session_id: &str) -> Result<()> {
377        if !self.config.enabled || !self.config.track_memory_usage {
378            return Ok(());
379        }
380
381        let timestamp = if let Some(session) = self.active_sessions.get(session_id) {
382            session.start_time.elapsed()
383        } else {
384            return Ok(());
385        };
386
387        let sample = self.get_memory_sample(timestamp)?;
388
389        if let Some(session) = self.active_sessions.get_mut(session_id) {
390            if session.memory_samples.len() < self.config.max_samples {
391                session.memory_samples.push(sample);
392            }
393        }
394
395        Ok(())
396    }
397
398    /// Take a compute sample
399    pub fn sample_compute(&mut self, session_id: &str) -> Result<()> {
400        if !self.config.enabled || !self.config.track_compute_utilization {
401            return Ok(());
402        }
403
404        let timestamp = if let Some(session) = self.active_sessions.get(session_id) {
405            session.start_time.elapsed()
406        } else {
407            return Ok(());
408        };
409
410        let sample = self.get_compute_sample(timestamp)?;
411
412        if let Some(session) = self.active_sessions.get_mut(session_id) {
413            if session.compute_samples.len() < self.config.max_samples {
414                session.compute_samples.push(sample);
415            }
416        }
417
418        Ok(())
419    }
420
421    /// End profiling and generate report
422    pub fn end_profiling(&mut self, session_id: &str) -> Result<ProfilingReport> {
423        let session = self
424            .active_sessions
425            .remove(session_id)
426            .ok_or_else(|| anyhow::anyhow!("Session not found: {}", session_id))?;
427
428        let total_duration = session.start_time.elapsed();
429
430        let layer_performance = self.analyze_layer_performance(&session, total_duration)?;
431        let operation_performance = self.analyze_operation_performance(&session, total_duration)?;
432        let memory_profile = self.analyze_memory_profile(&session)?;
433        let compute_profile = self.analyze_compute_profile(&session)?;
434        let bottleneck_analysis =
435            self.analyze_bottlenecks(&session, &layer_performance, &operation_performance)?;
436        let optimization_suggestions = self.generate_optimization_suggestions(
437            &bottleneck_analysis,
438            &memory_profile,
439            &compute_profile,
440        )?;
441
442        Ok(ProfilingReport {
443            session_id: session.id,
444            total_duration,
445            layer_performance,
446            operation_performance,
447            memory_profile,
448            compute_profile,
449            bottleneck_analysis,
450            optimization_suggestions,
451        })
452    }
453
454    /// Clear all profiling data
455    pub fn clear(&mut self) -> Result<()> {
456        self.active_sessions.clear();
457        Ok(())
458    }
459
460    /// Get current memory sample
461    fn get_memory_sample(&self, timestamp: Duration) -> Result<MemorySample> {
462        // Simplified implementation - would use actual system monitoring
463        Ok(MemorySample {
464            timestamp,
465            cpu_usage_mb: 1024.0 + (timestamp.as_millis() as f64 * 0.1) % 512.0,
466            gpu_usage_mb: 2048.0 + (timestamp.as_millis() as f64 * 0.05) % 1024.0,
467            peak_usage_mb: 3072.0,
468        })
469    }
470
471    /// Get current compute sample
472    fn get_compute_sample(&self, timestamp: Duration) -> Result<ComputeSample> {
473        // Simplified implementation - would use actual system monitoring
474        let phase = (timestamp.as_millis() as f64 * 0.01) % (2.0 * std::f64::consts::PI);
475
476        Ok(ComputeSample {
477            timestamp,
478            cpu_utilization: 0.6 + 0.3 * phase.sin(),
479            gpu_utilization: 0.8 + 0.2 * phase.cos(),
480            memory_bandwidth: 200.0 + 50.0 * phase.sin(),
481            flops: 1000.0 + 200.0 * phase.cos(),
482        })
483    }
484
485    /// Analyze layer performance
486    fn analyze_layer_performance(
487        &self,
488        session: &ProfilingSession,
489        total_duration: Duration,
490    ) -> Result<LayerPerformanceReport> {
491        let mut layer_timings = HashMap::new();
492        let mut total_layer_time = Duration::from_secs(0);
493        let mut slowest_layers = Vec::new();
494        let mut layer_efficiency = HashMap::new();
495
496        for (layer_name, timings) in &session.layer_timings {
497            let total_time: Duration = timings.iter().sum();
498            let average_time = total_time / timings.len() as u32;
499            let min_time = *timings.iter().min().unwrap_or(&Duration::from_secs(0));
500            let max_time = *timings.iter().max().unwrap_or(&Duration::from_secs(0));
501
502            // Calculate standard deviation
503            let mean_nanos = average_time.as_nanos() as f64;
504            let variance = timings
505                .iter()
506                .map(|t| {
507                    let diff = t.as_nanos() as f64 - mean_nanos;
508                    diff * diff
509                })
510                .sum::<f64>()
511                / timings.len() as f64;
512            let std_dev_nanos = variance.sqrt() as u64;
513            let std_deviation = Duration::from_nanos(std_dev_nanos);
514
515            let percentage_of_total = if total_duration.as_nanos() > 0 {
516                (total_time.as_nanos() as f64 / total_duration.as_nanos() as f64) * 100.0
517            } else {
518                0.0
519            };
520
521            layer_timings.insert(
522                layer_name.clone(),
523                LayerTiming {
524                    layer_name: layer_name.clone(),
525                    average_time,
526                    min_time,
527                    max_time,
528                    std_deviation,
529                    call_count: timings.len(),
530                    total_time,
531                    percentage_of_total,
532                },
533            );
534
535            total_layer_time += total_time;
536            slowest_layers.push((layer_name.clone(), total_time));
537
538            // Calculate efficiency (inverse of coefficient of variation)
539            let efficiency = if std_dev_nanos > 0 && mean_nanos > 0.0 {
540                1.0 / (std_dev_nanos as f64 / mean_nanos)
541            } else {
542                1.0
543            };
544            layer_efficiency.insert(layer_name.clone(), efficiency);
545        }
546
547        // Sort slowest layers by total time
548        slowest_layers.sort_by_key(|item| std::cmp::Reverse(item.1));
549        slowest_layers.truncate(10); // Keep top 10
550
551        Ok(LayerPerformanceReport {
552            layer_timings,
553            total_layer_time,
554            slowest_layers,
555            layer_efficiency,
556        })
557    }
558
559    /// Analyze operation performance
560    fn analyze_operation_performance(
561        &self,
562        session: &ProfilingSession,
563        total_duration: Duration,
564    ) -> Result<OperationPerformanceReport> {
565        let mut operation_timings = HashMap::new();
566        let mut total_operation_time = Duration::from_secs(0);
567        let mut slowest_operations = Vec::new();
568        let mut operation_efficiency = HashMap::new();
569
570        for (operation_name, timings) in &session.operation_timings {
571            let total_time: Duration = timings.iter().sum();
572            let average_time = total_time / timings.len() as u32;
573            let min_time = *timings.iter().min().unwrap_or(&Duration::from_secs(0));
574            let max_time = *timings.iter().max().unwrap_or(&Duration::from_secs(0));
575
576            // Calculate standard deviation
577            let mean_nanos = average_time.as_nanos() as f64;
578            let variance = timings
579                .iter()
580                .map(|t| {
581                    let diff = t.as_nanos() as f64 - mean_nanos;
582                    diff * diff
583                })
584                .sum::<f64>()
585                / timings.len() as f64;
586            let std_dev_nanos = variance.sqrt() as u64;
587            let std_deviation = Duration::from_nanos(std_dev_nanos);
588
589            let percentage_of_total = if total_duration.as_nanos() > 0 {
590                (total_time.as_nanos() as f64 / total_duration.as_nanos() as f64) * 100.0
591            } else {
592                0.0
593            };
594
595            operation_timings.insert(
596                operation_name.clone(),
597                OperationTiming {
598                    operation_name: operation_name.clone(),
599                    average_time,
600                    min_time,
601                    max_time,
602                    std_deviation,
603                    call_count: timings.len(),
604                    total_time,
605                    percentage_of_total,
606                },
607            );
608
609            total_operation_time += total_time;
610            slowest_operations.push((operation_name.clone(), total_time));
611
612            // Calculate efficiency
613            let efficiency = if std_dev_nanos > 0 && mean_nanos > 0.0 {
614                1.0 / (std_dev_nanos as f64 / mean_nanos)
615            } else {
616                1.0
617            };
618            operation_efficiency.insert(operation_name.clone(), efficiency);
619        }
620
621        // Sort slowest operations
622        slowest_operations.sort_by_key(|item| std::cmp::Reverse(item.1));
623        slowest_operations.truncate(10);
624
625        Ok(OperationPerformanceReport {
626            operation_timings,
627            total_operation_time,
628            slowest_operations,
629            operation_efficiency,
630        })
631    }
632
633    /// Analyze memory profile
634    fn analyze_memory_profile(&self, session: &ProfilingSession) -> Result<MemoryProfile> {
635        if session.memory_samples.is_empty() {
636            return Ok(MemoryProfile::default());
637        }
638
639        let peak_memory_usage = session
640            .memory_samples
641            .iter()
642            .map(|s| s.cpu_usage_mb.max(s.gpu_usage_mb))
643            .fold(0.0, f64::max);
644
645        let average_memory_usage = session
646            .memory_samples
647            .iter()
648            .map(|s| s.cpu_usage_mb + s.gpu_usage_mb)
649            .sum::<f64>()
650            / session.memory_samples.len() as f64;
651
652        let memory_efficiency = if peak_memory_usage > 0.0 {
653            average_memory_usage / peak_memory_usage
654        } else {
655            0.0
656        };
657
658        let memory_fragmentation = 0.1; // Simplified calculation
659
660        let allocation_patterns = vec![
661            AllocationPattern {
662                pattern_type: "Tensor".to_string(),
663                frequency: 100,
664                average_size_mb: 10.0,
665                total_size_mb: 1000.0,
666            },
667            AllocationPattern {
668                pattern_type: "Weight".to_string(),
669                frequency: 50,
670                average_size_mb: 20.0,
671                total_size_mb: 1000.0,
672            },
673        ];
674
675        Ok(MemoryProfile {
676            peak_memory_usage,
677            average_memory_usage,
678            memory_efficiency,
679            memory_fragmentation,
680            memory_timeline: session.memory_samples.clone(),
681            allocation_patterns,
682        })
683    }
684
685    /// Analyze compute profile
686    fn analyze_compute_profile(&self, session: &ProfilingSession) -> Result<ComputeProfile> {
687        if session.compute_samples.is_empty() {
688            return Ok(ComputeProfile::default());
689        }
690
691        let average_cpu_utilization =
692            session.compute_samples.iter().map(|s| s.cpu_utilization).sum::<f64>()
693                / session.compute_samples.len() as f64;
694
695        let average_gpu_utilization =
696            session.compute_samples.iter().map(|s| s.gpu_utilization).sum::<f64>()
697                / session.compute_samples.len() as f64;
698
699        let peak_flops = session.compute_samples.iter().map(|s| s.flops).fold(0.0, f64::max);
700
701        let average_flops = session.compute_samples.iter().map(|s| s.flops).sum::<f64>()
702            / session.compute_samples.len() as f64;
703
704        let compute_efficiency = if peak_flops > 0.0 { average_flops / peak_flops } else { 0.0 };
705
706        Ok(ComputeProfile {
707            average_cpu_utilization,
708            average_gpu_utilization,
709            peak_flops,
710            average_flops,
711            compute_efficiency,
712            utilization_timeline: session.compute_samples.clone(),
713        })
714    }
715
716    /// Analyze bottlenecks
717    fn analyze_bottlenecks(
718        &self,
719        _session: &ProfilingSession,
720        layer_performance: &LayerPerformanceReport,
721        _operation_performance: &OperationPerformanceReport,
722    ) -> Result<BottleneckAnalysis> {
723        let mut primary_bottleneck = BottleneckType::None;
724        let mut bottleneck_severity = 0.0;
725        let mut affected_operations = Vec::new();
726
727        // Find the slowest layers as potential bottlenecks
728        if let Some((slowest_layer, duration)) = layer_performance.slowest_layers.first() {
729            let total_time = layer_performance.total_layer_time;
730            if total_time.as_nanos() > 0 {
731                let percentage =
732                    (duration.as_nanos() as f64 / total_time.as_nanos() as f64) * 100.0;
733                if percentage > 30.0 {
734                    primary_bottleneck = BottleneckType::Compute;
735                    bottleneck_severity = percentage / 100.0;
736                    affected_operations.push(slowest_layer.clone());
737                }
738            }
739        }
740
741        Ok(BottleneckAnalysis {
742            primary_bottleneck,
743            bottleneck_severity,
744            affected_operations,
745            bottleneck_timeline: Vec::new(),
746        })
747    }
748
749    /// Generate optimization suggestions
750    fn generate_optimization_suggestions(
751        &self,
752        bottleneck_analysis: &BottleneckAnalysis,
753        memory_profile: &MemoryProfile,
754        compute_profile: &ComputeProfile,
755    ) -> Result<Vec<OptimizationSuggestion>> {
756        let mut suggestions = Vec::new();
757
758        // Memory optimization suggestions
759        if memory_profile.memory_efficiency < 0.7 {
760            suggestions.push(OptimizationSuggestion {
761                suggestion_type: OptimizationType::MemoryOptimization,
762                priority: OptimizationPriority::High,
763                description: "Consider implementing gradient checkpointing to reduce memory usage"
764                    .to_string(),
765                expected_improvement: 0.3,
766                implementation_complexity: ComplexityLevel::Medium,
767            });
768        }
769
770        // Compute optimization suggestions
771        if compute_profile.compute_efficiency < 0.6 {
772            suggestions.push(OptimizationSuggestion {
773                suggestion_type: OptimizationType::ComputeOptimization,
774                priority: OptimizationPriority::High,
775                description:
776                    "Improve compute efficiency with kernel fusion and better parallelization"
777                        .to_string(),
778                expected_improvement: 0.4,
779                implementation_complexity: ComplexityLevel::High,
780            });
781        }
782
783        // Bottleneck-specific suggestions
784        match bottleneck_analysis.primary_bottleneck {
785            BottleneckType::Memory => {
786                suggestions.push(OptimizationSuggestion {
787                    suggestion_type: OptimizationType::MemoryOptimization,
788                    priority: OptimizationPriority::Critical,
789                    description: "Memory bottleneck detected. Consider reducing batch size or using gradient accumulation".to_string(),
790                    expected_improvement: 0.5,
791                    implementation_complexity: ComplexityLevel::Low,
792                });
793            },
794            BottleneckType::Compute => {
795                suggestions.push(OptimizationSuggestion {
796                    suggestion_type: OptimizationType::ComputeOptimization,
797                    priority: OptimizationPriority::Critical,
798                    description: "Compute bottleneck detected. Consider using mixed precision training or model parallelism".to_string(),
799                    expected_improvement: 0.4,
800                    implementation_complexity: ComplexityLevel::Medium,
801                });
802            },
803            _ => {},
804        }
805
806        Ok(suggestions)
807    }
808}
809
810impl ProfilingReport {
811    /// Print a summary of the profiling report
812    pub fn print_summary(&self) {
813        println!("Profiling Report Summary");
814        println!("=======================");
815        println!("Total Duration: {:.2}ms", self.total_duration.as_millis());
816        println!("Layer Performance:");
817        println!(
818            "  Total Layer Time: {:.2}ms",
819            self.layer_performance.total_layer_time.as_millis()
820        );
821        println!(
822            "  Slowest Layers: {}",
823            self.layer_performance.slowest_layers.len()
824        );
825
826        if let Some((slowest_layer, duration)) = self.layer_performance.slowest_layers.first() {
827            println!(
828                "  Slowest Layer: {} ({:.2}ms)",
829                slowest_layer,
830                duration.as_millis()
831            );
832        }
833
834        println!("Memory Profile:");
835        println!(
836            "  Peak Usage: {:.1} MB",
837            self.memory_profile.peak_memory_usage
838        );
839        println!(
840            "  Average Usage: {:.1} MB",
841            self.memory_profile.average_memory_usage
842        );
843        println!(
844            "  Memory Efficiency: {:.1}%",
845            self.memory_profile.memory_efficiency * 100.0
846        );
847
848        println!("Compute Profile:");
849        println!(
850            "  Average CPU Utilization: {:.1}%",
851            self.compute_profile.average_cpu_utilization * 100.0
852        );
853        println!(
854            "  Average GPU Utilization: {:.1}%",
855            self.compute_profile.average_gpu_utilization * 100.0
856        );
857        println!(
858            "  Compute Efficiency: {:.1}%",
859            self.compute_profile.compute_efficiency * 100.0
860        );
861
862        println!("Bottleneck Analysis:");
863        println!(
864            "  Primary Bottleneck: {:?}",
865            self.bottleneck_analysis.primary_bottleneck
866        );
867        println!(
868            "  Severity: {:.1}%",
869            self.bottleneck_analysis.bottleneck_severity * 100.0
870        );
871
872        if !self.optimization_suggestions.is_empty() {
873            println!(
874                "Optimization Suggestions: {}",
875                self.optimization_suggestions.len()
876            );
877            for (i, suggestion) in self.optimization_suggestions.iter().take(3).enumerate() {
878                println!(
879                    "  {}. [{:?}] {}",
880                    i + 1,
881                    suggestion.priority,
882                    suggestion.description
883                );
884            }
885        }
886    }
887}
888
889#[cfg(test)]
890mod tests {
891    use super::*;
892
893    #[test]
894    fn test_profiler_creation() {
895        let profiler = ModelProfiler::new();
896        assert!(profiler.config.enabled);
897        assert!(profiler.config.track_layer_times);
898    }
899
900    #[test]
901    fn test_profiler_with_config() {
902        let config = ProfilerConfig {
903            enabled: true,
904            track_layer_times: true,
905            track_memory_usage: false,
906            track_compute_utilization: true,
907            sample_interval_ms: 50,
908            max_samples: 5000,
909        };
910
911        let profiler = ModelProfiler::with_config(config.clone());
912        assert!(!profiler.config.track_memory_usage);
913        assert!(profiler.config.track_compute_utilization);
914        assert_eq!(profiler.config.max_samples, 5000);
915    }
916
917    #[test]
918    fn test_profiling_session() -> Result<()> {
919        let mut profiler = ModelProfiler::new();
920        let session_id = "test_session";
921
922        profiler.start_profiling(session_id)?;
923
924        // Profile a layer
925        let _result = profiler.profile_layer(session_id, "attention", || {
926            std::thread::sleep(Duration::from_millis(10));
927            Ok(42)
928        })?;
929
930        // Profile an operation
931        let _result = profiler.profile_operation(session_id, "matmul", || {
932            std::thread::sleep(Duration::from_millis(5));
933            Ok("done".to_string())
934        })?;
935
936        let report = profiler.end_profiling(session_id)?;
937
938        assert_eq!(report.session_id, session_id);
939        assert!(report.total_duration > Duration::from_millis(10));
940        assert!(report.layer_performance.layer_timings.contains_key("attention"));
941        assert!(report.operation_performance.operation_timings.contains_key("matmul"));
942
943        Ok(())
944    }
945
946    #[test]
947    fn test_memory_sampling() -> Result<()> {
948        let mut profiler = ModelProfiler::new();
949        let session_id = "test_session";
950
951        profiler.start_profiling(session_id)?;
952        profiler.sample_memory(session_id)?;
953        profiler.sample_memory(session_id)?;
954
955        let report = profiler.end_profiling(session_id)?;
956
957        assert!(report.memory_profile.memory_timeline.len() >= 2);
958        assert!(report.memory_profile.peak_memory_usage > 0.0);
959
960        Ok(())
961    }
962
963    #[test]
964    fn test_compute_sampling() -> Result<()> {
965        let mut profiler = ModelProfiler::with_config(ProfilerConfig {
966            track_compute_utilization: true,
967            ..Default::default()
968        });
969
970        let session_id = "test_session";
971
972        profiler.start_profiling(session_id)?;
973        profiler.sample_compute(session_id)?;
974        profiler.sample_compute(session_id)?;
975
976        let report = profiler.end_profiling(session_id)?;
977
978        assert!(report.compute_profile.utilization_timeline.len() >= 2);
979        assert!(report.compute_profile.average_cpu_utilization > 0.0);
980
981        Ok(())
982    }
983
984    #[test]
985    fn test_optimization_suggestions() {
986        let suggestion = OptimizationSuggestion {
987            suggestion_type: OptimizationType::MemoryOptimization,
988            priority: OptimizationPriority::High,
989            description: "Test suggestion".to_string(),
990            expected_improvement: 0.3,
991            implementation_complexity: ComplexityLevel::Medium,
992        };
993
994        assert_eq!(suggestion.expected_improvement, 0.3);
995        assert!(matches!(suggestion.priority, OptimizationPriority::High));
996        assert!(matches!(
997            suggestion.implementation_complexity,
998            ComplexityLevel::Medium
999        ));
1000    }
1001
1002    #[test]
1003    fn test_bottleneck_analysis() {
1004        let analysis = BottleneckAnalysis {
1005            primary_bottleneck: BottleneckType::Memory,
1006            bottleneck_severity: 0.8,
1007            affected_operations: vec!["attention".to_string()],
1008            bottleneck_timeline: Vec::new(),
1009        };
1010
1011        assert!(matches!(
1012            analysis.primary_bottleneck,
1013            BottleneckType::Memory
1014        ));
1015        assert_eq!(analysis.bottleneck_severity, 0.8);
1016        assert_eq!(analysis.affected_operations.len(), 1);
1017    }
1018
1019    #[test]
1020    fn test_layer_timing_calculation() -> Result<()> {
1021        let mut profiler = ModelProfiler::new();
1022        let session_id = "test_session";
1023
1024        profiler.start_profiling(session_id)?;
1025
1026        // Profile the same layer multiple times
1027        for _ in 0..5 {
1028            profiler.profile_layer(session_id, "test_layer", || {
1029                std::thread::sleep(Duration::from_millis(10));
1030                Ok(())
1031            })?;
1032        }
1033
1034        let report = profiler.end_profiling(session_id)?;
1035
1036        if let Some(timing) = report.layer_performance.layer_timings.get("test_layer") {
1037            assert_eq!(timing.call_count, 5);
1038            assert!(timing.average_time >= Duration::from_millis(8)); // Allow some variance
1039            assert!(timing.total_time >= Duration::from_millis(40));
1040        } else {
1041            panic!("Layer timing not found");
1042        }
1043
1044        Ok(())
1045    }
1046}