Skip to main content

tract_core/ops/nn/
gelu_approximate.rs

1use crate::internal::*;
2use crate::ops::binary::TypedBinOp;
3use crate::ops::element_wise::ElementWiseOp;
4use crate::ops::math::{Add, Mul, Pow, Tanh};
5
6use tract_data::half::f16;
7
8fn gelu_approx_f32(x: f32, pow: i32) -> f32 {
9    let sqrt_2_over_pi = (2.0 / std::f32::consts::PI).sqrt();
10    0.5 * x * (1.0 + f32::tanh(sqrt_2_over_pi * (x + 0.044715 * x.powi(pow))))
11}
12
13element_wise!(gelu_approximate, GeluApproximate { fast_impl: bool },
14    [f16] => |op, xs| {
15        let pow = if op.fast_impl { 2 } else { 3 };
16        xs.iter_mut().for_each(|x| {
17            *x = f16::from_f32(gelu_approx_f32(x.to_f32(), pow));
18        });
19        Ok(())
20    },
21    [f32] => |op, xs| {
22        if op.fast_impl {
23            // pow=2 fast path: no linalg kernel yet, scalar fallback.
24            xs.iter_mut().for_each(|x| {
25                *x = gelu_approx_f32(*x, 2);
26            });
27            Ok(())
28        } else {
29            // pow=3 canonical path: linalg NEON kernel composes with tanh.
30            (tract_linalg::ops().gelu_f32)().run(xs)
31        }
32    };
33    cost: |dt| {tvec!((Cost::FMA(dt), 15))}
34);
35
36/// Search pattern => NEW_GELU(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^N))); N ∈ {2, 3}
37pub fn detect_gelu_approx(
38    _op: &Pow,
39    model: &TypedModel,
40    node: &TypedNode,
41) -> TractResult<Option<TypedModelPatch>> {
42    let pow_node = node;
43
44    let in_fact = model.node_input_facts(pow_node.id)?[0];
45    let dt = in_fact.datum_type;
46
47    // Only F16 and F32 is supported.
48    rule_if!(matches!(dt, DatumType::F32 | DatumType::F16));
49
50    rule_if!(
51        model.matches_single_input_const(pow_node, 3.0)
52            || model.matches_single_input_const(pow_node, 2.0)
53    );
54    let fast_impl = model.matches_single_input_const(pow_node, 2.0);
55
56    // 0.044715 * x^N
57    rule_if_some!(mul_coef_a = model.find_succ_bin_with_const::<Mul>(pow_node, 0.044715));
58
59    // x + 0.044715 * x^N
60    rule_if_some!(
61        x_plus_mul_coef_a = model.find_succ_bin_with_outlet::<Add>(mul_coef_a, &pow_node.inputs[0])
62    );
63
64    // sqrt(2/pi) * (x + 0.044715 * x^N)
65    let sqrt_2_over_pi = (2.0 / std::f32::consts::PI).sqrt();
66    rule_if_some!(
67        mul_sqrt_2_over_pi =
68            model.find_succ_bin_with_const::<Mul>(x_plus_mul_coef_a, sqrt_2_over_pi)
69    );
70
71    // tanh(sqrt(2/pi) * (x + 0.044715 * x^N))
72    rule_if_some!(tanh_succ = model.single_succ(mul_sqrt_2_over_pi.id)?);
73    rule_if_some!(tanh_succ_op = tanh_succ.op_as::<ElementWiseOp>());
74    rule_if!(tanh_succ_op.0.is::<Tanh>());
75
76    // 1.0 + tanh(sqrt(2/pi) * (x + 0.044715 * x^N)) N ∈ {2, 3}
77    rule_if_some!(tanh_plus_1 = model.find_succ_bin_with_const::<Add>(tanh_succ, 1.0));
78
79    // Identify Mul
80    rule_if_some!(mul_succ = model.single_succ(tanh_plus_1.id)?);
81    rule_if_some!(mul_succ_op = mul_succ.op_as::<TypedBinOp>());
82    rule_if!(mul_succ_op.0.is::<Mul>());
83
84    // Search first
85    // tmp = x * (1.0 + tanh(sqrt(2/pi) * (x + 0.044715 * x^N)))
86    // out = 0.5 * tmp
87    let last_node_id = if mul_succ.inputs.contains(&pow_node.inputs[0]) {
88        // 0.5 * x * (1.0 + tanh(sqrt(2/pi) * (x + 0.044715 * x^N)))
89        rule_if_some!(last_mul_with_0_5 = model.find_succ_bin_with_const::<Mul>(mul_succ, 0.5));
90        last_mul_with_0_5.id
91    } else {
92        // tmp = 0.5 * x
93        // out = tmp * (1.0 + tanh(sqrt(2/pi) * (x + 0.044715 * x^N))) N ∈ {2, 3}
94        rule_if_some!(
95            x_mul_0_5 = mul_succ
96                .inputs
97                .iter()
98                .filter_map(|i| {
99                    let n = &model.nodes()[i.node];
100                    let op = n.op_as::<TypedBinOp>()?;
101                    op.0.is::<Mul>().then_some(n)
102                })
103                .next()
104        );
105        rule_if!(model.matches_single_input_const(x_mul_0_5, 0.5));
106        rule_if!(x_mul_0_5.inputs.contains(&pow_node.inputs[0]));
107        mul_succ.id
108    };
109
110    let mut patch = TypedModelPatch::default();
111    let gelu_approx_input = patch.taps(model, &pow_node.inputs)?;
112    let out = patch.wire_node(
113        format!("{}.gelu_approx", pow_node.name),
114        gelu_approximate(fast_impl),
115        &[gelu_approx_input[0]],
116    )?;
117    patch.shunt_outside(model, last_node_id.into(), out[0])?;
118    Ok(Some(patch))
119}