1use crate::compilation::{CompilationConfig, CompiledGraph, GraphCompiler, OptimizationLevel};
35use crate::error::ExecutorError;
36use crate::shape::TensorShape;
37use std::collections::HashMap;
38use std::hash::{Hash, Hasher};
39use std::sync::{Arc, RwLock};
40use std::time::{Duration, Instant};
41use tensorlogic_ir::EinsumGraph;
42
43#[derive(Debug, Clone)]
45pub struct JitConfig {
46 pub initial_optimization: OptimizationLevel,
48 pub hot_path_optimization: OptimizationLevel,
50 pub hot_path_threshold: usize,
52 pub enable_specialization: bool,
54 pub max_specializations: usize,
56 pub enable_adaptive_optimization: bool,
58 pub profiling_window: usize,
60 pub cache_size: usize,
62 pub enable_deoptimization: bool,
64 pub deoptimization_threshold: usize,
66}
67
68impl Default for JitConfig {
69 fn default() -> Self {
70 JitConfig {
71 initial_optimization: OptimizationLevel::Basic,
72 hot_path_optimization: OptimizationLevel::Aggressive,
73 hot_path_threshold: 10,
74 enable_specialization: true,
75 max_specializations: 5,
76 enable_adaptive_optimization: true,
77 profiling_window: 100,
78 cache_size: 1000,
79 enable_deoptimization: true,
80 deoptimization_threshold: 1,
81 }
82 }
83}
84
85#[derive(Debug, Clone, PartialEq, Eq, Hash)]
87pub struct JitKey {
88 pub graph_hash: u64,
90 pub specialization: Option<SpecializationContext>,
92}
93
94#[derive(Debug, Clone, PartialEq, Eq, Hash)]
96pub struct SpecializationContext {
97 pub input_shapes: Vec<Vec<usize>>,
99 pub device: Option<String>,
101}
102
103impl SpecializationContext {
104 pub fn from_shapes(shapes: &[TensorShape]) -> Self {
106 SpecializationContext {
107 input_shapes: shapes
108 .iter()
109 .map(|s| {
110 s.dims
111 .iter()
112 .filter_map(|d| d.as_static())
113 .collect::<Vec<_>>()
114 })
115 .collect(),
116 device: None,
117 }
118 }
119
120 pub fn with_device(mut self, device: String) -> Self {
122 self.device = Some(device);
123 self
124 }
125}
126
127#[derive(Debug, Clone)]
129pub struct JitEntryStats {
130 pub execution_count: usize,
132 pub total_execution_time: Duration,
134 pub avg_execution_time: Duration,
136 pub optimization_level: OptimizationLevel,
138 pub last_executed: Instant,
140 pub compiled_at: Instant,
142 pub is_specialized: bool,
144}
145
146impl Default for JitEntryStats {
147 fn default() -> Self {
148 JitEntryStats {
149 execution_count: 0,
150 total_execution_time: Duration::from_secs(0),
151 avg_execution_time: Duration::from_secs(0),
152 optimization_level: OptimizationLevel::Basic,
153 last_executed: Instant::now(),
154 compiled_at: Instant::now(),
155 is_specialized: false,
156 }
157 }
158}
159
160impl JitEntryStats {
161 pub fn record_execution(&mut self, duration: Duration) {
163 self.execution_count += 1;
164 self.total_execution_time += duration;
165 self.avg_execution_time = self.total_execution_time / self.execution_count as u32;
166 self.last_executed = Instant::now();
167 }
168
169 pub fn is_hot(&self, threshold: usize) -> bool {
171 self.execution_count >= threshold
172 }
173
174 pub fn is_cold(&self, threshold: usize, window: Duration) -> bool {
176 let time_since_last = Instant::now().duration_since(self.last_executed);
177 time_since_last > window && self.execution_count < threshold
178 }
179}
180
181#[derive(Debug, Clone)]
183pub struct JitCacheEntry {
184 pub compiled: CompiledGraph,
186 pub stats: JitEntryStats,
188}
189
190pub struct JitCache {
192 cache: Arc<RwLock<HashMap<JitKey, JitCacheEntry>>>,
193 config: JitConfig,
194}
195
196impl JitCache {
197 pub fn new(config: JitConfig) -> Self {
199 JitCache {
200 cache: Arc::new(RwLock::new(HashMap::new())),
201 config,
202 }
203 }
204
205 pub fn insert(&self, key: JitKey, compiled: CompiledGraph, is_specialized: bool) {
207 let mut cache = self.cache.write().unwrap();
208
209 if cache.len() >= self.config.cache_size {
211 self.evict_lru(&mut cache);
212 }
213
214 let stats = JitEntryStats {
215 optimization_level: compiled.config.optimization_level,
216 is_specialized,
217 ..Default::default()
218 };
219
220 cache.insert(key, JitCacheEntry { compiled, stats });
221 }
222
223 pub fn get(&self, key: &JitKey) -> Option<CompiledGraph> {
225 let cache = self.cache.read().unwrap();
226 cache.get(key).map(|entry| entry.compiled.clone())
227 }
228
229 pub fn record_execution(&self, key: &JitKey, duration: Duration) {
231 let mut cache = self.cache.write().unwrap();
232 if let Some(entry) = cache.get_mut(key) {
233 entry.stats.record_execution(duration);
234 }
235 }
236
237 pub fn get_stats(&self, key: &JitKey) -> Option<JitEntryStats> {
239 let cache = self.cache.read().unwrap();
240 cache.get(key).map(|entry| entry.stats.clone())
241 }
242
243 pub fn get_hot_paths(&self) -> Vec<(JitKey, JitEntryStats)> {
245 let cache = self.cache.read().unwrap();
246 cache
247 .iter()
248 .filter(|(_, entry)| entry.stats.is_hot(self.config.hot_path_threshold))
249 .map(|(key, entry)| (key.clone(), entry.stats.clone()))
250 .collect()
251 }
252
253 pub fn get_cold_paths(&self) -> Vec<(JitKey, JitEntryStats)> {
255 let cache = self.cache.read().unwrap();
256 let window = Duration::from_secs(300); cache
258 .iter()
259 .filter(|(_, entry)| {
260 entry
261 .stats
262 .is_cold(self.config.deoptimization_threshold, window)
263 })
264 .map(|(key, entry)| (key.clone(), entry.stats.clone()))
265 .collect()
266 }
267
268 fn evict_lru(&self, cache: &mut HashMap<JitKey, JitCacheEntry>) {
270 if let Some((key, _)) = cache
271 .iter()
272 .min_by_key(|(_, entry)| entry.stats.last_executed)
273 {
274 let key = key.clone();
275 cache.remove(&key);
276 }
277 }
278
279 pub fn clear(&self) {
281 let mut cache = self.cache.write().unwrap();
282 cache.clear();
283 }
284
285 pub fn cache_stats(&self) -> JitCacheStats {
287 let cache = self.cache.read().unwrap();
288 let total_entries = cache.len();
289 let hot_entries = cache
290 .values()
291 .filter(|e| e.stats.is_hot(self.config.hot_path_threshold))
292 .count();
293 let specialized_entries = cache.values().filter(|e| e.stats.is_specialized).count();
294 let total_executions = cache.values().map(|e| e.stats.execution_count).sum();
295
296 JitCacheStats {
297 total_entries,
298 hot_entries,
299 specialized_entries,
300 total_executions,
301 cache_capacity: self.config.cache_size,
302 }
303 }
304}
305
306#[derive(Debug, Clone)]
308pub struct JitCacheStats {
309 pub total_entries: usize,
311 pub hot_entries: usize,
313 pub specialized_entries: usize,
315 pub total_executions: usize,
317 pub cache_capacity: usize,
319}
320
321pub struct HotPathDetector {
323 config: JitConfig,
324}
325
326impl HotPathDetector {
327 pub fn new(config: JitConfig) -> Self {
329 HotPathDetector { config }
330 }
331
332 pub fn detect_hot_paths(&self, cache: &JitCache) -> Vec<JitKey> {
334 cache
335 .get_hot_paths()
336 .into_iter()
337 .map(|(key, _)| key)
338 .collect()
339 }
340
341 pub fn recommend_recompilation(&self, cache: &JitCache) -> Vec<(JitKey, OptimizationLevel)> {
343 cache
344 .get_hot_paths()
345 .into_iter()
346 .filter_map(|(key, stats)| {
347 if stats.optimization_level < self.config.hot_path_optimization {
349 Some((key, self.config.hot_path_optimization))
350 } else {
351 None
352 }
353 })
354 .collect()
355 }
356
357 pub fn recommend_deoptimization(&self, cache: &JitCache) -> Vec<JitKey> {
359 if !self.config.enable_deoptimization {
360 return Vec::new();
361 }
362
363 cache
364 .get_cold_paths()
365 .into_iter()
366 .map(|(key, _)| key)
367 .collect()
368 }
369}
370
371pub struct AdaptiveOptimizer {
373 config: JitConfig,
374 hot_path_detector: HotPathDetector,
375}
376
377impl AdaptiveOptimizer {
378 pub fn new(config: JitConfig) -> Self {
380 AdaptiveOptimizer {
381 hot_path_detector: HotPathDetector::new(config.clone()),
382 config,
383 }
384 }
385
386 pub fn analyze_and_recommend(&self, cache: &JitCache) -> AdaptiveOptimizationPlan {
388 let hot_paths = self.hot_path_detector.recommend_recompilation(cache);
389 let cold_paths = self.hot_path_detector.recommend_deoptimization(cache);
390
391 AdaptiveOptimizationPlan {
392 recompile: hot_paths,
393 deoptimize: cold_paths,
394 }
395 }
396
397 pub fn optimize(&self, cache: &JitCache) -> Result<usize, ExecutorError> {
399 let plan = self.analyze_and_recommend(cache);
400 let mut optimized_count = 0;
401
402 for (key, opt_level) in plan.recompile {
404 if let Some(entry) = cache.cache.read().unwrap().get(&key) {
405 let graph = &entry.compiled.graph;
406 let mut config = entry.compiled.config.clone();
407 config.optimization_level = opt_level;
408
409 let mut new_compiler = GraphCompiler::new(config);
410 let recompiled = new_compiler.compile(graph)?;
411
412 cache.cache.write().unwrap().get_mut(&key).unwrap().compiled = recompiled;
414 optimized_count += 1;
415 }
416 }
417
418 for key in plan.deoptimize {
420 cache.cache.write().unwrap().remove(&key);
421 }
422
423 Ok(optimized_count)
424 }
425
426 pub fn config(&self) -> &JitConfig {
428 &self.config
429 }
430
431 pub fn hot_path_detector(&self) -> &HotPathDetector {
433 &self.hot_path_detector
434 }
435}
436
437#[derive(Debug, Clone)]
439pub struct AdaptiveOptimizationPlan {
440 pub recompile: Vec<(JitKey, OptimizationLevel)>,
442 pub deoptimize: Vec<JitKey>,
444}
445
446pub struct JitCompiler {
448 config: JitConfig,
449 cache: JitCache,
450 adaptive_optimizer: AdaptiveOptimizer,
451}
452
453impl JitCompiler {
454 pub fn new(config: JitConfig) -> Self {
456 JitCompiler {
457 cache: JitCache::new(config.clone()),
458 adaptive_optimizer: AdaptiveOptimizer::new(config.clone()),
459 config,
460 }
461 }
462
463 pub fn with_default_config() -> Self {
465 Self::new(JitConfig::default())
466 }
467
468 pub fn compile_or_retrieve(
470 &mut self,
471 graph: &EinsumGraph,
472 input_shapes: &[TensorShape],
473 ) -> Result<CompiledGraph, ExecutorError> {
474 let key = self.create_key(graph, input_shapes);
475
476 if let Some(compiled) = self.cache.get(&key) {
478 return Ok(compiled);
479 }
480
481 let config = CompilationConfig {
483 optimization_level: self.config.initial_optimization,
484 enable_shape_inference: true,
485 enable_memory_estimation: true,
486 enable_caching: true,
487 enable_parallelism: true,
488 ..Default::default()
489 };
490
491 let mut compiler = GraphCompiler::new(config);
492 let compiled = compiler.compile(graph)?;
493
494 let is_specialized = self.config.enable_specialization && !input_shapes.is_empty();
496 self.cache.insert(key, compiled.clone(), is_specialized);
497
498 Ok(compiled)
499 }
500
501 pub fn record_execution(
503 &self,
504 graph: &EinsumGraph,
505 input_shapes: &[TensorShape],
506 duration: Duration,
507 ) {
508 let key = self.create_key(graph, input_shapes);
509 self.cache.record_execution(&key, duration);
510 }
511
512 pub fn optimize_hot_paths(&mut self) -> Result<usize, ExecutorError> {
514 if !self.config.enable_adaptive_optimization {
515 return Ok(0);
516 }
517
518 self.adaptive_optimizer.optimize(&self.cache)
519 }
520
521 pub fn cache_stats(&self) -> JitCacheStats {
523 self.cache.cache_stats()
524 }
525
526 pub fn clear_cache(&self) {
528 self.cache.clear();
529 }
530
531 fn create_key(&self, graph: &EinsumGraph, input_shapes: &[TensorShape]) -> JitKey {
533 let graph_hash = self.hash_graph(graph);
534 let specialization = if self.config.enable_specialization && !input_shapes.is_empty() {
535 Some(SpecializationContext::from_shapes(input_shapes))
536 } else {
537 None
538 };
539
540 JitKey {
541 graph_hash,
542 specialization,
543 }
544 }
545
546 fn hash_graph(&self, graph: &EinsumGraph) -> u64 {
548 use std::collections::hash_map::DefaultHasher;
549 let mut hasher = DefaultHasher::new();
550 graph.nodes.len().hash(&mut hasher);
551 hasher.finish()
554 }
555}
556
557pub trait TlJitExecutor {
559 fn jit_compiler(&mut self) -> &mut JitCompiler;
561
562 fn enable_jit(&mut self);
564
565 fn disable_jit(&mut self);
567
568 fn is_jit_enabled(&self) -> bool;
570
571 fn optimize_hot_paths(&mut self) -> Result<usize, ExecutorError> {
573 self.jit_compiler().optimize_hot_paths()
574 }
575
576 fn jit_stats(&self) -> JitCacheStats;
578}
579
580#[derive(Debug, Clone)]
582pub struct JitStats {
583 pub total_compilations: usize,
585 pub cache_hits: usize,
587 pub cache_misses: usize,
589 pub recompilations: usize,
591 pub deoptimizations: usize,
593 pub avg_compilation_time: Duration,
595 pub total_time_saved: Duration,
597}
598
599impl Default for JitStats {
600 fn default() -> Self {
601 JitStats {
602 total_compilations: 0,
603 cache_hits: 0,
604 cache_misses: 0,
605 recompilations: 0,
606 deoptimizations: 0,
607 avg_compilation_time: Duration::from_secs(0),
608 total_time_saved: Duration::from_secs(0),
609 }
610 }
611}
612
613impl JitStats {
614 pub fn cache_hit_rate(&self) -> f64 {
616 if self.cache_hits + self.cache_misses == 0 {
617 return 0.0;
618 }
619 self.cache_hits as f64 / (self.cache_hits + self.cache_misses) as f64
620 }
621
622 pub fn summary(&self) -> String {
624 format!(
625 "JIT Stats: {} compilations, {:.1}% cache hit rate, {} recompilations, {:.2}ms avg compile time",
626 self.total_compilations,
627 self.cache_hit_rate() * 100.0,
628 self.recompilations,
629 self.avg_compilation_time.as_secs_f64() * 1000.0
630 )
631 }
632}
633
634#[cfg(test)]
635mod tests {
636 use super::*;
637
638 #[test]
639 fn test_jit_config_default() {
640 let config = JitConfig::default();
641 assert_eq!(config.initial_optimization, OptimizationLevel::Basic);
642 assert_eq!(config.hot_path_optimization, OptimizationLevel::Aggressive);
643 assert_eq!(config.hot_path_threshold, 10);
644 assert!(config.enable_specialization);
645 assert!(config.enable_adaptive_optimization);
646 }
647
648 #[test]
649 fn test_specialization_context() {
650 let shapes = vec![
651 TensorShape::static_shape(vec![2, 3]),
652 TensorShape::static_shape(vec![3, 4]),
653 ];
654 let ctx = SpecializationContext::from_shapes(&shapes);
655 assert_eq!(ctx.input_shapes.len(), 2);
656 assert_eq!(ctx.input_shapes[0], vec![2, 3]);
657 assert_eq!(ctx.input_shapes[1], vec![3, 4]);
658 }
659
660 #[test]
661 fn test_jit_entry_stats() {
662 let mut stats = JitEntryStats::default();
663 assert_eq!(stats.execution_count, 0);
664 assert!(!stats.is_hot(10));
665
666 for _ in 0..15 {
668 stats.record_execution(Duration::from_millis(10));
669 }
670
671 assert_eq!(stats.execution_count, 15);
672 assert!(stats.is_hot(10));
673 assert_eq!(stats.total_execution_time, Duration::from_millis(150));
674 }
675
676 #[test]
677 fn test_jit_cache_insert_retrieve() {
678 let config = JitConfig::default();
679 let cache = JitCache::new(config);
680
681 let graph = EinsumGraph::new();
682 let compiled = CompiledGraph {
683 graph: graph.clone(),
684 schedule: crate::scheduling::ExecutionSchedule {
685 execution_order: Vec::new(),
686 device_placement: HashMap::new(),
687 parallel_groups: Vec::new(),
688 estimated_cost: 0.0,
689 },
690 shapes: HashMap::new(),
691 memory_usage: HashMap::new(),
692 config: CompilationConfig::default(),
693 stats: crate::compilation::CompilationStats::default(),
694 compiled_at: std::time::SystemTime::now(),
695 };
696
697 let key = JitKey {
698 graph_hash: 12345,
699 specialization: None,
700 };
701
702 cache.insert(key.clone(), compiled.clone(), false);
703 let retrieved = cache.get(&key);
704 assert!(retrieved.is_some());
705 }
706
707 #[test]
708 fn test_jit_cache_eviction() {
709 let config = JitConfig {
710 cache_size: 2, ..Default::default()
712 };
713 let cache = JitCache::new(config);
714
715 let graph = EinsumGraph::new();
716 let compiled = CompiledGraph {
717 graph: graph.clone(),
718 schedule: crate::scheduling::ExecutionSchedule {
719 execution_order: Vec::new(),
720 device_placement: HashMap::new(),
721 parallel_groups: Vec::new(),
722 estimated_cost: 0.0,
723 },
724 shapes: HashMap::new(),
725 memory_usage: HashMap::new(),
726 config: CompilationConfig::default(),
727 stats: crate::compilation::CompilationStats::default(),
728 compiled_at: std::time::SystemTime::now(),
729 };
730
731 for i in 0..3 {
733 let key = JitKey {
734 graph_hash: i,
735 specialization: None,
736 };
737 cache.insert(key, compiled.clone(), false);
738 std::thread::sleep(Duration::from_millis(10)); }
740
741 let stats = cache.cache_stats();
742 assert_eq!(stats.total_entries, 2); }
744
745 #[test]
746 fn test_hot_path_detection() {
747 let config = JitConfig::default();
748 let cache = JitCache::new(config.clone());
749 let detector = HotPathDetector::new(config);
750
751 let graph = EinsumGraph::new();
752 let compiled = CompiledGraph {
753 graph: graph.clone(),
754 schedule: crate::scheduling::ExecutionSchedule {
755 execution_order: Vec::new(),
756 device_placement: HashMap::new(),
757 parallel_groups: Vec::new(),
758 estimated_cost: 0.0,
759 },
760 shapes: HashMap::new(),
761 memory_usage: HashMap::new(),
762 config: CompilationConfig::default(),
763 stats: crate::compilation::CompilationStats::default(),
764 compiled_at: std::time::SystemTime::now(),
765 };
766
767 let key = JitKey {
768 graph_hash: 123,
769 specialization: None,
770 };
771
772 cache.insert(key.clone(), compiled, false);
773
774 for _ in 0..15 {
776 cache.record_execution(&key, Duration::from_millis(10));
777 }
778
779 let hot_paths = detector.detect_hot_paths(&cache);
780 assert_eq!(hot_paths.len(), 1);
781 assert_eq!(hot_paths[0].graph_hash, 123);
782 }
783
784 #[test]
785 fn test_jit_compiler_basic() {
786 let mut jit = JitCompiler::with_default_config();
787 let graph = EinsumGraph::new();
788 let shapes = vec![];
789
790 let result = jit.compile_or_retrieve(&graph, &shapes);
791 assert!(result.is_ok());
792
793 let result2 = jit.compile_or_retrieve(&graph, &shapes);
795 assert!(result2.is_ok());
796 }
797
798 #[test]
799 fn test_jit_stats() {
800 let stats = JitStats::default();
801 assert_eq!(stats.cache_hit_rate(), 0.0);
802
803 let stats = JitStats {
804 cache_hits: 8,
805 cache_misses: 2,
806 ..Default::default()
807 };
808 assert_eq!(stats.cache_hit_rate(), 0.8);
809 }
810
811 #[test]
812 fn test_adaptive_optimization_plan() {
813 let plan = AdaptiveOptimizationPlan {
814 recompile: vec![(
815 JitKey {
816 graph_hash: 123,
817 specialization: None,
818 },
819 OptimizationLevel::Aggressive,
820 )],
821 deoptimize: vec![],
822 };
823
824 assert_eq!(plan.recompile.len(), 1);
825 assert_eq!(plan.deoptimize.len(), 0);
826 }
827
828 #[test]
829 fn test_jit_cache_stats() {
830 let config = JitConfig::default();
831 let cache = JitCache::new(config);
832
833 let stats = cache.cache_stats();
834 assert_eq!(stats.total_entries, 0);
835 assert_eq!(stats.hot_entries, 0);
836 assert_eq!(stats.total_executions, 0);
837 }
838
839 #[test]
840 fn test_specialization_with_device() {
841 let shapes = vec![TensorShape::static_shape(vec![2, 3])];
842 let ctx = SpecializationContext::from_shapes(&shapes).with_device("cuda:0".to_string());
843
844 assert_eq!(ctx.device, Some("cuda:0".to_string()));
845 assert_eq!(ctx.input_shapes[0], vec![2, 3]);
846 }
847
848 #[test]
849 fn test_jit_entry_cold_detection() {
850 let mut stats = JitEntryStats::default();
851
852 stats.record_execution(Duration::from_millis(10));
854
855 assert!(!stats.is_cold(5, Duration::from_millis(100)));
857
858 std::thread::sleep(Duration::from_millis(150));
860 assert!(stats.is_cold(5, Duration::from_millis(100)));
861 }
862
863 #[test]
864 fn test_jit_cache_clear() {
865 let config = JitConfig::default();
866 let cache = JitCache::new(config);
867
868 let graph = EinsumGraph::new();
869 let compiled = CompiledGraph {
870 graph: graph.clone(),
871 schedule: crate::scheduling::ExecutionSchedule {
872 execution_order: Vec::new(),
873 device_placement: HashMap::new(),
874 parallel_groups: Vec::new(),
875 estimated_cost: 0.0,
876 },
877 shapes: HashMap::new(),
878 memory_usage: HashMap::new(),
879 config: CompilationConfig::default(),
880 stats: crate::compilation::CompilationStats::default(),
881 compiled_at: std::time::SystemTime::now(),
882 };
883
884 let key = JitKey {
885 graph_hash: 123,
886 specialization: None,
887 };
888
889 cache.insert(key.clone(), compiled, false);
890 assert_eq!(cache.cache_stats().total_entries, 1);
891
892 cache.clear();
893 assert_eq!(cache.cache_stats().total_entries, 0);
894 }
895}