Skip to main content

tensorlogic_ir/graph/
optimization.rs

1//! Graph optimization passes.
2
3use std::collections::{HashMap, HashSet};
4
5use crate::{EinsumGraph, EinsumNode, IrError};
6
7/// Dead Code Elimination (DCE) - removes unused tensors and nodes
8pub fn eliminate_dead_code(graph: &mut EinsumGraph) -> Result<usize, IrError> {
9    if graph.outputs.is_empty() {
10        return Ok(0);
11    }
12
13    // Track which tensors are live (needed)
14    let mut live_tensors = HashSet::new();
15    let mut worklist: Vec<usize> = graph.outputs.clone();
16
17    // Mark all output tensors as live
18    for &output_idx in &graph.outputs {
19        live_tensors.insert(output_idx);
20    }
21
22    // Build tensor-to-node mapping (which node produces each tensor)
23    let mut tensor_producers: HashMap<usize, usize> = HashMap::new();
24    for (node_idx, _node) in graph.nodes.iter().enumerate() {
25        // Each node produces a tensor at index equal to the number of tensors
26        // before this node plus its position
27        let produced_tensor_idx = node_idx + count_input_tensors(graph, node_idx);
28        tensor_producers.insert(produced_tensor_idx, node_idx);
29    }
30
31    // Backward pass: mark all dependencies as live
32    while let Some(tensor_idx) = worklist.pop() {
33        if let Some(&node_idx) = tensor_producers.get(&tensor_idx) {
34            let node = &graph.nodes[node_idx];
35            for &input_idx in &node.inputs {
36                if !live_tensors.contains(&input_idx) {
37                    live_tensors.insert(input_idx);
38                    worklist.push(input_idx);
39                }
40            }
41        }
42    }
43
44    // Remove dead tensors and nodes
45    let mut removed_count = 0;
46
47    // Mark dead nodes for removal (nodes whose output is not live)
48    let mut nodes_to_keep = Vec::new();
49    for (node_idx, node) in graph.nodes.iter().enumerate() {
50        let produced_tensor_idx = node_idx + count_input_tensors(graph, node_idx);
51        if live_tensors.contains(&produced_tensor_idx) {
52            nodes_to_keep.push(node.clone());
53        } else {
54            removed_count += 1;
55        }
56    }
57
58    graph.nodes = nodes_to_keep;
59
60    // Note: We don't actually remove tensors from the tensors vector
61    // as this would require renumbering all node inputs and outputs.
62    // Instead, we just remove the nodes that produce unused tensors.
63
64    Ok(removed_count)
65}
66
67#[allow(dead_code)]
68fn count_input_tensors(graph: &EinsumGraph, before_node: usize) -> usize {
69    // Count how many tensors exist before this node
70    // This is a simplified version - in practice, you'd track this more carefully
71    graph
72        .nodes
73        .iter()
74        .take(before_node)
75        .map(|_| 1) // Each node produces one tensor
76        .sum()
77}
78
79/// Common Subexpression Elimination (CSE) - detects and deduplicates identical subgraphs
80pub fn eliminate_common_subexpressions(graph: &mut EinsumGraph) -> Result<usize, IrError> {
81    let mut node_hashes: HashMap<String, usize> = HashMap::new();
82    let mut replacements: HashMap<usize, usize> = HashMap::new();
83    let mut eliminated_count = 0;
84
85    // Build hash for each node (based on operation and inputs)
86    for (node_idx, node) in graph.nodes.iter().enumerate() {
87        let node_hash = compute_node_hash(node);
88
89        if let Some(&existing_idx) = node_hashes.get(&node_hash) {
90            // Found a duplicate - mark for replacement
91            let produced_tensor_idx = node_idx + count_input_tensors(graph, node_idx);
92            let existing_tensor_idx = existing_idx + count_input_tensors(graph, existing_idx);
93            replacements.insert(produced_tensor_idx, existing_tensor_idx);
94            eliminated_count += 1;
95        } else {
96            node_hashes.insert(node_hash, node_idx);
97        }
98    }
99
100    // Apply replacements to all node inputs and outputs
101    for node in &mut graph.nodes {
102        for input_idx in &mut node.inputs {
103            if let Some(&replacement) = replacements.get(input_idx) {
104                *input_idx = replacement;
105            }
106        }
107    }
108
109    for output_idx in &mut graph.outputs {
110        if let Some(&replacement) = replacements.get(output_idx) {
111            *output_idx = replacement;
112        }
113    }
114
115    // Remove duplicate nodes (would require DCE to actually clean up)
116    Ok(eliminated_count)
117}
118
119#[allow(dead_code)]
120fn compute_node_hash(node: &EinsumNode) -> String {
121    // Simple hash based on operation type and inputs
122    // In a real implementation, you'd use a proper hash function
123    format!("{:?}|{:?}", node.op, node.inputs)
124}
125
126/// Simplify identity operations (operations that don't transform their input)
127pub fn simplify_identity_operations(graph: &mut EinsumGraph) -> Result<usize, IrError> {
128    let mut simplified_count = 0;
129    let mut replacements: HashMap<usize, usize> = HashMap::new();
130
131    for (node_idx, node) in graph.nodes.iter().enumerate() {
132        if is_identity_operation(node) && !node.inputs.is_empty() {
133            // Map output to input directly
134            let produced_tensor_idx = node_idx + count_input_tensors(graph, node_idx);
135            replacements.insert(produced_tensor_idx, node.inputs[0]);
136            simplified_count += 1;
137        }
138    }
139
140    // Apply replacements
141    for node in &mut graph.nodes {
142        for input_idx in &mut node.inputs {
143            if let Some(&replacement) = replacements.get(input_idx) {
144                *input_idx = replacement;
145            }
146        }
147    }
148
149    for output_idx in &mut graph.outputs {
150        if let Some(&replacement) = replacements.get(output_idx) {
151            *output_idx = replacement;
152        }
153    }
154
155    Ok(simplified_count)
156}
157
158#[allow(dead_code)]
159fn is_identity_operation(node: &EinsumNode) -> bool {
160    use crate::OpType;
161
162    match &node.op {
163        // Einsum with identity spec (e.g., "a->a")
164        OpType::Einsum { spec } => {
165            if let Some(arrow_pos) = spec.find("->") {
166                let input_axes = &spec[..arrow_pos];
167                let output_axes = &spec[arrow_pos + 2..];
168                input_axes == output_axes && node.inputs.len() == 1
169            } else {
170                false
171            }
172        }
173        // ElemBinary multiply by 1 or add 0 could be detected here
174        _ => false,
175    }
176}
177
178/// Apply all optimization passes to the graph
179pub fn optimize_graph(graph: &mut EinsumGraph) -> Result<OptimizationStats, IrError> {
180    let mut stats = OptimizationStats::default();
181
182    // Multiple passes for maximum effect
183    for _ in 0..3 {
184        let cse_count = eliminate_common_subexpressions(graph)?;
185        stats.cse_eliminated += cse_count;
186
187        let identity_count = simplify_identity_operations(graph)?;
188        stats.identities_simplified += identity_count;
189
190        let dce_count = eliminate_dead_code(graph)?;
191        stats.dead_code_eliminated += dce_count;
192
193        // Stop if no changes
194        if cse_count == 0 && identity_count == 0 && dce_count == 0 {
195            break;
196        }
197    }
198
199    Ok(stats)
200}
201
202#[derive(Debug, Default, Clone, Copy)]
203pub struct OptimizationStats {
204    pub dead_code_eliminated: usize,
205    pub cse_eliminated: usize,
206    pub identities_simplified: usize,
207}
208
209impl OptimizationStats {
210    pub fn total_optimizations(&self) -> usize {
211        self.dead_code_eliminated + self.cse_eliminated + self.identities_simplified
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218    use crate::OpType;
219
220    #[test]
221    fn test_dead_code_elimination_empty_graph() {
222        let mut graph = EinsumGraph::new();
223        let removed = eliminate_dead_code(&mut graph).unwrap();
224        assert_eq!(removed, 0);
225    }
226
227    #[test]
228    fn test_dead_code_elimination_no_outputs() {
229        let mut graph = EinsumGraph::new();
230        graph.add_tensor("a[i]");
231        graph.add_tensor("b[i]");
232        let removed = eliminate_dead_code(&mut graph).unwrap();
233        assert_eq!(removed, 0); // No outputs, so nothing to eliminate
234    }
235
236    #[test]
237    fn test_identity_operation_detection() {
238        let identity_node = EinsumNode {
239            op: OpType::Einsum {
240                spec: "a->a".to_string(),
241            },
242            inputs: vec![0],
243            outputs: vec![1],
244            metadata: None,
245        };
246        assert!(is_identity_operation(&identity_node));
247
248        let non_identity_node = EinsumNode {
249            op: OpType::Einsum {
250                spec: "ab->a".to_string(),
251            },
252            inputs: vec![0],
253            outputs: vec![1],
254            metadata: None,
255        };
256        assert!(!is_identity_operation(&non_identity_node));
257    }
258
259    #[test]
260    fn test_node_hash_computation() {
261        let node1 = EinsumNode {
262            op: OpType::Einsum {
263                spec: "ab->a".to_string(),
264            },
265            inputs: vec![0],
266            outputs: vec![1],
267            metadata: None,
268        };
269        let node2 = EinsumNode {
270            op: OpType::Einsum {
271                spec: "ab->a".to_string(),
272            },
273            inputs: vec![0],
274            outputs: vec![1],
275            metadata: None,
276        };
277        let node3 = EinsumNode {
278            op: OpType::Einsum {
279                spec: "ab->b".to_string(),
280            },
281            inputs: vec![0],
282            outputs: vec![1],
283            metadata: None,
284        };
285
286        assert_eq!(compute_node_hash(&node1), compute_node_hash(&node2));
287        assert_ne!(compute_node_hash(&node1), compute_node_hash(&node3));
288    }
289
290    #[test]
291    fn test_optimization_stats() {
292        let stats = OptimizationStats {
293            dead_code_eliminated: 2,
294            cse_eliminated: 3,
295            identities_simplified: 1,
296        };
297        assert_eq!(stats.total_optimizations(), 6);
298    }
299
300    #[test]
301    fn test_full_optimization_pipeline() {
302        let mut graph = EinsumGraph::new();
303        let t0 = graph.add_tensor("input[a]");
304        let t1 = graph.add_tensor("output[a]");
305
306        // Add some nodes
307        let _n1 = graph
308            .add_node(EinsumNode {
309                op: OpType::Einsum {
310                    spec: "a->a".to_string(),
311                },
312                inputs: vec![t0],
313                outputs: vec![t1],
314                metadata: None,
315            })
316            .unwrap();
317
318        // Set output
319        graph.add_output(t1).unwrap();
320
321        let stats = optimize_graph(&mut graph).unwrap();
322        // Optimization stats should be computed (just check it doesn't panic)
323        let _total = stats.total_optimizations();
324    }
325}