tract_core/ops/cnn/
padding.rs

1use crate::internal::*;
2
3#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
4pub enum PaddingSpec {
5    Explicit(TVec<usize>, TVec<usize>),
6    ExplicitOnnxPool(TVec<usize>, TVec<usize>, bool),
7    #[default]
8    Valid,
9    SameUpper,
10    SameLower,
11}
12
13use PaddingSpec::*;
14
15#[derive(Debug, Clone, new, PartialEq, Eq)]
16pub struct ComputedPaddedDim<D: DimLike> {
17    pub deconvoluted: D,
18    pub convoluted: D,
19    pub pad_before: D,
20    pub pad_after: D,
21}
22
23impl PaddingSpec {
24    pub fn valid_dim(&self, d: usize, stride_is_one: bool) -> bool {
25        match self {
26            Valid => true,
27            Explicit(bef, aft) => bef[d] == 0 && aft[d] == 0,
28            ExplicitOnnxPool(a, b, ceil_mode) => {
29                (*ceil_mode || stride_is_one) && a[d] == 0 && b[d] == 0
30            }
31            _ => false,
32        }
33    }
34
35    pub fn change_geo_axes(&self, op: &AxisOp) -> TractResult<PaddingSpec> {
36        match &self {
37            ExplicitOnnxPool(before, after, round) => {
38                let mut before: TVec<usize> = before.clone();
39                let mut after: TVec<usize> = after.clone();
40                op.change_shape_array(&mut before, false)?;
41                op.change_shape_array(&mut after, false)?;
42                if let AxisOp::Add(add) = op {
43                    before[*add] = 0;
44                    after[*add] = 0;
45                }
46                Ok(ExplicitOnnxPool(before, after, *round))
47            }
48            Explicit(before, after) => {
49                let mut before: TVec<usize> = before.clone();
50                let mut after: TVec<usize> = after.clone();
51                op.change_shape_array(&mut before, false)?;
52                op.change_shape_array(&mut after, false)?;
53                if let AxisOp::Add(add) = op {
54                    before[*add] = 0;
55                    after[*add] = 0;
56                }
57                Ok(Explicit(before, after))
58            }
59            Valid | SameLower | SameUpper => Ok(self.clone()),
60        }
61    }
62
63    pub fn compute<D: DimLike>(
64        &self,
65        input_spatial_shape: &[D],
66        kernel_spatial_shape: &[usize],
67        dilations: &[usize],
68        strides: &[usize],
69    ) -> TVec<ComputedPaddedDim<D>> {
70        (0..input_spatial_shape.len())
71            .map(|d| {
72                self.compute_one(
73                    d,
74                    &input_spatial_shape[d],
75                    kernel_spatial_shape[d],
76                    dilations[d],
77                    strides[d],
78                )
79            })
80            .collect()
81    }
82
83    pub fn compute_for_deconv<D: DimLike>(
84        &self,
85        conv_spatial_shape: &[D],
86        kernel_spatial_shape: &[usize],
87        dilations: &[usize],
88        strides: &[usize],
89        adjustments: &[usize],
90    ) -> TractResult<TVec<ComputedPaddedDim<D>>> {
91        (0..conv_spatial_shape.len())
92            .map(|d| {
93                self.compute_one_for_deconv(
94                    d,
95                    &conv_spatial_shape[d],
96                    kernel_spatial_shape[d],
97                    dilations[d],
98                    strides[d],
99                    adjustments[d],
100                )
101            })
102            .collect()
103    }
104
105    pub fn compute_one<D: DimLike>(
106        &self,
107        axis: usize,
108        input: &D,
109        kernel: usize,
110        dilation: usize,
111        stride: usize,
112    ) -> ComputedPaddedDim<D> {
113        match self {
114            Valid => Self::valid(input, kernel, dilation, stride),
115            Explicit(ref bef, ref aft) => {
116                Self::explicit(input, kernel, dilation, stride, bef[axis], aft[axis])
117            }
118            ExplicitOnnxPool(ref bef, ref aft, ceil_mode) => Self::explicit_onnx_pool(
119                input, kernel, dilation, stride, bef[axis], aft[axis], *ceil_mode,
120            ),
121            SameUpper => Self::same(input, kernel, dilation, stride, true),
122            SameLower => Self::same(input, kernel, dilation, stride, false),
123        }
124    }
125
126    pub fn compute_one_for_deconv<D: DimLike>(
127        &self,
128        axis: usize,
129        input: &D,
130        kernel: usize,
131        dilation: usize,
132        stride: usize,
133        adjustment: usize,
134    ) -> TractResult<ComputedPaddedDim<D>> {
135        match self {
136            Valid => Self::valid_for_deconv(input, kernel, dilation, stride, adjustment),
137            SameUpper => Self::same_for_deconv(input, kernel, dilation, stride, adjustment, true),
138            SameLower => Self::same_for_deconv(input, kernel, dilation, stride, adjustment, false),
139            Explicit(ref bef, ref aft) => Self::explicit_for_deconv(
140                input, kernel, dilation, stride, bef[axis], aft[axis], adjustment,
141            ),
142            // unreachable ?
143            ExplicitOnnxPool(ref bef, ref aft, _ceil_mode) => Self::explicit_for_deconv(
144                input, kernel, dilation, stride, bef[axis], aft[axis], adjustment,
145            ),
146        }
147    }
148
149    fn valid<D: DimLike>(
150        input: &D,
151        kernel: usize,
152        dilation: usize,
153        stride: usize,
154    ) -> ComputedPaddedDim<D> {
155        let kernel_field = (kernel - 1) * dilation + 1;
156        let output = if let Ok(int) = input.to_usize() {
157            D::from((int + 1).saturating_sub(kernel_field).divceil(stride))
158        } else {
159            (input.clone() + 1 - kernel_field).divceil(stride)
160        };
161        ComputedPaddedDim::new(input.clone(), output, 0.into(), 0.into())
162    }
163
164    fn valid_for_deconv<D: DimLike>(
165        convoluted: &D,
166        kernel: usize,
167        dilation: usize,
168        stride: usize,
169        adjustment: usize,
170    ) -> TractResult<ComputedPaddedDim<D>> {
171        let kernel_field = (kernel - 1) * dilation + 1;
172        let deconvoluted = (convoluted.clone() - 1) * stride + kernel_field + adjustment;
173        Ok(ComputedPaddedDim::new(deconvoluted, convoluted.clone(), 0.into(), 0.into()))
174    }
175
176    fn explicit<D: DimLike>(
177        input: &D,
178        kernel: usize,
179        dilation: usize,
180        stride: usize,
181        bef: usize,
182        aft: usize,
183    ) -> ComputedPaddedDim<D> {
184        if let Ok(i) = input.to_dim().to_usize() {
185            let ints = Self::explicit_usize(i, kernel, dilation, stride, bef, aft);
186            ComputedPaddedDim::new(
187                input.clone(),
188                ints.convoluted.into(),
189                ints.pad_before.into(),
190                ints.pad_after.into(),
191            )
192        } else {
193            let kernel_field = (kernel - 1) * dilation + 1;
194            let dividend = input.clone() + bef + aft - kernel_field;
195            let output = dividend.div(stride) + 1;
196            ComputedPaddedDim::new(input.clone(), output, bef.into(), aft.into())
197        }
198    }
199
200    fn explicit_usize(
201        input: usize,
202        kernel: usize,
203        dilation: usize,
204        stride: usize,
205        bef: usize,
206        aft: usize,
207    ) -> ComputedPaddedDim<usize> {
208        let kernel_field = (kernel - 1) * dilation + 1;
209        let dividend = (input + bef + aft).saturating_sub(kernel_field);
210        let output = dividend / stride + 1;
211        ComputedPaddedDim::new(input, output, bef, aft)
212    }
213
214    fn explicit_onnx_pool<D: DimLike>(
215        input: &D,
216        kernel: usize,
217        dilation: usize,
218        stride: usize,
219        bef: usize,
220        aft: usize,
221        ceil_mode: bool,
222    ) -> ComputedPaddedDim<D> {
223        if let Ok(i) = input.to_dim().to_usize() {
224            let ints =
225                Self::explicit_onnx_pool_usize(i, kernel, dilation, stride, bef, aft, ceil_mode);
226            ComputedPaddedDim::new(
227                input.clone(),
228                ints.convoluted.into(),
229                ints.pad_before.into(),
230                ints.pad_after.into(),
231            )
232        } else {
233            // output_spatial_shape[i] = ceil((input_spatial_shape[i] + pad_shape[i] - ((kernel_spatial_shape[i] - 1) * dilations[i] + 1)) / strides_spatial_shape[i] + 1)
234            let kernel_field = (kernel - 1) * dilation + 1;
235            let dividend = input.clone() + bef + aft - kernel_field;
236            let output =
237                if ceil_mode { dividend.divceil(stride) } else { dividend.div(stride) } + 1;
238            ComputedPaddedDim::new(input.clone(), output, bef.into(), aft.into())
239        }
240    }
241
242    fn explicit_onnx_pool_usize(
243        input: usize,
244        kernel: usize,
245        dilation: usize,
246        stride: usize,
247        bef: usize,
248        aft: usize,
249        ceil_mode: bool,
250    ) -> ComputedPaddedDim<usize> {
251        // output_spatial_shape[i] = ceil((input_spatial_shape[i] + pad_shape[i] - ((kernel_spatial_shape[i] - 1) * dilations[i] + 1)) / strides_spatial_shape[i] + 1)
252        let kernel_field = (kernel - 1) * dilation + 1;
253        let dividend = (input + bef + aft).saturating_sub(kernel_field);
254        let mut output = if ceil_mode { dividend.divceil(stride) } else { dividend / stride } + 1;
255        if ceil_mode {
256            // ensure that the last pooling starts inside the image
257            // needed to avoid problems in ceil mode
258            if (output - 1) * stride >= input + bef {
259                output -= 1;
260            }
261        }
262        ComputedPaddedDim::new(input, output, bef, aft)
263    }
264
265    fn explicit_for_deconv<D: DimLike>(
266        convoluted: &D,
267        kernel: usize,
268        dilation: usize,
269        stride: usize,
270        bef: usize,
271        aft: usize,
272        adjustment: usize,
273    ) -> TractResult<ComputedPaddedDim<D>> {
274        let kernel_field = (kernel - 1) * dilation + 1;
275        let deconvoluted =
276            (convoluted.clone() - 1) * stride + kernel_field - bef - aft + adjustment;
277        Ok(ComputedPaddedDim::new(deconvoluted, convoluted.clone(), bef.into(), aft.into()))
278    }
279
280    fn same<D: DimLike>(
281        input: &D,
282        kernel: usize,
283        dilation: usize,
284        stride: usize,
285        upper: bool,
286    ) -> ComputedPaddedDim<D> {
287        let output = input.divceil(stride);
288        let kernel_field = (kernel - 1) * dilation + 1;
289        let pad = if let Ok(input) = input.to_usize() {
290            let pad = (((output.clone() - 1) * stride + kernel_field).to_usize().unwrap())
291                .saturating_sub(input);
292            pad.into()
293        } else {
294            (output.clone() - 1) * stride + kernel_field - input
295        };
296        let lower_pad = pad.clone() / 2;
297        let higher_pad = pad - &lower_pad;
298        let (before, after) = if upper { (lower_pad, higher_pad) } else { (higher_pad, lower_pad) };
299        ComputedPaddedDim::new(input.clone(), output, before, after) // TODO input is wrong for stride != 1
300    }
301
302    fn same_for_deconv<D: DimLike>(
303        convoluted: &D,
304        kernel: usize,
305        dilation: usize,
306        stride: usize,
307        adjustment: usize,
308        upper: bool,
309    ) -> TractResult<ComputedPaddedDim<D>> {
310        if (kernel - 1) * dilation < stride {
311            bail!("Invalid axis geometry for SAME padding: expect (kernel_len - 1) * dilation > stride - 1");
312        }
313        let kernel_field = (kernel - 1) * dilation + 1;
314        let crop = kernel_field + adjustment - stride;
315        let lower_crop = crop / 2;
316        let higher_crop = crop - lower_crop;
317        let (before, after) =
318            if upper { (lower_crop, higher_crop) } else { (higher_crop, lower_crop) };
319        let deconvoluted = (convoluted.clone() - 1) * stride + kernel_field - before - after;
320        Ok(ComputedPaddedDim::new(deconvoluted, convoluted.clone(), before.into(), after.into()))
321    }
322}
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327    use PaddingSpec as PS;
328
329    #[test]
330    fn same_stride_1() {
331        assert_eq!(PS::same(&1usize, 2usize, 1, 1, true), ComputedPaddedDim::new(1, 1, 0, 1));
332        assert_eq!(PS::same(&2usize, 2usize, 1, 1, true), ComputedPaddedDim::new(2, 2, 0, 1));
333        assert_eq!(PS::same(&3usize, 2usize, 1, 1, true), ComputedPaddedDim::new(3, 3, 0, 1));
334        assert_eq!(PS::same(&4usize, 2usize, 1, 1, true), ComputedPaddedDim::new(4, 4, 0, 1));
335    }
336
337    #[test]
338    fn same_stride_2() {
339        assert_eq!(PS::same(&1usize, 2usize, 1, 2, true), ComputedPaddedDim::new(1, 1, 0, 1));
340        assert_eq!(PS::same(&2usize, 2usize, 1, 2, true), ComputedPaddedDim::new(2, 1, 0, 0));
341        assert_eq!(PS::same(&3usize, 2usize, 1, 2, true), ComputedPaddedDim::new(3, 2, 0, 1));
342        assert_eq!(PS::same(&4usize, 2usize, 1, 2, true), ComputedPaddedDim::new(4, 2, 0, 0));
343    }
344
345    #[test]
346    fn same_1() {
347        assert_eq!(PS::same(&6usize, 1usize, 1, 2, true), ComputedPaddedDim::new(6, 3, 0, 0));
348    }
349
350    #[test]
351    fn same_lower() {
352        assert_eq!(PS::same(&10usize, 2usize, 1, 3, false), ComputedPaddedDim::new(10, 4, 1, 0));
353    }
354
355    #[test]
356    fn same_ker_3() {
357        assert_eq!(PS::same(&1usize, 3usize, 1, 1, true), ComputedPaddedDim::new(1, 1, 1, 1));
358        assert_eq!(PS::same(&2usize, 3usize, 1, 1, true), ComputedPaddedDim::new(2, 2, 1, 1));
359        assert_eq!(PS::same(&3usize, 3usize, 1, 1, true), ComputedPaddedDim::new(3, 3, 1, 1));
360        assert_eq!(PS::same(&4usize, 3usize, 1, 1, true), ComputedPaddedDim::new(4, 4, 1, 1));
361    }
362
363    #[test]
364    fn same_ker_3_stride_3() {
365        assert_eq!(PS::same(&3usize, 3usize, 1, 3, true), ComputedPaddedDim::new(3, 1, 0, 0));
366    }
367
368    #[test]
369    fn valid_1() {
370        assert_eq!(PS::valid(&10usize, 2usize, 1, 3), ComputedPaddedDim::new(10, 3, 0, 0));
371    }
372
373    #[test]
374    fn explicit_2() {
375        assert_eq!(
376            PS::explicit_onnx_pool(&28usize, 3usize, 1, 1, 2, 2, true),
377            ComputedPaddedDim::new(28, 30, 2, 2)
378        );
379    }
380
381    #[test]
382    #[ignore = "ONNX weird output computation for explicit"]
383    fn explicit_3() {
384        assert_eq!(
385            PS::explicit_onnx_pool(&2usize, 1usize, 1, 2, 0, 0, true),
386            ComputedPaddedDim::new(2, 2, 0, 0)
387        );
388    }
389
390    #[test]
391    fn same_upper() {
392        assert_eq!(PS::same(&7usize, 1usize, 1, 2, true), ComputedPaddedDim::new(7, 4, 0, 0));
393    }
394
395    // 0 1 2 3 4 5 6 7 8 9 a b
396    // 012 345 678 9ab
397    #[test]
398    fn bug_explicit_stride() {
399        assert_eq!(
400            PS::explicit_onnx_pool(&12usize, 3usize, 1, 3, 0, 0, false),
401            ComputedPaddedDim::new(12, 4, 0, 0)
402        );
403    }
404}