tract_core/ops/cnn/
maxpool.rs1use crate::internal::*;
2use ndarray::prelude::*;
3
4use crate::ops::cnn::pools::{ConcretePoolGeometry, PoolGeometry, PoolSpec};
5
6#[derive(Debug, Clone, new, Hash)]
7pub struct MaxPool {
8 pub pool_spec: PoolSpec,
9 pub with_index_outputs: Option<DatumType>,
10}
11
12impl Op for MaxPool {
13 fn name(&self) -> Cow<str> {
14 "MaxPool".into()
15 }
16
17 fn info(&self) -> TractResult<Vec<String>> {
18 Ok(self.pool_spec.info())
19 }
20
21 op_as_typed_op!();
22}
23
24impl EvalOp for MaxPool {
25 fn is_stateless(&self) -> bool {
26 true
27 }
28
29 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
30 let shape: TVec<TDim> = inputs[0].shape().iter().map(|d| d.to_dim()).collect();
31 self.to_optimized(&shape)?.eval(inputs)
32 }
33}
34
35impl TypedOp for MaxPool {
36 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
37 let mut facts = self.pool_spec.output_facts(inputs)?;
38 if let Some(idt) = self.with_index_outputs {
39 facts.push(facts[0].clone());
40 facts[1].datum_type = idt;
41 }
42 Ok(facts)
43 }
44
45 fn declutter(
46 &self,
47 model: &TypedModel,
48 node: &TypedNode,
49 ) -> TractResult<Option<TypedModelPatch>> {
50 if self.with_index_outputs.is_some()
51 && node.outputs[1].successors.len() == 0
52 && !model.output_outlets()?.contains(&OutletId::new(node.id, 1))
53 {
54 let op = Self { with_index_outputs: None, ..self.clone() };
55 let mut patch = TypedModelPatch::default();
56 let mut wire = patch.tap_model(model, node.inputs[0])?;
57 wire = patch.wire_node(&node.name, op, &[wire])?[0];
58 patch.shunt_outside(model, node.id.into(), wire)?;
59 return Ok(Some(patch));
60 }
61 let fact = model.outlet_fact(node.inputs[0])?;
62 if let Some(pool_spec) = self.pool_spec.declutter(&fact.shape)? {
63 return Ok(Some(TypedModelPatch::replace_single_op(
64 model,
65 node,
66 &node.inputs,
67 Self { pool_spec, ..self.clone() },
68 )?));
69 }
70 Ok(None)
71 }
72
73 as_op!();
74}
75
76impl MaxPool {
77 fn to_optimized(&self, input_shape: &[TDim]) -> TractResult<OptMaxPool> {
78 Ok(OptMaxPool {
79 pool_spec: self.pool_spec.clone(),
80 with_index_outputs: self.with_index_outputs,
81 geometry: self.pool_spec.compute_geo(input_shape)?,
82 })
83 }
84}
85
86#[derive(Debug, Clone, new, Hash)]
87pub struct OptMaxPool {
88 pub pool_spec: PoolSpec,
89 pub with_index_outputs: Option<DatumType>,
90 pub geometry: PoolGeometry,
91}
92
93impl Op for OptMaxPool {
94 fn name(&self) -> Cow<str> {
95 "OptMaxPool".into()
96 }
97
98 fn info(&self) -> TractResult<Vec<String>> {
99 Ok(self.pool_spec.info())
100 }
101
102 op_as_typed_op!();
103}
104
105impl EvalOp for OptMaxPool {
106 fn is_stateless(&self) -> bool {
107 true
108 }
109
110 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
111 let input = args_1!(inputs);
112 let geo = self.geometry.to_concrete(input.shape())?;
113 dispatch_numbers!(Self::eval_t(input.datum_type())(self, &*input, geo.as_ref()))
114 }
115}
116
117impl TypedOp for OptMaxPool {
118 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
119 let mut facts = self.pool_spec.output_facts(inputs)?;
120 if let Some(idt) = self.with_index_outputs {
121 facts.push(facts[0].clone());
122 facts[1].datum_type = idt;
123 }
124 Ok(facts)
125 }
126
127 as_op!();
128}
129
130impl OptMaxPool {
131 fn eval_t<T: Datum + Copy + num_traits::Bounded + PartialOrd>(
132 &self,
133 input: &Tensor,
134 geo: &ConcretePoolGeometry,
135 ) -> TractResult<TVec<TValue>> {
136 let input_dt = input.datum_type();
137 let input: ArrayViewD<T> = input.to_array_view()?;
138 let input_ptr = input.as_ptr();
139
140 let mut values = unsafe { ArrayD::<T>::uninit(&*geo.output_shape.shape).assume_init() };
141 let mut indices = if self.with_index_outputs.is_some() {
142 Some(unsafe { ArrayD::<i32>::uninit(&*geo.output_shape.shape).assume_init() })
143 } else {
144 None
145 };
146 let n = *geo.input_shape.n().unwrap_or(&1);
147 let n_stride_i = geo.input_shape.n_stride().unwrap_or(&0);
148 let n_stride_o = geo.output_shape.n_stride().unwrap_or(&0);
149 unsafe {
150 geo.patch.visit_output(|visitor| {
151 for n in 0..n {
152 let input_offset = n * n_stride_i;
153 let output_offset = n * n_stride_o;
154 for c in 0..*geo.input_shape.c() {
155 let input_offset = input_offset + geo.input_shape.c_stride() * c;
156 let output_offset = output_offset + geo.output_shape.c_stride() * c;
157 let max = visitor
158 .valid_offsets()
159 .map(|v| (v, *input_ptr.offset(v + input_offset as isize)))
160 .fold((0, T::min_value()), |acc, v| if acc.1 < v.1 { v } else { acc });
161 *values
162 .as_mut_ptr()
163 .offset(output_offset as isize + visitor.output_offset) = max.1;
164 if let Some(ref mut indices) = indices {
165 *indices
166 .as_mut_ptr()
167 .offset(output_offset as isize + visitor.output_offset) =
168 max.0 as i32 / geo.patch.spec.output_inner_stride as i32;
169 }
170 }
171 }
172 });
173 }
174 let mut values = values.into_tensor();
175 unsafe {
176 values.set_datum_type(input_dt);
177 }
178 if let Some(dt) = self.with_index_outputs {
179 Ok(tvec!(
180 values.into_tvalue(),
181 indices.unwrap().into_tensor().cast_to_dt(dt)?.into_owned().into_tvalue()
182 ))
183 } else {
184 Ok(tvec!(values.into_tvalue()))
185 }
186 }
187}