torsh_functional/fusion/
engine.rs1use super::core::FusedOp;
7
8pub struct OpFusionEngine {
10 pub enabled: bool,
11 pub fusion_threshold: usize,
12}
13
14impl OpFusionEngine {
15 pub fn new() -> Self {
16 Self {
17 enabled: true,
18 fusion_threshold: 2, }
20 }
21
22 pub fn analyze_sequence(&self, ops: &[&str]) -> Vec<FusedOp> {
24 let mut fused_ops = Vec::new();
25
26 for window in ops.windows(2) {
27 match window {
28 ["add", "relu"] => fused_ops.push(FusedOp::ReluAdd),
29 ["mul", "add"] => fused_ops.push(FusedOp::MulAdd),
30 ["add", "mul"] => fused_ops.push(FusedOp::AddMul),
31 ["sigmoid", "mul"] => fused_ops.push(FusedOp::SigmoidMul),
32 ["tanh", "scale"] => fused_ops.push(FusedOp::TanhScale),
33 _ => {}
34 }
35 }
36
37 for window in ops.windows(3) {
39 match window {
40 ["add", "relu", "mul"] => {
41 fused_ops.retain(|op| !matches!(op, FusedOp::ReluAdd));
43 fused_ops.push(FusedOp::AddReluMul);
44 }
45 _ => {}
46 }
47 }
48
49 fused_ops
50 }
51
52 pub fn should_fuse(&self, ops: &[&str]) -> bool {
54 self.enabled && ops.len() >= self.fusion_threshold
55 }
56}
57
58impl Default for OpFusionEngine {
59 fn default() -> Self {
60 Self::new()
61 }
62}
63
64pub fn detect_fusible_patterns(operations: &[&str]) -> Vec<(usize, FusedOp)> {
66 let mut patterns = Vec::new();
67
68 for (i, window) in operations.windows(2).enumerate() {
70 match window {
71 ["add", "relu"] => patterns.push((i, FusedOp::ReluAdd)),
72 ["mul", "add"] => patterns.push((i, FusedOp::MulAdd)),
73 ["add", "mul"] => patterns.push((i, FusedOp::AddMul)),
74 ["sigmoid", "mul"] => patterns.push((i, FusedOp::SigmoidMul)),
75 ["tanh", "scale"] => patterns.push((i, FusedOp::TanhScale)),
76 _ => {}
77 }
78 }
79
80 for (i, window) in operations.windows(3).enumerate() {
82 match window {
83 ["add", "relu", "mul"] => {
84 patterns.retain(|(pos, op)| {
86 !(*pos == i && matches!(op, FusedOp::ReluAdd))
87 && !(*pos == i + 1 && matches!(op, FusedOp::AddMul))
88 });
89 patterns.push((i, FusedOp::AddReluMul));
90 }
91 _ => {}
92 }
93 }
94
95 patterns
96}