tract_onnx_opl/
lrn.rs

1use tract_ndarray::prelude::*;
2use tract_nnef::internal::*;
3
4#[derive(Debug, Clone, Default)]
5pub struct Lrn {
6    pub alpha: f32,
7    pub beta: f32,
8    pub bias: f32,
9    pub size: usize,
10}
11
12impl Lrn {
13    fn eval_t<T>(&self, input: TValue) -> TractResult<TVec<TValue>>
14    where
15        T: Datum + tract_num_traits::Float + ::std::iter::Sum,
16    {
17        let input = input.to_array_view::<T>()?;
18        let channels = input.shape()[1];
19        let output = Array::from_shape_fn(input.shape(), |mut coords| {
20            let c = coords[1];
21            let x = input[&coords];
22            let c_min = c.saturating_sub((self.size - 1) / 2);
23            let c_max = (c + ((self.size - 1).divceil(2))).min(channels - 1);
24            let square_sum: T = (c_min..=c_max)
25                .map(|c| {
26                    coords[1] = c;
27                    input[&coords].powi(2)
28                })
29                .sum();
30            x / (T::from(self.bias).unwrap()
31                + T::from(self.alpha).unwrap() / T::from(self.size).unwrap() * square_sum)
32                .powf(T::from(self.beta).unwrap())
33        });
34        Ok(tvec!(output.into_tvalue()))
35    }
36}
37
38impl Op for Lrn {
39    fn name(&self) -> StaticName {
40        "Lrn".into()
41    }
42
43    op_as_typed_op!();
44}
45
46impl EvalOp for Lrn {
47    fn is_stateless(&self) -> bool {
48        true
49    }
50
51    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
52        let input = args_1!(inputs);
53        dispatch_floatlike!(Self::eval_t(input.datum_type())(self, input))
54    }
55}
56
57impl TypedOp for Lrn {
58    as_op!();
59
60    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
61        Ok(tvec!(inputs[0].clone()))
62    }
63}
64
65pub fn parameters() -> Vec<Parameter> {
66    vec![
67        TypeName::Scalar.tensor().named("input"),
68        TypeName::Scalar.named("alpha").default(0.0001),
69        TypeName::Scalar.named("beta").default(0.75),
70        TypeName::Scalar.named("bias").default(1.0),
71        TypeName::Integer.named("size"),
72    ]
73}
74
75pub fn dump(ast: &mut IntoAst, node: &TypedNode, lrn: &Lrn) -> TractResult<Option<Arc<RValue>>> {
76    let input = ast.mapping[&node.inputs[0]].clone();
77    Ok(Some(invocation(
78        "tract_onnx_lrn",
79        &[input],
80        &[
81            ("alpha", numeric(lrn.alpha)),
82            ("beta", numeric(lrn.beta)),
83            ("bias", numeric(lrn.bias)),
84            ("size", numeric(lrn.size)),
85        ],
86    )))
87}
88
89pub fn load(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractResult<Value> {
90    let input = invocation.named_arg_as(builder, "input")?;
91    let alpha = invocation.named_arg_as(builder, "alpha")?;
92    let beta = invocation.named_arg_as(builder, "beta")?;
93    let bias = invocation.named_arg_as(builder, "bias")?;
94    let size = invocation.named_arg_as(builder, "size")?;
95    let op = Lrn { alpha, beta, bias, size };
96    builder.wire(op, &[input])
97}