1use crate::error::{Result, TorshError};
8use std::collections::{HashMap, VecDeque};
9use std::fmt;
10use std::sync::{Arc, Mutex, OnceLock};
11use std::thread;
12use std::time::{Duration, Instant};
13
14static PROFILER: OnceLock<Arc<Mutex<PerformanceProfiler>>> = OnceLock::new();
16
17#[derive(Debug, Clone)]
19pub struct ProfilerConfig {
20 pub enabled: bool,
22 pub max_records: usize,
24 pub capture_stack_traces: bool,
26 pub track_memory_bandwidth: bool,
28 pub track_cache_performance: bool,
30 pub min_duration_ns: u64,
32 pub aggregate_similar_ops: bool,
34}
35
36impl Default for ProfilerConfig {
37 fn default() -> Self {
38 Self {
39 enabled: true,
40 max_records: 10_000,
41 capture_stack_traces: false,
42 track_memory_bandwidth: true,
43 track_cache_performance: true,
44 min_duration_ns: 1_000, aggregate_similar_ops: true,
46 }
47 }
48}
49
50#[derive(Debug, Clone, PartialEq, Eq, Hash)]
52pub enum OperationType {
53 Creation(String),
55 Math(String),
57 Memory(String),
59 Shape(String),
61 Reduction(String),
63 Neural(String),
65 Backend(String),
67 Custom(String),
69}
70
71impl fmt::Display for OperationType {
72 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73 match self {
74 OperationType::Creation(name) => write!(f, "Creation::{name}"),
75 OperationType::Math(name) => write!(f, "Math::{name}"),
76 OperationType::Memory(name) => write!(f, "Memory::{name}"),
77 OperationType::Shape(name) => write!(f, "Shape::{name}"),
78 OperationType::Reduction(name) => write!(f, "Reduction::{name}"),
79 OperationType::Neural(name) => write!(f, "Neural::{name}"),
80 OperationType::Backend(name) => write!(f, "Backend::{name}"),
81 OperationType::Custom(name) => write!(f, "Custom::{name}"),
82 }
83 }
84}
85
86#[derive(Debug, Clone)]
88pub struct OperationRecord {
89 pub id: u64,
91 pub operation_type: OperationType,
93 pub duration: Duration,
95 pub memory_bandwidth: Option<f64>,
97 pub cache_hit_rate: Option<f64>,
99 pub input_sizes: Vec<usize>,
101 pub output_size: Option<usize>,
103 pub thread_id: thread::ThreadId,
105 pub timestamp: Instant,
107 pub stack_trace: Option<String>,
109 pub metadata: HashMap<String, String>,
111}
112
113#[derive(Debug, Clone)]
115pub struct OperationStats {
116 pub operation_type: OperationType,
118 pub count: u64,
120 pub total_duration: Duration,
122 pub min_duration: Duration,
124 pub max_duration: Duration,
126 pub avg_duration: Duration,
128 pub p50_duration: Duration,
130 pub p95_duration: Duration,
132 pub p99_duration: Duration,
134 pub avg_memory_bandwidth: Option<f64>,
136 pub avg_cache_hit_rate: Option<f64>,
138 pub total_bytes: usize,
140}
141
142#[derive(Debug, Clone)]
144pub struct PerformanceBottleneck {
145 pub operation_type: OperationType,
147 pub time_percentage: f64,
149 pub call_count: u64,
151 pub avg_duration: Duration,
153 pub optimization_suggestion: String,
155}
156
157pub struct PerformanceProfiler {
159 config: ProfilerConfig,
161 records: VecDeque<OperationRecord>,
163 stats: HashMap<OperationType, OperationStats>,
165 next_id: u64,
167 overhead_ns: u64,
169 start_time: Instant,
171}
172
173impl PerformanceProfiler {
174 pub fn new(config: ProfilerConfig) -> Self {
176 Self {
177 config,
178 records: VecDeque::new(),
179 stats: HashMap::new(),
180 next_id: 1,
181 overhead_ns: 0,
182 start_time: Instant::now(),
183 }
184 }
185
186 pub fn start_operation(&mut self, operation_type: OperationType) -> OperationHandle {
188 if !self.config.enabled {
189 return OperationHandle::disabled();
190 }
191
192 let start_time = Instant::now();
193 let id = self.next_id;
194 self.next_id += 1;
195
196 OperationHandle {
197 id,
198 operation_type,
199 start_time,
200 enabled: true,
201 }
202 }
203
204 pub fn finish_operation(&mut self, handle: OperationHandle, context: OperationContext) {
206 if !handle.enabled || !self.config.enabled {
207 return;
208 }
209
210 let profile_start = Instant::now();
211 let duration = handle.start_time.elapsed();
212
213 if duration.as_nanos() < self.config.min_duration_ns as u128 {
215 self.overhead_ns += profile_start.elapsed().as_nanos() as u64;
216 return;
217 }
218
219 let memory_bandwidth = if self.config.track_memory_bandwidth {
220 context.calculate_memory_bandwidth(duration)
221 } else {
222 None
223 };
224
225 let cache_hit_rate = if self.config.track_cache_performance {
226 context.cache_hit_rate
227 } else {
228 None
229 };
230
231 let stack_trace = if self.config.capture_stack_traces {
232 Some(capture_stack_trace())
233 } else {
234 None
235 };
236
237 let record = OperationRecord {
238 id: handle.id,
239 operation_type: handle.operation_type.clone(),
240 duration,
241 memory_bandwidth,
242 cache_hit_rate,
243 input_sizes: context.input_sizes,
244 output_size: context.output_size,
245 thread_id: thread::current().id(),
246 timestamp: handle.start_time,
247 stack_trace,
248 metadata: context.metadata,
249 };
250
251 self.records.push_back(record.clone());
253
254 if self.records.len() > self.config.max_records {
256 self.records.pop_front();
257 }
258
259 self.update_stats(&record);
261
262 self.overhead_ns += profile_start.elapsed().as_nanos() as u64;
263 }
264
265 pub fn get_stats(&self) -> HashMap<OperationType, OperationStats> {
267 self.stats.clone()
268 }
269
270 pub fn get_records(&self) -> Vec<OperationRecord> {
272 self.records.iter().cloned().collect()
273 }
274
275 pub fn generate_report(&self) -> String {
277 let mut report = String::new();
278 report.push_str("=== Performance Profile Report ===\n\n");
279
280 let total_duration = self.start_time.elapsed();
281 report.push_str(&format!("Profiling Duration: {total_duration:.2?}\n"));
282 let total_ops = self.records.len();
283 report.push_str(&format!("Total Operations: {total_ops}\n"));
284 let overhead_us = self.overhead_ns as f64 / 1000.0;
285 report.push_str(&format!("Profiling Overhead: {overhead_us:.2} µs\n"));
286
287 let mut sorted_stats: Vec<_> = self.stats.values().collect();
289 sorted_stats.sort_by(|a, b| b.total_duration.cmp(&a.total_duration));
290
291 report.push_str("\nTop Operations by Total Time:\n");
292 for (i, stat) in sorted_stats.iter().take(10).enumerate() {
293 let percentage =
294 (stat.total_duration.as_nanos() as f64 / total_duration.as_nanos() as f64) * 100.0;
295 let idx = i + 1;
296 let op_type = &stat.operation_type;
297 let total_dur = stat.total_duration;
298 let count = stat.count;
299 let avg_dur = stat.avg_duration;
300 report.push_str(&format!(
301 " {idx}. {op_type} - {total_dur:.2?} ({percentage:.1}%, {count} calls, avg: {avg_dur:.2?})\n"
302 ));
303 }
304
305 let bottlenecks = self.identify_bottlenecks();
307 if !bottlenecks.is_empty() {
308 report.push_str("\nPerformance Bottlenecks:\n");
309 for bottleneck in bottlenecks.iter().take(5) {
310 let op_type = &bottleneck.operation_type;
311 let time_pct = bottleneck.time_percentage;
312 let call_count = bottleneck.call_count;
313 let suggestion = &bottleneck.optimization_suggestion;
314 report.push_str(&format!(
315 " - {op_type}: {time_pct:.1}% of total time ({call_count} calls)\n"
316 ));
317 report.push_str(&format!(" Suggestion: {suggestion}\n"));
318 }
319 }
320
321 let avg_bandwidth = self.calculate_average_bandwidth();
323 if let Some(bandwidth) = avg_bandwidth {
324 report.push_str(&format!(
325 "\nAverage Memory Bandwidth: {bandwidth:.2} GB/s\n"
326 ));
327 }
328
329 let avg_cache_hit_rate = self.calculate_average_cache_hit_rate();
331 if let Some(hit_rate) = avg_cache_hit_rate {
332 let hit_rate_percent = hit_rate * 100.0;
333 report.push_str(&format!("Average Cache Hit Rate: {hit_rate_percent:.1}%\n"));
334 }
335
336 report
337 }
338
339 pub fn reset(&mut self) {
341 self.records.clear();
342 self.stats.clear();
343 self.next_id = 1;
344 self.overhead_ns = 0;
345 self.start_time = Instant::now();
346 }
347
348 pub fn update_config(&mut self, config: ProfilerConfig) {
350 self.config = config;
351 }
352
353 fn update_stats(&mut self, record: &OperationRecord) {
354 let entry = self
355 .stats
356 .entry(record.operation_type.clone())
357 .or_insert_with(|| OperationStats {
358 operation_type: record.operation_type.clone(),
359 count: 0,
360 total_duration: Duration::ZERO,
361 min_duration: Duration::MAX,
362 max_duration: Duration::ZERO,
363 avg_duration: Duration::ZERO,
364 p50_duration: Duration::ZERO,
365 p95_duration: Duration::ZERO,
366 p99_duration: Duration::ZERO,
367 avg_memory_bandwidth: None,
368 avg_cache_hit_rate: None,
369 total_bytes: 0,
370 });
371
372 entry.count += 1;
373 entry.total_duration += record.duration;
374 entry.min_duration = entry.min_duration.min(record.duration);
375 entry.max_duration = entry.max_duration.max(record.duration);
376 entry.avg_duration = entry.total_duration / entry.count as u32;
377
378 if let Some(bandwidth) = record.memory_bandwidth {
379 entry.avg_memory_bandwidth = Some(
380 entry.avg_memory_bandwidth.unwrap_or(0.0)
381 + (bandwidth - entry.avg_memory_bandwidth.unwrap_or(0.0)) / entry.count as f64,
382 );
383 }
384
385 if let Some(cache_rate) = record.cache_hit_rate {
386 entry.avg_cache_hit_rate = Some(
387 entry.avg_cache_hit_rate.unwrap_or(0.0)
388 + (cache_rate - entry.avg_cache_hit_rate.unwrap_or(0.0)) / entry.count as f64,
389 );
390 }
391
392 let durations: Vec<Duration> = self
394 .records
395 .iter()
396 .filter(|r| r.operation_type == record.operation_type)
397 .map(|r| r.duration)
398 .collect();
399
400 if !durations.is_empty() {
401 let mut sorted_durations = durations.clone();
402 sorted_durations.sort();
403
404 let p50_idx = (sorted_durations.len() * 50) / 100;
405 let p95_idx = (sorted_durations.len() * 95) / 100;
406 let p99_idx = (sorted_durations.len() * 99) / 100;
407
408 entry.p50_duration = sorted_durations
409 .get(p50_idx)
410 .copied()
411 .unwrap_or(Duration::ZERO);
412 entry.p95_duration = sorted_durations
413 .get(p95_idx)
414 .copied()
415 .unwrap_or(Duration::ZERO);
416 entry.p99_duration = sorted_durations
417 .get(p99_idx)
418 .copied()
419 .unwrap_or(Duration::ZERO);
420 }
421
422 let total_input_bytes: usize = record.input_sizes.iter().sum();
424 let total_bytes = total_input_bytes + record.output_size.unwrap_or(0);
425 entry.total_bytes += total_bytes;
426 }
427
428 fn identify_bottlenecks(&self) -> Vec<PerformanceBottleneck> {
429 let total_time = self.start_time.elapsed();
430 let mut bottlenecks = Vec::new();
431
432 for stat in self.stats.values() {
433 let time_percentage =
434 (stat.total_duration.as_nanos() as f64 / total_time.as_nanos() as f64) * 100.0;
435
436 if time_percentage > 5.0 {
437 let suggestion = generate_optimization_suggestion(&stat.operation_type, stat);
439
440 bottlenecks.push(PerformanceBottleneck {
441 operation_type: stat.operation_type.clone(),
442 time_percentage,
443 call_count: stat.count,
444 avg_duration: stat.avg_duration,
445 optimization_suggestion: suggestion,
446 });
447 }
448 }
449
450 bottlenecks.sort_by(|a, b| {
451 b.time_percentage
452 .partial_cmp(&a.time_percentage)
453 .unwrap_or(std::cmp::Ordering::Equal)
454 });
455 bottlenecks
456 }
457
458 fn calculate_average_bandwidth(&self) -> Option<f64> {
459 let bandwidths: Vec<f64> = self
460 .records
461 .iter()
462 .filter_map(|r| r.memory_bandwidth)
463 .collect();
464
465 if bandwidths.is_empty() {
466 None
467 } else {
468 Some(bandwidths.iter().sum::<f64>() / bandwidths.len() as f64)
469 }
470 }
471
472 fn calculate_average_cache_hit_rate(&self) -> Option<f64> {
473 let hit_rates: Vec<f64> = self
474 .records
475 .iter()
476 .filter_map(|r| r.cache_hit_rate)
477 .collect();
478
479 if hit_rates.is_empty() {
480 None
481 } else {
482 Some(hit_rates.iter().sum::<f64>() / hit_rates.len() as f64)
483 }
484 }
485}
486
487pub struct OperationHandle {
489 id: u64,
490 operation_type: OperationType,
491 start_time: Instant,
492 enabled: bool,
493}
494
495impl OperationHandle {
496 fn disabled() -> Self {
497 Self {
498 id: 0,
499 operation_type: OperationType::Custom("disabled".to_string()),
500 start_time: Instant::now(),
501 enabled: false,
502 }
503 }
504}
505
506pub struct OperationContext {
508 pub input_sizes: Vec<usize>,
510 pub output_size: Option<usize>,
512 pub cache_hit_rate: Option<f64>,
514 pub metadata: HashMap<String, String>,
516}
517
518impl OperationContext {
519 pub fn new() -> Self {
520 Self {
521 input_sizes: Vec::new(),
522 output_size: None,
523 cache_hit_rate: None,
524 metadata: HashMap::new(),
525 }
526 }
527
528 pub fn with_input_size(mut self, size: usize) -> Self {
529 self.input_sizes.push(size);
530 self
531 }
532
533 pub fn with_output_size(mut self, size: usize) -> Self {
534 self.output_size = Some(size);
535 self
536 }
537
538 pub fn with_cache_hit_rate(mut self, rate: f64) -> Self {
539 self.cache_hit_rate = Some(rate);
540 self
541 }
542
543 pub fn with_metadata(mut self, key: String, value: String) -> Self {
544 self.metadata.insert(key, value);
545 self
546 }
547
548 fn calculate_memory_bandwidth(&self, duration: Duration) -> Option<f64> {
549 let total_bytes: usize =
550 self.input_sizes.iter().sum::<usize>() + self.output_size.unwrap_or(0);
551
552 if total_bytes == 0 || duration.is_zero() {
553 return None;
554 }
555
556 let duration_secs = duration.as_secs_f64();
557 let bandwidth_bytes_per_sec = total_bytes as f64 / duration_secs;
558 let bandwidth_gb_per_sec = bandwidth_bytes_per_sec / 1_000_000_000.0;
559
560 Some(bandwidth_gb_per_sec)
561 }
562}
563
564impl Default for OperationContext {
565 fn default() -> Self {
566 Self::new()
567 }
568}
569
570fn generate_optimization_suggestion(op_type: &OperationType, stats: &OperationStats) -> String {
572 match op_type {
573 OperationType::Math(name) => {
574 if stats.avg_duration > Duration::from_millis(10) {
575 format!("Consider using SIMD optimizations for {name} operations")
576 } else if let Some(bandwidth) = stats.avg_memory_bandwidth {
577 if bandwidth < 10.0 {
578 "Memory bandwidth is low - consider batching operations".to_string()
579 } else {
580 "Consider using more efficient algorithms or caching".to_string()
581 }
582 } else {
583 "Consider optimizing algorithm or using specialized libraries".to_string()
584 }
585 }
586 OperationType::Memory(name) => {
587 if let Some(bandwidth) = stats.avg_memory_bandwidth {
588 if bandwidth < 20.0 {
589 format!(
590 "Memory bandwidth for {name} is low - consider memory layout optimization"
591 )
592 } else {
593 "Consider reducing memory allocations or using memory pools".to_string()
594 }
595 } else {
596 "Consider optimizing memory access patterns".to_string()
597 }
598 }
599 OperationType::Shape(name) => {
600 if stats.count > 1000 {
601 format!("High frequency {name} operations - consider caching or batching")
602 } else {
603 "Consider optimizing shape operations with compile-time checks".to_string()
604 }
605 }
606 OperationType::Neural(name) => {
607 format!("Consider using specialized neural network libraries for {name} operations")
608 }
609 _ => "Consider profiling individual sub-operations to identify bottlenecks".to_string(),
610 }
611}
612
613fn capture_stack_trace() -> String {
615 let binding = std::thread::current();
618 let thread_name = binding.name().unwrap_or("unknown");
619 format!("Stack trace captured at {thread_name}")
620}
621
622pub fn get_profiler() -> Arc<Mutex<PerformanceProfiler>> {
624 PROFILER
625 .get_or_init(|| {
626 Arc::new(Mutex::new(PerformanceProfiler::new(
627 ProfilerConfig::default(),
628 )))
629 })
630 .clone()
631}
632
633pub fn init_profiler(config: ProfilerConfig) -> Result<()> {
635 if PROFILER.get().is_some() {
636 return Err(TorshError::InvalidState(
637 "Profiler already initialized".to_string(),
638 ));
639 }
640
641 PROFILER
642 .set(Arc::new(Mutex::new(PerformanceProfiler::new(config))))
643 .map_err(|_| TorshError::InvalidState("Failed to initialize profiler".to_string()))?;
644
645 Ok(())
646}
647
648#[macro_export]
650macro_rules! profile_operation {
651 ($op_type:expr, $context:expr, $body:expr) => {{
652 let profiler = $crate::profiling::get_profiler();
653 let handle = {
654 let mut p = profiler.lock().unwrap();
655 p.start_operation($op_type)
656 };
657
658 let result = $body;
659
660 {
661 let mut p = profiler.lock().unwrap();
662 p.finish_operation(handle, $context);
663 }
664
665 result
666 }};
667}
668
669pub fn profile_closure<F, R>(op_type: OperationType, context: OperationContext, closure: F) -> R
671where
672 F: FnOnce() -> R,
673{
674 let profiler = get_profiler();
675 let handle = {
676 let mut p = profiler.lock().unwrap();
677 p.start_operation(op_type)
678 };
679
680 let result = closure();
681
682 {
683 let mut p = profiler.lock().unwrap();
684 p.finish_operation(handle, context);
685 }
686
687 result
688}
689
690#[derive(Debug, Clone, Default)]
692pub struct ShapeMetrics {
693 pub ndim: usize,
695 pub numel: usize,
697 pub layout_efficiency: f64,
699 pub broadcast_complexity: f64,
701 pub simd_efficiency: Option<f64>,
703 pub cache_locality: Option<f64>,
705}
706
707impl ShapeMetrics {
708 pub fn new(ndim: usize, numel: usize) -> Self {
710 Self {
711 ndim,
712 numel,
713 layout_efficiency: 1.0, broadcast_complexity: 0.0, simd_efficiency: None,
716 cache_locality: None,
717 }
718 }
719
720 pub fn with_layout_efficiency(mut self, efficiency: f64) -> Self {
722 self.layout_efficiency = efficiency.clamp(0.0, 1.0);
723 self
724 }
725
726 pub fn with_broadcast_complexity(mut self, complexity: f64) -> Self {
728 self.broadcast_complexity = complexity.max(0.0);
729 self
730 }
731
732 pub fn with_simd_efficiency(mut self, efficiency: f64) -> Self {
734 self.simd_efficiency = Some(efficiency.clamp(0.0, 1.0));
735 self
736 }
737
738 pub fn with_cache_locality(mut self, locality: f64) -> Self {
740 self.cache_locality = Some(locality.clamp(0.0, 1.0));
741 self
742 }
743
744 pub fn performance_score(&self) -> f64 {
746 let mut score = self.layout_efficiency;
747
748 score *= 1.0 - (self.broadcast_complexity / 10.0).min(0.5);
750
751 if let Some(simd) = self.simd_efficiency {
753 score *= 1.0 + simd * 0.2;
754 }
755
756 if let Some(cache) = self.cache_locality {
758 score *= 1.0 + cache * 0.1;
759 }
760
761 score.clamp(0.0, 1.0)
762 }
763}
764
765#[derive(Debug)]
767pub struct ShapePerformanceTracker {
768 records: VecDeque<ShapeOperationRecord>,
770 max_records: usize,
772 aggregates: HashMap<String, ShapeOperationAggregate>,
774}
775
776#[derive(Debug, Clone)]
778pub struct ShapeOperationRecord {
779 pub operation: String,
781 pub duration: Duration,
783 pub metrics: ShapeMetrics,
785 pub timestamp: Instant,
787 pub thread_id: std::thread::ThreadId,
789}
790
791#[derive(Debug, Clone)]
793pub struct ShapeOperationAggregate {
794 pub count: usize,
796 pub total_duration: Duration,
798 pub avg_duration: Duration,
800 pub min_duration: Duration,
802 pub max_duration: Duration,
804 pub avg_performance_score: f64,
806 pub best_performance_score: f64,
808 pub worst_performance_score: f64,
810}
811
812impl ShapePerformanceTracker {
813 pub fn new(max_records: usize) -> Self {
815 Self {
816 records: VecDeque::with_capacity(max_records),
817 max_records,
818 aggregates: HashMap::new(),
819 }
820 }
821
822 pub fn record_operation(
824 &mut self,
825 operation: String,
826 duration: Duration,
827 metrics: ShapeMetrics,
828 ) {
829 let record = ShapeOperationRecord {
830 operation: operation.clone(),
831 duration,
832 metrics: metrics.clone(),
833 timestamp: Instant::now(),
834 thread_id: std::thread::current().id(),
835 };
836
837 if self.records.len() >= self.max_records {
839 self.records.pop_front();
840 }
841 self.records.push_back(record);
842
843 let performance_score = metrics.performance_score();
845 let aggregate =
846 self.aggregates
847 .entry(operation)
848 .or_insert_with(|| ShapeOperationAggregate {
849 count: 0,
850 total_duration: Duration::ZERO,
851 avg_duration: Duration::ZERO,
852 min_duration: duration,
853 max_duration: duration,
854 avg_performance_score: performance_score,
855 best_performance_score: performance_score,
856 worst_performance_score: performance_score,
857 });
858
859 aggregate.count += 1;
860 aggregate.total_duration += duration;
861 aggregate.avg_duration = aggregate.total_duration / aggregate.count as u32;
862 aggregate.min_duration = aggregate.min_duration.min(duration);
863 aggregate.max_duration = aggregate.max_duration.max(duration);
864
865 let total_score =
867 aggregate.avg_performance_score * (aggregate.count - 1) as f64 + performance_score;
868 aggregate.avg_performance_score = total_score / aggregate.count as f64;
869 aggregate.best_performance_score = aggregate.best_performance_score.max(performance_score);
870 aggregate.worst_performance_score =
871 aggregate.worst_performance_score.min(performance_score);
872 }
873
874 pub fn get_records(&self) -> Vec<ShapeOperationRecord> {
876 self.records.iter().cloned().collect()
877 }
878
879 pub fn get_aggregates(&self) -> &HashMap<String, ShapeOperationAggregate> {
881 &self.aggregates
882 }
883
884 pub fn generate_report(&self) -> String {
886 let mut report = String::new();
887 report.push_str("=== Shape Operations Performance Report ===\n\n");
888
889 report.push_str(&format!("Total Records: {}\n", self.records.len()));
890 report.push_str(&format!("Operation Types: {}\n\n", self.aggregates.len()));
891
892 let mut sorted_ops: Vec<_> = self.aggregates.iter().collect();
894 sorted_ops.sort_by(|a, b| {
895 a.1.avg_performance_score
896 .partial_cmp(&b.1.avg_performance_score)
897 .unwrap()
898 });
899
900 report.push_str("Performance Summary (worst to best):\n");
901 for (op_name, aggregate) in sorted_ops {
902 report.push_str(&format!(
903 " {}: {:.3} avg score, {:.2}ms avg time, {} calls\n",
904 op_name,
905 aggregate.avg_performance_score,
906 aggregate.avg_duration.as_secs_f64() * 1000.0,
907 aggregate.count
908 ));
909 }
910
911 report.push_str("\nDetailed Statistics:\n");
912 for (op_name, aggregate) in &self.aggregates {
913 report.push_str(&format!("\n{op_name}:\n"));
914 report.push_str(&format!(" Count: {}\n", aggregate.count));
915 report.push_str(&format!(
916 " Avg Duration: {:.2}ms\n",
917 aggregate.avg_duration.as_secs_f64() * 1000.0
918 ));
919 report.push_str(&format!(
920 " Min Duration: {:.2}ms\n",
921 aggregate.min_duration.as_secs_f64() * 1000.0
922 ));
923 report.push_str(&format!(
924 " Max Duration: {:.2}ms\n",
925 aggregate.max_duration.as_secs_f64() * 1000.0
926 ));
927 report.push_str(&format!(
928 " Avg Performance: {:.3}\n",
929 aggregate.avg_performance_score
930 ));
931 report.push_str(&format!(
932 " Best Performance: {:.3}\n",
933 aggregate.best_performance_score
934 ));
935 report.push_str(&format!(
936 " Worst Performance: {:.3}\n",
937 aggregate.worst_performance_score
938 ));
939 }
940
941 report
942 }
943
944 pub fn find_bottlenecks(&self) -> Vec<(String, String)> {
946 let mut bottlenecks = Vec::new();
947
948 for (op_name, aggregate) in &self.aggregates {
949 if aggregate.avg_performance_score < 0.5 {
951 bottlenecks.push((
952 op_name.clone(),
953 format!(
954 "Low performance score: {:.3}",
955 aggregate.avg_performance_score
956 ),
957 ));
958 }
959
960 let duration_ratio =
962 aggregate.max_duration.as_secs_f64() / aggregate.min_duration.as_secs_f64();
963 if duration_ratio > 5.0 && aggregate.count > 10 {
964 bottlenecks.push((
965 op_name.clone(),
966 format!(
967 "High variance: {duration_ratio:.1}x difference between min/max duration"
968 ),
969 ));
970 }
971
972 if aggregate.count > 100 && aggregate.avg_duration.as_millis() > 1 {
974 bottlenecks.push((
975 op_name.clone(),
976 format!(
977 "Frequent expensive operation: {} calls, {:.2}ms avg",
978 aggregate.count,
979 aggregate.avg_duration.as_secs_f64() * 1000.0
980 ),
981 ));
982 }
983 }
984
985 bottlenecks
986 }
987
988 pub fn get_optimization_suggestions(&self) -> Vec<String> {
990 let mut suggestions = Vec::new();
991 let bottlenecks = self.find_bottlenecks();
992
993 for (op_name, issue) in bottlenecks {
994 if issue.contains("Low performance score") {
995 suggestions.push(format!(
996 "Consider optimizing {op_name} - check memory layout and broadcasting efficiency"
997 ));
998 } else if issue.contains("High variance") {
999 suggestions.push(format!(
1000 "Investigate {op_name} for inconsistent performance - possible cache/memory pressure issues"
1001 ));
1002 } else if issue.contains("Frequent expensive") {
1003 suggestions.push(format!(
1004 "Profile {op_name} for optimization opportunities - consider caching or vectorization"
1005 ));
1006 }
1007 }
1008
1009 if suggestions.is_empty() {
1010 suggestions.push("No performance issues detected - good job!".to_string());
1011 }
1012
1013 suggestions
1014 }
1015}
1016
1017static SHAPE_TRACKER: OnceLock<Arc<Mutex<ShapePerformanceTracker>>> = OnceLock::new();
1019
1020pub fn get_shape_tracker() -> &'static Arc<Mutex<ShapePerformanceTracker>> {
1022 SHAPE_TRACKER.get_or_init(|| Arc::new(Mutex::new(ShapePerformanceTracker::new(10_000))))
1023}
1024
1025pub fn profile_shape_operation<F, R>(operation_name: &str, ndim: usize, numel: usize, f: F) -> R
1027where
1028 F: FnOnce() -> R,
1029{
1030 let start = Instant::now();
1031 let result = f();
1032 let duration = start.elapsed();
1033
1034 let metrics = ShapeMetrics::new(ndim, numel);
1035
1036 let tracker = get_shape_tracker();
1037 if let Ok(mut tracker) = tracker.lock() {
1038 tracker.record_operation(operation_name.to_string(), duration, metrics);
1039 }
1040
1041 result
1042}
1043
1044pub fn profile_shape_operation_with_metrics<F, R>(
1046 operation_name: &str,
1047 metrics: ShapeMetrics,
1048 f: F,
1049) -> R
1050where
1051 F: FnOnce() -> R,
1052{
1053 let start = Instant::now();
1054 let result = f();
1055 let duration = start.elapsed();
1056
1057 let tracker = get_shape_tracker();
1058 if let Ok(mut tracker) = tracker.lock() {
1059 tracker.record_operation(operation_name.to_string(), duration, metrics);
1060 }
1061
1062 result
1063}
1064
1065#[macro_export]
1067macro_rules! profile_shape_op {
1068 ($op_name:expr, $ndim:expr, $numel:expr, $body:expr) => {
1069 $crate::profiling::profile_shape_operation($op_name, $ndim, $numel, || $body)
1070 };
1071 ($op_name:expr, $metrics:expr, $body:expr) => {
1072 $crate::profiling::profile_shape_operation_with_metrics($op_name, $metrics, || $body)
1073 };
1074}
1075
1076#[cfg(test)]
1077mod tests {
1078 use super::*;
1079 use std::thread;
1080 use std::time::Duration;
1081
1082 #[test]
1083 fn test_profiler_creation() {
1084 let profiler = PerformanceProfiler::new(ProfilerConfig::default());
1085 assert_eq!(profiler.records.len(), 0);
1086 assert_eq!(profiler.stats.len(), 0);
1087 }
1088
1089 #[test]
1090 fn test_operation_profiling() {
1091 let mut profiler = PerformanceProfiler::new(ProfilerConfig::default());
1092 let op_type = OperationType::Math("add".to_string());
1093
1094 let handle = profiler.start_operation(op_type.clone());
1095 thread::sleep(Duration::from_millis(1));
1096
1097 let context = OperationContext::new()
1098 .with_input_size(1000)
1099 .with_output_size(1000);
1100
1101 profiler.finish_operation(handle, context);
1102
1103 assert_eq!(profiler.records.len(), 1);
1104 assert!(profiler.stats.contains_key(&op_type));
1105 }
1106
1107 #[test]
1108 fn test_profiler_statistics() {
1109 let mut profiler = PerformanceProfiler::new(ProfilerConfig::default());
1110 let op_type = OperationType::Math("multiply".to_string());
1111
1112 for _ in 0..3 {
1114 let handle = profiler.start_operation(op_type.clone());
1115 thread::sleep(Duration::from_millis(1));
1116
1117 let context = OperationContext::new()
1118 .with_input_size(500)
1119 .with_output_size(500);
1120
1121 profiler.finish_operation(handle, context);
1122 }
1123
1124 let stats = profiler.get_stats();
1125 let multiply_stats = stats.get(&op_type).unwrap();
1126
1127 assert_eq!(multiply_stats.count, 3);
1128 assert!(multiply_stats.total_duration > Duration::ZERO);
1129 assert!(multiply_stats.avg_duration > Duration::ZERO);
1130 }
1131
1132 #[test]
1133 fn test_bottleneck_identification() {
1134 let mut profiler = PerformanceProfiler::new(ProfilerConfig::default());
1135 let slow_op = OperationType::Math("slow_operation".to_string());
1136 let fast_op = OperationType::Math("fast_operation".to_string());
1137
1138 let handle = profiler.start_operation(slow_op.clone());
1140 thread::sleep(Duration::from_millis(10));
1141 profiler.finish_operation(handle, OperationContext::new());
1142
1143 for _ in 0..5 {
1145 let handle = profiler.start_operation(fast_op.clone());
1146 thread::sleep(Duration::from_millis(1));
1147 profiler.finish_operation(handle, OperationContext::new());
1148 }
1149
1150 let bottlenecks = profiler.identify_bottlenecks();
1151 assert!(!bottlenecks.is_empty());
1152
1153 assert!(bottlenecks.iter().any(|b| b.operation_type == slow_op));
1155 }
1156
1157 #[test]
1158 fn test_memory_bandwidth_calculation() {
1159 let context = OperationContext::new()
1160 .with_input_size(1000)
1161 .with_output_size(1000);
1162
1163 let duration = Duration::from_millis(1);
1164 let bandwidth = context.calculate_memory_bandwidth(duration);
1165
1166 assert!(bandwidth.is_some());
1167 assert!(bandwidth.unwrap() > 0.0);
1168 }
1169
1170 #[test]
1171 fn test_profile_closure() {
1172 let _profiler = get_profiler();
1173
1174 let result = profile_closure(
1175 OperationType::Math("test".to_string()),
1176 OperationContext::new(),
1177 || {
1178 thread::sleep(Duration::from_millis(1));
1179 42
1180 },
1181 );
1182
1183 assert_eq!(result, 42);
1184
1185 let profiler = get_profiler();
1187 let records = {
1188 let p = profiler.lock().unwrap();
1189 p.get_records()
1190 };
1191
1192 assert!(!records.is_empty());
1193 }
1194
1195 #[test]
1196 fn test_profiler_report_generation() {
1197 let mut profiler = PerformanceProfiler::new(ProfilerConfig::default());
1198
1199 let handle = profiler.start_operation(OperationType::Math("add".to_string()));
1201 thread::sleep(Duration::from_millis(1));
1202 profiler.finish_operation(handle, OperationContext::new());
1203
1204 let report = profiler.generate_report();
1205 assert!(report.contains("Performance Profile Report"));
1206 assert!(report.contains("Total Operations: 1"));
1207 assert!(report.contains("Math::add"));
1208 }
1209}