Skip to main content

tensorlogic_infer/
jit.rs

1//! Just-In-Time (JIT) compilation infrastructure.
2//!
3//! This module provides runtime compilation and adaptive optimization capabilities:
4//! - `JitCompiler`: Runtime compilation with hot path detection
5//! - `JitCache`: Specialized caching for JIT-compiled graphs
6//! - `HotPathDetector`: Identifies frequently executed code paths
7//! - `AdaptiveOptimizer`: Progressively optimizes based on runtime profiling
8//! - `TlJitExecutor`: Trait for executors that support JIT compilation
9//!
10//! # JIT Compilation Workflow
11//!
12//! 1. **First Execution**: Graph is compiled with minimal optimization
13//! 2. **Profiling**: Runtime characteristics are collected
14//! 3. **Hot Path Detection**: Frequently executed paths are identified
15//! 4. **Adaptive Optimization**: Hot paths are recompiled with aggressive optimization
16//! 5. **Specialization**: Graphs are specialized for observed shapes/types
17//!
18//! # Example
19//!
20//! ```
21//! use tensorlogic_infer::jit::{JitCompiler, JitConfig};
22//! use tensorlogic_ir::EinsumGraph;
23//!
24//! let mut jit = JitCompiler::new(JitConfig::default());
25//! let graph = EinsumGraph::new();
26//!
27//! // First execution: minimal compilation
28//! let compiled = jit.compile_or_retrieve(&graph, &[]).unwrap();
29//!
30//! // After profiling, hot paths are recompiled with aggressive optimization
31//! jit.optimize_hot_paths();
32//! ```
33
34use crate::compilation::{CompilationConfig, CompiledGraph, GraphCompiler, OptimizationLevel};
35use crate::error::ExecutorError;
36use crate::shape::TensorShape;
37use std::collections::HashMap;
38use std::hash::{Hash, Hasher};
39use std::sync::{Arc, RwLock};
40use std::time::{Duration, Instant};
41use tensorlogic_ir::EinsumGraph;
42
43/// Configuration for JIT compilation.
44#[derive(Debug, Clone)]
45pub struct JitConfig {
46    /// Initial optimization level for first compilation
47    pub initial_optimization: OptimizationLevel,
48    /// Hot path optimization level
49    pub hot_path_optimization: OptimizationLevel,
50    /// Minimum execution count to consider a path "hot"
51    pub hot_path_threshold: usize,
52    /// Enable shape specialization
53    pub enable_specialization: bool,
54    /// Maximum number of specialized versions per graph
55    pub max_specializations: usize,
56    /// Enable adaptive optimization
57    pub enable_adaptive_optimization: bool,
58    /// Profiling window size for hot path detection
59    pub profiling_window: usize,
60    /// Cache size limit (number of compiled graphs)
61    pub cache_size: usize,
62    /// Enable deoptimization for rarely used paths
63    pub enable_deoptimization: bool,
64    /// Threshold for deoptimization (executions per time window)
65    pub deoptimization_threshold: usize,
66}
67
68impl Default for JitConfig {
69    fn default() -> Self {
70        JitConfig {
71            initial_optimization: OptimizationLevel::Basic,
72            hot_path_optimization: OptimizationLevel::Aggressive,
73            hot_path_threshold: 10,
74            enable_specialization: true,
75            max_specializations: 5,
76            enable_adaptive_optimization: true,
77            profiling_window: 100,
78            cache_size: 1000,
79            enable_deoptimization: true,
80            deoptimization_threshold: 1,
81        }
82    }
83}
84
85/// Key for identifying graphs and their specializations.
86#[derive(Debug, Clone, PartialEq, Eq, Hash)]
87pub struct JitKey {
88    /// Hash of the graph structure
89    pub graph_hash: u64,
90    /// Specialization context (shapes, if enabled)
91    pub specialization: Option<SpecializationContext>,
92}
93
94/// Context for graph specialization.
95#[derive(Debug, Clone, PartialEq, Eq, Hash)]
96pub struct SpecializationContext {
97    /// Input shapes for specialization
98    pub input_shapes: Vec<Vec<usize>>,
99    /// Device target (if specified)
100    pub device: Option<String>,
101}
102
103impl SpecializationContext {
104    /// Create a new specialization context from input shapes.
105    pub fn from_shapes(shapes: &[TensorShape]) -> Self {
106        SpecializationContext {
107            input_shapes: shapes
108                .iter()
109                .map(|s| {
110                    s.dims
111                        .iter()
112                        .filter_map(|d| d.as_static())
113                        .collect::<Vec<_>>()
114                })
115                .collect(),
116            device: None,
117        }
118    }
119
120    /// Create a context with device specification.
121    pub fn with_device(mut self, device: String) -> Self {
122        self.device = Some(device);
123        self
124    }
125}
126
127/// Statistics for a compiled graph in the JIT cache.
128#[derive(Debug, Clone)]
129pub struct JitEntryStats {
130    /// Number of times this compiled version has been executed
131    pub execution_count: usize,
132    /// Total execution time for this version
133    pub total_execution_time: Duration,
134    /// Average execution time
135    pub avg_execution_time: Duration,
136    /// Optimization level used
137    pub optimization_level: OptimizationLevel,
138    /// Timestamp of last execution
139    pub last_executed: Instant,
140    /// Timestamp when compiled
141    pub compiled_at: Instant,
142    /// Whether this is a specialized version
143    pub is_specialized: bool,
144}
145
146impl Default for JitEntryStats {
147    fn default() -> Self {
148        JitEntryStats {
149            execution_count: 0,
150            total_execution_time: Duration::from_secs(0),
151            avg_execution_time: Duration::from_secs(0),
152            optimization_level: OptimizationLevel::Basic,
153            last_executed: Instant::now(),
154            compiled_at: Instant::now(),
155            is_specialized: false,
156        }
157    }
158}
159
160impl JitEntryStats {
161    /// Record an execution of this compiled graph.
162    pub fn record_execution(&mut self, duration: Duration) {
163        self.execution_count += 1;
164        self.total_execution_time += duration;
165        self.avg_execution_time = self.total_execution_time / self.execution_count as u32;
166        self.last_executed = Instant::now();
167    }
168
169    /// Check if this entry is "hot" based on execution count.
170    pub fn is_hot(&self, threshold: usize) -> bool {
171        self.execution_count >= threshold
172    }
173
174    /// Check if this entry is cold (rarely used).
175    pub fn is_cold(&self, threshold: usize, window: Duration) -> bool {
176        let time_since_last = Instant::now().duration_since(self.last_executed);
177        time_since_last > window && self.execution_count < threshold
178    }
179}
180
181/// Entry in the JIT cache.
182#[derive(Debug, Clone)]
183pub struct JitCacheEntry {
184    /// The compiled graph
185    pub compiled: CompiledGraph,
186    /// Statistics for this entry
187    pub stats: JitEntryStats,
188}
189
190/// Cache for JIT-compiled graphs with profiling support.
191pub struct JitCache {
192    cache: Arc<RwLock<HashMap<JitKey, JitCacheEntry>>>,
193    config: JitConfig,
194}
195
196impl JitCache {
197    /// Create a new JIT cache.
198    pub fn new(config: JitConfig) -> Self {
199        JitCache {
200            cache: Arc::new(RwLock::new(HashMap::new())),
201            config,
202        }
203    }
204
205    /// Insert a compiled graph into the cache.
206    pub fn insert(&self, key: JitKey, compiled: CompiledGraph, is_specialized: bool) {
207        let mut cache = self.cache.write().expect("lock should not be poisoned");
208
209        // Evict old entries if cache is full
210        if cache.len() >= self.config.cache_size {
211            self.evict_lru(&mut cache);
212        }
213
214        let stats = JitEntryStats {
215            optimization_level: compiled.config.optimization_level,
216            is_specialized,
217            ..Default::default()
218        };
219
220        cache.insert(key, JitCacheEntry { compiled, stats });
221    }
222
223    /// Retrieve a compiled graph from the cache.
224    pub fn get(&self, key: &JitKey) -> Option<CompiledGraph> {
225        let cache = self.cache.read().expect("lock should not be poisoned");
226        cache.get(key).map(|entry| entry.compiled.clone())
227    }
228
229    /// Record an execution of a cached graph.
230    pub fn record_execution(&self, key: &JitKey, duration: Duration) {
231        let mut cache = self.cache.write().expect("lock should not be poisoned");
232        if let Some(entry) = cache.get_mut(key) {
233            entry.stats.record_execution(duration);
234        }
235    }
236
237    /// Get statistics for a cached entry.
238    pub fn get_stats(&self, key: &JitKey) -> Option<JitEntryStats> {
239        let cache = self.cache.read().expect("lock should not be poisoned");
240        cache.get(key).map(|entry| entry.stats.clone())
241    }
242
243    /// Get all hot paths (frequently executed graphs).
244    pub fn get_hot_paths(&self) -> Vec<(JitKey, JitEntryStats)> {
245        let cache = self.cache.read().expect("lock should not be poisoned");
246        cache
247            .iter()
248            .filter(|(_, entry)| entry.stats.is_hot(self.config.hot_path_threshold))
249            .map(|(key, entry)| (key.clone(), entry.stats.clone()))
250            .collect()
251    }
252
253    /// Get all cold paths (rarely executed graphs).
254    pub fn get_cold_paths(&self) -> Vec<(JitKey, JitEntryStats)> {
255        let cache = self.cache.read().expect("lock should not be poisoned");
256        let window = Duration::from_secs(300); // 5 minutes
257        cache
258            .iter()
259            .filter(|(_, entry)| {
260                entry
261                    .stats
262                    .is_cold(self.config.deoptimization_threshold, window)
263            })
264            .map(|(key, entry)| (key.clone(), entry.stats.clone()))
265            .collect()
266    }
267
268    /// Evict least recently used entry.
269    fn evict_lru(&self, cache: &mut HashMap<JitKey, JitCacheEntry>) {
270        if let Some((key, _)) = cache
271            .iter()
272            .min_by_key(|(_, entry)| entry.stats.last_executed)
273        {
274            let key = key.clone();
275            cache.remove(&key);
276        }
277    }
278
279    /// Clear the cache.
280    pub fn clear(&self) {
281        let mut cache = self.cache.write().expect("lock should not be poisoned");
282        cache.clear();
283    }
284
285    /// Get cache statistics.
286    pub fn cache_stats(&self) -> JitCacheStats {
287        let cache = self.cache.read().expect("lock should not be poisoned");
288        let total_entries = cache.len();
289        let hot_entries = cache
290            .values()
291            .filter(|e| e.stats.is_hot(self.config.hot_path_threshold))
292            .count();
293        let specialized_entries = cache.values().filter(|e| e.stats.is_specialized).count();
294        let total_executions = cache.values().map(|e| e.stats.execution_count).sum();
295
296        JitCacheStats {
297            total_entries,
298            hot_entries,
299            specialized_entries,
300            total_executions,
301            cache_capacity: self.config.cache_size,
302        }
303    }
304}
305
306/// Statistics for the JIT cache.
307#[derive(Debug, Clone)]
308pub struct JitCacheStats {
309    /// Total number of entries in the cache
310    pub total_entries: usize,
311    /// Number of hot entries
312    pub hot_entries: usize,
313    /// Number of specialized entries
314    pub specialized_entries: usize,
315    /// Total number of executions across all entries
316    pub total_executions: usize,
317    /// Cache capacity
318    pub cache_capacity: usize,
319}
320
321/// Hot path detector that identifies frequently executed code paths.
322pub struct HotPathDetector {
323    config: JitConfig,
324}
325
326impl HotPathDetector {
327    /// Create a new hot path detector.
328    pub fn new(config: JitConfig) -> Self {
329        HotPathDetector { config }
330    }
331
332    /// Detect hot paths from cache statistics.
333    pub fn detect_hot_paths(&self, cache: &JitCache) -> Vec<JitKey> {
334        cache
335            .get_hot_paths()
336            .into_iter()
337            .map(|(key, _)| key)
338            .collect()
339    }
340
341    /// Recommend recompilation for hot paths.
342    pub fn recommend_recompilation(&self, cache: &JitCache) -> Vec<(JitKey, OptimizationLevel)> {
343        cache
344            .get_hot_paths()
345            .into_iter()
346            .filter_map(|(key, stats)| {
347                // Only recommend recompilation if current optimization is below hot path level
348                if stats.optimization_level < self.config.hot_path_optimization {
349                    Some((key, self.config.hot_path_optimization))
350                } else {
351                    None
352                }
353            })
354            .collect()
355    }
356
357    /// Recommend deoptimization for cold paths.
358    pub fn recommend_deoptimization(&self, cache: &JitCache) -> Vec<JitKey> {
359        if !self.config.enable_deoptimization {
360            return Vec::new();
361        }
362
363        cache
364            .get_cold_paths()
365            .into_iter()
366            .map(|(key, _)| key)
367            .collect()
368    }
369}
370
371/// Adaptive optimizer that progressively optimizes based on runtime profiling.
372pub struct AdaptiveOptimizer {
373    config: JitConfig,
374    hot_path_detector: HotPathDetector,
375}
376
377impl AdaptiveOptimizer {
378    /// Create a new adaptive optimizer.
379    pub fn new(config: JitConfig) -> Self {
380        AdaptiveOptimizer {
381            hot_path_detector: HotPathDetector::new(config.clone()),
382            config,
383        }
384    }
385
386    /// Analyze runtime behavior and recommend optimizations.
387    pub fn analyze_and_recommend(&self, cache: &JitCache) -> AdaptiveOptimizationPlan {
388        let hot_paths = self.hot_path_detector.recommend_recompilation(cache);
389        let cold_paths = self.hot_path_detector.recommend_deoptimization(cache);
390
391        AdaptiveOptimizationPlan {
392            recompile: hot_paths,
393            deoptimize: cold_paths,
394        }
395    }
396
397    /// Apply adaptive optimizations to the cache.
398    pub fn optimize(&self, cache: &JitCache) -> Result<usize, ExecutorError> {
399        let plan = self.analyze_and_recommend(cache);
400        let mut optimized_count = 0;
401
402        // Recompile hot paths with aggressive optimization
403        for (key, opt_level) in plan.recompile {
404            if let Some(entry) = cache
405                .cache
406                .read()
407                .expect("lock should not be poisoned")
408                .get(&key)
409            {
410                let graph = &entry.compiled.graph;
411                let mut config = entry.compiled.config.clone();
412                config.optimization_level = opt_level;
413
414                let mut new_compiler = GraphCompiler::new(config);
415                let recompiled = new_compiler.compile(graph)?;
416
417                // Update cache with recompiled version
418                cache
419                    .cache
420                    .write()
421                    .expect("lock should not be poisoned")
422                    .get_mut(&key)
423                    .expect("key just retrieved from cache")
424                    .compiled = recompiled;
425                optimized_count += 1;
426            }
427        }
428
429        // Deoptimize cold paths (remove from cache or downgrade)
430        for key in plan.deoptimize {
431            cache
432                .cache
433                .write()
434                .expect("lock should not be poisoned")
435                .remove(&key);
436        }
437
438        Ok(optimized_count)
439    }
440
441    /// Get the JIT configuration.
442    pub fn config(&self) -> &JitConfig {
443        &self.config
444    }
445
446    /// Get the hot path detector.
447    pub fn hot_path_detector(&self) -> &HotPathDetector {
448        &self.hot_path_detector
449    }
450}
451
452/// Plan for adaptive optimization.
453#[derive(Debug, Clone)]
454pub struct AdaptiveOptimizationPlan {
455    /// Graphs to recompile with higher optimization
456    pub recompile: Vec<(JitKey, OptimizationLevel)>,
457    /// Graphs to deoptimize (remove or downgrade)
458    pub deoptimize: Vec<JitKey>,
459}
460
461/// JIT compiler with runtime compilation and adaptive optimization.
462pub struct JitCompiler {
463    config: JitConfig,
464    cache: JitCache,
465    adaptive_optimizer: AdaptiveOptimizer,
466}
467
468impl JitCompiler {
469    /// Create a new JIT compiler.
470    pub fn new(config: JitConfig) -> Self {
471        JitCompiler {
472            cache: JitCache::new(config.clone()),
473            adaptive_optimizer: AdaptiveOptimizer::new(config.clone()),
474            config,
475        }
476    }
477
478    /// Create a JIT compiler with default configuration.
479    pub fn with_default_config() -> Self {
480        Self::new(JitConfig::default())
481    }
482
483    /// Compile a graph or retrieve from cache.
484    pub fn compile_or_retrieve(
485        &mut self,
486        graph: &EinsumGraph,
487        input_shapes: &[TensorShape],
488    ) -> Result<CompiledGraph, ExecutorError> {
489        let key = self.create_key(graph, input_shapes);
490
491        // Check cache first
492        if let Some(compiled) = self.cache.get(&key) {
493            return Ok(compiled);
494        }
495
496        // Compile with initial optimization level
497        let config = CompilationConfig {
498            optimization_level: self.config.initial_optimization,
499            enable_shape_inference: true,
500            enable_memory_estimation: true,
501            enable_caching: true,
502            enable_parallelism: true,
503            ..Default::default()
504        };
505
506        let mut compiler = GraphCompiler::new(config);
507        let compiled = compiler.compile(graph)?;
508
509        // Cache the compiled graph
510        let is_specialized = self.config.enable_specialization && !input_shapes.is_empty();
511        self.cache.insert(key, compiled.clone(), is_specialized);
512
513        Ok(compiled)
514    }
515
516    /// Record execution of a compiled graph.
517    pub fn record_execution(
518        &self,
519        graph: &EinsumGraph,
520        input_shapes: &[TensorShape],
521        duration: Duration,
522    ) {
523        let key = self.create_key(graph, input_shapes);
524        self.cache.record_execution(&key, duration);
525    }
526
527    /// Optimize hot paths based on profiling data.
528    pub fn optimize_hot_paths(&mut self) -> Result<usize, ExecutorError> {
529        if !self.config.enable_adaptive_optimization {
530            return Ok(0);
531        }
532
533        self.adaptive_optimizer.optimize(&self.cache)
534    }
535
536    /// Get JIT cache statistics.
537    pub fn cache_stats(&self) -> JitCacheStats {
538        self.cache.cache_stats()
539    }
540
541    /// Clear the JIT cache.
542    pub fn clear_cache(&self) {
543        self.cache.clear();
544    }
545
546    /// Create a cache key for the graph.
547    fn create_key(&self, graph: &EinsumGraph, input_shapes: &[TensorShape]) -> JitKey {
548        let graph_hash = self.hash_graph(graph);
549        let specialization = if self.config.enable_specialization && !input_shapes.is_empty() {
550            Some(SpecializationContext::from_shapes(input_shapes))
551        } else {
552            None
553        };
554
555        JitKey {
556            graph_hash,
557            specialization,
558        }
559    }
560
561    /// Hash a graph for caching.
562    fn hash_graph(&self, graph: &EinsumGraph) -> u64 {
563        use std::collections::hash_map::DefaultHasher;
564        let mut hasher = DefaultHasher::new();
565        graph.nodes.len().hash(&mut hasher);
566        // Simple hash based on node count and structure
567        // In production, would use more sophisticated hashing
568        hasher.finish()
569    }
570}
571
572/// Trait for executors that support JIT compilation.
573pub trait TlJitExecutor {
574    /// Get the JIT compiler for this executor.
575    fn jit_compiler(&mut self) -> &mut JitCompiler;
576
577    /// Enable JIT compilation.
578    fn enable_jit(&mut self);
579
580    /// Disable JIT compilation.
581    fn disable_jit(&mut self);
582
583    /// Check if JIT is enabled.
584    fn is_jit_enabled(&self) -> bool;
585
586    /// Trigger adaptive optimization of hot paths.
587    fn optimize_hot_paths(&mut self) -> Result<usize, ExecutorError> {
588        self.jit_compiler().optimize_hot_paths()
589    }
590
591    /// Get JIT statistics.
592    fn jit_stats(&self) -> JitCacheStats;
593}
594
595/// Statistics for JIT compilation performance.
596#[derive(Debug, Clone)]
597pub struct JitStats {
598    /// Total number of compilations performed
599    pub total_compilations: usize,
600    /// Number of cache hits
601    pub cache_hits: usize,
602    /// Number of cache misses
603    pub cache_misses: usize,
604    /// Number of recompilations due to hot path optimization
605    pub recompilations: usize,
606    /// Number of deoptimizations
607    pub deoptimizations: usize,
608    /// Average compilation time
609    pub avg_compilation_time: Duration,
610    /// Total time saved by caching
611    pub total_time_saved: Duration,
612}
613
614impl Default for JitStats {
615    fn default() -> Self {
616        JitStats {
617            total_compilations: 0,
618            cache_hits: 0,
619            cache_misses: 0,
620            recompilations: 0,
621            deoptimizations: 0,
622            avg_compilation_time: Duration::from_secs(0),
623            total_time_saved: Duration::from_secs(0),
624        }
625    }
626}
627
628impl JitStats {
629    /// Calculate cache hit rate.
630    pub fn cache_hit_rate(&self) -> f64 {
631        if self.cache_hits + self.cache_misses == 0 {
632            return 0.0;
633        }
634        self.cache_hits as f64 / (self.cache_hits + self.cache_misses) as f64
635    }
636
637    /// Get a summary of JIT statistics.
638    pub fn summary(&self) -> String {
639        format!(
640            "JIT Stats: {} compilations, {:.1}% cache hit rate, {} recompilations, {:.2}ms avg compile time",
641            self.total_compilations,
642            self.cache_hit_rate() * 100.0,
643            self.recompilations,
644            self.avg_compilation_time.as_secs_f64() * 1000.0
645        )
646    }
647}
648
649#[cfg(test)]
650mod tests {
651    use super::*;
652
653    #[test]
654    fn test_jit_config_default() {
655        let config = JitConfig::default();
656        assert_eq!(config.initial_optimization, OptimizationLevel::Basic);
657        assert_eq!(config.hot_path_optimization, OptimizationLevel::Aggressive);
658        assert_eq!(config.hot_path_threshold, 10);
659        assert!(config.enable_specialization);
660        assert!(config.enable_adaptive_optimization);
661    }
662
663    #[test]
664    fn test_specialization_context() {
665        let shapes = vec![
666            TensorShape::static_shape(vec![2, 3]),
667            TensorShape::static_shape(vec![3, 4]),
668        ];
669        let ctx = SpecializationContext::from_shapes(&shapes);
670        assert_eq!(ctx.input_shapes.len(), 2);
671        assert_eq!(ctx.input_shapes[0], vec![2, 3]);
672        assert_eq!(ctx.input_shapes[1], vec![3, 4]);
673    }
674
675    #[test]
676    fn test_jit_entry_stats() {
677        let mut stats = JitEntryStats::default();
678        assert_eq!(stats.execution_count, 0);
679        assert!(!stats.is_hot(10));
680
681        // Record executions
682        for _ in 0..15 {
683            stats.record_execution(Duration::from_millis(10));
684        }
685
686        assert_eq!(stats.execution_count, 15);
687        assert!(stats.is_hot(10));
688        assert_eq!(stats.total_execution_time, Duration::from_millis(150));
689    }
690
691    #[test]
692    fn test_jit_cache_insert_retrieve() {
693        let config = JitConfig::default();
694        let cache = JitCache::new(config);
695
696        let graph = EinsumGraph::new();
697        let compiled = CompiledGraph {
698            graph: graph.clone(),
699            schedule: crate::scheduling::ExecutionSchedule {
700                execution_order: Vec::new(),
701                device_placement: HashMap::new(),
702                parallel_groups: Vec::new(),
703                estimated_cost: 0.0,
704            },
705            shapes: HashMap::new(),
706            memory_usage: HashMap::new(),
707            config: CompilationConfig::default(),
708            stats: crate::compilation::CompilationStats::default(),
709            compiled_at: std::time::SystemTime::now(),
710        };
711
712        let key = JitKey {
713            graph_hash: 12345,
714            specialization: None,
715        };
716
717        cache.insert(key.clone(), compiled.clone(), false);
718        let retrieved = cache.get(&key);
719        assert!(retrieved.is_some());
720    }
721
722    #[test]
723    fn test_jit_cache_eviction() {
724        let config = JitConfig {
725            cache_size: 2, // Small cache for testing
726            ..Default::default()
727        };
728        let cache = JitCache::new(config);
729
730        let graph = EinsumGraph::new();
731        let compiled = CompiledGraph {
732            graph: graph.clone(),
733            schedule: crate::scheduling::ExecutionSchedule {
734                execution_order: Vec::new(),
735                device_placement: HashMap::new(),
736                parallel_groups: Vec::new(),
737                estimated_cost: 0.0,
738            },
739            shapes: HashMap::new(),
740            memory_usage: HashMap::new(),
741            config: CompilationConfig::default(),
742            stats: crate::compilation::CompilationStats::default(),
743            compiled_at: std::time::SystemTime::now(),
744        };
745
746        // Insert 3 entries (should evict oldest)
747        for i in 0..3 {
748            let key = JitKey {
749                graph_hash: i,
750                specialization: None,
751            };
752            cache.insert(key, compiled.clone(), false);
753            std::thread::sleep(Duration::from_millis(10)); // Ensure different timestamps
754        }
755
756        let stats = cache.cache_stats();
757        assert_eq!(stats.total_entries, 2); // Should only have 2 due to eviction
758    }
759
760    #[test]
761    fn test_hot_path_detection() {
762        let config = JitConfig::default();
763        let cache = JitCache::new(config.clone());
764        let detector = HotPathDetector::new(config);
765
766        let graph = EinsumGraph::new();
767        let compiled = CompiledGraph {
768            graph: graph.clone(),
769            schedule: crate::scheduling::ExecutionSchedule {
770                execution_order: Vec::new(),
771                device_placement: HashMap::new(),
772                parallel_groups: Vec::new(),
773                estimated_cost: 0.0,
774            },
775            shapes: HashMap::new(),
776            memory_usage: HashMap::new(),
777            config: CompilationConfig::default(),
778            stats: crate::compilation::CompilationStats::default(),
779            compiled_at: std::time::SystemTime::now(),
780        };
781
782        let key = JitKey {
783            graph_hash: 123,
784            specialization: None,
785        };
786
787        cache.insert(key.clone(), compiled, false);
788
789        // Record many executions to make it hot
790        for _ in 0..15 {
791            cache.record_execution(&key, Duration::from_millis(10));
792        }
793
794        let hot_paths = detector.detect_hot_paths(&cache);
795        assert_eq!(hot_paths.len(), 1);
796        assert_eq!(hot_paths[0].graph_hash, 123);
797    }
798
799    #[test]
800    fn test_jit_compiler_basic() {
801        let mut jit = JitCompiler::with_default_config();
802        let graph = EinsumGraph::new();
803        let shapes = vec![];
804
805        let result = jit.compile_or_retrieve(&graph, &shapes);
806        assert!(result.is_ok());
807
808        // Second call should hit cache
809        let result2 = jit.compile_or_retrieve(&graph, &shapes);
810        assert!(result2.is_ok());
811    }
812
813    #[test]
814    fn test_jit_stats() {
815        let stats = JitStats::default();
816        assert_eq!(stats.cache_hit_rate(), 0.0);
817
818        let stats = JitStats {
819            cache_hits: 8,
820            cache_misses: 2,
821            ..Default::default()
822        };
823        assert_eq!(stats.cache_hit_rate(), 0.8);
824    }
825
826    #[test]
827    fn test_adaptive_optimization_plan() {
828        let plan = AdaptiveOptimizationPlan {
829            recompile: vec![(
830                JitKey {
831                    graph_hash: 123,
832                    specialization: None,
833                },
834                OptimizationLevel::Aggressive,
835            )],
836            deoptimize: vec![],
837        };
838
839        assert_eq!(plan.recompile.len(), 1);
840        assert_eq!(plan.deoptimize.len(), 0);
841    }
842
843    #[test]
844    fn test_jit_cache_stats() {
845        let config = JitConfig::default();
846        let cache = JitCache::new(config);
847
848        let stats = cache.cache_stats();
849        assert_eq!(stats.total_entries, 0);
850        assert_eq!(stats.hot_entries, 0);
851        assert_eq!(stats.total_executions, 0);
852    }
853
854    #[test]
855    fn test_specialization_with_device() {
856        let shapes = vec![TensorShape::static_shape(vec![2, 3])];
857        let ctx = SpecializationContext::from_shapes(&shapes).with_device("cuda:0".to_string());
858
859        assert_eq!(ctx.device, Some("cuda:0".to_string()));
860        assert_eq!(ctx.input_shapes[0], vec![2, 3]);
861    }
862
863    #[test]
864    fn test_jit_entry_cold_detection() {
865        let mut stats = JitEntryStats::default();
866
867        // Execute once
868        stats.record_execution(Duration::from_millis(10));
869
870        // Not cold immediately
871        assert!(!stats.is_cold(5, Duration::from_millis(100)));
872
873        // Wait and check
874        std::thread::sleep(Duration::from_millis(150));
875        assert!(stats.is_cold(5, Duration::from_millis(100)));
876    }
877
878    #[test]
879    fn test_jit_cache_clear() {
880        let config = JitConfig::default();
881        let cache = JitCache::new(config);
882
883        let graph = EinsumGraph::new();
884        let compiled = CompiledGraph {
885            graph: graph.clone(),
886            schedule: crate::scheduling::ExecutionSchedule {
887                execution_order: Vec::new(),
888                device_placement: HashMap::new(),
889                parallel_groups: Vec::new(),
890                estimated_cost: 0.0,
891            },
892            shapes: HashMap::new(),
893            memory_usage: HashMap::new(),
894            config: CompilationConfig::default(),
895            stats: crate::compilation::CompilationStats::default(),
896            compiled_at: std::time::SystemTime::now(),
897        };
898
899        let key = JitKey {
900            graph_hash: 123,
901            specialization: None,
902        };
903
904        cache.insert(key.clone(), compiled, false);
905        assert_eq!(cache.cache_stats().total_entries, 1);
906
907        cache.clear();
908        assert_eq!(cache.cache_stats().total_entries, 0);
909    }
910}