Skip to main content

tract_core/ops/nn/
silu.rs

1use crate::internal::*;
2use crate::ops::element_wise::ElementWiseOp;
3use crate::ops::math::Mul;
4use crate::ops::nn::Sigmoid;
5
6use tract_data::half::f16;
7
8element_wise!(silu, Silu,
9    [f16] => |_, xs| {
10        xs.iter_mut().for_each(|x| {
11            let xf = x.to_f32();
12            *x = f16::from_f32(xf / (1.0 + (-xf).exp()));
13        });
14        Ok(())
15    },
16    [f32] => |_, xs| {
17        let mut sigmoid = xs.to_vec();
18        (tract_linalg::ops().sigmoid_f32)().run(&mut sigmoid)?;
19        xs.iter_mut().zip(sigmoid).for_each(|(x, s)| *x *= s);
20        Ok(())
21    };
22    declutter: detect_silu
23);
24
25/// Search pattern => A = A * SIGMOID(A)
26pub fn detect_silu(model: &TypedModel, node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
27    rule_if!(node.op_as::<ElementWiseOp>().is_some_and(|op| op.0.is::<Sigmoid>()));
28
29    let in_fact = model.node_input_facts(node.id)?[0];
30    let dt = in_fact.datum_type;
31
32    // Only F16 and F32 is supported.
33    rule_if!(matches!(dt, DatumType::F32 | DatumType::F16));
34
35    // Identify Mul successor: Sigmoid(A) * A
36    rule_if_some!(mul_succ = model.find_succ_bin_with_outlet::<Mul>(node, &node.inputs[0]));
37
38    let mut patch = TypedModelPatch::default();
39    let silu_input = patch.taps(model, &node.inputs)?;
40    let out = patch.wire_node(format!("{}.silu", node.name), silu(), &silu_input)?;
41    patch.shunt_outside(model, mul_succ.id.into(), out[0])?;
42    Ok(Some(patch))
43}