1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
use crate::ops::prelude::*; use ndarray::prelude::*; use super::{DataFormat, PaddingSpec, Patch}; #[derive(Debug, Clone, new, Default)] pub struct MaxPool { data_fmt: DataFormat, kernel_shape: TVec<usize>, padding: PaddingSpec, strides: Option<TVec<usize>>, with_index_outputs: Option<DatumType>, } impl MaxPool { fn patch(&self, input_full_shape: &[usize]) -> Patch { let hw_rank = self.data_fmt.shape(input_full_shape).hw_rank(); Patch::new( self.data_fmt, tvec![1; hw_rank], self.kernel_shape.clone(), &self.padding, self.strides.clone().unwrap_or_else(|| tvec![1; hw_rank]), input_full_shape.into(), ) } } impl Op for MaxPool { fn name(&self) -> Cow<str> { "MaxPool".into() } fn noutputs(&self) -> usize { if self.with_index_outputs.is_some() { 2 } else { 1 } } } impl StatelessOp for MaxPool { fn eval(&self, mut inputs: TVec<SharedTensor>) -> TractResult<TVec<SharedTensor>> { let input = args_1!(inputs); let input: ArrayViewD<f32> = input.to_array_view()?; let patch = self.patch(input.shape()); let shape: TVec<usize> = patch.output_full_shape(patch.input_shape.c_dim()); let visitor = patch.wrap(&input); let mut values = unsafe { ArrayD::uninitialized(&*shape) }; let mut indices = if self.with_index_outputs.is_some() { Some(unsafe { ArrayD::uninitialized(&*shape) }) } else { None }; ::ndarray::indices(&*shape).into_iter().for_each(|coords| { let max = visitor .at(&coords.slice()) .enumerate() .filter_map(|(ix, v)| v.map(|v| (ix, v))) .fold( (0, ::std::f32::MIN), |acc, v| if acc.1 < v.1 { v } else { acc }, ); values[&coords] = max.1; if self.with_index_outputs.is_some() { indices.as_mut().unwrap()[coords] = visitor.global_offset_for(&coords.slice(), max.0) as i32; } }); if let Some(dt) = self.with_index_outputs { Ok(tvec!( values.into(), Tensor::from(indices.unwrap()) .cast_to_dt(dt)? .into_owned() .into_tensor() )) } else { Ok(tvec!(values.into())) } } } impl InferenceRulesOp for MaxPool { fn rules<'r, 'p: 'r, 's: 'r>( &'s self, s: &mut Solver<'r>, inputs: &'p SharedTensorsProxy, outputs: &'p SharedTensorsProxy, ) -> InferenceResult { s.equals(&outputs.len, self.noutputs() as i32)?; s.equals(&outputs[0].datum_type, &inputs[0].datum_type)?; s.equals(&outputs[0].rank, &inputs[0].rank)?; if let Some(idt) = self.with_index_outputs { s.equals(&outputs[1].datum_type, idt)?; s.equals(&outputs[1].rank, &inputs[0].rank)?; } s.given(&inputs[0].shape, move |s, ishape| { let ishape = self.data_fmt.shape(ishape); let ones = tvec![1; ishape.hw_rank()]; let computed = self.padding.compute( ishape.hw_dims(), &*self.kernel_shape, &ones, self.strides.as_ref().unwrap_or(&ones), ); for o in 0..self.noutputs() { for (ix, &d) in computed.output.iter().enumerate() { s.equals(&outputs[o].shape[ix + ishape.h_axis()], d)?; } s.equals(&outputs[o].shape[ishape.n_axis()], ishape.n_dim())?; s.equals(&outputs[o].shape[ishape.c_axis()], ishape.c_dim())?; } Ok(()) }) } }