1use crate::error::ExecutorError;
26use crate::memory::MemoryEstimator;
27use crate::optimization::{GraphOptimizer, OptimizationResult};
28use crate::scheduling::{ExecutionSchedule, Scheduler, SchedulingStrategy};
29use crate::shape::ShapeInferenceContext;
30use crate::validation::GraphValidator;
31use std::collections::HashMap;
32use std::sync::{Arc, RwLock};
33use std::time::{Duration, SystemTime};
34use tensorlogic_ir::EinsumGraph;
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Default)]
38pub enum OptimizationLevel {
39 None,
41 Basic,
43 #[default]
45 Moderate,
46 Aggressive,
48}
49
50#[derive(Debug, Clone)]
52pub struct CompilationConfig {
53 pub optimization_level: OptimizationLevel,
55 pub enable_shape_inference: bool,
57 pub enable_memory_estimation: bool,
59 pub target_device: Option<String>,
61 pub memory_budget: Option<usize>,
63 pub enable_caching: bool,
65 pub enable_parallelism: bool,
67}
68
69impl Default for CompilationConfig {
70 fn default() -> Self {
71 CompilationConfig {
72 optimization_level: OptimizationLevel::default(),
73 enable_shape_inference: true,
74 enable_memory_estimation: true,
75 target_device: None,
76 memory_budget: None,
77 enable_caching: true,
78 enable_parallelism: true,
79 }
80 }
81}
82
83#[derive(Debug, Clone)]
85pub struct CompilationStats {
86 pub compilation_time: Duration,
88 pub original_nodes: usize,
90 pub optimized_nodes: usize,
92 pub fusions_applied: usize,
94 pub dead_nodes_eliminated: usize,
96 pub estimated_memory_bytes: usize,
98 pub execution_steps: usize,
100}
101
102impl Default for CompilationStats {
103 fn default() -> Self {
104 CompilationStats {
105 compilation_time: Duration::from_secs(0),
106 original_nodes: 0,
107 optimized_nodes: 0,
108 fusions_applied: 0,
109 dead_nodes_eliminated: 0,
110 estimated_memory_bytes: 0,
111 execution_steps: 0,
112 }
113 }
114}
115
116#[derive(Debug, Clone)]
121pub struct CompiledGraph {
122 pub graph: EinsumGraph,
124 pub schedule: ExecutionSchedule,
126 pub shapes: HashMap<usize, Vec<usize>>,
128 pub memory_usage: HashMap<usize, usize>,
130 pub config: CompilationConfig,
132 pub stats: CompilationStats,
134 pub compiled_at: SystemTime,
136}
137
138impl CompiledGraph {
139 pub fn node_count(&self) -> usize {
141 self.graph.nodes.len()
142 }
143
144 pub fn total_memory(&self) -> usize {
146 self.memory_usage.values().sum()
147 }
148
149 pub fn is_valid(&self) -> bool {
151 if self.graph.nodes.is_empty() {
153 return false;
154 }
155
156 if self.schedule.execution_order.len() != self.graph.nodes.len() {
158 return false;
159 }
160
161 true
162 }
163
164 pub fn summary(&self) -> String {
166 format!(
167 "CompiledGraph: {} nodes, {} steps, {:.2}MB memory, compiled in {:.2}ms",
168 self.node_count(),
169 self.stats.execution_steps,
170 self.total_memory() as f64 / 1_000_000.0,
171 self.stats.compilation_time.as_secs_f64() * 1000.0
172 )
173 }
174}
175
176pub struct GraphCompiler {
178 config: CompilationConfig,
179 optimizer: GraphOptimizer,
180 validator: GraphValidator,
181 scheduler: Scheduler,
182}
183
184impl GraphCompiler {
185 pub fn new(config: CompilationConfig) -> Self {
187 GraphCompiler {
188 config,
189 optimizer: GraphOptimizer::new(),
190 validator: GraphValidator::new(),
191 scheduler: Scheduler::new(SchedulingStrategy::Balanced),
192 }
193 }
194
195 pub fn with_default_config() -> Self {
197 Self::new(CompilationConfig::default())
198 }
199
200 pub fn compile(&mut self, graph: &EinsumGraph) -> Result<CompiledGraph, ExecutorError> {
202 let start_time = SystemTime::now();
203 let original_nodes = graph.nodes.len();
204
205 let validation_result = self.validator.validate(graph);
207 if !validation_result.is_valid {
208 return Err(ExecutorError::GraphValidationError(format!(
209 "Graph validation failed: {}",
210 validation_result
211 .errors
212 .first()
213 .map(|e| e.as_str())
214 .unwrap_or("unknown error")
215 )));
216 }
217
218 let optimized_graph = graph.clone();
220
221 let opt_result = match self.config.optimization_level {
223 OptimizationLevel::None => OptimizationResult {
224 fusion_opportunities: vec![],
225 dead_nodes: vec![],
226 redundant_computations: vec![],
227 estimated_improvement: 0.0,
228 },
229 OptimizationLevel::Basic
230 | OptimizationLevel::Moderate
231 | OptimizationLevel::Aggressive => {
232 self.optimizer.analyze(&optimized_graph)
234 }
235 };
236
237 let schedule = self.scheduler.schedule(&optimized_graph);
239
240 let shapes = if self.config.enable_shape_inference {
242 let _shape_ctx = ShapeInferenceContext::new();
243 HashMap::new()
246 } else {
247 HashMap::new()
248 };
249
250 let memory_usage = if self.config.enable_memory_estimation {
252 use crate::capabilities::DType;
253 let estimator = MemoryEstimator::new(DType::F32);
254 let estimate = estimator.estimate(&optimized_graph);
255 let mut per_node: HashMap<usize, usize> = HashMap::new();
257 for (idx, mem) in estimate.intermediate_memory.iter().enumerate() {
258 per_node.insert(idx, mem.bytes);
259 }
260 per_node
261 } else {
262 HashMap::new()
263 };
264
265 let compilation_time = start_time.elapsed().unwrap_or(Duration::from_secs(0));
266
267 let stats = CompilationStats {
268 compilation_time,
269 original_nodes,
270 optimized_nodes: optimized_graph.nodes.len(),
271 fusions_applied: opt_result.fusion_opportunities.len(),
272 dead_nodes_eliminated: opt_result.dead_nodes.len(),
273 estimated_memory_bytes: memory_usage.values().sum(),
274 execution_steps: schedule.execution_order.len(),
275 };
276
277 Ok(CompiledGraph {
278 graph: optimized_graph,
279 schedule,
280 shapes,
281 memory_usage,
282 config: self.config.clone(),
283 stats,
284 compiled_at: SystemTime::now(),
285 })
286 }
287
288 pub fn set_config(&mut self, config: CompilationConfig) {
290 self.config = config;
291 }
292
293 pub fn config(&self) -> &CompilationConfig {
295 &self.config
296 }
297}
298
299#[derive(Debug, Clone, PartialEq, Eq, Hash)]
301pub struct CompilationKey {
302 pub graph_hash: u64,
304 pub optimization_level: OptimizationLevel,
306 pub target_device: Option<String>,
308}
309
310impl CompilationKey {
311 pub fn new(graph: &EinsumGraph, config: &CompilationConfig) -> Self {
313 CompilationKey {
314 graph_hash: Self::hash_graph(graph),
315 optimization_level: config.optimization_level,
316 target_device: config.target_device.clone(),
317 }
318 }
319
320 fn hash_graph(graph: &EinsumGraph) -> u64 {
322 use std::collections::hash_map::DefaultHasher;
323 use std::hash::{Hash, Hasher};
324
325 let mut hasher = DefaultHasher::new();
326
327 graph.nodes.len().hash(&mut hasher);
329
330 for node in &graph.nodes {
332 match &node.op {
334 tensorlogic_ir::OpType::Einsum { spec } => {
335 "einsum".hash(&mut hasher);
336 spec.hash(&mut hasher);
337 }
338 tensorlogic_ir::OpType::Reduce { op, axes } => {
339 "reduce".hash(&mut hasher);
340 op.hash(&mut hasher);
341 axes.hash(&mut hasher);
342 }
343 tensorlogic_ir::OpType::ElemUnary { op } => {
344 "elemunary".hash(&mut hasher);
345 op.hash(&mut hasher);
346 }
347 tensorlogic_ir::OpType::ElemBinary { op } => {
348 "elembinary".hash(&mut hasher);
349 op.hash(&mut hasher);
350 }
351 }
352
353 node.inputs.hash(&mut hasher);
355 node.outputs.hash(&mut hasher);
356 }
357
358 hasher.finish()
359 }
360}
361
362#[derive(Debug, Clone, Default)]
364pub struct CacheStats {
365 pub hits: usize,
367 pub misses: usize,
369 pub size: usize,
371 pub time_saved: Duration,
373}
374
375impl CacheStats {
376 pub fn hit_rate(&self) -> f64 {
378 let total = self.hits + self.misses;
379 if total == 0 {
380 0.0
381 } else {
382 self.hits as f64 / total as f64
383 }
384 }
385}
386
387pub struct CompilationCache {
392 cache: Arc<RwLock<HashMap<CompilationKey, Arc<CompiledGraph>>>>,
393 stats: Arc<RwLock<CacheStats>>,
394 max_size: usize,
395}
396
397impl CompilationCache {
398 pub fn new(max_size: usize) -> Self {
400 CompilationCache {
401 cache: Arc::new(RwLock::new(HashMap::new())),
402 stats: Arc::new(RwLock::new(CacheStats::default())),
403 max_size,
404 }
405 }
406
407 pub fn with_default_size() -> Self {
409 Self::new(100)
410 }
411
412 pub fn get(&self, key: &CompilationKey) -> Option<Arc<CompiledGraph>> {
414 let cache = self.cache.read().unwrap();
415 let result = cache.get(key).cloned();
416
417 let mut stats = self.stats.write().unwrap();
419 if let Some(ref compiled) = result {
420 stats.hits += 1;
421 stats.time_saved += compiled.stats.compilation_time;
422 } else {
423 stats.misses += 1;
424 }
425
426 result
427 }
428
429 pub fn insert(&self, key: CompilationKey, compiled: CompiledGraph) {
431 let mut cache = self.cache.write().unwrap();
432
433 if cache.len() >= self.max_size && !cache.contains_key(&key) {
435 if let Some(oldest_key) = cache.keys().next().cloned() {
436 cache.remove(&oldest_key);
437 }
438 }
439
440 cache.insert(key, Arc::new(compiled));
441
442 let mut stats = self.stats.write().unwrap();
444 stats.size = cache.len();
445 }
446
447 pub fn clear(&self) {
449 let mut cache = self.cache.write().unwrap();
450 cache.clear();
451
452 let mut stats = self.stats.write().unwrap();
453 stats.size = 0;
454 }
455
456 pub fn stats(&self) -> CacheStats {
458 self.stats.read().unwrap().clone()
459 }
460
461 pub fn len(&self) -> usize {
463 self.cache.read().unwrap().len()
464 }
465
466 pub fn is_empty(&self) -> bool {
468 self.len() == 0
469 }
470}
471
472pub trait TlCompilableExecutor {
477 fn compile_graph(
482 &mut self,
483 graph: &EinsumGraph,
484 config: &CompilationConfig,
485 ) -> Result<CompiledGraph, ExecutorError>;
486
487 fn execute_compiled(
492 &mut self,
493 compiled: &CompiledGraph,
494 inputs: &HashMap<usize, Box<dyn std::any::Any>>,
495 ) -> Result<HashMap<usize, Box<dyn std::any::Any>>, ExecutorError>;
496
497 fn supports_compilation(&self) -> bool {
499 true
500 }
501}
502
503#[cfg(test)]
504mod tests {
505 use super::*;
506 use tensorlogic_ir::EinsumNode;
507
508 fn create_test_graph() -> EinsumGraph {
509 let mut graph = EinsumGraph::new();
510
511 graph.tensors.push("input".to_string());
513 graph.inputs.push(0);
514
515 graph
517 .nodes
518 .push(EinsumNode::new("ij->ij", vec![0], vec![1]));
519 graph
520 .nodes
521 .push(EinsumNode::new("ij,jk->ik", vec![1], vec![2]));
522 graph
523 .nodes
524 .push(EinsumNode::new("ik->ik", vec![2], vec![3]));
525
526 graph.outputs.push(3);
528
529 graph
530 }
531
532 #[test]
533 fn test_compilation_key_equality() {
534 let graph1 = create_test_graph();
535 let graph2 = create_test_graph();
536
537 let config = CompilationConfig::default();
538
539 let key1 = CompilationKey::new(&graph1, &config);
540 let key2 = CompilationKey::new(&graph2, &config);
541
542 assert_eq!(key1, key2);
543 }
544
545 #[test]
546 fn test_compilation_key_different_graphs() {
547 let graph1 = create_test_graph();
548 let mut graph2 = create_test_graph();
549 graph2.nodes.push(EinsumNode::new("i->i", vec![3], vec![4]));
550
551 let config = CompilationConfig::default();
552
553 let key1 = CompilationKey::new(&graph1, &config);
554 let key2 = CompilationKey::new(&graph2, &config);
555
556 assert_ne!(key1, key2);
557 }
558
559 #[test]
560 fn test_compilation_key_different_config() {
561 let graph = create_test_graph();
562
563 let config1 = CompilationConfig {
564 optimization_level: OptimizationLevel::Basic,
565 ..Default::default()
566 };
567
568 let config2 = CompilationConfig {
569 optimization_level: OptimizationLevel::Aggressive,
570 ..Default::default()
571 };
572
573 let key1 = CompilationKey::new(&graph, &config1);
574 let key2 = CompilationKey::new(&graph, &config2);
575
576 assert_ne!(key1, key2);
577 }
578
579 #[test]
580 fn test_graph_compiler_basic() {
581 let graph = create_test_graph();
582 let mut compiler = GraphCompiler::new(CompilationConfig {
583 optimization_level: OptimizationLevel::Basic,
584 ..Default::default()
585 });
586
587 let result = compiler.compile(&graph);
588 assert!(result.is_ok());
589
590 let compiled = result.unwrap();
591 assert!(compiled.is_valid());
592 assert_eq!(compiled.stats.original_nodes, 3);
593 }
594
595 #[test]
596 fn test_graph_compiler_moderate() {
597 let graph = create_test_graph();
598 let mut compiler = GraphCompiler::new(CompilationConfig {
599 optimization_level: OptimizationLevel::Moderate,
600 ..Default::default()
601 });
602
603 let result = compiler.compile(&graph);
604 assert!(result.is_ok());
605
606 let compiled = result.unwrap();
607 assert!(compiled.is_valid());
608 assert!(compiled.stats.compilation_time > Duration::from_secs(0));
609 }
610
611 #[test]
612 fn test_graph_compiler_aggressive() {
613 let graph = create_test_graph();
614 let mut compiler = GraphCompiler::new(CompilationConfig {
615 optimization_level: OptimizationLevel::Aggressive,
616 ..Default::default()
617 });
618
619 let result = compiler.compile(&graph);
620 assert!(result.is_ok());
621
622 let compiled = result.unwrap();
623 assert!(compiled.is_valid());
624 assert_eq!(compiled.node_count(), compiled.stats.optimized_nodes);
625 }
626
627 #[test]
628 fn test_compiled_graph_summary() {
629 let graph = create_test_graph();
630 let mut compiler = GraphCompiler::with_default_config();
631 let compiled = compiler.compile(&graph).unwrap();
632
633 let summary = compiled.summary();
634 assert!(summary.contains("CompiledGraph"));
635 assert!(summary.contains("nodes"));
636 assert!(summary.contains("MB"));
637 }
638
639 #[test]
640 fn test_compilation_cache_basic() {
641 let cache = CompilationCache::new(10);
642 assert_eq!(cache.len(), 0);
643 assert!(cache.is_empty());
644
645 let graph = create_test_graph();
646 let config = CompilationConfig::default();
647 let key = CompilationKey::new(&graph, &config);
648
649 assert!(cache.get(&key).is_none());
651
652 let mut compiler = GraphCompiler::with_default_config();
654 let compiled = compiler.compile(&graph).unwrap();
655 cache.insert(key.clone(), compiled);
656
657 assert_eq!(cache.len(), 1);
658 assert!(!cache.is_empty());
659
660 let cached = cache.get(&key);
662 assert!(cached.is_some());
663 }
664
665 #[test]
666 fn test_compilation_cache_eviction() {
667 let cache = CompilationCache::new(2);
668
669 let graph1 = create_test_graph();
670 let mut graph2 = create_test_graph();
671 graph2.nodes.push(EinsumNode::new("i->i", vec![3], vec![4]));
672 let mut graph3 = create_test_graph();
673 graph3
674 .nodes
675 .push(EinsumNode::new("ij->ji", vec![3], vec![5]));
676
677 let config = CompilationConfig::default();
678 let mut compiler = GraphCompiler::with_default_config();
679
680 let key1 = CompilationKey::new(&graph1, &config);
681 let key2 = CompilationKey::new(&graph2, &config);
682 let key3 = CompilationKey::new(&graph3, &config);
683
684 cache.insert(key1.clone(), compiler.compile(&graph1).unwrap());
686 cache.insert(key2.clone(), compiler.compile(&graph2).unwrap());
687 assert_eq!(cache.len(), 2);
688
689 cache.insert(key3.clone(), compiler.compile(&graph3).unwrap());
691 assert_eq!(cache.len(), 2);
692 }
693
694 #[test]
695 fn test_compilation_cache_stats() {
696 let cache = CompilationCache::new(10);
697
698 let graph = create_test_graph();
699 let config = CompilationConfig::default();
700 let key = CompilationKey::new(&graph, &config);
701
702 let stats = cache.stats();
704 assert_eq!(stats.hits, 0);
705 assert_eq!(stats.misses, 0);
706 assert_eq!(stats.hit_rate(), 0.0);
707
708 cache.get(&key);
710 let stats = cache.stats();
711 assert_eq!(stats.misses, 1);
712
713 let mut compiler = GraphCompiler::with_default_config();
715 let compiled = compiler.compile(&graph).unwrap();
716 cache.insert(key.clone(), compiled);
717 cache.get(&key);
718
719 let stats = cache.stats();
720 assert_eq!(stats.hits, 1);
721 assert_eq!(stats.misses, 1);
722 assert_eq!(stats.hit_rate(), 0.5);
723 }
724
725 #[test]
726 fn test_compilation_cache_clear() {
727 let cache = CompilationCache::new(10);
728 let graph = create_test_graph();
729 let config = CompilationConfig::default();
730 let key = CompilationKey::new(&graph, &config);
731
732 let mut compiler = GraphCompiler::with_default_config();
733 let compiled = compiler.compile(&graph).unwrap();
734 cache.insert(key.clone(), compiled);
735
736 assert_eq!(cache.len(), 1);
737
738 cache.clear();
739 assert_eq!(cache.len(), 0);
740 assert!(cache.is_empty());
741 }
742
743 #[test]
744 fn test_optimization_levels() {
745 let graph = create_test_graph();
746
747 let levels = vec![
748 OptimizationLevel::None,
749 OptimizationLevel::Basic,
750 OptimizationLevel::Moderate,
751 OptimizationLevel::Aggressive,
752 ];
753
754 for level in levels {
755 let mut compiler = GraphCompiler::new(CompilationConfig {
756 optimization_level: level,
757 ..Default::default()
758 });
759
760 let result = compiler.compile(&graph);
761 assert!(result.is_ok(), "Compilation failed for level {:?}", level);
762
763 let compiled = result.unwrap();
764 assert!(compiled.is_valid());
765 }
766 }
767
768 #[test]
769 fn test_compiled_graph_memory_estimation() {
770 let graph = create_test_graph();
771 let mut compiler = GraphCompiler::new(CompilationConfig {
772 enable_memory_estimation: true,
773 ..Default::default()
774 });
775
776 let compiled = compiler.compile(&graph).unwrap();
777 let _memory = compiled.total_memory();
779 }
780
781 #[test]
782 fn test_config_update() {
783 let mut compiler = GraphCompiler::with_default_config();
784
785 let new_config = CompilationConfig {
786 optimization_level: OptimizationLevel::Aggressive,
787 enable_parallelism: false,
788 ..Default::default()
789 };
790
791 compiler.set_config(new_config.clone());
792
793 let config = compiler.config();
794 assert_eq!(config.optimization_level, OptimizationLevel::Aggressive);
795 assert!(!config.enable_parallelism);
796 }
797}