tract_core/ops/cnn/
patch_axis.rs

1use crate::internal::*;
2
3use std::ops::Range;
4use tract_itertools::Itertools;
5
6#[derive(Clone, Debug, new, PartialEq, Eq)]
7pub struct Region {
8    pub range: Range<usize>,
9    pub mask: Option<TVec<bool>>,
10}
11
12#[derive(Clone, Debug, new, PartialEq, Eq)]
13pub struct PatchAxis {
14    pub input_dim: usize,
15    pub kernel_dim: usize,
16    pub pad_before: usize,
17    pub pad_after: usize,
18    pub output_dim: usize,
19    pub stride: usize,
20    pub dilation: usize,
21}
22
23impl PatchAxis {
24    fn valid_range(&self) -> Option<Range<usize>> {
25        let field = (self.kernel_dim - 1) * self.dilation + 1;
26        if field > self.input_dim {
27            return None;
28        }
29        let min = self.pad_before.divceil(self.stride);
30        let max = (self.input_dim + self.pad_before).saturating_sub(field) / self.stride;
31        if max >= min {
32            Some(min..(max + 1))
33        } else {
34            None
35        }
36    }
37
38    fn invalid_at_left(&self, pos: usize) -> usize {
39        let center_pos = pos * self.stride;
40        self.pad_before.saturating_sub(center_pos).divceil(self.dilation).min(self.kernel_dim)
41    }
42
43    fn invalid_at_right(&self, pos: usize) -> usize {
44        let center_pos = pos * self.stride;
45        let last_valid = self.input_dim + self.pad_before;
46        let valid = last_valid.saturating_sub(center_pos).divceil(self.dilation);
47        self.kernel_dim.saturating_sub(valid)
48    }
49
50    fn make_invalid_regions(&self, range: Range<usize>) -> TVec<Region> {
51        range
52            .map(move |ix| (ix, (self.invalid_at_left(ix), self.invalid_at_right(ix))))
53            .group_by(|&pair| pair.1)
54            .into_iter()
55            .map(move |(invalid, pairs)| {
56                let (min, max) = pairs.map(|p| p.0).minmax().into_option().unwrap();
57                let mut mask = tvec!(false; self.kernel_dim);
58                for i in 0..invalid.0 {
59                    mask[i] = true;
60                }
61                for i in 0..invalid.1 {
62                    mask[self.kernel_dim - 1 - i] = true;
63                }
64                Region::new(min..max + 1, Some(mask))
65            })
66            .collect()
67    }
68
69    pub fn regions(&self) -> TVec<Region> {
70        let mut regions = tvec!();
71        if let Some(valid_range) = self.valid_range() {
72            if valid_range.start > 0 {
73                regions.extend(self.make_invalid_regions(0..valid_range.start));
74            }
75            if valid_range.start != valid_range.end {
76                regions.push(Region::new(valid_range.clone(), None));
77            }
78            if valid_range.end < self.output_dim {
79                regions.extend(self.make_invalid_regions(valid_range.end..self.output_dim));
80            }
81        } else {
82            regions.extend(self.make_invalid_regions(0..self.output_dim));
83        }
84        regions
85    }
86}
87
88#[cfg(test)]
89pub mod test {
90    use super::*;
91
92    // • 0 1 2 3 4 • -> 3 -> (0) 1 2 3 (4)
93    fn axis_5_3() -> PatchAxis {
94        PatchAxis::new(5, 3, 1, 1, 5, 1, 1)
95    }
96
97    // • • 0 1 2 3 4 • -> 4 -> (0) (1) 2 3 (4)
98    fn axis_5_4() -> PatchAxis {
99        PatchAxis::new(5, 4, 2, 1, 5, 1, 1)
100    }
101
102    // • • 0 1 2 3 4 • • -> 4 -> (0) (1) 2 (3) (4)
103    fn axis_5_5() -> PatchAxis {
104        PatchAxis::new(5, 5, 2, 2, 5, 1, 1)
105    }
106
107    // • 0 1 2 3 4 • -> 3 -> (0) 2 (4)
108    fn axis_5_3_s2() -> PatchAxis {
109        PatchAxis::new(5, 3, 1, 1, 3, 2, 1)
110    }
111
112    // • • 0 1 2 3 4 • • -> 3x2 -> (0) (1) 2 (3) (4)
113    fn axis_5_3_d2() -> PatchAxis {
114        PatchAxis::new(5, 3, 2, 2, 5, 1, 2)
115    }
116
117    // 0 1 2 3 4 5 6 7 8 9 -> 2 -> 0 3 6
118    fn axis_10_2_s3_valid() -> PatchAxis {
119        PatchAxis::new(10, 2, 0, 0, 3, 3, 1)
120    }
121
122    #[test]
123    fn axis_valid_ranges() {
124        assert_eq!(axis_5_3().valid_range(), Some(1..4));
125        assert_eq!(axis_5_4().valid_range(), Some(2..4));
126        assert_eq!(axis_5_5().valid_range(), Some(2..3));
127        assert_eq!(axis_5_3_s2().valid_range(), Some(1..2));
128        assert_eq!(axis_5_3_d2().valid_range(), Some(2..3));
129    }
130
131    #[test]
132    fn axis_invalid_at_left() {
133        assert_eq!(axis_5_3().invalid_at_left(0), 1);
134        assert_eq!(axis_5_3().invalid_at_left(1), 0);
135        assert_eq!(axis_5_3().invalid_at_left(2), 0);
136
137        assert_eq!(axis_5_4().invalid_at_left(0), 2);
138        assert_eq!(axis_5_4().invalid_at_left(1), 1);
139        assert_eq!(axis_5_4().invalid_at_left(2), 0);
140
141        assert_eq!(axis_5_5().invalid_at_left(0), 2);
142        assert_eq!(axis_5_5().invalid_at_left(1), 1);
143        assert_eq!(axis_5_5().invalid_at_left(2), 0);
144
145        assert_eq!(axis_5_3_d2().invalid_at_left(0), 1);
146        assert_eq!(axis_5_3_d2().invalid_at_left(1), 1);
147        assert_eq!(axis_5_3_d2().invalid_at_left(2), 0);
148    }
149
150    #[test]
151    fn axis_invalid_at_right() {
152        assert_eq!(axis_5_3().invalid_at_right(0), 0);
153        assert_eq!(axis_5_3().invalid_at_right(3), 0);
154        assert_eq!(axis_5_3().invalid_at_right(4), 1);
155
156        assert_eq!(axis_5_4().invalid_at_right(0), 0);
157        assert_eq!(axis_5_4().invalid_at_right(3), 0);
158        assert_eq!(axis_5_4().invalid_at_right(4), 1);
159
160        assert_eq!(axis_5_5().invalid_at_right(0), 0);
161        assert_eq!(axis_5_5().invalid_at_right(3), 1);
162        assert_eq!(axis_5_5().invalid_at_right(4), 2);
163    }
164
165    #[test]
166    fn axis_5_3_regions() {
167        let regions = axis_5_3().regions();
168        assert_eq!(
169            regions,
170            tvec!(
171                Region::new(0..1, Some(tvec!(true, false, false))),
172                Region::new(1..4, None),
173                Region::new(4..5, Some(tvec!(false, false, true)))
174            )
175        );
176    }
177
178    #[test]
179    fn axis_5_3_s2_regions() {
180        let regions = axis_5_3_s2().regions();
181        assert_eq!(
182            regions,
183            tvec!(
184                Region::new(0..1, Some(tvec!(true, false, false))),
185                Region::new(1..2, None),
186                Region::new(2..3, Some(tvec!(false, false, true)))
187            )
188        );
189    }
190
191    #[test]
192    fn axis_5_3_d2_regions() {
193        let regions = axis_5_3_d2().regions();
194        assert_eq!(
195            regions,
196            tvec!(
197                Region::new(0..2, Some(tvec!(true, false, false))),
198                Region::new(2..3, None),
199                Region::new(3..5, Some(tvec!(false, false, true)))
200            )
201        );
202    }
203
204    #[test]
205    fn axis_10_2_s3_valid_regions() {
206        let regions = axis_10_2_s3_valid().regions();
207        assert_eq!(regions, tvec!(Region::new(0..3, None),));
208    }
209
210    #[test]
211    fn axis_7_3_s2_regions() {
212        // • 0 1 2 3 4 5 6 • -> 3 -> (0) 2 4 (6)
213        let regions = PatchAxis::new(7, 3, 1, 1, 4, 2, 1).regions();
214        assert_eq!(
215            regions,
216            tvec!(
217                Region::new(0..1, Some(tvec!(true, false, false))),
218                Region::new(1..3, None),
219                Region::new(3..4, Some(tvec!(false, false, true)))
220            )
221        );
222    }
223
224    #[test]
225    fn axis_5_2_s2_regions() {
226        // • 0 1 2 3 4 • -> 2 -> (0) 2 4
227        let regions = PatchAxis::new(5, 2, 1, 1, 3, 2, 1).regions();
228        assert_eq!(
229            regions,
230            tvec!(Region::new(0..1, Some(tvec!(true, false))), Region::new(1..3, None),)
231        );
232    }
233
234    #[test]
235    fn axis_28_3_very_padded_regions() {
236        // • • 0 1 2 3 ... 26 27 • • -> 2 -> (-1) (0) (1) 2 3 4 ... 26 (27) (28) (29)
237        let regions = PatchAxis::new(28, 3, 2, 2, 30, 1, 1).regions();
238        assert_eq!(
239            regions,
240            tvec!(
241                Region::new(0..1, Some(tvec!(true, true, false))),
242                Region::new(1..2, Some(tvec!(true, false, false))),
243                Region::new(2..28, None),
244                Region::new(28..29, Some(tvec!(false, false, true))),
245                Region::new(29..30, Some(tvec!(false, true, true))),
246            )
247        );
248    }
249
250    #[test]
251    fn axis_7_1_s2_regions() {
252        // 0 1 2 3 4 5 6 -> 1 -> 0 2 4 6
253        let regions = PatchAxis::new(7, 1, 0, 0, 4, 2, 1).regions();
254        assert_eq!(regions, tvec!(Region::new(0..4, None),));
255    }
256
257    #[test]
258    fn axis_1_2_regions() {
259        // 0 -> 2 -> (0)
260        let regions = PatchAxis::new(1, 2, 0, 1, 1, 1, 1).regions();
261        assert_eq!(regions, tvec!(Region::new(0..1, Some(tvec!(false, true))),));
262    }
263
264    #[test]
265    fn axis_dnn_left_pad() {
266        let regions = PatchAxis::new(1, 1, 2, 0, 3, 1, 1).regions();
267        assert_eq!(regions, tvec!(Region::new(0..2, Some(tvec!(true))), Region::new(2..3, None)));
268    }
269}