tract_core/ops/nn/
silu.rs1use crate::internal::*;
2use crate::ops::element_wise::ElementWiseOp;
3use crate::ops::math::Mul;
4use crate::ops::nn::Sigmoid;
5
6use tract_data::half::f16;
7
8element_wise!(silu, Silu,
9 [f16] => |_, xs| {
10 xs.iter_mut().for_each(|x| {
11 let xf = x.to_f32();
12 *x = f16::from_f32(xf / (1.0 + (-xf).exp()));
13 });
14 Ok(())
15 },
16 [f32] => |_, xs| {
17 let mut sigmoid = xs.to_vec();
18 (tract_linalg::ops().sigmoid_f32)().run(&mut sigmoid)?;
19 xs.iter_mut().zip(sigmoid).for_each(|(x, s)| *x *= s);
20 Ok(())
21 };
22 declutter: detect_silu
23);
24
25pub fn detect_silu(model: &TypedModel, node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
27 rule_if!(node.op_as::<ElementWiseOp>().is_some_and(|op| op.0.is::<Sigmoid>()));
28
29 let in_fact = model.node_input_facts(node.id)?[0];
30 let dt = in_fact.datum_type;
31
32 rule_if!(matches!(dt, DatumType::F32 | DatumType::F16));
34
35 rule_if_some!(mul_succ = model.find_succ_bin_with_outlet::<Mul>(node, &node.inputs[0]));
37
38 let mut patch = TypedModelPatch::default();
39 let silu_input = patch.taps(model, &node.inputs)?;
40 let out = patch.wire_node(format!("{}.silu", node.name), silu(), &silu_input)?;
41 patch.shunt_outside(model, mul_succ.id.into(), out[0])?;
42 Ok(Some(patch))
43}