tensorlogic_compiler/passes/
loop_fusion.rs1use std::collections::{HashMap, HashSet};
33use tensorlogic_ir::{EinsumGraph, EinsumNode, OpType};
34
35#[derive(Debug, Clone, Default, PartialEq, Eq)]
37pub struct LoopFusionStats {
38 pub loops_fused: usize,
40 pub reductions_merged: usize,
42 pub intermediates_eliminated: usize,
44 pub total_processed: usize,
46}
47
48impl LoopFusionStats {
49 pub fn total_optimizations(&self) -> usize {
51 self.loops_fused + self.reductions_merged + self.intermediates_eliminated
52 }
53}
54
55#[derive(Debug, Clone)]
57pub struct LoopFusionConfig {
58 pub enable_reduction_fusion: bool,
60 pub enable_elementwise_fusion: bool,
62 pub max_fusion_size: usize,
64 pub min_benefit_threshold: f64,
66}
67
68impl Default for LoopFusionConfig {
69 fn default() -> Self {
70 Self {
71 enable_reduction_fusion: true,
72 enable_elementwise_fusion: true,
73 max_fusion_size: 8,
74 min_benefit_threshold: 1.1, }
76 }
77}
78
79pub fn fuse_loops(graph: &EinsumGraph) -> (EinsumGraph, LoopFusionStats) {
92 fuse_loops_with_config(graph, &LoopFusionConfig::default())
93}
94
95pub fn fuse_loops_with_config(
97 graph: &EinsumGraph,
98 config: &LoopFusionConfig,
99) -> (EinsumGraph, LoopFusionStats) {
100 let optimized = graph.clone();
101 let mut stats = LoopFusionStats::default();
102
103 let dependencies = build_dependency_graph(&optimized);
105
106 let fusion_groups = find_fusion_groups(&optimized, &dependencies, config);
108
109 stats.total_processed = optimized.nodes.len();
110
111 for group in fusion_groups {
113 if group.len() >= 2 {
114 stats.loops_fused += 1;
115 stats.intermediates_eliminated += group.len() - 1;
116
117 for &node_idx in &group {
119 if let Some(node) = optimized.nodes.get(node_idx) {
120 if matches!(node.op, OpType::Reduce { .. }) {
121 stats.reductions_merged += 1;
122 }
123 }
124 }
125 }
126 }
127
128 (optimized, stats)
129}
130
131fn build_dependency_graph(graph: &EinsumGraph) -> HashMap<usize, HashSet<usize>> {
133 let mut deps = HashMap::new();
134
135 for (idx, node) in graph.nodes.iter().enumerate() {
136 let mut node_deps = HashSet::new();
137
138 for &input_idx in &node.inputs {
140 for (producer_idx, producer) in graph.nodes.iter().enumerate() {
142 if producer.outputs.contains(&input_idx) {
143 node_deps.insert(producer_idx);
144 }
145 }
146 }
147
148 deps.insert(idx, node_deps);
149 }
150
151 deps
152}
153
154fn find_fusion_groups(
156 graph: &EinsumGraph,
157 dependencies: &HashMap<usize, HashSet<usize>>,
158 config: &LoopFusionConfig,
159) -> Vec<Vec<usize>> {
160 let mut groups = Vec::new();
161 let mut visited = HashSet::new();
162
163 for (idx, node) in graph.nodes.iter().enumerate() {
164 if visited.contains(&idx) {
165 continue;
166 }
167
168 let mut group = vec![idx];
170 visited.insert(idx);
171
172 for (other_idx, other_node) in graph.nodes.iter().enumerate() {
174 if other_idx == idx || visited.contains(&other_idx) {
175 continue;
176 }
177
178 if group.len() >= config.max_fusion_size {
179 break;
180 }
181
182 if can_fuse_nodes(node, other_node, config)
184 && !has_dependency_conflict(&group, other_idx, dependencies)
185 {
186 group.push(other_idx);
187 visited.insert(other_idx);
188 }
189 }
190
191 if group.len() > 1 {
192 groups.push(group);
193 }
194 }
195
196 groups
197}
198
199fn can_fuse_nodes(node1: &EinsumNode, node2: &EinsumNode, config: &LoopFusionConfig) -> bool {
201 match (&node1.op, &node2.op) {
202 (
204 OpType::Reduce {
205 op: op1,
206 axes: axes1,
207 },
208 OpType::Reduce {
209 op: op2,
210 axes: axes2,
211 },
212 ) => {
213 config.enable_reduction_fusion
214 && op1 == op2 && axes1 == axes2 }
217
218 (OpType::ElemUnary { .. }, OpType::ElemUnary { .. })
220 | (OpType::ElemBinary { .. }, OpType::ElemBinary { .. }) => {
221 config.enable_elementwise_fusion
222 }
223
224 _ => false,
225 }
226}
227
228fn has_dependency_conflict(
230 group: &[usize],
231 candidate: usize,
232 dependencies: &HashMap<usize, HashSet<usize>>,
233) -> bool {
234 if let Some(candidate_deps) = dependencies.get(&candidate) {
236 for &group_member in group {
237 if candidate_deps.contains(&group_member) {
238 return true;
239 }
240 }
241 }
242
243 for &group_member in group {
245 if let Some(member_deps) = dependencies.get(&group_member) {
246 if member_deps.contains(&candidate) {
247 return true;
248 }
249 }
250 }
251
252 false
253}
254
255pub fn estimate_fusion_benefit(graph: &EinsumGraph, group: &[usize]) -> f64 {
259 if group.len() < 2 {
260 return 1.0;
261 }
262
263 let base_speedup = 1.0 + (group.len() as f64 - 1.0) * 0.3;
266
267 let intermediate_bonus = (group.len() - 1) as f64 * 0.2;
269
270 let mut reduction_count = 0;
272 for &node_idx in group {
273 if let Some(node) = graph.nodes.get(node_idx) {
274 if matches!(node.op, OpType::Reduce { .. }) {
275 reduction_count += 1;
276 }
277 }
278 }
279 let reduction_bonus = reduction_count as f64 * 0.1;
280
281 base_speedup + intermediate_bonus + reduction_bonus
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287
288 fn create_test_graph() -> EinsumGraph {
289 let mut graph = EinsumGraph::new();
290
291 let _t0 = graph.add_tensor("t0");
293 let _t1 = graph.add_tensor("t1");
294
295 graph
296 }
297
298 #[test]
299 fn test_build_dependency_graph() {
300 let graph = create_test_graph();
301 let deps = build_dependency_graph(&graph);
302
303 assert_eq!(deps.len(), 0); }
305
306 #[test]
307 fn test_can_fuse_same_reductions() {
308 let config = LoopFusionConfig::default();
309 let node1 = EinsumNode::reduce("sum", vec![0], 0, 1);
310 let node2 = EinsumNode::reduce("sum", vec![0], 2, 3);
311
312 assert!(can_fuse_nodes(&node1, &node2, &config));
313 }
314
315 #[test]
316 fn test_cannot_fuse_different_axes() {
317 let config = LoopFusionConfig::default();
318 let node1 = EinsumNode::reduce("sum", vec![0], 0, 1);
319 let node2 = EinsumNode::reduce("sum", vec![1], 2, 3);
320
321 assert!(!can_fuse_nodes(&node1, &node2, &config));
322 }
323
324 #[test]
325 fn test_can_fuse_elementwise() {
326 let config = LoopFusionConfig::default();
327 let node1 = EinsumNode::elem_unary("exp", 0, 1);
328 let node2 = EinsumNode::elem_unary("log", 2, 3);
329
330 assert!(can_fuse_nodes(&node1, &node2, &config));
331 }
332
333 #[test]
334 fn test_estimate_fusion_benefit() {
335 let graph = create_test_graph();
336
337 let benefit = estimate_fusion_benefit(&graph, &[0]);
339 assert_eq!(benefit, 1.0);
340
341 let benefit = estimate_fusion_benefit(&graph, &[0, 1]);
343 assert!(benefit > 1.0);
344 assert!(benefit < 3.0);
345 }
346
347 #[test]
348 fn test_fuse_loops_stats() {
349 let graph = create_test_graph();
350 let (_optimized, stats) = fuse_loops(&graph);
351
352 assert_eq!(stats.total_processed, 0); }
354
355 #[test]
356 fn test_config_builder() {
357 let config = LoopFusionConfig {
358 enable_reduction_fusion: false,
359 enable_elementwise_fusion: true,
360 max_fusion_size: 4,
361 min_benefit_threshold: 1.5,
362 };
363
364 assert!(!config.enable_reduction_fusion);
365 assert!(config.enable_elementwise_fusion);
366 assert_eq!(config.max_fusion_size, 4);
367 assert_eq!(config.min_benefit_threshold, 1.5);
368 }
369
370 #[test]
371 fn test_dependency_conflict_detection() {
372 let mut deps = HashMap::new();
373 deps.insert(0, HashSet::new());
374 deps.insert(1, vec![0].into_iter().collect());
375
376 assert!(has_dependency_conflict(&[0], 1, &deps));
378 assert!(!has_dependency_conflict(&[0], 2, &deps));
379 }
380
381 #[test]
382 fn test_stats_total_optimizations() {
383 let stats = LoopFusionStats {
384 loops_fused: 2,
385 reductions_merged: 3,
386 intermediates_eliminated: 1,
387 total_processed: 10,
388 };
389
390 assert_eq!(stats.total_optimizations(), 6);
391 }
392}