1use crate::internal::*;
2
3#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
4pub enum PaddingSpec {
5 Explicit(TVec<usize>, TVec<usize>),
6 ExplicitOnnxPool(TVec<usize>, TVec<usize>, bool),
7 #[default]
8 Valid,
9 SameUpper,
10 SameLower,
11}
12
13use PaddingSpec::*;
14
15#[derive(Debug, Clone, new, PartialEq, Eq)]
16pub struct ComputedPaddedDim<D: DimLike> {
17 pub deconvoluted: D,
18 pub convoluted: D,
19 pub pad_before: D,
20 pub pad_after: D,
21}
22
23impl PaddingSpec {
24 pub fn valid_dim(&self, d: usize, stride_is_one: bool) -> bool {
25 match self {
26 Valid => true,
27 Explicit(bef, aft) => bef[d] == 0 && aft[d] == 0,
28 ExplicitOnnxPool(a, b, ceil_mode) => {
29 (*ceil_mode || stride_is_one) && a[d] == 0 && b[d] == 0
30 }
31 _ => false,
32 }
33 }
34
35 pub fn change_geo_axes(&self, op: &AxisOp) -> TractResult<PaddingSpec> {
36 match &self {
37 ExplicitOnnxPool(before, after, round) => {
38 let mut before: TVec<usize> = before.clone();
39 let mut after: TVec<usize> = after.clone();
40 op.change_shape_array(&mut before, false)?;
41 op.change_shape_array(&mut after, false)?;
42 if let AxisOp::Add(add) = op {
43 before[*add] = 0;
44 after[*add] = 0;
45 }
46 Ok(ExplicitOnnxPool(before, after, *round))
47 }
48 Explicit(before, after) => {
49 let mut before: TVec<usize> = before.clone();
50 let mut after: TVec<usize> = after.clone();
51 op.change_shape_array(&mut before, false)?;
52 op.change_shape_array(&mut after, false)?;
53 if let AxisOp::Add(add) = op {
54 before[*add] = 0;
55 after[*add] = 0;
56 }
57 Ok(Explicit(before, after))
58 }
59 Valid | SameLower | SameUpper => Ok(self.clone()),
60 }
61 }
62
63 pub fn compute<D: DimLike>(
64 &self,
65 input_spatial_shape: &[D],
66 kernel_spatial_shape: &[usize],
67 dilations: &[usize],
68 strides: &[usize],
69 ) -> TVec<ComputedPaddedDim<D>> {
70 (0..input_spatial_shape.len())
71 .map(|d| {
72 self.compute_one(
73 d,
74 &input_spatial_shape[d],
75 kernel_spatial_shape[d],
76 dilations[d],
77 strides[d],
78 )
79 })
80 .collect()
81 }
82
83 pub fn compute_for_deconv<D: DimLike>(
84 &self,
85 conv_spatial_shape: &[D],
86 kernel_spatial_shape: &[usize],
87 dilations: &[usize],
88 strides: &[usize],
89 adjustments: &[usize],
90 ) -> TractResult<TVec<ComputedPaddedDim<D>>> {
91 (0..conv_spatial_shape.len())
92 .map(|d| {
93 self.compute_one_for_deconv(
94 d,
95 &conv_spatial_shape[d],
96 kernel_spatial_shape[d],
97 dilations[d],
98 strides[d],
99 adjustments[d],
100 )
101 })
102 .collect()
103 }
104
105 pub fn compute_one<D: DimLike>(
106 &self,
107 axis: usize,
108 input: &D,
109 kernel: usize,
110 dilation: usize,
111 stride: usize,
112 ) -> ComputedPaddedDim<D> {
113 match self {
114 Valid => Self::valid(input, kernel, dilation, stride),
115 Explicit(ref bef, ref aft) => {
116 Self::explicit(input, kernel, dilation, stride, bef[axis], aft[axis])
117 }
118 ExplicitOnnxPool(ref bef, ref aft, ceil_mode) => Self::explicit_onnx_pool(
119 input, kernel, dilation, stride, bef[axis], aft[axis], *ceil_mode,
120 ),
121 SameUpper => Self::same(input, kernel, dilation, stride, true),
122 SameLower => Self::same(input, kernel, dilation, stride, false),
123 }
124 }
125
126 pub fn compute_one_for_deconv<D: DimLike>(
127 &self,
128 axis: usize,
129 input: &D,
130 kernel: usize,
131 dilation: usize,
132 stride: usize,
133 adjustment: usize,
134 ) -> TractResult<ComputedPaddedDim<D>> {
135 match self {
136 Valid => Self::valid_for_deconv(input, kernel, dilation, stride, adjustment),
137 SameUpper => Self::same_for_deconv(input, kernel, dilation, stride, adjustment, true),
138 SameLower => Self::same_for_deconv(input, kernel, dilation, stride, adjustment, false),
139 Explicit(ref bef, ref aft) => Self::explicit_for_deconv(
140 input, kernel, dilation, stride, bef[axis], aft[axis], adjustment,
141 ),
142 ExplicitOnnxPool(ref bef, ref aft, _ceil_mode) => Self::explicit_for_deconv(
144 input, kernel, dilation, stride, bef[axis], aft[axis], adjustment,
145 ),
146 }
147 }
148
149 fn valid<D: DimLike>(
150 input: &D,
151 kernel: usize,
152 dilation: usize,
153 stride: usize,
154 ) -> ComputedPaddedDim<D> {
155 let kernel_field = (kernel - 1) * dilation + 1;
156 let output = if let Ok(int) = input.to_usize() {
157 D::from((int + 1).saturating_sub(kernel_field).divceil(stride))
158 } else {
159 (input.clone() + 1 - kernel_field).divceil(stride)
160 };
161 ComputedPaddedDim::new(input.clone(), output, 0.into(), 0.into())
162 }
163
164 fn valid_for_deconv<D: DimLike>(
165 convoluted: &D,
166 kernel: usize,
167 dilation: usize,
168 stride: usize,
169 adjustment: usize,
170 ) -> TractResult<ComputedPaddedDim<D>> {
171 let kernel_field = (kernel - 1) * dilation + 1;
172 let deconvoluted = (convoluted.clone() - 1) * stride + kernel_field + adjustment;
173 Ok(ComputedPaddedDim::new(deconvoluted, convoluted.clone(), 0.into(), 0.into()))
174 }
175
176 fn explicit<D: DimLike>(
177 input: &D,
178 kernel: usize,
179 dilation: usize,
180 stride: usize,
181 bef: usize,
182 aft: usize,
183 ) -> ComputedPaddedDim<D> {
184 if let Ok(i) = input.to_dim().to_usize() {
185 let ints = Self::explicit_usize(i, kernel, dilation, stride, bef, aft);
186 ComputedPaddedDim::new(
187 input.clone(),
188 ints.convoluted.into(),
189 ints.pad_before.into(),
190 ints.pad_after.into(),
191 )
192 } else {
193 let kernel_field = (kernel - 1) * dilation + 1;
194 let dividend = input.clone() + bef + aft - kernel_field;
195 let output = dividend.div(stride) + 1;
196 ComputedPaddedDim::new(input.clone(), output, bef.into(), aft.into())
197 }
198 }
199
200 fn explicit_usize(
201 input: usize,
202 kernel: usize,
203 dilation: usize,
204 stride: usize,
205 bef: usize,
206 aft: usize,
207 ) -> ComputedPaddedDim<usize> {
208 let kernel_field = (kernel - 1) * dilation + 1;
209 let dividend = (input + bef + aft).saturating_sub(kernel_field);
210 let output = dividend / stride + 1;
211 ComputedPaddedDim::new(input, output, bef, aft)
212 }
213
214 fn explicit_onnx_pool<D: DimLike>(
215 input: &D,
216 kernel: usize,
217 dilation: usize,
218 stride: usize,
219 bef: usize,
220 aft: usize,
221 ceil_mode: bool,
222 ) -> ComputedPaddedDim<D> {
223 if let Ok(i) = input.to_dim().to_usize() {
224 let ints =
225 Self::explicit_onnx_pool_usize(i, kernel, dilation, stride, bef, aft, ceil_mode);
226 ComputedPaddedDim::new(
227 input.clone(),
228 ints.convoluted.into(),
229 ints.pad_before.into(),
230 ints.pad_after.into(),
231 )
232 } else {
233 let kernel_field = (kernel - 1) * dilation + 1;
235 let dividend = input.clone() + bef + aft - kernel_field;
236 let output =
237 if ceil_mode { dividend.divceil(stride) } else { dividend.div(stride) } + 1;
238 ComputedPaddedDim::new(input.clone(), output, bef.into(), aft.into())
239 }
240 }
241
242 fn explicit_onnx_pool_usize(
243 input: usize,
244 kernel: usize,
245 dilation: usize,
246 stride: usize,
247 bef: usize,
248 aft: usize,
249 ceil_mode: bool,
250 ) -> ComputedPaddedDim<usize> {
251 let kernel_field = (kernel - 1) * dilation + 1;
253 let dividend = (input + bef + aft).saturating_sub(kernel_field);
254 let mut output = if ceil_mode { dividend.divceil(stride) } else { dividend / stride } + 1;
255 if ceil_mode {
256 if (output - 1) * stride >= input + bef {
259 output -= 1;
260 }
261 }
262 ComputedPaddedDim::new(input, output, bef, aft)
263 }
264
265 fn explicit_for_deconv<D: DimLike>(
266 convoluted: &D,
267 kernel: usize,
268 dilation: usize,
269 stride: usize,
270 bef: usize,
271 aft: usize,
272 adjustment: usize,
273 ) -> TractResult<ComputedPaddedDim<D>> {
274 let kernel_field = (kernel - 1) * dilation + 1;
275 let deconvoluted =
276 (convoluted.clone() - 1) * stride + kernel_field - bef - aft + adjustment;
277 Ok(ComputedPaddedDim::new(deconvoluted, convoluted.clone(), bef.into(), aft.into()))
278 }
279
280 fn same<D: DimLike>(
281 input: &D,
282 kernel: usize,
283 dilation: usize,
284 stride: usize,
285 upper: bool,
286 ) -> ComputedPaddedDim<D> {
287 let output = input.divceil(stride);
288 let kernel_field = (kernel - 1) * dilation + 1;
289 let pad = if let Ok(input) = input.to_usize() {
290 let pad = (((output.clone() - 1) * stride + kernel_field).to_usize().unwrap())
291 .saturating_sub(input);
292 pad.into()
293 } else {
294 (output.clone() - 1) * stride + kernel_field - input
295 };
296 let lower_pad = pad.clone() / 2;
297 let higher_pad = pad - &lower_pad;
298 let (before, after) = if upper { (lower_pad, higher_pad) } else { (higher_pad, lower_pad) };
299 ComputedPaddedDim::new(input.clone(), output, before, after) }
301
302 fn same_for_deconv<D: DimLike>(
303 convoluted: &D,
304 kernel: usize,
305 dilation: usize,
306 stride: usize,
307 adjustment: usize,
308 upper: bool,
309 ) -> TractResult<ComputedPaddedDim<D>> {
310 if (kernel - 1) * dilation < stride {
311 bail!("Invalid axis geometry for SAME padding: expect (kernel_len - 1) * dilation > stride - 1");
312 }
313 let kernel_field = (kernel - 1) * dilation + 1;
314 let crop = kernel_field + adjustment - stride;
315 let lower_crop = crop / 2;
316 let higher_crop = crop - lower_crop;
317 let (before, after) =
318 if upper { (lower_crop, higher_crop) } else { (higher_crop, lower_crop) };
319 let deconvoluted = (convoluted.clone() - 1) * stride + kernel_field - before - after;
320 Ok(ComputedPaddedDim::new(deconvoluted, convoluted.clone(), before.into(), after.into()))
321 }
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327 use PaddingSpec as PS;
328
329 #[test]
330 fn same_stride_1() {
331 assert_eq!(PS::same(&1usize, 2usize, 1, 1, true), ComputedPaddedDim::new(1, 1, 0, 1));
332 assert_eq!(PS::same(&2usize, 2usize, 1, 1, true), ComputedPaddedDim::new(2, 2, 0, 1));
333 assert_eq!(PS::same(&3usize, 2usize, 1, 1, true), ComputedPaddedDim::new(3, 3, 0, 1));
334 assert_eq!(PS::same(&4usize, 2usize, 1, 1, true), ComputedPaddedDim::new(4, 4, 0, 1));
335 }
336
337 #[test]
338 fn same_stride_2() {
339 assert_eq!(PS::same(&1usize, 2usize, 1, 2, true), ComputedPaddedDim::new(1, 1, 0, 1));
340 assert_eq!(PS::same(&2usize, 2usize, 1, 2, true), ComputedPaddedDim::new(2, 1, 0, 0));
341 assert_eq!(PS::same(&3usize, 2usize, 1, 2, true), ComputedPaddedDim::new(3, 2, 0, 1));
342 assert_eq!(PS::same(&4usize, 2usize, 1, 2, true), ComputedPaddedDim::new(4, 2, 0, 0));
343 }
344
345 #[test]
346 fn same_1() {
347 assert_eq!(PS::same(&6usize, 1usize, 1, 2, true), ComputedPaddedDim::new(6, 3, 0, 0));
348 }
349
350 #[test]
351 fn same_lower() {
352 assert_eq!(PS::same(&10usize, 2usize, 1, 3, false), ComputedPaddedDim::new(10, 4, 1, 0));
353 }
354
355 #[test]
356 fn same_ker_3() {
357 assert_eq!(PS::same(&1usize, 3usize, 1, 1, true), ComputedPaddedDim::new(1, 1, 1, 1));
358 assert_eq!(PS::same(&2usize, 3usize, 1, 1, true), ComputedPaddedDim::new(2, 2, 1, 1));
359 assert_eq!(PS::same(&3usize, 3usize, 1, 1, true), ComputedPaddedDim::new(3, 3, 1, 1));
360 assert_eq!(PS::same(&4usize, 3usize, 1, 1, true), ComputedPaddedDim::new(4, 4, 1, 1));
361 }
362
363 #[test]
364 fn same_ker_3_stride_3() {
365 assert_eq!(PS::same(&3usize, 3usize, 1, 3, true), ComputedPaddedDim::new(3, 1, 0, 0));
366 }
367
368 #[test]
369 fn valid_1() {
370 assert_eq!(PS::valid(&10usize, 2usize, 1, 3), ComputedPaddedDim::new(10, 3, 0, 0));
371 }
372
373 #[test]
374 fn explicit_2() {
375 assert_eq!(
376 PS::explicit_onnx_pool(&28usize, 3usize, 1, 1, 2, 2, true),
377 ComputedPaddedDim::new(28, 30, 2, 2)
378 );
379 }
380
381 #[test]
382 #[ignore = "ONNX weird output computation for explicit"]
383 fn explicit_3() {
384 assert_eq!(
385 PS::explicit_onnx_pool(&2usize, 1usize, 1, 2, 0, 0, true),
386 ComputedPaddedDim::new(2, 2, 0, 0)
387 );
388 }
389
390 #[test]
391 fn same_upper() {
392 assert_eq!(PS::same(&7usize, 1usize, 1, 2, true), ComputedPaddedDim::new(7, 4, 0, 0));
393 }
394
395 #[test]
398 fn bug_explicit_stride() {
399 assert_eq!(
400 PS::explicit_onnx_pool(&12usize, 3usize, 1, 3, 0, 0, false),
401 ComputedPaddedDim::new(12, 4, 0, 0)
402 );
403 }
404}