1use crate::internal::*;
2use crate::ops::cnn::PaddingSpec;
3use crate::ops::nn::{DataFormat, DataShape};
4use ndarray::prelude::*;
5
6use super::PatchAxis;
7
8use std::fmt::Debug;
9use std::ops::Range;
10
11use tract_itertools::{izip, Itertools};
12
13#[derive(Clone, PartialEq, Eq, Hash)]
14pub struct PatchSpec {
15 pub input_shape: TVec<usize>,
16 pub input_inner_stride: usize,
17 pub output_inner_stride: usize,
18 pub kernel_shape: TVec<usize>,
19 pub strides: TVec<usize>,
20 pub dilations: TVec<usize>,
21 pub padding: PaddingSpec,
22}
23
24impl Debug for PatchSpec {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 write!(
27 f,
28 "input: {} kernel: {} strides: {} dil: {} pad: {:?}",
29 self.input_shape.iter().join(","),
30 self.kernel_shape.iter().join(","),
31 self.strides.iter().join(","),
32 self.dilations.iter().join(","),
33 self.padding
34 )
35 }
36}
37
38impl PatchSpec {
39 pub fn for_full_shape(
40 data_format: DataFormat,
41 input_full_shape: &[usize],
42 ) -> TractResult<PatchSpec> {
43 let shape = data_format.shape(input_full_shape.into())?;
44 Ok(Self::for_data_shape(shape))
45 }
46
47 pub fn for_data_shape(data_shape: DataShape) -> PatchSpec {
48 let input_shape: TVec<usize> = data_shape.hw_dims().into();
49 PatchSpec {
50 kernel_shape: tvec!(1; input_shape.len()),
51 input_inner_stride: *data_shape.w_stride(),
52 output_inner_stride: 1,
53 strides: tvec!(1; input_shape.len()),
54 dilations: tvec!(1; input_shape.len()),
55 padding: PaddingSpec::Valid,
56 input_shape,
57 }
58 }
59
60 pub fn with_kernel_shape(self, kernel_shape: TVec<usize>) -> PatchSpec {
61 PatchSpec { kernel_shape, ..self }
62 }
63
64 pub fn with_dilations(self, dilations: TVec<usize>) -> PatchSpec {
65 PatchSpec { dilations, ..self }
66 }
67
68 pub fn with_strides(self, strides: TVec<usize>) -> PatchSpec {
69 PatchSpec { strides, ..self }
70 }
71
72 pub fn with_padding(self, padding: PaddingSpec) -> PatchSpec {
73 PatchSpec { padding, ..self }
74 }
75
76 pub fn with_output_inner_stride(self, output_inner_stride: usize) -> PatchSpec {
77 PatchSpec { output_inner_stride, ..self }
78 }
79
80 pub fn into_patch(self) -> Patch {
81 let dims = self.padding.compute(
82 &self.input_shape,
83 &self.kernel_shape,
84 &self.dilations,
85 &self.strides,
86 );
87 let output: TVec<usize> = dims.iter().map(|d| d.convoluted).collect();
88 let pad_before: TVec<usize> = dims.iter().map(|d| d.pad_before).collect();
89 let pad_after: TVec<usize> = dims.iter().map(|d| d.pad_after).collect();
90
91 let data_field: Vec<isize> = ::ndarray::indices(&*self.kernel_shape)
92 .into_iter()
93 .flat_map(|coords| {
94 #[allow(clippy::unnecessary_to_owned)] coords
96 .slice()
97 .to_vec()
98 .into_iter()
99 .enumerate()
100 .map(|(ix, c)| (c * self.dilations[ix]) as isize - pad_before[ix] as isize)
101 })
102 .collect();
103 let data_field = Array2::from_shape_vec(
104 (self.kernel_shape.iter().cloned().product(), self.kernel_shape.len()),
105 data_field,
106 )
107 .unwrap();
108 let data_field_min_max: TVec<_> = data_field
109 .columns()
110 .into_iter()
111 .map(|col| (col.iter().min().cloned().unwrap(), col.iter().max().cloned().unwrap()))
112 .collect();
113
114 fn strides(shape: &[usize], inner: usize) -> TVec<isize> {
115 let mut strides: TVec<isize> = tvec![inner as isize];
116 for dim in shape.iter().skip(1).rev() {
117 let previous = *strides.last().unwrap();
118 strides.push(*dim as isize * previous);
119 }
120 strides.reverse();
121 strides
122 }
123
124 let input_storage_strides = strides(&self.input_shape, self.input_inner_stride);
125 let output_storage_strides = strides(&output, self.output_inner_stride);
126
127 let standard_layout_data_field: Vec<isize> = data_field
128 .outer_iter()
129 .map(|coords| izip!(coords, &input_storage_strides).map(|(a, b)| a * b).sum::<isize>())
130 .collect();
131
132 let regions: TVec<TVec<_>> = dims
134 .iter()
135 .enumerate()
136 .map(|(ix, d)| {
137 PatchAxis {
138 input_dim: self.input_shape[ix],
139 kernel_dim: self.kernel_shape[ix],
140 pad_before: d.pad_before,
141 pad_after: d.pad_after,
142 output_dim: d.convoluted,
143 stride: self.strides[ix],
144 dilation: self.dilations[ix],
145 }
146 .regions()
147 })
148 .collect::<TVec<_>>();
149
150 let zone_strides = strides(®ions.iter().map(|d| d.len()).collect::<TVec<_>>(), 1);
151
152 let zones: Vec<Zone> = regions
153 .iter()
154 .multi_cartesian_product()
155 .map(|regions| Zone {
156 input_zone_offset: 0,
157 output_ranges: regions.iter().map(|reg| reg.range.clone()).collect(),
158 output_shape: regions.iter().map(|reg| reg.range.end - reg.range.start).collect(),
159 output_zone_offset: izip!(®ions, &output_storage_strides)
160 .map(|(reg, &stride)| reg.range.start as isize * stride)
161 .sum::<isize>(),
162 valid: regions.iter().all(|reg| reg.mask.is_none()),
163 values_offsets: izip!(
164 0..,
165 ndarray::indices(&*self.kernel_shape),
166 &standard_layout_data_field
167 )
168 .filter(|(_ix, coords, _offset)| {
169 izip!(coords.slice(), ®ions)
170 .all(|(&x, axis)| !axis.mask.as_ref().map(|mask| mask[x]).unwrap_or(false))
171 })
172 .map(|(ix, _coords, &window_offset)| (ix, window_offset))
173 .collect(),
174 })
175 .collect();
176
177 let valid_zone = zones.iter().position(|z| z.valid);
178
179 let mut valid_output_zone = tvec!();
180 let mut invalid_output_zones = tvec!();
181 for ix in 0..self.input_shape.len() {
182 let min_max = data_field_min_max[ix];
183 let min = (-min_max.0 as usize).divceil(self.strides[ix]);
184 let max =
185 (self.input_shape[ix].saturating_sub(min_max.1 as usize)).divceil(self.strides[ix]);
186 if min != 0 {
187 let mut invalid = valid_output_zone.clone();
188 invalid.push(0..min);
189 while invalid.len() < output.len() {
190 invalid.push(0..output[invalid.len()])
191 }
192 invalid_output_zones.push(invalid);
193 }
194 if max < output[ix] {
195 let mut invalid = valid_output_zone.clone();
196 invalid.push(max..output[ix]);
197 while invalid.len() < output.len() {
198 invalid.push(0..output[invalid.len()])
199 }
200 invalid_output_zones.push(invalid);
201 }
202 valid_output_zone.push(min..max)
203 }
204
205 let op_strides_times_input_storage_strides =
206 izip!(&self.strides, &input_storage_strides).map(|(a, b)| *a as isize * b).collect();
207
208 Patch {
209 spec: self,
210 padded: pad_before.iter().any(|&p| p != 0) || pad_after.iter().any(|&p| p != 0),
211 pad_before,
212 pad_after,
213 output_shape: output,
214 data_field,
215 data_field_min_max,
216 standard_layout_data_field,
217 input_storage_strides,
218 output_storage_strides,
219 op_strides_times_input_storage_strides,
220 valid_output_zone,
221 invalid_output_zones,
222 zones,
223 valid_zone_id: valid_zone,
224 zone_strides,
225 }
226 }
227}
228
229#[derive(Clone, PartialEq, Eq, Hash)]
230pub struct Patch {
231 pub spec: PatchSpec,
232 pub pad_before: TVec<usize>,
233 pub pad_after: TVec<usize>,
234 pub padded: bool,
235 pub output_shape: TVec<usize>,
236 pub data_field: Array2<isize>,
237 pub data_field_min_max: TVec<(isize, isize)>,
238 pub standard_layout_data_field: Vec<isize>,
239 pub op_strides_times_input_storage_strides: TVec<isize>,
240 pub valid_output_zone: TVec<Range<usize>>,
241 pub invalid_output_zones: TVec<TVec<Range<usize>>>,
242 pub zones: Vec<Zone>,
243 pub valid_zone_id: Option<usize>,
244 pub zone_strides: TVec<isize>,
245 pub input_storage_strides: TVec<isize>,
246 pub output_storage_strides: TVec<isize>,
247}
248
249impl Debug for Patch {
250 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
251 write!(f, "{:?}", self.spec)
252 }
253}
254
255impl Patch {
256 #[inline]
257 pub fn rank(&self) -> usize {
258 self.spec.input_shape.len()
259 }
260
261 unsafe fn is_valid(&self, coords: &[usize]) -> bool {
262 unsafe {
263 for ix in 0..self.rank() {
264 let c = *coords.get_unchecked(ix) as isize;
265 let strides = *self.spec.strides.get_unchecked(ix) as isize;
266 let pos = c * strides;
267 let min_max = self.data_field_min_max.get_unchecked(ix);
268 if pos + min_max.0 < 0
269 || pos + min_max.1 >= *self.spec.input_shape.get_unchecked(ix) as isize
270 {
271 return false;
272 }
273 }
274 true
275 }
276 }
277
278 pub fn valid_zone(&self) -> Option<&Zone> {
279 self.valid_zone_id.map(|id| &self.zones[id])
280 }
281
282 #[inline]
283 pub fn visit_output(&self, mut acceptor: impl FnMut(&Scanner)) {
284 if self.zones.len() == 0 {
285 return;
286 }
287 let mut scanner = Scanner::new(self);
288 while !scanner.done() {
289 acceptor(&scanner);
290 scanner.next();
291 }
292 }
293
294 pub fn centers_offsets(&self) -> Vec<isize> {
295 if self.zones.len() == 0 {
296 return vec![];
297 }
298 let mut scanner = Scanner::new(self);
299 let len = self.output_shape.iter().cloned().product();
300 let mut v = Vec::with_capacity(len);
301 for _ in 0..len {
302 v.push(scanner.input_center_offset);
303 scanner.next()
304 }
305 v
306 }
307
308 pub fn at<'p>(&'p self, coords: &[usize]) -> PatchIterator<'p> {
309 self.at_hint(coords, None)
310 }
311
312 pub fn at_hint<'p>(&'p self, coords: &[usize], hint: Option<bool>) -> PatchIterator<'p> {
313 unsafe {
314 assert_eq!(coords.len(), self.spec.kernel_shape.len());
315 let mut center = 0;
316 for i in 0..self.op_strides_times_input_storage_strides.len() {
317 center += *self.op_strides_times_input_storage_strides.get_unchecked(i)
318 * *coords.get_unchecked(i) as isize;
319 }
320 let valid = hint.unwrap_or_else(|| !self.padded || self.is_valid(coords));
321 if valid {
322 PatchIterator::Fast(FastPatchIterator { patch: self, center, item: 0 })
323 } else {
324 let mut input_patch_center: TVec<_> = coords.into();
325 input_patch_center
326 .iter_mut()
327 .zip(self.spec.strides.iter())
328 .for_each(|(a, &b)| *a *= b);
329 PatchIterator::Safe(SafePatchIterator {
330 patch: self,
331 item: 0,
332 input_patch_center,
333 center,
334 })
335 }
336 }
337 }
338
339 pub fn global_offset_for(&self, coords: &[usize], patch_index: usize) -> usize {
340 assert_eq!(coords.len(), self.spec.kernel_shape.len());
341 let center = izip!(coords, &self.op_strides_times_input_storage_strides)
342 .map(|(a, b)| *a as isize * *b)
343 .sum::<isize>();
344 (center + self.standard_layout_data_field[patch_index]) as usize
345 }
346}
347
348#[derive(Clone, Debug, PartialEq, Eq, Hash)]
349pub struct Zone {
350 pub valid: bool,
351 pub input_zone_offset: isize,
352 pub output_zone_offset: isize,
353 pub output_ranges: Box<[Range<usize>]>,
354 pub output_shape: Box<[usize]>,
355 pub values_offsets: Box<[(usize, isize)]>,
357}
358
359impl Zone {
360 pub fn contains_output(&self, coords: &[usize]) -> bool {
361 self.output_ranges.iter().zip(coords).all(|(range, &x)| x >= range.start && x < range.end)
362 }
363
364 #[inline]
365 pub fn visit_output(&self, patch: &Patch, mut acceptor: impl FnMut(&ZoneScanner)) {
366 let mut scanner = ZoneScanner::new(self, patch);
367 while !scanner.done() {
368 acceptor(&scanner);
369 scanner.next();
370 }
371 }
372}
373
374#[derive(Clone, Debug, PartialEq, Eq)]
375pub struct ZoneScanner<'p> {
376 pub patch: &'p Patch,
377 pub zone: &'p Zone,
378 pub output_offset: isize,
379 pub output_coords: Box<[usize]>,
380 pub input_center_offset: isize,
381 pub inner_loop_axis: usize,
382 pub inner_loop_len: usize,
383 pub inner_loop_output_range: Range<usize>,
384 pub inner_loop_output_stride: isize,
385 pub inner_loop_input_full_stride: isize,
386 pub done: bool,
387}
388
389impl<'p> ZoneScanner<'p> {
390 pub fn new(zone: &'p Zone, patch: &'p Patch) -> ZoneScanner<'p> {
391 let inner_loop_axis =
392 zone.output_shape.iter().enumerate().max_by_key(|(_, dim)| *dim).unwrap().0;
393 let inner_loop_output_range = zone.output_ranges[inner_loop_axis].clone();
394 let inner_loop_output_stride = patch.output_storage_strides[inner_loop_axis];
395 let inner_loop_input_full_stride =
396 patch.op_strides_times_input_storage_strides[inner_loop_axis];
397 let mut scan = ZoneScanner {
398 patch,
399 zone,
400 output_offset: 0,
401 input_center_offset: 0,
402 inner_loop_axis,
403 inner_loop_len: inner_loop_output_range.len(),
404 inner_loop_output_range,
405 inner_loop_output_stride,
406 inner_loop_input_full_stride,
407 output_coords: zone.output_ranges.iter().map(|r| r.start).collect(),
408 done: false,
409 };
410 scan.refresh_dependent();
411 scan
412 }
413
414 #[inline]
415 pub fn valid_offsets_ker_in(&self) -> impl Iterator<Item = (usize, isize)> + '_ {
416 self.zone.values_offsets.iter().map(move |pair| (pair.0, pair.1 + self.input_center_offset))
417 }
418
419 pub unsafe fn next_non_inner_axis(&mut self) {
420 unsafe {
421 let rank = self.patch.rank();
422 let inner_loop_axis = self.inner_loop_axis;
423 for axis in (0..rank).rev() {
424 if axis == inner_loop_axis {
425 continue;
426 }
427 *self.output_coords.get_unchecked_mut(axis) += 1;
428 if *self.output_coords.get_unchecked_mut(axis)
429 < self.zone.output_ranges.get_unchecked(axis).end
430 {
431 self.refresh_dependent();
432 return;
433 }
434 *self.output_coords.get_unchecked_mut(axis) =
435 self.zone.output_ranges.get_unchecked(axis).start;
436 }
437 self.done = true;
438 }
439 }
440
441 pub unsafe fn reset(&mut self) {
442 unsafe {
443 self.output_offset = 0;
444 self.input_center_offset = 0;
445 for ix in 0..self.output_coords.len() {
446 *self.output_coords.get_unchecked_mut(ix) =
447 self.zone.output_ranges.get_unchecked(ix).start;
448 }
449 self.done = false;
450 self.refresh_dependent()
451 }
452 }
453
454 #[inline(never)]
455 fn refresh_dependent(&mut self) {
456 self.input_center_offset = self
457 .patch
458 .op_strides_times_input_storage_strides
459 .iter()
460 .zip(self.output_coords.iter())
461 .map(|(a, b)| *a * *b as isize)
462 .sum();
463 self.output_offset = self
464 .patch
465 .output_storage_strides
466 .iter()
467 .zip(self.output_coords.iter())
468 .map(|(a, b)| a * *b as isize)
469 .sum();
470 }
471
472 #[inline]
473 pub fn next(&mut self) {
474 let inner_loop_axis = self.inner_loop_axis;
475 unsafe {
476 *self.output_coords.get_unchecked_mut(inner_loop_axis) += 1;
477 if *self.output_coords.get_unchecked(inner_loop_axis) < self.inner_loop_output_range.end
478 {
479 self.input_center_offset += self.inner_loop_input_full_stride;
480 self.output_offset += self.inner_loop_output_stride;
481 } else {
482 *self.output_coords.get_unchecked_mut(inner_loop_axis) =
483 self.inner_loop_output_range.start;
484 self.next_non_inner_axis();
485 }
486 }
487 }
488
489 pub fn done(&self) -> bool {
490 self.done
491 }
492}
493
494#[derive(Clone, Debug, PartialEq, Eq)]
495pub struct Scanner<'p> {
496 pub patch: &'p Patch,
497 pub zone_id: usize,
498 pub zone_coords: TVec<usize>,
499 pub zone: &'p Zone,
500 pub output_offset: isize,
501 pub output_coords: TVec<usize>,
502 pub input_coords: TVec<usize>,
503 pub input_center_offset: isize,
504 done: bool,
505}
506
507impl<'p> Scanner<'p> {
508 fn new(patch: &'p Patch) -> Scanner<'p> {
509 let rank = patch.rank();
510 let zone = &patch.zones[0];
511 Scanner {
512 patch,
513 zone_coords: tvec!(0; rank),
514 zone,
515 zone_id: 0,
516 output_offset: 0,
517 input_center_offset: 0,
518 input_coords: tvec!(0; rank),
519 output_coords: tvec!(0; rank),
520 done: false,
521 }
522 }
523
524 #[inline]
525 pub fn valid_count(&self) -> usize {
526 self.zone.values_offsets.len()
527 }
528
529 #[inline]
530 pub fn valid_offsets(&self) -> impl Iterator<Item = isize> + '_ {
531 self.zone.values_offsets.iter().map(move |pair| pair.1 + self.input_center_offset)
532 }
533
534 #[inline]
535 pub fn valid_offsets_ker_in(&self) -> impl Iterator<Item = (usize, isize)> + '_ {
536 self.zone.values_offsets.iter().map(move |pair| (pair.0, pair.1 + self.input_center_offset))
537 }
538
539 #[inline]
540 pub fn next(&mut self) {
541 let rank = self.patch.rank();
542 let inner_dim = rank - 1;
543 unsafe {
544 *self.output_coords.get_unchecked_mut(inner_dim) += 1;
545 *self.input_coords.get_unchecked_mut(inner_dim) +=
546 *self.patch.spec.strides.get_unchecked(inner_dim);
547 self.output_offset += self.patch.spec.output_inner_stride as isize;
548 self.input_center_offset +=
549 self.patch.op_strides_times_input_storage_strides.get_unchecked(inner_dim);
550 if *self.output_coords.get_unchecked(inner_dim)
551 < self.zone.output_ranges.get_unchecked(inner_dim).end
552 {
553 return;
554 }
555 if self.output_coords.get_unchecked(inner_dim)
556 < self.patch.output_shape.get_unchecked(inner_dim)
557 {
558 self.zone_id += 1;
559 *self.zone_coords.get_unchecked_mut(inner_dim) += 1;
560 self.zone = self.patch.zones.get_unchecked(self.zone_id);
561 } else {
562 for axis in (0..rank - 1).rev() {
563 *self.output_coords.get_unchecked_mut(axis + 1) = 0;
564 *self.input_coords.get_unchecked_mut(axis + 1) = 0;
565 *self.output_coords.get_unchecked_mut(axis) += 1;
566 *self.input_coords.get_unchecked_mut(axis) +=
567 self.patch.spec.strides.get_unchecked(axis);
568 *self.zone_coords.get_unchecked_mut(axis + 1) = 0;
569 if *self.output_coords.get_unchecked(axis)
570 == self.zone.output_ranges.get_unchecked(axis).end
571 {
572 *self.zone_coords.get_unchecked_mut(axis) += 1;
573 }
574 if *self.output_coords.get_unchecked(axis)
575 < *self.patch.output_shape.get_unchecked(axis)
576 {
577 break;
578 }
579 }
580 if self.output_coords.get_unchecked(0) == self.patch.output_shape.get_unchecked(0) {
581 self.done = true;
582 return;
583 }
584 self.zone_id = 0;
585 self.input_center_offset = 0;
586 for i in 0..rank {
587 self.zone_id += *self.zone_coords.get_unchecked(i)
588 * *self.patch.zone_strides.get_unchecked(i) as usize;
589 self.input_center_offset += *self.input_coords.get_unchecked(i) as isize
590 * *self.patch.input_storage_strides.get_unchecked(i);
591 }
592 self.zone = self.patch.zones.get_unchecked(self.zone_id);
593 }
594 }
595 }
596
597 pub fn done(&self) -> bool {
598 self.done
599 }
600}
601
602#[derive(Debug)]
603pub enum PatchIterator<'p> {
604 Fast(FastPatchIterator<'p>),
605 Safe(SafePatchIterator<'p>),
606}
607
608impl Iterator for PatchIterator<'_> {
609 type Item = Option<isize>;
610 #[inline(always)]
611 fn next(&mut self) -> Option<Option<isize>> {
612 match self {
613 PatchIterator::Fast(it) => it.next(),
614 PatchIterator::Safe(it) => it.next(),
615 }
616 }
617}
618
619#[derive(Debug)]
620pub struct FastPatchIterator<'p> {
621 patch: &'p Patch,
622 center: isize,
623 item: usize,
624}
625
626impl Iterator for FastPatchIterator<'_> {
627 type Item = Option<isize>;
628 #[inline(always)]
629 fn next(&mut self) -> Option<Option<isize>> {
630 if self.item == self.patch.standard_layout_data_field.len() {
631 return None;
632 }
633 unsafe {
634 let position =
635 self.center + self.patch.standard_layout_data_field.get_unchecked(self.item);
636 self.item += 1;
637 Some(Some(position))
638 }
639 }
640}
641
642#[derive(Debug)]
643pub struct SafePatchIterator<'p> {
644 patch: &'p Patch,
645 item: usize,
646 input_patch_center: TVec<usize>,
647 center: isize,
648}
649
650impl Iterator for SafePatchIterator<'_> {
651 type Item = Option<isize>;
652 fn next(&mut self) -> Option<Option<isize>> {
653 unsafe {
654 if self.item == self.patch.standard_layout_data_field.len() {
655 return None;
656 }
657 let input_shape = &self.patch.spec.input_shape;
658 let img_offset = self.patch.data_field.as_ptr().add(self.item * input_shape.len());
659
660 for ix in 0..input_shape.len() {
661 let pos = *self.input_patch_center.get_unchecked(ix) as isize + *img_offset.add(ix);
662 if pos < 0 || pos as usize >= *input_shape.get_unchecked(ix) {
663 self.item += 1;
664 return Some(None);
665 }
666 }
667 let position =
668 self.center + self.patch.standard_layout_data_field.get_unchecked(self.item);
669 self.item += 1;
670 Some(Some(position))
671 }
672 }
673}
674
675#[cfg(test)]
676pub mod test {
677 use super::*;
678 use crate::ops::nn::DataFormat::*;
679 use proptest::prelude::*;
680 use proptest::*;
681
682 fn compute_output_spatial_dim(
683 input: usize,
684 dilation: usize,
685 kdim: usize,
686 pad_before: usize,
687 bad_after: usize,
688 stride: usize,
689 ) -> usize {
690 let patch = PatchSpec::for_full_shape(NCHW, &[1, 1, input])
691 .unwrap()
692 .with_dilations(tvec!(dilation))
693 .with_kernel_shape(tvec!(kdim))
694 .with_padding(PaddingSpec::ExplicitOnnxPool(tvec![pad_before], tvec![bad_after], true))
695 .with_strides(tvec![stride])
696 .into_patch();
697 patch.output_shape[0]
698 }
699
700 #[test]
701 fn basic() {
702 assert_eq!(compute_output_spatial_dim(5, 1, 3, 0, 0, 1), 3);
703 }
704
705 #[test]
706 fn strides() {
707 assert_eq!(compute_output_spatial_dim(7, 1, 3, 0, 0, 2), 3);
708 }
709
710 #[test]
711 fn padding() {
712 assert_eq!(compute_output_spatial_dim(5, 1, 3, 1, 1, 1), 5);
713 }
714
715 #[test]
716 fn strides_and_padding() {
717 assert_eq!(compute_output_spatial_dim(7, 1, 3, 1, 1, 2), 4);
718 }
719
720 fn field(kdim: &[usize], dilations: &[usize]) -> Array2<isize> {
721 let patch =
722 PatchSpec::for_data_shape(NCHW.from_n_c_hw(1, 1, tvec![10; kdim.len()]).unwrap())
723 .with_dilations(dilations.into())
724 .with_kernel_shape(kdim.into())
725 .with_strides(tvec![1; kdim.len()])
726 .into_patch();
727 patch.data_field
728 }
729
730 #[test]
731 fn test_field() {
732 assert_eq!(field(&[3], &[1]), arr2(&[[0], [1], [2]]));
733 assert_eq!(field(&[3], &[2]), arr2(&[[0], [2], [4]]));
734 assert_eq!(field(&[2, 2], &[1, 1]), arr2(&[[0, 0], [0, 1], [1, 0], [1, 1]]));
735 assert_eq!(field(&[2, 2], &[2, 1]), arr2(&[[0, 0], [0, 1], [2, 0], [2, 1]]));
736 }
737
738 pub fn tensor(shape: &[usize]) -> BoxedStrategy<Tensor> {
739 let len = shape.iter().product::<usize>();
740 let shape = shape.to_vec();
741 proptest::collection::vec(any::<i8>().prop_map(|i| i as f32), len..=len)
742 .prop_map(move |vec| ArrayD::from_shape_vec(shape.clone(), vec).unwrap().into_tensor())
743 .boxed()
744 }
745
746 #[derive(Debug)]
747 struct Problem {
748 patch: Patch,
749 input: Tensor,
750 data_format: DataFormat,
751 }
752
753 impl Arbitrary for Problem {
754 type Parameters = ();
755 type Strategy = BoxedStrategy<Problem>;
756 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
757 (
758 prop_oneof!(Just(NCHW), Just(NHWC)),
759 (1usize..3, 1usize..3),
760 1usize..3,
761 (1usize..3, 1usize..3),
762 prop_oneof![
763 Just(PaddingSpec::Valid),
764 Just(PaddingSpec::SameLower),
765 Just(PaddingSpec::SameUpper)
766 ],
767 (1usize..4, 1usize..4),
768 )
769 .prop_flat_map(|p| {
770 let dil = p.1;
771 let ks = p.3;
772 let strides = p.5;
773 let min_size: (usize, usize) = (1 + (ks.0 - 1) * dil.0, 1 + (ks.1 - 1) * dil.1);
774 (
775 Just(p),
776 (min_size.0..min_size.0 + strides.0 * 3),
777 (min_size.1..min_size.1 + strides.1 * 3),
778 )
779 })
780 .prop_flat_map(|(p, h, w)| {
781 let input_shape = p.0.from_n_c_hw(1, p.2, [h, w]).unwrap();
782 let input = tensor(&input_shape.shape);
783 (Just(p), input)
784 })
785 .prop_map(|((fmt, dil, c, ks, pad, strides), input)| {
786 let output_inner_stride = if fmt.c_is_last() { c } else { 1 };
787 Problem {
788 patch: PatchSpec::for_full_shape(fmt, input.shape())
789 .unwrap()
790 .with_dilations(tvec!(dil.0, dil.1))
791 .with_kernel_shape(tvec!(ks.0, ks.1))
792 .with_padding(pad)
793 .with_strides(tvec![strides.0, strides.1])
794 .with_output_inner_stride(output_inner_stride)
795 .into_patch(),
796 input,
797 data_format: fmt,
798 }
799 })
800 .boxed()
801 }
802 }
803
804 impl Problem {
805 fn input_shape(&self) -> DataShape {
806 self.data_format.shape(self.input.shape().into()).unwrap()
807 }
808
809 fn output_shape(&self) -> DataShape {
810 self.data_format
811 .from_n_c_hw(
812 self.input_shape().n().cloned().unwrap_or(1),
813 *self.input_shape().c(),
814 &*self.patch.output_shape,
815 )
816 .unwrap()
817 }
818
819 fn reference_sumpool(&self) -> Tensor {
820 let input_shape = self.input_shape();
821 let output_shape = self.output_shape();
822 let mut output = Tensor::zero::<f32>(&output_shape.shape).unwrap();
823 for geo_out in tract_ndarray::indices(output_shape.hw_dims()) {
824 for geo_ker in tract_ndarray::indices(&*self.patch.spec.kernel_shape) {
825 let geo_in: TVec<isize> = izip!(
826 geo_out.slice(),
827 geo_ker.slice(),
828 &self.patch.spec.strides,
829 &self.patch.spec.dilations,
830 &self.patch.pad_before
831 )
832 .map(|(o, k, s, d, p)| (o * s + k * d) as isize - *p as isize)
833 .collect();
834 if izip!(&geo_in, input_shape.hw_dims())
835 .any(|(g, i)| *g >= *i as isize || *g < 0)
836 {
837 continue;
838 }
839 let geo_in: TVec<usize> = geo_in.into_iter().map(|x| x as usize).collect();
840 for c in 0..*output_shape.c() {
841 let ocoords = self.data_format.from_n_c_hw(0, c, geo_out.slice()).unwrap();
842 let icoords = self.data_format.from_n_c_hw(0, c, &geo_in).unwrap();
843 output.to_array_view_mut::<f32>().unwrap()[&*ocoords.shape] +=
844 self.input.to_array_view::<f32>().unwrap()[&*icoords.shape];
845 }
846 }
847 }
848 output
849 }
850
851 fn check_visitor(&self) {
852 let input_shape = self.input_shape();
853 let output_shape = self.output_shape();
854 let mut output = Tensor::zero::<f32>(&output_shape.shape).unwrap();
855 self.patch.visit_output(|visitor| {
856 for (_k, offset_in) in visitor.valid_offsets_ker_in() {
857 for c in 0..*output_shape.c() {
858 output.as_slice_mut::<f32>().unwrap()
859 [visitor.output_offset as usize + c * output_shape.c_stride()] +=
860 self.input.as_slice::<f32>().unwrap()
861 [offset_in as usize + c * input_shape.c_stride()];
862 }
863 }
864 });
865 assert_eq!(output, self.reference_sumpool());
866 }
867
868 fn check_zone_visitor(&self) {
869 let input_shape = self.input_shape();
870 let output_shape = self.output_shape();
871 let mut output = Tensor::zero::<f32>(&output_shape.shape).unwrap();
872 for zone in &self.patch.zones {
873 zone.visit_output(&self.patch, |visitor| {
874 for (_k, offset_in) in visitor.valid_offsets_ker_in() {
875 for c in 0..*output_shape.c() {
876 output.as_slice_mut::<f32>().unwrap()
877 [visitor.output_offset as usize + c * output_shape.c_stride()] +=
878 self.input.as_slice::<f32>().unwrap()
879 [offset_in as usize + c * input_shape.c_stride()];
880 }
881 }
882 });
883 }
884 assert_eq!(output, self.reference_sumpool());
885 }
886
887 fn check_zoning(&self) {
888 fn in_zone(full_coords: &[usize], h_axis: usize, zone: &[Range<usize>]) -> bool {
889 for a in 0..zone.len() {
890 if full_coords[h_axis + a] < zone[a].start
891 || full_coords[h_axis + a] >= zone[a].end
892 {
893 return false;
894 }
895 }
896 true
897 }
898
899 let valid_zone = &self.patch.valid_output_zone;
900 let invalid_zones = &self.patch.invalid_output_zones;
901 let output_full_shape = self.output_shape();
902 let h_axis = self.input_shape().h_axis();
903 for coords in ndarray::indices(&*output_full_shape.shape) {
904 let inside_valid = in_zone(coords.slice(), h_axis, valid_zone);
905 let invalid_count =
906 invalid_zones.iter().filter(|z| in_zone(coords.slice(), h_axis, z)).count();
907 unsafe {
908 assert_eq!(
909 inside_valid,
910 self.patch.is_valid(&coords.slice()[self.input_shape().hw_axes()]),
911 "coords {:?}, valid_zone: {:?} inside_valid: {:?}",
912 coords.slice(),
913 valid_zone,
914 inside_valid
915 );
916 }
917 if inside_valid {
918 assert_eq!(invalid_count, 0);
919 } else {
920 assert_eq!(
921 invalid_count,
922 1,
923 "coords {:?}, valid_zone: {:?} inside_valid: {:?} invalid_zones: {:?}",
924 coords.slice(),
925 valid_zone,
926 inside_valid,
927 invalid_zones
928 );
929 }
930 }
931 }
932 }
933
934 proptest! {
935 #[test]
936 fn test_visitor(pb in any::<Problem>()) {
937 pb.check_visitor();
938 }
939
940 #[test]
941 fn test_zone_visitor(pb in any::<Problem>()) {
942 pb.check_zone_visitor();
943 }
944
945 #[test]
946 fn test_zoning(pb in any::<Problem>()) {
947 pb.check_zoning();
948 }
949 }
950
951 #[test]
952 fn test_visitor_1() {
953 let input_shape = NCHW.from_n_c_hw(1, 1, [2, 2]).unwrap();
954 let input = Tensor::zero::<f32>(&input_shape.shape).unwrap();
955 let patch = PatchSpec::for_data_shape(input_shape.clone())
956 .with_kernel_shape(tvec![2, 1])
957 .with_padding(PaddingSpec::SameLower)
958 .with_strides(tvec![1, 2])
959 .into_patch();
960 Problem { patch, input, data_format: input_shape.fmt }.check_visitor();
961 }
962
963 #[test]
964 fn test_visitor_2() {
965 let input_shape = NCHW.from_n_c_hw(1, 2, [1, 1]).unwrap();
966 let input = tensor4(&[[[[0.]], [[1f32]]]]);
967 assert_eq!(input.shape(), &*input_shape.shape);
968 let patch =
969 PatchSpec::for_data_shape(input_shape.clone()).with_output_inner_stride(2).into_patch();
970 Problem { patch, input, data_format: input_shape.fmt }.check_visitor();
971 }
972
973 #[test]
974 fn test_visitor_3() {
975 let input_shape = NHWC.from_n_c_hw(1, 2, [2, 1]).unwrap();
976 let input = tensor4(&[[[[0., 0.]], [[1., 0f32]]]]);
977 assert_eq!(input.shape(), &*input_shape.shape);
978 let patch =
979 PatchSpec::for_data_shape(input_shape.clone()).with_output_inner_stride(2).into_patch();
980 Problem { patch, input, data_format: input_shape.fmt }.check_visitor();
981 }
982
983 #[test]
984 fn test_visitor_4() {
985 let input_shape = NCHW.from_n_c_hw(1, 1, [1, 2]).unwrap();
986 let input = tensor4(&[[[[0., 1f32]]]]);
987 assert_eq!(input.shape(), &*input_shape.shape);
988 let patch = PatchSpec::for_data_shape(input_shape.clone())
989 .with_kernel_shape(tvec!(1, 2))
990 .with_output_inner_stride(1)
991 .with_padding(PaddingSpec::SameLower)
992 .into_patch();
993 Problem { patch, input, data_format: input_shape.fmt }.check_visitor();
994 }
995
996 #[test]
997 fn test_zone_visitor_1() {
998 let input_shape = NCHW.from_n_c_hw(1, 1, [2, 1]).unwrap();
999 let input = tensor4(&[[[[0.], [1f32]]]]);
1000 assert_eq!(input.shape(), &*input_shape.shape);
1001 let patch = PatchSpec::for_data_shape(input_shape.clone()).into_patch();
1002 Problem { patch, input, data_format: input_shape.fmt }.check_zone_visitor();
1003 }
1004
1005 #[test]
1006 fn test_zone_visitor_2() {
1007 let input_shape = NCHW.from_n_c_hw(1, 1, [1, 2]).unwrap();
1008 let input = tensor4(&[[[[0., 1f32]]]]);
1009 assert_eq!(input.shape(), &*input_shape.shape);
1010 let patch = PatchSpec::for_data_shape(input_shape.clone()).into_patch();
1011 Problem { patch, input, data_format: input_shape.fmt }.check_zone_visitor();
1012 }
1013}