1use tract_ndarray::prelude::*;
2use tract_nnef::internal::*;
3
4#[derive(Debug, Clone, Default)]
5pub struct Lrn {
6 pub alpha: f32,
7 pub beta: f32,
8 pub bias: f32,
9 pub size: usize,
10}
11
12impl Lrn {
13 fn eval_t<T>(&self, input: TValue) -> TractResult<TVec<TValue>>
14 where
15 T: Datum + tract_num_traits::Float + ::std::iter::Sum,
16 {
17 let input = input.to_array_view::<T>()?;
18 let channels = input.shape()[1];
19 let output = Array::from_shape_fn(input.shape(), |mut coords| {
20 let c = coords[1];
21 let x = input[&coords];
22 let c_min = c.saturating_sub((self.size - 1) / 2);
23 let c_max = (c + ((self.size - 1).divceil(2))).min(channels - 1);
24 let square_sum: T = (c_min..=c_max)
25 .map(|c| {
26 coords[1] = c;
27 input[&coords].powi(2)
28 })
29 .sum();
30 x / (T::from(self.bias).unwrap()
31 + T::from(self.alpha).unwrap() / T::from(self.size).unwrap() * square_sum)
32 .powf(T::from(self.beta).unwrap())
33 });
34 Ok(tvec!(output.into_tvalue()))
35 }
36}
37
38impl Op for Lrn {
39 fn name(&self) -> StaticName {
40 "Lrn".into()
41 }
42
43 op_as_typed_op!();
44}
45
46impl EvalOp for Lrn {
47 fn is_stateless(&self) -> bool {
48 true
49 }
50
51 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
52 let input = args_1!(inputs);
53 dispatch_floatlike!(Self::eval_t(input.datum_type())(self, input))
54 }
55}
56
57impl TypedOp for Lrn {
58 as_op!();
59
60 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
61 Ok(tvec!(inputs[0].clone()))
62 }
63}
64
65pub fn parameters() -> Vec<Parameter> {
66 vec![
67 TypeName::Scalar.tensor().named("input"),
68 TypeName::Scalar.named("alpha").default(0.0001),
69 TypeName::Scalar.named("beta").default(0.75),
70 TypeName::Scalar.named("bias").default(1.0),
71 TypeName::Integer.named("size"),
72 ]
73}
74
75pub fn dump(ast: &mut IntoAst, node: &TypedNode, lrn: &Lrn) -> TractResult<Option<Arc<RValue>>> {
76 let input = ast.mapping[&node.inputs[0]].clone();
77 Ok(Some(invocation(
78 "tract_onnx_lrn",
79 &[input],
80 &[
81 ("alpha", numeric(lrn.alpha)),
82 ("beta", numeric(lrn.beta)),
83 ("bias", numeric(lrn.bias)),
84 ("size", numeric(lrn.size)),
85 ],
86 )))
87}
88
89pub fn load(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractResult<Value> {
90 let input = invocation.named_arg_as(builder, "input")?;
91 let alpha = invocation.named_arg_as(builder, "alpha")?;
92 let beta = invocation.named_arg_as(builder, "beta")?;
93 let bias = invocation.named_arg_as(builder, "bias")?;
94 let size = invocation.named_arg_as(builder, "size")?;
95 let op = Lrn { alpha, beta, bias, size };
96 builder.wire(op, &[input])
97}