tensorlogic_ir/graph/
optimization.rs1use std::collections::{HashMap, HashSet};
4
5use crate::{EinsumGraph, EinsumNode, IrError};
6
7pub fn eliminate_dead_code(graph: &mut EinsumGraph) -> Result<usize, IrError> {
9 if graph.outputs.is_empty() {
10 return Ok(0);
11 }
12
13 let mut live_tensors = HashSet::new();
15 let mut worklist: Vec<usize> = graph.outputs.clone();
16
17 for &output_idx in &graph.outputs {
19 live_tensors.insert(output_idx);
20 }
21
22 let mut tensor_producers: HashMap<usize, usize> = HashMap::new();
24 for (node_idx, _node) in graph.nodes.iter().enumerate() {
25 let produced_tensor_idx = node_idx + count_input_tensors(graph, node_idx);
28 tensor_producers.insert(produced_tensor_idx, node_idx);
29 }
30
31 while let Some(tensor_idx) = worklist.pop() {
33 if let Some(&node_idx) = tensor_producers.get(&tensor_idx) {
34 let node = &graph.nodes[node_idx];
35 for &input_idx in &node.inputs {
36 if !live_tensors.contains(&input_idx) {
37 live_tensors.insert(input_idx);
38 worklist.push(input_idx);
39 }
40 }
41 }
42 }
43
44 let mut removed_count = 0;
46
47 let mut nodes_to_keep = Vec::new();
49 for (node_idx, node) in graph.nodes.iter().enumerate() {
50 let produced_tensor_idx = node_idx + count_input_tensors(graph, node_idx);
51 if live_tensors.contains(&produced_tensor_idx) {
52 nodes_to_keep.push(node.clone());
53 } else {
54 removed_count += 1;
55 }
56 }
57
58 graph.nodes = nodes_to_keep;
59
60 Ok(removed_count)
65}
66
67#[allow(dead_code)]
68fn count_input_tensors(graph: &EinsumGraph, before_node: usize) -> usize {
69 graph
72 .nodes
73 .iter()
74 .take(before_node)
75 .map(|_| 1) .sum()
77}
78
79pub fn eliminate_common_subexpressions(graph: &mut EinsumGraph) -> Result<usize, IrError> {
81 let mut node_hashes: HashMap<String, usize> = HashMap::new();
82 let mut replacements: HashMap<usize, usize> = HashMap::new();
83 let mut eliminated_count = 0;
84
85 for (node_idx, node) in graph.nodes.iter().enumerate() {
87 let node_hash = compute_node_hash(node);
88
89 if let Some(&existing_idx) = node_hashes.get(&node_hash) {
90 let produced_tensor_idx = node_idx + count_input_tensors(graph, node_idx);
92 let existing_tensor_idx = existing_idx + count_input_tensors(graph, existing_idx);
93 replacements.insert(produced_tensor_idx, existing_tensor_idx);
94 eliminated_count += 1;
95 } else {
96 node_hashes.insert(node_hash, node_idx);
97 }
98 }
99
100 for node in &mut graph.nodes {
102 for input_idx in &mut node.inputs {
103 if let Some(&replacement) = replacements.get(input_idx) {
104 *input_idx = replacement;
105 }
106 }
107 }
108
109 for output_idx in &mut graph.outputs {
110 if let Some(&replacement) = replacements.get(output_idx) {
111 *output_idx = replacement;
112 }
113 }
114
115 Ok(eliminated_count)
117}
118
119#[allow(dead_code)]
120fn compute_node_hash(node: &EinsumNode) -> String {
121 format!("{:?}|{:?}", node.op, node.inputs)
124}
125
126pub fn simplify_identity_operations(graph: &mut EinsumGraph) -> Result<usize, IrError> {
128 let mut simplified_count = 0;
129 let mut replacements: HashMap<usize, usize> = HashMap::new();
130
131 for (node_idx, node) in graph.nodes.iter().enumerate() {
132 if is_identity_operation(node) && !node.inputs.is_empty() {
133 let produced_tensor_idx = node_idx + count_input_tensors(graph, node_idx);
135 replacements.insert(produced_tensor_idx, node.inputs[0]);
136 simplified_count += 1;
137 }
138 }
139
140 for node in &mut graph.nodes {
142 for input_idx in &mut node.inputs {
143 if let Some(&replacement) = replacements.get(input_idx) {
144 *input_idx = replacement;
145 }
146 }
147 }
148
149 for output_idx in &mut graph.outputs {
150 if let Some(&replacement) = replacements.get(output_idx) {
151 *output_idx = replacement;
152 }
153 }
154
155 Ok(simplified_count)
156}
157
158#[allow(dead_code)]
159fn is_identity_operation(node: &EinsumNode) -> bool {
160 use crate::OpType;
161
162 match &node.op {
163 OpType::Einsum { spec } => {
165 if let Some(arrow_pos) = spec.find("->") {
166 let input_axes = &spec[..arrow_pos];
167 let output_axes = &spec[arrow_pos + 2..];
168 input_axes == output_axes && node.inputs.len() == 1
169 } else {
170 false
171 }
172 }
173 _ => false,
175 }
176}
177
178pub fn optimize_graph(graph: &mut EinsumGraph) -> Result<OptimizationStats, IrError> {
180 let mut stats = OptimizationStats::default();
181
182 for _ in 0..3 {
184 let cse_count = eliminate_common_subexpressions(graph)?;
185 stats.cse_eliminated += cse_count;
186
187 let identity_count = simplify_identity_operations(graph)?;
188 stats.identities_simplified += identity_count;
189
190 let dce_count = eliminate_dead_code(graph)?;
191 stats.dead_code_eliminated += dce_count;
192
193 if cse_count == 0 && identity_count == 0 && dce_count == 0 {
195 break;
196 }
197 }
198
199 Ok(stats)
200}
201
202#[derive(Debug, Default, Clone, Copy)]
203pub struct OptimizationStats {
204 pub dead_code_eliminated: usize,
205 pub cse_eliminated: usize,
206 pub identities_simplified: usize,
207}
208
209impl OptimizationStats {
210 pub fn total_optimizations(&self) -> usize {
211 self.dead_code_eliminated + self.cse_eliminated + self.identities_simplified
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218 use crate::OpType;
219
220 #[test]
221 fn test_dead_code_elimination_empty_graph() {
222 let mut graph = EinsumGraph::new();
223 let removed = eliminate_dead_code(&mut graph).unwrap();
224 assert_eq!(removed, 0);
225 }
226
227 #[test]
228 fn test_dead_code_elimination_no_outputs() {
229 let mut graph = EinsumGraph::new();
230 graph.add_tensor("a[i]");
231 graph.add_tensor("b[i]");
232 let removed = eliminate_dead_code(&mut graph).unwrap();
233 assert_eq!(removed, 0); }
235
236 #[test]
237 fn test_identity_operation_detection() {
238 let identity_node = EinsumNode {
239 op: OpType::Einsum {
240 spec: "a->a".to_string(),
241 },
242 inputs: vec![0],
243 outputs: vec![1],
244 metadata: None,
245 };
246 assert!(is_identity_operation(&identity_node));
247
248 let non_identity_node = EinsumNode {
249 op: OpType::Einsum {
250 spec: "ab->a".to_string(),
251 },
252 inputs: vec![0],
253 outputs: vec![1],
254 metadata: None,
255 };
256 assert!(!is_identity_operation(&non_identity_node));
257 }
258
259 #[test]
260 fn test_node_hash_computation() {
261 let node1 = EinsumNode {
262 op: OpType::Einsum {
263 spec: "ab->a".to_string(),
264 },
265 inputs: vec![0],
266 outputs: vec![1],
267 metadata: None,
268 };
269 let node2 = EinsumNode {
270 op: OpType::Einsum {
271 spec: "ab->a".to_string(),
272 },
273 inputs: vec![0],
274 outputs: vec![1],
275 metadata: None,
276 };
277 let node3 = EinsumNode {
278 op: OpType::Einsum {
279 spec: "ab->b".to_string(),
280 },
281 inputs: vec![0],
282 outputs: vec![1],
283 metadata: None,
284 };
285
286 assert_eq!(compute_node_hash(&node1), compute_node_hash(&node2));
287 assert_ne!(compute_node_hash(&node1), compute_node_hash(&node3));
288 }
289
290 #[test]
291 fn test_optimization_stats() {
292 let stats = OptimizationStats {
293 dead_code_eliminated: 2,
294 cse_eliminated: 3,
295 identities_simplified: 1,
296 };
297 assert_eq!(stats.total_optimizations(), 6);
298 }
299
300 #[test]
301 fn test_full_optimization_pipeline() {
302 let mut graph = EinsumGraph::new();
303 let t0 = graph.add_tensor("input[a]");
304 let t1 = graph.add_tensor("output[a]");
305
306 let _n1 = graph
308 .add_node(EinsumNode {
309 op: OpType::Einsum {
310 spec: "a->a".to_string(),
311 },
312 inputs: vec![t0],
313 outputs: vec![t1],
314 metadata: None,
315 })
316 .unwrap();
317
318 graph.add_output(t1).unwrap();
320
321 let stats = optimize_graph(&mut graph).unwrap();
322 let _total = stats.total_optimizations();
324 }
325}