Skip to main content

tensorlogic_infer/
optimization.rs

1//! Graph optimization and fusion detection utilities.
2//!
3//! This module provides utilities for analyzing and optimizing EinsumGraph structures:
4//! - Fusion opportunities detection (combining adjacent operations)
5//! - Dead node elimination
6//! - Redundant computation detection
7//! - Operation reordering for better cache locality
8
9use std::collections::{HashMap, HashSet};
10use tensorlogic_ir::{EinsumGraph, EinsumNode, OpType};
11
12/// Fusion opportunity between two nodes
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub struct FusionOpportunity {
15    pub producer_idx: usize,
16    pub consumer_idx: usize,
17    pub fusion_type: FusionType,
18    pub estimated_speedup: u32, // Percentage improvement
19}
20
21/// Type of fusion that can be applied
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum FusionType {
24    /// Element-wise operations can be fused
25    ElementWise,
26    /// Reduction followed by element-wise
27    ReductionElementWise,
28    /// Multiple reductions on same input
29    MultiReduction,
30    /// Einsum operations with compatible specs
31    EinsumChain,
32}
33
34/// Optimization pass result
35#[derive(Debug, Clone)]
36pub struct OptimizationResult {
37    pub fusion_opportunities: Vec<FusionOpportunity>,
38    pub dead_nodes: Vec<usize>,
39    pub redundant_computations: Vec<(usize, usize)>, // Pairs of equivalent nodes
40    pub estimated_improvement: f64,                  // Overall estimated improvement percentage
41}
42
43impl OptimizationResult {
44    pub fn new() -> Self {
45        OptimizationResult {
46            fusion_opportunities: Vec::new(),
47            dead_nodes: Vec::new(),
48            redundant_computations: Vec::new(),
49            estimated_improvement: 0.0,
50        }
51    }
52
53    pub fn is_empty(&self) -> bool {
54        self.fusion_opportunities.is_empty()
55            && self.dead_nodes.is_empty()
56            && self.redundant_computations.is_empty()
57    }
58
59    pub fn total_opportunities(&self) -> usize {
60        self.fusion_opportunities.len() + self.dead_nodes.len() + self.redundant_computations.len()
61    }
62}
63
64impl Default for OptimizationResult {
65    fn default() -> Self {
66        Self::new()
67    }
68}
69
70/// Graph optimizer for detecting optimization opportunities
71pub struct GraphOptimizer {
72    enable_fusion: bool,
73    enable_dead_node_elimination: bool,
74    enable_redundancy_detection: bool,
75    min_fusion_benefit: u32,
76}
77
78impl GraphOptimizer {
79    pub fn new() -> Self {
80        GraphOptimizer {
81            enable_fusion: true,
82            enable_dead_node_elimination: true,
83            enable_redundancy_detection: true,
84            min_fusion_benefit: 10, // Minimum 10% improvement
85        }
86    }
87
88    pub fn with_fusion(mut self, enabled: bool) -> Self {
89        self.enable_fusion = enabled;
90        self
91    }
92
93    pub fn with_dead_node_elimination(mut self, enabled: bool) -> Self {
94        self.enable_dead_node_elimination = enabled;
95        self
96    }
97
98    pub fn with_redundancy_detection(mut self, enabled: bool) -> Self {
99        self.enable_redundancy_detection = enabled;
100        self
101    }
102
103    pub fn with_min_fusion_benefit(mut self, min_benefit: u32) -> Self {
104        self.min_fusion_benefit = min_benefit;
105        self
106    }
107
108    /// Analyze graph and detect optimization opportunities
109    pub fn analyze(&self, graph: &EinsumGraph) -> OptimizationResult {
110        let mut result = OptimizationResult::new();
111
112        // Build dependency information
113        let tensor_producers = self.build_producer_map(graph);
114        let tensor_consumers = self.build_consumer_map(graph);
115
116        // Detect fusion opportunities
117        if self.enable_fusion {
118            result.fusion_opportunities =
119                self.detect_fusion_opportunities(graph, &tensor_producers, &tensor_consumers);
120        }
121
122        // Detect dead nodes
123        if self.enable_dead_node_elimination {
124            result.dead_nodes = self.detect_dead_nodes(graph, &tensor_consumers);
125        }
126
127        // Detect redundant computations
128        if self.enable_redundancy_detection {
129            result.redundant_computations = self.detect_redundant_computations(graph);
130        }
131
132        // Estimate overall improvement
133        result.estimated_improvement = self.estimate_improvement(&result);
134
135        result
136    }
137
138    /// Build map of which node produces each tensor
139    fn build_producer_map(&self, graph: &EinsumGraph) -> HashMap<usize, usize> {
140        let mut producers = HashMap::new();
141        for (node_idx, node) in graph.nodes.iter().enumerate() {
142            for &output_idx in &node.outputs {
143                producers.insert(output_idx, node_idx);
144            }
145        }
146        producers
147    }
148
149    /// Build map of which nodes consume each tensor
150    fn build_consumer_map(&self, graph: &EinsumGraph) -> HashMap<usize, Vec<usize>> {
151        let mut consumers: HashMap<usize, Vec<usize>> = HashMap::new();
152        for (node_idx, node) in graph.nodes.iter().enumerate() {
153            for &input_idx in &node.inputs {
154                consumers.entry(input_idx).or_default().push(node_idx);
155            }
156        }
157        consumers
158    }
159
160    /// Detect fusion opportunities between adjacent nodes
161    fn detect_fusion_opportunities(
162        &self,
163        graph: &EinsumGraph,
164        tensor_producers: &HashMap<usize, usize>,
165        tensor_consumers: &HashMap<usize, Vec<usize>>,
166    ) -> Vec<FusionOpportunity> {
167        let mut opportunities = Vec::new();
168
169        for (node_idx, node) in graph.nodes.iter().enumerate() {
170            // Check each input to see if it comes from a fusible producer
171            for &input_idx in &node.inputs {
172                if let Some(&producer_idx) = tensor_producers.get(&input_idx) {
173                    // Check if this tensor is only used by current node (enables fusion)
174                    let is_single_use = tensor_consumers
175                        .get(&input_idx)
176                        .map(|consumers| consumers.len() == 1)
177                        .unwrap_or(false);
178
179                    if is_single_use {
180                        if let Some(fusion_type) = self.can_fuse(&graph.nodes[producer_idx], node) {
181                            let estimated_speedup = self.estimate_fusion_speedup(fusion_type);
182                            if estimated_speedup >= self.min_fusion_benefit {
183                                opportunities.push(FusionOpportunity {
184                                    producer_idx,
185                                    consumer_idx: node_idx,
186                                    fusion_type,
187                                    estimated_speedup,
188                                });
189                            }
190                        }
191                    }
192                }
193            }
194        }
195
196        opportunities
197    }
198
199    /// Check if two nodes can be fused
200    fn can_fuse(&self, producer: &EinsumNode, consumer: &EinsumNode) -> Option<FusionType> {
201        match (&producer.op, &consumer.op) {
202            // Element-wise + Element-wise (unary or binary)
203            (OpType::ElemUnary { .. }, OpType::ElemUnary { .. })
204            | (OpType::ElemUnary { .. }, OpType::ElemBinary { .. })
205            | (OpType::ElemBinary { .. }, OpType::ElemUnary { .. })
206            | (OpType::ElemBinary { .. }, OpType::ElemBinary { .. }) => {
207                Some(FusionType::ElementWise)
208            }
209
210            // Reduction + Element-wise
211            (OpType::Reduce { .. }, OpType::ElemUnary { .. })
212            | (OpType::Reduce { .. }, OpType::ElemBinary { .. }) => {
213                Some(FusionType::ReductionElementWise)
214            }
215
216            // Einsum chain (simplified detection)
217            (OpType::Einsum { .. }, OpType::Einsum { .. }) => Some(FusionType::EinsumChain),
218
219            _ => None,
220        }
221    }
222
223    /// Estimate speedup from fusion
224    fn estimate_fusion_speedup(&self, fusion_type: FusionType) -> u32 {
225        match fusion_type {
226            FusionType::ElementWise => 40, // High benefit: eliminates memory round-trip
227            FusionType::ReductionElementWise => 25, // Moderate benefit
228            FusionType::MultiReduction => 30, // Good benefit: shared input loading
229            FusionType::EinsumChain => 20, // Lower benefit: complex to fuse
230        }
231    }
232
233    /// Detect nodes whose outputs are never used
234    fn detect_dead_nodes(
235        &self,
236        graph: &EinsumGraph,
237        tensor_consumers: &HashMap<usize, Vec<usize>>,
238    ) -> Vec<usize> {
239        let mut dead_nodes = Vec::new();
240
241        for (node_idx, node) in graph.nodes.iter().enumerate() {
242            // A node is dead if none of its outputs are consumed
243            let all_outputs_unused = node.outputs.iter().all(|&output_idx| {
244                tensor_consumers
245                    .get(&output_idx)
246                    .map(|consumers| consumers.is_empty())
247                    .unwrap_or(true)
248            });
249
250            if all_outputs_unused {
251                dead_nodes.push(node_idx);
252            }
253        }
254
255        dead_nodes
256    }
257
258    /// Detect redundant computations (nodes with identical inputs and operations)
259    fn detect_redundant_computations(&self, graph: &EinsumGraph) -> Vec<(usize, usize)> {
260        let mut redundant_pairs = Vec::new();
261        let mut seen: HashMap<String, Vec<usize>> = HashMap::new();
262
263        for (node_idx, node) in graph.nodes.iter().enumerate() {
264            // Create signature for this node (op + sorted inputs)
265            let mut signature = format!("{:?}", node.op);
266            let mut sorted_inputs = node.inputs.clone();
267            sorted_inputs.sort_unstable();
268            signature.push_str(&format!("{:?}", sorted_inputs));
269
270            // Check if we've seen this signature before
271            if let Some(previous_nodes) = seen.get(&signature) {
272                for &prev_idx in previous_nodes {
273                    redundant_pairs.push((prev_idx, node_idx));
274                }
275            }
276
277            seen.entry(signature).or_default().push(node_idx);
278        }
279
280        redundant_pairs
281    }
282
283    /// Estimate overall improvement percentage
284    fn estimate_improvement(&self, result: &OptimizationResult) -> f64 {
285        let mut total_improvement = 0.0;
286
287        // Add fusion benefits
288        for fusion in &result.fusion_opportunities {
289            total_improvement += fusion.estimated_speedup as f64;
290        }
291
292        // Add dead node elimination (assume 5% per dead node)
293        total_improvement += result.dead_nodes.len() as f64 * 5.0;
294
295        // Add redundancy elimination (assume 10% per redundant pair)
296        total_improvement += result.redundant_computations.len() as f64 * 10.0;
297
298        total_improvement
299    }
300}
301
302impl Default for GraphOptimizer {
303    fn default() -> Self {
304        Self::new()
305    }
306}
307
308/// Fusion planner for actually applying fusion transformations
309pub struct FusionPlanner {
310    max_fusion_depth: usize,
311}
312
313impl FusionPlanner {
314    pub fn new() -> Self {
315        FusionPlanner {
316            max_fusion_depth: 3,
317        }
318    }
319
320    pub fn with_max_depth(mut self, depth: usize) -> Self {
321        self.max_fusion_depth = depth;
322        self
323    }
324
325    /// Plan which fusions to apply (considering dependencies and depth limits)
326    pub fn plan_fusions(&self, opportunities: &[FusionOpportunity]) -> Vec<FusionOpportunity> {
327        let mut planned = Vec::new();
328        let mut fused_nodes = HashSet::new();
329
330        // Sort by estimated speedup (highest first)
331        let mut sorted_ops = opportunities.to_vec();
332        sorted_ops.sort_by(|a, b| b.estimated_speedup.cmp(&a.estimated_speedup));
333
334        for fusion in sorted_ops {
335            // Skip if either node already part of a fusion
336            if fused_nodes.contains(&fusion.producer_idx)
337                || fused_nodes.contains(&fusion.consumer_idx)
338            {
339                continue;
340            }
341
342            // Check depth limit (simplified: just count current chain length)
343            if planned.len() >= self.max_fusion_depth {
344                break;
345            }
346
347            planned.push(fusion.clone());
348            fused_nodes.insert(fusion.producer_idx);
349            fused_nodes.insert(fusion.consumer_idx);
350        }
351
352        planned
353    }
354
355    /// Validate that planned fusions don't conflict
356    pub fn validate_plan(&self, plan: &[FusionOpportunity]) -> bool {
357        let mut used_nodes = HashSet::new();
358
359        for fusion in plan {
360            if used_nodes.contains(&fusion.producer_idx)
361                || used_nodes.contains(&fusion.consumer_idx)
362            {
363                return false;
364            }
365            used_nodes.insert(fusion.producer_idx);
366            used_nodes.insert(fusion.consumer_idx);
367        }
368
369        true
370    }
371}
372
373impl Default for FusionPlanner {
374    fn default() -> Self {
375        Self::new()
376    }
377}
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382
383    fn create_test_graph() -> EinsumGraph {
384        let mut graph = EinsumGraph::new();
385
386        // Add input tensors
387        graph.tensors.push("x".to_string()); // tensor 0
388        graph.tensors.push("y".to_string()); // tensor 1
389
390        // Node 0: Einsum ab,bc->ac
391        graph.tensors.push("t2".to_string()); // tensor 2
392        graph.nodes.push(EinsumNode {
393            inputs: vec![0, 1],
394            outputs: vec![2],
395            op: OpType::Einsum {
396                spec: "ab,bc->ac".into(),
397            },
398            metadata: None,
399        });
400
401        // Node 1: Element-wise operation on tensor 2
402        graph.tensors.push("t3".to_string()); // tensor 3
403        graph.nodes.push(EinsumNode {
404            inputs: vec![2],
405            outputs: vec![3],
406            op: OpType::ElemUnary { op: "add".into() },
407            metadata: None,
408        });
409
410        // Node 2: Another element-wise on tensor 3 (fusible with node 1)
411        graph.tensors.push("t4".to_string()); // tensor 4
412        graph.nodes.push(EinsumNode {
413            inputs: vec![3],
414            outputs: vec![4],
415            op: OpType::ElemUnary { op: "mul".into() },
416            metadata: None,
417        });
418
419        graph
420    }
421
422    fn create_graph_with_dead_node() -> EinsumGraph {
423        let mut graph = create_test_graph();
424
425        // Add a dead node whose output is never used
426        graph.tensors.push("t5".to_string()); // tensor 5 (never consumed)
427        graph.nodes.push(EinsumNode {
428            inputs: vec![0],
429            outputs: vec![5],
430            op: OpType::ElemUnary { op: "add".into() },
431            metadata: None,
432        });
433
434        graph
435    }
436
437    fn create_graph_with_redundancy() -> EinsumGraph {
438        let mut graph = EinsumGraph::new();
439
440        // Input tensors
441        graph.tensors.push("x".to_string()); // tensor 0
442        graph.tensors.push("y".to_string()); // tensor 1
443
444        // Node 0: Add tensors 0 and 1
445        graph.tensors.push("t2".to_string()); // tensor 2
446        graph.nodes.push(EinsumNode {
447            inputs: vec![0, 1],
448            outputs: vec![2],
449            op: OpType::ElemBinary { op: "add".into() },
450            metadata: None,
451        });
452
453        // Node 1: Duplicate of node 0 (redundant)
454        graph.tensors.push("t3".to_string()); // tensor 3
455        graph.nodes.push(EinsumNode {
456            inputs: vec![0, 1],
457            outputs: vec![3],
458            op: OpType::ElemBinary { op: "add".into() },
459            metadata: None,
460        });
461
462        graph
463    }
464
465    #[test]
466    fn test_optimizer_creation() {
467        let optimizer = GraphOptimizer::new();
468        assert!(optimizer.enable_fusion);
469        assert!(optimizer.enable_dead_node_elimination);
470        assert!(optimizer.enable_redundancy_detection);
471        assert_eq!(optimizer.min_fusion_benefit, 10);
472    }
473
474    #[test]
475    fn test_optimizer_builder() {
476        let optimizer = GraphOptimizer::new()
477            .with_fusion(false)
478            .with_dead_node_elimination(false)
479            .with_min_fusion_benefit(20);
480
481        assert!(!optimizer.enable_fusion);
482        assert!(!optimizer.enable_dead_node_elimination);
483        assert_eq!(optimizer.min_fusion_benefit, 20);
484    }
485
486    #[test]
487    fn test_producer_map() {
488        let graph = create_test_graph();
489        let optimizer = GraphOptimizer::new();
490        let producers = optimizer.build_producer_map(&graph);
491
492        assert_eq!(producers.get(&2), Some(&0)); // Node 0 produces tensor 2
493        assert_eq!(producers.get(&3), Some(&1)); // Node 1 produces tensor 3
494        assert_eq!(producers.get(&4), Some(&2)); // Node 2 produces tensor 4
495    }
496
497    #[test]
498    fn test_consumer_map() {
499        let graph = create_test_graph();
500        let optimizer = GraphOptimizer::new();
501        let consumers = optimizer.build_consumer_map(&graph);
502
503        assert_eq!(consumers.get(&0), Some(&vec![0])); // Tensor 0 consumed by node 0
504        assert_eq!(consumers.get(&2), Some(&vec![1])); // Tensor 2 consumed by node 1
505        assert_eq!(consumers.get(&3), Some(&vec![2])); // Tensor 3 consumed by node 2
506    }
507
508    #[test]
509    fn test_fusion_detection() {
510        let graph = create_test_graph();
511        let optimizer = GraphOptimizer::new();
512        let result = optimizer.analyze(&graph);
513
514        // Should detect fusion opportunity between nodes 1 and 2 (element-wise chain)
515        assert!(!result.fusion_opportunities.is_empty());
516        let fusion = &result.fusion_opportunities[0];
517        assert_eq!(fusion.fusion_type, FusionType::ElementWise);
518        assert!(fusion.estimated_speedup >= 10);
519    }
520
521    #[test]
522    fn test_dead_node_detection() {
523        let graph = create_graph_with_dead_node();
524        let optimizer = GraphOptimizer::new();
525        let result = optimizer.analyze(&graph);
526
527        // Should detect node 3 as dead (output never consumed)
528        assert!(!result.dead_nodes.is_empty());
529        assert!(result.dead_nodes.contains(&3));
530    }
531
532    #[test]
533    fn test_redundancy_detection() {
534        let graph = create_graph_with_redundancy();
535        let optimizer = GraphOptimizer::new();
536        let result = optimizer.analyze(&graph);
537
538        // Should detect nodes 0 and 1 as redundant
539        assert!(!result.redundant_computations.is_empty());
540        assert_eq!(result.redundant_computations[0], (0, 1));
541    }
542
543    #[test]
544    fn test_optimization_result_empty() {
545        let result = OptimizationResult::new();
546        assert!(result.is_empty());
547        assert_eq!(result.total_opportunities(), 0);
548    }
549
550    #[test]
551    fn test_optimization_result_nonempty() {
552        let mut result = OptimizationResult::new();
553        result.fusion_opportunities.push(FusionOpportunity {
554            producer_idx: 0,
555            consumer_idx: 1,
556            fusion_type: FusionType::ElementWise,
557            estimated_speedup: 40,
558        });
559        result.dead_nodes.push(2);
560
561        assert!(!result.is_empty());
562        assert_eq!(result.total_opportunities(), 2);
563    }
564
565    #[test]
566    fn test_can_fuse_elementwise() {
567        let optimizer = GraphOptimizer::new();
568
569        let producer = EinsumNode {
570            inputs: vec![0],
571            outputs: vec![1],
572            op: OpType::ElemUnary { op: "add".into() },
573            metadata: None,
574        };
575
576        let consumer = EinsumNode {
577            inputs: vec![1],
578            outputs: vec![2],
579            op: OpType::ElemUnary { op: "mul".into() },
580            metadata: None,
581        };
582
583        let fusion_type = optimizer.can_fuse(&producer, &consumer);
584        assert_eq!(fusion_type, Some(FusionType::ElementWise));
585    }
586
587    #[test]
588    fn test_fusion_planner_creation() {
589        let planner = FusionPlanner::new();
590        assert_eq!(planner.max_fusion_depth, 3);
591    }
592
593    #[test]
594    fn test_fusion_planner_with_depth() {
595        let planner = FusionPlanner::new().with_max_depth(5);
596        assert_eq!(planner.max_fusion_depth, 5);
597    }
598
599    #[test]
600    fn test_fusion_planning() {
601        let opportunities = vec![
602            FusionOpportunity {
603                producer_idx: 0,
604                consumer_idx: 1,
605                fusion_type: FusionType::ElementWise,
606                estimated_speedup: 40,
607            },
608            FusionOpportunity {
609                producer_idx: 2,
610                consumer_idx: 3,
611                fusion_type: FusionType::ReductionElementWise,
612                estimated_speedup: 25,
613            },
614        ];
615
616        let planner = FusionPlanner::new();
617        let plan = planner.plan_fusions(&opportunities);
618
619        assert_eq!(plan.len(), 2);
620        assert!(planner.validate_plan(&plan));
621    }
622
623    #[test]
624    fn test_fusion_planning_with_conflicts() {
625        let opportunities = vec![
626            FusionOpportunity {
627                producer_idx: 0,
628                consumer_idx: 1,
629                fusion_type: FusionType::ElementWise,
630                estimated_speedup: 40,
631            },
632            FusionOpportunity {
633                producer_idx: 1, // Conflicts with previous consumer
634                consumer_idx: 2,
635                fusion_type: FusionType::ElementWise,
636                estimated_speedup: 35,
637            },
638        ];
639
640        let planner = FusionPlanner::new();
641        let plan = planner.plan_fusions(&opportunities);
642
643        // Should only include the first fusion (higher speedup)
644        assert_eq!(plan.len(), 1);
645        assert_eq!(plan[0].producer_idx, 0);
646    }
647
648    #[test]
649    fn test_estimate_improvement() {
650        let optimizer = GraphOptimizer::new();
651        let mut result = OptimizationResult::new();
652
653        result.fusion_opportunities.push(FusionOpportunity {
654            producer_idx: 0,
655            consumer_idx: 1,
656            fusion_type: FusionType::ElementWise,
657            estimated_speedup: 40,
658        });
659        result.dead_nodes.push(2);
660        result.redundant_computations.push((3, 4));
661
662        let improvement = optimizer.estimate_improvement(&result);
663        assert!(improvement > 0.0);
664        assert_eq!(improvement, 40.0 + 5.0 + 10.0); // fusion + dead + redundant
665    }
666
667    #[test]
668    fn test_disabled_optimizations() {
669        let graph = create_graph_with_dead_node();
670        let optimizer = GraphOptimizer::new()
671            .with_fusion(false)
672            .with_dead_node_elimination(false)
673            .with_redundancy_detection(false);
674
675        let result = optimizer.analyze(&graph);
676
677        assert!(result.fusion_opportunities.is_empty());
678        assert!(result.dead_nodes.is_empty());
679        assert!(result.redundant_computations.is_empty());
680    }
681}