tract_hir/ops/nn/
global_pools.rs

1use tract_core::ops::change_axes::wire_with_rank_broadcast;
2
3use crate::infer::*;
4use crate::internal::*;
5
6#[derive(Clone, Debug, new, Hash)]
7pub struct GlobalAvgPool;
8
9impl Expansion for GlobalAvgPool {
10    fn name(&self) -> StaticName {
11        "GlobalAvgPool".into()
12    }
13
14    fn rules<'r, 'p: 'r, 's: 'r>(
15        &'s self,
16        solver: &mut Solver<'r>,
17        inputs: &'p [TensorProxy],
18        outputs: &'p [TensorProxy],
19    ) -> InferenceResult {
20        rules(solver, inputs, outputs)
21    }
22
23    fn wire(
24        &self,
25        name: &str,
26        target: &mut TypedModel,
27        inputs: &[OutletId],
28    ) -> TractResult<TVec<OutletId>> {
29        let input = inputs[0];
30        let input_fact = target.outlet_fact(input)?.clone();
31        let axes = (2..input_fact.rank()).collect();
32        let wire = target.wire_node(
33            name.to_string() + ".sum",
34            tract_core::ops::nn::Reduce::new(axes, tract_core::ops::nn::Reducer::Sum),
35            &[input],
36        )?;
37        let div = tensor0(input_fact.shape.iter().skip(2).product::<TDim>());
38        let div = target.add_const(format!("{name}.div"), div)?;
39        let div = target.wire_node(
40            format!("{name}.casted"),
41            tract_core::ops::cast::cast(input_fact.datum_type),
42            &[div],
43        )?;
44        wire_with_rank_broadcast(
45            format!("{name}.norm"),
46            target,
47            tract_core::ops::math::div(),
48            &[wire[0], div[0]],
49        )
50    }
51}
52
53#[derive(Clone, Debug, new, Hash)]
54pub struct GlobalLpPool(usize);
55
56impl Expansion for GlobalLpPool {
57    fn name(&self) -> StaticName {
58        format!("GlobalL{}Pool", self.0).into()
59    }
60
61    fn rules<'r, 'p: 'r, 's: 'r>(
62        &'s self,
63        solver: &mut Solver<'r>,
64        inputs: &'p [TensorProxy],
65        outputs: &'p [TensorProxy],
66    ) -> InferenceResult {
67        rules(solver, inputs, outputs)
68    }
69
70    fn wire(
71        &self,
72        name: &str,
73        target: &mut TypedModel,
74        inputs: &[OutletId],
75    ) -> TractResult<TVec<OutletId>> {
76        let input = inputs[0];
77        let input_fact = target.outlet_fact(input)?.clone();
78        let axes = (2..input_fact.rank()).collect();
79        let mut wire = tvec!(input);
80        if self.0 == 2 {
81            wire = target.wire_node(
82                name.to_string() + ".sqr",
83                tract_core::ops::math::square(),
84                &wire,
85            )?;
86        } else {
87            let pow = tensor0(self.0 as f64)
88                .cast_to_dt(input_fact.datum_type)?
89                .into_owned()
90                .broadcast_into_rank(input_fact.rank())?
91                .into_arc_tensor();
92            let pow = target.add_const(name.to_string() + ".pow.cst", pow)?;
93            wire = target.wire_node(
94                name.to_string() + ".pow",
95                tract_core::ops::math::pow(),
96                &[wire[0], pow],
97            )?;
98        }
99        wire = target.wire_node(
100            name.to_string() + ".sum",
101            tract_core::ops::nn::Reduce::new(axes, tract_core::ops::nn::Reducer::Sum),
102            &wire,
103        )?;
104        let div = tensor0(input_fact.shape.iter().skip(2).product::<TDim>().to_i64()? as f64)
105            .cast_to_dt(input_fact.datum_type)?
106            .into_owned()
107            .broadcast_into_rank(input_fact.rank())?;
108        let div = target.add_const(name.to_string() + ".div", div)?;
109        wire = target.wire_node(
110            name.to_string() + ".norm",
111            tract_core::ops::math::div(),
112            &[wire[0], div],
113        )?;
114        if self.0 == 2 {
115            wire = target.wire_node(
116                name.to_string() + ".sqrt",
117                tract_core::ops::math::sqrt(),
118                &wire,
119            )?;
120        } else {
121            let anti_pow = tensor0((self.0 as f64).recip())
122                .cast_to_dt(input_fact.datum_type)?
123                .into_owned()
124                .broadcast_into_rank(input_fact.rank())?
125                .into_arc_tensor();
126            let anti_pow = target.add_const(name.to_string() + ".anti_pow", anti_pow)?;
127            wire = target.wire_node(
128                name.to_string() + ".antipow",
129                tract_core::ops::math::pow(),
130                &[wire[0], anti_pow],
131            )?;
132        }
133        Ok(wire)
134    }
135}
136
137#[derive(Clone, Debug, new, Hash)]
138pub struct GlobalMaxPool;
139
140impl Expansion for GlobalMaxPool {
141    fn name(&self) -> StaticName {
142        "GlobalMaxPool".into()
143    }
144
145    fn rules<'r, 'p: 'r, 's: 'r>(
146        &'s self,
147        solver: &mut Solver<'r>,
148        inputs: &'p [TensorProxy],
149        outputs: &'p [TensorProxy],
150    ) -> InferenceResult {
151        rules(solver, inputs, outputs)
152    }
153
154    fn wire(
155        &self,
156        name: &str,
157        target: &mut TypedModel,
158        inputs: &[OutletId],
159    ) -> TractResult<TVec<OutletId>> {
160        let input = inputs[0];
161        let input_fact = target.outlet_fact(input)?.clone();
162        let axes = (2..input_fact.rank()).collect();
163        target.wire_node(
164            name.to_string() + ".max",
165            tract_core::ops::nn::Reduce::new(axes, tract_core::ops::nn::Reducer::Max),
166            &[input],
167        )
168    }
169}
170
171fn rules<'r, 'p: 'r, 's: 'r>(
172    s: &mut Solver<'r>,
173    inputs: &'p [TensorProxy],
174    outputs: &'p [TensorProxy],
175) -> InferenceResult {
176    check_input_arity(inputs, 1)?;
177    check_output_arity(outputs, 1)?;
178    s.equals(&outputs[0].datum_type, &inputs[0].datum_type)?;
179    s.equals(&outputs[0].rank, &inputs[0].rank)?;
180    s.equals(&outputs[0].shape[0], &inputs[0].shape[0])?;
181    s.equals(&outputs[0].shape[1], &inputs[0].shape[1])?;
182    s.given(&inputs[0].rank, move |s, rank| {
183        for i in 2..rank {
184            s.equals(&outputs[0].shape[i as usize], TDim::from(1))?;
185        }
186        Ok(())
187    })
188}