1use crate::internal::*;
2use crate::ops::cnn::PaddingSpec;
3use crate::ops::nn::{DataFormat, DataShape};
4use ndarray::prelude::*;
5
6use super::PatchAxis;
7
8use std::fmt::Debug;
9use std::ops::Range;
10
11use tract_itertools::{izip, Itertools};
12
13#[derive(Clone, PartialEq, Eq, Hash)]
14pub struct PatchSpec {
15 pub input_shape: TVec<usize>,
16 pub input_inner_stride: usize,
17 pub output_inner_stride: usize,
18 pub kernel_shape: TVec<usize>,
19 pub strides: TVec<usize>,
20 pub dilations: TVec<usize>,
21 pub padding: PaddingSpec,
22}
23
24impl Debug for PatchSpec {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 write!(
27 f,
28 "input: {} kernel: {} strides: {} dil: {} pad: {:?}",
29 self.input_shape.iter().join(","),
30 self.kernel_shape.iter().join(","),
31 self.strides.iter().join(","),
32 self.dilations.iter().join(","),
33 self.padding
34 )
35 }
36}
37
38impl PatchSpec {
39 pub fn for_full_shape(
40 data_format: DataFormat,
41 input_full_shape: &[usize],
42 ) -> TractResult<PatchSpec> {
43 let shape = data_format.shape(input_full_shape.into())?;
44 Ok(Self::for_data_shape(shape))
45 }
46
47 pub fn for_data_shape(data_shape: DataShape) -> PatchSpec {
48 let input_shape: TVec<usize> = data_shape.hw_dims().into();
49 PatchSpec {
50 kernel_shape: tvec!(1; input_shape.len()),
51 input_inner_stride: *data_shape.w_stride(),
52 output_inner_stride: 1,
53 strides: tvec!(1; input_shape.len()),
54 dilations: tvec!(1; input_shape.len()),
55 padding: PaddingSpec::Valid,
56 input_shape,
57 }
58 }
59
60 pub fn with_kernel_shape(self, kernel_shape: TVec<usize>) -> PatchSpec {
61 PatchSpec { kernel_shape, ..self }
62 }
63
64 pub fn with_dilations(self, dilations: TVec<usize>) -> PatchSpec {
65 PatchSpec { dilations, ..self }
66 }
67
68 pub fn with_strides(self, strides: TVec<usize>) -> PatchSpec {
69 PatchSpec { strides, ..self }
70 }
71
72 pub fn with_padding(self, padding: PaddingSpec) -> PatchSpec {
73 PatchSpec { padding, ..self }
74 }
75
76 pub fn with_output_inner_stride(self, output_inner_stride: usize) -> PatchSpec {
77 PatchSpec { output_inner_stride, ..self }
78 }
79
80 pub fn into_patch(self) -> Patch {
81 let dims = self.padding.compute(
82 &self.input_shape,
83 &self.kernel_shape,
84 &self.dilations,
85 &self.strides,
86 );
87 let output: TVec<usize> = dims.iter().map(|d| d.convoluted).collect();
88 let pad_before: TVec<usize> = dims.iter().map(|d| d.pad_before).collect();
89 let pad_after: TVec<usize> = dims.iter().map(|d| d.pad_after).collect();
90
91 let data_field: Vec<isize> = ::ndarray::indices(&*self.kernel_shape)
92 .into_iter()
93 .flat_map(|coords| {
94 #[allow(clippy::unnecessary_to_owned)] coords
96 .slice()
97 .to_vec()
98 .into_iter()
99 .enumerate()
100 .map(|(ix, c)| (c * self.dilations[ix]) as isize - pad_before[ix] as isize)
101 })
102 .collect();
103 let data_field = Array2::from_shape_vec(
104 (self.kernel_shape.iter().cloned().product(), self.kernel_shape.len()),
105 data_field,
106 )
107 .unwrap();
108 let data_field_min_max: TVec<_> = data_field
109 .columns()
110 .into_iter()
111 .map(|col| (col.iter().min().cloned().unwrap(), col.iter().max().cloned().unwrap()))
112 .collect();
113
114 fn strides(shape: &[usize], inner: usize) -> TVec<isize> {
115 let mut strides: TVec<isize> = tvec![inner as isize];
116 for dim in shape.iter().skip(1).rev() {
117 let previous = *strides.last().unwrap();
118 strides.push(*dim as isize * previous);
119 }
120 strides.reverse();
121 strides
122 }
123
124 let input_storage_strides = strides(&self.input_shape, self.input_inner_stride);
125 let output_storage_strides = strides(&output, self.output_inner_stride);
126
127 let standard_layout_data_field: Vec<isize> = data_field
128 .outer_iter()
129 .map(|coords| izip!(coords, &input_storage_strides).map(|(a, b)| a * b).sum::<isize>())
130 .collect();
131
132 let regions: TVec<TVec<_>> = dims
134 .iter()
135 .enumerate()
136 .map(|(ix, d)| {
137 PatchAxis {
138 input_dim: self.input_shape[ix],
139 kernel_dim: self.kernel_shape[ix],
140 pad_before: d.pad_before,
141 pad_after: d.pad_after,
142 output_dim: d.convoluted,
143 stride: self.strides[ix],
144 dilation: self.dilations[ix],
145 }
146 .regions()
147 })
148 .collect::<TVec<_>>();
149
150 let zone_strides = strides(®ions.iter().map(|d| d.len()).collect::<TVec<_>>(), 1);
151
152 let zones: Vec<Zone> = regions
153 .iter()
154 .multi_cartesian_product()
155 .map(|regions| Zone {
156 input_zone_offset: 0,
157 output_ranges: regions.iter().map(|reg| reg.range.clone()).collect(),
158 output_shape: regions.iter().map(|reg| reg.range.end - reg.range.start).collect(),
159 output_zone_offset: izip!(®ions, &output_storage_strides)
160 .map(|(reg, &stride)| reg.range.start as isize * stride)
161 .sum::<isize>(),
162 valid: regions.iter().all(|reg| reg.mask.is_none()),
163 values_offsets: izip!(
164 0..,
165 ndarray::indices(&*self.kernel_shape),
166 &standard_layout_data_field
167 )
168 .filter(|(_ix, coords, _offset)| {
169 izip!(coords.slice(), ®ions)
170 .all(|(&x, axis)| !axis.mask.as_ref().map(|mask| mask[x]).unwrap_or(false))
171 })
172 .map(|(ix, _coords, &window_offset)| (ix, window_offset))
173 .collect(),
174 })
175 .collect();
176
177 let valid_zone = zones.iter().position(|z| z.valid);
178
179 let mut valid_output_zone = tvec!();
180 let mut invalid_output_zones = tvec!();
181 for ix in 0..self.input_shape.len() {
182 let min_max = data_field_min_max[ix];
183 let min = (-min_max.0 as usize).divceil(self.strides[ix]);
184 let max =
185 (self.input_shape[ix].saturating_sub(min_max.1 as usize)).divceil(self.strides[ix]);
186 if min != 0 {
187 let mut invalid = valid_output_zone.clone();
188 invalid.push(0..min);
189 while invalid.len() < output.len() {
190 invalid.push(0..output[invalid.len()])
191 }
192 invalid_output_zones.push(invalid);
193 }
194 if max < output[ix] {
195 let mut invalid = valid_output_zone.clone();
196 invalid.push(max..output[ix]);
197 while invalid.len() < output.len() {
198 invalid.push(0..output[invalid.len()])
199 }
200 invalid_output_zones.push(invalid);
201 }
202 valid_output_zone.push(min..max)
203 }
204
205 let op_strides_times_input_storage_strides =
206 izip!(&self.strides, &input_storage_strides).map(|(a, b)| (*a as isize * b)).collect();
207
208 Patch {
209 spec: self,
210 padded: pad_before.iter().any(|&p| p != 0) || pad_after.iter().any(|&p| p != 0),
211 pad_before,
212 pad_after,
213 output_shape: output,
214 data_field,
215 data_field_min_max,
216 standard_layout_data_field,
217 input_storage_strides,
218 output_storage_strides,
219 op_strides_times_input_storage_strides,
220 valid_output_zone,
221 invalid_output_zones,
222 zones,
223 valid_zone_id: valid_zone,
224 zone_strides,
225 }
226 }
227}
228
229#[derive(Clone, PartialEq, Eq, Hash)]
230pub struct Patch {
231 pub spec: PatchSpec,
232 pub pad_before: TVec<usize>,
233 pub pad_after: TVec<usize>,
234 pub padded: bool,
235 pub output_shape: TVec<usize>,
236 pub data_field: Array2<isize>,
237 pub data_field_min_max: TVec<(isize, isize)>,
238 pub standard_layout_data_field: Vec<isize>,
239 pub op_strides_times_input_storage_strides: TVec<isize>,
240 pub valid_output_zone: TVec<Range<usize>>,
241 pub invalid_output_zones: TVec<TVec<Range<usize>>>,
242 pub zones: Vec<Zone>,
243 pub valid_zone_id: Option<usize>,
244 pub zone_strides: TVec<isize>,
245 pub input_storage_strides: TVec<isize>,
246 pub output_storage_strides: TVec<isize>,
247}
248
249impl Debug for Patch {
250 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
251 write!(f, "{:?}", self.spec)
252 }
253}
254
255impl Patch {
256 #[inline]
257 pub fn rank(&self) -> usize {
258 self.spec.input_shape.len()
259 }
260
261 unsafe fn is_valid(&self, coords: &[usize]) -> bool {
262 for ix in 0..self.rank() {
263 let c = *coords.get_unchecked(ix) as isize;
264 let strides = *self.spec.strides.get_unchecked(ix) as isize;
265 let pos = c * strides;
266 let min_max = self.data_field_min_max.get_unchecked(ix);
267 if pos + min_max.0 < 0
268 || pos + min_max.1 >= *self.spec.input_shape.get_unchecked(ix) as isize
269 {
270 return false;
271 }
272 }
273 true
274 }
275
276 pub fn valid_zone(&self) -> Option<&Zone> {
277 self.valid_zone_id.map(|id| &self.zones[id])
278 }
279
280 #[inline]
281 pub fn visit_output(&self, mut acceptor: impl FnMut(&Scanner)) {
282 if self.zones.len() == 0 {
283 return;
284 }
285 let mut scanner = Scanner::new(self);
286 while !scanner.done() {
287 acceptor(&scanner);
288 scanner.next();
289 }
290 }
291
292 pub fn centers_offsets(&self) -> Vec<isize> {
293 if self.zones.len() == 0 {
294 return vec![];
295 }
296 let mut scanner = Scanner::new(self);
297 let len = self.output_shape.iter().cloned().product();
298 let mut v = Vec::with_capacity(len);
299 for _ in 0..len {
300 v.push(scanner.input_center_offset);
301 scanner.next()
302 }
303 v
304 }
305
306 pub fn at<'p>(&'p self, coords: &[usize]) -> PatchIterator<'p> {
307 self.at_hint(coords, None)
308 }
309
310 pub fn at_hint<'p>(&'p self, coords: &[usize], hint: Option<bool>) -> PatchIterator<'p> {
311 unsafe {
312 assert_eq!(coords.len(), self.spec.kernel_shape.len());
313 let mut center = 0;
314 for i in 0..self.op_strides_times_input_storage_strides.len() {
315 center += *self.op_strides_times_input_storage_strides.get_unchecked(i)
316 * *coords.get_unchecked(i) as isize;
317 }
318 let valid = hint.unwrap_or_else(|| !self.padded || self.is_valid(coords));
319 if valid {
320 PatchIterator::Fast(FastPatchIterator { patch: self, center, item: 0 })
321 } else {
322 let mut input_patch_center: TVec<_> = coords.into();
323 input_patch_center
324 .iter_mut()
325 .zip(self.spec.strides.iter())
326 .for_each(|(a, &b)| *a *= b);
327 PatchIterator::Safe(SafePatchIterator {
328 patch: self,
329 item: 0,
330 input_patch_center,
331 center,
332 })
333 }
334 }
335 }
336
337 pub fn global_offset_for(&self, coords: &[usize], patch_index: usize) -> usize {
338 assert_eq!(coords.len(), self.spec.kernel_shape.len());
339 let center = izip!(coords, &self.op_strides_times_input_storage_strides)
340 .map(|(a, b)| *a as isize * *b)
341 .sum::<isize>();
342 (center + self.standard_layout_data_field[patch_index]) as usize
343 }
344}
345
346#[derive(Clone, Debug, PartialEq, Eq, Hash)]
347pub struct Zone {
348 pub valid: bool,
349 pub input_zone_offset: isize,
350 pub output_zone_offset: isize,
351 pub output_ranges: Box<[Range<usize>]>,
352 pub output_shape: Box<[usize]>,
353 pub values_offsets: Box<[(usize, isize)]>,
355}
356
357impl Zone {
358 pub fn contains_output(&self, coords: &[usize]) -> bool {
359 self.output_ranges.iter().zip(coords).all(|(range, &x)| x >= range.start && x < range.end)
360 }
361
362 #[inline]
363 pub fn visit_output(&self, patch: &Patch, mut acceptor: impl FnMut(&ZoneScanner)) {
364 let mut scanner = ZoneScanner::new(self, patch);
365 while !scanner.done() {
366 acceptor(&scanner);
367 scanner.next();
368 }
369 }
370}
371
372#[derive(Clone, Debug, PartialEq, Eq)]
373pub struct ZoneScanner<'p> {
374 pub patch: &'p Patch,
375 pub zone: &'p Zone,
376 pub output_offset: isize,
377 pub output_coords: Box<[usize]>,
378 pub input_center_offset: isize,
379 pub inner_loop_axis: usize,
380 pub inner_loop_len: usize,
381 pub inner_loop_output_range: Range<usize>,
382 pub inner_loop_output_stride: isize,
383 pub inner_loop_input_full_stride: isize,
384 pub done: bool,
385}
386
387impl<'p> ZoneScanner<'p> {
388 pub fn new(zone: &'p Zone, patch: &'p Patch) -> ZoneScanner<'p> {
389 let inner_loop_axis =
390 zone.output_shape.iter().enumerate().max_by_key(|(_, dim)| *dim).unwrap().0;
391 let inner_loop_output_range = zone.output_ranges[inner_loop_axis].clone();
392 let inner_loop_output_stride = patch.output_storage_strides[inner_loop_axis];
393 let inner_loop_input_full_stride =
394 patch.op_strides_times_input_storage_strides[inner_loop_axis];
395 let mut scan = ZoneScanner {
396 patch,
397 zone,
398 output_offset: 0,
399 input_center_offset: 0,
400 inner_loop_axis,
401 inner_loop_len: inner_loop_output_range.len(),
402 inner_loop_output_range,
403 inner_loop_output_stride,
404 inner_loop_input_full_stride,
405 output_coords: zone.output_ranges.iter().map(|r| r.start).collect(),
406 done: false,
407 };
408 scan.refresh_dependent();
409 scan
410 }
411
412 #[inline]
413 pub fn valid_offsets_ker_in(&self) -> impl Iterator<Item = (usize, isize)> + '_ {
414 self.zone.values_offsets.iter().map(move |pair| (pair.0, pair.1 + self.input_center_offset))
415 }
416
417 pub unsafe fn next_non_inner_axis(&mut self) {
418 let rank = self.patch.rank();
419 let inner_loop_axis = self.inner_loop_axis;
420 for axis in (0..rank).rev() {
421 if axis == inner_loop_axis {
422 continue;
423 }
424 *self.output_coords.get_unchecked_mut(axis) += 1;
425 if *self.output_coords.get_unchecked_mut(axis)
426 < self.zone.output_ranges.get_unchecked(axis).end
427 {
428 self.refresh_dependent();
429 return;
430 }
431 *self.output_coords.get_unchecked_mut(axis) =
432 self.zone.output_ranges.get_unchecked(axis).start;
433 }
434 self.done = true;
435 }
436
437 pub unsafe fn reset(&mut self) {
438 self.output_offset = 0;
439 self.input_center_offset = 0;
440 for ix in 0..self.output_coords.len() {
441 *self.output_coords.get_unchecked_mut(ix) =
442 self.zone.output_ranges.get_unchecked(ix).start;
443 }
444 self.done = false;
445 self.refresh_dependent()
446 }
447
448 #[inline(never)]
449 fn refresh_dependent(&mut self) {
450 self.input_center_offset = self
451 .patch
452 .op_strides_times_input_storage_strides
453 .iter()
454 .zip(self.output_coords.iter())
455 .map(|(a, b)| *a * *b as isize)
456 .sum();
457 self.output_offset = self
458 .patch
459 .output_storage_strides
460 .iter()
461 .zip(self.output_coords.iter())
462 .map(|(a, b)| a * *b as isize)
463 .sum();
464 }
465
466 #[inline]
467 pub fn next(&mut self) {
468 let inner_loop_axis = self.inner_loop_axis;
469 unsafe {
470 *self.output_coords.get_unchecked_mut(inner_loop_axis) += 1;
471 if *self.output_coords.get_unchecked(inner_loop_axis) < self.inner_loop_output_range.end
472 {
473 self.input_center_offset += self.inner_loop_input_full_stride;
474 self.output_offset += self.inner_loop_output_stride;
475 } else {
476 *self.output_coords.get_unchecked_mut(inner_loop_axis) =
477 self.inner_loop_output_range.start;
478 self.next_non_inner_axis();
479 }
480 }
481 }
482
483 pub fn done(&self) -> bool {
484 self.done
485 }
486}
487
488#[derive(Clone, Debug, PartialEq, Eq)]
489pub struct Scanner<'p> {
490 pub patch: &'p Patch,
491 pub zone_id: usize,
492 pub zone_coords: TVec<usize>,
493 pub zone: &'p Zone,
494 pub output_offset: isize,
495 pub output_coords: TVec<usize>,
496 pub input_coords: TVec<usize>,
497 pub input_center_offset: isize,
498 done: bool,
499}
500
501impl<'p> Scanner<'p> {
502 fn new(patch: &'p Patch) -> Scanner<'p> {
503 let rank = patch.rank();
504 let zone = &patch.zones[0];
505 Scanner {
506 patch,
507 zone_coords: tvec!(0; rank),
508 zone,
509 zone_id: 0,
510 output_offset: 0,
511 input_center_offset: 0,
512 input_coords: tvec!(0; rank),
513 output_coords: tvec!(0; rank),
514 done: false,
515 }
516 }
517
518 #[inline]
519 pub fn valid_count(&self) -> usize {
520 self.zone.values_offsets.len()
521 }
522
523 #[inline]
524 pub fn valid_offsets(&self) -> impl Iterator<Item = isize> + '_ {
525 self.zone.values_offsets.iter().map(move |pair| pair.1 + self.input_center_offset)
526 }
527
528 #[inline]
529 pub fn valid_offsets_ker_in(&self) -> impl Iterator<Item = (usize, isize)> + '_ {
530 self.zone.values_offsets.iter().map(move |pair| (pair.0, pair.1 + self.input_center_offset))
531 }
532
533 #[inline]
534 pub fn next(&mut self) {
535 let rank = self.patch.rank();
536 let inner_dim = rank - 1;
537 unsafe {
538 *self.output_coords.get_unchecked_mut(inner_dim) += 1;
539 *self.input_coords.get_unchecked_mut(inner_dim) +=
540 *self.patch.spec.strides.get_unchecked(inner_dim);
541 self.output_offset += self.patch.spec.output_inner_stride as isize;
542 self.input_center_offset +=
543 self.patch.op_strides_times_input_storage_strides.get_unchecked(inner_dim);
544 if *self.output_coords.get_unchecked(inner_dim)
545 < self.zone.output_ranges.get_unchecked(inner_dim).end
546 {
547 return;
548 }
549 if self.output_coords.get_unchecked(inner_dim)
550 < self.patch.output_shape.get_unchecked(inner_dim)
551 {
552 self.zone_id += 1;
553 *self.zone_coords.get_unchecked_mut(inner_dim) += 1;
554 self.zone = self.patch.zones.get_unchecked(self.zone_id);
555 } else {
556 for axis in (0..rank - 1).rev() {
557 *self.output_coords.get_unchecked_mut(axis + 1) = 0;
558 *self.input_coords.get_unchecked_mut(axis + 1) = 0;
559 *self.output_coords.get_unchecked_mut(axis) += 1;
560 *self.input_coords.get_unchecked_mut(axis) +=
561 self.patch.spec.strides.get_unchecked(axis);
562 *self.zone_coords.get_unchecked_mut(axis + 1) = 0;
563 if *self.output_coords.get_unchecked(axis)
564 == self.zone.output_ranges.get_unchecked(axis).end
565 {
566 *self.zone_coords.get_unchecked_mut(axis) += 1;
567 }
568 if *self.output_coords.get_unchecked(axis)
569 < *self.patch.output_shape.get_unchecked(axis)
570 {
571 break;
572 }
573 }
574 if self.output_coords.get_unchecked(0) == self.patch.output_shape.get_unchecked(0) {
575 self.done = true;
576 return;
577 }
578 self.zone_id = 0;
579 self.input_center_offset = 0;
580 for i in 0..rank {
581 self.zone_id += *self.zone_coords.get_unchecked(i)
582 * *self.patch.zone_strides.get_unchecked(i) as usize;
583 self.input_center_offset += *self.input_coords.get_unchecked(i) as isize
584 * *self.patch.input_storage_strides.get_unchecked(i);
585 }
586 self.zone = self.patch.zones.get_unchecked(self.zone_id);
587 }
588 }
589 }
590
591 pub fn done(&self) -> bool {
592 self.done
593 }
594}
595
596#[derive(Debug)]
597pub enum PatchIterator<'p> {
598 Fast(FastPatchIterator<'p>),
599 Safe(SafePatchIterator<'p>),
600}
601
602impl Iterator for PatchIterator<'_> {
603 type Item = Option<isize>;
604 #[inline(always)]
605 fn next(&mut self) -> Option<Option<isize>> {
606 match self {
607 PatchIterator::Fast(ref mut it) => it.next(),
608 PatchIterator::Safe(ref mut it) => it.next(),
609 }
610 }
611}
612
613#[derive(Debug)]
614pub struct FastPatchIterator<'p> {
615 patch: &'p Patch,
616 center: isize,
617 item: usize,
618}
619
620impl Iterator for FastPatchIterator<'_> {
621 type Item = Option<isize>;
622 #[inline(always)]
623 fn next(&mut self) -> Option<Option<isize>> {
624 if self.item == self.patch.standard_layout_data_field.len() {
625 return None;
626 }
627 unsafe {
628 let position =
629 self.center + self.patch.standard_layout_data_field.get_unchecked(self.item);
630 self.item += 1;
631 Some(Some(position))
632 }
633 }
634}
635
636#[derive(Debug)]
637pub struct SafePatchIterator<'p> {
638 patch: &'p Patch,
639 item: usize,
640 input_patch_center: TVec<usize>,
641 center: isize,
642}
643
644impl Iterator for SafePatchIterator<'_> {
645 type Item = Option<isize>;
646 fn next(&mut self) -> Option<Option<isize>> {
647 unsafe {
648 if self.item == self.patch.standard_layout_data_field.len() {
649 return None;
650 }
651 let input_shape = &self.patch.spec.input_shape;
652 let img_offset = self.patch.data_field.as_ptr().add(self.item * input_shape.len());
653
654 for ix in 0..input_shape.len() {
655 let pos = *self.input_patch_center.get_unchecked(ix) as isize + *img_offset.add(ix);
656 if pos < 0 || pos as usize >= *input_shape.get_unchecked(ix) {
657 self.item += 1;
658 return Some(None);
659 }
660 }
661 let position =
662 self.center + self.patch.standard_layout_data_field.get_unchecked(self.item);
663 self.item += 1;
664 Some(Some(position))
665 }
666 }
667}
668
669#[cfg(test)]
670pub mod test {
671 use super::*;
672 use crate::ops::nn::DataFormat::*;
673 use proptest::prelude::*;
674 use proptest::*;
675
676 fn compute_output_spatial_dim(
677 input: usize,
678 dilation: usize,
679 kdim: usize,
680 pad_before: usize,
681 bad_after: usize,
682 stride: usize,
683 ) -> usize {
684 let patch = PatchSpec::for_full_shape(NCHW, &[1, 1, input])
685 .unwrap()
686 .with_dilations(tvec!(dilation))
687 .with_kernel_shape(tvec!(kdim))
688 .with_padding(PaddingSpec::ExplicitOnnxPool(tvec![pad_before], tvec![bad_after], true))
689 .with_strides(tvec![stride])
690 .into_patch();
691 patch.output_shape[0]
692 }
693
694 #[test]
695 fn basic() {
696 assert_eq!(compute_output_spatial_dim(5, 1, 3, 0, 0, 1), 3);
697 }
698
699 #[test]
700 fn strides() {
701 assert_eq!(compute_output_spatial_dim(7, 1, 3, 0, 0, 2), 3);
702 }
703
704 #[test]
705 fn padding() {
706 assert_eq!(compute_output_spatial_dim(5, 1, 3, 1, 1, 1), 5);
707 }
708
709 #[test]
710 fn strides_and_padding() {
711 assert_eq!(compute_output_spatial_dim(7, 1, 3, 1, 1, 2), 4);
712 }
713
714 fn field(kdim: &[usize], dilations: &[usize]) -> Array2<isize> {
715 let patch =
716 PatchSpec::for_data_shape(NCHW.from_n_c_hw(1, 1, tvec![10; kdim.len()]).unwrap())
717 .with_dilations(dilations.into())
718 .with_kernel_shape(kdim.into())
719 .with_strides(tvec![1; kdim.len()])
720 .into_patch();
721 patch.data_field
722 }
723
724 #[test]
725 fn test_field() {
726 assert_eq!(field(&[3], &[1]), arr2(&[[0], [1], [2]]));
727 assert_eq!(field(&[3], &[2]), arr2(&[[0], [2], [4]]));
728 assert_eq!(field(&[2, 2], &[1, 1]), arr2(&[[0, 0], [0, 1], [1, 0], [1, 1]]));
729 assert_eq!(field(&[2, 2], &[2, 1]), arr2(&[[0, 0], [0, 1], [2, 0], [2, 1]]));
730 }
731
732 pub fn tensor(shape: &[usize]) -> BoxedStrategy<Tensor> {
733 let len = shape.iter().product::<usize>();
734 let shape = shape.to_vec();
735 proptest::collection::vec(any::<i8>().prop_map(|i| i as f32), len..=len)
736 .prop_map(move |vec| ArrayD::from_shape_vec(shape.clone(), vec).unwrap().into_tensor())
737 .boxed()
738 }
739
740 #[derive(Debug)]
741 struct Problem {
742 patch: Patch,
743 input: Tensor,
744 data_format: DataFormat,
745 }
746
747 impl Arbitrary for Problem {
748 type Parameters = ();
749 type Strategy = BoxedStrategy<Problem>;
750 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
751 (
752 prop_oneof!(Just(NCHW), Just(NHWC)),
753 (1usize..3, 1usize..3),
754 1usize..3,
755 (1usize..3, 1usize..3),
756 prop_oneof![
757 Just(PaddingSpec::Valid),
758 Just(PaddingSpec::SameLower),
759 Just(PaddingSpec::SameUpper)
760 ],
761 (1usize..4, 1usize..4),
762 )
763 .prop_flat_map(|p| {
764 let dil = p.1;
765 let ks = p.3;
766 let strides = p.5;
767 let min_size: (usize, usize) = (1 + (ks.0 - 1) * dil.0, 1 + (ks.1 - 1) * dil.1);
768 (
769 Just(p),
770 (min_size.0..min_size.0 + strides.0 * 3),
771 (min_size.1..min_size.1 + strides.1 * 3),
772 )
773 })
774 .prop_flat_map(|(p, h, w)| {
775 let input_shape = p.0.from_n_c_hw(1, p.2, [h, w]).unwrap();
776 let input = tensor(&input_shape.shape);
777 (Just(p), input)
778 })
779 .prop_map(|((fmt, dil, c, ks, pad, strides), input)| {
780 let output_inner_stride = if fmt.c_is_last() { c } else { 1 };
781 Problem {
782 patch: PatchSpec::for_full_shape(fmt, input.shape())
783 .unwrap()
784 .with_dilations(tvec!(dil.0, dil.1))
785 .with_kernel_shape(tvec!(ks.0, ks.1))
786 .with_padding(pad)
787 .with_strides(tvec![strides.0, strides.1])
788 .with_output_inner_stride(output_inner_stride)
789 .into_patch(),
790 input,
791 data_format: fmt,
792 }
793 })
794 .boxed()
795 }
796 }
797
798 impl Problem {
799 fn input_shape(&self) -> DataShape {
800 self.data_format.shape(self.input.shape().into()).unwrap()
801 }
802
803 fn output_shape(&self) -> DataShape {
804 self.data_format
805 .from_n_c_hw(
806 self.input_shape().n().cloned().unwrap_or(1),
807 *self.input_shape().c(),
808 &*self.patch.output_shape,
809 )
810 .unwrap()
811 }
812
813 fn reference_sumpool(&self) -> Tensor {
814 let input_shape = self.input_shape();
815 let output_shape = self.output_shape();
816 let mut output = Tensor::zero::<f32>(&output_shape.shape).unwrap();
817 for geo_out in tract_ndarray::indices(output_shape.hw_dims()) {
818 for geo_ker in tract_ndarray::indices(&*self.patch.spec.kernel_shape) {
819 let geo_in: TVec<isize> = izip!(
820 geo_out.slice(),
821 geo_ker.slice(),
822 &self.patch.spec.strides,
823 &self.patch.spec.dilations,
824 &self.patch.pad_before
825 )
826 .map(|(o, k, s, d, p)| (o * s + k * d) as isize - *p as isize)
827 .collect();
828 if izip!(&geo_in, input_shape.hw_dims())
829 .any(|(g, i)| *g >= *i as isize || *g < 0)
830 {
831 continue;
832 }
833 let geo_in: TVec<usize> = geo_in.into_iter().map(|x| x as usize).collect();
834 for c in 0..*output_shape.c() {
835 let ocoords = self.data_format.from_n_c_hw(0, c, geo_out.slice()).unwrap();
836 let icoords = self.data_format.from_n_c_hw(0, c, &geo_in).unwrap();
837 output.to_array_view_mut::<f32>().unwrap()[&*ocoords.shape] +=
838 self.input.to_array_view::<f32>().unwrap()[&*icoords.shape];
839 }
840 }
841 }
842 output
843 }
844
845 fn check_visitor(&self) {
846 let input_shape = self.input_shape();
847 let output_shape = self.output_shape();
848 let mut output = Tensor::zero::<f32>(&output_shape.shape).unwrap();
849 self.patch.visit_output(|visitor| {
850 for (_k, offset_in) in visitor.valid_offsets_ker_in() {
851 for c in 0..*output_shape.c() {
852 output.as_slice_mut::<f32>().unwrap()
853 [visitor.output_offset as usize + c * output_shape.c_stride()] +=
854 self.input.as_slice::<f32>().unwrap()
855 [offset_in as usize + c * input_shape.c_stride()];
856 }
857 }
858 });
859 assert_eq!(output, self.reference_sumpool());
860 }
861
862 fn check_zone_visitor(&self) {
863 let input_shape = self.input_shape();
864 let output_shape = self.output_shape();
865 let mut output = Tensor::zero::<f32>(&output_shape.shape).unwrap();
866 for zone in &self.patch.zones {
867 zone.visit_output(&self.patch, |visitor| {
868 for (_k, offset_in) in visitor.valid_offsets_ker_in() {
869 for c in 0..*output_shape.c() {
870 output.as_slice_mut::<f32>().unwrap()
871 [visitor.output_offset as usize + c * output_shape.c_stride()] +=
872 self.input.as_slice::<f32>().unwrap()
873 [offset_in as usize + c * input_shape.c_stride()];
874 }
875 }
876 });
877 }
878 assert_eq!(output, self.reference_sumpool());
879 }
880
881 fn check_zoning(&self) {
882 fn in_zone(full_coords: &[usize], h_axis: usize, zone: &[Range<usize>]) -> bool {
883 for a in 0..zone.len() {
884 if full_coords[h_axis + a] < zone[a].start
885 || full_coords[h_axis + a] >= zone[a].end
886 {
887 return false;
888 }
889 }
890 true
891 }
892
893 let valid_zone = &self.patch.valid_output_zone;
894 let invalid_zones = &self.patch.invalid_output_zones;
895 let output_full_shape = self.output_shape();
896 let h_axis = self.input_shape().h_axis();
897 for coords in ndarray::indices(&*output_full_shape.shape) {
898 let inside_valid = in_zone(coords.slice(), h_axis, valid_zone);
899 let invalid_count =
900 invalid_zones.iter().filter(|z| in_zone(coords.slice(), h_axis, z)).count();
901 unsafe {
902 assert_eq!(
903 inside_valid,
904 self.patch.is_valid(&coords.slice()[self.input_shape().hw_axes()]),
905 "coords {:?}, valid_zone: {:?} inside_valid: {:?}",
906 coords.slice(),
907 valid_zone,
908 inside_valid
909 );
910 }
911 if inside_valid {
912 assert_eq!(invalid_count, 0);
913 } else {
914 assert_eq!(
915 invalid_count,
916 1,
917 "coords {:?}, valid_zone: {:?} inside_valid: {:?} invalid_zones: {:?}",
918 coords.slice(),
919 valid_zone,
920 inside_valid,
921 invalid_zones
922 );
923 }
924 }
925 }
926 }
927
928 proptest! {
929 #[test]
930 fn test_visitor(pb in any::<Problem>()) {
931 pb.check_visitor();
932 }
933
934 #[test]
935 fn test_zone_visitor(pb in any::<Problem>()) {
936 pb.check_zone_visitor();
937 }
938
939 #[test]
940 fn test_zoning(pb in any::<Problem>()) {
941 pb.check_zoning();
942 }
943 }
944
945 #[test]
946 fn test_visitor_1() {
947 let input_shape = NCHW.from_n_c_hw(1, 1, [2, 2]).unwrap();
948 let input = Tensor::zero::<f32>(&input_shape.shape).unwrap();
949 let patch = PatchSpec::for_data_shape(input_shape.clone())
950 .with_kernel_shape(tvec![2, 1])
951 .with_padding(PaddingSpec::SameLower)
952 .with_strides(tvec![1, 2])
953 .into_patch();
954 Problem { patch, input, data_format: input_shape.fmt }.check_visitor();
955 }
956
957 #[test]
958 fn test_visitor_2() {
959 let input_shape = NCHW.from_n_c_hw(1, 2, [1, 1]).unwrap();
960 let input = tensor4(&[[[[0.]], [[1f32]]]]);
961 assert_eq!(input.shape(), &*input_shape.shape);
962 let patch =
963 PatchSpec::for_data_shape(input_shape.clone()).with_output_inner_stride(2).into_patch();
964 Problem { patch, input, data_format: input_shape.fmt }.check_visitor();
965 }
966
967 #[test]
968 fn test_visitor_3() {
969 let input_shape = NHWC.from_n_c_hw(1, 2, [2, 1]).unwrap();
970 let input = tensor4(&[[[[0., 0.]], [[1., 0f32]]]]);
971 assert_eq!(input.shape(), &*input_shape.shape);
972 let patch =
973 PatchSpec::for_data_shape(input_shape.clone()).with_output_inner_stride(2).into_patch();
974 Problem { patch, input, data_format: input_shape.fmt }.check_visitor();
975 }
976
977 #[test]
978 fn test_visitor_4() {
979 let input_shape = NCHW.from_n_c_hw(1, 1, [1, 2]).unwrap();
980 let input = tensor4(&[[[[0., 1f32]]]]);
981 assert_eq!(input.shape(), &*input_shape.shape);
982 let patch = PatchSpec::for_data_shape(input_shape.clone())
983 .with_kernel_shape(tvec!(1, 2))
984 .with_output_inner_stride(1)
985 .with_padding(PaddingSpec::SameLower)
986 .into_patch();
987 Problem { patch, input, data_format: input_shape.fmt }.check_visitor();
988 }
989
990 #[test]
991 fn test_zone_visitor_1() {
992 let input_shape = NCHW.from_n_c_hw(1, 1, [2, 1]).unwrap();
993 let input = tensor4(&[[[[0.], [1f32]]]]);
994 assert_eq!(input.shape(), &*input_shape.shape);
995 let patch = PatchSpec::for_data_shape(input_shape.clone()).into_patch();
996 Problem { patch, input, data_format: input_shape.fmt }.check_zone_visitor();
997 }
998
999 #[test]
1000 fn test_zone_visitor_2() {
1001 let input_shape = NCHW.from_n_c_hw(1, 1, [1, 2]).unwrap();
1002 let input = tensor4(&[[[[0., 1f32]]]]);
1003 assert_eq!(input.shape(), &*input_shape.shape);
1004 let patch = PatchSpec::for_data_shape(input_shape.clone()).into_patch();
1005 Problem { patch, input, data_format: input_shape.fmt }.check_zone_visitor();
1006 }
1007}