Skip to main content

tensorlogic_compiler/passes/
loop_fusion.rs

1//! Loop fusion optimization pass.
2//!
3//! This module provides optimization passes that fuse multiple loops/reductions
4//! over the same axes to improve cache locality and reduce memory traffic.
5//!
6//! # Overview
7//!
8//! Loop fusion combines multiple consecutive operations that iterate over
9//! the same axis into a single fused operation. This optimization:
10//! - Reduces memory traffic (fewer intermediate tensors)
11//! - Improves cache locality (better temporal locality)
12//! - Reduces loop overhead (fewer loop iterations)
13//!
14//! # Fusion Criteria
15//!
16//! Two loops can be fused if:
17//! 1. They iterate over the same axis/axes
18//! 2. They have compatible domains
19//! 3. There are no dependencies that prevent fusion
20//! 4. The fused operation doesn't exceed memory constraints
21//!
22//! # Examples
23//!
24//! ```rust
25//! use tensorlogic_compiler::passes::fuse_loops;
26//! use tensorlogic_ir::EinsumGraph;
27//!
28//! let graph = EinsumGraph::new();
29//! let (fused_graph, stats) = fuse_loops(&graph);
30//! ```
31
32use std::collections::{HashMap, HashSet};
33use tensorlogic_ir::{EinsumGraph, EinsumNode, OpType};
34
35/// Statistics from loop fusion optimization.
36#[derive(Debug, Clone, Default, PartialEq, Eq)]
37pub struct LoopFusionStats {
38    /// Number of loop pairs successfully fused
39    pub loops_fused: usize,
40    /// Number of reduction operations merged
41    pub reductions_merged: usize,
42    /// Number of intermediate tensors eliminated
43    pub intermediates_eliminated: usize,
44    /// Total number of nodes processed
45    pub total_processed: usize,
46}
47
48impl LoopFusionStats {
49    /// Get total number of optimizations applied.
50    pub fn total_optimizations(&self) -> usize {
51        self.loops_fused + self.reductions_merged + self.intermediates_eliminated
52    }
53}
54
55/// Configuration for loop fusion optimization.
56#[derive(Debug, Clone)]
57pub struct LoopFusionConfig {
58    /// Enable fusion of reduction operations
59    pub enable_reduction_fusion: bool,
60    /// Enable fusion of element-wise operations
61    pub enable_elementwise_fusion: bool,
62    /// Maximum number of operations to fuse together
63    pub max_fusion_size: usize,
64    /// Minimum benefit threshold (estimated speedup factor)
65    pub min_benefit_threshold: f64,
66}
67
68impl Default for LoopFusionConfig {
69    fn default() -> Self {
70        Self {
71            enable_reduction_fusion: true,
72            enable_elementwise_fusion: true,
73            max_fusion_size: 8,
74            min_benefit_threshold: 1.1, // At least 10% speedup
75        }
76    }
77}
78
79/// Fuse loops in an einsum graph.
80///
81/// This function identifies opportunities to fuse multiple loops/reductions
82/// over the same axes and combines them into single fused operations.
83///
84/// # Arguments
85///
86/// * `graph` - The einsum graph to optimize
87///
88/// # Returns
89///
90/// A tuple of (optimized_graph, statistics)
91pub fn fuse_loops(graph: &EinsumGraph) -> (EinsumGraph, LoopFusionStats) {
92    fuse_loops_with_config(graph, &LoopFusionConfig::default())
93}
94
95/// Fuse loops with custom configuration.
96pub fn fuse_loops_with_config(
97    graph: &EinsumGraph,
98    config: &LoopFusionConfig,
99) -> (EinsumGraph, LoopFusionStats) {
100    let optimized = graph.clone();
101    let mut stats = LoopFusionStats::default();
102
103    // Build dependency graph
104    let dependencies = build_dependency_graph(&optimized);
105
106    // Find fusible loop groups
107    let fusion_groups = find_fusion_groups(&optimized, &dependencies, config);
108
109    stats.total_processed = optimized.nodes.len();
110
111    // Count potential fusions
112    for group in fusion_groups {
113        if group.len() >= 2 {
114            stats.loops_fused += 1;
115            stats.intermediates_eliminated += group.len() - 1;
116
117            // Check if we would fuse reductions
118            for &node_idx in &group {
119                if let Some(node) = optimized.nodes.get(node_idx) {
120                    if matches!(node.op, OpType::Reduce { .. }) {
121                        stats.reductions_merged += 1;
122                    }
123                }
124            }
125        }
126    }
127
128    (optimized, stats)
129}
130
131/// Build a dependency graph showing which nodes depend on which.
132fn build_dependency_graph(graph: &EinsumGraph) -> HashMap<usize, HashSet<usize>> {
133    let mut deps = HashMap::new();
134
135    for (idx, node) in graph.nodes.iter().enumerate() {
136        let mut node_deps = HashSet::new();
137
138        // Add dependencies from input tensors
139        for &input_idx in &node.inputs {
140            // Find which node produced this tensor
141            for (producer_idx, producer) in graph.nodes.iter().enumerate() {
142                if producer.outputs.contains(&input_idx) {
143                    node_deps.insert(producer_idx);
144                }
145            }
146        }
147
148        deps.insert(idx, node_deps);
149    }
150
151    deps
152}
153
154/// Find groups of nodes that can be fused together.
155fn find_fusion_groups(
156    graph: &EinsumGraph,
157    dependencies: &HashMap<usize, HashSet<usize>>,
158    config: &LoopFusionConfig,
159) -> Vec<Vec<usize>> {
160    let mut groups = Vec::new();
161    let mut visited = HashSet::new();
162
163    for (idx, node) in graph.nodes.iter().enumerate() {
164        if visited.contains(&idx) {
165            continue;
166        }
167
168        // Start a new potential fusion group
169        let mut group = vec![idx];
170        visited.insert(idx);
171
172        // Try to find compatible nodes to fuse with
173        for (other_idx, other_node) in graph.nodes.iter().enumerate() {
174            if other_idx == idx || visited.contains(&other_idx) {
175                continue;
176            }
177
178            if group.len() >= config.max_fusion_size {
179                break;
180            }
181
182            // Check if nodes are fusible
183            if can_fuse_nodes(node, other_node, config)
184                && !has_dependency_conflict(&group, other_idx, dependencies)
185            {
186                group.push(other_idx);
187                visited.insert(other_idx);
188            }
189        }
190
191        if group.len() > 1 {
192            groups.push(group);
193        }
194    }
195
196    groups
197}
198
199/// Check if two nodes can be fused together.
200fn can_fuse_nodes(node1: &EinsumNode, node2: &EinsumNode, config: &LoopFusionConfig) -> bool {
201    match (&node1.op, &node2.op) {
202        // Fuse reductions over the same axes
203        (
204            OpType::Reduce {
205                op: op1,
206                axes: axes1,
207            },
208            OpType::Reduce {
209                op: op2,
210                axes: axes2,
211            },
212        ) => {
213            config.enable_reduction_fusion
214                && op1 == op2 // Same reduction operation
215                && axes1 == axes2 // Same axes
216        }
217
218        // Fuse element-wise operations
219        (OpType::ElemUnary { .. }, OpType::ElemUnary { .. })
220        | (OpType::ElemBinary { .. }, OpType::ElemBinary { .. }) => {
221            config.enable_elementwise_fusion
222        }
223
224        _ => false,
225    }
226}
227
228/// Check if adding a node to a group would create dependency conflicts.
229fn has_dependency_conflict(
230    group: &[usize],
231    candidate: usize,
232    dependencies: &HashMap<usize, HashSet<usize>>,
233) -> bool {
234    // Check if candidate depends on any node in the group
235    if let Some(candidate_deps) = dependencies.get(&candidate) {
236        for &group_member in group {
237            if candidate_deps.contains(&group_member) {
238                return true;
239            }
240        }
241    }
242
243    // Check if any node in the group depends on candidate
244    for &group_member in group {
245        if let Some(member_deps) = dependencies.get(&group_member) {
246            if member_deps.contains(&candidate) {
247                return true;
248            }
249        }
250    }
251
252    false
253}
254
255/// Estimate the benefit of fusing a group of nodes.
256///
257/// Returns an estimated speedup factor (1.0 = no benefit, 2.0 = 2x speedup).
258pub fn estimate_fusion_benefit(graph: &EinsumGraph, group: &[usize]) -> f64 {
259    if group.len() < 2 {
260        return 1.0;
261    }
262
263    // Simple heuristic: speedup ~ number of fused operations
264    // In practice, fusion reduces memory traffic and loop overhead
265    let base_speedup = 1.0 + (group.len() as f64 - 1.0) * 0.3;
266
267    // Bonus for reducing intermediate tensors
268    let intermediate_bonus = (group.len() - 1) as f64 * 0.2;
269
270    // Check if we're fusing reductions (higher benefit)
271    let mut reduction_count = 0;
272    for &node_idx in group {
273        if let Some(node) = graph.nodes.get(node_idx) {
274            if matches!(node.op, OpType::Reduce { .. }) {
275                reduction_count += 1;
276            }
277        }
278    }
279    let reduction_bonus = reduction_count as f64 * 0.1;
280
281    base_speedup + intermediate_bonus + reduction_bonus
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287
288    fn create_test_graph() -> EinsumGraph {
289        let mut graph = EinsumGraph::new();
290
291        // Add some tensors
292        let _t0 = graph.add_tensor("t0");
293        let _t1 = graph.add_tensor("t1");
294
295        graph
296    }
297
298    #[test]
299    fn test_build_dependency_graph() {
300        let graph = create_test_graph();
301        let deps = build_dependency_graph(&graph);
302
303        assert_eq!(deps.len(), 0); // No nodes yet
304    }
305
306    #[test]
307    fn test_can_fuse_same_reductions() {
308        let config = LoopFusionConfig::default();
309        let node1 = EinsumNode::reduce("sum", vec![0], 0, 1);
310        let node2 = EinsumNode::reduce("sum", vec![0], 2, 3);
311
312        assert!(can_fuse_nodes(&node1, &node2, &config));
313    }
314
315    #[test]
316    fn test_cannot_fuse_different_axes() {
317        let config = LoopFusionConfig::default();
318        let node1 = EinsumNode::reduce("sum", vec![0], 0, 1);
319        let node2 = EinsumNode::reduce("sum", vec![1], 2, 3);
320
321        assert!(!can_fuse_nodes(&node1, &node2, &config));
322    }
323
324    #[test]
325    fn test_can_fuse_elementwise() {
326        let config = LoopFusionConfig::default();
327        let node1 = EinsumNode::elem_unary("exp", 0, 1);
328        let node2 = EinsumNode::elem_unary("log", 2, 3);
329
330        assert!(can_fuse_nodes(&node1, &node2, &config));
331    }
332
333    #[test]
334    fn test_estimate_fusion_benefit() {
335        let graph = create_test_graph();
336
337        // Single node: no benefit
338        let benefit = estimate_fusion_benefit(&graph, &[0]);
339        assert_eq!(benefit, 1.0);
340
341        // Two nodes: significant benefit
342        let benefit = estimate_fusion_benefit(&graph, &[0, 1]);
343        assert!(benefit > 1.0);
344        assert!(benefit < 3.0);
345    }
346
347    #[test]
348    fn test_fuse_loops_stats() {
349        let graph = create_test_graph();
350        let (_optimized, stats) = fuse_loops(&graph);
351
352        assert_eq!(stats.total_processed, 0); // No nodes
353    }
354
355    #[test]
356    fn test_config_builder() {
357        let config = LoopFusionConfig {
358            enable_reduction_fusion: false,
359            enable_elementwise_fusion: true,
360            max_fusion_size: 4,
361            min_benefit_threshold: 1.5,
362        };
363
364        assert!(!config.enable_reduction_fusion);
365        assert!(config.enable_elementwise_fusion);
366        assert_eq!(config.max_fusion_size, 4);
367        assert_eq!(config.min_benefit_threshold, 1.5);
368    }
369
370    #[test]
371    fn test_dependency_conflict_detection() {
372        let mut deps = HashMap::new();
373        deps.insert(0, HashSet::new());
374        deps.insert(1, vec![0].into_iter().collect());
375
376        // Node 1 depends on node 0, so they cannot be fused
377        assert!(has_dependency_conflict(&[0], 1, &deps));
378        assert!(!has_dependency_conflict(&[0], 2, &deps));
379    }
380
381    #[test]
382    fn test_stats_total_optimizations() {
383        let stats = LoopFusionStats {
384            loops_fused: 2,
385            reductions_merged: 3,
386            intermediates_eliminated: 1,
387            total_processed: 10,
388        };
389
390        assert_eq!(stats.total_optimizations(), 6);
391    }
392}