tract_core/ops/cnn/
maxpool.rs

1use crate::internal::*;
2use ndarray::prelude::*;
3
4use crate::ops::cnn::pools::{ConcretePoolGeometry, PoolGeometry, PoolSpec};
5
6#[derive(Debug, Clone, new, Hash)]
7pub struct MaxPool {
8    pub pool_spec: PoolSpec,
9    pub with_index_outputs: Option<DatumType>,
10}
11
12impl Op for MaxPool {
13    fn name(&self) -> Cow<str> {
14        "MaxPool".into()
15    }
16
17    fn info(&self) -> TractResult<Vec<String>> {
18        Ok(self.pool_spec.info())
19    }
20
21    op_as_typed_op!();
22}
23
24impl EvalOp for MaxPool {
25    fn is_stateless(&self) -> bool {
26        true
27    }
28
29    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
30        let shape: TVec<TDim> = inputs[0].shape().iter().map(|d| d.to_dim()).collect();
31        self.to_optimized(&shape)?.eval(inputs)
32    }
33}
34
35impl TypedOp for MaxPool {
36    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
37        let mut facts = self.pool_spec.output_facts(inputs)?;
38        if let Some(idt) = self.with_index_outputs {
39            facts.push(facts[0].clone());
40            facts[1].datum_type = idt;
41        }
42        Ok(facts)
43    }
44
45    fn declutter(
46        &self,
47        model: &TypedModel,
48        node: &TypedNode,
49    ) -> TractResult<Option<TypedModelPatch>> {
50        if self.with_index_outputs.is_some()
51            && node.outputs[1].successors.len() == 0
52            && !model.output_outlets()?.contains(&OutletId::new(node.id, 1))
53        {
54            let op = Self { with_index_outputs: None, ..self.clone() };
55            let mut patch = TypedModelPatch::default();
56            let mut wire = patch.tap_model(model, node.inputs[0])?;
57            wire = patch.wire_node(&node.name, op, &[wire])?[0];
58            patch.shunt_outside(model, node.id.into(), wire)?;
59            return Ok(Some(patch));
60        }
61        let fact = model.outlet_fact(node.inputs[0])?;
62        if let Some(pool_spec) = self.pool_spec.declutter(&fact.shape)? {
63            return Ok(Some(TypedModelPatch::replace_single_op(
64                model,
65                node,
66                &node.inputs,
67                Self { pool_spec, ..self.clone() },
68            )?));
69        }
70        Ok(None)
71    }
72
73    as_op!();
74}
75
76impl MaxPool {
77    fn to_optimized(&self, input_shape: &[TDim]) -> TractResult<OptMaxPool> {
78        Ok(OptMaxPool {
79            pool_spec: self.pool_spec.clone(),
80            with_index_outputs: self.with_index_outputs,
81            geometry: self.pool_spec.compute_geo(input_shape)?,
82        })
83    }
84}
85
86#[derive(Debug, Clone, new, Hash)]
87pub struct OptMaxPool {
88    pub pool_spec: PoolSpec,
89    pub with_index_outputs: Option<DatumType>,
90    pub geometry: PoolGeometry,
91}
92
93impl Op for OptMaxPool {
94    fn name(&self) -> Cow<str> {
95        "OptMaxPool".into()
96    }
97
98    fn info(&self) -> TractResult<Vec<String>> {
99        Ok(self.pool_spec.info())
100    }
101
102    op_as_typed_op!();
103}
104
105impl EvalOp for OptMaxPool {
106    fn is_stateless(&self) -> bool {
107        true
108    }
109
110    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
111        let input = args_1!(inputs);
112        let geo = self.geometry.to_concrete(input.shape())?;
113        dispatch_numbers!(Self::eval_t(input.datum_type())(self, &*input, geo.as_ref()))
114    }
115}
116
117impl TypedOp for OptMaxPool {
118    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
119        let mut facts = self.pool_spec.output_facts(inputs)?;
120        if let Some(idt) = self.with_index_outputs {
121            facts.push(facts[0].clone());
122            facts[1].datum_type = idt;
123        }
124        Ok(facts)
125    }
126
127    as_op!();
128}
129
130impl OptMaxPool {
131    fn eval_t<T: Datum + Copy + num_traits::Bounded + PartialOrd>(
132        &self,
133        input: &Tensor,
134        geo: &ConcretePoolGeometry,
135    ) -> TractResult<TVec<TValue>> {
136        let input_dt = input.datum_type();
137        let input: ArrayViewD<T> = input.to_array_view()?;
138        let input_ptr = input.as_ptr();
139
140        let mut values = unsafe { ArrayD::<T>::uninit(&*geo.output_shape.shape).assume_init() };
141        let mut indices = if self.with_index_outputs.is_some() {
142            Some(unsafe { ArrayD::<i32>::uninit(&*geo.output_shape.shape).assume_init() })
143        } else {
144            None
145        };
146        let n = *geo.input_shape.n().unwrap_or(&1);
147        let n_stride_i = geo.input_shape.n_stride().unwrap_or(&0);
148        let n_stride_o = geo.output_shape.n_stride().unwrap_or(&0);
149        unsafe {
150            geo.patch.visit_output(|visitor| {
151                for n in 0..n {
152                    let input_offset = n * n_stride_i;
153                    let output_offset = n * n_stride_o;
154                    for c in 0..*geo.input_shape.c() {
155                        let input_offset = input_offset + geo.input_shape.c_stride() * c;
156                        let output_offset = output_offset + geo.output_shape.c_stride() * c;
157                        let max = visitor
158                            .valid_offsets()
159                            .map(|v| (v, *input_ptr.offset(v + input_offset as isize)))
160                            .fold((0, T::min_value()), |acc, v| if acc.1 < v.1 { v } else { acc });
161                        *values
162                            .as_mut_ptr()
163                            .offset(output_offset as isize + visitor.output_offset) = max.1;
164                        if let Some(ref mut indices) = indices {
165                            *indices
166                                .as_mut_ptr()
167                                .offset(output_offset as isize + visitor.output_offset) =
168                                max.0 as i32 / geo.patch.spec.output_inner_stride as i32;
169                        }
170                    }
171                }
172            });
173        }
174        let mut values = values.into_tensor();
175        unsafe {
176            values.set_datum_type(input_dt);
177        }
178        if let Some(dt) = self.with_index_outputs {
179            Ok(tvec!(
180                values.into_tvalue(),
181                indices.unwrap().into_tensor().cast_to_dt(dt)?.into_owned().into_tvalue()
182            ))
183        } else {
184            Ok(tvec!(values.into_tvalue()))
185        }
186    }
187}