1use std::collections::{HashMap, HashSet};
9use tensorlogic_ir::{
10 fold_constants_aggressive, fuse_elementwise_operations, optimize_layouts, EinsumGraph,
11 EinsumNode, OpType,
12};
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
16pub enum ExecutionMode {
17 #[default]
20 Eager,
21
22 Graph,
25
26 Jit,
30}
31
32impl ExecutionMode {
33 pub fn is_eager(&self) -> bool {
35 matches!(self, ExecutionMode::Eager)
36 }
37
38 pub fn requires_compilation(&self) -> bool {
40 matches!(self, ExecutionMode::Graph | ExecutionMode::Jit)
41 }
42
43 pub fn description(&self) -> &'static str {
45 match self {
46 ExecutionMode::Eager => "Immediate execution with no compilation overhead",
47 ExecutionMode::Graph => "Graph compilation with optimization passes",
48 ExecutionMode::Jit => "Just-in-time compilation to native code",
49 }
50 }
51}
52
53impl std::fmt::Display for ExecutionMode {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 match self {
56 ExecutionMode::Eager => write!(f, "Eager"),
57 ExecutionMode::Graph => write!(f, "Graph"),
58 ExecutionMode::Jit => write!(f, "JIT"),
59 }
60 }
61}
62
63#[derive(Debug, Clone)]
68pub struct CompiledGraph {
69 pub original: EinsumGraph,
71
72 pub optimized: EinsumGraph,
74
75 pub memory_plan: Option<MemoryPlan>,
77
78 pub stats: CompilationStats,
80}
81
82#[derive(Debug, Clone)]
84pub struct MemoryPlan {
85 pub max_live_tensors: usize,
87
88 pub peak_memory_bytes: usize,
90
91 pub reuse_opportunities: Vec<(usize, usize)>, }
94
95#[derive(Debug, Clone)]
97pub struct OptimizationConfig {
98 pub enable_constant_folding: bool,
100
101 pub enable_fusion: bool,
103
104 pub enable_dce: bool,
106
107 pub enable_cse: bool,
109
110 pub enable_layout_opt: bool,
112
113 pub enable_memory_planning: bool,
115}
116
117impl Default for OptimizationConfig {
118 fn default() -> Self {
119 Self {
120 enable_constant_folding: true,
121 enable_fusion: true,
122 enable_dce: true,
123 enable_cse: true,
124 enable_layout_opt: true,
125 enable_memory_planning: true,
126 }
127 }
128}
129
130impl OptimizationConfig {
131 pub fn aggressive() -> Self {
133 Self::default()
134 }
135
136 pub fn conservative() -> Self {
138 Self {
139 enable_constant_folding: true,
140 enable_fusion: false,
141 enable_dce: true,
142 enable_cse: false,
143 enable_layout_opt: false,
144 enable_memory_planning: false,
145 }
146 }
147
148 pub fn none() -> Self {
150 Self {
151 enable_constant_folding: false,
152 enable_fusion: false,
153 enable_dce: false,
154 enable_cse: false,
155 enable_layout_opt: false,
156 enable_memory_planning: false,
157 }
158 }
159}
160
161#[derive(Debug, Clone, Default)]
163pub struct CompilationStats {
164 pub original_ops: usize,
166
167 pub optimized_ops: usize,
169
170 pub eliminated_ops: usize,
172
173 pub fused_ops: usize,
175
176 pub compilation_time_ms: f64,
178}
179
180impl CompiledGraph {
181 pub fn compile(graph: EinsumGraph) -> Self {
185 Self::compile_with_config(graph, &OptimizationConfig::default())
186 }
187
188 pub fn compile_with_config(graph: EinsumGraph, config: &OptimizationConfig) -> Self {
190 let start = std::time::Instant::now();
191 let original_ops = graph.nodes.len();
192
193 let mut optimized = graph.clone();
194 let mut fused_count = 0;
195 let mut eliminated_count = 0;
196
197 if config.enable_constant_folding {
199 if let Ok(_stats) = fold_constants_aggressive(&mut optimized) {
200 }
202 }
203
204 if config.enable_fusion {
206 if let Ok(stats) = fuse_elementwise_operations(&mut optimized) {
207 fused_count = stats.ops_fused;
208 }
209 }
210
211 if config.enable_dce {
213 if let Ok(removed) = eliminate_dead_code(&mut optimized) {
214 eliminated_count += removed;
215 }
216 }
217
218 if config.enable_cse {
220 if let Ok(removed) = eliminate_common_subexpressions(&mut optimized) {
221 eliminated_count += removed;
222 }
223 }
224
225 if config.enable_layout_opt {
227 if let Ok(_result) = optimize_layouts(&optimized) {
228 }
230 }
231
232 let optimized_ops = optimized.nodes.len();
233 let compilation_time_ms = start.elapsed().as_secs_f64() * 1000.0;
234
235 let memory_plan = if config.enable_memory_planning {
237 Some(compute_memory_plan(&optimized))
238 } else {
239 None
240 };
241
242 CompiledGraph {
243 original: graph,
244 optimized,
245 memory_plan,
246 stats: CompilationStats {
247 original_ops,
248 optimized_ops,
249 eliminated_ops: eliminated_count,
250 fused_ops: fused_count,
251 compilation_time_ms,
252 },
253 }
254 }
255
256 pub fn graph(&self) -> &EinsumGraph {
258 &self.optimized
259 }
260
261 pub fn stats(&self) -> &CompilationStats {
263 &self.stats
264 }
265}
266
267impl std::fmt::Display for CompilationStats {
268 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
269 write!(
270 f,
271 "CompilationStats {{ original: {}, optimized: {}, eliminated: {}, fused: {}, time: {:.2}ms }}",
272 self.original_ops,
273 self.optimized_ops,
274 self.eliminated_ops,
275 self.fused_ops,
276 self.compilation_time_ms
277 )
278 }
279}
280
281#[derive(Debug, Clone)]
283pub struct ExecutionConfig {
284 pub mode: ExecutionMode,
286
287 pub enable_optimizations: bool,
289
290 pub enable_memory_planning: bool,
292}
293
294impl Default for ExecutionConfig {
295 fn default() -> Self {
296 Self {
297 mode: ExecutionMode::Eager,
298 enable_optimizations: true,
299 enable_memory_planning: true,
300 }
301 }
302}
303
304impl ExecutionConfig {
305 pub fn eager() -> Self {
307 Self {
308 mode: ExecutionMode::Eager,
309 enable_optimizations: false,
310 enable_memory_planning: false,
311 }
312 }
313
314 pub fn graph() -> Self {
316 Self {
317 mode: ExecutionMode::Graph,
318 enable_optimizations: true,
319 enable_memory_planning: true,
320 }
321 }
322
323 pub fn with_optimizations(mut self, enable: bool) -> Self {
325 self.enable_optimizations = enable;
326 self
327 }
328
329 pub fn with_memory_planning(mut self, enable: bool) -> Self {
331 self.enable_memory_planning = enable;
332 self
333 }
334}
335
336fn eliminate_dead_code(graph: &mut EinsumGraph) -> Result<usize, String> {
338 if graph.outputs.is_empty() {
339 return Ok(0);
340 }
341
342 let mut live_tensors = HashSet::new();
344 let mut worklist: Vec<usize> = graph.outputs.clone();
345
346 for &output_idx in &graph.outputs {
348 live_tensors.insert(output_idx);
349 }
350
351 let mut tensor_producers: HashMap<usize, usize> = HashMap::new();
353 for (node_idx, node) in graph.nodes.iter().enumerate() {
354 for &output_idx in &node.outputs {
355 tensor_producers.insert(output_idx, node_idx);
356 }
357 }
358
359 while let Some(tensor_idx) = worklist.pop() {
361 if let Some(&node_idx) = tensor_producers.get(&tensor_idx) {
362 let node = &graph.nodes[node_idx];
363 for &input_idx in &node.inputs {
364 if !live_tensors.contains(&input_idx) {
365 live_tensors.insert(input_idx);
366 worklist.push(input_idx);
367 }
368 }
369 }
370 }
371
372 let initial_count = graph.nodes.len();
374 let mut nodes_to_keep = Vec::new();
375 for node in &graph.nodes {
376 let all_outputs_live = node
377 .outputs
378 .iter()
379 .any(|out_idx| live_tensors.contains(out_idx));
380 if all_outputs_live {
381 nodes_to_keep.push(node.clone());
382 }
383 }
384
385 graph.nodes = nodes_to_keep;
386 let removed_count = initial_count - graph.nodes.len();
387
388 Ok(removed_count)
389}
390
391fn eliminate_common_subexpressions(graph: &mut EinsumGraph) -> Result<usize, String> {
393 let mut node_hashes: HashMap<String, usize> = HashMap::new();
394 let mut replacements: HashMap<usize, usize> = HashMap::new();
395 let mut eliminated_count = 0;
396
397 for (node_idx, node) in graph.nodes.iter().enumerate() {
399 let node_hash = compute_node_hash(node);
400
401 if let Some(&existing_idx) = node_hashes.get(&node_hash) {
402 if !node.outputs.is_empty() && !graph.nodes[existing_idx].outputs.is_empty() {
404 let produced_tensor_idx = node.outputs[0];
405 let existing_tensor_idx = graph.nodes[existing_idx].outputs[0];
406 replacements.insert(produced_tensor_idx, existing_tensor_idx);
407 eliminated_count += 1;
408 }
409 } else {
410 node_hashes.insert(node_hash, node_idx);
411 }
412 }
413
414 if !replacements.is_empty() {
416 for node in &mut graph.nodes {
417 for input_idx in &mut node.inputs {
418 if let Some(&replacement_idx) = replacements.get(input_idx) {
419 *input_idx = replacement_idx;
420 }
421 }
422 }
423
424 for output_idx in &mut graph.outputs {
426 if let Some(&replacement_idx) = replacements.get(output_idx) {
427 *output_idx = replacement_idx;
428 }
429 }
430 }
431
432 Ok(eliminated_count)
433}
434
435fn compute_node_hash(node: &EinsumNode) -> String {
437 let op_str = match &node.op {
438 OpType::Einsum { spec } => format!("einsum:{}", spec),
439 OpType::ElemUnary { op } => format!("unary:{}", op),
440 OpType::ElemBinary { op } => format!("binary:{}", op),
441 OpType::Reduce { op, axes } => format!("reduce:{}:{:?}", op, axes),
442 };
443
444 format!("{}|inputs:{:?}", op_str, node.inputs)
445}
446
447fn compute_memory_plan(graph: &EinsumGraph) -> MemoryPlan {
449 let total_tensors = graph.tensors.len();
451 let mut live_at_step: Vec<HashSet<usize>> = Vec::new();
452 let mut current_live = HashSet::new();
453
454 for &input_idx in &graph.inputs {
456 current_live.insert(input_idx);
457 }
458
459 for node in &graph.nodes {
461 for &output_idx in &node.outputs {
463 current_live.insert(output_idx);
464 }
465
466 for &input_idx in &node.inputs {
468 let mut still_needed = false;
469 for later_node in graph.nodes.iter().skip(1) {
471 if later_node.inputs.contains(&input_idx) {
472 still_needed = true;
473 break;
474 }
475 }
476 if graph.outputs.contains(&input_idx) {
478 still_needed = true;
479 }
480 if !still_needed {
481 current_live.remove(&input_idx);
482 }
483 }
484
485 live_at_step.push(current_live.clone());
486 }
487
488 let max_live_tensors = live_at_step
490 .iter()
491 .map(|live_set| live_set.len())
492 .max()
493 .unwrap_or(0);
494
495 let avg_tensor_size = 8 * 1000; let peak_memory_bytes = max_live_tensors * avg_tensor_size;
498
499 let mut reuse_opportunities = Vec::new();
501 for i in 0..total_tensors {
502 for j in (i + 1)..total_tensors {
503 let mut i_live = false;
505 let mut j_live = false;
506 let mut overlap = false;
507
508 for live_set in &live_at_step {
509 let i_in_this = live_set.contains(&i);
510 let j_in_this = live_set.contains(&j);
511
512 if i_in_this {
513 i_live = true;
514 }
515 if j_in_this {
516 j_live = true;
517 }
518 if i_in_this && j_in_this {
519 overlap = true;
520 break;
521 }
522 }
523
524 if i_live && j_live && !overlap {
525 reuse_opportunities.push((i, j));
526 }
527 }
528 }
529
530 MemoryPlan {
531 max_live_tensors,
532 peak_memory_bytes,
533 reuse_opportunities,
534 }
535}
536
537#[cfg(test)]
538mod tests {
539 use super::*;
540
541 #[test]
542 fn test_execution_mode_default() {
543 let mode = ExecutionMode::default();
544 assert_eq!(mode, ExecutionMode::Eager);
545 assert!(mode.is_eager());
546 assert!(!mode.requires_compilation());
547 }
548
549 #[test]
550 fn test_execution_mode_properties() {
551 assert!(ExecutionMode::Eager.is_eager());
552 assert!(!ExecutionMode::Graph.is_eager());
553 assert!(!ExecutionMode::Jit.is_eager());
554
555 assert!(!ExecutionMode::Eager.requires_compilation());
556 assert!(ExecutionMode::Graph.requires_compilation());
557 assert!(ExecutionMode::Jit.requires_compilation());
558 }
559
560 #[test]
561 fn test_execution_mode_display() {
562 assert_eq!(ExecutionMode::Eager.to_string(), "Eager");
563 assert_eq!(ExecutionMode::Graph.to_string(), "Graph");
564 assert_eq!(ExecutionMode::Jit.to_string(), "JIT");
565 }
566
567 #[test]
568 fn test_execution_config_default() {
569 let config = ExecutionConfig::default();
570 assert_eq!(config.mode, ExecutionMode::Eager);
571 assert!(config.enable_optimizations);
572 assert!(config.enable_memory_planning);
573 }
574
575 #[test]
576 fn test_execution_config_eager() {
577 let config = ExecutionConfig::eager();
578 assert_eq!(config.mode, ExecutionMode::Eager);
579 assert!(!config.enable_optimizations);
580 assert!(!config.enable_memory_planning);
581 }
582
583 #[test]
584 fn test_execution_config_graph() {
585 let config = ExecutionConfig::graph();
586 assert_eq!(config.mode, ExecutionMode::Graph);
587 assert!(config.enable_optimizations);
588 assert!(config.enable_memory_planning);
589 }
590
591 #[test]
592 fn test_execution_config_builder() {
593 let config = ExecutionConfig::graph()
594 .with_optimizations(false)
595 .with_memory_planning(false);
596
597 assert_eq!(config.mode, ExecutionMode::Graph);
598 assert!(!config.enable_optimizations);
599 assert!(!config.enable_memory_planning);
600 }
601
602 #[test]
603 fn test_compiled_graph_basic() {
604 use tensorlogic_ir::{EinsumNode, OpType};
605
606 let mut graph = EinsumGraph::new();
607 let a_idx = graph.add_tensor("a");
608 let b_idx = graph.add_tensor("b");
609
610 graph.add_input(a_idx).unwrap();
611 graph
612 .add_node(EinsumNode {
613 op: OpType::ElemUnary {
614 op: "relu".to_string(),
615 },
616 inputs: vec![a_idx],
617 outputs: vec![b_idx],
618 metadata: None,
619 })
620 .unwrap();
621 graph.add_output(b_idx).unwrap();
622
623 let compiled = CompiledGraph::compile(graph);
624
625 assert_eq!(compiled.stats.original_ops, 1);
626 assert_eq!(compiled.stats.optimized_ops, 1);
627 assert_eq!(compiled.stats.eliminated_ops, 0);
628 }
629
630 #[test]
631 fn test_compilation_stats_display() {
632 let stats = CompilationStats {
633 original_ops: 10,
634 optimized_ops: 8,
635 eliminated_ops: 2,
636 fused_ops: 1,
637 compilation_time_ms: 1.5,
638 };
639
640 let display = stats.to_string();
641 assert!(display.contains("original: 10"));
642 assert!(display.contains("optimized: 8"));
643 assert!(display.contains("eliminated: 2"));
644 }
645
646 #[test]
647 fn test_optimization_config_default() {
648 let config = OptimizationConfig::default();
649 assert!(config.enable_constant_folding);
650 assert!(config.enable_fusion);
651 assert!(config.enable_dce);
652 assert!(config.enable_cse);
653 assert!(config.enable_layout_opt);
654 assert!(config.enable_memory_planning);
655 }
656
657 #[test]
658 fn test_optimization_config_aggressive() {
659 let config = OptimizationConfig::aggressive();
660 assert!(config.enable_constant_folding);
661 assert!(config.enable_fusion);
662 assert!(config.enable_dce);
663 assert!(config.enable_cse);
664 assert!(config.enable_layout_opt);
665 assert!(config.enable_memory_planning);
666 }
667
668 #[test]
669 fn test_optimization_config_conservative() {
670 let config = OptimizationConfig::conservative();
671 assert!(config.enable_constant_folding);
672 assert!(!config.enable_fusion);
673 assert!(config.enable_dce);
674 assert!(!config.enable_cse);
675 assert!(!config.enable_layout_opt);
676 assert!(!config.enable_memory_planning);
677 }
678
679 #[test]
680 fn test_optimization_config_none() {
681 let config = OptimizationConfig::none();
682 assert!(!config.enable_constant_folding);
683 assert!(!config.enable_fusion);
684 assert!(!config.enable_dce);
685 assert!(!config.enable_cse);
686 assert!(!config.enable_layout_opt);
687 assert!(!config.enable_memory_planning);
688 }
689
690 #[test]
691 fn test_compiled_graph_with_optimization() {
692 use tensorlogic_ir::{EinsumNode, OpType};
693
694 let mut graph = EinsumGraph::new();
695 let a_idx = graph.add_tensor("a");
696 let b_idx = graph.add_tensor("b");
697 let c_idx = graph.add_tensor("c");
698
699 graph.add_input(a_idx).unwrap();
700
701 graph
703 .add_node(EinsumNode {
704 op: OpType::ElemUnary {
705 op: "relu".to_string(),
706 },
707 inputs: vec![a_idx],
708 outputs: vec![b_idx],
709 metadata: None,
710 })
711 .unwrap();
712
713 graph
715 .add_node(EinsumNode {
716 op: OpType::ElemUnary {
717 op: "relu".to_string(),
718 },
719 inputs: vec![a_idx],
720 outputs: vec![c_idx],
721 metadata: None,
722 })
723 .unwrap();
724
725 graph.add_output(b_idx).unwrap();
726
727 let compiled = CompiledGraph::compile(graph);
728
729 assert_eq!(compiled.stats.original_ops, 2);
730 assert!(compiled.stats.compilation_time_ms >= 0.0);
732 }
733
734 #[test]
735 fn test_compiled_graph_with_custom_config() {
736 use tensorlogic_ir::{EinsumNode, OpType};
737
738 let mut graph = EinsumGraph::new();
739 let a_idx = graph.add_tensor("a");
740 let b_idx = graph.add_tensor("b");
741
742 graph.add_input(a_idx).unwrap();
743 graph
744 .add_node(EinsumNode {
745 op: OpType::ElemUnary {
746 op: "relu".to_string(),
747 },
748 inputs: vec![a_idx],
749 outputs: vec![b_idx],
750 metadata: None,
751 })
752 .unwrap();
753 graph.add_output(b_idx).unwrap();
754
755 let config = OptimizationConfig::none();
756 let compiled = CompiledGraph::compile_with_config(graph, &config);
757
758 assert_eq!(compiled.stats.original_ops, 1);
759 assert_eq!(compiled.stats.optimized_ops, 1);
760 assert_eq!(compiled.stats.eliminated_ops, 0);
761 assert_eq!(compiled.stats.fused_ops, 0);
762 assert!(compiled.memory_plan.is_none());
763 }
764
765 #[test]
766 fn test_memory_plan_basic() {
767 use tensorlogic_ir::{EinsumNode, OpType};
768
769 let mut graph = EinsumGraph::new();
770 let a_idx = graph.add_tensor("a");
771 let b_idx = graph.add_tensor("b");
772 let c_idx = graph.add_tensor("c");
773
774 graph.add_input(a_idx).unwrap();
775 graph
776 .add_node(EinsumNode {
777 op: OpType::ElemUnary {
778 op: "relu".to_string(),
779 },
780 inputs: vec![a_idx],
781 outputs: vec![b_idx],
782 metadata: None,
783 })
784 .unwrap();
785 graph
786 .add_node(EinsumNode {
787 op: OpType::ElemUnary {
788 op: "sigmoid".to_string(),
789 },
790 inputs: vec![b_idx],
791 outputs: vec![c_idx],
792 metadata: None,
793 })
794 .unwrap();
795 graph.add_output(c_idx).unwrap();
796
797 let compiled = CompiledGraph::compile(graph);
798
799 assert!(compiled.memory_plan.is_some());
800 let plan = compiled.memory_plan.unwrap();
801 assert!(plan.max_live_tensors > 0);
802 assert!(plan.peak_memory_bytes > 0);
803 }
804
805 #[test]
806 fn test_dce_removes_dead_code() {
807 use tensorlogic_ir::{EinsumNode, OpType};
808
809 let mut graph = EinsumGraph::new();
810 let a_idx = graph.add_tensor("a");
811 let b_idx = graph.add_tensor("b");
812 let c_idx = graph.add_tensor("c");
813 let d_idx = graph.add_tensor("d");
814
815 graph.add_input(a_idx).unwrap();
816
817 graph
819 .add_node(EinsumNode {
820 op: OpType::ElemUnary {
821 op: "relu".to_string(),
822 },
823 inputs: vec![a_idx],
824 outputs: vec![b_idx],
825 metadata: None,
826 })
827 .unwrap();
828
829 graph
831 .add_node(EinsumNode {
832 op: OpType::ElemUnary {
833 op: "sigmoid".to_string(),
834 },
835 inputs: vec![a_idx],
836 outputs: vec![c_idx],
837 metadata: None,
838 })
839 .unwrap();
840
841 graph
843 .add_node(EinsumNode {
844 op: OpType::ElemUnary {
845 op: "oneminus".to_string(),
846 },
847 inputs: vec![b_idx],
848 outputs: vec![d_idx],
849 metadata: None,
850 })
851 .unwrap();
852
853 graph.add_output(d_idx).unwrap();
854
855 let initial_nodes = graph.nodes.len();
856 let removed = eliminate_dead_code(&mut graph).unwrap();
857
858 assert!(removed > 0 || graph.nodes.len() < initial_nodes);
860 }
861
862 #[test]
863 fn test_cse_deduplicates_nodes() {
864 use tensorlogic_ir::{EinsumNode, OpType};
865
866 let mut graph = EinsumGraph::new();
867 let a_idx = graph.add_tensor("a");
868 let b_idx = graph.add_tensor("b");
869 let c_idx = graph.add_tensor("c");
870
871 graph.add_input(a_idx).unwrap();
872
873 graph
875 .add_node(EinsumNode {
876 op: OpType::ElemUnary {
877 op: "relu".to_string(),
878 },
879 inputs: vec![a_idx],
880 outputs: vec![b_idx],
881 metadata: None,
882 })
883 .unwrap();
884
885 graph
887 .add_node(EinsumNode {
888 op: OpType::ElemUnary {
889 op: "relu".to_string(),
890 },
891 inputs: vec![a_idx],
892 outputs: vec![c_idx],
893 metadata: None,
894 })
895 .unwrap();
896
897 graph.add_output(b_idx).unwrap();
898 graph.add_output(c_idx).unwrap();
899
900 let eliminated = eliminate_common_subexpressions(&mut graph).unwrap();
901
902 let _ = eliminated; }
906}