Skip to main content

tensorlogic_infer/
jit.rs

1//! Just-In-Time (JIT) compilation infrastructure.
2//!
3//! This module provides runtime compilation and adaptive optimization capabilities:
4//! - `JitCompiler`: Runtime compilation with hot path detection
5//! - `JitCache`: Specialized caching for JIT-compiled graphs
6//! - `HotPathDetector`: Identifies frequently executed code paths
7//! - `AdaptiveOptimizer`: Progressively optimizes based on runtime profiling
8//! - `TlJitExecutor`: Trait for executors that support JIT compilation
9//!
10//! # JIT Compilation Workflow
11//!
12//! 1. **First Execution**: Graph is compiled with minimal optimization
13//! 2. **Profiling**: Runtime characteristics are collected
14//! 3. **Hot Path Detection**: Frequently executed paths are identified
15//! 4. **Adaptive Optimization**: Hot paths are recompiled with aggressive optimization
16//! 5. **Specialization**: Graphs are specialized for observed shapes/types
17//!
18//! # Example
19//!
20//! ```
21//! use tensorlogic_infer::jit::{JitCompiler, JitConfig};
22//! use tensorlogic_ir::EinsumGraph;
23//!
24//! let mut jit = JitCompiler::new(JitConfig::default());
25//! let graph = EinsumGraph::new();
26//!
27//! // First execution: minimal compilation
28//! let compiled = jit.compile_or_retrieve(&graph, &[]).unwrap();
29//!
30//! // After profiling, hot paths are recompiled with aggressive optimization
31//! jit.optimize_hot_paths();
32//! ```
33
34use crate::compilation::{CompilationConfig, CompiledGraph, GraphCompiler, OptimizationLevel};
35use crate::error::ExecutorError;
36use crate::shape::TensorShape;
37use std::collections::HashMap;
38use std::hash::{Hash, Hasher};
39use std::sync::{Arc, RwLock};
40use std::time::{Duration, Instant};
41use tensorlogic_ir::EinsumGraph;
42
43/// Configuration for JIT compilation.
44#[derive(Debug, Clone)]
45pub struct JitConfig {
46    /// Initial optimization level for first compilation
47    pub initial_optimization: OptimizationLevel,
48    /// Hot path optimization level
49    pub hot_path_optimization: OptimizationLevel,
50    /// Minimum execution count to consider a path "hot"
51    pub hot_path_threshold: usize,
52    /// Enable shape specialization
53    pub enable_specialization: bool,
54    /// Maximum number of specialized versions per graph
55    pub max_specializations: usize,
56    /// Enable adaptive optimization
57    pub enable_adaptive_optimization: bool,
58    /// Profiling window size for hot path detection
59    pub profiling_window: usize,
60    /// Cache size limit (number of compiled graphs)
61    pub cache_size: usize,
62    /// Enable deoptimization for rarely used paths
63    pub enable_deoptimization: bool,
64    /// Threshold for deoptimization (executions per time window)
65    pub deoptimization_threshold: usize,
66}
67
68impl Default for JitConfig {
69    fn default() -> Self {
70        JitConfig {
71            initial_optimization: OptimizationLevel::Basic,
72            hot_path_optimization: OptimizationLevel::Aggressive,
73            hot_path_threshold: 10,
74            enable_specialization: true,
75            max_specializations: 5,
76            enable_adaptive_optimization: true,
77            profiling_window: 100,
78            cache_size: 1000,
79            enable_deoptimization: true,
80            deoptimization_threshold: 1,
81        }
82    }
83}
84
85/// Key for identifying graphs and their specializations.
86#[derive(Debug, Clone, PartialEq, Eq, Hash)]
87pub struct JitKey {
88    /// Hash of the graph structure
89    pub graph_hash: u64,
90    /// Specialization context (shapes, if enabled)
91    pub specialization: Option<SpecializationContext>,
92}
93
94/// Context for graph specialization.
95#[derive(Debug, Clone, PartialEq, Eq, Hash)]
96pub struct SpecializationContext {
97    /// Input shapes for specialization
98    pub input_shapes: Vec<Vec<usize>>,
99    /// Device target (if specified)
100    pub device: Option<String>,
101}
102
103impl SpecializationContext {
104    /// Create a new specialization context from input shapes.
105    pub fn from_shapes(shapes: &[TensorShape]) -> Self {
106        SpecializationContext {
107            input_shapes: shapes
108                .iter()
109                .map(|s| {
110                    s.dims
111                        .iter()
112                        .filter_map(|d| d.as_static())
113                        .collect::<Vec<_>>()
114                })
115                .collect(),
116            device: None,
117        }
118    }
119
120    /// Create a context with device specification.
121    pub fn with_device(mut self, device: String) -> Self {
122        self.device = Some(device);
123        self
124    }
125}
126
127/// Statistics for a compiled graph in the JIT cache.
128#[derive(Debug, Clone)]
129pub struct JitEntryStats {
130    /// Number of times this compiled version has been executed
131    pub execution_count: usize,
132    /// Total execution time for this version
133    pub total_execution_time: Duration,
134    /// Average execution time
135    pub avg_execution_time: Duration,
136    /// Optimization level used
137    pub optimization_level: OptimizationLevel,
138    /// Timestamp of last execution
139    pub last_executed: Instant,
140    /// Timestamp when compiled
141    pub compiled_at: Instant,
142    /// Whether this is a specialized version
143    pub is_specialized: bool,
144}
145
146impl Default for JitEntryStats {
147    fn default() -> Self {
148        JitEntryStats {
149            execution_count: 0,
150            total_execution_time: Duration::from_secs(0),
151            avg_execution_time: Duration::from_secs(0),
152            optimization_level: OptimizationLevel::Basic,
153            last_executed: Instant::now(),
154            compiled_at: Instant::now(),
155            is_specialized: false,
156        }
157    }
158}
159
160impl JitEntryStats {
161    /// Record an execution of this compiled graph.
162    pub fn record_execution(&mut self, duration: Duration) {
163        self.execution_count += 1;
164        self.total_execution_time += duration;
165        self.avg_execution_time = self.total_execution_time / self.execution_count as u32;
166        self.last_executed = Instant::now();
167    }
168
169    /// Check if this entry is "hot" based on execution count.
170    pub fn is_hot(&self, threshold: usize) -> bool {
171        self.execution_count >= threshold
172    }
173
174    /// Check if this entry is cold (rarely used).
175    pub fn is_cold(&self, threshold: usize, window: Duration) -> bool {
176        let time_since_last = Instant::now().duration_since(self.last_executed);
177        time_since_last > window && self.execution_count < threshold
178    }
179}
180
181/// Entry in the JIT cache.
182#[derive(Debug, Clone)]
183pub struct JitCacheEntry {
184    /// The compiled graph
185    pub compiled: CompiledGraph,
186    /// Statistics for this entry
187    pub stats: JitEntryStats,
188}
189
190/// Cache for JIT-compiled graphs with profiling support.
191pub struct JitCache {
192    cache: Arc<RwLock<HashMap<JitKey, JitCacheEntry>>>,
193    config: JitConfig,
194}
195
196impl JitCache {
197    /// Create a new JIT cache.
198    pub fn new(config: JitConfig) -> Self {
199        JitCache {
200            cache: Arc::new(RwLock::new(HashMap::new())),
201            config,
202        }
203    }
204
205    /// Insert a compiled graph into the cache.
206    pub fn insert(&self, key: JitKey, compiled: CompiledGraph, is_specialized: bool) {
207        let mut cache = self.cache.write().unwrap();
208
209        // Evict old entries if cache is full
210        if cache.len() >= self.config.cache_size {
211            self.evict_lru(&mut cache);
212        }
213
214        let stats = JitEntryStats {
215            optimization_level: compiled.config.optimization_level,
216            is_specialized,
217            ..Default::default()
218        };
219
220        cache.insert(key, JitCacheEntry { compiled, stats });
221    }
222
223    /// Retrieve a compiled graph from the cache.
224    pub fn get(&self, key: &JitKey) -> Option<CompiledGraph> {
225        let cache = self.cache.read().unwrap();
226        cache.get(key).map(|entry| entry.compiled.clone())
227    }
228
229    /// Record an execution of a cached graph.
230    pub fn record_execution(&self, key: &JitKey, duration: Duration) {
231        let mut cache = self.cache.write().unwrap();
232        if let Some(entry) = cache.get_mut(key) {
233            entry.stats.record_execution(duration);
234        }
235    }
236
237    /// Get statistics for a cached entry.
238    pub fn get_stats(&self, key: &JitKey) -> Option<JitEntryStats> {
239        let cache = self.cache.read().unwrap();
240        cache.get(key).map(|entry| entry.stats.clone())
241    }
242
243    /// Get all hot paths (frequently executed graphs).
244    pub fn get_hot_paths(&self) -> Vec<(JitKey, JitEntryStats)> {
245        let cache = self.cache.read().unwrap();
246        cache
247            .iter()
248            .filter(|(_, entry)| entry.stats.is_hot(self.config.hot_path_threshold))
249            .map(|(key, entry)| (key.clone(), entry.stats.clone()))
250            .collect()
251    }
252
253    /// Get all cold paths (rarely executed graphs).
254    pub fn get_cold_paths(&self) -> Vec<(JitKey, JitEntryStats)> {
255        let cache = self.cache.read().unwrap();
256        let window = Duration::from_secs(300); // 5 minutes
257        cache
258            .iter()
259            .filter(|(_, entry)| {
260                entry
261                    .stats
262                    .is_cold(self.config.deoptimization_threshold, window)
263            })
264            .map(|(key, entry)| (key.clone(), entry.stats.clone()))
265            .collect()
266    }
267
268    /// Evict least recently used entry.
269    fn evict_lru(&self, cache: &mut HashMap<JitKey, JitCacheEntry>) {
270        if let Some((key, _)) = cache
271            .iter()
272            .min_by_key(|(_, entry)| entry.stats.last_executed)
273        {
274            let key = key.clone();
275            cache.remove(&key);
276        }
277    }
278
279    /// Clear the cache.
280    pub fn clear(&self) {
281        let mut cache = self.cache.write().unwrap();
282        cache.clear();
283    }
284
285    /// Get cache statistics.
286    pub fn cache_stats(&self) -> JitCacheStats {
287        let cache = self.cache.read().unwrap();
288        let total_entries = cache.len();
289        let hot_entries = cache
290            .values()
291            .filter(|e| e.stats.is_hot(self.config.hot_path_threshold))
292            .count();
293        let specialized_entries = cache.values().filter(|e| e.stats.is_specialized).count();
294        let total_executions = cache.values().map(|e| e.stats.execution_count).sum();
295
296        JitCacheStats {
297            total_entries,
298            hot_entries,
299            specialized_entries,
300            total_executions,
301            cache_capacity: self.config.cache_size,
302        }
303    }
304}
305
306/// Statistics for the JIT cache.
307#[derive(Debug, Clone)]
308pub struct JitCacheStats {
309    /// Total number of entries in the cache
310    pub total_entries: usize,
311    /// Number of hot entries
312    pub hot_entries: usize,
313    /// Number of specialized entries
314    pub specialized_entries: usize,
315    /// Total number of executions across all entries
316    pub total_executions: usize,
317    /// Cache capacity
318    pub cache_capacity: usize,
319}
320
321/// Hot path detector that identifies frequently executed code paths.
322pub struct HotPathDetector {
323    config: JitConfig,
324}
325
326impl HotPathDetector {
327    /// Create a new hot path detector.
328    pub fn new(config: JitConfig) -> Self {
329        HotPathDetector { config }
330    }
331
332    /// Detect hot paths from cache statistics.
333    pub fn detect_hot_paths(&self, cache: &JitCache) -> Vec<JitKey> {
334        cache
335            .get_hot_paths()
336            .into_iter()
337            .map(|(key, _)| key)
338            .collect()
339    }
340
341    /// Recommend recompilation for hot paths.
342    pub fn recommend_recompilation(&self, cache: &JitCache) -> Vec<(JitKey, OptimizationLevel)> {
343        cache
344            .get_hot_paths()
345            .into_iter()
346            .filter_map(|(key, stats)| {
347                // Only recommend recompilation if current optimization is below hot path level
348                if stats.optimization_level < self.config.hot_path_optimization {
349                    Some((key, self.config.hot_path_optimization))
350                } else {
351                    None
352                }
353            })
354            .collect()
355    }
356
357    /// Recommend deoptimization for cold paths.
358    pub fn recommend_deoptimization(&self, cache: &JitCache) -> Vec<JitKey> {
359        if !self.config.enable_deoptimization {
360            return Vec::new();
361        }
362
363        cache
364            .get_cold_paths()
365            .into_iter()
366            .map(|(key, _)| key)
367            .collect()
368    }
369}
370
371/// Adaptive optimizer that progressively optimizes based on runtime profiling.
372pub struct AdaptiveOptimizer {
373    config: JitConfig,
374    hot_path_detector: HotPathDetector,
375}
376
377impl AdaptiveOptimizer {
378    /// Create a new adaptive optimizer.
379    pub fn new(config: JitConfig) -> Self {
380        AdaptiveOptimizer {
381            hot_path_detector: HotPathDetector::new(config.clone()),
382            config,
383        }
384    }
385
386    /// Analyze runtime behavior and recommend optimizations.
387    pub fn analyze_and_recommend(&self, cache: &JitCache) -> AdaptiveOptimizationPlan {
388        let hot_paths = self.hot_path_detector.recommend_recompilation(cache);
389        let cold_paths = self.hot_path_detector.recommend_deoptimization(cache);
390
391        AdaptiveOptimizationPlan {
392            recompile: hot_paths,
393            deoptimize: cold_paths,
394        }
395    }
396
397    /// Apply adaptive optimizations to the cache.
398    pub fn optimize(&self, cache: &JitCache) -> Result<usize, ExecutorError> {
399        let plan = self.analyze_and_recommend(cache);
400        let mut optimized_count = 0;
401
402        // Recompile hot paths with aggressive optimization
403        for (key, opt_level) in plan.recompile {
404            if let Some(entry) = cache.cache.read().unwrap().get(&key) {
405                let graph = &entry.compiled.graph;
406                let mut config = entry.compiled.config.clone();
407                config.optimization_level = opt_level;
408
409                let mut new_compiler = GraphCompiler::new(config);
410                let recompiled = new_compiler.compile(graph)?;
411
412                // Update cache with recompiled version
413                cache.cache.write().unwrap().get_mut(&key).unwrap().compiled = recompiled;
414                optimized_count += 1;
415            }
416        }
417
418        // Deoptimize cold paths (remove from cache or downgrade)
419        for key in plan.deoptimize {
420            cache.cache.write().unwrap().remove(&key);
421        }
422
423        Ok(optimized_count)
424    }
425
426    /// Get the JIT configuration.
427    pub fn config(&self) -> &JitConfig {
428        &self.config
429    }
430
431    /// Get the hot path detector.
432    pub fn hot_path_detector(&self) -> &HotPathDetector {
433        &self.hot_path_detector
434    }
435}
436
437/// Plan for adaptive optimization.
438#[derive(Debug, Clone)]
439pub struct AdaptiveOptimizationPlan {
440    /// Graphs to recompile with higher optimization
441    pub recompile: Vec<(JitKey, OptimizationLevel)>,
442    /// Graphs to deoptimize (remove or downgrade)
443    pub deoptimize: Vec<JitKey>,
444}
445
446/// JIT compiler with runtime compilation and adaptive optimization.
447pub struct JitCompiler {
448    config: JitConfig,
449    cache: JitCache,
450    adaptive_optimizer: AdaptiveOptimizer,
451}
452
453impl JitCompiler {
454    /// Create a new JIT compiler.
455    pub fn new(config: JitConfig) -> Self {
456        JitCompiler {
457            cache: JitCache::new(config.clone()),
458            adaptive_optimizer: AdaptiveOptimizer::new(config.clone()),
459            config,
460        }
461    }
462
463    /// Create a JIT compiler with default configuration.
464    pub fn with_default_config() -> Self {
465        Self::new(JitConfig::default())
466    }
467
468    /// Compile a graph or retrieve from cache.
469    pub fn compile_or_retrieve(
470        &mut self,
471        graph: &EinsumGraph,
472        input_shapes: &[TensorShape],
473    ) -> Result<CompiledGraph, ExecutorError> {
474        let key = self.create_key(graph, input_shapes);
475
476        // Check cache first
477        if let Some(compiled) = self.cache.get(&key) {
478            return Ok(compiled);
479        }
480
481        // Compile with initial optimization level
482        let config = CompilationConfig {
483            optimization_level: self.config.initial_optimization,
484            enable_shape_inference: true,
485            enable_memory_estimation: true,
486            enable_caching: true,
487            enable_parallelism: true,
488            ..Default::default()
489        };
490
491        let mut compiler = GraphCompiler::new(config);
492        let compiled = compiler.compile(graph)?;
493
494        // Cache the compiled graph
495        let is_specialized = self.config.enable_specialization && !input_shapes.is_empty();
496        self.cache.insert(key, compiled.clone(), is_specialized);
497
498        Ok(compiled)
499    }
500
501    /// Record execution of a compiled graph.
502    pub fn record_execution(
503        &self,
504        graph: &EinsumGraph,
505        input_shapes: &[TensorShape],
506        duration: Duration,
507    ) {
508        let key = self.create_key(graph, input_shapes);
509        self.cache.record_execution(&key, duration);
510    }
511
512    /// Optimize hot paths based on profiling data.
513    pub fn optimize_hot_paths(&mut self) -> Result<usize, ExecutorError> {
514        if !self.config.enable_adaptive_optimization {
515            return Ok(0);
516        }
517
518        self.adaptive_optimizer.optimize(&self.cache)
519    }
520
521    /// Get JIT cache statistics.
522    pub fn cache_stats(&self) -> JitCacheStats {
523        self.cache.cache_stats()
524    }
525
526    /// Clear the JIT cache.
527    pub fn clear_cache(&self) {
528        self.cache.clear();
529    }
530
531    /// Create a cache key for the graph.
532    fn create_key(&self, graph: &EinsumGraph, input_shapes: &[TensorShape]) -> JitKey {
533        let graph_hash = self.hash_graph(graph);
534        let specialization = if self.config.enable_specialization && !input_shapes.is_empty() {
535            Some(SpecializationContext::from_shapes(input_shapes))
536        } else {
537            None
538        };
539
540        JitKey {
541            graph_hash,
542            specialization,
543        }
544    }
545
546    /// Hash a graph for caching.
547    fn hash_graph(&self, graph: &EinsumGraph) -> u64 {
548        use std::collections::hash_map::DefaultHasher;
549        let mut hasher = DefaultHasher::new();
550        graph.nodes.len().hash(&mut hasher);
551        // Simple hash based on node count and structure
552        // In production, would use more sophisticated hashing
553        hasher.finish()
554    }
555}
556
557/// Trait for executors that support JIT compilation.
558pub trait TlJitExecutor {
559    /// Get the JIT compiler for this executor.
560    fn jit_compiler(&mut self) -> &mut JitCompiler;
561
562    /// Enable JIT compilation.
563    fn enable_jit(&mut self);
564
565    /// Disable JIT compilation.
566    fn disable_jit(&mut self);
567
568    /// Check if JIT is enabled.
569    fn is_jit_enabled(&self) -> bool;
570
571    /// Trigger adaptive optimization of hot paths.
572    fn optimize_hot_paths(&mut self) -> Result<usize, ExecutorError> {
573        self.jit_compiler().optimize_hot_paths()
574    }
575
576    /// Get JIT statistics.
577    fn jit_stats(&self) -> JitCacheStats;
578}
579
580/// Statistics for JIT compilation performance.
581#[derive(Debug, Clone)]
582pub struct JitStats {
583    /// Total number of compilations performed
584    pub total_compilations: usize,
585    /// Number of cache hits
586    pub cache_hits: usize,
587    /// Number of cache misses
588    pub cache_misses: usize,
589    /// Number of recompilations due to hot path optimization
590    pub recompilations: usize,
591    /// Number of deoptimizations
592    pub deoptimizations: usize,
593    /// Average compilation time
594    pub avg_compilation_time: Duration,
595    /// Total time saved by caching
596    pub total_time_saved: Duration,
597}
598
599impl Default for JitStats {
600    fn default() -> Self {
601        JitStats {
602            total_compilations: 0,
603            cache_hits: 0,
604            cache_misses: 0,
605            recompilations: 0,
606            deoptimizations: 0,
607            avg_compilation_time: Duration::from_secs(0),
608            total_time_saved: Duration::from_secs(0),
609        }
610    }
611}
612
613impl JitStats {
614    /// Calculate cache hit rate.
615    pub fn cache_hit_rate(&self) -> f64 {
616        if self.cache_hits + self.cache_misses == 0 {
617            return 0.0;
618        }
619        self.cache_hits as f64 / (self.cache_hits + self.cache_misses) as f64
620    }
621
622    /// Get a summary of JIT statistics.
623    pub fn summary(&self) -> String {
624        format!(
625            "JIT Stats: {} compilations, {:.1}% cache hit rate, {} recompilations, {:.2}ms avg compile time",
626            self.total_compilations,
627            self.cache_hit_rate() * 100.0,
628            self.recompilations,
629            self.avg_compilation_time.as_secs_f64() * 1000.0
630        )
631    }
632}
633
634#[cfg(test)]
635mod tests {
636    use super::*;
637
638    #[test]
639    fn test_jit_config_default() {
640        let config = JitConfig::default();
641        assert_eq!(config.initial_optimization, OptimizationLevel::Basic);
642        assert_eq!(config.hot_path_optimization, OptimizationLevel::Aggressive);
643        assert_eq!(config.hot_path_threshold, 10);
644        assert!(config.enable_specialization);
645        assert!(config.enable_adaptive_optimization);
646    }
647
648    #[test]
649    fn test_specialization_context() {
650        let shapes = vec![
651            TensorShape::static_shape(vec![2, 3]),
652            TensorShape::static_shape(vec![3, 4]),
653        ];
654        let ctx = SpecializationContext::from_shapes(&shapes);
655        assert_eq!(ctx.input_shapes.len(), 2);
656        assert_eq!(ctx.input_shapes[0], vec![2, 3]);
657        assert_eq!(ctx.input_shapes[1], vec![3, 4]);
658    }
659
660    #[test]
661    fn test_jit_entry_stats() {
662        let mut stats = JitEntryStats::default();
663        assert_eq!(stats.execution_count, 0);
664        assert!(!stats.is_hot(10));
665
666        // Record executions
667        for _ in 0..15 {
668            stats.record_execution(Duration::from_millis(10));
669        }
670
671        assert_eq!(stats.execution_count, 15);
672        assert!(stats.is_hot(10));
673        assert_eq!(stats.total_execution_time, Duration::from_millis(150));
674    }
675
676    #[test]
677    fn test_jit_cache_insert_retrieve() {
678        let config = JitConfig::default();
679        let cache = JitCache::new(config);
680
681        let graph = EinsumGraph::new();
682        let compiled = CompiledGraph {
683            graph: graph.clone(),
684            schedule: crate::scheduling::ExecutionSchedule {
685                execution_order: Vec::new(),
686                device_placement: HashMap::new(),
687                parallel_groups: Vec::new(),
688                estimated_cost: 0.0,
689            },
690            shapes: HashMap::new(),
691            memory_usage: HashMap::new(),
692            config: CompilationConfig::default(),
693            stats: crate::compilation::CompilationStats::default(),
694            compiled_at: std::time::SystemTime::now(),
695        };
696
697        let key = JitKey {
698            graph_hash: 12345,
699            specialization: None,
700        };
701
702        cache.insert(key.clone(), compiled.clone(), false);
703        let retrieved = cache.get(&key);
704        assert!(retrieved.is_some());
705    }
706
707    #[test]
708    fn test_jit_cache_eviction() {
709        let config = JitConfig {
710            cache_size: 2, // Small cache for testing
711            ..Default::default()
712        };
713        let cache = JitCache::new(config);
714
715        let graph = EinsumGraph::new();
716        let compiled = CompiledGraph {
717            graph: graph.clone(),
718            schedule: crate::scheduling::ExecutionSchedule {
719                execution_order: Vec::new(),
720                device_placement: HashMap::new(),
721                parallel_groups: Vec::new(),
722                estimated_cost: 0.0,
723            },
724            shapes: HashMap::new(),
725            memory_usage: HashMap::new(),
726            config: CompilationConfig::default(),
727            stats: crate::compilation::CompilationStats::default(),
728            compiled_at: std::time::SystemTime::now(),
729        };
730
731        // Insert 3 entries (should evict oldest)
732        for i in 0..3 {
733            let key = JitKey {
734                graph_hash: i,
735                specialization: None,
736            };
737            cache.insert(key, compiled.clone(), false);
738            std::thread::sleep(Duration::from_millis(10)); // Ensure different timestamps
739        }
740
741        let stats = cache.cache_stats();
742        assert_eq!(stats.total_entries, 2); // Should only have 2 due to eviction
743    }
744
745    #[test]
746    fn test_hot_path_detection() {
747        let config = JitConfig::default();
748        let cache = JitCache::new(config.clone());
749        let detector = HotPathDetector::new(config);
750
751        let graph = EinsumGraph::new();
752        let compiled = CompiledGraph {
753            graph: graph.clone(),
754            schedule: crate::scheduling::ExecutionSchedule {
755                execution_order: Vec::new(),
756                device_placement: HashMap::new(),
757                parallel_groups: Vec::new(),
758                estimated_cost: 0.0,
759            },
760            shapes: HashMap::new(),
761            memory_usage: HashMap::new(),
762            config: CompilationConfig::default(),
763            stats: crate::compilation::CompilationStats::default(),
764            compiled_at: std::time::SystemTime::now(),
765        };
766
767        let key = JitKey {
768            graph_hash: 123,
769            specialization: None,
770        };
771
772        cache.insert(key.clone(), compiled, false);
773
774        // Record many executions to make it hot
775        for _ in 0..15 {
776            cache.record_execution(&key, Duration::from_millis(10));
777        }
778
779        let hot_paths = detector.detect_hot_paths(&cache);
780        assert_eq!(hot_paths.len(), 1);
781        assert_eq!(hot_paths[0].graph_hash, 123);
782    }
783
784    #[test]
785    fn test_jit_compiler_basic() {
786        let mut jit = JitCompiler::with_default_config();
787        let graph = EinsumGraph::new();
788        let shapes = vec![];
789
790        let result = jit.compile_or_retrieve(&graph, &shapes);
791        assert!(result.is_ok());
792
793        // Second call should hit cache
794        let result2 = jit.compile_or_retrieve(&graph, &shapes);
795        assert!(result2.is_ok());
796    }
797
798    #[test]
799    fn test_jit_stats() {
800        let stats = JitStats::default();
801        assert_eq!(stats.cache_hit_rate(), 0.0);
802
803        let stats = JitStats {
804            cache_hits: 8,
805            cache_misses: 2,
806            ..Default::default()
807        };
808        assert_eq!(stats.cache_hit_rate(), 0.8);
809    }
810
811    #[test]
812    fn test_adaptive_optimization_plan() {
813        let plan = AdaptiveOptimizationPlan {
814            recompile: vec![(
815                JitKey {
816                    graph_hash: 123,
817                    specialization: None,
818                },
819                OptimizationLevel::Aggressive,
820            )],
821            deoptimize: vec![],
822        };
823
824        assert_eq!(plan.recompile.len(), 1);
825        assert_eq!(plan.deoptimize.len(), 0);
826    }
827
828    #[test]
829    fn test_jit_cache_stats() {
830        let config = JitConfig::default();
831        let cache = JitCache::new(config);
832
833        let stats = cache.cache_stats();
834        assert_eq!(stats.total_entries, 0);
835        assert_eq!(stats.hot_entries, 0);
836        assert_eq!(stats.total_executions, 0);
837    }
838
839    #[test]
840    fn test_specialization_with_device() {
841        let shapes = vec![TensorShape::static_shape(vec![2, 3])];
842        let ctx = SpecializationContext::from_shapes(&shapes).with_device("cuda:0".to_string());
843
844        assert_eq!(ctx.device, Some("cuda:0".to_string()));
845        assert_eq!(ctx.input_shapes[0], vec![2, 3]);
846    }
847
848    #[test]
849    fn test_jit_entry_cold_detection() {
850        let mut stats = JitEntryStats::default();
851
852        // Execute once
853        stats.record_execution(Duration::from_millis(10));
854
855        // Not cold immediately
856        assert!(!stats.is_cold(5, Duration::from_millis(100)));
857
858        // Wait and check
859        std::thread::sleep(Duration::from_millis(150));
860        assert!(stats.is_cold(5, Duration::from_millis(100)));
861    }
862
863    #[test]
864    fn test_jit_cache_clear() {
865        let config = JitConfig::default();
866        let cache = JitCache::new(config);
867
868        let graph = EinsumGraph::new();
869        let compiled = CompiledGraph {
870            graph: graph.clone(),
871            schedule: crate::scheduling::ExecutionSchedule {
872                execution_order: Vec::new(),
873                device_placement: HashMap::new(),
874                parallel_groups: Vec::new(),
875                estimated_cost: 0.0,
876            },
877            shapes: HashMap::new(),
878            memory_usage: HashMap::new(),
879            config: CompilationConfig::default(),
880            stats: crate::compilation::CompilationStats::default(),
881            compiled_at: std::time::SystemTime::now(),
882        };
883
884        let key = JitKey {
885            graph_hash: 123,
886            specialization: None,
887        };
888
889        cache.insert(key.clone(), compiled, false);
890        assert_eq!(cache.cache_stats().total_entries, 1);
891
892        cache.clear();
893        assert_eq!(cache.cache_stats().total_entries, 0);
894    }
895}