1use crate::compilation::{CompilationConfig, CompiledGraph, GraphCompiler, OptimizationLevel};
35use crate::error::ExecutorError;
36use crate::shape::TensorShape;
37use std::collections::HashMap;
38use std::hash::{Hash, Hasher};
39use std::sync::{Arc, RwLock};
40use std::time::{Duration, Instant};
41use tensorlogic_ir::EinsumGraph;
42
43#[derive(Debug, Clone)]
45pub struct JitConfig {
46 pub initial_optimization: OptimizationLevel,
48 pub hot_path_optimization: OptimizationLevel,
50 pub hot_path_threshold: usize,
52 pub enable_specialization: bool,
54 pub max_specializations: usize,
56 pub enable_adaptive_optimization: bool,
58 pub profiling_window: usize,
60 pub cache_size: usize,
62 pub enable_deoptimization: bool,
64 pub deoptimization_threshold: usize,
66}
67
68impl Default for JitConfig {
69 fn default() -> Self {
70 JitConfig {
71 initial_optimization: OptimizationLevel::Basic,
72 hot_path_optimization: OptimizationLevel::Aggressive,
73 hot_path_threshold: 10,
74 enable_specialization: true,
75 max_specializations: 5,
76 enable_adaptive_optimization: true,
77 profiling_window: 100,
78 cache_size: 1000,
79 enable_deoptimization: true,
80 deoptimization_threshold: 1,
81 }
82 }
83}
84
85#[derive(Debug, Clone, PartialEq, Eq, Hash)]
87pub struct JitKey {
88 pub graph_hash: u64,
90 pub specialization: Option<SpecializationContext>,
92}
93
94#[derive(Debug, Clone, PartialEq, Eq, Hash)]
96pub struct SpecializationContext {
97 pub input_shapes: Vec<Vec<usize>>,
99 pub device: Option<String>,
101}
102
103impl SpecializationContext {
104 pub fn from_shapes(shapes: &[TensorShape]) -> Self {
106 SpecializationContext {
107 input_shapes: shapes
108 .iter()
109 .map(|s| {
110 s.dims
111 .iter()
112 .filter_map(|d| d.as_static())
113 .collect::<Vec<_>>()
114 })
115 .collect(),
116 device: None,
117 }
118 }
119
120 pub fn with_device(mut self, device: String) -> Self {
122 self.device = Some(device);
123 self
124 }
125}
126
127#[derive(Debug, Clone)]
129pub struct JitEntryStats {
130 pub execution_count: usize,
132 pub total_execution_time: Duration,
134 pub avg_execution_time: Duration,
136 pub optimization_level: OptimizationLevel,
138 pub last_executed: Instant,
140 pub compiled_at: Instant,
142 pub is_specialized: bool,
144}
145
146impl Default for JitEntryStats {
147 fn default() -> Self {
148 JitEntryStats {
149 execution_count: 0,
150 total_execution_time: Duration::from_secs(0),
151 avg_execution_time: Duration::from_secs(0),
152 optimization_level: OptimizationLevel::Basic,
153 last_executed: Instant::now(),
154 compiled_at: Instant::now(),
155 is_specialized: false,
156 }
157 }
158}
159
160impl JitEntryStats {
161 pub fn record_execution(&mut self, duration: Duration) {
163 self.execution_count += 1;
164 self.total_execution_time += duration;
165 self.avg_execution_time = self.total_execution_time / self.execution_count as u32;
166 self.last_executed = Instant::now();
167 }
168
169 pub fn is_hot(&self, threshold: usize) -> bool {
171 self.execution_count >= threshold
172 }
173
174 pub fn is_cold(&self, threshold: usize, window: Duration) -> bool {
176 let time_since_last = Instant::now().duration_since(self.last_executed);
177 time_since_last > window && self.execution_count < threshold
178 }
179}
180
181#[derive(Debug, Clone)]
183pub struct JitCacheEntry {
184 pub compiled: CompiledGraph,
186 pub stats: JitEntryStats,
188}
189
190pub struct JitCache {
192 cache: Arc<RwLock<HashMap<JitKey, JitCacheEntry>>>,
193 config: JitConfig,
194}
195
196impl JitCache {
197 pub fn new(config: JitConfig) -> Self {
199 JitCache {
200 cache: Arc::new(RwLock::new(HashMap::new())),
201 config,
202 }
203 }
204
205 pub fn insert(&self, key: JitKey, compiled: CompiledGraph, is_specialized: bool) {
207 let mut cache = self.cache.write().expect("lock should not be poisoned");
208
209 if cache.len() >= self.config.cache_size {
211 self.evict_lru(&mut cache);
212 }
213
214 let stats = JitEntryStats {
215 optimization_level: compiled.config.optimization_level,
216 is_specialized,
217 ..Default::default()
218 };
219
220 cache.insert(key, JitCacheEntry { compiled, stats });
221 }
222
223 pub fn get(&self, key: &JitKey) -> Option<CompiledGraph> {
225 let cache = self.cache.read().expect("lock should not be poisoned");
226 cache.get(key).map(|entry| entry.compiled.clone())
227 }
228
229 pub fn record_execution(&self, key: &JitKey, duration: Duration) {
231 let mut cache = self.cache.write().expect("lock should not be poisoned");
232 if let Some(entry) = cache.get_mut(key) {
233 entry.stats.record_execution(duration);
234 }
235 }
236
237 pub fn get_stats(&self, key: &JitKey) -> Option<JitEntryStats> {
239 let cache = self.cache.read().expect("lock should not be poisoned");
240 cache.get(key).map(|entry| entry.stats.clone())
241 }
242
243 pub fn get_hot_paths(&self) -> Vec<(JitKey, JitEntryStats)> {
245 let cache = self.cache.read().expect("lock should not be poisoned");
246 cache
247 .iter()
248 .filter(|(_, entry)| entry.stats.is_hot(self.config.hot_path_threshold))
249 .map(|(key, entry)| (key.clone(), entry.stats.clone()))
250 .collect()
251 }
252
253 pub fn get_cold_paths(&self) -> Vec<(JitKey, JitEntryStats)> {
255 let cache = self.cache.read().expect("lock should not be poisoned");
256 let window = Duration::from_secs(300); cache
258 .iter()
259 .filter(|(_, entry)| {
260 entry
261 .stats
262 .is_cold(self.config.deoptimization_threshold, window)
263 })
264 .map(|(key, entry)| (key.clone(), entry.stats.clone()))
265 .collect()
266 }
267
268 fn evict_lru(&self, cache: &mut HashMap<JitKey, JitCacheEntry>) {
270 if let Some((key, _)) = cache
271 .iter()
272 .min_by_key(|(_, entry)| entry.stats.last_executed)
273 {
274 let key = key.clone();
275 cache.remove(&key);
276 }
277 }
278
279 pub fn clear(&self) {
281 let mut cache = self.cache.write().expect("lock should not be poisoned");
282 cache.clear();
283 }
284
285 pub fn cache_stats(&self) -> JitCacheStats {
287 let cache = self.cache.read().expect("lock should not be poisoned");
288 let total_entries = cache.len();
289 let hot_entries = cache
290 .values()
291 .filter(|e| e.stats.is_hot(self.config.hot_path_threshold))
292 .count();
293 let specialized_entries = cache.values().filter(|e| e.stats.is_specialized).count();
294 let total_executions = cache.values().map(|e| e.stats.execution_count).sum();
295
296 JitCacheStats {
297 total_entries,
298 hot_entries,
299 specialized_entries,
300 total_executions,
301 cache_capacity: self.config.cache_size,
302 }
303 }
304}
305
306#[derive(Debug, Clone)]
308pub struct JitCacheStats {
309 pub total_entries: usize,
311 pub hot_entries: usize,
313 pub specialized_entries: usize,
315 pub total_executions: usize,
317 pub cache_capacity: usize,
319}
320
321pub struct HotPathDetector {
323 config: JitConfig,
324}
325
326impl HotPathDetector {
327 pub fn new(config: JitConfig) -> Self {
329 HotPathDetector { config }
330 }
331
332 pub fn detect_hot_paths(&self, cache: &JitCache) -> Vec<JitKey> {
334 cache
335 .get_hot_paths()
336 .into_iter()
337 .map(|(key, _)| key)
338 .collect()
339 }
340
341 pub fn recommend_recompilation(&self, cache: &JitCache) -> Vec<(JitKey, OptimizationLevel)> {
343 cache
344 .get_hot_paths()
345 .into_iter()
346 .filter_map(|(key, stats)| {
347 if stats.optimization_level < self.config.hot_path_optimization {
349 Some((key, self.config.hot_path_optimization))
350 } else {
351 None
352 }
353 })
354 .collect()
355 }
356
357 pub fn recommend_deoptimization(&self, cache: &JitCache) -> Vec<JitKey> {
359 if !self.config.enable_deoptimization {
360 return Vec::new();
361 }
362
363 cache
364 .get_cold_paths()
365 .into_iter()
366 .map(|(key, _)| key)
367 .collect()
368 }
369}
370
371pub struct AdaptiveOptimizer {
373 config: JitConfig,
374 hot_path_detector: HotPathDetector,
375}
376
377impl AdaptiveOptimizer {
378 pub fn new(config: JitConfig) -> Self {
380 AdaptiveOptimizer {
381 hot_path_detector: HotPathDetector::new(config.clone()),
382 config,
383 }
384 }
385
386 pub fn analyze_and_recommend(&self, cache: &JitCache) -> AdaptiveOptimizationPlan {
388 let hot_paths = self.hot_path_detector.recommend_recompilation(cache);
389 let cold_paths = self.hot_path_detector.recommend_deoptimization(cache);
390
391 AdaptiveOptimizationPlan {
392 recompile: hot_paths,
393 deoptimize: cold_paths,
394 }
395 }
396
397 pub fn optimize(&self, cache: &JitCache) -> Result<usize, ExecutorError> {
399 let plan = self.analyze_and_recommend(cache);
400 let mut optimized_count = 0;
401
402 for (key, opt_level) in plan.recompile {
404 if let Some(entry) = cache
405 .cache
406 .read()
407 .expect("lock should not be poisoned")
408 .get(&key)
409 {
410 let graph = &entry.compiled.graph;
411 let mut config = entry.compiled.config.clone();
412 config.optimization_level = opt_level;
413
414 let mut new_compiler = GraphCompiler::new(config);
415 let recompiled = new_compiler.compile(graph)?;
416
417 cache
419 .cache
420 .write()
421 .expect("lock should not be poisoned")
422 .get_mut(&key)
423 .expect("key just retrieved from cache")
424 .compiled = recompiled;
425 optimized_count += 1;
426 }
427 }
428
429 for key in plan.deoptimize {
431 cache
432 .cache
433 .write()
434 .expect("lock should not be poisoned")
435 .remove(&key);
436 }
437
438 Ok(optimized_count)
439 }
440
441 pub fn config(&self) -> &JitConfig {
443 &self.config
444 }
445
446 pub fn hot_path_detector(&self) -> &HotPathDetector {
448 &self.hot_path_detector
449 }
450}
451
452#[derive(Debug, Clone)]
454pub struct AdaptiveOptimizationPlan {
455 pub recompile: Vec<(JitKey, OptimizationLevel)>,
457 pub deoptimize: Vec<JitKey>,
459}
460
461pub struct JitCompiler {
463 config: JitConfig,
464 cache: JitCache,
465 adaptive_optimizer: AdaptiveOptimizer,
466}
467
468impl JitCompiler {
469 pub fn new(config: JitConfig) -> Self {
471 JitCompiler {
472 cache: JitCache::new(config.clone()),
473 adaptive_optimizer: AdaptiveOptimizer::new(config.clone()),
474 config,
475 }
476 }
477
478 pub fn with_default_config() -> Self {
480 Self::new(JitConfig::default())
481 }
482
483 pub fn compile_or_retrieve(
485 &mut self,
486 graph: &EinsumGraph,
487 input_shapes: &[TensorShape],
488 ) -> Result<CompiledGraph, ExecutorError> {
489 let key = self.create_key(graph, input_shapes);
490
491 if let Some(compiled) = self.cache.get(&key) {
493 return Ok(compiled);
494 }
495
496 let config = CompilationConfig {
498 optimization_level: self.config.initial_optimization,
499 enable_shape_inference: true,
500 enable_memory_estimation: true,
501 enable_caching: true,
502 enable_parallelism: true,
503 ..Default::default()
504 };
505
506 let mut compiler = GraphCompiler::new(config);
507 let compiled = compiler.compile(graph)?;
508
509 let is_specialized = self.config.enable_specialization && !input_shapes.is_empty();
511 self.cache.insert(key, compiled.clone(), is_specialized);
512
513 Ok(compiled)
514 }
515
516 pub fn record_execution(
518 &self,
519 graph: &EinsumGraph,
520 input_shapes: &[TensorShape],
521 duration: Duration,
522 ) {
523 let key = self.create_key(graph, input_shapes);
524 self.cache.record_execution(&key, duration);
525 }
526
527 pub fn optimize_hot_paths(&mut self) -> Result<usize, ExecutorError> {
529 if !self.config.enable_adaptive_optimization {
530 return Ok(0);
531 }
532
533 self.adaptive_optimizer.optimize(&self.cache)
534 }
535
536 pub fn cache_stats(&self) -> JitCacheStats {
538 self.cache.cache_stats()
539 }
540
541 pub fn clear_cache(&self) {
543 self.cache.clear();
544 }
545
546 fn create_key(&self, graph: &EinsumGraph, input_shapes: &[TensorShape]) -> JitKey {
548 let graph_hash = self.hash_graph(graph);
549 let specialization = if self.config.enable_specialization && !input_shapes.is_empty() {
550 Some(SpecializationContext::from_shapes(input_shapes))
551 } else {
552 None
553 };
554
555 JitKey {
556 graph_hash,
557 specialization,
558 }
559 }
560
561 fn hash_graph(&self, graph: &EinsumGraph) -> u64 {
563 use std::collections::hash_map::DefaultHasher;
564 let mut hasher = DefaultHasher::new();
565 graph.nodes.len().hash(&mut hasher);
566 hasher.finish()
569 }
570}
571
572pub trait TlJitExecutor {
574 fn jit_compiler(&mut self) -> &mut JitCompiler;
576
577 fn enable_jit(&mut self);
579
580 fn disable_jit(&mut self);
582
583 fn is_jit_enabled(&self) -> bool;
585
586 fn optimize_hot_paths(&mut self) -> Result<usize, ExecutorError> {
588 self.jit_compiler().optimize_hot_paths()
589 }
590
591 fn jit_stats(&self) -> JitCacheStats;
593}
594
595#[derive(Debug, Clone)]
597pub struct JitStats {
598 pub total_compilations: usize,
600 pub cache_hits: usize,
602 pub cache_misses: usize,
604 pub recompilations: usize,
606 pub deoptimizations: usize,
608 pub avg_compilation_time: Duration,
610 pub total_time_saved: Duration,
612}
613
614impl Default for JitStats {
615 fn default() -> Self {
616 JitStats {
617 total_compilations: 0,
618 cache_hits: 0,
619 cache_misses: 0,
620 recompilations: 0,
621 deoptimizations: 0,
622 avg_compilation_time: Duration::from_secs(0),
623 total_time_saved: Duration::from_secs(0),
624 }
625 }
626}
627
628impl JitStats {
629 pub fn cache_hit_rate(&self) -> f64 {
631 if self.cache_hits + self.cache_misses == 0 {
632 return 0.0;
633 }
634 self.cache_hits as f64 / (self.cache_hits + self.cache_misses) as f64
635 }
636
637 pub fn summary(&self) -> String {
639 format!(
640 "JIT Stats: {} compilations, {:.1}% cache hit rate, {} recompilations, {:.2}ms avg compile time",
641 self.total_compilations,
642 self.cache_hit_rate() * 100.0,
643 self.recompilations,
644 self.avg_compilation_time.as_secs_f64() * 1000.0
645 )
646 }
647}
648
649#[cfg(test)]
650mod tests {
651 use super::*;
652
653 #[test]
654 fn test_jit_config_default() {
655 let config = JitConfig::default();
656 assert_eq!(config.initial_optimization, OptimizationLevel::Basic);
657 assert_eq!(config.hot_path_optimization, OptimizationLevel::Aggressive);
658 assert_eq!(config.hot_path_threshold, 10);
659 assert!(config.enable_specialization);
660 assert!(config.enable_adaptive_optimization);
661 }
662
663 #[test]
664 fn test_specialization_context() {
665 let shapes = vec![
666 TensorShape::static_shape(vec![2, 3]),
667 TensorShape::static_shape(vec![3, 4]),
668 ];
669 let ctx = SpecializationContext::from_shapes(&shapes);
670 assert_eq!(ctx.input_shapes.len(), 2);
671 assert_eq!(ctx.input_shapes[0], vec![2, 3]);
672 assert_eq!(ctx.input_shapes[1], vec![3, 4]);
673 }
674
675 #[test]
676 fn test_jit_entry_stats() {
677 let mut stats = JitEntryStats::default();
678 assert_eq!(stats.execution_count, 0);
679 assert!(!stats.is_hot(10));
680
681 for _ in 0..15 {
683 stats.record_execution(Duration::from_millis(10));
684 }
685
686 assert_eq!(stats.execution_count, 15);
687 assert!(stats.is_hot(10));
688 assert_eq!(stats.total_execution_time, Duration::from_millis(150));
689 }
690
691 #[test]
692 fn test_jit_cache_insert_retrieve() {
693 let config = JitConfig::default();
694 let cache = JitCache::new(config);
695
696 let graph = EinsumGraph::new();
697 let compiled = CompiledGraph {
698 graph: graph.clone(),
699 schedule: crate::scheduling::ExecutionSchedule {
700 execution_order: Vec::new(),
701 device_placement: HashMap::new(),
702 parallel_groups: Vec::new(),
703 estimated_cost: 0.0,
704 },
705 shapes: HashMap::new(),
706 memory_usage: HashMap::new(),
707 config: CompilationConfig::default(),
708 stats: crate::compilation::CompilationStats::default(),
709 compiled_at: std::time::SystemTime::now(),
710 };
711
712 let key = JitKey {
713 graph_hash: 12345,
714 specialization: None,
715 };
716
717 cache.insert(key.clone(), compiled.clone(), false);
718 let retrieved = cache.get(&key);
719 assert!(retrieved.is_some());
720 }
721
722 #[test]
723 fn test_jit_cache_eviction() {
724 let config = JitConfig {
725 cache_size: 2, ..Default::default()
727 };
728 let cache = JitCache::new(config);
729
730 let graph = EinsumGraph::new();
731 let compiled = CompiledGraph {
732 graph: graph.clone(),
733 schedule: crate::scheduling::ExecutionSchedule {
734 execution_order: Vec::new(),
735 device_placement: HashMap::new(),
736 parallel_groups: Vec::new(),
737 estimated_cost: 0.0,
738 },
739 shapes: HashMap::new(),
740 memory_usage: HashMap::new(),
741 config: CompilationConfig::default(),
742 stats: crate::compilation::CompilationStats::default(),
743 compiled_at: std::time::SystemTime::now(),
744 };
745
746 for i in 0..3 {
748 let key = JitKey {
749 graph_hash: i,
750 specialization: None,
751 };
752 cache.insert(key, compiled.clone(), false);
753 std::thread::sleep(Duration::from_millis(10)); }
755
756 let stats = cache.cache_stats();
757 assert_eq!(stats.total_entries, 2); }
759
760 #[test]
761 fn test_hot_path_detection() {
762 let config = JitConfig::default();
763 let cache = JitCache::new(config.clone());
764 let detector = HotPathDetector::new(config);
765
766 let graph = EinsumGraph::new();
767 let compiled = CompiledGraph {
768 graph: graph.clone(),
769 schedule: crate::scheduling::ExecutionSchedule {
770 execution_order: Vec::new(),
771 device_placement: HashMap::new(),
772 parallel_groups: Vec::new(),
773 estimated_cost: 0.0,
774 },
775 shapes: HashMap::new(),
776 memory_usage: HashMap::new(),
777 config: CompilationConfig::default(),
778 stats: crate::compilation::CompilationStats::default(),
779 compiled_at: std::time::SystemTime::now(),
780 };
781
782 let key = JitKey {
783 graph_hash: 123,
784 specialization: None,
785 };
786
787 cache.insert(key.clone(), compiled, false);
788
789 for _ in 0..15 {
791 cache.record_execution(&key, Duration::from_millis(10));
792 }
793
794 let hot_paths = detector.detect_hot_paths(&cache);
795 assert_eq!(hot_paths.len(), 1);
796 assert_eq!(hot_paths[0].graph_hash, 123);
797 }
798
799 #[test]
800 fn test_jit_compiler_basic() {
801 let mut jit = JitCompiler::with_default_config();
802 let graph = EinsumGraph::new();
803 let shapes = vec![];
804
805 let result = jit.compile_or_retrieve(&graph, &shapes);
806 assert!(result.is_ok());
807
808 let result2 = jit.compile_or_retrieve(&graph, &shapes);
810 assert!(result2.is_ok());
811 }
812
813 #[test]
814 fn test_jit_stats() {
815 let stats = JitStats::default();
816 assert_eq!(stats.cache_hit_rate(), 0.0);
817
818 let stats = JitStats {
819 cache_hits: 8,
820 cache_misses: 2,
821 ..Default::default()
822 };
823 assert_eq!(stats.cache_hit_rate(), 0.8);
824 }
825
826 #[test]
827 fn test_adaptive_optimization_plan() {
828 let plan = AdaptiveOptimizationPlan {
829 recompile: vec![(
830 JitKey {
831 graph_hash: 123,
832 specialization: None,
833 },
834 OptimizationLevel::Aggressive,
835 )],
836 deoptimize: vec![],
837 };
838
839 assert_eq!(plan.recompile.len(), 1);
840 assert_eq!(plan.deoptimize.len(), 0);
841 }
842
843 #[test]
844 fn test_jit_cache_stats() {
845 let config = JitConfig::default();
846 let cache = JitCache::new(config);
847
848 let stats = cache.cache_stats();
849 assert_eq!(stats.total_entries, 0);
850 assert_eq!(stats.hot_entries, 0);
851 assert_eq!(stats.total_executions, 0);
852 }
853
854 #[test]
855 fn test_specialization_with_device() {
856 let shapes = vec![TensorShape::static_shape(vec![2, 3])];
857 let ctx = SpecializationContext::from_shapes(&shapes).with_device("cuda:0".to_string());
858
859 assert_eq!(ctx.device, Some("cuda:0".to_string()));
860 assert_eq!(ctx.input_shapes[0], vec![2, 3]);
861 }
862
863 #[test]
864 fn test_jit_entry_cold_detection() {
865 let mut stats = JitEntryStats::default();
866
867 stats.record_execution(Duration::from_millis(10));
869
870 assert!(!stats.is_cold(5, Duration::from_millis(100)));
872
873 std::thread::sleep(Duration::from_millis(150));
875 assert!(stats.is_cold(5, Duration::from_millis(100)));
876 }
877
878 #[test]
879 fn test_jit_cache_clear() {
880 let config = JitConfig::default();
881 let cache = JitCache::new(config);
882
883 let graph = EinsumGraph::new();
884 let compiled = CompiledGraph {
885 graph: graph.clone(),
886 schedule: crate::scheduling::ExecutionSchedule {
887 execution_order: Vec::new(),
888 device_placement: HashMap::new(),
889 parallel_groups: Vec::new(),
890 estimated_cost: 0.0,
891 },
892 shapes: HashMap::new(),
893 memory_usage: HashMap::new(),
894 config: CompilationConfig::default(),
895 stats: crate::compilation::CompilationStats::default(),
896 compiled_at: std::time::SystemTime::now(),
897 };
898
899 let key = JitKey {
900 graph_hash: 123,
901 specialization: None,
902 };
903
904 cache.insert(key.clone(), compiled, false);
905 assert_eq!(cache.cache_stats().total_entries, 1);
906
907 cache.clear();
908 assert_eq!(cache.cache_stats().total_entries, 0);
909 }
910}