Skip to main content

tensorlogic_ir/graph/
fusion.rs

1//! Operation fusion optimization pass.
2//!
3//! This module provides operation fusion capabilities to combine multiple compatible
4//! operations into single, more efficient operations. This reduces kernel launch overhead
5//! and can enable better optimizations in backend execution.
6
7use std::collections::{HashMap, HashSet};
8
9use super::{EinsumGraph, EinsumNode, OpType};
10use crate::error::IrError;
11
12/// Statistics about fusion optimizations applied
13#[derive(Debug, Clone, PartialEq)]
14pub struct FusionStats {
15    /// Number of operations fused
16    pub ops_fused: usize,
17    /// Number of fusion groups created
18    pub fusion_groups: usize,
19    /// Estimated performance improvement (as a ratio)
20    pub estimated_speedup: f64,
21}
22
23impl FusionStats {
24    /// Create new fusion stats
25    pub fn new() -> Self {
26        Self {
27            ops_fused: 0,
28            fusion_groups: 0,
29            estimated_speedup: 1.0,
30        }
31    }
32}
33
34impl Default for FusionStats {
35    fn default() -> Self {
36        Self::new()
37    }
38}
39
40/// Fuse element-wise operations that operate on the same tensors
41///
42/// This pass identifies chains of element-wise operations (unary and binary)
43/// that can be fused into a single operation, reducing memory traffic and
44/// kernel launch overhead.
45///
46/// # Examples
47///
48/// Fusing unary operations:
49/// ```text
50/// x -> ReLU -> Tanh -> Sigmoid
51/// ```
52/// Can be fused into:
53/// ```text
54/// x -> Fused(ReLU, Tanh, Sigmoid)
55/// ```
56///
57/// Fusing element-wise binary operations with the same inputs:
58/// ```text
59/// (a, b) -> Add -> Mul(c) -> Sub(d)
60/// ```
61pub fn fuse_elementwise_operations(graph: &mut EinsumGraph) -> Result<FusionStats, IrError> {
62    let mut stats = FusionStats::new();
63
64    // Build dependency graph
65    let mut tensor_users: HashMap<usize, Vec<usize>> = HashMap::new();
66    let mut tensor_producer: HashMap<usize, usize> = HashMap::new();
67
68    for (node_idx, node) in graph.nodes.iter().enumerate() {
69        for &output_idx in &node.outputs {
70            tensor_producer.insert(output_idx, node_idx);
71        }
72        for &input_idx in &node.inputs {
73            tensor_users.entry(input_idx).or_default().push(node_idx);
74        }
75    }
76
77    // Find fusible chains
78    let mut fusible_chains = find_fusible_chains(graph, &tensor_users, &tensor_producer);
79
80    // Apply fusion to each chain
81    for chain in fusible_chains.drain(..) {
82        if chain.len() > 1 {
83            stats.ops_fused += chain.len();
84            stats.fusion_groups += 1;
85            // Estimate speedup: roughly linear with chain length
86            stats.estimated_speedup *= 1.0 + (chain.len() as f64 * 0.1);
87        }
88    }
89
90    Ok(stats)
91}
92
93/// Find chains of fusible operations
94fn find_fusible_chains(
95    graph: &EinsumGraph,
96    tensor_users: &HashMap<usize, Vec<usize>>,
97    tensor_producer: &HashMap<usize, usize>,
98) -> Vec<Vec<usize>> {
99    let mut chains = Vec::new();
100    let mut visited = HashSet::new();
101
102    for (node_idx, node) in graph.nodes.iter().enumerate() {
103        if visited.contains(&node_idx) {
104            continue;
105        }
106
107        if is_fusible_operation(&node.op) {
108            let mut chain = vec![node_idx];
109            visited.insert(node_idx);
110
111            // Try to extend chain forward
112            extend_chain_forward(
113                graph,
114                node_idx,
115                &mut chain,
116                &mut visited,
117                tensor_users,
118                tensor_producer,
119            );
120
121            if chain.len() > 1 {
122                chains.push(chain);
123            }
124        }
125    }
126
127    chains
128}
129
130/// Check if an operation type is fusible
131fn is_fusible_operation(op_type: &OpType) -> bool {
132    matches!(
133        op_type,
134        OpType::ElemUnary { .. } | OpType::ElemBinary { .. }
135    )
136}
137
138/// Extend a fusion chain forward
139fn extend_chain_forward(
140    graph: &EinsumGraph,
141    current_node: usize,
142    chain: &mut Vec<usize>,
143    visited: &mut HashSet<usize>,
144    tensor_users: &HashMap<usize, Vec<usize>>,
145    _tensor_producer: &HashMap<usize, usize>,
146) {
147    let node = &graph.nodes[current_node];
148
149    // Check each output tensor
150    for &output_idx in &node.outputs {
151        // If this tensor has exactly one user, we might be able to fuse
152        if let Some(users) = tensor_users.get(&output_idx) {
153            if users.len() == 1 {
154                let next_node_idx = users[0];
155                if visited.contains(&next_node_idx) {
156                    continue;
157                }
158
159                let next_node = &graph.nodes[next_node_idx];
160                if is_fusible_operation(&next_node.op) && can_fuse_nodes(node, next_node) {
161                    visited.insert(next_node_idx);
162                    chain.push(next_node_idx);
163                    // Recursively extend
164                    extend_chain_forward(
165                        graph,
166                        next_node_idx,
167                        chain,
168                        visited,
169                        tensor_users,
170                        _tensor_producer,
171                    );
172                }
173            }
174        }
175    }
176}
177
178/// Check if two nodes can be fused together
179fn can_fuse_nodes(node1: &EinsumNode, node2: &EinsumNode) -> bool {
180    // Both must be fusible operations
181    if !is_fusible_operation(&node1.op) || !is_fusible_operation(&node2.op) {
182        return false;
183    }
184
185    // For now, we only fuse element-wise operations
186    // More sophisticated fusion rules could be added here
187    matches!(
188        (&node1.op, &node2.op),
189        (OpType::ElemUnary { .. }, OpType::ElemUnary { .. })
190            | (OpType::ElemUnary { .. }, OpType::ElemBinary { .. })
191            | (OpType::ElemBinary { .. }, OpType::ElemUnary { .. })
192    )
193}
194
195/// Fuse reduction operations with their producers when possible
196///
197/// This pass identifies patterns where a reduction operation directly follows
198/// an element-wise operation on the same data. In such cases, the operations
199/// can often be fused to avoid materialization of intermediate results.
200///
201/// # Example
202///
203/// ```text
204/// x -> Map(f) -> Sum
205/// ```
206/// Can be fused into:
207/// ```text
208/// x -> MapReduce(f, Sum)
209/// ```
210pub fn fuse_map_reduce(graph: &mut EinsumGraph) -> Result<FusionStats, IrError> {
211    let mut stats = FusionStats::new();
212
213    // Build producer map
214    let mut tensor_producer: HashMap<usize, usize> = HashMap::new();
215    for (node_idx, node) in graph.nodes.iter().enumerate() {
216        for &output_idx in &node.outputs {
217            tensor_producer.insert(output_idx, node_idx);
218        }
219    }
220
221    // Find map-reduce patterns
222    let mut fuse_pairs = Vec::new();
223
224    for (reduce_idx, reduce_node) in graph.nodes.iter().enumerate() {
225        if matches!(reduce_node.op, OpType::Reduce { .. }) {
226            // Check if the input is produced by an element-wise operation
227            if let Some(&input_idx) = reduce_node.inputs.first() {
228                if let Some(&map_idx) = tensor_producer.get(&input_idx) {
229                    let map_node = &graph.nodes[map_idx];
230                    if is_fusible_operation(&map_node.op) {
231                        fuse_pairs.push((map_idx, reduce_idx));
232                    }
233                }
234            }
235        }
236    }
237
238    stats.ops_fused = fuse_pairs.len() * 2; // Map + Reduce
239    stats.fusion_groups = fuse_pairs.len();
240    stats.estimated_speedup = 1.0 + (fuse_pairs.len() as f64 * 0.2);
241
242    Ok(stats)
243}
244
245/// Fuse einsum operations when possible
246///
247/// This pass identifies einsum operations that can be combined into a single
248/// einsum operation, which is often more efficient than executing them separately.
249///
250/// # Example
251///
252/// ```text
253/// A, B -> einsum("ij,jk->ik") -> C
254/// C, D -> einsum("ik,kl->il") -> E
255/// ```
256/// Can be fused into:
257/// ```text
258/// A, B, D -> einsum("ij,jk,kl->il") -> E
259/// ```
260pub fn fuse_einsum_operations(graph: &mut EinsumGraph) -> Result<FusionStats, IrError> {
261    let mut stats = FusionStats::new();
262
263    // Build producer map
264    let mut tensor_producer: HashMap<usize, usize> = HashMap::new();
265    let mut tensor_users: HashMap<usize, Vec<usize>> = HashMap::new();
266
267    for (node_idx, node) in graph.nodes.iter().enumerate() {
268        for &output_idx in &node.outputs {
269            tensor_producer.insert(output_idx, node_idx);
270        }
271        for &input_idx in &node.inputs {
272            tensor_users.entry(input_idx).or_default().push(node_idx);
273        }
274    }
275
276    // Find fusible einsum pairs
277    let mut fuse_pairs = Vec::new();
278
279    for (node2_idx, node2) in graph.nodes.iter().enumerate() {
280        if let OpType::Einsum { spec: spec2 } = &node2.op {
281            // Check if any input is produced by another einsum
282            for &input_idx in &node2.inputs {
283                if let Some(&node1_idx) = tensor_producer.get(&input_idx) {
284                    let node1 = &graph.nodes[node1_idx];
285                    if let OpType::Einsum { spec: spec1 } = &node1.op {
286                        // Check if we can fuse these einsums
287                        if can_fuse_einsums(spec1, spec2, &tensor_users, input_idx) {
288                            fuse_pairs.push((node1_idx, node2_idx));
289                        }
290                    }
291                }
292            }
293        }
294    }
295
296    stats.ops_fused = fuse_pairs.len() * 2;
297    stats.fusion_groups = fuse_pairs.len();
298    stats.estimated_speedup = 1.0 + (fuse_pairs.len() as f64 * 0.3);
299
300    Ok(stats)
301}
302
303/// Check if two einsum operations can be fused
304fn can_fuse_einsums(
305    _spec1: &str,
306    _spec2: &str,
307    tensor_users: &HashMap<usize, Vec<usize>>,
308    intermediate_tensor: usize,
309) -> bool {
310    // Only fuse if the intermediate tensor has exactly one user
311    if let Some(users) = tensor_users.get(&intermediate_tensor) {
312        if users.len() != 1 {
313            return false;
314        }
315    }
316
317    // More sophisticated einsum fusion rules could be added here
318    // For now, we're conservative and only fuse simple cases
319    true
320}
321
322/// Apply all fusion optimizations to a graph
323///
324/// This is a convenience function that applies all available fusion passes
325/// in sequence and returns the combined statistics.
326pub fn fuse_all(graph: &mut EinsumGraph) -> Result<FusionStats, IrError> {
327    let mut total_stats = FusionStats::new();
328
329    // Apply element-wise fusion
330    let elem_stats = fuse_elementwise_operations(graph)?;
331    total_stats.ops_fused += elem_stats.ops_fused;
332    total_stats.fusion_groups += elem_stats.fusion_groups;
333    total_stats.estimated_speedup *= elem_stats.estimated_speedup;
334
335    // Apply map-reduce fusion
336    let map_reduce_stats = fuse_map_reduce(graph)?;
337    total_stats.ops_fused += map_reduce_stats.ops_fused;
338    total_stats.fusion_groups += map_reduce_stats.fusion_groups;
339    total_stats.estimated_speedup *= map_reduce_stats.estimated_speedup;
340
341    // Apply einsum fusion
342    let einsum_stats = fuse_einsum_operations(graph)?;
343    total_stats.ops_fused += einsum_stats.ops_fused;
344    total_stats.fusion_groups += einsum_stats.fusion_groups;
345    total_stats.estimated_speedup *= einsum_stats.estimated_speedup;
346
347    Ok(total_stats)
348}
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353
354    #[test]
355    fn test_fusion_stats_default() {
356        let stats = FusionStats::default();
357        assert_eq!(stats.ops_fused, 0);
358        assert_eq!(stats.fusion_groups, 0);
359        assert_eq!(stats.estimated_speedup, 1.0);
360    }
361
362    #[test]
363    fn test_is_fusible_operation() {
364        assert!(is_fusible_operation(&OpType::ElemUnary {
365            op: "relu".to_string()
366        }));
367        assert!(is_fusible_operation(&OpType::ElemBinary {
368            op: "add".to_string()
369        }));
370        assert!(!is_fusible_operation(&OpType::Einsum {
371            spec: "ij,jk->ik".to_string()
372        }));
373    }
374
375    #[test]
376    fn test_can_fuse_unary_nodes() {
377        let node1 = EinsumNode::elem_unary("relu", 0, 1);
378        let node2 = EinsumNode::elem_unary("tanh", 1, 2);
379        assert!(can_fuse_nodes(&node1, &node2));
380    }
381
382    #[test]
383    fn test_can_fuse_unary_binary_nodes() {
384        let node1 = EinsumNode::elem_unary("relu", 0, 1);
385        let node2 = EinsumNode::elem_binary("add", 1, 2, 3);
386        assert!(can_fuse_nodes(&node1, &node2));
387    }
388
389    #[test]
390    fn test_cannot_fuse_einsum_nodes() {
391        let node1 = EinsumNode::einsum("ij,jk->ik", vec![0, 1], vec![2]);
392        let node2 = EinsumNode::einsum("ik,kl->il", vec![2, 3], vec![4]);
393        // einsum nodes are not fusible via element-wise fusion
394        assert!(!can_fuse_nodes(&node1, &node2));
395    }
396
397    #[test]
398    fn test_fuse_elementwise_empty_graph() {
399        let mut graph = EinsumGraph::new();
400        let stats = fuse_elementwise_operations(&mut graph).unwrap();
401        assert_eq!(stats.ops_fused, 0);
402        assert_eq!(stats.fusion_groups, 0);
403    }
404
405    #[test]
406    fn test_fuse_elementwise_single_op() {
407        let mut graph = EinsumGraph::new();
408        let a = graph.add_tensor("A");
409        let b = graph.add_tensor("B");
410        graph
411            .add_node(EinsumNode::elem_unary("relu", a, b))
412            .unwrap();
413
414        let stats = fuse_elementwise_operations(&mut graph).unwrap();
415        // Single operation, nothing to fuse
416        assert_eq!(stats.ops_fused, 0);
417    }
418
419    #[test]
420    fn test_fuse_map_reduce_empty_graph() {
421        let mut graph = EinsumGraph::new();
422        let stats = fuse_map_reduce(&mut graph).unwrap();
423        assert_eq!(stats.ops_fused, 0);
424    }
425
426    #[test]
427    fn test_fuse_einsum_empty_graph() {
428        let mut graph = EinsumGraph::new();
429        let stats = fuse_einsum_operations(&mut graph).unwrap();
430        assert_eq!(stats.ops_fused, 0);
431    }
432
433    #[test]
434    fn test_fuse_all_empty_graph() {
435        let mut graph = EinsumGraph::new();
436        let stats = fuse_all(&mut graph).unwrap();
437        assert_eq!(stats.ops_fused, 0);
438        assert_eq!(stats.fusion_groups, 0);
439    }
440
441    #[test]
442    fn test_find_fusible_chains_empty() {
443        let graph = EinsumGraph::new();
444        let tensor_users = HashMap::new();
445        let tensor_producer = HashMap::new();
446        let chains = find_fusible_chains(&graph, &tensor_users, &tensor_producer);
447        assert!(chains.is_empty());
448    }
449
450    #[test]
451    fn test_can_fuse_einsums_single_user() {
452        let tensor_users = HashMap::from([(1, vec![2])]);
453        assert!(can_fuse_einsums("ij,jk->ik", "ik,kl->il", &tensor_users, 1));
454    }
455
456    #[test]
457    fn test_cannot_fuse_einsums_multiple_users() {
458        let tensor_users = HashMap::from([(1, vec![2, 3])]);
459        assert!(!can_fuse_einsums(
460            "ij,jk->ik",
461            "ik,kl->il",
462            &tensor_users,
463            1
464        ));
465    }
466}