Skip to main content

tensorlogic_scirs_backend/
execution_mode.rs

1//! Execution mode abstractions for different execution strategies.
2//!
3//! This module provides infrastructure for multiple execution modes:
4//! - **Eager**: Immediate execution (default, already implemented)
5//! - **Graph**: Graph compilation and optimization
6//! - **JIT**: Just-in-time compilation (future)
7
8use std::collections::{HashMap, HashSet};
9use tensorlogic_ir::{
10    fold_constants_aggressive, fuse_elementwise_operations, optimize_layouts, EinsumGraph,
11    EinsumNode, OpType,
12};
13
14/// Execution mode for the backend.
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
16pub enum ExecutionMode {
17    /// Eager execution: operations execute immediately as they're called.
18    /// This is the default mode and provides the best debugging experience.
19    #[default]
20    Eager,
21
22    /// Graph mode: operations are compiled into an optimized graph before execution.
23    /// This mode enables graph-level optimizations like operation fusion and memory planning.
24    Graph,
25
26    /// JIT mode: operations are compiled to native code at runtime.
27    /// This mode provides the best performance but has compilation overhead.
28    /// Currently not implemented.
29    Jit,
30}
31
32impl ExecutionMode {
33    /// Returns true if this mode is eager execution.
34    pub fn is_eager(&self) -> bool {
35        matches!(self, ExecutionMode::Eager)
36    }
37
38    /// Returns true if this mode requires graph compilation.
39    pub fn requires_compilation(&self) -> bool {
40        matches!(self, ExecutionMode::Graph | ExecutionMode::Jit)
41    }
42
43    /// Returns a human-readable description of this mode.
44    pub fn description(&self) -> &'static str {
45        match self {
46            ExecutionMode::Eager => "Immediate execution with no compilation overhead",
47            ExecutionMode::Graph => "Graph compilation with optimization passes",
48            ExecutionMode::Jit => "Just-in-time compilation to native code",
49        }
50    }
51}
52
53impl std::fmt::Display for ExecutionMode {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        match self {
56            ExecutionMode::Eager => write!(f, "Eager"),
57            ExecutionMode::Graph => write!(f, "Graph"),
58            ExecutionMode::Jit => write!(f, "JIT"),
59        }
60    }
61}
62
63/// Compiled graph for optimized execution.
64///
65/// In Graph mode, the EinsumGraph is analyzed and optimized before execution.
66/// This structure holds the compiled representation.
67#[derive(Debug, Clone)]
68pub struct CompiledGraph {
69    /// Original graph
70    pub original: EinsumGraph,
71
72    /// Optimized graph (after passes like fusion, CSE, DCE)
73    pub optimized: EinsumGraph,
74
75    /// Memory plan for tensor allocation
76    pub memory_plan: Option<MemoryPlan>,
77
78    /// Compilation statistics
79    pub stats: CompilationStats,
80}
81
82/// Memory allocation plan for optimized execution.
83#[derive(Debug, Clone)]
84pub struct MemoryPlan {
85    /// Maximum number of tensors alive at any point
86    pub max_live_tensors: usize,
87
88    /// Peak memory usage estimate (in bytes)
89    pub peak_memory_bytes: usize,
90
91    /// Tensor reuse opportunities
92    pub reuse_opportunities: Vec<(usize, usize)>, // (source_tensor, dest_tensor)
93}
94
95/// Configuration for graph optimization passes.
96#[derive(Debug, Clone)]
97pub struct OptimizationConfig {
98    /// Enable constant folding
99    pub enable_constant_folding: bool,
100
101    /// Enable operation fusion
102    pub enable_fusion: bool,
103
104    /// Enable dead code elimination
105    pub enable_dce: bool,
106
107    /// Enable common subexpression elimination
108    pub enable_cse: bool,
109
110    /// Enable layout optimization
111    pub enable_layout_opt: bool,
112
113    /// Enable memory planning
114    pub enable_memory_planning: bool,
115}
116
117impl Default for OptimizationConfig {
118    fn default() -> Self {
119        Self {
120            enable_constant_folding: true,
121            enable_fusion: true,
122            enable_dce: true,
123            enable_cse: true,
124            enable_layout_opt: true,
125            enable_memory_planning: true,
126        }
127    }
128}
129
130impl OptimizationConfig {
131    /// Create a new configuration with all optimizations enabled.
132    pub fn aggressive() -> Self {
133        Self::default()
134    }
135
136    /// Create a new configuration with only safe optimizations.
137    pub fn conservative() -> Self {
138        Self {
139            enable_constant_folding: true,
140            enable_fusion: false,
141            enable_dce: true,
142            enable_cse: false,
143            enable_layout_opt: false,
144            enable_memory_planning: false,
145        }
146    }
147
148    /// Create a new configuration with no optimizations.
149    pub fn none() -> Self {
150        Self {
151            enable_constant_folding: false,
152            enable_fusion: false,
153            enable_dce: false,
154            enable_cse: false,
155            enable_layout_opt: false,
156            enable_memory_planning: false,
157        }
158    }
159}
160
161/// Statistics from graph compilation.
162#[derive(Debug, Clone, Default)]
163pub struct CompilationStats {
164    /// Number of operations in original graph
165    pub original_ops: usize,
166
167    /// Number of operations after optimization
168    pub optimized_ops: usize,
169
170    /// Number of operations eliminated
171    pub eliminated_ops: usize,
172
173    /// Number of operations fused
174    pub fused_ops: usize,
175
176    /// Compilation time in milliseconds
177    pub compilation_time_ms: f64,
178}
179
180impl CompiledGraph {
181    /// Create a new compiled graph from an EinsumGraph.
182    ///
183    /// This performs optimization passes on the graph.
184    pub fn compile(graph: EinsumGraph) -> Self {
185        Self::compile_with_config(graph, &OptimizationConfig::default())
186    }
187
188    /// Create a new compiled graph with custom optimization configuration.
189    pub fn compile_with_config(graph: EinsumGraph, config: &OptimizationConfig) -> Self {
190        let start = std::time::Instant::now();
191        let original_ops = graph.nodes.len();
192
193        let mut optimized = graph.clone();
194        let mut fused_count = 0;
195        let mut eliminated_count = 0;
196
197        // Phase 1: Constant folding (if enabled)
198        if config.enable_constant_folding {
199            if let Ok(_stats) = fold_constants_aggressive(&mut optimized) {
200                // Constant folding succeeded
201            }
202        }
203
204        // Phase 2: Operation fusion (if enabled)
205        if config.enable_fusion {
206            if let Ok(stats) = fuse_elementwise_operations(&mut optimized) {
207                fused_count = stats.ops_fused;
208            }
209        }
210
211        // Phase 3: Dead code elimination (if enabled)
212        if config.enable_dce {
213            if let Ok(removed) = eliminate_dead_code(&mut optimized) {
214                eliminated_count += removed;
215            }
216        }
217
218        // Phase 4: Common subexpression elimination (if enabled)
219        if config.enable_cse {
220            if let Ok(removed) = eliminate_common_subexpressions(&mut optimized) {
221                eliminated_count += removed;
222            }
223        }
224
225        // Phase 5: Layout optimization (if enabled)
226        if config.enable_layout_opt {
227            if let Ok(_result) = optimize_layouts(&optimized) {
228                // Layout optimization succeeded
229            }
230        }
231
232        let optimized_ops = optimized.nodes.len();
233        let compilation_time_ms = start.elapsed().as_secs_f64() * 1000.0;
234
235        // Phase 6: Memory planning (if enabled)
236        let memory_plan = if config.enable_memory_planning {
237            Some(compute_memory_plan(&optimized))
238        } else {
239            None
240        };
241
242        CompiledGraph {
243            original: graph,
244            optimized,
245            memory_plan,
246            stats: CompilationStats {
247                original_ops,
248                optimized_ops,
249                eliminated_ops: eliminated_count,
250                fused_ops: fused_count,
251                compilation_time_ms,
252            },
253        }
254    }
255
256    /// Get the graph to execute (optimized version).
257    pub fn graph(&self) -> &EinsumGraph {
258        &self.optimized
259    }
260
261    /// Get compilation statistics.
262    pub fn stats(&self) -> &CompilationStats {
263        &self.stats
264    }
265}
266
267impl std::fmt::Display for CompilationStats {
268    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
269        write!(
270            f,
271            "CompilationStats {{ original: {}, optimized: {}, eliminated: {}, fused: {}, time: {:.2}ms }}",
272            self.original_ops,
273            self.optimized_ops,
274            self.eliminated_ops,
275            self.fused_ops,
276            self.compilation_time_ms
277        )
278    }
279}
280
281/// Execution configuration combining mode and device settings.
282#[derive(Debug, Clone)]
283pub struct ExecutionConfig {
284    /// Execution mode
285    pub mode: ExecutionMode,
286
287    /// Enable graph optimizations (only applies to Graph mode)
288    pub enable_optimizations: bool,
289
290    /// Enable memory planning (only applies to Graph mode)
291    pub enable_memory_planning: bool,
292}
293
294impl Default for ExecutionConfig {
295    fn default() -> Self {
296        Self {
297            mode: ExecutionMode::Eager,
298            enable_optimizations: true,
299            enable_memory_planning: true,
300        }
301    }
302}
303
304impl ExecutionConfig {
305    /// Create a new configuration with eager mode.
306    pub fn eager() -> Self {
307        Self {
308            mode: ExecutionMode::Eager,
309            enable_optimizations: false,
310            enable_memory_planning: false,
311        }
312    }
313
314    /// Create a new configuration with graph mode.
315    pub fn graph() -> Self {
316        Self {
317            mode: ExecutionMode::Graph,
318            enable_optimizations: true,
319            enable_memory_planning: true,
320        }
321    }
322
323    /// Enable or disable optimizations.
324    pub fn with_optimizations(mut self, enable: bool) -> Self {
325        self.enable_optimizations = enable;
326        self
327    }
328
329    /// Enable or disable memory planning.
330    pub fn with_memory_planning(mut self, enable: bool) -> Self {
331        self.enable_memory_planning = enable;
332        self
333    }
334}
335
336/// Dead Code Elimination (DCE) - removes unused tensors and nodes.
337fn eliminate_dead_code(graph: &mut EinsumGraph) -> Result<usize, String> {
338    if graph.outputs.is_empty() {
339        return Ok(0);
340    }
341
342    // Track which tensors are live (needed)
343    let mut live_tensors = HashSet::new();
344    let mut worklist: Vec<usize> = graph.outputs.clone();
345
346    // Mark all output tensors as live
347    for &output_idx in &graph.outputs {
348        live_tensors.insert(output_idx);
349    }
350
351    // Build tensor-to-node mapping (which node produces each tensor)
352    let mut tensor_producers: HashMap<usize, usize> = HashMap::new();
353    for (node_idx, node) in graph.nodes.iter().enumerate() {
354        for &output_idx in &node.outputs {
355            tensor_producers.insert(output_idx, node_idx);
356        }
357    }
358
359    // Backward pass: mark all dependencies as live
360    while let Some(tensor_idx) = worklist.pop() {
361        if let Some(&node_idx) = tensor_producers.get(&tensor_idx) {
362            let node = &graph.nodes[node_idx];
363            for &input_idx in &node.inputs {
364                if !live_tensors.contains(&input_idx) {
365                    live_tensors.insert(input_idx);
366                    worklist.push(input_idx);
367                }
368            }
369        }
370    }
371
372    // Remove dead nodes (nodes whose output is not live)
373    let initial_count = graph.nodes.len();
374    let mut nodes_to_keep = Vec::new();
375    for node in &graph.nodes {
376        let all_outputs_live = node
377            .outputs
378            .iter()
379            .any(|out_idx| live_tensors.contains(out_idx));
380        if all_outputs_live {
381            nodes_to_keep.push(node.clone());
382        }
383    }
384
385    graph.nodes = nodes_to_keep;
386    let removed_count = initial_count - graph.nodes.len();
387
388    Ok(removed_count)
389}
390
391/// Common Subexpression Elimination (CSE) - detects and deduplicates identical subgraphs.
392fn eliminate_common_subexpressions(graph: &mut EinsumGraph) -> Result<usize, String> {
393    let mut node_hashes: HashMap<String, usize> = HashMap::new();
394    let mut replacements: HashMap<usize, usize> = HashMap::new();
395    let mut eliminated_count = 0;
396
397    // Build hash for each node (based on operation and inputs)
398    for (node_idx, node) in graph.nodes.iter().enumerate() {
399        let node_hash = compute_node_hash(node);
400
401        if let Some(&existing_idx) = node_hashes.get(&node_hash) {
402            // Found a duplicate - mark for replacement
403            if !node.outputs.is_empty() && !graph.nodes[existing_idx].outputs.is_empty() {
404                let produced_tensor_idx = node.outputs[0];
405                let existing_tensor_idx = graph.nodes[existing_idx].outputs[0];
406                replacements.insert(produced_tensor_idx, existing_tensor_idx);
407                eliminated_count += 1;
408            }
409        } else {
410            node_hashes.insert(node_hash, node_idx);
411        }
412    }
413
414    // Apply replacements (update all node inputs that reference eliminated tensors)
415    if !replacements.is_empty() {
416        for node in &mut graph.nodes {
417            for input_idx in &mut node.inputs {
418                if let Some(&replacement_idx) = replacements.get(input_idx) {
419                    *input_idx = replacement_idx;
420                }
421            }
422        }
423
424        // Update outputs
425        for output_idx in &mut graph.outputs {
426            if let Some(&replacement_idx) = replacements.get(output_idx) {
427                *output_idx = replacement_idx;
428            }
429        }
430    }
431
432    Ok(eliminated_count)
433}
434
435/// Compute a hash for a node based on its operation and inputs.
436fn compute_node_hash(node: &EinsumNode) -> String {
437    let op_str = match &node.op {
438        OpType::Einsum { spec } => format!("einsum:{}", spec),
439        OpType::ElemUnary { op } => format!("unary:{}", op),
440        OpType::ElemBinary { op } => format!("binary:{}", op),
441        OpType::Reduce { op, axes } => format!("reduce:{}:{:?}", op, axes),
442    };
443
444    format!("{}|inputs:{:?}", op_str, node.inputs)
445}
446
447/// Compute memory plan for a graph.
448fn compute_memory_plan(graph: &EinsumGraph) -> MemoryPlan {
449    // Build liveness analysis
450    let total_tensors = graph.tensors.len();
451    let mut live_at_step: Vec<HashSet<usize>> = Vec::new();
452    let mut current_live = HashSet::new();
453
454    // Add input tensors as initially live
455    for &input_idx in &graph.inputs {
456        current_live.insert(input_idx);
457    }
458
459    // Process each node in order
460    for node in &graph.nodes {
461        // Mark outputs as live
462        for &output_idx in &node.outputs {
463            current_live.insert(output_idx);
464        }
465
466        // Check if inputs are still needed later
467        for &input_idx in &node.inputs {
468            let mut still_needed = false;
469            // Check if this input is used by later nodes
470            for later_node in graph.nodes.iter().skip(1) {
471                if later_node.inputs.contains(&input_idx) {
472                    still_needed = true;
473                    break;
474                }
475            }
476            // Check if it's an output
477            if graph.outputs.contains(&input_idx) {
478                still_needed = true;
479            }
480            if !still_needed {
481                current_live.remove(&input_idx);
482            }
483        }
484
485        live_at_step.push(current_live.clone());
486    }
487
488    // Compute max live tensors
489    let max_live_tensors = live_at_step
490        .iter()
491        .map(|live_set| live_set.len())
492        .max()
493        .unwrap_or(0);
494
495    // Estimate peak memory (assuming 8 bytes per element, 1000 elements per tensor on average)
496    let avg_tensor_size = 8 * 1000; // 8KB average
497    let peak_memory_bytes = max_live_tensors * avg_tensor_size;
498
499    // Identify reuse opportunities (tensors with non-overlapping lifetimes)
500    let mut reuse_opportunities = Vec::new();
501    for i in 0..total_tensors {
502        for j in (i + 1)..total_tensors {
503            // Check if lifetimes don't overlap
504            let mut i_live = false;
505            let mut j_live = false;
506            let mut overlap = false;
507
508            for live_set in &live_at_step {
509                let i_in_this = live_set.contains(&i);
510                let j_in_this = live_set.contains(&j);
511
512                if i_in_this {
513                    i_live = true;
514                }
515                if j_in_this {
516                    j_live = true;
517                }
518                if i_in_this && j_in_this {
519                    overlap = true;
520                    break;
521                }
522            }
523
524            if i_live && j_live && !overlap {
525                reuse_opportunities.push((i, j));
526            }
527        }
528    }
529
530    MemoryPlan {
531        max_live_tensors,
532        peak_memory_bytes,
533        reuse_opportunities,
534    }
535}
536
537#[cfg(test)]
538mod tests {
539    use super::*;
540
541    #[test]
542    fn test_execution_mode_default() {
543        let mode = ExecutionMode::default();
544        assert_eq!(mode, ExecutionMode::Eager);
545        assert!(mode.is_eager());
546        assert!(!mode.requires_compilation());
547    }
548
549    #[test]
550    fn test_execution_mode_properties() {
551        assert!(ExecutionMode::Eager.is_eager());
552        assert!(!ExecutionMode::Graph.is_eager());
553        assert!(!ExecutionMode::Jit.is_eager());
554
555        assert!(!ExecutionMode::Eager.requires_compilation());
556        assert!(ExecutionMode::Graph.requires_compilation());
557        assert!(ExecutionMode::Jit.requires_compilation());
558    }
559
560    #[test]
561    fn test_execution_mode_display() {
562        assert_eq!(ExecutionMode::Eager.to_string(), "Eager");
563        assert_eq!(ExecutionMode::Graph.to_string(), "Graph");
564        assert_eq!(ExecutionMode::Jit.to_string(), "JIT");
565    }
566
567    #[test]
568    fn test_execution_config_default() {
569        let config = ExecutionConfig::default();
570        assert_eq!(config.mode, ExecutionMode::Eager);
571        assert!(config.enable_optimizations);
572        assert!(config.enable_memory_planning);
573    }
574
575    #[test]
576    fn test_execution_config_eager() {
577        let config = ExecutionConfig::eager();
578        assert_eq!(config.mode, ExecutionMode::Eager);
579        assert!(!config.enable_optimizations);
580        assert!(!config.enable_memory_planning);
581    }
582
583    #[test]
584    fn test_execution_config_graph() {
585        let config = ExecutionConfig::graph();
586        assert_eq!(config.mode, ExecutionMode::Graph);
587        assert!(config.enable_optimizations);
588        assert!(config.enable_memory_planning);
589    }
590
591    #[test]
592    fn test_execution_config_builder() {
593        let config = ExecutionConfig::graph()
594            .with_optimizations(false)
595            .with_memory_planning(false);
596
597        assert_eq!(config.mode, ExecutionMode::Graph);
598        assert!(!config.enable_optimizations);
599        assert!(!config.enable_memory_planning);
600    }
601
602    #[test]
603    fn test_compiled_graph_basic() {
604        use tensorlogic_ir::{EinsumNode, OpType};
605
606        let mut graph = EinsumGraph::new();
607        let a_idx = graph.add_tensor("a");
608        let b_idx = graph.add_tensor("b");
609
610        graph.add_input(a_idx).unwrap();
611        graph
612            .add_node(EinsumNode {
613                op: OpType::ElemUnary {
614                    op: "relu".to_string(),
615                },
616                inputs: vec![a_idx],
617                outputs: vec![b_idx],
618                metadata: None,
619            })
620            .unwrap();
621        graph.add_output(b_idx).unwrap();
622
623        let compiled = CompiledGraph::compile(graph);
624
625        assert_eq!(compiled.stats.original_ops, 1);
626        assert_eq!(compiled.stats.optimized_ops, 1);
627        assert_eq!(compiled.stats.eliminated_ops, 0);
628    }
629
630    #[test]
631    fn test_compilation_stats_display() {
632        let stats = CompilationStats {
633            original_ops: 10,
634            optimized_ops: 8,
635            eliminated_ops: 2,
636            fused_ops: 1,
637            compilation_time_ms: 1.5,
638        };
639
640        let display = stats.to_string();
641        assert!(display.contains("original: 10"));
642        assert!(display.contains("optimized: 8"));
643        assert!(display.contains("eliminated: 2"));
644    }
645
646    #[test]
647    fn test_optimization_config_default() {
648        let config = OptimizationConfig::default();
649        assert!(config.enable_constant_folding);
650        assert!(config.enable_fusion);
651        assert!(config.enable_dce);
652        assert!(config.enable_cse);
653        assert!(config.enable_layout_opt);
654        assert!(config.enable_memory_planning);
655    }
656
657    #[test]
658    fn test_optimization_config_aggressive() {
659        let config = OptimizationConfig::aggressive();
660        assert!(config.enable_constant_folding);
661        assert!(config.enable_fusion);
662        assert!(config.enable_dce);
663        assert!(config.enable_cse);
664        assert!(config.enable_layout_opt);
665        assert!(config.enable_memory_planning);
666    }
667
668    #[test]
669    fn test_optimization_config_conservative() {
670        let config = OptimizationConfig::conservative();
671        assert!(config.enable_constant_folding);
672        assert!(!config.enable_fusion);
673        assert!(config.enable_dce);
674        assert!(!config.enable_cse);
675        assert!(!config.enable_layout_opt);
676        assert!(!config.enable_memory_planning);
677    }
678
679    #[test]
680    fn test_optimization_config_none() {
681        let config = OptimizationConfig::none();
682        assert!(!config.enable_constant_folding);
683        assert!(!config.enable_fusion);
684        assert!(!config.enable_dce);
685        assert!(!config.enable_cse);
686        assert!(!config.enable_layout_opt);
687        assert!(!config.enable_memory_planning);
688    }
689
690    #[test]
691    fn test_compiled_graph_with_optimization() {
692        use tensorlogic_ir::{EinsumNode, OpType};
693
694        let mut graph = EinsumGraph::new();
695        let a_idx = graph.add_tensor("a");
696        let b_idx = graph.add_tensor("b");
697        let c_idx = graph.add_tensor("c");
698
699        graph.add_input(a_idx).unwrap();
700
701        // Add a ReLU node
702        graph
703            .add_node(EinsumNode {
704                op: OpType::ElemUnary {
705                    op: "relu".to_string(),
706                },
707                inputs: vec![a_idx],
708                outputs: vec![b_idx],
709                metadata: None,
710            })
711            .unwrap();
712
713        // Add another ReLU node (duplicate for CSE testing)
714        graph
715            .add_node(EinsumNode {
716                op: OpType::ElemUnary {
717                    op: "relu".to_string(),
718                },
719                inputs: vec![a_idx],
720                outputs: vec![c_idx],
721                metadata: None,
722            })
723            .unwrap();
724
725        graph.add_output(b_idx).unwrap();
726
727        let compiled = CompiledGraph::compile(graph);
728
729        assert_eq!(compiled.stats.original_ops, 2);
730        // Note: The optimized ops might be less if CSE works
731        assert!(compiled.stats.compilation_time_ms >= 0.0);
732    }
733
734    #[test]
735    fn test_compiled_graph_with_custom_config() {
736        use tensorlogic_ir::{EinsumNode, OpType};
737
738        let mut graph = EinsumGraph::new();
739        let a_idx = graph.add_tensor("a");
740        let b_idx = graph.add_tensor("b");
741
742        graph.add_input(a_idx).unwrap();
743        graph
744            .add_node(EinsumNode {
745                op: OpType::ElemUnary {
746                    op: "relu".to_string(),
747                },
748                inputs: vec![a_idx],
749                outputs: vec![b_idx],
750                metadata: None,
751            })
752            .unwrap();
753        graph.add_output(b_idx).unwrap();
754
755        let config = OptimizationConfig::none();
756        let compiled = CompiledGraph::compile_with_config(graph, &config);
757
758        assert_eq!(compiled.stats.original_ops, 1);
759        assert_eq!(compiled.stats.optimized_ops, 1);
760        assert_eq!(compiled.stats.eliminated_ops, 0);
761        assert_eq!(compiled.stats.fused_ops, 0);
762        assert!(compiled.memory_plan.is_none());
763    }
764
765    #[test]
766    fn test_memory_plan_basic() {
767        use tensorlogic_ir::{EinsumNode, OpType};
768
769        let mut graph = EinsumGraph::new();
770        let a_idx = graph.add_tensor("a");
771        let b_idx = graph.add_tensor("b");
772        let c_idx = graph.add_tensor("c");
773
774        graph.add_input(a_idx).unwrap();
775        graph
776            .add_node(EinsumNode {
777                op: OpType::ElemUnary {
778                    op: "relu".to_string(),
779                },
780                inputs: vec![a_idx],
781                outputs: vec![b_idx],
782                metadata: None,
783            })
784            .unwrap();
785        graph
786            .add_node(EinsumNode {
787                op: OpType::ElemUnary {
788                    op: "sigmoid".to_string(),
789                },
790                inputs: vec![b_idx],
791                outputs: vec![c_idx],
792                metadata: None,
793            })
794            .unwrap();
795        graph.add_output(c_idx).unwrap();
796
797        let compiled = CompiledGraph::compile(graph);
798
799        assert!(compiled.memory_plan.is_some());
800        let plan = compiled.memory_plan.unwrap();
801        assert!(plan.max_live_tensors > 0);
802        assert!(plan.peak_memory_bytes > 0);
803    }
804
805    #[test]
806    fn test_dce_removes_dead_code() {
807        use tensorlogic_ir::{EinsumNode, OpType};
808
809        let mut graph = EinsumGraph::new();
810        let a_idx = graph.add_tensor("a");
811        let b_idx = graph.add_tensor("b");
812        let c_idx = graph.add_tensor("c");
813        let d_idx = graph.add_tensor("d");
814
815        graph.add_input(a_idx).unwrap();
816
817        // Node that produces b (will be used)
818        graph
819            .add_node(EinsumNode {
820                op: OpType::ElemUnary {
821                    op: "relu".to_string(),
822                },
823                inputs: vec![a_idx],
824                outputs: vec![b_idx],
825                metadata: None,
826            })
827            .unwrap();
828
829        // Dead node that produces c (not used)
830        graph
831            .add_node(EinsumNode {
832                op: OpType::ElemUnary {
833                    op: "sigmoid".to_string(),
834                },
835                inputs: vec![a_idx],
836                outputs: vec![c_idx],
837                metadata: None,
838            })
839            .unwrap();
840
841        // Node that uses b to produce d
842        graph
843            .add_node(EinsumNode {
844                op: OpType::ElemUnary {
845                    op: "oneminus".to_string(),
846                },
847                inputs: vec![b_idx],
848                outputs: vec![d_idx],
849                metadata: None,
850            })
851            .unwrap();
852
853        graph.add_output(d_idx).unwrap();
854
855        let initial_nodes = graph.nodes.len();
856        let removed = eliminate_dead_code(&mut graph).unwrap();
857
858        // Should remove the dead sigmoid node
859        assert!(removed > 0 || graph.nodes.len() < initial_nodes);
860    }
861
862    #[test]
863    fn test_cse_deduplicates_nodes() {
864        use tensorlogic_ir::{EinsumNode, OpType};
865
866        let mut graph = EinsumGraph::new();
867        let a_idx = graph.add_tensor("a");
868        let b_idx = graph.add_tensor("b");
869        let c_idx = graph.add_tensor("c");
870
871        graph.add_input(a_idx).unwrap();
872
873        // First ReLU
874        graph
875            .add_node(EinsumNode {
876                op: OpType::ElemUnary {
877                    op: "relu".to_string(),
878                },
879                inputs: vec![a_idx],
880                outputs: vec![b_idx],
881                metadata: None,
882            })
883            .unwrap();
884
885        // Duplicate ReLU (same operation, same input)
886        graph
887            .add_node(EinsumNode {
888                op: OpType::ElemUnary {
889                    op: "relu".to_string(),
890                },
891                inputs: vec![a_idx],
892                outputs: vec![c_idx],
893                metadata: None,
894            })
895            .unwrap();
896
897        graph.add_output(b_idx).unwrap();
898        graph.add_output(c_idx).unwrap();
899
900        let eliminated = eliminate_common_subexpressions(&mut graph).unwrap();
901
902        // Should detect the duplicate (CSE may or may not eliminate it depending on implementation)
903        // At minimum, the function should not error
904        let _ = eliminated; // Use the value to avoid unused variable warning
905    }
906}