Skip to main content

tensorlogic_compiler/
profiling.rs

1//! Compilation profiling and performance tracking.
2//!
3//! This module provides tools for profiling the compilation process,
4//! tracking performance metrics, and identifying bottlenecks.
5//!
6//! # Overview
7//!
8//! Compilation profiling helps developers:
9//! - Identify slow compilation passes
10//! - Track memory usage during compilation
11//! - Optimize compilation performance
12//! - Compare different compilation strategies
13//!
14//! # Features
15//!
16//! - **Time Tracking**: Measure time spent in each compilation phase
17//! - **Memory Tracking**: Monitor memory allocations and peak usage
18//! - **Pass Analysis**: Identify expensive optimization passes
19//! - **Cache Statistics**: Track cache hit rates and effectiveness
20//!
21//! # Examples
22//!
23//! ```rust
24//! use tensorlogic_compiler::profiling::{CompilationProfiler, ProfileConfig};
25//! use tensorlogic_compiler::compile_to_einsum;
26//! use tensorlogic_ir::{TLExpr, Term};
27//!
28//! let mut profiler = CompilationProfiler::new();
29//! profiler.start_phase("compilation");
30//!
31//! let expr = TLExpr::pred("p", vec![Term::var("x")]);
32//! let _graph = compile_to_einsum(&expr).unwrap();
33//!
34//! profiler.end_phase("compilation");
35//!
36//! let report = profiler.generate_report();
37//! println!("{}", report);
38//! ```
39
40use std::collections::HashMap;
41use std::time::{Duration, Instant};
42
43/// Configuration for compilation profiling.
44#[derive(Debug, Clone)]
45pub struct ProfileConfig {
46    /// Enable time tracking
47    pub track_time: bool,
48    /// Enable memory tracking
49    pub track_memory: bool,
50    /// Enable detailed pass-level profiling
51    pub track_passes: bool,
52    /// Enable cache statistics
53    pub track_cache: bool,
54    /// Minimum duration to report (filter noise)
55    pub min_duration_ms: u64,
56}
57
58impl Default for ProfileConfig {
59    fn default() -> Self {
60        Self {
61            track_time: true,
62            track_memory: true,
63            track_passes: true,
64            track_cache: true,
65            min_duration_ms: 1,
66        }
67    }
68}
69
70/// Time spent in a compilation phase.
71#[derive(Debug, Clone)]
72pub struct PhaseTime {
73    /// Phase name
74    pub name: String,
75    /// Total duration
76    pub duration: Duration,
77    /// Number of times this phase was executed
78    pub count: usize,
79    /// Child phases
80    pub children: Vec<PhaseTime>,
81}
82
83impl PhaseTime {
84    /// Create a new phase time entry.
85    pub fn new(name: String, duration: Duration) -> Self {
86        Self {
87            name,
88            duration,
89            count: 1,
90            children: Vec::new(),
91        }
92    }
93
94    /// Get average duration per execution.
95    pub fn average_duration(&self) -> Duration {
96        if self.count == 0 {
97            Duration::from_secs(0)
98        } else {
99            self.duration / self.count as u32
100        }
101    }
102
103    /// Get total time including children.
104    pub fn total_time_with_children(&self) -> Duration {
105        let mut total = self.duration;
106        for child in &self.children {
107            total += child.total_time_with_children();
108        }
109        total
110    }
111}
112
113/// Memory usage snapshot.
114#[derive(Debug, Clone, Default)]
115pub struct MemorySnapshot {
116    /// Timestamp of snapshot
117    pub timestamp: Option<Instant>,
118    /// Estimated heap usage in bytes
119    pub heap_bytes: usize,
120    /// Number of active allocations
121    pub allocation_count: usize,
122}
123
124impl MemorySnapshot {
125    /// Create a new memory snapshot.
126    pub fn new() -> Self {
127        Self {
128            timestamp: Some(Instant::now()),
129            heap_bytes: 0,
130            allocation_count: 0,
131        }
132    }
133
134    /// Record an allocation.
135    pub fn record_allocation(&mut self, size: usize) {
136        self.heap_bytes += size;
137        self.allocation_count += 1;
138    }
139
140    /// Record a deallocation.
141    pub fn record_deallocation(&mut self, size: usize) {
142        self.heap_bytes = self.heap_bytes.saturating_sub(size);
143        self.allocation_count = self.allocation_count.saturating_sub(1);
144    }
145}
146
147/// Pass-level profiling information.
148#[derive(Debug, Clone)]
149pub struct PassProfile {
150    /// Pass name
151    pub name: String,
152    /// Number of times executed
153    pub execution_count: usize,
154    /// Total time spent
155    pub total_time: Duration,
156    /// Number of optimizations applied
157    pub optimizations_applied: usize,
158    /// Memory allocated during pass
159    pub memory_allocated: usize,
160}
161
162impl PassProfile {
163    /// Create a new pass profile.
164    pub fn new(name: String) -> Self {
165        Self {
166            name,
167            execution_count: 0,
168            total_time: Duration::from_secs(0),
169            optimizations_applied: 0,
170            memory_allocated: 0,
171        }
172    }
173
174    /// Record an execution of this pass.
175    pub fn record_execution(&mut self, duration: Duration, optimizations: usize) {
176        self.execution_count += 1;
177        self.total_time += duration;
178        self.optimizations_applied += optimizations;
179    }
180
181    /// Get average time per execution.
182    pub fn average_time(&self) -> Duration {
183        if self.execution_count == 0 {
184            Duration::from_secs(0)
185        } else {
186            self.total_time / self.execution_count as u32
187        }
188    }
189
190    /// Get optimizations per execution.
191    pub fn optimizations_per_execution(&self) -> f64 {
192        if self.execution_count == 0 {
193            0.0
194        } else {
195            self.optimizations_applied as f64 / self.execution_count as f64
196        }
197    }
198}
199
200/// Cache statistics.
201#[derive(Debug, Clone, Default)]
202pub struct CacheStats {
203    /// Total cache lookups
204    pub lookups: usize,
205    /// Cache hits
206    pub hits: usize,
207    /// Cache misses
208    pub misses: usize,
209    /// Cache evictions
210    pub evictions: usize,
211}
212
213impl CacheStats {
214    /// Calculate hit rate as a percentage.
215    pub fn hit_rate(&self) -> f64 {
216        if self.lookups == 0 {
217            0.0
218        } else {
219            (self.hits as f64 / self.lookups as f64) * 100.0
220        }
221    }
222
223    /// Calculate miss rate as a percentage.
224    pub fn miss_rate(&self) -> f64 {
225        100.0 - self.hit_rate()
226    }
227
228    /// Record a cache lookup.
229    pub fn record_lookup(&mut self, hit: bool) {
230        self.lookups += 1;
231        if hit {
232            self.hits += 1;
233        } else {
234            self.misses += 1;
235        }
236    }
237}
238
239/// Main compilation profiler.
240pub struct CompilationProfiler {
241    config: ProfileConfig,
242    phases: Vec<PhaseTime>,
243    active_phases: Vec<(String, Instant)>,
244    memory_snapshots: Vec<MemorySnapshot>,
245    pass_profiles: HashMap<String, PassProfile>,
246    cache_stats: CacheStats,
247    start_time: Option<Instant>,
248}
249
250impl CompilationProfiler {
251    /// Create a new profiler with default configuration.
252    pub fn new() -> Self {
253        Self::with_config(ProfileConfig::default())
254    }
255
256    /// Create a new profiler with custom configuration.
257    pub fn with_config(config: ProfileConfig) -> Self {
258        Self {
259            config,
260            phases: Vec::new(),
261            active_phases: Vec::new(),
262            memory_snapshots: Vec::new(),
263            pass_profiles: HashMap::new(),
264            cache_stats: CacheStats::default(),
265            start_time: None,
266        }
267    }
268
269    /// Start overall compilation profiling.
270    pub fn start(&mut self) {
271        self.start_time = Some(Instant::now());
272        self.phases.clear();
273        self.active_phases.clear();
274    }
275
276    /// Start profiling a compilation phase.
277    pub fn start_phase(&mut self, name: &str) {
278        if !self.config.track_time {
279            return;
280        }
281
282        self.active_phases.push((name.to_string(), Instant::now()));
283    }
284
285    /// End profiling a compilation phase.
286    pub fn end_phase(&mut self, name: &str) {
287        if !self.config.track_time {
288            return;
289        }
290
291        if let Some(pos) = self.active_phases.iter().rposition(|(n, _)| n == name) {
292            let (phase_name, start_time) = self.active_phases.remove(pos);
293            let duration = start_time.elapsed();
294
295            if duration.as_millis() >= self.config.min_duration_ms as u128 {
296                self.phases.push(PhaseTime::new(phase_name, duration));
297            }
298        }
299    }
300
301    /// Record a pass execution.
302    pub fn record_pass(&mut self, pass_name: &str, duration: Duration, optimizations: usize) {
303        if !self.config.track_passes {
304            return;
305        }
306
307        let profile = self
308            .pass_profiles
309            .entry(pass_name.to_string())
310            .or_insert_with(|| PassProfile::new(pass_name.to_string()));
311
312        profile.record_execution(duration, optimizations);
313    }
314
315    /// Take a memory snapshot.
316    pub fn snapshot_memory(&mut self) {
317        if !self.config.track_memory {
318            return;
319        }
320
321        self.memory_snapshots.push(MemorySnapshot::new());
322    }
323
324    /// Record a cache lookup.
325    pub fn record_cache_lookup(&mut self, hit: bool) {
326        if !self.config.track_cache {
327            return;
328        }
329
330        self.cache_stats.record_lookup(hit);
331    }
332
333    /// Get total compilation time.
334    pub fn total_time(&self) -> Option<Duration> {
335        self.start_time.map(|start| start.elapsed())
336    }
337
338    /// Get peak memory usage.
339    pub fn peak_memory(&self) -> usize {
340        self.memory_snapshots
341            .iter()
342            .map(|s| s.heap_bytes)
343            .max()
344            .unwrap_or(0)
345    }
346
347    /// Get the slowest compilation phase.
348    pub fn slowest_phase(&self) -> Option<&PhaseTime> {
349        self.phases.iter().max_by_key(|p| p.duration)
350    }
351
352    /// Get the most expensive pass (by total time).
353    pub fn most_expensive_pass(&self) -> Option<&PassProfile> {
354        self.pass_profiles.values().max_by_key(|p| p.total_time)
355    }
356
357    /// Generate a human-readable profiling report.
358    pub fn generate_report(&self) -> String {
359        let mut report = String::new();
360
361        report.push_str("=== Compilation Profiling Report ===\n\n");
362
363        // Overall stats
364        if let Some(total) = self.total_time() {
365            report.push_str(&format!("Total Time: {:.2?}\n", total));
366        }
367
368        if self.config.track_memory {
369            report.push_str(&format!("Peak Memory: {} bytes\n", self.peak_memory()));
370        }
371
372        report.push('\n');
373
374        // Phase breakdown
375        if self.config.track_time && !self.phases.is_empty() {
376            report.push_str("=== Phase Breakdown ===\n");
377            for phase in &self.phases {
378                report.push_str(&format!(
379                    "  {}: {:.2?} ({} times, avg: {:.2?})\n",
380                    phase.name,
381                    phase.duration,
382                    phase.count,
383                    phase.average_duration()
384                ));
385            }
386            report.push('\n');
387        }
388
389        // Pass profiles
390        if self.config.track_passes && !self.pass_profiles.is_empty() {
391            report.push_str("=== Optimization Passes ===\n");
392            let mut passes: Vec<_> = self.pass_profiles.values().collect();
393            passes.sort_by_key(|p| std::cmp::Reverse(p.total_time));
394
395            for pass in passes.iter().take(10) {
396                report.push_str(&format!(
397                    "  {}: {:.2?} ({} execs, {:.1} opts/exec)\n",
398                    pass.name,
399                    pass.total_time,
400                    pass.execution_count,
401                    pass.optimizations_per_execution()
402                ));
403            }
404            report.push('\n');
405        }
406
407        // Cache statistics
408        if self.config.track_cache && self.cache_stats.lookups > 0 {
409            report.push_str("=== Cache Statistics ===\n");
410            report.push_str(&format!("  Lookups: {}\n", self.cache_stats.lookups));
411            report.push_str(&format!("  Hits: {}\n", self.cache_stats.hits));
412            report.push_str(&format!("  Misses: {}\n", self.cache_stats.misses));
413            report.push_str(&format!(
414                "  Hit Rate: {:.1}%\n",
415                self.cache_stats.hit_rate()
416            ));
417            report.push('\n');
418        }
419
420        // Recommendations
421        if let Some(slowest) = self.slowest_phase() {
422            report.push_str("=== Recommendations ===\n");
423            report.push_str(&format!(
424                "  Slowest phase: {} ({:.2?})\n",
425                slowest.name, slowest.duration
426            ));
427
428            if let Some(expensive_pass) = self.most_expensive_pass() {
429                report.push_str(&format!(
430                    "  Most expensive pass: {} ({:.2?})\n",
431                    expensive_pass.name, expensive_pass.total_time
432                ));
433            }
434
435            if self.config.track_cache && self.cache_stats.hit_rate() < 50.0 {
436                report.push_str("  Consider increasing cache size (low hit rate)\n");
437            }
438        }
439
440        report
441    }
442
443    /// Generate JSON profiling report.
444    pub fn generate_json_report(&self) -> String {
445        // Simple JSON serialization
446        let mut json = String::from("{\n");
447
448        if let Some(total) = self.total_time() {
449            json.push_str(&format!("  \"total_time_ms\": {},\n", total.as_millis()));
450        }
451
452        json.push_str(&format!(
453            "  \"peak_memory_bytes\": {},\n",
454            self.peak_memory()
455        ));
456
457        // Phases
458        json.push_str("  \"phases\": [\n");
459        for (i, phase) in self.phases.iter().enumerate() {
460            json.push_str(&format!(
461                "    {{\"name\": \"{}\", \"duration_ms\": {}, \"count\": {}}}",
462                phase.name,
463                phase.duration.as_millis(),
464                phase.count
465            ));
466            if i < self.phases.len() - 1 {
467                json.push(',');
468            }
469            json.push('\n');
470        }
471        json.push_str("  ],\n");
472
473        // Cache stats
474        json.push_str("  \"cache\": {\n");
475        json.push_str(&format!("    \"lookups\": {},\n", self.cache_stats.lookups));
476        json.push_str(&format!("    \"hits\": {},\n", self.cache_stats.hits));
477        json.push_str(&format!(
478            "    \"hit_rate\": {:.2}\n",
479            self.cache_stats.hit_rate()
480        ));
481        json.push_str("  }\n");
482
483        json.push_str("}\n");
484        json
485    }
486}
487
488impl Default for CompilationProfiler {
489    fn default() -> Self {
490        Self::new()
491    }
492}
493
494#[cfg(test)]
495mod tests {
496    use super::*;
497    use std::thread;
498
499    #[test]
500    fn test_profiler_basic() {
501        let mut profiler = CompilationProfiler::new();
502        profiler.start();
503
504        profiler.start_phase("test_phase");
505        thread::sleep(Duration::from_millis(10));
506        profiler.end_phase("test_phase");
507
508        assert!(!profiler.phases.is_empty());
509    }
510
511    #[test]
512    fn test_phase_time() {
513        let phase = PhaseTime::new("test".to_string(), Duration::from_secs(1));
514        assert_eq!(phase.name, "test");
515        assert_eq!(phase.count, 1);
516        assert_eq!(phase.average_duration(), Duration::from_secs(1));
517    }
518
519    #[test]
520    fn test_memory_snapshot() {
521        let mut snapshot = MemorySnapshot::new();
522        snapshot.record_allocation(1000);
523        snapshot.record_allocation(500);
524
525        assert_eq!(snapshot.heap_bytes, 1500);
526        assert_eq!(snapshot.allocation_count, 2);
527
528        snapshot.record_deallocation(500);
529        assert_eq!(snapshot.heap_bytes, 1000);
530        assert_eq!(snapshot.allocation_count, 1);
531    }
532
533    #[test]
534    fn test_pass_profile() {
535        let mut profile = PassProfile::new("constant_folding".to_string());
536        profile.record_execution(Duration::from_millis(10), 5);
537        profile.record_execution(Duration::from_millis(15), 3);
538
539        assert_eq!(profile.execution_count, 2);
540        assert_eq!(profile.optimizations_applied, 8);
541        assert!(profile.average_time().as_millis() >= 10);
542        assert_eq!(profile.optimizations_per_execution(), 4.0);
543    }
544
545    #[test]
546    fn test_cache_stats() {
547        let mut stats = CacheStats::default();
548        stats.record_lookup(true); // hit
549        stats.record_lookup(true); // hit
550        stats.record_lookup(false); // miss
551
552        assert_eq!(stats.lookups, 3);
553        assert_eq!(stats.hits, 2);
554        assert_eq!(stats.misses, 1);
555        assert!((stats.hit_rate() - 66.67).abs() < 0.1);
556    }
557
558    #[test]
559    fn test_generate_report() {
560        let mut profiler = CompilationProfiler::new();
561        profiler.start();
562        profiler.start_phase("compilation");
563        thread::sleep(Duration::from_millis(10));
564        profiler.end_phase("compilation");
565
566        let report = profiler.generate_report();
567        assert!(report.contains("Compilation Profiling Report"));
568        assert!(report.contains("Total Time"));
569    }
570
571    #[test]
572    fn test_slowest_phase() {
573        let mut profiler = CompilationProfiler::new();
574        profiler.start();
575
576        profiler.start_phase("fast");
577        thread::sleep(Duration::from_millis(5));
578        profiler.end_phase("fast");
579
580        profiler.start_phase("slow");
581        thread::sleep(Duration::from_millis(20));
582        profiler.end_phase("slow");
583
584        let slowest = profiler.slowest_phase().unwrap();
585        assert_eq!(slowest.name, "slow");
586    }
587
588    #[test]
589    fn test_most_expensive_pass() {
590        let mut profiler = CompilationProfiler::new();
591        profiler.record_pass("pass1", Duration::from_millis(10), 5);
592        profiler.record_pass("pass2", Duration::from_millis(50), 10);
593
594        let expensive = profiler.most_expensive_pass().unwrap();
595        assert_eq!(expensive.name, "pass2");
596    }
597
598    #[test]
599    fn test_json_report() {
600        let mut profiler = CompilationProfiler::new();
601        profiler.start();
602        profiler.record_cache_lookup(true);
603        profiler.record_cache_lookup(false);
604
605        let json = profiler.generate_json_report();
606        assert!(json.contains("total_time_ms"));
607        assert!(json.contains("cache"));
608        assert!(json.contains("hit_rate"));
609    }
610
611    #[test]
612    fn test_config_filtering() {
613        let config = ProfileConfig {
614            track_time: true,
615            track_memory: false,
616            track_passes: true,
617            track_cache: false,
618            min_duration_ms: 100,
619        };
620
621        let mut profiler = CompilationProfiler::with_config(config);
622        profiler.start();
623
624        // Short phase should be filtered out
625        profiler.start_phase("short");
626        thread::sleep(Duration::from_millis(1));
627        profiler.end_phase("short");
628
629        assert!(profiler.phases.is_empty());
630    }
631
632    #[test]
633    fn test_nested_phases() {
634        let mut profiler = CompilationProfiler::new();
635        profiler.start();
636
637        profiler.start_phase("outer");
638        thread::sleep(Duration::from_millis(5));
639
640        profiler.start_phase("inner");
641        thread::sleep(Duration::from_millis(5));
642        profiler.end_phase("inner");
643
644        profiler.end_phase("outer");
645
646        assert_eq!(profiler.phases.len(), 2);
647    }
648}