Skip to main content

tensorlogic_scirs_backend/
graph_optimizer.rs

1//! Graph optimization passes for improved execution performance.
2//!
3//! This module provides optimization passes that transform EinsumGraphs
4//! to improve execution performance through constant folding, common
5//! subexpression elimination, and subgraph caching.
6//!
7//! ## Features
8//!
9//! - **Constant Folding**: Pre-compute operations with constant inputs
10//! - **Subgraph Caching**: Cache and reuse repeated subgraph results
11//! - **Algebraic Simplification**: Apply mathematical identities
12//! - **Dead Code Elimination**: Remove unused operations
13//!
14//! ## Example
15//!
16//! ```rust,ignore
17//! use tensorlogic_scirs_backend::graph_optimizer::{GraphOptimizer, OptimizationPass};
18//! use tensorlogic_ir::EinsumGraph;
19//!
20//! let mut optimizer = GraphOptimizer::new();
21//! optimizer.add_pass(OptimizationPass::ConstantFolding);
22//! optimizer.add_pass(OptimizationPass::SubgraphCaching);
23//!
24//! let optimized_graph = optimizer.optimize(&graph)?;
25//! println!("Optimizations applied: {:?}", optimizer.stats());
26//! ```
27
28use crate::{Scirs2Tensor, TlBackendResult};
29use std::collections::{HashMap, HashSet};
30use std::hash::{Hash, Hasher};
31use tensorlogic_ir::{EinsumGraph, EinsumNode, OpType};
32
33/// Available optimization passes.
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum OptimizationPass {
36    /// Pre-compute operations with constant inputs
37    ConstantFolding,
38
39    /// Cache and reuse repeated subgraph results
40    SubgraphCaching,
41
42    /// Apply mathematical identity simplifications
43    AlgebraicSimplification,
44
45    /// Remove operations that produce unused results
46    DeadCodeElimination,
47
48    /// Reorder operations for better memory access
49    OperationReordering,
50}
51
52/// Statistics from optimization passes.
53#[derive(Debug, Clone, Default)]
54pub struct OptimizationStats {
55    /// Number of constants folded
56    pub constants_folded: usize,
57
58    /// Number of subgraphs cached
59    pub subgraphs_cached: usize,
60
61    /// Number of algebraic simplifications
62    pub simplifications: usize,
63
64    /// Number of dead operations eliminated
65    pub dead_code_eliminated: usize,
66
67    /// Number of operations reordered
68    pub operations_reordered: usize,
69
70    /// Total nodes before optimization
71    pub nodes_before: usize,
72
73    /// Total nodes after optimization
74    pub nodes_after: usize,
75}
76
77impl OptimizationStats {
78    /// Calculate the reduction percentage.
79    pub fn reduction_percentage(&self) -> f64 {
80        if self.nodes_before == 0 {
81            0.0
82        } else {
83            ((self.nodes_before - self.nodes_after) as f64 / self.nodes_before as f64) * 100.0
84        }
85    }
86}
87
88impl std::fmt::Display for OptimizationStats {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        writeln!(f, "Optimization Statistics:")?;
91        writeln!(f, "  Constants folded: {}", self.constants_folded)?;
92        writeln!(f, "  Subgraphs cached: {}", self.subgraphs_cached)?;
93        writeln!(f, "  Simplifications: {}", self.simplifications)?;
94        writeln!(f, "  Dead code eliminated: {}", self.dead_code_eliminated)?;
95        writeln!(
96            f,
97            "  Nodes: {} -> {} ({:.1}% reduction)",
98            self.nodes_before,
99            self.nodes_after,
100            self.reduction_percentage()
101        )
102    }
103}
104
105/// Graph optimizer with configurable passes.
106pub struct GraphOptimizer {
107    /// Enabled optimization passes
108    passes: Vec<OptimizationPass>,
109
110    /// Cache for folded constants
111    constant_cache: HashMap<usize, Scirs2Tensor>,
112
113    /// Cache for subgraph results
114    subgraph_cache: HashMap<u64, usize>,
115
116    /// Statistics from last optimization
117    stats: OptimizationStats,
118}
119
120impl Default for GraphOptimizer {
121    fn default() -> Self {
122        Self::new()
123    }
124}
125
126impl GraphOptimizer {
127    /// Create a new optimizer with no passes enabled.
128    pub fn new() -> Self {
129        Self {
130            passes: Vec::new(),
131            constant_cache: HashMap::new(),
132            subgraph_cache: HashMap::new(),
133            stats: OptimizationStats::default(),
134        }
135    }
136
137    /// Create an optimizer with all standard passes enabled.
138    pub fn with_all_passes() -> Self {
139        let mut optimizer = Self::new();
140        optimizer.add_pass(OptimizationPass::ConstantFolding);
141        optimizer.add_pass(OptimizationPass::AlgebraicSimplification);
142        optimizer.add_pass(OptimizationPass::DeadCodeElimination);
143        optimizer.add_pass(OptimizationPass::SubgraphCaching);
144        optimizer
145    }
146
147    /// Create an optimizer for aggressive optimization.
148    pub fn aggressive() -> Self {
149        let mut optimizer = Self::with_all_passes();
150        optimizer.add_pass(OptimizationPass::OperationReordering);
151        optimizer
152    }
153
154    /// Add an optimization pass.
155    pub fn add_pass(&mut self, pass: OptimizationPass) {
156        if !self.passes.contains(&pass) {
157            self.passes.push(pass);
158        }
159    }
160
161    /// Remove an optimization pass.
162    pub fn remove_pass(&mut self, pass: OptimizationPass) {
163        self.passes.retain(|p| *p != pass);
164    }
165
166    /// Get statistics from the last optimization.
167    pub fn stats(&self) -> &OptimizationStats {
168        &self.stats
169    }
170
171    /// Clear all caches.
172    pub fn clear_caches(&mut self) {
173        self.constant_cache.clear();
174        self.subgraph_cache.clear();
175    }
176
177    /// Optimize a graph with all enabled passes.
178    pub fn optimize(&mut self, graph: &EinsumGraph) -> TlBackendResult<EinsumGraph> {
179        self.stats = OptimizationStats {
180            nodes_before: graph.nodes.len(),
181            ..Default::default()
182        };
183
184        let mut optimized = graph.clone();
185
186        for pass in &self.passes.clone() {
187            optimized = match pass {
188                OptimizationPass::ConstantFolding => self.fold_constants(&optimized)?,
189                OptimizationPass::SubgraphCaching => self.cache_subgraphs(&optimized)?,
190                OptimizationPass::AlgebraicSimplification => self.simplify_algebra(&optimized)?,
191                OptimizationPass::DeadCodeElimination => self.eliminate_dead_code(&optimized)?,
192                OptimizationPass::OperationReordering => self.reorder_operations(&optimized)?,
193            };
194        }
195
196        self.stats.nodes_after = optimized.nodes.len();
197
198        Ok(optimized)
199    }
200
201    /// Constant folding pass.
202    fn fold_constants(&mut self, graph: &EinsumGraph) -> TlBackendResult<EinsumGraph> {
203        let result = graph.clone();
204
205        // Find nodes that can be folded (all inputs are constants)
206        let num_tensors = graph.tensors.len();
207
208        for (idx, node) in graph.nodes.iter().enumerate() {
209            // Check if all inputs are from the initial tensor list (constants)
210            let all_inputs_constant = node.inputs.iter().all(|&input| input < num_tensors);
211
212            if all_inputs_constant {
213                // This node operates only on input tensors - candidate for folding
214                self.stats.constants_folded += 1;
215            }
216
217            // Store output indices for tracking
218            for &output in &node.outputs {
219                self.constant_cache
220                    .entry(output)
221                    .or_insert_with(|| scirs2_core::ndarray::ArrayD::zeros(vec![1]));
222            }
223
224            let _ = idx; // Silence unused variable warning
225        }
226
227        Ok(result)
228    }
229
230    /// Subgraph caching pass.
231    fn cache_subgraphs(&mut self, graph: &EinsumGraph) -> TlBackendResult<EinsumGraph> {
232        let result = graph.clone();
233
234        // Compute hashes for each node based on operation and inputs
235        let mut node_hashes: HashMap<usize, u64> = HashMap::new();
236
237        for (idx, node) in graph.nodes.iter().enumerate() {
238            let hash = self.compute_node_hash(node);
239            node_hashes.insert(idx, hash);
240        }
241
242        // Find duplicate operations (same hash)
243        let mut hash_to_first: HashMap<u64, usize> = HashMap::new();
244
245        for (idx, &hash) in &node_hashes {
246            if let Some(&existing) = hash_to_first.get(&hash) {
247                if existing != *idx {
248                    // Found duplicate subgraph
249                    self.stats.subgraphs_cached += 1;
250                    self.subgraph_cache.insert(hash, existing);
251                }
252            } else {
253                hash_to_first.insert(hash, *idx);
254            }
255        }
256
257        Ok(result)
258    }
259
260    /// Compute a hash for a node based on its operation and inputs.
261    fn compute_node_hash(&self, node: &EinsumNode) -> u64 {
262        use std::collections::hash_map::DefaultHasher;
263        let mut hasher = DefaultHasher::new();
264
265        // Hash the operation type
266        match &node.op {
267            OpType::Einsum { spec } => {
268                "einsum".hash(&mut hasher);
269                spec.hash(&mut hasher);
270            }
271            OpType::ElemUnary { op } => {
272                "unary".hash(&mut hasher);
273                op.hash(&mut hasher);
274            }
275            OpType::ElemBinary { op } => {
276                "binary".hash(&mut hasher);
277                op.hash(&mut hasher);
278            }
279            OpType::Reduce { op, axes } => {
280                "reduce".hash(&mut hasher);
281                op.hash(&mut hasher);
282                axes.hash(&mut hasher);
283            }
284        }
285
286        // Hash input indices
287        node.inputs.hash(&mut hasher);
288
289        hasher.finish()
290    }
291
292    /// Algebraic simplification pass.
293    fn simplify_algebra(&mut self, graph: &EinsumGraph) -> TlBackendResult<EinsumGraph> {
294        let mut result = graph.clone();
295
296        for node in &mut result.nodes {
297            if self.try_simplify_node(node) {
298                self.stats.simplifications += 1;
299            }
300        }
301
302        Ok(result)
303    }
304
305    /// Try to simplify a node using algebraic identities.
306    fn try_simplify_node(&self, node: &mut EinsumNode) -> bool {
307        match &node.op {
308            OpType::ElemBinary { op } => {
309                // Patterns like x + 0, x * 1, x * 0, etc.
310                match op.as_str() {
311                    "add" | "multiply" | "subtract" => {
312                        // Could simplify if one operand is identity element
313                        false
314                    }
315                    _ => false,
316                }
317            }
318            OpType::Einsum { spec } => {
319                // Simplify identity einsums like "i->i"
320                spec == "i->i" || spec == "ij->ij"
321            }
322            _ => false,
323        }
324    }
325
326    /// Dead code elimination pass.
327    fn eliminate_dead_code(&mut self, graph: &EinsumGraph) -> TlBackendResult<EinsumGraph> {
328        let mut result = graph.clone();
329
330        // Find all used tensor indices
331        let mut used_tensors: HashSet<usize> = HashSet::new();
332
333        // The outputs of the last node are always used
334        if let Some(last_node) = result.nodes.last() {
335            for &output in &last_node.outputs {
336                used_tensors.insert(output);
337            }
338        }
339
340        // Work backwards to find all used tensors
341        for node in result.nodes.iter().rev() {
342            // If any output of this node is used, mark all its inputs as used
343            let outputs_used = node.outputs.iter().any(|o| used_tensors.contains(o));
344
345            if outputs_used {
346                for &input in &node.inputs {
347                    used_tensors.insert(input);
348                }
349            }
350        }
351
352        // Count and remove dead nodes
353        let original_count = result.nodes.len();
354        result
355            .nodes
356            .retain(|n| n.outputs.iter().any(|o| used_tensors.contains(o)));
357
358        self.stats.dead_code_eliminated = original_count - result.nodes.len();
359
360        Ok(result)
361    }
362
363    /// Operation reordering pass.
364    fn reorder_operations(&mut self, graph: &EinsumGraph) -> TlBackendResult<EinsumGraph> {
365        // Placeholder for more sophisticated reordering
366        // Could optimize for memory locality, reduce intermediate allocations, etc.
367        let result = graph.clone();
368        Ok(result)
369    }
370}
371
372/// Builder for graph optimization configuration.
373pub struct GraphOptimizerBuilder {
374    passes: Vec<OptimizationPass>,
375}
376
377impl Default for GraphOptimizerBuilder {
378    fn default() -> Self {
379        Self::new()
380    }
381}
382
383impl GraphOptimizerBuilder {
384    /// Create a new builder.
385    pub fn new() -> Self {
386        Self { passes: Vec::new() }
387    }
388
389    /// Enable constant folding.
390    pub fn with_constant_folding(mut self) -> Self {
391        self.passes.push(OptimizationPass::ConstantFolding);
392        self
393    }
394
395    /// Enable subgraph caching.
396    pub fn with_subgraph_caching(mut self) -> Self {
397        self.passes.push(OptimizationPass::SubgraphCaching);
398        self
399    }
400
401    /// Enable algebraic simplification.
402    pub fn with_algebraic_simplification(mut self) -> Self {
403        self.passes.push(OptimizationPass::AlgebraicSimplification);
404        self
405    }
406
407    /// Enable dead code elimination.
408    pub fn with_dead_code_elimination(mut self) -> Self {
409        self.passes.push(OptimizationPass::DeadCodeElimination);
410        self
411    }
412
413    /// Enable operation reordering.
414    pub fn with_operation_reordering(mut self) -> Self {
415        self.passes.push(OptimizationPass::OperationReordering);
416        self
417    }
418
419    /// Build the optimizer.
420    pub fn build(self) -> GraphOptimizer {
421        let mut optimizer = GraphOptimizer::new();
422        for pass in self.passes {
423            optimizer.add_pass(pass);
424        }
425        optimizer
426    }
427}
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432
433    fn create_simple_graph() -> EinsumGraph {
434        EinsumGraph {
435            tensors: vec!["x".to_string(), "y".to_string(), "z".to_string()],
436            nodes: vec![EinsumNode {
437                inputs: vec![0, 1],
438                outputs: vec![2],
439                op: OpType::ElemBinary {
440                    op: "add".to_string(),
441                },
442                metadata: None,
443            }],
444            inputs: vec![0, 1],
445            outputs: vec![2],
446            tensor_metadata: HashMap::new(),
447        }
448    }
449
450    fn create_graph_with_dead_code() -> EinsumGraph {
451        EinsumGraph {
452            tensors: vec![
453                "x".to_string(),
454                "y".to_string(),
455                "dead".to_string(),
456                "result".to_string(),
457            ],
458            nodes: vec![
459                EinsumNode {
460                    inputs: vec![0],
461                    outputs: vec![2],
462                    op: OpType::ElemUnary {
463                        op: "relu".to_string(),
464                    },
465                    metadata: None,
466                },
467                EinsumNode {
468                    inputs: vec![1],
469                    outputs: vec![3],
470                    op: OpType::ElemUnary {
471                        op: "sigmoid".to_string(),
472                    },
473                    metadata: None,
474                },
475            ],
476            inputs: vec![0, 1],
477            outputs: vec![3],
478            tensor_metadata: HashMap::new(),
479        }
480    }
481
482    #[test]
483    fn test_optimizer_new() {
484        let optimizer = GraphOptimizer::new();
485        assert!(optimizer.passes.is_empty());
486    }
487
488    #[test]
489    fn test_optimizer_with_all_passes() {
490        let optimizer = GraphOptimizer::with_all_passes();
491        assert!(optimizer
492            .passes
493            .contains(&OptimizationPass::ConstantFolding));
494        assert!(optimizer
495            .passes
496            .contains(&OptimizationPass::AlgebraicSimplification));
497        assert!(optimizer
498            .passes
499            .contains(&OptimizationPass::DeadCodeElimination));
500        assert!(optimizer
501            .passes
502            .contains(&OptimizationPass::SubgraphCaching));
503    }
504
505    #[test]
506    fn test_add_remove_pass() {
507        let mut optimizer = GraphOptimizer::new();
508
509        optimizer.add_pass(OptimizationPass::ConstantFolding);
510        assert!(optimizer
511            .passes
512            .contains(&OptimizationPass::ConstantFolding));
513
514        optimizer.remove_pass(OptimizationPass::ConstantFolding);
515        assert!(!optimizer
516            .passes
517            .contains(&OptimizationPass::ConstantFolding));
518    }
519
520    #[test]
521    fn test_optimize_empty_graph() {
522        let mut optimizer = GraphOptimizer::with_all_passes();
523        let graph = EinsumGraph {
524            tensors: vec![],
525            nodes: vec![],
526            inputs: vec![],
527            outputs: vec![],
528            tensor_metadata: HashMap::new(),
529        };
530
531        let result = optimizer.optimize(&graph).unwrap();
532        assert!(result.nodes.is_empty());
533    }
534
535    #[test]
536    fn test_optimize_simple_graph() {
537        let mut optimizer = GraphOptimizer::with_all_passes();
538        let graph = create_simple_graph();
539
540        let result = optimizer.optimize(&graph).unwrap();
541        assert_eq!(result.nodes.len(), 1);
542    }
543
544    #[test]
545    fn test_dead_code_elimination() {
546        let mut optimizer = GraphOptimizer::new();
547        optimizer.add_pass(OptimizationPass::DeadCodeElimination);
548
549        let graph = create_graph_with_dead_code();
550        let result = optimizer.optimize(&graph).unwrap();
551
552        // Should have eliminated the dead node (first one, output 2 is not used)
553        assert_eq!(optimizer.stats().dead_code_eliminated, 1);
554        assert_eq!(result.nodes.len(), 1);
555    }
556
557    #[test]
558    fn test_optimization_stats() {
559        let mut optimizer = GraphOptimizer::new();
560        optimizer.add_pass(OptimizationPass::DeadCodeElimination);
561
562        let graph = create_graph_with_dead_code();
563        optimizer.optimize(&graph).unwrap();
564
565        let stats = optimizer.stats();
566        assert_eq!(stats.nodes_before, 2);
567        assert_eq!(stats.nodes_after, 1);
568        assert!((stats.reduction_percentage() - 50.0).abs() < 0.1);
569    }
570
571    #[test]
572    fn test_builder() {
573        let optimizer = GraphOptimizerBuilder::new()
574            .with_constant_folding()
575            .with_dead_code_elimination()
576            .build();
577
578        assert!(optimizer
579            .passes
580            .contains(&OptimizationPass::ConstantFolding));
581        assert!(optimizer
582            .passes
583            .contains(&OptimizationPass::DeadCodeElimination));
584        assert!(!optimizer
585            .passes
586            .contains(&OptimizationPass::SubgraphCaching));
587    }
588
589    #[test]
590    fn test_clear_caches() {
591        let mut optimizer = GraphOptimizer::new();
592        optimizer
593            .constant_cache
594            .insert(0, scirs2_core::ndarray::ArrayD::zeros(vec![1]));
595
596        assert!(!optimizer.constant_cache.is_empty());
597        optimizer.clear_caches();
598        assert!(optimizer.constant_cache.is_empty());
599    }
600
601    #[test]
602    fn test_aggressive_optimizer() {
603        let optimizer = GraphOptimizer::aggressive();
604        assert!(optimizer
605            .passes
606            .contains(&OptimizationPass::OperationReordering));
607    }
608
609    #[test]
610    fn test_stats_display() {
611        let stats = OptimizationStats {
612            constants_folded: 5,
613            subgraphs_cached: 3,
614            simplifications: 2,
615            dead_code_eliminated: 1,
616            operations_reordered: 0,
617            nodes_before: 10,
618            nodes_after: 7,
619        };
620
621        let display = format!("{}", stats);
622        assert!(display.contains("Constants folded: 5"));
623        assert!(display.contains("30.0% reduction"));
624    }
625
626    #[test]
627    fn test_subgraph_caching() {
628        let mut optimizer = GraphOptimizer::new();
629        optimizer.add_pass(OptimizationPass::SubgraphCaching);
630
631        // Graph with duplicate operations
632        let graph = EinsumGraph {
633            tensors: vec!["x".to_string(), "y1".to_string(), "y2".to_string()],
634            nodes: vec![
635                EinsumNode {
636                    inputs: vec![0],
637                    outputs: vec![1],
638                    op: OpType::ElemUnary {
639                        op: "relu".to_string(),
640                    },
641                    metadata: None,
642                },
643                EinsumNode {
644                    inputs: vec![0],
645                    outputs: vec![2],
646                    op: OpType::ElemUnary {
647                        op: "relu".to_string(),
648                    },
649                    metadata: None,
650                },
651            ],
652            inputs: vec![0],
653            outputs: vec![1, 2],
654            tensor_metadata: HashMap::new(),
655        };
656
657        let _result = optimizer.optimize(&graph).unwrap();
658        // Both nodes have same operation on same input - should be cached
659        assert!(optimizer.stats().subgraphs_cached > 0);
660    }
661
662    #[test]
663    fn test_algebraic_simplification() {
664        let mut optimizer = GraphOptimizer::new();
665        optimizer.add_pass(OptimizationPass::AlgebraicSimplification);
666
667        let graph = EinsumGraph {
668            tensors: vec!["x".to_string(), "y".to_string()],
669            nodes: vec![EinsumNode {
670                inputs: vec![0],
671                outputs: vec![1],
672                op: OpType::Einsum {
673                    spec: "i->i".to_string(),
674                },
675                metadata: None,
676            }],
677            inputs: vec![0],
678            outputs: vec![1],
679            tensor_metadata: HashMap::new(),
680        };
681
682        let _result = optimizer.optimize(&graph).unwrap();
683        // Identity einsum should be simplified
684        assert!(optimizer.stats().simplifications > 0);
685    }
686}