tract_core/ops/cnn/
sumpool.rs

1use crate::internal::*;
2use num_traits::AsPrimitive;
3use std::iter::Sum;
4
5use crate::ops::cnn::pools::{ConcretePoolGeometry, PoolGeometry, PoolSpec};
6
7#[derive(Debug, Clone, new, Hash)]
8pub struct SumPool {
9    pub pool_spec: PoolSpec,
10    pub count_include_pad: bool,
11    pub normalize: bool,
12}
13
14impl Op for SumPool {
15    fn name(&self) -> Cow<str> {
16        "SumPool".into()
17    }
18
19    fn info(&self) -> TractResult<Vec<String>> {
20        Ok(self.pool_spec.info())
21    }
22
23    fn validation(&self) -> Validation {
24        Validation::Rounding
25    }
26
27    op_as_typed_op!();
28}
29
30impl EvalOp for SumPool {
31    fn is_stateless(&self) -> bool {
32        true
33    }
34
35    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
36        let shape: TVec<TDim> = inputs[0].shape().iter().map(|d| d.to_dim()).collect();
37        self.to_optimized(&shape)?.eval(inputs)
38    }
39}
40
41impl TypedOp for SumPool {
42    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
43        self.pool_spec.output_facts(inputs)
44    }
45
46    fn declutter(
47        &self,
48        model: &TypedModel,
49        node: &TypedNode,
50    ) -> TractResult<Option<TypedModelPatch>> {
51        let fact = model.outlet_fact(node.inputs[0])?;
52        if let Some(pool_spec) = self.pool_spec.declutter(&fact.shape)? {
53            return Ok(Some(TypedModelPatch::replace_single_op(
54                model,
55                node,
56                &node.inputs,
57                Self { pool_spec, ..self.clone() },
58            )?));
59        }
60        Ok(None)
61    }
62
63    as_op!();
64}
65
66impl SumPool {
67    fn to_optimized(&self, input_shape: &[TDim]) -> TractResult<OptSumPool> {
68        Ok(OptSumPool {
69            pool_spec: self.pool_spec.clone(),
70            count_include_pad: self.count_include_pad,
71            normalize: self.normalize,
72            geometry: self.pool_spec.compute_geo(input_shape)?,
73        })
74    }
75}
76
77#[derive(Debug, Clone, new, Hash)]
78pub struct OptSumPool {
79    pub pool_spec: PoolSpec,
80    pub count_include_pad: bool,
81    pub normalize: bool,
82    pub geometry: PoolGeometry,
83}
84
85impl Op for OptSumPool {
86    fn name(&self) -> Cow<str> {
87        "OptSumPool".into()
88    }
89
90    fn info(&self) -> TractResult<Vec<String>> {
91        Ok(self.pool_spec.info())
92    }
93
94    fn validation(&self) -> Validation {
95        Validation::Rounding
96    }
97
98    op_as_typed_op!();
99}
100
101impl EvalOp for OptSumPool {
102    fn is_stateless(&self) -> bool {
103        true
104    }
105
106    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
107        let input = args_1!(inputs);
108        let geo = self.geometry.to_concrete(input.shape())?;
109        let values = if input.datum_type().is_float() {
110            let mut values =
111                unsafe { Tensor::uninitialized_dt(input.datum_type(), &geo.output_shape.shape)? };
112            dispatch_floatlike!(Self::eval_t(input.datum_type())(
113                self,
114                &*input,
115                values.as_ptr_mut()?,
116                geo.as_ref()
117            ))?;
118            values
119        } else {
120            let mut values =
121                unsafe { Tensor::uninitialized_dt(DatumType::F32, &geo.output_shape.shape)? };
122            let input_f32 = input.cast_to_dt(DatumType::F32)?;
123            self.eval_t::<f32>(input_f32.as_ref(), values.as_ptr_mut()?, geo.as_ref())?;
124            values.cast_to_dt(input.datum_type())?.into_owned()
125        };
126
127        Ok(tvec!(values.into_tvalue()))
128    }
129}
130
131impl TypedOp for OptSumPool {
132    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
133        self.pool_spec.output_facts(inputs)
134    }
135
136    fn declutter(
137        &self,
138        model: &TypedModel,
139        node: &TypedNode,
140    ) -> TractResult<Option<TypedModelPatch>> {
141        let fact = model.outlet_fact(node.inputs[0])?;
142        if let Some(pool_spec) = self.pool_spec.declutter(&fact.shape)? {
143            return Ok(Some(TypedModelPatch::replace_single_op(
144                model,
145                node,
146                &node.inputs,
147                Self { pool_spec, ..self.clone() },
148            )?));
149        }
150        Ok(None)
151    }
152
153    as_op!();
154}
155
156impl OptSumPool {
157    fn eval_t<T: Copy + Datum + Sum + num_traits::Float>(
158        &self,
159        input: &Tensor,
160        values_ptr: *mut T,
161        geo: &ConcretePoolGeometry,
162    ) -> TractResult<()>
163    where
164        usize: AsPrimitive<T>,
165    {
166        let input_ptr = input.as_ptr::<T>()?;
167
168        let n = *geo.input_shape.n().unwrap_or(&1);
169        let n_stride_i = geo.input_shape.n_stride().unwrap_or(&0);
170        let n_stride_o = geo.output_shape.n_stride().unwrap_or(&0);
171        unsafe {
172            geo.patch.visit_output(|visitor| {
173                let div: Option<T> = if self.normalize {
174                    Some(
175                        if self.count_include_pad {
176                            geo.patch.standard_layout_data_field.len().as_()
177                        } else {
178                            visitor.valid_count().as_()
179                        }
180                        .recip(),
181                    )
182                } else {
183                    None
184                };
185                for n in 0..n {
186                    let input_offset = n * n_stride_i;
187                    let output_offset = n * n_stride_o;
188                    for c in 0..*geo.input_shape.c() {
189                        let input_offset = input_offset + geo.input_shape.c_stride() * c;
190                        let output_offset = output_offset + geo.output_shape.c_stride() * c;
191                        let sum = visitor
192                            .valid_offsets()
193                            .map(|v| *input_ptr.offset(v + input_offset as isize))
194                            .sum::<T>();
195
196                        if let Some(div) = div {
197                            *values_ptr.offset(output_offset as isize + visitor.output_offset) =
198                                sum * div;
199                        }
200                    }
201                }
202            });
203        }
204        Ok(())
205    }
206}