Skip to main content

torsh_functional/fusion/
engine.rs

1//! Basic fusion engine infrastructure and pattern detection
2//!
3//! This module contains the core OpFusionEngine and pattern detection
4//! functionality for identifying fusible operation sequences.
5
6use super::core::FusedOp;
7
8/// Auto-detection and fusion of operation patterns
9pub struct OpFusionEngine {
10    pub enabled: bool,
11    pub fusion_threshold: usize,
12}
13
14impl OpFusionEngine {
15    pub fn new() -> Self {
16        Self {
17            enabled: true,
18            fusion_threshold: 2, // Minimum number of operations to consider for fusion
19        }
20    }
21
22    /// Analyze a sequence of operations and suggest fusion opportunities
23    pub fn analyze_sequence(&self, ops: &[&str]) -> Vec<FusedOp> {
24        let mut fused_ops = Vec::new();
25
26        for window in ops.windows(2) {
27            match window {
28                ["add", "relu"] => fused_ops.push(FusedOp::ReluAdd),
29                ["mul", "add"] => fused_ops.push(FusedOp::MulAdd),
30                ["add", "mul"] => fused_ops.push(FusedOp::AddMul),
31                ["sigmoid", "mul"] => fused_ops.push(FusedOp::SigmoidMul),
32                ["tanh", "scale"] => fused_ops.push(FusedOp::TanhScale),
33                _ => {}
34            }
35        }
36
37        // Look for longer patterns
38        for window in ops.windows(3) {
39            match window {
40                ["add", "relu", "mul"] => {
41                    // Remove the individual operations if they were detected
42                    fused_ops.retain(|op| !matches!(op, FusedOp::ReluAdd));
43                    fused_ops.push(FusedOp::AddReluMul);
44                }
45                _ => {}
46            }
47        }
48
49        fused_ops
50    }
51
52    /// Check if fusion is beneficial for the given operation sequence
53    pub fn should_fuse(&self, ops: &[&str]) -> bool {
54        self.enabled && ops.len() >= self.fusion_threshold
55    }
56}
57
58impl Default for OpFusionEngine {
59    fn default() -> Self {
60        Self::new()
61    }
62}
63
64/// Pattern matcher for common operation sequences
65pub fn detect_fusible_patterns(operations: &[&str]) -> Vec<(usize, FusedOp)> {
66    let mut patterns = Vec::new();
67
68    // Detect 2-operation patterns
69    for (i, window) in operations.windows(2).enumerate() {
70        match window {
71            ["add", "relu"] => patterns.push((i, FusedOp::ReluAdd)),
72            ["mul", "add"] => patterns.push((i, FusedOp::MulAdd)),
73            ["add", "mul"] => patterns.push((i, FusedOp::AddMul)),
74            ["sigmoid", "mul"] => patterns.push((i, FusedOp::SigmoidMul)),
75            ["tanh", "scale"] => patterns.push((i, FusedOp::TanhScale)),
76            _ => {}
77        }
78    }
79
80    // Detect 3-operation patterns (and remove conflicting 2-op patterns)
81    for (i, window) in operations.windows(3).enumerate() {
82        match window {
83            ["add", "relu", "mul"] => {
84                // Remove conflicting 2-op patterns
85                patterns.retain(|(pos, op)| {
86                    !(*pos == i && matches!(op, FusedOp::ReluAdd))
87                        && !(*pos == i + 1 && matches!(op, FusedOp::AddMul))
88                });
89                patterns.push((i, FusedOp::AddReluMul));
90            }
91            _ => {}
92        }
93    }
94
95    patterns
96}