Skip to main content

tract_core/ops/nn/
silu.rs

1use crate::internal::*;
2use crate::ops::element_wise::ElementWiseOp;
3use crate::ops::math::Mul;
4use crate::ops::nn::Sigmoid;
5
6use tract_data::half::f16;
7
8element_wise!(silu, Silu,
9    [f16] => |_, xs| {
10        xs.iter_mut().for_each(|x| {
11            let xf = x.to_f32();
12            *x = f16::from_f32(xf / (1.0 + (-xf).exp()));
13        });
14        Ok(())
15    },
16    [f32] => |_, xs| { (tract_linalg::ops().silu_f32)().run(xs) };
17    cost: |dt| {tvec!((Cost::FMA(dt), 12), (Cost::Div(dt), 1))};
18    declutter: detect_silu
19);
20
21/// Search pattern => A = A * SIGMOID(A)
22pub fn detect_silu(model: &TypedModel, node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
23    rule_if!(node.op_as::<ElementWiseOp>().is_some_and(|op| op.0.is::<Sigmoid>()));
24
25    let in_fact = model.node_input_facts(node.id)?[0];
26    let dt = in_fact.datum_type;
27
28    // Only F16 and F32 is supported.
29    rule_if!(matches!(dt, DatumType::F32 | DatumType::F16));
30
31    // Identify Mul successor: Sigmoid(A) * A
32    rule_if_some!(mul_succ = model.find_succ_bin_with_outlet::<Mul>(node, &node.inputs[0]));
33
34    let mut patch = TypedModelPatch::default();
35    let silu_input = patch.taps(model, &node.inputs)?;
36    let out = patch.wire_node(format!("{}.silu", node.name), silu(), &silu_input)?;
37    patch.shunt_outside(model, mul_succ.id.into(), out[0])?;
38    Ok(Some(patch))
39}