tract_core/ops/nn/
gelu_approximate.rs1use crate::internal::*;
2use crate::ops::binary::TypedBinOp;
3use crate::ops::element_wise::ElementWiseOp;
4use crate::ops::math::{Add, Mul, Pow, Tanh};
5
6use tract_data::half::f16;
7
8fn gelu_approx_f32(x: f32, pow: i32) -> f32 {
9 let sqrt_2_over_pi = (2.0 / std::f32::consts::PI).sqrt();
10 0.5 * x * (1.0 + f32::tanh(sqrt_2_over_pi * (x + 0.044715 * x.powi(pow))))
11}
12
13element_wise!(gelu_approximate, GeluApproximate { fast_impl: bool },
14 [f16] => |op, xs| {
15 let pow = if op.fast_impl { 2 } else { 3 };
16 xs.iter_mut().for_each(|x| {
17 *x = f16::from_f32(gelu_approx_f32(x.to_f32(), pow));
18 });
19 Ok(())
20 },
21 [f32] => |op, xs| {
22 let pow = if op.fast_impl { 2 } else { 3 };
23 xs.iter_mut().for_each(|x| {
24 *x = gelu_approx_f32(*x, pow);
25 });
26 Ok(())
27 };
28 cost: |dt| {tvec!((Cost::FMA(dt), 15))}
29);
30
31pub fn detect_gelu_approx(
33 _op: &Pow,
34 model: &TypedModel,
35 node: &TypedNode,
36) -> TractResult<Option<TypedModelPatch>> {
37 let pow_node = node;
38
39 let in_fact = model.node_input_facts(pow_node.id)?[0];
40 let dt = in_fact.datum_type;
41
42 rule_if!(matches!(dt, DatumType::F32 | DatumType::F16));
44
45 rule_if!(
46 model.matches_single_input_const(pow_node, 3.0)
47 || model.matches_single_input_const(pow_node, 2.0)
48 );
49 let fast_impl = model.matches_single_input_const(pow_node, 2.0);
50
51 rule_if_some!(mul_coef_a = model.find_succ_bin_with_const::<Mul>(pow_node, 0.044715));
53
54 rule_if_some!(
56 x_plus_mul_coef_a = model.find_succ_bin_with_outlet::<Add>(mul_coef_a, &pow_node.inputs[0])
57 );
58
59 let sqrt_2_over_pi = (2.0 / std::f32::consts::PI).sqrt();
61 rule_if_some!(
62 mul_sqrt_2_over_pi =
63 model.find_succ_bin_with_const::<Mul>(x_plus_mul_coef_a, sqrt_2_over_pi)
64 );
65
66 rule_if_some!(tanh_succ = model.single_succ(mul_sqrt_2_over_pi.id)?);
68 rule_if_some!(tanh_succ_op = tanh_succ.op_as::<ElementWiseOp>());
69 rule_if!(tanh_succ_op.0.is::<Tanh>());
70
71 rule_if_some!(tanh_plus_1 = model.find_succ_bin_with_const::<Add>(tanh_succ, 1.0));
73
74 rule_if_some!(mul_succ = model.single_succ(tanh_plus_1.id)?);
76 rule_if_some!(mul_succ_op = mul_succ.op_as::<TypedBinOp>());
77 rule_if!(mul_succ_op.0.is::<Mul>());
78
79 let last_node_id = if mul_succ.inputs.contains(&pow_node.inputs[0]) {
83 rule_if_some!(last_mul_with_0_5 = model.find_succ_bin_with_const::<Mul>(mul_succ, 0.5));
85 last_mul_with_0_5.id
86 } else {
87 rule_if_some!(
90 x_mul_0_5 = mul_succ
91 .inputs
92 .iter()
93 .filter_map(|i| {
94 let n = &model.nodes()[i.node];
95 let op = n.op_as::<TypedBinOp>()?;
96 op.0.is::<Mul>().then_some(n)
97 })
98 .next()
99 );
100 rule_if!(model.matches_single_input_const(x_mul_0_5, 0.5));
101 rule_if!(x_mul_0_5.inputs.contains(&pow_node.inputs[0]));
102 mul_succ.id
103 };
104
105 let mut patch = TypedModelPatch::default();
106 let gelu_approx_input = patch.taps(model, &pow_node.inputs)?;
107 let out = patch.wire_node(
108 format!("{}.gelu_approx", pow_node.name),
109 gelu_approximate(fast_impl),
110 &[gelu_approx_input[0]],
111 )?;
112 patch.shunt_outside(model, last_node_id.into(), out[0])?;
113 Ok(Some(patch))
114}