1use tract_core::ops::change_axes::wire_with_rank_broadcast;
2
3use crate::infer::*;
4use crate::internal::*;
5
6#[derive(Clone, Debug, new, Hash)]
7pub struct GlobalAvgPool;
8
9impl Expansion for GlobalAvgPool {
10 fn name(&self) -> StaticName {
11 "GlobalAvgPool".into()
12 }
13
14 fn rules<'r, 'p: 'r, 's: 'r>(
15 &'s self,
16 solver: &mut Solver<'r>,
17 inputs: &'p [TensorProxy],
18 outputs: &'p [TensorProxy],
19 ) -> InferenceResult {
20 rules(solver, inputs, outputs)
21 }
22
23 fn wire(
24 &self,
25 name: &str,
26 target: &mut TypedModel,
27 inputs: &[OutletId],
28 ) -> TractResult<TVec<OutletId>> {
29 let input = inputs[0];
30 let input_fact = target.outlet_fact(input)?.clone();
31 let axes = (2..input_fact.rank()).collect();
32 let wire = target.wire_node(
33 name.to_string() + ".sum",
34 tract_core::ops::nn::Reduce::new(axes, tract_core::ops::nn::Reducer::Sum),
35 &[input],
36 )?;
37 let div = tensor0(input_fact.shape.iter().skip(2).product::<TDim>());
38 let div = target.add_const(format!("{name}.div"), div)?;
39 let div = target.wire_node(
40 format!("{name}.casted"),
41 tract_core::ops::cast::cast(input_fact.datum_type),
42 &[div],
43 )?;
44 wire_with_rank_broadcast(
45 format!("{name}.norm"),
46 target,
47 tract_core::ops::math::div(),
48 &[wire[0], div[0]],
49 )
50 }
51}
52
53#[derive(Clone, Debug, new, Hash)]
54pub struct GlobalLpPool(usize);
55
56impl Expansion for GlobalLpPool {
57 fn name(&self) -> StaticName {
58 format!("GlobalL{}Pool", self.0).into()
59 }
60
61 fn rules<'r, 'p: 'r, 's: 'r>(
62 &'s self,
63 solver: &mut Solver<'r>,
64 inputs: &'p [TensorProxy],
65 outputs: &'p [TensorProxy],
66 ) -> InferenceResult {
67 rules(solver, inputs, outputs)
68 }
69
70 fn wire(
71 &self,
72 name: &str,
73 target: &mut TypedModel,
74 inputs: &[OutletId],
75 ) -> TractResult<TVec<OutletId>> {
76 let input = inputs[0];
77 let input_fact = target.outlet_fact(input)?.clone();
78 let axes = (2..input_fact.rank()).collect();
79 let mut wire = tvec!(input);
80 if self.0 == 2 {
81 wire = target.wire_node(
82 name.to_string() + ".sqr",
83 tract_core::ops::math::square(),
84 &wire,
85 )?;
86 } else {
87 let pow = tensor0(self.0 as f64)
88 .cast_to_dt(input_fact.datum_type)?
89 .into_owned()
90 .broadcast_into_rank(input_fact.rank())?
91 .into_arc_tensor();
92 let pow = target.add_const(name.to_string() + ".pow.cst", pow)?;
93 wire = target.wire_node(
94 name.to_string() + ".pow",
95 tract_core::ops::math::pow(),
96 &[wire[0], pow],
97 )?;
98 }
99 wire = target.wire_node(
100 name.to_string() + ".sum",
101 tract_core::ops::nn::Reduce::new(axes, tract_core::ops::nn::Reducer::Sum),
102 &wire,
103 )?;
104 let div = tensor0(input_fact.shape.iter().skip(2).product::<TDim>().to_i64()? as f64)
105 .cast_to_dt(input_fact.datum_type)?
106 .into_owned()
107 .broadcast_into_rank(input_fact.rank())?;
108 let div = target.add_const(name.to_string() + ".div", div)?;
109 wire = target.wire_node(
110 name.to_string() + ".norm",
111 tract_core::ops::math::div(),
112 &[wire[0], div],
113 )?;
114 if self.0 == 2 {
115 wire = target.wire_node(
116 name.to_string() + ".sqrt",
117 tract_core::ops::math::sqrt(),
118 &wire,
119 )?;
120 } else {
121 let anti_pow = tensor0((self.0 as f64).recip())
122 .cast_to_dt(input_fact.datum_type)?
123 .into_owned()
124 .broadcast_into_rank(input_fact.rank())?
125 .into_arc_tensor();
126 let anti_pow = target.add_const(name.to_string() + ".anti_pow", anti_pow)?;
127 wire = target.wire_node(
128 name.to_string() + ".antipow",
129 tract_core::ops::math::pow(),
130 &[wire[0], anti_pow],
131 )?;
132 }
133 Ok(wire)
134 }
135}
136
137#[derive(Clone, Debug, new, Hash)]
138pub struct GlobalMaxPool;
139
140impl Expansion for GlobalMaxPool {
141 fn name(&self) -> StaticName {
142 "GlobalMaxPool".into()
143 }
144
145 fn rules<'r, 'p: 'r, 's: 'r>(
146 &'s self,
147 solver: &mut Solver<'r>,
148 inputs: &'p [TensorProxy],
149 outputs: &'p [TensorProxy],
150 ) -> InferenceResult {
151 rules(solver, inputs, outputs)
152 }
153
154 fn wire(
155 &self,
156 name: &str,
157 target: &mut TypedModel,
158 inputs: &[OutletId],
159 ) -> TractResult<TVec<OutletId>> {
160 let input = inputs[0];
161 let input_fact = target.outlet_fact(input)?.clone();
162 let axes = (2..input_fact.rank()).collect();
163 target.wire_node(
164 name.to_string() + ".max",
165 tract_core::ops::nn::Reduce::new(axes, tract_core::ops::nn::Reducer::Max),
166 &[input],
167 )
168 }
169}
170
171fn rules<'r, 'p: 'r, 's: 'r>(
172 s: &mut Solver<'r>,
173 inputs: &'p [TensorProxy],
174 outputs: &'p [TensorProxy],
175) -> InferenceResult {
176 check_input_arity(inputs, 1)?;
177 check_output_arity(outputs, 1)?;
178 s.equals(&outputs[0].datum_type, &inputs[0].datum_type)?;
179 s.equals(&outputs[0].rank, &inputs[0].rank)?;
180 s.equals(&outputs[0].shape[0], &inputs[0].shape[0])?;
181 s.equals(&outputs[0].shape[1], &inputs[0].shape[1])?;
182 s.given(&inputs[0].rank, move |s, rank| {
183 for i in 2..rank {
184 s.equals(&outputs[0].shape[i as usize], TDim::from(1))?;
185 }
186 Ok(())
187 })
188}