Skip to main content

tract_core/ops/cnn/
patch_axis.rs

1use crate::internal::*;
2
3use std::ops::Range;
4use tract_itertools::Itertools;
5
6#[derive(Clone, Debug, new, PartialEq, Eq)]
7pub struct Region {
8    pub range: Range<usize>,
9    pub mask: Option<TVec<bool>>,
10}
11
12#[derive(Clone, Debug, new, PartialEq, Eq)]
13pub struct PatchAxis {
14    pub input_dim: usize,
15    pub kernel_dim: usize,
16    pub pad_before: usize,
17    pub pad_after: usize,
18    pub output_dim: usize,
19    pub stride: usize,
20    pub dilation: usize,
21}
22
23impl PatchAxis {
24    fn valid_range(&self) -> Option<Range<usize>> {
25        let field = (self.kernel_dim - 1) * self.dilation + 1;
26        if field > self.input_dim {
27            return None;
28        }
29        let min = self.pad_before.divceil(self.stride);
30        let max = (self.input_dim + self.pad_before).saturating_sub(field) / self.stride;
31        if max >= min { Some(min..(max + 1)) } else { None }
32    }
33
34    fn invalid_at_left(&self, pos: usize) -> usize {
35        let center_pos = pos * self.stride;
36        self.pad_before.saturating_sub(center_pos).divceil(self.dilation).min(self.kernel_dim)
37    }
38
39    fn invalid_at_right(&self, pos: usize) -> usize {
40        let center_pos = pos * self.stride;
41        let last_valid = self.input_dim + self.pad_before;
42        let valid = last_valid.saturating_sub(center_pos).divceil(self.dilation);
43        self.kernel_dim.saturating_sub(valid)
44    }
45
46    fn make_invalid_regions(&self, range: Range<usize>) -> TVec<Region> {
47        range
48            .map(move |ix| (ix, (self.invalid_at_left(ix), self.invalid_at_right(ix))))
49            .group_by(|&pair| pair.1)
50            .into_iter()
51            .map(move |(invalid, pairs)| {
52                let (min, max) = pairs.map(|p| p.0).minmax().into_option().unwrap();
53                let mut mask = tvec!(false; self.kernel_dim);
54                for i in 0..invalid.0 {
55                    mask[i] = true;
56                }
57                for i in 0..invalid.1 {
58                    mask[self.kernel_dim - 1 - i] = true;
59                }
60                Region::new(min..max + 1, Some(mask))
61            })
62            .collect()
63    }
64
65    pub fn regions(&self) -> TVec<Region> {
66        let mut regions = tvec!();
67        if let Some(valid_range) = self.valid_range() {
68            if valid_range.start > 0 {
69                regions.extend(self.make_invalid_regions(0..valid_range.start));
70            }
71            if valid_range.start != valid_range.end {
72                regions.push(Region::new(valid_range.clone(), None));
73            }
74            if valid_range.end < self.output_dim {
75                regions.extend(self.make_invalid_regions(valid_range.end..self.output_dim));
76            }
77        } else {
78            regions.extend(self.make_invalid_regions(0..self.output_dim));
79        }
80        regions
81    }
82}
83
84#[cfg(test)]
85pub mod test {
86    use super::*;
87
88    // • 0 1 2 3 4 • -> 3 -> (0) 1 2 3 (4)
89    fn axis_5_3() -> PatchAxis {
90        PatchAxis::new(5, 3, 1, 1, 5, 1, 1)
91    }
92
93    // • • 0 1 2 3 4 • -> 4 -> (0) (1) 2 3 (4)
94    fn axis_5_4() -> PatchAxis {
95        PatchAxis::new(5, 4, 2, 1, 5, 1, 1)
96    }
97
98    // • • 0 1 2 3 4 • • -> 4 -> (0) (1) 2 (3) (4)
99    fn axis_5_5() -> PatchAxis {
100        PatchAxis::new(5, 5, 2, 2, 5, 1, 1)
101    }
102
103    // • 0 1 2 3 4 • -> 3 -> (0) 2 (4)
104    fn axis_5_3_s2() -> PatchAxis {
105        PatchAxis::new(5, 3, 1, 1, 3, 2, 1)
106    }
107
108    // • • 0 1 2 3 4 • • -> 3x2 -> (0) (1) 2 (3) (4)
109    fn axis_5_3_d2() -> PatchAxis {
110        PatchAxis::new(5, 3, 2, 2, 5, 1, 2)
111    }
112
113    // 0 1 2 3 4 5 6 7 8 9 -> 2 -> 0 3 6
114    fn axis_10_2_s3_valid() -> PatchAxis {
115        PatchAxis::new(10, 2, 0, 0, 3, 3, 1)
116    }
117
118    #[test]
119    fn axis_valid_ranges() {
120        assert_eq!(axis_5_3().valid_range(), Some(1..4));
121        assert_eq!(axis_5_4().valid_range(), Some(2..4));
122        assert_eq!(axis_5_5().valid_range(), Some(2..3));
123        assert_eq!(axis_5_3_s2().valid_range(), Some(1..2));
124        assert_eq!(axis_5_3_d2().valid_range(), Some(2..3));
125    }
126
127    #[test]
128    fn axis_invalid_at_left() {
129        assert_eq!(axis_5_3().invalid_at_left(0), 1);
130        assert_eq!(axis_5_3().invalid_at_left(1), 0);
131        assert_eq!(axis_5_3().invalid_at_left(2), 0);
132
133        assert_eq!(axis_5_4().invalid_at_left(0), 2);
134        assert_eq!(axis_5_4().invalid_at_left(1), 1);
135        assert_eq!(axis_5_4().invalid_at_left(2), 0);
136
137        assert_eq!(axis_5_5().invalid_at_left(0), 2);
138        assert_eq!(axis_5_5().invalid_at_left(1), 1);
139        assert_eq!(axis_5_5().invalid_at_left(2), 0);
140
141        assert_eq!(axis_5_3_d2().invalid_at_left(0), 1);
142        assert_eq!(axis_5_3_d2().invalid_at_left(1), 1);
143        assert_eq!(axis_5_3_d2().invalid_at_left(2), 0);
144    }
145
146    #[test]
147    fn axis_invalid_at_right() {
148        assert_eq!(axis_5_3().invalid_at_right(0), 0);
149        assert_eq!(axis_5_3().invalid_at_right(3), 0);
150        assert_eq!(axis_5_3().invalid_at_right(4), 1);
151
152        assert_eq!(axis_5_4().invalid_at_right(0), 0);
153        assert_eq!(axis_5_4().invalid_at_right(3), 0);
154        assert_eq!(axis_5_4().invalid_at_right(4), 1);
155
156        assert_eq!(axis_5_5().invalid_at_right(0), 0);
157        assert_eq!(axis_5_5().invalid_at_right(3), 1);
158        assert_eq!(axis_5_5().invalid_at_right(4), 2);
159    }
160
161    #[test]
162    fn axis_5_3_regions() {
163        let regions = axis_5_3().regions();
164        assert_eq!(
165            regions,
166            tvec!(
167                Region::new(0..1, Some(tvec!(true, false, false))),
168                Region::new(1..4, None),
169                Region::new(4..5, Some(tvec!(false, false, true)))
170            )
171        );
172    }
173
174    #[test]
175    fn axis_5_3_s2_regions() {
176        let regions = axis_5_3_s2().regions();
177        assert_eq!(
178            regions,
179            tvec!(
180                Region::new(0..1, Some(tvec!(true, false, false))),
181                Region::new(1..2, None),
182                Region::new(2..3, Some(tvec!(false, false, true)))
183            )
184        );
185    }
186
187    #[test]
188    fn axis_5_3_d2_regions() {
189        let regions = axis_5_3_d2().regions();
190        assert_eq!(
191            regions,
192            tvec!(
193                Region::new(0..2, Some(tvec!(true, false, false))),
194                Region::new(2..3, None),
195                Region::new(3..5, Some(tvec!(false, false, true)))
196            )
197        );
198    }
199
200    #[test]
201    fn axis_10_2_s3_valid_regions() {
202        let regions = axis_10_2_s3_valid().regions();
203        assert_eq!(regions, tvec!(Region::new(0..3, None),));
204    }
205
206    #[test]
207    fn axis_7_3_s2_regions() {
208        // • 0 1 2 3 4 5 6 • -> 3 -> (0) 2 4 (6)
209        let regions = PatchAxis::new(7, 3, 1, 1, 4, 2, 1).regions();
210        assert_eq!(
211            regions,
212            tvec!(
213                Region::new(0..1, Some(tvec!(true, false, false))),
214                Region::new(1..3, None),
215                Region::new(3..4, Some(tvec!(false, false, true)))
216            )
217        );
218    }
219
220    #[test]
221    fn axis_5_2_s2_regions() {
222        // • 0 1 2 3 4 • -> 2 -> (0) 2 4
223        let regions = PatchAxis::new(5, 2, 1, 1, 3, 2, 1).regions();
224        assert_eq!(
225            regions,
226            tvec!(Region::new(0..1, Some(tvec!(true, false))), Region::new(1..3, None),)
227        );
228    }
229
230    #[test]
231    fn axis_28_3_very_padded_regions() {
232        // • • 0 1 2 3 ... 26 27 • • -> 2 -> (-1) (0) (1) 2 3 4 ... 26 (27) (28) (29)
233        let regions = PatchAxis::new(28, 3, 2, 2, 30, 1, 1).regions();
234        assert_eq!(
235            regions,
236            tvec!(
237                Region::new(0..1, Some(tvec!(true, true, false))),
238                Region::new(1..2, Some(tvec!(true, false, false))),
239                Region::new(2..28, None),
240                Region::new(28..29, Some(tvec!(false, false, true))),
241                Region::new(29..30, Some(tvec!(false, true, true))),
242            )
243        );
244    }
245
246    #[test]
247    fn axis_7_1_s2_regions() {
248        // 0 1 2 3 4 5 6 -> 1 -> 0 2 4 6
249        let regions = PatchAxis::new(7, 1, 0, 0, 4, 2, 1).regions();
250        assert_eq!(regions, tvec!(Region::new(0..4, None),));
251    }
252
253    #[test]
254    fn axis_1_2_regions() {
255        // 0 -> 2 -> (0)
256        let regions = PatchAxis::new(1, 2, 0, 1, 1, 1, 1).regions();
257        assert_eq!(regions, tvec!(Region::new(0..1, Some(tvec!(false, true))),));
258    }
259
260    #[test]
261    fn axis_dnn_left_pad() {
262        let regions = PatchAxis::new(1, 1, 2, 0, 3, 1, 1).regions();
263        assert_eq!(regions, tvec!(Region::new(0..2, Some(tvec!(true))), Region::new(2..3, None)));
264    }
265}