Skip to main content

tensorlogic_infer/
optimization.rs

1//! Graph optimization and fusion detection utilities.
2//!
3//! This module provides utilities for analyzing and optimizing EinsumGraph structures:
4//! - Fusion opportunities detection (combining adjacent operations)
5//! - Dead node elimination
6//! - Redundant computation detection
7//! - Operation reordering for better cache locality
8
9use std::cmp::Reverse;
10use std::collections::{HashMap, HashSet};
11use tensorlogic_ir::{EinsumGraph, EinsumNode, OpType};
12
13/// Fusion opportunity between two nodes
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct FusionOpportunity {
16    pub producer_idx: usize,
17    pub consumer_idx: usize,
18    pub fusion_type: FusionType,
19    pub estimated_speedup: u32, // Percentage improvement
20}
21
22/// Type of fusion that can be applied
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum FusionType {
25    /// Element-wise operations can be fused
26    ElementWise,
27    /// Reduction followed by element-wise
28    ReductionElementWise,
29    /// Multiple reductions on same input
30    MultiReduction,
31    /// Einsum operations with compatible specs
32    EinsumChain,
33}
34
35/// Optimization pass result
36#[derive(Debug, Clone)]
37pub struct OptimizationResult {
38    pub fusion_opportunities: Vec<FusionOpportunity>,
39    pub dead_nodes: Vec<usize>,
40    pub redundant_computations: Vec<(usize, usize)>, // Pairs of equivalent nodes
41    pub estimated_improvement: f64,                  // Overall estimated improvement percentage
42}
43
44impl OptimizationResult {
45    pub fn new() -> Self {
46        OptimizationResult {
47            fusion_opportunities: Vec::new(),
48            dead_nodes: Vec::new(),
49            redundant_computations: Vec::new(),
50            estimated_improvement: 0.0,
51        }
52    }
53
54    pub fn is_empty(&self) -> bool {
55        self.fusion_opportunities.is_empty()
56            && self.dead_nodes.is_empty()
57            && self.redundant_computations.is_empty()
58    }
59
60    pub fn total_opportunities(&self) -> usize {
61        self.fusion_opportunities.len() + self.dead_nodes.len() + self.redundant_computations.len()
62    }
63}
64
65impl Default for OptimizationResult {
66    fn default() -> Self {
67        Self::new()
68    }
69}
70
71/// Graph optimizer for detecting optimization opportunities
72pub struct GraphOptimizer {
73    enable_fusion: bool,
74    enable_dead_node_elimination: bool,
75    enable_redundancy_detection: bool,
76    min_fusion_benefit: u32,
77}
78
79impl GraphOptimizer {
80    pub fn new() -> Self {
81        GraphOptimizer {
82            enable_fusion: true,
83            enable_dead_node_elimination: true,
84            enable_redundancy_detection: true,
85            min_fusion_benefit: 10, // Minimum 10% improvement
86        }
87    }
88
89    pub fn with_fusion(mut self, enabled: bool) -> Self {
90        self.enable_fusion = enabled;
91        self
92    }
93
94    pub fn with_dead_node_elimination(mut self, enabled: bool) -> Self {
95        self.enable_dead_node_elimination = enabled;
96        self
97    }
98
99    pub fn with_redundancy_detection(mut self, enabled: bool) -> Self {
100        self.enable_redundancy_detection = enabled;
101        self
102    }
103
104    pub fn with_min_fusion_benefit(mut self, min_benefit: u32) -> Self {
105        self.min_fusion_benefit = min_benefit;
106        self
107    }
108
109    /// Analyze graph and detect optimization opportunities
110    pub fn analyze(&self, graph: &EinsumGraph) -> OptimizationResult {
111        let mut result = OptimizationResult::new();
112
113        // Build dependency information
114        let tensor_producers = self.build_producer_map(graph);
115        let tensor_consumers = self.build_consumer_map(graph);
116
117        // Detect fusion opportunities
118        if self.enable_fusion {
119            result.fusion_opportunities =
120                self.detect_fusion_opportunities(graph, &tensor_producers, &tensor_consumers);
121        }
122
123        // Detect dead nodes
124        if self.enable_dead_node_elimination {
125            result.dead_nodes = self.detect_dead_nodes(graph, &tensor_consumers);
126        }
127
128        // Detect redundant computations
129        if self.enable_redundancy_detection {
130            result.redundant_computations = self.detect_redundant_computations(graph);
131        }
132
133        // Estimate overall improvement
134        result.estimated_improvement = self.estimate_improvement(&result);
135
136        result
137    }
138
139    /// Build map of which node produces each tensor
140    fn build_producer_map(&self, graph: &EinsumGraph) -> HashMap<usize, usize> {
141        let mut producers = HashMap::new();
142        for (node_idx, node) in graph.nodes.iter().enumerate() {
143            for &output_idx in &node.outputs {
144                producers.insert(output_idx, node_idx);
145            }
146        }
147        producers
148    }
149
150    /// Build map of which nodes consume each tensor
151    fn build_consumer_map(&self, graph: &EinsumGraph) -> HashMap<usize, Vec<usize>> {
152        let mut consumers: HashMap<usize, Vec<usize>> = HashMap::new();
153        for (node_idx, node) in graph.nodes.iter().enumerate() {
154            for &input_idx in &node.inputs {
155                consumers.entry(input_idx).or_default().push(node_idx);
156            }
157        }
158        consumers
159    }
160
161    /// Detect fusion opportunities between adjacent nodes
162    fn detect_fusion_opportunities(
163        &self,
164        graph: &EinsumGraph,
165        tensor_producers: &HashMap<usize, usize>,
166        tensor_consumers: &HashMap<usize, Vec<usize>>,
167    ) -> Vec<FusionOpportunity> {
168        let mut opportunities = Vec::new();
169
170        for (node_idx, node) in graph.nodes.iter().enumerate() {
171            // Check each input to see if it comes from a fusible producer
172            for &input_idx in &node.inputs {
173                if let Some(&producer_idx) = tensor_producers.get(&input_idx) {
174                    // Check if this tensor is only used by current node (enables fusion)
175                    let is_single_use = tensor_consumers
176                        .get(&input_idx)
177                        .map(|consumers| consumers.len() == 1)
178                        .unwrap_or(false);
179
180                    if is_single_use {
181                        if let Some(fusion_type) = self.can_fuse(&graph.nodes[producer_idx], node) {
182                            let estimated_speedup = self.estimate_fusion_speedup(fusion_type);
183                            if estimated_speedup >= self.min_fusion_benefit {
184                                opportunities.push(FusionOpportunity {
185                                    producer_idx,
186                                    consumer_idx: node_idx,
187                                    fusion_type,
188                                    estimated_speedup,
189                                });
190                            }
191                        }
192                    }
193                }
194            }
195        }
196
197        opportunities
198    }
199
200    /// Check if two nodes can be fused
201    fn can_fuse(&self, producer: &EinsumNode, consumer: &EinsumNode) -> Option<FusionType> {
202        match (&producer.op, &consumer.op) {
203            // Element-wise + Element-wise (unary or binary)
204            (OpType::ElemUnary { .. }, OpType::ElemUnary { .. })
205            | (OpType::ElemUnary { .. }, OpType::ElemBinary { .. })
206            | (OpType::ElemBinary { .. }, OpType::ElemUnary { .. })
207            | (OpType::ElemBinary { .. }, OpType::ElemBinary { .. }) => {
208                Some(FusionType::ElementWise)
209            }
210
211            // Reduction + Element-wise
212            (OpType::Reduce { .. }, OpType::ElemUnary { .. })
213            | (OpType::Reduce { .. }, OpType::ElemBinary { .. }) => {
214                Some(FusionType::ReductionElementWise)
215            }
216
217            // Einsum chain (simplified detection)
218            (OpType::Einsum { .. }, OpType::Einsum { .. }) => Some(FusionType::EinsumChain),
219
220            _ => None,
221        }
222    }
223
224    /// Estimate speedup from fusion
225    fn estimate_fusion_speedup(&self, fusion_type: FusionType) -> u32 {
226        match fusion_type {
227            FusionType::ElementWise => 40, // High benefit: eliminates memory round-trip
228            FusionType::ReductionElementWise => 25, // Moderate benefit
229            FusionType::MultiReduction => 30, // Good benefit: shared input loading
230            FusionType::EinsumChain => 20, // Lower benefit: complex to fuse
231        }
232    }
233
234    /// Detect nodes whose outputs are never used
235    fn detect_dead_nodes(
236        &self,
237        graph: &EinsumGraph,
238        tensor_consumers: &HashMap<usize, Vec<usize>>,
239    ) -> Vec<usize> {
240        let mut dead_nodes = Vec::new();
241
242        for (node_idx, node) in graph.nodes.iter().enumerate() {
243            // A node is dead if none of its outputs are consumed
244            let all_outputs_unused = node.outputs.iter().all(|&output_idx| {
245                tensor_consumers
246                    .get(&output_idx)
247                    .map(|consumers| consumers.is_empty())
248                    .unwrap_or(true)
249            });
250
251            if all_outputs_unused {
252                dead_nodes.push(node_idx);
253            }
254        }
255
256        dead_nodes
257    }
258
259    /// Detect redundant computations (nodes with identical inputs and operations)
260    fn detect_redundant_computations(&self, graph: &EinsumGraph) -> Vec<(usize, usize)> {
261        let mut redundant_pairs = Vec::new();
262        let mut seen: HashMap<String, Vec<usize>> = HashMap::new();
263
264        for (node_idx, node) in graph.nodes.iter().enumerate() {
265            // Create signature for this node (op + sorted inputs)
266            let mut signature = format!("{:?}", node.op);
267            let mut sorted_inputs = node.inputs.clone();
268            sorted_inputs.sort_unstable();
269            signature.push_str(&format!("{:?}", sorted_inputs));
270
271            // Check if we've seen this signature before
272            if let Some(previous_nodes) = seen.get(&signature) {
273                for &prev_idx in previous_nodes {
274                    redundant_pairs.push((prev_idx, node_idx));
275                }
276            }
277
278            seen.entry(signature).or_default().push(node_idx);
279        }
280
281        redundant_pairs
282    }
283
284    /// Estimate overall improvement percentage
285    fn estimate_improvement(&self, result: &OptimizationResult) -> f64 {
286        let mut total_improvement = 0.0;
287
288        // Add fusion benefits
289        for fusion in &result.fusion_opportunities {
290            total_improvement += fusion.estimated_speedup as f64;
291        }
292
293        // Add dead node elimination (assume 5% per dead node)
294        total_improvement += result.dead_nodes.len() as f64 * 5.0;
295
296        // Add redundancy elimination (assume 10% per redundant pair)
297        total_improvement += result.redundant_computations.len() as f64 * 10.0;
298
299        total_improvement
300    }
301}
302
303impl Default for GraphOptimizer {
304    fn default() -> Self {
305        Self::new()
306    }
307}
308
309/// Fusion planner for actually applying fusion transformations
310pub struct FusionPlanner {
311    max_fusion_depth: usize,
312}
313
314impl FusionPlanner {
315    pub fn new() -> Self {
316        FusionPlanner {
317            max_fusion_depth: 3,
318        }
319    }
320
321    pub fn with_max_depth(mut self, depth: usize) -> Self {
322        self.max_fusion_depth = depth;
323        self
324    }
325
326    /// Plan which fusions to apply (considering dependencies and depth limits)
327    pub fn plan_fusions(&self, opportunities: &[FusionOpportunity]) -> Vec<FusionOpportunity> {
328        let mut planned = Vec::new();
329        let mut fused_nodes = HashSet::new();
330
331        // Sort by estimated speedup (highest first)
332        let mut sorted_ops = opportunities.to_vec();
333        sorted_ops.sort_by_key(|b| Reverse(b.estimated_speedup));
334
335        for fusion in sorted_ops {
336            // Skip if either node already part of a fusion
337            if fused_nodes.contains(&fusion.producer_idx)
338                || fused_nodes.contains(&fusion.consumer_idx)
339            {
340                continue;
341            }
342
343            // Check depth limit (simplified: just count current chain length)
344            if planned.len() >= self.max_fusion_depth {
345                break;
346            }
347
348            planned.push(fusion.clone());
349            fused_nodes.insert(fusion.producer_idx);
350            fused_nodes.insert(fusion.consumer_idx);
351        }
352
353        planned
354    }
355
356    /// Validate that planned fusions don't conflict
357    pub fn validate_plan(&self, plan: &[FusionOpportunity]) -> bool {
358        let mut used_nodes = HashSet::new();
359
360        for fusion in plan {
361            if used_nodes.contains(&fusion.producer_idx)
362                || used_nodes.contains(&fusion.consumer_idx)
363            {
364                return false;
365            }
366            used_nodes.insert(fusion.producer_idx);
367            used_nodes.insert(fusion.consumer_idx);
368        }
369
370        true
371    }
372}
373
374impl Default for FusionPlanner {
375    fn default() -> Self {
376        Self::new()
377    }
378}
379
380#[cfg(test)]
381mod tests {
382    use super::*;
383
384    fn create_test_graph() -> EinsumGraph {
385        let mut graph = EinsumGraph::new();
386
387        // Add input tensors
388        graph.tensors.push("x".to_string()); // tensor 0
389        graph.tensors.push("y".to_string()); // tensor 1
390
391        // Node 0: Einsum ab,bc->ac
392        graph.tensors.push("t2".to_string()); // tensor 2
393        graph.nodes.push(EinsumNode {
394            inputs: vec![0, 1],
395            outputs: vec![2],
396            op: OpType::Einsum {
397                spec: "ab,bc->ac".into(),
398            },
399            metadata: None,
400        });
401
402        // Node 1: Element-wise operation on tensor 2
403        graph.tensors.push("t3".to_string()); // tensor 3
404        graph.nodes.push(EinsumNode {
405            inputs: vec![2],
406            outputs: vec![3],
407            op: OpType::ElemUnary { op: "add".into() },
408            metadata: None,
409        });
410
411        // Node 2: Another element-wise on tensor 3 (fusible with node 1)
412        graph.tensors.push("t4".to_string()); // tensor 4
413        graph.nodes.push(EinsumNode {
414            inputs: vec![3],
415            outputs: vec![4],
416            op: OpType::ElemUnary { op: "mul".into() },
417            metadata: None,
418        });
419
420        graph
421    }
422
423    fn create_graph_with_dead_node() -> EinsumGraph {
424        let mut graph = create_test_graph();
425
426        // Add a dead node whose output is never used
427        graph.tensors.push("t5".to_string()); // tensor 5 (never consumed)
428        graph.nodes.push(EinsumNode {
429            inputs: vec![0],
430            outputs: vec![5],
431            op: OpType::ElemUnary { op: "add".into() },
432            metadata: None,
433        });
434
435        graph
436    }
437
438    fn create_graph_with_redundancy() -> EinsumGraph {
439        let mut graph = EinsumGraph::new();
440
441        // Input tensors
442        graph.tensors.push("x".to_string()); // tensor 0
443        graph.tensors.push("y".to_string()); // tensor 1
444
445        // Node 0: Add tensors 0 and 1
446        graph.tensors.push("t2".to_string()); // tensor 2
447        graph.nodes.push(EinsumNode {
448            inputs: vec![0, 1],
449            outputs: vec![2],
450            op: OpType::ElemBinary { op: "add".into() },
451            metadata: None,
452        });
453
454        // Node 1: Duplicate of node 0 (redundant)
455        graph.tensors.push("t3".to_string()); // tensor 3
456        graph.nodes.push(EinsumNode {
457            inputs: vec![0, 1],
458            outputs: vec![3],
459            op: OpType::ElemBinary { op: "add".into() },
460            metadata: None,
461        });
462
463        graph
464    }
465
466    #[test]
467    fn test_optimizer_creation() {
468        let optimizer = GraphOptimizer::new();
469        assert!(optimizer.enable_fusion);
470        assert!(optimizer.enable_dead_node_elimination);
471        assert!(optimizer.enable_redundancy_detection);
472        assert_eq!(optimizer.min_fusion_benefit, 10);
473    }
474
475    #[test]
476    fn test_optimizer_builder() {
477        let optimizer = GraphOptimizer::new()
478            .with_fusion(false)
479            .with_dead_node_elimination(false)
480            .with_min_fusion_benefit(20);
481
482        assert!(!optimizer.enable_fusion);
483        assert!(!optimizer.enable_dead_node_elimination);
484        assert_eq!(optimizer.min_fusion_benefit, 20);
485    }
486
487    #[test]
488    fn test_producer_map() {
489        let graph = create_test_graph();
490        let optimizer = GraphOptimizer::new();
491        let producers = optimizer.build_producer_map(&graph);
492
493        assert_eq!(producers.get(&2), Some(&0)); // Node 0 produces tensor 2
494        assert_eq!(producers.get(&3), Some(&1)); // Node 1 produces tensor 3
495        assert_eq!(producers.get(&4), Some(&2)); // Node 2 produces tensor 4
496    }
497
498    #[test]
499    fn test_consumer_map() {
500        let graph = create_test_graph();
501        let optimizer = GraphOptimizer::new();
502        let consumers = optimizer.build_consumer_map(&graph);
503
504        assert_eq!(consumers.get(&0), Some(&vec![0])); // Tensor 0 consumed by node 0
505        assert_eq!(consumers.get(&2), Some(&vec![1])); // Tensor 2 consumed by node 1
506        assert_eq!(consumers.get(&3), Some(&vec![2])); // Tensor 3 consumed by node 2
507    }
508
509    #[test]
510    fn test_fusion_detection() {
511        let graph = create_test_graph();
512        let optimizer = GraphOptimizer::new();
513        let result = optimizer.analyze(&graph);
514
515        // Should detect fusion opportunity between nodes 1 and 2 (element-wise chain)
516        assert!(!result.fusion_opportunities.is_empty());
517        let fusion = &result.fusion_opportunities[0];
518        assert_eq!(fusion.fusion_type, FusionType::ElementWise);
519        assert!(fusion.estimated_speedup >= 10);
520    }
521
522    #[test]
523    fn test_dead_node_detection() {
524        let graph = create_graph_with_dead_node();
525        let optimizer = GraphOptimizer::new();
526        let result = optimizer.analyze(&graph);
527
528        // Should detect node 3 as dead (output never consumed)
529        assert!(!result.dead_nodes.is_empty());
530        assert!(result.dead_nodes.contains(&3));
531    }
532
533    #[test]
534    fn test_redundancy_detection() {
535        let graph = create_graph_with_redundancy();
536        let optimizer = GraphOptimizer::new();
537        let result = optimizer.analyze(&graph);
538
539        // Should detect nodes 0 and 1 as redundant
540        assert!(!result.redundant_computations.is_empty());
541        assert_eq!(result.redundant_computations[0], (0, 1));
542    }
543
544    #[test]
545    fn test_optimization_result_empty() {
546        let result = OptimizationResult::new();
547        assert!(result.is_empty());
548        assert_eq!(result.total_opportunities(), 0);
549    }
550
551    #[test]
552    fn test_optimization_result_nonempty() {
553        let mut result = OptimizationResult::new();
554        result.fusion_opportunities.push(FusionOpportunity {
555            producer_idx: 0,
556            consumer_idx: 1,
557            fusion_type: FusionType::ElementWise,
558            estimated_speedup: 40,
559        });
560        result.dead_nodes.push(2);
561
562        assert!(!result.is_empty());
563        assert_eq!(result.total_opportunities(), 2);
564    }
565
566    #[test]
567    fn test_can_fuse_elementwise() {
568        let optimizer = GraphOptimizer::new();
569
570        let producer = EinsumNode {
571            inputs: vec![0],
572            outputs: vec![1],
573            op: OpType::ElemUnary { op: "add".into() },
574            metadata: None,
575        };
576
577        let consumer = EinsumNode {
578            inputs: vec![1],
579            outputs: vec![2],
580            op: OpType::ElemUnary { op: "mul".into() },
581            metadata: None,
582        };
583
584        let fusion_type = optimizer.can_fuse(&producer, &consumer);
585        assert_eq!(fusion_type, Some(FusionType::ElementWise));
586    }
587
588    #[test]
589    fn test_fusion_planner_creation() {
590        let planner = FusionPlanner::new();
591        assert_eq!(planner.max_fusion_depth, 3);
592    }
593
594    #[test]
595    fn test_fusion_planner_with_depth() {
596        let planner = FusionPlanner::new().with_max_depth(5);
597        assert_eq!(planner.max_fusion_depth, 5);
598    }
599
600    #[test]
601    fn test_fusion_planning() {
602        let opportunities = vec![
603            FusionOpportunity {
604                producer_idx: 0,
605                consumer_idx: 1,
606                fusion_type: FusionType::ElementWise,
607                estimated_speedup: 40,
608            },
609            FusionOpportunity {
610                producer_idx: 2,
611                consumer_idx: 3,
612                fusion_type: FusionType::ReductionElementWise,
613                estimated_speedup: 25,
614            },
615        ];
616
617        let planner = FusionPlanner::new();
618        let plan = planner.plan_fusions(&opportunities);
619
620        assert_eq!(plan.len(), 2);
621        assert!(planner.validate_plan(&plan));
622    }
623
624    #[test]
625    fn test_fusion_planning_with_conflicts() {
626        let opportunities = vec![
627            FusionOpportunity {
628                producer_idx: 0,
629                consumer_idx: 1,
630                fusion_type: FusionType::ElementWise,
631                estimated_speedup: 40,
632            },
633            FusionOpportunity {
634                producer_idx: 1, // Conflicts with previous consumer
635                consumer_idx: 2,
636                fusion_type: FusionType::ElementWise,
637                estimated_speedup: 35,
638            },
639        ];
640
641        let planner = FusionPlanner::new();
642        let plan = planner.plan_fusions(&opportunities);
643
644        // Should only include the first fusion (higher speedup)
645        assert_eq!(plan.len(), 1);
646        assert_eq!(plan[0].producer_idx, 0);
647    }
648
649    #[test]
650    fn test_estimate_improvement() {
651        let optimizer = GraphOptimizer::new();
652        let mut result = OptimizationResult::new();
653
654        result.fusion_opportunities.push(FusionOpportunity {
655            producer_idx: 0,
656            consumer_idx: 1,
657            fusion_type: FusionType::ElementWise,
658            estimated_speedup: 40,
659        });
660        result.dead_nodes.push(2);
661        result.redundant_computations.push((3, 4));
662
663        let improvement = optimizer.estimate_improvement(&result);
664        assert!(improvement > 0.0);
665        assert_eq!(improvement, 40.0 + 5.0 + 10.0); // fusion + dead + redundant
666    }
667
668    #[test]
669    fn test_disabled_optimizations() {
670        let graph = create_graph_with_dead_node();
671        let optimizer = GraphOptimizer::new()
672            .with_fusion(false)
673            .with_dead_node_elimination(false)
674            .with_redundancy_detection(false);
675
676        let result = optimizer.analyze(&graph);
677
678        assert!(result.fusion_opportunities.is_empty());
679        assert!(result.dead_nodes.is_empty());
680        assert!(result.redundant_computations.is_empty());
681    }
682}