tract_core/ops/cnn/
patches.rs

1use crate::internal::*;
2use crate::ops::cnn::PaddingSpec;
3use crate::ops::nn::{DataFormat, DataShape};
4use ndarray::prelude::*;
5
6use super::PatchAxis;
7
8use std::fmt::Debug;
9use std::ops::Range;
10
11use tract_itertools::{izip, Itertools};
12
13#[derive(Clone, PartialEq, Eq, Hash)]
14pub struct PatchSpec {
15    pub input_shape: TVec<usize>,
16    pub input_inner_stride: usize,
17    pub output_inner_stride: usize,
18    pub kernel_shape: TVec<usize>,
19    pub strides: TVec<usize>,
20    pub dilations: TVec<usize>,
21    pub padding: PaddingSpec,
22}
23
24impl Debug for PatchSpec {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        write!(
27            f,
28            "input: {} kernel: {} strides: {} dil: {} pad: {:?}",
29            self.input_shape.iter().join(","),
30            self.kernel_shape.iter().join(","),
31            self.strides.iter().join(","),
32            self.dilations.iter().join(","),
33            self.padding
34        )
35    }
36}
37
38impl PatchSpec {
39    pub fn for_full_shape(
40        data_format: DataFormat,
41        input_full_shape: &[usize],
42    ) -> TractResult<PatchSpec> {
43        let shape = data_format.shape(input_full_shape.into())?;
44        Ok(Self::for_data_shape(shape))
45    }
46
47    pub fn for_data_shape(data_shape: DataShape) -> PatchSpec {
48        let input_shape: TVec<usize> = data_shape.hw_dims().into();
49        PatchSpec {
50            kernel_shape: tvec!(1; input_shape.len()),
51            input_inner_stride: *data_shape.w_stride(),
52            output_inner_stride: 1,
53            strides: tvec!(1; input_shape.len()),
54            dilations: tvec!(1; input_shape.len()),
55            padding: PaddingSpec::Valid,
56            input_shape,
57        }
58    }
59
60    pub fn with_kernel_shape(self, kernel_shape: TVec<usize>) -> PatchSpec {
61        PatchSpec { kernel_shape, ..self }
62    }
63
64    pub fn with_dilations(self, dilations: TVec<usize>) -> PatchSpec {
65        PatchSpec { dilations, ..self }
66    }
67
68    pub fn with_strides(self, strides: TVec<usize>) -> PatchSpec {
69        PatchSpec { strides, ..self }
70    }
71
72    pub fn with_padding(self, padding: PaddingSpec) -> PatchSpec {
73        PatchSpec { padding, ..self }
74    }
75
76    pub fn with_output_inner_stride(self, output_inner_stride: usize) -> PatchSpec {
77        PatchSpec { output_inner_stride, ..self }
78    }
79
80    pub fn into_patch(self) -> Patch {
81        let dims = self.padding.compute(
82            &self.input_shape,
83            &self.kernel_shape,
84            &self.dilations,
85            &self.strides,
86        );
87        let output: TVec<usize> = dims.iter().map(|d| d.convoluted).collect();
88        let pad_before: TVec<usize> = dims.iter().map(|d| d.pad_before).collect();
89        let pad_after: TVec<usize> = dims.iter().map(|d| d.pad_after).collect();
90
91        let data_field: Vec<isize> = ::ndarray::indices(&*self.kernel_shape)
92            .into_iter()
93            .flat_map(|coords| {
94                #[allow(clippy::unnecessary_to_owned)] // I think this one is a clippy bug.
95                coords
96                    .slice()
97                    .to_vec()
98                    .into_iter()
99                    .enumerate()
100                    .map(|(ix, c)| (c * self.dilations[ix]) as isize - pad_before[ix] as isize)
101            })
102            .collect();
103        let data_field = Array2::from_shape_vec(
104            (self.kernel_shape.iter().cloned().product(), self.kernel_shape.len()),
105            data_field,
106        )
107        .unwrap();
108        let data_field_min_max: TVec<_> = data_field
109            .columns()
110            .into_iter()
111            .map(|col| (col.iter().min().cloned().unwrap(), col.iter().max().cloned().unwrap()))
112            .collect();
113
114        fn strides(shape: &[usize], inner: usize) -> TVec<isize> {
115            let mut strides: TVec<isize> = tvec![inner as isize];
116            for dim in shape.iter().skip(1).rev() {
117                let previous = *strides.last().unwrap();
118                strides.push(*dim as isize * previous);
119            }
120            strides.reverse();
121            strides
122        }
123
124        let input_storage_strides = strides(&self.input_shape, self.input_inner_stride);
125        let output_storage_strides = strides(&output, self.output_inner_stride);
126
127        let standard_layout_data_field: Vec<isize> = data_field
128            .outer_iter()
129            .map(|coords| izip!(coords, &input_storage_strides).map(|(a, b)| a * b).sum::<isize>())
130            .collect();
131
132        // regions[axis][range+mask]
133        let regions: TVec<TVec<_>> = dims
134            .iter()
135            .enumerate()
136            .map(|(ix, d)| {
137                PatchAxis {
138                    input_dim: self.input_shape[ix],
139                    kernel_dim: self.kernel_shape[ix],
140                    pad_before: d.pad_before,
141                    pad_after: d.pad_after,
142                    output_dim: d.convoluted,
143                    stride: self.strides[ix],
144                    dilation: self.dilations[ix],
145                }
146                .regions()
147            })
148            .collect::<TVec<_>>();
149
150        let zone_strides = strides(&regions.iter().map(|d| d.len()).collect::<TVec<_>>(), 1);
151
152        let zones: Vec<Zone> = regions
153            .iter()
154            .multi_cartesian_product()
155            .map(|regions| Zone {
156                input_zone_offset: 0,
157                output_ranges: regions.iter().map(|reg| reg.range.clone()).collect(),
158                output_shape: regions.iter().map(|reg| reg.range.end - reg.range.start).collect(),
159                output_zone_offset: izip!(&regions, &output_storage_strides)
160                    .map(|(reg, &stride)| reg.range.start as isize * stride)
161                    .sum::<isize>(),
162                valid: regions.iter().all(|reg| reg.mask.is_none()),
163                values_offsets: izip!(
164                    0..,
165                    ndarray::indices(&*self.kernel_shape),
166                    &standard_layout_data_field
167                )
168                .filter(|(_ix, coords, _offset)| {
169                    izip!(coords.slice(), &regions)
170                        .all(|(&x, axis)| !axis.mask.as_ref().map(|mask| mask[x]).unwrap_or(false))
171                })
172                .map(|(ix, _coords, &window_offset)| (ix, window_offset))
173                .collect(),
174            })
175            .collect();
176
177        let valid_zone = zones.iter().position(|z| z.valid);
178
179        let mut valid_output_zone = tvec!();
180        let mut invalid_output_zones = tvec!();
181        for ix in 0..self.input_shape.len() {
182            let min_max = data_field_min_max[ix];
183            let min = (-min_max.0 as usize).divceil(self.strides[ix]);
184            let max =
185                (self.input_shape[ix].saturating_sub(min_max.1 as usize)).divceil(self.strides[ix]);
186            if min != 0 {
187                let mut invalid = valid_output_zone.clone();
188                invalid.push(0..min);
189                while invalid.len() < output.len() {
190                    invalid.push(0..output[invalid.len()])
191                }
192                invalid_output_zones.push(invalid);
193            }
194            if max < output[ix] {
195                let mut invalid = valid_output_zone.clone();
196                invalid.push(max..output[ix]);
197                while invalid.len() < output.len() {
198                    invalid.push(0..output[invalid.len()])
199                }
200                invalid_output_zones.push(invalid);
201            }
202            valid_output_zone.push(min..max)
203        }
204
205        let op_strides_times_input_storage_strides =
206            izip!(&self.strides, &input_storage_strides).map(|(a, b)| (*a as isize * b)).collect();
207
208        Patch {
209            spec: self,
210            padded: pad_before.iter().any(|&p| p != 0) || pad_after.iter().any(|&p| p != 0),
211            pad_before,
212            pad_after,
213            output_shape: output,
214            data_field,
215            data_field_min_max,
216            standard_layout_data_field,
217            input_storage_strides,
218            output_storage_strides,
219            op_strides_times_input_storage_strides,
220            valid_output_zone,
221            invalid_output_zones,
222            zones,
223            valid_zone_id: valid_zone,
224            zone_strides,
225        }
226    }
227}
228
229#[derive(Clone, PartialEq, Eq, Hash)]
230pub struct Patch {
231    pub spec: PatchSpec,
232    pub pad_before: TVec<usize>,
233    pub pad_after: TVec<usize>,
234    pub padded: bool,
235    pub output_shape: TVec<usize>,
236    pub data_field: Array2<isize>,
237    pub data_field_min_max: TVec<(isize, isize)>,
238    pub standard_layout_data_field: Vec<isize>,
239    pub op_strides_times_input_storage_strides: TVec<isize>,
240    pub valid_output_zone: TVec<Range<usize>>,
241    pub invalid_output_zones: TVec<TVec<Range<usize>>>,
242    pub zones: Vec<Zone>,
243    pub valid_zone_id: Option<usize>,
244    pub zone_strides: TVec<isize>,
245    pub input_storage_strides: TVec<isize>,
246    pub output_storage_strides: TVec<isize>,
247}
248
249impl Debug for Patch {
250    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
251        write!(f, "{:?}", self.spec)
252    }
253}
254
255impl Patch {
256    #[inline]
257    pub fn rank(&self) -> usize {
258        self.spec.input_shape.len()
259    }
260
261    unsafe fn is_valid(&self, coords: &[usize]) -> bool {
262        for ix in 0..self.rank() {
263            let c = *coords.get_unchecked(ix) as isize;
264            let strides = *self.spec.strides.get_unchecked(ix) as isize;
265            let pos = c * strides;
266            let min_max = self.data_field_min_max.get_unchecked(ix);
267            if pos + min_max.0 < 0
268                || pos + min_max.1 >= *self.spec.input_shape.get_unchecked(ix) as isize
269            {
270                return false;
271            }
272        }
273        true
274    }
275
276    pub fn valid_zone(&self) -> Option<&Zone> {
277        self.valid_zone_id.map(|id| &self.zones[id])
278    }
279
280    #[inline]
281    pub fn visit_output(&self, mut acceptor: impl FnMut(&Scanner)) {
282        if self.zones.len() == 0 {
283            return;
284        }
285        let mut scanner = Scanner::new(self);
286        while !scanner.done() {
287            acceptor(&scanner);
288            scanner.next();
289        }
290    }
291
292    pub fn centers_offsets(&self) -> Vec<isize> {
293        if self.zones.len() == 0 {
294            return vec![];
295        }
296        let mut scanner = Scanner::new(self);
297        let len = self.output_shape.iter().cloned().product();
298        let mut v = Vec::with_capacity(len);
299        for _ in 0..len {
300            v.push(scanner.input_center_offset);
301            scanner.next()
302        }
303        v
304    }
305
306    pub fn at<'p>(&'p self, coords: &[usize]) -> PatchIterator<'p> {
307        self.at_hint(coords, None)
308    }
309
310    pub fn at_hint<'p>(&'p self, coords: &[usize], hint: Option<bool>) -> PatchIterator<'p> {
311        unsafe {
312            assert_eq!(coords.len(), self.spec.kernel_shape.len());
313            let mut center = 0;
314            for i in 0..self.op_strides_times_input_storage_strides.len() {
315                center += *self.op_strides_times_input_storage_strides.get_unchecked(i)
316                    * *coords.get_unchecked(i) as isize;
317            }
318            let valid = hint.unwrap_or_else(|| !self.padded || self.is_valid(coords));
319            if valid {
320                PatchIterator::Fast(FastPatchIterator { patch: self, center, item: 0 })
321            } else {
322                let mut input_patch_center: TVec<_> = coords.into();
323                input_patch_center
324                    .iter_mut()
325                    .zip(self.spec.strides.iter())
326                    .for_each(|(a, &b)| *a *= b);
327                PatchIterator::Safe(SafePatchIterator {
328                    patch: self,
329                    item: 0,
330                    input_patch_center,
331                    center,
332                })
333            }
334        }
335    }
336
337    pub fn global_offset_for(&self, coords: &[usize], patch_index: usize) -> usize {
338        assert_eq!(coords.len(), self.spec.kernel_shape.len());
339        let center = izip!(coords, &self.op_strides_times_input_storage_strides)
340            .map(|(a, b)| *a as isize * *b)
341            .sum::<isize>();
342        (center + self.standard_layout_data_field[patch_index]) as usize
343    }
344}
345
346#[derive(Clone, Debug, PartialEq, Eq, Hash)]
347pub struct Zone {
348    pub valid: bool,
349    pub input_zone_offset: isize,
350    pub output_zone_offset: isize,
351    pub output_ranges: Box<[Range<usize>]>,
352    pub output_shape: Box<[usize]>,
353    /// (index in kernel, offset from center in image)
354    pub values_offsets: Box<[(usize, isize)]>,
355}
356
357impl Zone {
358    pub fn contains_output(&self, coords: &[usize]) -> bool {
359        self.output_ranges.iter().zip(coords).all(|(range, &x)| x >= range.start && x < range.end)
360    }
361
362    #[inline]
363    pub fn visit_output(&self, patch: &Patch, mut acceptor: impl FnMut(&ZoneScanner)) {
364        let mut scanner = ZoneScanner::new(self, patch);
365        while !scanner.done() {
366            acceptor(&scanner);
367            scanner.next();
368        }
369    }
370}
371
372#[derive(Clone, Debug, PartialEq, Eq)]
373pub struct ZoneScanner<'p> {
374    pub patch: &'p Patch,
375    pub zone: &'p Zone,
376    pub output_offset: isize,
377    pub output_coords: Box<[usize]>,
378    pub input_center_offset: isize,
379    pub inner_loop_axis: usize,
380    pub inner_loop_len: usize,
381    pub inner_loop_output_range: Range<usize>,
382    pub inner_loop_output_stride: isize,
383    pub inner_loop_input_full_stride: isize,
384    pub done: bool,
385}
386
387impl<'p> ZoneScanner<'p> {
388    pub fn new(zone: &'p Zone, patch: &'p Patch) -> ZoneScanner<'p> {
389        let inner_loop_axis =
390            zone.output_shape.iter().enumerate().max_by_key(|(_, dim)| *dim).unwrap().0;
391        let inner_loop_output_range = zone.output_ranges[inner_loop_axis].clone();
392        let inner_loop_output_stride = patch.output_storage_strides[inner_loop_axis];
393        let inner_loop_input_full_stride =
394            patch.op_strides_times_input_storage_strides[inner_loop_axis];
395        let mut scan = ZoneScanner {
396            patch,
397            zone,
398            output_offset: 0,
399            input_center_offset: 0,
400            inner_loop_axis,
401            inner_loop_len: inner_loop_output_range.len(),
402            inner_loop_output_range,
403            inner_loop_output_stride,
404            inner_loop_input_full_stride,
405            output_coords: zone.output_ranges.iter().map(|r| r.start).collect(),
406            done: false,
407        };
408        scan.refresh_dependent();
409        scan
410    }
411
412    #[inline]
413    pub fn valid_offsets_ker_in(&self) -> impl Iterator<Item = (usize, isize)> + '_ {
414        self.zone.values_offsets.iter().map(move |pair| (pair.0, pair.1 + self.input_center_offset))
415    }
416
417    pub unsafe fn next_non_inner_axis(&mut self) {
418        let rank = self.patch.rank();
419        let inner_loop_axis = self.inner_loop_axis;
420        for axis in (0..rank).rev() {
421            if axis == inner_loop_axis {
422                continue;
423            }
424            *self.output_coords.get_unchecked_mut(axis) += 1;
425            if *self.output_coords.get_unchecked_mut(axis)
426                < self.zone.output_ranges.get_unchecked(axis).end
427            {
428                self.refresh_dependent();
429                return;
430            }
431            *self.output_coords.get_unchecked_mut(axis) =
432                self.zone.output_ranges.get_unchecked(axis).start;
433        }
434        self.done = true;
435    }
436
437    pub unsafe fn reset(&mut self) {
438        self.output_offset = 0;
439        self.input_center_offset = 0;
440        for ix in 0..self.output_coords.len() {
441            *self.output_coords.get_unchecked_mut(ix) =
442                self.zone.output_ranges.get_unchecked(ix).start;
443        }
444        self.done = false;
445        self.refresh_dependent()
446    }
447
448    #[inline(never)]
449    fn refresh_dependent(&mut self) {
450        self.input_center_offset = self
451            .patch
452            .op_strides_times_input_storage_strides
453            .iter()
454            .zip(self.output_coords.iter())
455            .map(|(a, b)| *a * *b as isize)
456            .sum();
457        self.output_offset = self
458            .patch
459            .output_storage_strides
460            .iter()
461            .zip(self.output_coords.iter())
462            .map(|(a, b)| a * *b as isize)
463            .sum();
464    }
465
466    #[inline]
467    pub fn next(&mut self) {
468        let inner_loop_axis = self.inner_loop_axis;
469        unsafe {
470            *self.output_coords.get_unchecked_mut(inner_loop_axis) += 1;
471            if *self.output_coords.get_unchecked(inner_loop_axis) < self.inner_loop_output_range.end
472            {
473                self.input_center_offset += self.inner_loop_input_full_stride;
474                self.output_offset += self.inner_loop_output_stride;
475            } else {
476                *self.output_coords.get_unchecked_mut(inner_loop_axis) =
477                    self.inner_loop_output_range.start;
478                self.next_non_inner_axis();
479            }
480        }
481    }
482
483    pub fn done(&self) -> bool {
484        self.done
485    }
486}
487
488#[derive(Clone, Debug, PartialEq, Eq)]
489pub struct Scanner<'p> {
490    pub patch: &'p Patch,
491    pub zone_id: usize,
492    pub zone_coords: TVec<usize>,
493    pub zone: &'p Zone,
494    pub output_offset: isize,
495    pub output_coords: TVec<usize>,
496    pub input_coords: TVec<usize>,
497    pub input_center_offset: isize,
498    done: bool,
499}
500
501impl<'p> Scanner<'p> {
502    fn new(patch: &'p Patch) -> Scanner<'p> {
503        let rank = patch.rank();
504        let zone = &patch.zones[0];
505        Scanner {
506            patch,
507            zone_coords: tvec!(0; rank),
508            zone,
509            zone_id: 0,
510            output_offset: 0,
511            input_center_offset: 0,
512            input_coords: tvec!(0; rank),
513            output_coords: tvec!(0; rank),
514            done: false,
515        }
516    }
517
518    #[inline]
519    pub fn valid_count(&self) -> usize {
520        self.zone.values_offsets.len()
521    }
522
523    #[inline]
524    pub fn valid_offsets(&self) -> impl Iterator<Item = isize> + '_ {
525        self.zone.values_offsets.iter().map(move |pair| pair.1 + self.input_center_offset)
526    }
527
528    #[inline]
529    pub fn valid_offsets_ker_in(&self) -> impl Iterator<Item = (usize, isize)> + '_ {
530        self.zone.values_offsets.iter().map(move |pair| (pair.0, pair.1 + self.input_center_offset))
531    }
532
533    #[inline]
534    pub fn next(&mut self) {
535        let rank = self.patch.rank();
536        let inner_dim = rank - 1;
537        unsafe {
538            *self.output_coords.get_unchecked_mut(inner_dim) += 1;
539            *self.input_coords.get_unchecked_mut(inner_dim) +=
540                *self.patch.spec.strides.get_unchecked(inner_dim);
541            self.output_offset += self.patch.spec.output_inner_stride as isize;
542            self.input_center_offset +=
543                self.patch.op_strides_times_input_storage_strides.get_unchecked(inner_dim);
544            if *self.output_coords.get_unchecked(inner_dim)
545                < self.zone.output_ranges.get_unchecked(inner_dim).end
546            {
547                return;
548            }
549            if self.output_coords.get_unchecked(inner_dim)
550                < self.patch.output_shape.get_unchecked(inner_dim)
551            {
552                self.zone_id += 1;
553                *self.zone_coords.get_unchecked_mut(inner_dim) += 1;
554                self.zone = self.patch.zones.get_unchecked(self.zone_id);
555            } else {
556                for axis in (0..rank - 1).rev() {
557                    *self.output_coords.get_unchecked_mut(axis + 1) = 0;
558                    *self.input_coords.get_unchecked_mut(axis + 1) = 0;
559                    *self.output_coords.get_unchecked_mut(axis) += 1;
560                    *self.input_coords.get_unchecked_mut(axis) +=
561                        self.patch.spec.strides.get_unchecked(axis);
562                    *self.zone_coords.get_unchecked_mut(axis + 1) = 0;
563                    if *self.output_coords.get_unchecked(axis)
564                        == self.zone.output_ranges.get_unchecked(axis).end
565                    {
566                        *self.zone_coords.get_unchecked_mut(axis) += 1;
567                    }
568                    if *self.output_coords.get_unchecked(axis)
569                        < *self.patch.output_shape.get_unchecked(axis)
570                    {
571                        break;
572                    }
573                }
574                if self.output_coords.get_unchecked(0) == self.patch.output_shape.get_unchecked(0) {
575                    self.done = true;
576                    return;
577                }
578                self.zone_id = 0;
579                self.input_center_offset = 0;
580                for i in 0..rank {
581                    self.zone_id += *self.zone_coords.get_unchecked(i)
582                        * *self.patch.zone_strides.get_unchecked(i) as usize;
583                    self.input_center_offset += *self.input_coords.get_unchecked(i) as isize
584                        * *self.patch.input_storage_strides.get_unchecked(i);
585                }
586                self.zone = self.patch.zones.get_unchecked(self.zone_id);
587            }
588        }
589    }
590
591    pub fn done(&self) -> bool {
592        self.done
593    }
594}
595
596#[derive(Debug)]
597pub enum PatchIterator<'p> {
598    Fast(FastPatchIterator<'p>),
599    Safe(SafePatchIterator<'p>),
600}
601
602impl Iterator for PatchIterator<'_> {
603    type Item = Option<isize>;
604    #[inline(always)]
605    fn next(&mut self) -> Option<Option<isize>> {
606        match self {
607            PatchIterator::Fast(ref mut it) => it.next(),
608            PatchIterator::Safe(ref mut it) => it.next(),
609        }
610    }
611}
612
613#[derive(Debug)]
614pub struct FastPatchIterator<'p> {
615    patch: &'p Patch,
616    center: isize,
617    item: usize,
618}
619
620impl Iterator for FastPatchIterator<'_> {
621    type Item = Option<isize>;
622    #[inline(always)]
623    fn next(&mut self) -> Option<Option<isize>> {
624        if self.item == self.patch.standard_layout_data_field.len() {
625            return None;
626        }
627        unsafe {
628            let position =
629                self.center + self.patch.standard_layout_data_field.get_unchecked(self.item);
630            self.item += 1;
631            Some(Some(position))
632        }
633    }
634}
635
636#[derive(Debug)]
637pub struct SafePatchIterator<'p> {
638    patch: &'p Patch,
639    item: usize,
640    input_patch_center: TVec<usize>,
641    center: isize,
642}
643
644impl Iterator for SafePatchIterator<'_> {
645    type Item = Option<isize>;
646    fn next(&mut self) -> Option<Option<isize>> {
647        unsafe {
648            if self.item == self.patch.standard_layout_data_field.len() {
649                return None;
650            }
651            let input_shape = &self.patch.spec.input_shape;
652            let img_offset = self.patch.data_field.as_ptr().add(self.item * input_shape.len());
653
654            for ix in 0..input_shape.len() {
655                let pos = *self.input_patch_center.get_unchecked(ix) as isize + *img_offset.add(ix);
656                if pos < 0 || pos as usize >= *input_shape.get_unchecked(ix) {
657                    self.item += 1;
658                    return Some(None);
659                }
660            }
661            let position =
662                self.center + self.patch.standard_layout_data_field.get_unchecked(self.item);
663            self.item += 1;
664            Some(Some(position))
665        }
666    }
667}
668
669#[cfg(test)]
670pub mod test {
671    use super::*;
672    use crate::ops::nn::DataFormat::*;
673    use proptest::prelude::*;
674    use proptest::*;
675
676    fn compute_output_spatial_dim(
677        input: usize,
678        dilation: usize,
679        kdim: usize,
680        pad_before: usize,
681        bad_after: usize,
682        stride: usize,
683    ) -> usize {
684        let patch = PatchSpec::for_full_shape(NCHW, &[1, 1, input])
685            .unwrap()
686            .with_dilations(tvec!(dilation))
687            .with_kernel_shape(tvec!(kdim))
688            .with_padding(PaddingSpec::ExplicitOnnxPool(tvec![pad_before], tvec![bad_after], true))
689            .with_strides(tvec![stride])
690            .into_patch();
691        patch.output_shape[0]
692    }
693
694    #[test]
695    fn basic() {
696        assert_eq!(compute_output_spatial_dim(5, 1, 3, 0, 0, 1), 3);
697    }
698
699    #[test]
700    fn strides() {
701        assert_eq!(compute_output_spatial_dim(7, 1, 3, 0, 0, 2), 3);
702    }
703
704    #[test]
705    fn padding() {
706        assert_eq!(compute_output_spatial_dim(5, 1, 3, 1, 1, 1), 5);
707    }
708
709    #[test]
710    fn strides_and_padding() {
711        assert_eq!(compute_output_spatial_dim(7, 1, 3, 1, 1, 2), 4);
712    }
713
714    fn field(kdim: &[usize], dilations: &[usize]) -> Array2<isize> {
715        let patch =
716            PatchSpec::for_data_shape(NCHW.from_n_c_hw(1, 1, tvec![10; kdim.len()]).unwrap())
717                .with_dilations(dilations.into())
718                .with_kernel_shape(kdim.into())
719                .with_strides(tvec![1; kdim.len()])
720                .into_patch();
721        patch.data_field
722    }
723
724    #[test]
725    fn test_field() {
726        assert_eq!(field(&[3], &[1]), arr2(&[[0], [1], [2]]));
727        assert_eq!(field(&[3], &[2]), arr2(&[[0], [2], [4]]));
728        assert_eq!(field(&[2, 2], &[1, 1]), arr2(&[[0, 0], [0, 1], [1, 0], [1, 1]]));
729        assert_eq!(field(&[2, 2], &[2, 1]), arr2(&[[0, 0], [0, 1], [2, 0], [2, 1]]));
730    }
731
732    pub fn tensor(shape: &[usize]) -> BoxedStrategy<Tensor> {
733        let len = shape.iter().product::<usize>();
734        let shape = shape.to_vec();
735        proptest::collection::vec(any::<i8>().prop_map(|i| i as f32), len..=len)
736            .prop_map(move |vec| ArrayD::from_shape_vec(shape.clone(), vec).unwrap().into_tensor())
737            .boxed()
738    }
739
740    #[derive(Debug)]
741    struct Problem {
742        patch: Patch,
743        input: Tensor,
744        data_format: DataFormat,
745    }
746
747    impl Arbitrary for Problem {
748        type Parameters = ();
749        type Strategy = BoxedStrategy<Problem>;
750        fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
751            (
752                prop_oneof!(Just(NCHW), Just(NHWC)),
753                (1usize..3, 1usize..3),
754                1usize..3,
755                (1usize..3, 1usize..3),
756                prop_oneof![
757                    Just(PaddingSpec::Valid),
758                    Just(PaddingSpec::SameLower),
759                    Just(PaddingSpec::SameUpper)
760                ],
761                (1usize..4, 1usize..4),
762            )
763                .prop_flat_map(|p| {
764                    let dil = p.1;
765                    let ks = p.3;
766                    let strides = p.5;
767                    let min_size: (usize, usize) = (1 + (ks.0 - 1) * dil.0, 1 + (ks.1 - 1) * dil.1);
768                    (
769                        Just(p),
770                        (min_size.0..min_size.0 + strides.0 * 3),
771                        (min_size.1..min_size.1 + strides.1 * 3),
772                    )
773                })
774                .prop_flat_map(|(p, h, w)| {
775                    let input_shape = p.0.from_n_c_hw(1, p.2, [h, w]).unwrap();
776                    let input = tensor(&input_shape.shape);
777                    (Just(p), input)
778                })
779                .prop_map(|((fmt, dil, c, ks, pad, strides), input)| {
780                    let output_inner_stride = if fmt.c_is_last() { c } else { 1 };
781                    Problem {
782                        patch: PatchSpec::for_full_shape(fmt, input.shape())
783                            .unwrap()
784                            .with_dilations(tvec!(dil.0, dil.1))
785                            .with_kernel_shape(tvec!(ks.0, ks.1))
786                            .with_padding(pad)
787                            .with_strides(tvec![strides.0, strides.1])
788                            .with_output_inner_stride(output_inner_stride)
789                            .into_patch(),
790                        input,
791                        data_format: fmt,
792                    }
793                })
794                .boxed()
795        }
796    }
797
798    impl Problem {
799        fn input_shape(&self) -> DataShape {
800            self.data_format.shape(self.input.shape().into()).unwrap()
801        }
802
803        fn output_shape(&self) -> DataShape {
804            self.data_format
805                .from_n_c_hw(
806                    self.input_shape().n().cloned().unwrap_or(1),
807                    *self.input_shape().c(),
808                    &*self.patch.output_shape,
809                )
810                .unwrap()
811        }
812
813        fn reference_sumpool(&self) -> Tensor {
814            let input_shape = self.input_shape();
815            let output_shape = self.output_shape();
816            let mut output = Tensor::zero::<f32>(&output_shape.shape).unwrap();
817            for geo_out in tract_ndarray::indices(output_shape.hw_dims()) {
818                for geo_ker in tract_ndarray::indices(&*self.patch.spec.kernel_shape) {
819                    let geo_in: TVec<isize> = izip!(
820                        geo_out.slice(),
821                        geo_ker.slice(),
822                        &self.patch.spec.strides,
823                        &self.patch.spec.dilations,
824                        &self.patch.pad_before
825                    )
826                    .map(|(o, k, s, d, p)| (o * s + k * d) as isize - *p as isize)
827                    .collect();
828                    if izip!(&geo_in, input_shape.hw_dims())
829                        .any(|(g, i)| *g >= *i as isize || *g < 0)
830                    {
831                        continue;
832                    }
833                    let geo_in: TVec<usize> = geo_in.into_iter().map(|x| x as usize).collect();
834                    for c in 0..*output_shape.c() {
835                        let ocoords = self.data_format.from_n_c_hw(0, c, geo_out.slice()).unwrap();
836                        let icoords = self.data_format.from_n_c_hw(0, c, &geo_in).unwrap();
837                        output.to_array_view_mut::<f32>().unwrap()[&*ocoords.shape] +=
838                            self.input.to_array_view::<f32>().unwrap()[&*icoords.shape];
839                    }
840                }
841            }
842            output
843        }
844
845        fn check_visitor(&self) {
846            let input_shape = self.input_shape();
847            let output_shape = self.output_shape();
848            let mut output = Tensor::zero::<f32>(&output_shape.shape).unwrap();
849            self.patch.visit_output(|visitor| {
850                for (_k, offset_in) in visitor.valid_offsets_ker_in() {
851                    for c in 0..*output_shape.c() {
852                        output.as_slice_mut::<f32>().unwrap()
853                            [visitor.output_offset as usize + c * output_shape.c_stride()] +=
854                            self.input.as_slice::<f32>().unwrap()
855                                [offset_in as usize + c * input_shape.c_stride()];
856                    }
857                }
858            });
859            assert_eq!(output, self.reference_sumpool());
860        }
861
862        fn check_zone_visitor(&self) {
863            let input_shape = self.input_shape();
864            let output_shape = self.output_shape();
865            let mut output = Tensor::zero::<f32>(&output_shape.shape).unwrap();
866            for zone in &self.patch.zones {
867                zone.visit_output(&self.patch, |visitor| {
868                    for (_k, offset_in) in visitor.valid_offsets_ker_in() {
869                        for c in 0..*output_shape.c() {
870                            output.as_slice_mut::<f32>().unwrap()
871                                [visitor.output_offset as usize + c * output_shape.c_stride()] +=
872                                self.input.as_slice::<f32>().unwrap()
873                                    [offset_in as usize + c * input_shape.c_stride()];
874                        }
875                    }
876                });
877            }
878            assert_eq!(output, self.reference_sumpool());
879        }
880
881        fn check_zoning(&self) {
882            fn in_zone(full_coords: &[usize], h_axis: usize, zone: &[Range<usize>]) -> bool {
883                for a in 0..zone.len() {
884                    if full_coords[h_axis + a] < zone[a].start
885                        || full_coords[h_axis + a] >= zone[a].end
886                    {
887                        return false;
888                    }
889                }
890                true
891            }
892
893            let valid_zone = &self.patch.valid_output_zone;
894            let invalid_zones = &self.patch.invalid_output_zones;
895            let output_full_shape = self.output_shape();
896            let h_axis = self.input_shape().h_axis();
897            for coords in ndarray::indices(&*output_full_shape.shape) {
898                let inside_valid = in_zone(coords.slice(), h_axis, valid_zone);
899                let invalid_count =
900                    invalid_zones.iter().filter(|z| in_zone(coords.slice(), h_axis, z)).count();
901                unsafe {
902                    assert_eq!(
903                        inside_valid,
904                        self.patch.is_valid(&coords.slice()[self.input_shape().hw_axes()]),
905                        "coords {:?}, valid_zone: {:?} inside_valid: {:?}",
906                        coords.slice(),
907                        valid_zone,
908                        inside_valid
909                    );
910                }
911                if inside_valid {
912                    assert_eq!(invalid_count, 0);
913                } else {
914                    assert_eq!(
915                        invalid_count,
916                        1,
917                        "coords {:?}, valid_zone: {:?} inside_valid: {:?} invalid_zones: {:?}",
918                        coords.slice(),
919                        valid_zone,
920                        inside_valid,
921                        invalid_zones
922                    );
923                }
924            }
925        }
926    }
927
928    proptest! {
929        #[test]
930        fn test_visitor(pb in any::<Problem>()) {
931            pb.check_visitor();
932        }
933
934        #[test]
935        fn test_zone_visitor(pb in any::<Problem>()) {
936            pb.check_zone_visitor();
937        }
938
939        #[test]
940        fn test_zoning(pb in any::<Problem>()) {
941            pb.check_zoning();
942        }
943    }
944
945    #[test]
946    fn test_visitor_1() {
947        let input_shape = NCHW.from_n_c_hw(1, 1, [2, 2]).unwrap();
948        let input = Tensor::zero::<f32>(&input_shape.shape).unwrap();
949        let patch = PatchSpec::for_data_shape(input_shape.clone())
950            .with_kernel_shape(tvec![2, 1])
951            .with_padding(PaddingSpec::SameLower)
952            .with_strides(tvec![1, 2])
953            .into_patch();
954        Problem { patch, input, data_format: input_shape.fmt }.check_visitor();
955    }
956
957    #[test]
958    fn test_visitor_2() {
959        let input_shape = NCHW.from_n_c_hw(1, 2, [1, 1]).unwrap();
960        let input = tensor4(&[[[[0.]], [[1f32]]]]);
961        assert_eq!(input.shape(), &*input_shape.shape);
962        let patch =
963            PatchSpec::for_data_shape(input_shape.clone()).with_output_inner_stride(2).into_patch();
964        Problem { patch, input, data_format: input_shape.fmt }.check_visitor();
965    }
966
967    #[test]
968    fn test_visitor_3() {
969        let input_shape = NHWC.from_n_c_hw(1, 2, [2, 1]).unwrap();
970        let input = tensor4(&[[[[0., 0.]], [[1., 0f32]]]]);
971        assert_eq!(input.shape(), &*input_shape.shape);
972        let patch =
973            PatchSpec::for_data_shape(input_shape.clone()).with_output_inner_stride(2).into_patch();
974        Problem { patch, input, data_format: input_shape.fmt }.check_visitor();
975    }
976
977    #[test]
978    fn test_visitor_4() {
979        let input_shape = NCHW.from_n_c_hw(1, 1, [1, 2]).unwrap();
980        let input = tensor4(&[[[[0., 1f32]]]]);
981        assert_eq!(input.shape(), &*input_shape.shape);
982        let patch = PatchSpec::for_data_shape(input_shape.clone())
983            .with_kernel_shape(tvec!(1, 2))
984            .with_output_inner_stride(1)
985            .with_padding(PaddingSpec::SameLower)
986            .into_patch();
987        Problem { patch, input, data_format: input_shape.fmt }.check_visitor();
988    }
989
990    #[test]
991    fn test_zone_visitor_1() {
992        let input_shape = NCHW.from_n_c_hw(1, 1, [2, 1]).unwrap();
993        let input = tensor4(&[[[[0.], [1f32]]]]);
994        assert_eq!(input.shape(), &*input_shape.shape);
995        let patch = PatchSpec::for_data_shape(input_shape.clone()).into_patch();
996        Problem { patch, input, data_format: input_shape.fmt }.check_zone_visitor();
997    }
998
999    #[test]
1000    fn test_zone_visitor_2() {
1001        let input_shape = NCHW.from_n_c_hw(1, 1, [1, 2]).unwrap();
1002        let input = tensor4(&[[[[0., 1f32]]]]);
1003        assert_eq!(input.shape(), &*input_shape.shape);
1004        let patch = PatchSpec::for_data_shape(input_shape.clone()).into_patch();
1005        Problem { patch, input, data_format: input_shape.fmt }.check_zone_visitor();
1006    }
1007}