1use crate::error::ExecutorError;
26use crate::memory::MemoryEstimator;
27use crate::optimization::{GraphOptimizer, OptimizationResult};
28use crate::scheduling::{ExecutionSchedule, Scheduler, SchedulingStrategy};
29use crate::shape::ShapeInferenceContext;
30use crate::validation::GraphValidator;
31use std::collections::HashMap;
32use std::sync::{Arc, RwLock};
33use std::time::{Duration, SystemTime};
34use tensorlogic_ir::EinsumGraph;
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Default)]
38pub enum OptimizationLevel {
39 None,
41 Basic,
43 #[default]
45 Moderate,
46 Aggressive,
48}
49
50#[derive(Debug, Clone)]
52pub struct CompilationConfig {
53 pub optimization_level: OptimizationLevel,
55 pub enable_shape_inference: bool,
57 pub enable_memory_estimation: bool,
59 pub target_device: Option<String>,
61 pub memory_budget: Option<usize>,
63 pub enable_caching: bool,
65 pub enable_parallelism: bool,
67}
68
69impl Default for CompilationConfig {
70 fn default() -> Self {
71 CompilationConfig {
72 optimization_level: OptimizationLevel::default(),
73 enable_shape_inference: true,
74 enable_memory_estimation: true,
75 target_device: None,
76 memory_budget: None,
77 enable_caching: true,
78 enable_parallelism: true,
79 }
80 }
81}
82
83#[derive(Debug, Clone)]
85pub struct CompilationStats {
86 pub compilation_time: Duration,
88 pub original_nodes: usize,
90 pub optimized_nodes: usize,
92 pub fusions_applied: usize,
94 pub dead_nodes_eliminated: usize,
96 pub estimated_memory_bytes: usize,
98 pub execution_steps: usize,
100}
101
102impl Default for CompilationStats {
103 fn default() -> Self {
104 CompilationStats {
105 compilation_time: Duration::from_secs(0),
106 original_nodes: 0,
107 optimized_nodes: 0,
108 fusions_applied: 0,
109 dead_nodes_eliminated: 0,
110 estimated_memory_bytes: 0,
111 execution_steps: 0,
112 }
113 }
114}
115
116#[derive(Debug, Clone)]
121pub struct CompiledGraph {
122 pub graph: EinsumGraph,
124 pub schedule: ExecutionSchedule,
126 pub shapes: HashMap<usize, Vec<usize>>,
128 pub memory_usage: HashMap<usize, usize>,
130 pub config: CompilationConfig,
132 pub stats: CompilationStats,
134 pub compiled_at: SystemTime,
136}
137
138impl CompiledGraph {
139 pub fn node_count(&self) -> usize {
141 self.graph.nodes.len()
142 }
143
144 pub fn total_memory(&self) -> usize {
146 self.memory_usage.values().sum()
147 }
148
149 pub fn is_valid(&self) -> bool {
151 if self.graph.nodes.is_empty() {
153 return false;
154 }
155
156 if self.schedule.execution_order.len() != self.graph.nodes.len() {
158 return false;
159 }
160
161 true
162 }
163
164 pub fn summary(&self) -> String {
166 format!(
167 "CompiledGraph: {} nodes, {} steps, {:.2}MB memory, compiled in {:.2}ms",
168 self.node_count(),
169 self.stats.execution_steps,
170 self.total_memory() as f64 / 1_000_000.0,
171 self.stats.compilation_time.as_secs_f64() * 1000.0
172 )
173 }
174}
175
176pub struct GraphCompiler {
178 config: CompilationConfig,
179 optimizer: GraphOptimizer,
180 validator: GraphValidator,
181 scheduler: Scheduler,
182}
183
184impl GraphCompiler {
185 pub fn new(config: CompilationConfig) -> Self {
187 GraphCompiler {
188 config,
189 optimizer: GraphOptimizer::new(),
190 validator: GraphValidator::new(),
191 scheduler: Scheduler::new(SchedulingStrategy::Balanced),
192 }
193 }
194
195 pub fn with_default_config() -> Self {
197 Self::new(CompilationConfig::default())
198 }
199
200 pub fn compile(&mut self, graph: &EinsumGraph) -> Result<CompiledGraph, ExecutorError> {
202 let start_time = SystemTime::now();
203 let original_nodes = graph.nodes.len();
204
205 let validation_result = self.validator.validate(graph);
207 if !validation_result.is_valid {
208 return Err(ExecutorError::GraphValidationError(format!(
209 "Graph validation failed: {}",
210 validation_result
211 .errors
212 .first()
213 .map(|e| e.as_str())
214 .unwrap_or("unknown error")
215 )));
216 }
217
218 let optimized_graph = graph.clone();
220
221 let opt_result = match self.config.optimization_level {
223 OptimizationLevel::None => OptimizationResult {
224 fusion_opportunities: vec![],
225 dead_nodes: vec![],
226 redundant_computations: vec![],
227 estimated_improvement: 0.0,
228 },
229 OptimizationLevel::Basic
230 | OptimizationLevel::Moderate
231 | OptimizationLevel::Aggressive => {
232 self.optimizer.analyze(&optimized_graph)
234 }
235 };
236
237 let schedule = self.scheduler.schedule(&optimized_graph);
239
240 let shapes = if self.config.enable_shape_inference {
242 let _shape_ctx = ShapeInferenceContext::new();
243 HashMap::new()
246 } else {
247 HashMap::new()
248 };
249
250 let memory_usage = if self.config.enable_memory_estimation {
252 use crate::capabilities::DType;
253 let estimator = MemoryEstimator::new(DType::F32);
254 let estimate = estimator.estimate(&optimized_graph);
255 let mut per_node: HashMap<usize, usize> = HashMap::new();
257 for (idx, mem) in estimate.intermediate_memory.iter().enumerate() {
258 per_node.insert(idx, mem.bytes);
259 }
260 per_node
261 } else {
262 HashMap::new()
263 };
264
265 let compilation_time = start_time.elapsed().unwrap_or(Duration::from_secs(0));
266
267 let stats = CompilationStats {
268 compilation_time,
269 original_nodes,
270 optimized_nodes: optimized_graph.nodes.len(),
271 fusions_applied: opt_result.fusion_opportunities.len(),
272 dead_nodes_eliminated: opt_result.dead_nodes.len(),
273 estimated_memory_bytes: memory_usage.values().sum(),
274 execution_steps: schedule.execution_order.len(),
275 };
276
277 Ok(CompiledGraph {
278 graph: optimized_graph,
279 schedule,
280 shapes,
281 memory_usage,
282 config: self.config.clone(),
283 stats,
284 compiled_at: SystemTime::now(),
285 })
286 }
287
288 pub fn set_config(&mut self, config: CompilationConfig) {
290 self.config = config;
291 }
292
293 pub fn config(&self) -> &CompilationConfig {
295 &self.config
296 }
297}
298
299#[derive(Debug, Clone, PartialEq, Eq, Hash)]
301pub struct CompilationKey {
302 pub graph_hash: u64,
304 pub optimization_level: OptimizationLevel,
306 pub target_device: Option<String>,
308}
309
310impl CompilationKey {
311 pub fn new(graph: &EinsumGraph, config: &CompilationConfig) -> Self {
313 CompilationKey {
314 graph_hash: Self::hash_graph(graph),
315 optimization_level: config.optimization_level,
316 target_device: config.target_device.clone(),
317 }
318 }
319
320 fn hash_graph(graph: &EinsumGraph) -> u64 {
322 use std::collections::hash_map::DefaultHasher;
323 use std::hash::{Hash, Hasher};
324
325 let mut hasher = DefaultHasher::new();
326
327 graph.nodes.len().hash(&mut hasher);
329
330 for node in &graph.nodes {
332 match &node.op {
334 tensorlogic_ir::OpType::Einsum { spec } => {
335 "einsum".hash(&mut hasher);
336 spec.hash(&mut hasher);
337 }
338 tensorlogic_ir::OpType::Reduce { op, axes } => {
339 "reduce".hash(&mut hasher);
340 op.hash(&mut hasher);
341 axes.hash(&mut hasher);
342 }
343 tensorlogic_ir::OpType::ElemUnary { op } => {
344 "elemunary".hash(&mut hasher);
345 op.hash(&mut hasher);
346 }
347 tensorlogic_ir::OpType::ElemBinary { op } => {
348 "elembinary".hash(&mut hasher);
349 op.hash(&mut hasher);
350 }
351 }
352
353 node.inputs.hash(&mut hasher);
355 node.outputs.hash(&mut hasher);
356 }
357
358 hasher.finish()
359 }
360}
361
362#[derive(Debug, Clone, Default)]
364pub struct CacheStats {
365 pub hits: usize,
367 pub misses: usize,
369 pub size: usize,
371 pub time_saved: Duration,
373}
374
375impl CacheStats {
376 pub fn hit_rate(&self) -> f64 {
378 let total = self.hits + self.misses;
379 if total == 0 {
380 0.0
381 } else {
382 self.hits as f64 / total as f64
383 }
384 }
385}
386
387pub struct CompilationCache {
392 cache: Arc<RwLock<HashMap<CompilationKey, Arc<CompiledGraph>>>>,
393 stats: Arc<RwLock<CacheStats>>,
394 max_size: usize,
395}
396
397impl CompilationCache {
398 pub fn new(max_size: usize) -> Self {
400 CompilationCache {
401 cache: Arc::new(RwLock::new(HashMap::new())),
402 stats: Arc::new(RwLock::new(CacheStats::default())),
403 max_size,
404 }
405 }
406
407 pub fn with_default_size() -> Self {
409 Self::new(100)
410 }
411
412 pub fn get(&self, key: &CompilationKey) -> Option<Arc<CompiledGraph>> {
414 let cache = self.cache.read().expect("lock should not be poisoned");
415 let result = cache.get(key).cloned();
416
417 let mut stats = self.stats.write().expect("lock should not be poisoned");
419 if let Some(ref compiled) = result {
420 stats.hits += 1;
421 stats.time_saved += compiled.stats.compilation_time;
422 } else {
423 stats.misses += 1;
424 }
425
426 result
427 }
428
429 pub fn insert(&self, key: CompilationKey, compiled: CompiledGraph) {
431 let mut cache = self.cache.write().expect("lock should not be poisoned");
432
433 if cache.len() >= self.max_size && !cache.contains_key(&key) {
435 if let Some(oldest_key) = cache.keys().next().cloned() {
436 cache.remove(&oldest_key);
437 }
438 }
439
440 cache.insert(key, Arc::new(compiled));
441
442 let mut stats = self.stats.write().expect("lock should not be poisoned");
444 stats.size = cache.len();
445 }
446
447 pub fn clear(&self) {
449 let mut cache = self.cache.write().expect("lock should not be poisoned");
450 cache.clear();
451
452 let mut stats = self.stats.write().expect("lock should not be poisoned");
453 stats.size = 0;
454 }
455
456 pub fn stats(&self) -> CacheStats {
458 self.stats
459 .read()
460 .expect("lock should not be poisoned")
461 .clone()
462 }
463
464 pub fn len(&self) -> usize {
466 self.cache
467 .read()
468 .expect("lock should not be poisoned")
469 .len()
470 }
471
472 pub fn is_empty(&self) -> bool {
474 self.len() == 0
475 }
476}
477
478pub trait TlCompilableExecutor {
483 fn compile_graph(
488 &mut self,
489 graph: &EinsumGraph,
490 config: &CompilationConfig,
491 ) -> Result<CompiledGraph, ExecutorError>;
492
493 fn execute_compiled(
498 &mut self,
499 compiled: &CompiledGraph,
500 inputs: &HashMap<usize, Box<dyn std::any::Any>>,
501 ) -> Result<HashMap<usize, Box<dyn std::any::Any>>, ExecutorError>;
502
503 fn supports_compilation(&self) -> bool {
505 true
506 }
507}
508
509#[cfg(test)]
510mod tests {
511 use super::*;
512 use tensorlogic_ir::EinsumNode;
513
514 fn create_test_graph() -> EinsumGraph {
515 let mut graph = EinsumGraph::new();
516
517 graph.tensors.push("input".to_string());
519 graph.inputs.push(0);
520
521 graph
523 .nodes
524 .push(EinsumNode::new("ij->ij", vec![0], vec![1]));
525 graph
526 .nodes
527 .push(EinsumNode::new("ij,jk->ik", vec![1], vec![2]));
528 graph
529 .nodes
530 .push(EinsumNode::new("ik->ik", vec![2], vec![3]));
531
532 graph.outputs.push(3);
534
535 graph
536 }
537
538 #[test]
539 fn test_compilation_key_equality() {
540 let graph1 = create_test_graph();
541 let graph2 = create_test_graph();
542
543 let config = CompilationConfig::default();
544
545 let key1 = CompilationKey::new(&graph1, &config);
546 let key2 = CompilationKey::new(&graph2, &config);
547
548 assert_eq!(key1, key2);
549 }
550
551 #[test]
552 fn test_compilation_key_different_graphs() {
553 let graph1 = create_test_graph();
554 let mut graph2 = create_test_graph();
555 graph2.nodes.push(EinsumNode::new("i->i", vec![3], vec![4]));
556
557 let config = CompilationConfig::default();
558
559 let key1 = CompilationKey::new(&graph1, &config);
560 let key2 = CompilationKey::new(&graph2, &config);
561
562 assert_ne!(key1, key2);
563 }
564
565 #[test]
566 fn test_compilation_key_different_config() {
567 let graph = create_test_graph();
568
569 let config1 = CompilationConfig {
570 optimization_level: OptimizationLevel::Basic,
571 ..Default::default()
572 };
573
574 let config2 = CompilationConfig {
575 optimization_level: OptimizationLevel::Aggressive,
576 ..Default::default()
577 };
578
579 let key1 = CompilationKey::new(&graph, &config1);
580 let key2 = CompilationKey::new(&graph, &config2);
581
582 assert_ne!(key1, key2);
583 }
584
585 #[test]
586 fn test_graph_compiler_basic() {
587 let graph = create_test_graph();
588 let mut compiler = GraphCompiler::new(CompilationConfig {
589 optimization_level: OptimizationLevel::Basic,
590 ..Default::default()
591 });
592
593 let result = compiler.compile(&graph);
594 assert!(result.is_ok());
595
596 let compiled = result.expect("unwrap");
597 assert!(compiled.is_valid());
598 assert_eq!(compiled.stats.original_nodes, 3);
599 }
600
601 #[test]
602 fn test_graph_compiler_moderate() {
603 let graph = create_test_graph();
604 let mut compiler = GraphCompiler::new(CompilationConfig {
605 optimization_level: OptimizationLevel::Moderate,
606 ..Default::default()
607 });
608
609 let result = compiler.compile(&graph);
610 assert!(result.is_ok());
611
612 let compiled = result.expect("unwrap");
613 assert!(compiled.is_valid());
614 assert!(compiled.stats.compilation_time > Duration::from_secs(0));
615 }
616
617 #[test]
618 fn test_graph_compiler_aggressive() {
619 let graph = create_test_graph();
620 let mut compiler = GraphCompiler::new(CompilationConfig {
621 optimization_level: OptimizationLevel::Aggressive,
622 ..Default::default()
623 });
624
625 let result = compiler.compile(&graph);
626 assert!(result.is_ok());
627
628 let compiled = result.expect("unwrap");
629 assert!(compiled.is_valid());
630 assert_eq!(compiled.node_count(), compiled.stats.optimized_nodes);
631 }
632
633 #[test]
634 fn test_compiled_graph_summary() {
635 let graph = create_test_graph();
636 let mut compiler = GraphCompiler::with_default_config();
637 let compiled = compiler.compile(&graph).expect("unwrap");
638
639 let summary = compiled.summary();
640 assert!(summary.contains("CompiledGraph"));
641 assert!(summary.contains("nodes"));
642 assert!(summary.contains("MB"));
643 }
644
645 #[test]
646 fn test_compilation_cache_basic() {
647 let cache = CompilationCache::new(10);
648 assert_eq!(cache.len(), 0);
649 assert!(cache.is_empty());
650
651 let graph = create_test_graph();
652 let config = CompilationConfig::default();
653 let key = CompilationKey::new(&graph, &config);
654
655 assert!(cache.get(&key).is_none());
657
658 let mut compiler = GraphCompiler::with_default_config();
660 let compiled = compiler.compile(&graph).expect("unwrap");
661 cache.insert(key.clone(), compiled);
662
663 assert_eq!(cache.len(), 1);
664 assert!(!cache.is_empty());
665
666 let cached = cache.get(&key);
668 assert!(cached.is_some());
669 }
670
671 #[test]
672 fn test_compilation_cache_eviction() {
673 let cache = CompilationCache::new(2);
674
675 let graph1 = create_test_graph();
676 let mut graph2 = create_test_graph();
677 graph2.nodes.push(EinsumNode::new("i->i", vec![3], vec![4]));
678 let mut graph3 = create_test_graph();
679 graph3
680 .nodes
681 .push(EinsumNode::new("ij->ji", vec![3], vec![5]));
682
683 let config = CompilationConfig::default();
684 let mut compiler = GraphCompiler::with_default_config();
685
686 let key1 = CompilationKey::new(&graph1, &config);
687 let key2 = CompilationKey::new(&graph2, &config);
688 let key3 = CompilationKey::new(&graph3, &config);
689
690 cache.insert(key1.clone(), compiler.compile(&graph1).expect("unwrap"));
692 cache.insert(key2.clone(), compiler.compile(&graph2).expect("unwrap"));
693 assert_eq!(cache.len(), 2);
694
695 cache.insert(key3.clone(), compiler.compile(&graph3).expect("unwrap"));
697 assert_eq!(cache.len(), 2);
698 }
699
700 #[test]
701 fn test_compilation_cache_stats() {
702 let cache = CompilationCache::new(10);
703
704 let graph = create_test_graph();
705 let config = CompilationConfig::default();
706 let key = CompilationKey::new(&graph, &config);
707
708 let stats = cache.stats();
710 assert_eq!(stats.hits, 0);
711 assert_eq!(stats.misses, 0);
712 assert_eq!(stats.hit_rate(), 0.0);
713
714 cache.get(&key);
716 let stats = cache.stats();
717 assert_eq!(stats.misses, 1);
718
719 let mut compiler = GraphCompiler::with_default_config();
721 let compiled = compiler.compile(&graph).expect("unwrap");
722 cache.insert(key.clone(), compiled);
723 cache.get(&key);
724
725 let stats = cache.stats();
726 assert_eq!(stats.hits, 1);
727 assert_eq!(stats.misses, 1);
728 assert_eq!(stats.hit_rate(), 0.5);
729 }
730
731 #[test]
732 fn test_compilation_cache_clear() {
733 let cache = CompilationCache::new(10);
734 let graph = create_test_graph();
735 let config = CompilationConfig::default();
736 let key = CompilationKey::new(&graph, &config);
737
738 let mut compiler = GraphCompiler::with_default_config();
739 let compiled = compiler.compile(&graph).expect("unwrap");
740 cache.insert(key.clone(), compiled);
741
742 assert_eq!(cache.len(), 1);
743
744 cache.clear();
745 assert_eq!(cache.len(), 0);
746 assert!(cache.is_empty());
747 }
748
749 #[test]
750 fn test_optimization_levels() {
751 let graph = create_test_graph();
752
753 let levels = vec![
754 OptimizationLevel::None,
755 OptimizationLevel::Basic,
756 OptimizationLevel::Moderate,
757 OptimizationLevel::Aggressive,
758 ];
759
760 for level in levels {
761 let mut compiler = GraphCompiler::new(CompilationConfig {
762 optimization_level: level,
763 ..Default::default()
764 });
765
766 let result = compiler.compile(&graph);
767 assert!(result.is_ok(), "Compilation failed for level {:?}", level);
768
769 let compiled = result.expect("unwrap");
770 assert!(compiled.is_valid());
771 }
772 }
773
774 #[test]
775 fn test_compiled_graph_memory_estimation() {
776 let graph = create_test_graph();
777 let mut compiler = GraphCompiler::new(CompilationConfig {
778 enable_memory_estimation: true,
779 ..Default::default()
780 });
781
782 let compiled = compiler.compile(&graph).expect("unwrap");
783 let _memory = compiled.total_memory();
785 }
786
787 #[test]
788 fn test_config_update() {
789 let mut compiler = GraphCompiler::with_default_config();
790
791 let new_config = CompilationConfig {
792 optimization_level: OptimizationLevel::Aggressive,
793 enable_parallelism: false,
794 ..Default::default()
795 };
796
797 compiler.set_config(new_config.clone());
798
799 let config = compiler.config();
800 assert_eq!(config.optimization_level, OptimizationLevel::Aggressive);
801 assert!(!config.enable_parallelism);
802 }
803}