torsh_core/
profiling.rs

1//! Performance Profiling Hooks for ToRSh Operations
2//!
3//! This module provides comprehensive performance profiling capabilities for tensor operations,
4//! including operation timing, memory bandwidth tracking, cache analysis, and performance
5//! bottleneck identification.
6
7use crate::error::{Result, TorshError};
8use std::collections::{HashMap, VecDeque};
9use std::fmt;
10use std::sync::{Arc, Mutex, OnceLock};
11use std::thread;
12use std::time::{Duration, Instant};
13
14/// Global profiler instance
15static PROFILER: OnceLock<Arc<Mutex<PerformanceProfiler>>> = OnceLock::new();
16
17/// Performance profiling configuration
18#[derive(Debug, Clone)]
19pub struct ProfilerConfig {
20    /// Whether profiling is enabled
21    pub enabled: bool,
22    /// Maximum number of operation records to keep
23    pub max_records: usize,
24    /// Whether to capture stack traces for operations
25    pub capture_stack_traces: bool,
26    /// Whether to track memory bandwidth
27    pub track_memory_bandwidth: bool,
28    /// Whether to track cache performance
29    pub track_cache_performance: bool,
30    /// Minimum operation duration to record (filter out very fast operations)
31    pub min_duration_ns: u64,
32    /// Whether to aggregate similar operations
33    pub aggregate_similar_ops: bool,
34}
35
36impl Default for ProfilerConfig {
37    fn default() -> Self {
38        Self {
39            enabled: true,
40            max_records: 10_000,
41            capture_stack_traces: false,
42            track_memory_bandwidth: true,
43            track_cache_performance: true,
44            min_duration_ns: 1_000, // 1 microsecond
45            aggregate_similar_ops: true,
46        }
47    }
48}
49
50/// Type of operation being profiled
51#[derive(Debug, Clone, PartialEq, Eq, Hash)]
52pub enum OperationType {
53    /// Tensor creation operations
54    Creation(String),
55    /// Mathematical operations
56    Math(String),
57    /// Memory operations (copy, move, etc.)
58    Memory(String),
59    /// Shape operations (reshape, transpose, etc.)
60    Shape(String),
61    /// Reduction operations (sum, mean, etc.)
62    Reduction(String),
63    /// Neural network operations
64    Neural(String),
65    /// Backend operations
66    Backend(String),
67    /// Custom operation
68    Custom(String),
69}
70
71impl fmt::Display for OperationType {
72    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73        match self {
74            OperationType::Creation(name) => write!(f, "Creation::{name}"),
75            OperationType::Math(name) => write!(f, "Math::{name}"),
76            OperationType::Memory(name) => write!(f, "Memory::{name}"),
77            OperationType::Shape(name) => write!(f, "Shape::{name}"),
78            OperationType::Reduction(name) => write!(f, "Reduction::{name}"),
79            OperationType::Neural(name) => write!(f, "Neural::{name}"),
80            OperationType::Backend(name) => write!(f, "Backend::{name}"),
81            OperationType::Custom(name) => write!(f, "Custom::{name}"),
82        }
83    }
84}
85
86/// Performance record for a single operation
87#[derive(Debug, Clone)]
88pub struct OperationRecord {
89    /// Unique operation ID
90    pub id: u64,
91    /// Type of operation
92    pub operation_type: OperationType,
93    /// Duration of the operation
94    pub duration: Duration,
95    /// Memory bandwidth (bytes/second) if tracked
96    pub memory_bandwidth: Option<f64>,
97    /// Cache hit rate if tracked
98    pub cache_hit_rate: Option<f64>,
99    /// Input tensor sizes
100    pub input_sizes: Vec<usize>,
101    /// Output tensor size
102    pub output_size: Option<usize>,
103    /// Thread ID that executed the operation
104    pub thread_id: thread::ThreadId,
105    /// Timestamp when operation started
106    pub timestamp: Instant,
107    /// Stack trace if captured
108    pub stack_trace: Option<String>,
109    /// Custom metadata
110    pub metadata: HashMap<String, String>,
111}
112
113/// Aggregated performance statistics for operation types
114#[derive(Debug, Clone)]
115pub struct OperationStats {
116    /// Operation type
117    pub operation_type: OperationType,
118    /// Number of times this operation was executed
119    pub count: u64,
120    /// Total time spent in this operation
121    pub total_duration: Duration,
122    /// Minimum execution time
123    pub min_duration: Duration,
124    /// Maximum execution time
125    pub max_duration: Duration,
126    /// Average execution time
127    pub avg_duration: Duration,
128    /// 50th percentile (median) execution time
129    pub p50_duration: Duration,
130    /// 95th percentile execution time
131    pub p95_duration: Duration,
132    /// 99th percentile execution time
133    pub p99_duration: Duration,
134    /// Average memory bandwidth
135    pub avg_memory_bandwidth: Option<f64>,
136    /// Average cache hit rate
137    pub avg_cache_hit_rate: Option<f64>,
138    /// Total bytes processed
139    pub total_bytes: usize,
140}
141
142/// Performance bottleneck identification
143#[derive(Debug, Clone)]
144pub struct PerformanceBottleneck {
145    /// Operation type causing the bottleneck
146    pub operation_type: OperationType,
147    /// Percentage of total time spent in this operation
148    pub time_percentage: f64,
149    /// Number of times this operation was called
150    pub call_count: u64,
151    /// Average duration per call
152    pub avg_duration: Duration,
153    /// Suggested optimization
154    pub optimization_suggestion: String,
155}
156
157/// Main performance profiler
158pub struct PerformanceProfiler {
159    /// Configuration
160    config: ProfilerConfig,
161    /// Operation records
162    records: VecDeque<OperationRecord>,
163    /// Aggregated statistics
164    stats: HashMap<OperationType, OperationStats>,
165    /// Next operation ID
166    next_id: u64,
167    /// Total profiling overhead
168    overhead_ns: u64,
169    /// Profiler start time
170    start_time: Instant,
171}
172
173impl PerformanceProfiler {
174    /// Create a new performance profiler
175    pub fn new(config: ProfilerConfig) -> Self {
176        Self {
177            config,
178            records: VecDeque::new(),
179            stats: HashMap::new(),
180            next_id: 1,
181            overhead_ns: 0,
182            start_time: Instant::now(),
183        }
184    }
185
186    /// Start profiling an operation
187    pub fn start_operation(&mut self, operation_type: OperationType) -> OperationHandle {
188        if !self.config.enabled {
189            return OperationHandle::disabled();
190        }
191
192        let start_time = Instant::now();
193        let id = self.next_id;
194        self.next_id += 1;
195
196        OperationHandle {
197            id,
198            operation_type,
199            start_time,
200            enabled: true,
201        }
202    }
203
204    /// Finish profiling an operation
205    pub fn finish_operation(&mut self, handle: OperationHandle, context: OperationContext) {
206        if !handle.enabled || !self.config.enabled {
207            return;
208        }
209
210        let profile_start = Instant::now();
211        let duration = handle.start_time.elapsed();
212
213        // Filter out very fast operations if configured
214        if duration.as_nanos() < self.config.min_duration_ns as u128 {
215            self.overhead_ns += profile_start.elapsed().as_nanos() as u64;
216            return;
217        }
218
219        let memory_bandwidth = if self.config.track_memory_bandwidth {
220            context.calculate_memory_bandwidth(duration)
221        } else {
222            None
223        };
224
225        let cache_hit_rate = if self.config.track_cache_performance {
226            context.cache_hit_rate
227        } else {
228            None
229        };
230
231        let stack_trace = if self.config.capture_stack_traces {
232            Some(capture_stack_trace())
233        } else {
234            None
235        };
236
237        let record = OperationRecord {
238            id: handle.id,
239            operation_type: handle.operation_type.clone(),
240            duration,
241            memory_bandwidth,
242            cache_hit_rate,
243            input_sizes: context.input_sizes,
244            output_size: context.output_size,
245            thread_id: thread::current().id(),
246            timestamp: handle.start_time,
247            stack_trace,
248            metadata: context.metadata,
249        };
250
251        // Add to records
252        self.records.push_back(record.clone());
253
254        // Maintain max records limit
255        if self.records.len() > self.config.max_records {
256            self.records.pop_front();
257        }
258
259        // Update aggregated statistics
260        self.update_stats(&record);
261
262        self.overhead_ns += profile_start.elapsed().as_nanos() as u64;
263    }
264
265    /// Get aggregated statistics for all operations
266    pub fn get_stats(&self) -> HashMap<OperationType, OperationStats> {
267        self.stats.clone()
268    }
269
270    /// Get all operation records
271    pub fn get_records(&self) -> Vec<OperationRecord> {
272        self.records.iter().cloned().collect()
273    }
274
275    /// Generate a performance report
276    pub fn generate_report(&self) -> String {
277        let mut report = String::new();
278        report.push_str("=== Performance Profile Report ===\n\n");
279
280        let total_duration = self.start_time.elapsed();
281        report.push_str(&format!("Profiling Duration: {total_duration:.2?}\n"));
282        let total_ops = self.records.len();
283        report.push_str(&format!("Total Operations: {total_ops}\n"));
284        let overhead_us = self.overhead_ns as f64 / 1000.0;
285        report.push_str(&format!("Profiling Overhead: {overhead_us:.2} µs\n"));
286
287        // Top operations by total time
288        let mut sorted_stats: Vec<_> = self.stats.values().collect();
289        sorted_stats.sort_by(|a, b| b.total_duration.cmp(&a.total_duration));
290
291        report.push_str("\nTop Operations by Total Time:\n");
292        for (i, stat) in sorted_stats.iter().take(10).enumerate() {
293            let percentage =
294                (stat.total_duration.as_nanos() as f64 / total_duration.as_nanos() as f64) * 100.0;
295            let idx = i + 1;
296            let op_type = &stat.operation_type;
297            let total_dur = stat.total_duration;
298            let count = stat.count;
299            let avg_dur = stat.avg_duration;
300            report.push_str(&format!(
301                "  {idx}. {op_type} - {total_dur:.2?} ({percentage:.1}%, {count} calls, avg: {avg_dur:.2?})\n"
302            ));
303        }
304
305        // Performance bottlenecks
306        let bottlenecks = self.identify_bottlenecks();
307        if !bottlenecks.is_empty() {
308            report.push_str("\nPerformance Bottlenecks:\n");
309            for bottleneck in bottlenecks.iter().take(5) {
310                let op_type = &bottleneck.operation_type;
311                let time_pct = bottleneck.time_percentage;
312                let call_count = bottleneck.call_count;
313                let suggestion = &bottleneck.optimization_suggestion;
314                report.push_str(&format!(
315                    "  - {op_type}: {time_pct:.1}% of total time ({call_count} calls)\n"
316                ));
317                report.push_str(&format!("    Suggestion: {suggestion}\n"));
318            }
319        }
320
321        // Memory bandwidth analysis
322        let avg_bandwidth = self.calculate_average_bandwidth();
323        if let Some(bandwidth) = avg_bandwidth {
324            report.push_str(&format!(
325                "\nAverage Memory Bandwidth: {bandwidth:.2} GB/s\n"
326            ));
327        }
328
329        // Cache performance
330        let avg_cache_hit_rate = self.calculate_average_cache_hit_rate();
331        if let Some(hit_rate) = avg_cache_hit_rate {
332            let hit_rate_percent = hit_rate * 100.0;
333            report.push_str(&format!("Average Cache Hit Rate: {hit_rate_percent:.1}%\n"));
334        }
335
336        report
337    }
338
339    /// Reset profiler state
340    pub fn reset(&mut self) {
341        self.records.clear();
342        self.stats.clear();
343        self.next_id = 1;
344        self.overhead_ns = 0;
345        self.start_time = Instant::now();
346    }
347
348    /// Update configuration
349    pub fn update_config(&mut self, config: ProfilerConfig) {
350        self.config = config;
351    }
352
353    fn update_stats(&mut self, record: &OperationRecord) {
354        let entry = self
355            .stats
356            .entry(record.operation_type.clone())
357            .or_insert_with(|| OperationStats {
358                operation_type: record.operation_type.clone(),
359                count: 0,
360                total_duration: Duration::ZERO,
361                min_duration: Duration::MAX,
362                max_duration: Duration::ZERO,
363                avg_duration: Duration::ZERO,
364                p50_duration: Duration::ZERO,
365                p95_duration: Duration::ZERO,
366                p99_duration: Duration::ZERO,
367                avg_memory_bandwidth: None,
368                avg_cache_hit_rate: None,
369                total_bytes: 0,
370            });
371
372        entry.count += 1;
373        entry.total_duration += record.duration;
374        entry.min_duration = entry.min_duration.min(record.duration);
375        entry.max_duration = entry.max_duration.max(record.duration);
376        entry.avg_duration = entry.total_duration / entry.count as u32;
377
378        if let Some(bandwidth) = record.memory_bandwidth {
379            entry.avg_memory_bandwidth = Some(
380                entry.avg_memory_bandwidth.unwrap_or(0.0)
381                    + (bandwidth - entry.avg_memory_bandwidth.unwrap_or(0.0)) / entry.count as f64,
382            );
383        }
384
385        if let Some(cache_rate) = record.cache_hit_rate {
386            entry.avg_cache_hit_rate = Some(
387                entry.avg_cache_hit_rate.unwrap_or(0.0)
388                    + (cache_rate - entry.avg_cache_hit_rate.unwrap_or(0.0)) / entry.count as f64,
389            );
390        }
391
392        // Update percentiles (simplified calculation)
393        let durations: Vec<Duration> = self
394            .records
395            .iter()
396            .filter(|r| r.operation_type == record.operation_type)
397            .map(|r| r.duration)
398            .collect();
399
400        if !durations.is_empty() {
401            let mut sorted_durations = durations.clone();
402            sorted_durations.sort();
403
404            let p50_idx = (sorted_durations.len() * 50) / 100;
405            let p95_idx = (sorted_durations.len() * 95) / 100;
406            let p99_idx = (sorted_durations.len() * 99) / 100;
407
408            entry.p50_duration = sorted_durations
409                .get(p50_idx)
410                .copied()
411                .unwrap_or(Duration::ZERO);
412            entry.p95_duration = sorted_durations
413                .get(p95_idx)
414                .copied()
415                .unwrap_or(Duration::ZERO);
416            entry.p99_duration = sorted_durations
417                .get(p99_idx)
418                .copied()
419                .unwrap_or(Duration::ZERO);
420        }
421
422        // Update total bytes
423        let total_input_bytes: usize = record.input_sizes.iter().sum();
424        let total_bytes = total_input_bytes + record.output_size.unwrap_or(0);
425        entry.total_bytes += total_bytes;
426    }
427
428    fn identify_bottlenecks(&self) -> Vec<PerformanceBottleneck> {
429        let total_time = self.start_time.elapsed();
430        let mut bottlenecks = Vec::new();
431
432        for stat in self.stats.values() {
433            let time_percentage =
434                (stat.total_duration.as_nanos() as f64 / total_time.as_nanos() as f64) * 100.0;
435
436            if time_percentage > 5.0 {
437                // Consider anything >5% of total time as a potential bottleneck
438                let suggestion = generate_optimization_suggestion(&stat.operation_type, stat);
439
440                bottlenecks.push(PerformanceBottleneck {
441                    operation_type: stat.operation_type.clone(),
442                    time_percentage,
443                    call_count: stat.count,
444                    avg_duration: stat.avg_duration,
445                    optimization_suggestion: suggestion,
446                });
447            }
448        }
449
450        bottlenecks.sort_by(|a, b| {
451            b.time_percentage
452                .partial_cmp(&a.time_percentage)
453                .unwrap_or(std::cmp::Ordering::Equal)
454        });
455        bottlenecks
456    }
457
458    fn calculate_average_bandwidth(&self) -> Option<f64> {
459        let bandwidths: Vec<f64> = self
460            .records
461            .iter()
462            .filter_map(|r| r.memory_bandwidth)
463            .collect();
464
465        if bandwidths.is_empty() {
466            None
467        } else {
468            Some(bandwidths.iter().sum::<f64>() / bandwidths.len() as f64)
469        }
470    }
471
472    fn calculate_average_cache_hit_rate(&self) -> Option<f64> {
473        let hit_rates: Vec<f64> = self
474            .records
475            .iter()
476            .filter_map(|r| r.cache_hit_rate)
477            .collect();
478
479        if hit_rates.is_empty() {
480            None
481        } else {
482            Some(hit_rates.iter().sum::<f64>() / hit_rates.len() as f64)
483        }
484    }
485}
486
487/// Handle for an operation being profiled
488pub struct OperationHandle {
489    id: u64,
490    operation_type: OperationType,
491    start_time: Instant,
492    enabled: bool,
493}
494
495impl OperationHandle {
496    fn disabled() -> Self {
497        Self {
498            id: 0,
499            operation_type: OperationType::Custom("disabled".to_string()),
500            start_time: Instant::now(),
501            enabled: false,
502        }
503    }
504}
505
506/// Context information for an operation
507pub struct OperationContext {
508    /// Input tensor sizes in bytes
509    pub input_sizes: Vec<usize>,
510    /// Output tensor size in bytes
511    pub output_size: Option<usize>,
512    /// Cache hit rate if available
513    pub cache_hit_rate: Option<f64>,
514    /// Custom metadata
515    pub metadata: HashMap<String, String>,
516}
517
518impl OperationContext {
519    pub fn new() -> Self {
520        Self {
521            input_sizes: Vec::new(),
522            output_size: None,
523            cache_hit_rate: None,
524            metadata: HashMap::new(),
525        }
526    }
527
528    pub fn with_input_size(mut self, size: usize) -> Self {
529        self.input_sizes.push(size);
530        self
531    }
532
533    pub fn with_output_size(mut self, size: usize) -> Self {
534        self.output_size = Some(size);
535        self
536    }
537
538    pub fn with_cache_hit_rate(mut self, rate: f64) -> Self {
539        self.cache_hit_rate = Some(rate);
540        self
541    }
542
543    pub fn with_metadata(mut self, key: String, value: String) -> Self {
544        self.metadata.insert(key, value);
545        self
546    }
547
548    fn calculate_memory_bandwidth(&self, duration: Duration) -> Option<f64> {
549        let total_bytes: usize =
550            self.input_sizes.iter().sum::<usize>() + self.output_size.unwrap_or(0);
551
552        if total_bytes == 0 || duration.is_zero() {
553            return None;
554        }
555
556        let duration_secs = duration.as_secs_f64();
557        let bandwidth_bytes_per_sec = total_bytes as f64 / duration_secs;
558        let bandwidth_gb_per_sec = bandwidth_bytes_per_sec / 1_000_000_000.0;
559
560        Some(bandwidth_gb_per_sec)
561    }
562}
563
564impl Default for OperationContext {
565    fn default() -> Self {
566        Self::new()
567    }
568}
569
570/// Generate optimization suggestions based on operation type and statistics
571fn generate_optimization_suggestion(op_type: &OperationType, stats: &OperationStats) -> String {
572    match op_type {
573        OperationType::Math(name) => {
574            if stats.avg_duration > Duration::from_millis(10) {
575                format!("Consider using SIMD optimizations for {name} operations")
576            } else if let Some(bandwidth) = stats.avg_memory_bandwidth {
577                if bandwidth < 10.0 {
578                    "Memory bandwidth is low - consider batching operations".to_string()
579                } else {
580                    "Consider using more efficient algorithms or caching".to_string()
581                }
582            } else {
583                "Consider optimizing algorithm or using specialized libraries".to_string()
584            }
585        }
586        OperationType::Memory(name) => {
587            if let Some(bandwidth) = stats.avg_memory_bandwidth {
588                if bandwidth < 20.0 {
589                    format!(
590                        "Memory bandwidth for {name} is low - consider memory layout optimization"
591                    )
592                } else {
593                    "Consider reducing memory allocations or using memory pools".to_string()
594                }
595            } else {
596                "Consider optimizing memory access patterns".to_string()
597            }
598        }
599        OperationType::Shape(name) => {
600            if stats.count > 1000 {
601                format!("High frequency {name} operations - consider caching or batching")
602            } else {
603                "Consider optimizing shape operations with compile-time checks".to_string()
604            }
605        }
606        OperationType::Neural(name) => {
607            format!("Consider using specialized neural network libraries for {name} operations")
608        }
609        _ => "Consider profiling individual sub-operations to identify bottlenecks".to_string(),
610    }
611}
612
613/// Capture stack trace (simplified implementation)
614fn capture_stack_trace() -> String {
615    // In a real implementation, this would capture the actual stack trace
616    // For now, we'll return a placeholder
617    let binding = std::thread::current();
618    let thread_name = binding.name().unwrap_or("unknown");
619    format!("Stack trace captured at {thread_name}")
620}
621
622/// Global profiler access functions
623pub fn get_profiler() -> Arc<Mutex<PerformanceProfiler>> {
624    PROFILER
625        .get_or_init(|| {
626            Arc::new(Mutex::new(PerformanceProfiler::new(
627                ProfilerConfig::default(),
628            )))
629        })
630        .clone()
631}
632
633/// Initialize the global profiler with custom configuration
634pub fn init_profiler(config: ProfilerConfig) -> Result<()> {
635    if PROFILER.get().is_some() {
636        return Err(TorshError::InvalidState(
637            "Profiler already initialized".to_string(),
638        ));
639    }
640
641    PROFILER
642        .set(Arc::new(Mutex::new(PerformanceProfiler::new(config))))
643        .map_err(|_| TorshError::InvalidState("Failed to initialize profiler".to_string()))?;
644
645    Ok(())
646}
647
648/// Convenience macro for profiling operations
649#[macro_export]
650macro_rules! profile_operation {
651    ($op_type:expr, $context:expr, $body:expr) => {{
652        let profiler = $crate::profiling::get_profiler();
653        let handle = {
654            let mut p = profiler.lock().unwrap();
655            p.start_operation($op_type)
656        };
657
658        let result = $body;
659
660        {
661            let mut p = profiler.lock().unwrap();
662            p.finish_operation(handle, $context);
663        }
664
665        result
666    }};
667}
668
669/// Convenience function for profiling a closure
670pub fn profile_closure<F, R>(op_type: OperationType, context: OperationContext, closure: F) -> R
671where
672    F: FnOnce() -> R,
673{
674    let profiler = get_profiler();
675    let handle = {
676        let mut p = profiler.lock().unwrap();
677        p.start_operation(op_type)
678    };
679
680    let result = closure();
681
682    {
683        let mut p = profiler.lock().unwrap();
684        p.finish_operation(handle, context);
685    }
686
687    result
688}
689
690/// Shape-specific performance metrics collection
691#[derive(Debug, Clone, Default)]
692pub struct ShapeMetrics {
693    /// Number of dimensions
694    pub ndim: usize,
695    /// Total number of elements
696    pub numel: usize,
697    /// Memory layout efficiency (0.0-1.0)
698    pub layout_efficiency: f64,
699    /// Broadcasting complexity score
700    pub broadcast_complexity: f64,
701    /// SIMD vectorization efficiency
702    pub simd_efficiency: Option<f64>,
703    /// Cache locality score
704    pub cache_locality: Option<f64>,
705}
706
707impl ShapeMetrics {
708    /// Create new shape metrics
709    pub fn new(ndim: usize, numel: usize) -> Self {
710        Self {
711            ndim,
712            numel,
713            layout_efficiency: 1.0,    // Default to perfect efficiency
714            broadcast_complexity: 0.0, // No broadcasting
715            simd_efficiency: None,
716            cache_locality: None,
717        }
718    }
719
720    /// Set layout efficiency score
721    pub fn with_layout_efficiency(mut self, efficiency: f64) -> Self {
722        self.layout_efficiency = efficiency.clamp(0.0, 1.0);
723        self
724    }
725
726    /// Set broadcasting complexity score
727    pub fn with_broadcast_complexity(mut self, complexity: f64) -> Self {
728        self.broadcast_complexity = complexity.max(0.0);
729        self
730    }
731
732    /// Set SIMD efficiency score
733    pub fn with_simd_efficiency(mut self, efficiency: f64) -> Self {
734        self.simd_efficiency = Some(efficiency.clamp(0.0, 1.0));
735        self
736    }
737
738    /// Set cache locality score
739    pub fn with_cache_locality(mut self, locality: f64) -> Self {
740        self.cache_locality = Some(locality.clamp(0.0, 1.0));
741        self
742    }
743
744    /// Calculate overall performance score
745    pub fn performance_score(&self) -> f64 {
746        let mut score = self.layout_efficiency;
747
748        // Penalize for broadcasting complexity
749        score *= 1.0 - (self.broadcast_complexity / 10.0).min(0.5);
750
751        // Boost for SIMD efficiency
752        if let Some(simd) = self.simd_efficiency {
753            score *= 1.0 + simd * 0.2;
754        }
755
756        // Boost for cache locality
757        if let Some(cache) = self.cache_locality {
758            score *= 1.0 + cache * 0.1;
759        }
760
761        score.clamp(0.0, 1.0)
762    }
763}
764
765/// Shape operation performance tracker
766#[derive(Debug)]
767pub struct ShapePerformanceTracker {
768    /// Shape operation records
769    records: VecDeque<ShapeOperationRecord>,
770    /// Maximum number of records to keep
771    max_records: usize,
772    /// Aggregate statistics by operation type
773    aggregates: HashMap<String, ShapeOperationAggregate>,
774}
775
776/// Record for a shape operation
777#[derive(Debug, Clone)]
778pub struct ShapeOperationRecord {
779    /// Operation name
780    pub operation: String,
781    /// Operation duration
782    pub duration: Duration,
783    /// Shape metrics
784    pub metrics: ShapeMetrics,
785    /// Timestamp
786    pub timestamp: Instant,
787    /// Thread ID
788    pub thread_id: std::thread::ThreadId,
789}
790
791/// Aggregate statistics for a shape operation type
792#[derive(Debug, Clone)]
793pub struct ShapeOperationAggregate {
794    /// Number of operations
795    pub count: usize,
796    /// Total duration
797    pub total_duration: Duration,
798    /// Average duration
799    pub avg_duration: Duration,
800    /// Min duration
801    pub min_duration: Duration,
802    /// Max duration
803    pub max_duration: Duration,
804    /// Average performance score
805    pub avg_performance_score: f64,
806    /// Best performance score
807    pub best_performance_score: f64,
808    /// Worst performance score
809    pub worst_performance_score: f64,
810}
811
812impl ShapePerformanceTracker {
813    /// Create a new shape performance tracker
814    pub fn new(max_records: usize) -> Self {
815        Self {
816            records: VecDeque::with_capacity(max_records),
817            max_records,
818            aggregates: HashMap::new(),
819        }
820    }
821
822    /// Record a shape operation
823    pub fn record_operation(
824        &mut self,
825        operation: String,
826        duration: Duration,
827        metrics: ShapeMetrics,
828    ) {
829        let record = ShapeOperationRecord {
830            operation: operation.clone(),
831            duration,
832            metrics: metrics.clone(),
833            timestamp: Instant::now(),
834            thread_id: std::thread::current().id(),
835        };
836
837        // Add to records (with size limit)
838        if self.records.len() >= self.max_records {
839            self.records.pop_front();
840        }
841        self.records.push_back(record);
842
843        // Update aggregates
844        let performance_score = metrics.performance_score();
845        let aggregate =
846            self.aggregates
847                .entry(operation)
848                .or_insert_with(|| ShapeOperationAggregate {
849                    count: 0,
850                    total_duration: Duration::ZERO,
851                    avg_duration: Duration::ZERO,
852                    min_duration: duration,
853                    max_duration: duration,
854                    avg_performance_score: performance_score,
855                    best_performance_score: performance_score,
856                    worst_performance_score: performance_score,
857                });
858
859        aggregate.count += 1;
860        aggregate.total_duration += duration;
861        aggregate.avg_duration = aggregate.total_duration / aggregate.count as u32;
862        aggregate.min_duration = aggregate.min_duration.min(duration);
863        aggregate.max_duration = aggregate.max_duration.max(duration);
864
865        // Update performance scores
866        let total_score =
867            aggregate.avg_performance_score * (aggregate.count - 1) as f64 + performance_score;
868        aggregate.avg_performance_score = total_score / aggregate.count as f64;
869        aggregate.best_performance_score = aggregate.best_performance_score.max(performance_score);
870        aggregate.worst_performance_score =
871            aggregate.worst_performance_score.min(performance_score);
872    }
873
874    /// Get recent records
875    pub fn get_records(&self) -> Vec<ShapeOperationRecord> {
876        self.records.iter().cloned().collect()
877    }
878
879    /// Get aggregate statistics
880    pub fn get_aggregates(&self) -> &HashMap<String, ShapeOperationAggregate> {
881        &self.aggregates
882    }
883
884    /// Generate performance report
885    pub fn generate_report(&self) -> String {
886        let mut report = String::new();
887        report.push_str("=== Shape Operations Performance Report ===\n\n");
888
889        report.push_str(&format!("Total Records: {}\n", self.records.len()));
890        report.push_str(&format!("Operation Types: {}\n\n", self.aggregates.len()));
891
892        // Sort aggregates by average performance score (worst first)
893        let mut sorted_ops: Vec<_> = self.aggregates.iter().collect();
894        sorted_ops.sort_by(|a, b| {
895            a.1.avg_performance_score
896                .partial_cmp(&b.1.avg_performance_score)
897                .unwrap()
898        });
899
900        report.push_str("Performance Summary (worst to best):\n");
901        for (op_name, aggregate) in sorted_ops {
902            report.push_str(&format!(
903                "  {}: {:.3} avg score, {:.2}ms avg time, {} calls\n",
904                op_name,
905                aggregate.avg_performance_score,
906                aggregate.avg_duration.as_secs_f64() * 1000.0,
907                aggregate.count
908            ));
909        }
910
911        report.push_str("\nDetailed Statistics:\n");
912        for (op_name, aggregate) in &self.aggregates {
913            report.push_str(&format!("\n{op_name}:\n"));
914            report.push_str(&format!("  Count: {}\n", aggregate.count));
915            report.push_str(&format!(
916                "  Avg Duration: {:.2}ms\n",
917                aggregate.avg_duration.as_secs_f64() * 1000.0
918            ));
919            report.push_str(&format!(
920                "  Min Duration: {:.2}ms\n",
921                aggregate.min_duration.as_secs_f64() * 1000.0
922            ));
923            report.push_str(&format!(
924                "  Max Duration: {:.2}ms\n",
925                aggregate.max_duration.as_secs_f64() * 1000.0
926            ));
927            report.push_str(&format!(
928                "  Avg Performance: {:.3}\n",
929                aggregate.avg_performance_score
930            ));
931            report.push_str(&format!(
932                "  Best Performance: {:.3}\n",
933                aggregate.best_performance_score
934            ));
935            report.push_str(&format!(
936                "  Worst Performance: {:.3}\n",
937                aggregate.worst_performance_score
938            ));
939        }
940
941        report
942    }
943
944    /// Find performance bottlenecks
945    pub fn find_bottlenecks(&self) -> Vec<(String, String)> {
946        let mut bottlenecks = Vec::new();
947
948        for (op_name, aggregate) in &self.aggregates {
949            // Check for poor performance scores
950            if aggregate.avg_performance_score < 0.5 {
951                bottlenecks.push((
952                    op_name.clone(),
953                    format!(
954                        "Low performance score: {:.3}",
955                        aggregate.avg_performance_score
956                    ),
957                ));
958            }
959
960            // Check for high variance in execution time
961            let duration_ratio =
962                aggregate.max_duration.as_secs_f64() / aggregate.min_duration.as_secs_f64();
963            if duration_ratio > 5.0 && aggregate.count > 10 {
964                bottlenecks.push((
965                    op_name.clone(),
966                    format!(
967                        "High variance: {duration_ratio:.1}x difference between min/max duration"
968                    ),
969                ));
970            }
971
972            // Check for frequent operations that could benefit from optimization
973            if aggregate.count > 100 && aggregate.avg_duration.as_millis() > 1 {
974                bottlenecks.push((
975                    op_name.clone(),
976                    format!(
977                        "Frequent expensive operation: {} calls, {:.2}ms avg",
978                        aggregate.count,
979                        aggregate.avg_duration.as_secs_f64() * 1000.0
980                    ),
981                ));
982            }
983        }
984
985        bottlenecks
986    }
987
988    /// Get optimization suggestions
989    pub fn get_optimization_suggestions(&self) -> Vec<String> {
990        let mut suggestions = Vec::new();
991        let bottlenecks = self.find_bottlenecks();
992
993        for (op_name, issue) in bottlenecks {
994            if issue.contains("Low performance score") {
995                suggestions.push(format!(
996                    "Consider optimizing {op_name} - check memory layout and broadcasting efficiency"
997                ));
998            } else if issue.contains("High variance") {
999                suggestions.push(format!(
1000                    "Investigate {op_name} for inconsistent performance - possible cache/memory pressure issues"
1001                ));
1002            } else if issue.contains("Frequent expensive") {
1003                suggestions.push(format!(
1004                    "Profile {op_name} for optimization opportunities - consider caching or vectorization"
1005                ));
1006            }
1007        }
1008
1009        if suggestions.is_empty() {
1010            suggestions.push("No performance issues detected - good job!".to_string());
1011        }
1012
1013        suggestions
1014    }
1015}
1016
1017/// Global shape performance tracker
1018static SHAPE_TRACKER: OnceLock<Arc<Mutex<ShapePerformanceTracker>>> = OnceLock::new();
1019
1020/// Get or initialize the global shape performance tracker
1021pub fn get_shape_tracker() -> &'static Arc<Mutex<ShapePerformanceTracker>> {
1022    SHAPE_TRACKER.get_or_init(|| Arc::new(Mutex::new(ShapePerformanceTracker::new(10_000))))
1023}
1024
1025/// Profile a shape operation with automatic metrics collection
1026pub fn profile_shape_operation<F, R>(operation_name: &str, ndim: usize, numel: usize, f: F) -> R
1027where
1028    F: FnOnce() -> R,
1029{
1030    let start = Instant::now();
1031    let result = f();
1032    let duration = start.elapsed();
1033
1034    let metrics = ShapeMetrics::new(ndim, numel);
1035
1036    let tracker = get_shape_tracker();
1037    if let Ok(mut tracker) = tracker.lock() {
1038        tracker.record_operation(operation_name.to_string(), duration, metrics);
1039    }
1040
1041    result
1042}
1043
1044/// Profile a shape operation with custom metrics
1045pub fn profile_shape_operation_with_metrics<F, R>(
1046    operation_name: &str,
1047    metrics: ShapeMetrics,
1048    f: F,
1049) -> R
1050where
1051    F: FnOnce() -> R,
1052{
1053    let start = Instant::now();
1054    let result = f();
1055    let duration = start.elapsed();
1056
1057    let tracker = get_shape_tracker();
1058    if let Ok(mut tracker) = tracker.lock() {
1059        tracker.record_operation(operation_name.to_string(), duration, metrics);
1060    }
1061
1062    result
1063}
1064
1065/// Macro for easy shape operation profiling
1066#[macro_export]
1067macro_rules! profile_shape_op {
1068    ($op_name:expr, $ndim:expr, $numel:expr, $body:expr) => {
1069        $crate::profiling::profile_shape_operation($op_name, $ndim, $numel, || $body)
1070    };
1071    ($op_name:expr, $metrics:expr, $body:expr) => {
1072        $crate::profiling::profile_shape_operation_with_metrics($op_name, $metrics, || $body)
1073    };
1074}
1075
1076#[cfg(test)]
1077mod tests {
1078    use super::*;
1079    use std::thread;
1080    use std::time::Duration;
1081
1082    #[test]
1083    fn test_profiler_creation() {
1084        let profiler = PerformanceProfiler::new(ProfilerConfig::default());
1085        assert_eq!(profiler.records.len(), 0);
1086        assert_eq!(profiler.stats.len(), 0);
1087    }
1088
1089    #[test]
1090    fn test_operation_profiling() {
1091        let mut profiler = PerformanceProfiler::new(ProfilerConfig::default());
1092        let op_type = OperationType::Math("add".to_string());
1093
1094        let handle = profiler.start_operation(op_type.clone());
1095        thread::sleep(Duration::from_millis(1));
1096
1097        let context = OperationContext::new()
1098            .with_input_size(1000)
1099            .with_output_size(1000);
1100
1101        profiler.finish_operation(handle, context);
1102
1103        assert_eq!(profiler.records.len(), 1);
1104        assert!(profiler.stats.contains_key(&op_type));
1105    }
1106
1107    #[test]
1108    fn test_profiler_statistics() {
1109        let mut profiler = PerformanceProfiler::new(ProfilerConfig::default());
1110        let op_type = OperationType::Math("multiply".to_string());
1111
1112        // Profile multiple operations
1113        for _ in 0..3 {
1114            let handle = profiler.start_operation(op_type.clone());
1115            thread::sleep(Duration::from_millis(1));
1116
1117            let context = OperationContext::new()
1118                .with_input_size(500)
1119                .with_output_size(500);
1120
1121            profiler.finish_operation(handle, context);
1122        }
1123
1124        let stats = profiler.get_stats();
1125        let multiply_stats = stats.get(&op_type).unwrap();
1126
1127        assert_eq!(multiply_stats.count, 3);
1128        assert!(multiply_stats.total_duration > Duration::ZERO);
1129        assert!(multiply_stats.avg_duration > Duration::ZERO);
1130    }
1131
1132    #[test]
1133    fn test_bottleneck_identification() {
1134        let mut profiler = PerformanceProfiler::new(ProfilerConfig::default());
1135        let slow_op = OperationType::Math("slow_operation".to_string());
1136        let fast_op = OperationType::Math("fast_operation".to_string());
1137
1138        // Create a slow operation
1139        let handle = profiler.start_operation(slow_op.clone());
1140        thread::sleep(Duration::from_millis(10));
1141        profiler.finish_operation(handle, OperationContext::new());
1142
1143        // Create fast operations
1144        for _ in 0..5 {
1145            let handle = profiler.start_operation(fast_op.clone());
1146            thread::sleep(Duration::from_millis(1));
1147            profiler.finish_operation(handle, OperationContext::new());
1148        }
1149
1150        let bottlenecks = profiler.identify_bottlenecks();
1151        assert!(!bottlenecks.is_empty());
1152
1153        // The slow operation should be identified as a bottleneck
1154        assert!(bottlenecks.iter().any(|b| b.operation_type == slow_op));
1155    }
1156
1157    #[test]
1158    fn test_memory_bandwidth_calculation() {
1159        let context = OperationContext::new()
1160            .with_input_size(1000)
1161            .with_output_size(1000);
1162
1163        let duration = Duration::from_millis(1);
1164        let bandwidth = context.calculate_memory_bandwidth(duration);
1165
1166        assert!(bandwidth.is_some());
1167        assert!(bandwidth.unwrap() > 0.0);
1168    }
1169
1170    #[test]
1171    fn test_profile_closure() {
1172        let _profiler = get_profiler();
1173
1174        let result = profile_closure(
1175            OperationType::Math("test".to_string()),
1176            OperationContext::new(),
1177            || {
1178                thread::sleep(Duration::from_millis(1));
1179                42
1180            },
1181        );
1182
1183        assert_eq!(result, 42);
1184
1185        // Check that the operation was recorded
1186        let profiler = get_profiler();
1187        let records = {
1188            let p = profiler.lock().unwrap();
1189            p.get_records()
1190        };
1191
1192        assert!(!records.is_empty());
1193    }
1194
1195    #[test]
1196    fn test_profiler_report_generation() {
1197        let mut profiler = PerformanceProfiler::new(ProfilerConfig::default());
1198
1199        // Add some operations
1200        let handle = profiler.start_operation(OperationType::Math("add".to_string()));
1201        thread::sleep(Duration::from_millis(1));
1202        profiler.finish_operation(handle, OperationContext::new());
1203
1204        let report = profiler.generate_report();
1205        assert!(report.contains("Performance Profile Report"));
1206        assert!(report.contains("Total Operations: 1"));
1207        assert!(report.contains("Math::add"));
1208    }
1209}