tract_core/ops/cnn/
sumpool.rs1use crate::internal::*;
2use num_traits::AsPrimitive;
3use std::iter::Sum;
4
5use crate::ops::cnn::pools::{ConcretePoolGeometry, PoolGeometry, PoolSpec};
6
7#[derive(Debug, Clone, new, Hash)]
8pub struct SumPool {
9 pub pool_spec: PoolSpec,
10 pub count_include_pad: bool,
11 pub normalize: bool,
12}
13
14impl Op for SumPool {
15 fn name(&self) -> Cow<str> {
16 "SumPool".into()
17 }
18
19 fn info(&self) -> TractResult<Vec<String>> {
20 Ok(self.pool_spec.info())
21 }
22
23 fn validation(&self) -> Validation {
24 Validation::Rounding
25 }
26
27 op_as_typed_op!();
28}
29
30impl EvalOp for SumPool {
31 fn is_stateless(&self) -> bool {
32 true
33 }
34
35 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
36 let shape: TVec<TDim> = inputs[0].shape().iter().map(|d| d.to_dim()).collect();
37 self.to_optimized(&shape)?.eval(inputs)
38 }
39}
40
41impl TypedOp for SumPool {
42 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
43 self.pool_spec.output_facts(inputs)
44 }
45
46 fn declutter(
47 &self,
48 model: &TypedModel,
49 node: &TypedNode,
50 ) -> TractResult<Option<TypedModelPatch>> {
51 let fact = model.outlet_fact(node.inputs[0])?;
52 if let Some(pool_spec) = self.pool_spec.declutter(&fact.shape)? {
53 return Ok(Some(TypedModelPatch::replace_single_op(
54 model,
55 node,
56 &node.inputs,
57 Self { pool_spec, ..self.clone() },
58 )?));
59 }
60 Ok(None)
61 }
62
63 as_op!();
64}
65
66impl SumPool {
67 fn to_optimized(&self, input_shape: &[TDim]) -> TractResult<OptSumPool> {
68 Ok(OptSumPool {
69 pool_spec: self.pool_spec.clone(),
70 count_include_pad: self.count_include_pad,
71 normalize: self.normalize,
72 geometry: self.pool_spec.compute_geo(input_shape)?,
73 })
74 }
75}
76
77#[derive(Debug, Clone, new, Hash)]
78pub struct OptSumPool {
79 pub pool_spec: PoolSpec,
80 pub count_include_pad: bool,
81 pub normalize: bool,
82 pub geometry: PoolGeometry,
83}
84
85impl Op for OptSumPool {
86 fn name(&self) -> Cow<str> {
87 "OptSumPool".into()
88 }
89
90 fn info(&self) -> TractResult<Vec<String>> {
91 Ok(self.pool_spec.info())
92 }
93
94 fn validation(&self) -> Validation {
95 Validation::Rounding
96 }
97
98 op_as_typed_op!();
99}
100
101impl EvalOp for OptSumPool {
102 fn is_stateless(&self) -> bool {
103 true
104 }
105
106 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
107 let input = args_1!(inputs);
108 let geo = self.geometry.to_concrete(input.shape())?;
109 let values = if input.datum_type().is_float() {
110 let mut values =
111 unsafe { Tensor::uninitialized_dt(input.datum_type(), &geo.output_shape.shape)? };
112 dispatch_floatlike!(Self::eval_t(input.datum_type())(
113 self,
114 &*input,
115 values.as_ptr_mut()?,
116 geo.as_ref()
117 ))?;
118 values
119 } else {
120 let mut values =
121 unsafe { Tensor::uninitialized_dt(DatumType::F32, &geo.output_shape.shape)? };
122 let input_f32 = input.cast_to_dt(DatumType::F32)?;
123 self.eval_t::<f32>(input_f32.as_ref(), values.as_ptr_mut()?, geo.as_ref())?;
124 values.cast_to_dt(input.datum_type())?.into_owned()
125 };
126
127 Ok(tvec!(values.into_tvalue()))
128 }
129}
130
131impl TypedOp for OptSumPool {
132 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
133 self.pool_spec.output_facts(inputs)
134 }
135
136 fn declutter(
137 &self,
138 model: &TypedModel,
139 node: &TypedNode,
140 ) -> TractResult<Option<TypedModelPatch>> {
141 let fact = model.outlet_fact(node.inputs[0])?;
142 if let Some(pool_spec) = self.pool_spec.declutter(&fact.shape)? {
143 return Ok(Some(TypedModelPatch::replace_single_op(
144 model,
145 node,
146 &node.inputs,
147 Self { pool_spec, ..self.clone() },
148 )?));
149 }
150 Ok(None)
151 }
152
153 as_op!();
154}
155
156impl OptSumPool {
157 fn eval_t<T: Copy + Datum + Sum + num_traits::Float>(
158 &self,
159 input: &Tensor,
160 values_ptr: *mut T,
161 geo: &ConcretePoolGeometry,
162 ) -> TractResult<()>
163 where
164 usize: AsPrimitive<T>,
165 {
166 let input_ptr = input.as_ptr::<T>()?;
167
168 let n = *geo.input_shape.n().unwrap_or(&1);
169 let n_stride_i = geo.input_shape.n_stride().unwrap_or(&0);
170 let n_stride_o = geo.output_shape.n_stride().unwrap_or(&0);
171 unsafe {
172 geo.patch.visit_output(|visitor| {
173 let div: Option<T> = if self.normalize {
174 Some(
175 if self.count_include_pad {
176 geo.patch.standard_layout_data_field.len().as_()
177 } else {
178 visitor.valid_count().as_()
179 }
180 .recip(),
181 )
182 } else {
183 None
184 };
185 for n in 0..n {
186 let input_offset = n * n_stride_i;
187 let output_offset = n * n_stride_o;
188 for c in 0..*geo.input_shape.c() {
189 let input_offset = input_offset + geo.input_shape.c_stride() * c;
190 let output_offset = output_offset + geo.output_shape.c_stride() * c;
191 let sum = visitor
192 .valid_offsets()
193 .map(|v| *input_ptr.offset(v + input_offset as isize))
194 .sum::<T>();
195
196 if let Some(div) = div {
197 *values_ptr.offset(output_offset as isize + visitor.output_offset) =
198 sum * div;
199 }
200 }
201 }
202 });
203 }
204 Ok(())
205 }
206}