1use crate::internal::*;
2
3use std::ops::Range;
4use tract_itertools::Itertools;
5
6#[derive(Clone, Debug, new, PartialEq, Eq)]
7pub struct Region {
8 pub range: Range<usize>,
9 pub mask: Option<TVec<bool>>,
10}
11
12#[derive(Clone, Debug, new, PartialEq, Eq)]
13pub struct PatchAxis {
14 pub input_dim: usize,
15 pub kernel_dim: usize,
16 pub pad_before: usize,
17 pub pad_after: usize,
18 pub output_dim: usize,
19 pub stride: usize,
20 pub dilation: usize,
21}
22
23impl PatchAxis {
24 fn valid_range(&self) -> Option<Range<usize>> {
25 let field = (self.kernel_dim - 1) * self.dilation + 1;
26 if field > self.input_dim {
27 return None;
28 }
29 let min = self.pad_before.divceil(self.stride);
30 let max = (self.input_dim + self.pad_before).saturating_sub(field) / self.stride;
31 if max >= min {
32 Some(min..(max + 1))
33 } else {
34 None
35 }
36 }
37
38 fn invalid_at_left(&self, pos: usize) -> usize {
39 let center_pos = pos * self.stride;
40 self.pad_before.saturating_sub(center_pos).divceil(self.dilation).min(self.kernel_dim)
41 }
42
43 fn invalid_at_right(&self, pos: usize) -> usize {
44 let center_pos = pos * self.stride;
45 let last_valid = self.input_dim + self.pad_before;
46 let valid = last_valid.saturating_sub(center_pos).divceil(self.dilation);
47 self.kernel_dim.saturating_sub(valid)
48 }
49
50 fn make_invalid_regions(&self, range: Range<usize>) -> TVec<Region> {
51 range
52 .map(move |ix| (ix, (self.invalid_at_left(ix), self.invalid_at_right(ix))))
53 .group_by(|&pair| pair.1)
54 .into_iter()
55 .map(move |(invalid, pairs)| {
56 let (min, max) = pairs.map(|p| p.0).minmax().into_option().unwrap();
57 let mut mask = tvec!(false; self.kernel_dim);
58 for i in 0..invalid.0 {
59 mask[i] = true;
60 }
61 for i in 0..invalid.1 {
62 mask[self.kernel_dim - 1 - i] = true;
63 }
64 Region::new(min..max + 1, Some(mask))
65 })
66 .collect()
67 }
68
69 pub fn regions(&self) -> TVec<Region> {
70 let mut regions = tvec!();
71 if let Some(valid_range) = self.valid_range() {
72 if valid_range.start > 0 {
73 regions.extend(self.make_invalid_regions(0..valid_range.start));
74 }
75 if valid_range.start != valid_range.end {
76 regions.push(Region::new(valid_range.clone(), None));
77 }
78 if valid_range.end < self.output_dim {
79 regions.extend(self.make_invalid_regions(valid_range.end..self.output_dim));
80 }
81 } else {
82 regions.extend(self.make_invalid_regions(0..self.output_dim));
83 }
84 regions
85 }
86}
87
88#[cfg(test)]
89pub mod test {
90 use super::*;
91
92 fn axis_5_3() -> PatchAxis {
94 PatchAxis::new(5, 3, 1, 1, 5, 1, 1)
95 }
96
97 fn axis_5_4() -> PatchAxis {
99 PatchAxis::new(5, 4, 2, 1, 5, 1, 1)
100 }
101
102 fn axis_5_5() -> PatchAxis {
104 PatchAxis::new(5, 5, 2, 2, 5, 1, 1)
105 }
106
107 fn axis_5_3_s2() -> PatchAxis {
109 PatchAxis::new(5, 3, 1, 1, 3, 2, 1)
110 }
111
112 fn axis_5_3_d2() -> PatchAxis {
114 PatchAxis::new(5, 3, 2, 2, 5, 1, 2)
115 }
116
117 fn axis_10_2_s3_valid() -> PatchAxis {
119 PatchAxis::new(10, 2, 0, 0, 3, 3, 1)
120 }
121
122 #[test]
123 fn axis_valid_ranges() {
124 assert_eq!(axis_5_3().valid_range(), Some(1..4));
125 assert_eq!(axis_5_4().valid_range(), Some(2..4));
126 assert_eq!(axis_5_5().valid_range(), Some(2..3));
127 assert_eq!(axis_5_3_s2().valid_range(), Some(1..2));
128 assert_eq!(axis_5_3_d2().valid_range(), Some(2..3));
129 }
130
131 #[test]
132 fn axis_invalid_at_left() {
133 assert_eq!(axis_5_3().invalid_at_left(0), 1);
134 assert_eq!(axis_5_3().invalid_at_left(1), 0);
135 assert_eq!(axis_5_3().invalid_at_left(2), 0);
136
137 assert_eq!(axis_5_4().invalid_at_left(0), 2);
138 assert_eq!(axis_5_4().invalid_at_left(1), 1);
139 assert_eq!(axis_5_4().invalid_at_left(2), 0);
140
141 assert_eq!(axis_5_5().invalid_at_left(0), 2);
142 assert_eq!(axis_5_5().invalid_at_left(1), 1);
143 assert_eq!(axis_5_5().invalid_at_left(2), 0);
144
145 assert_eq!(axis_5_3_d2().invalid_at_left(0), 1);
146 assert_eq!(axis_5_3_d2().invalid_at_left(1), 1);
147 assert_eq!(axis_5_3_d2().invalid_at_left(2), 0);
148 }
149
150 #[test]
151 fn axis_invalid_at_right() {
152 assert_eq!(axis_5_3().invalid_at_right(0), 0);
153 assert_eq!(axis_5_3().invalid_at_right(3), 0);
154 assert_eq!(axis_5_3().invalid_at_right(4), 1);
155
156 assert_eq!(axis_5_4().invalid_at_right(0), 0);
157 assert_eq!(axis_5_4().invalid_at_right(3), 0);
158 assert_eq!(axis_5_4().invalid_at_right(4), 1);
159
160 assert_eq!(axis_5_5().invalid_at_right(0), 0);
161 assert_eq!(axis_5_5().invalid_at_right(3), 1);
162 assert_eq!(axis_5_5().invalid_at_right(4), 2);
163 }
164
165 #[test]
166 fn axis_5_3_regions() {
167 let regions = axis_5_3().regions();
168 assert_eq!(
169 regions,
170 tvec!(
171 Region::new(0..1, Some(tvec!(true, false, false))),
172 Region::new(1..4, None),
173 Region::new(4..5, Some(tvec!(false, false, true)))
174 )
175 );
176 }
177
178 #[test]
179 fn axis_5_3_s2_regions() {
180 let regions = axis_5_3_s2().regions();
181 assert_eq!(
182 regions,
183 tvec!(
184 Region::new(0..1, Some(tvec!(true, false, false))),
185 Region::new(1..2, None),
186 Region::new(2..3, Some(tvec!(false, false, true)))
187 )
188 );
189 }
190
191 #[test]
192 fn axis_5_3_d2_regions() {
193 let regions = axis_5_3_d2().regions();
194 assert_eq!(
195 regions,
196 tvec!(
197 Region::new(0..2, Some(tvec!(true, false, false))),
198 Region::new(2..3, None),
199 Region::new(3..5, Some(tvec!(false, false, true)))
200 )
201 );
202 }
203
204 #[test]
205 fn axis_10_2_s3_valid_regions() {
206 let regions = axis_10_2_s3_valid().regions();
207 assert_eq!(regions, tvec!(Region::new(0..3, None),));
208 }
209
210 #[test]
211 fn axis_7_3_s2_regions() {
212 let regions = PatchAxis::new(7, 3, 1, 1, 4, 2, 1).regions();
214 assert_eq!(
215 regions,
216 tvec!(
217 Region::new(0..1, Some(tvec!(true, false, false))),
218 Region::new(1..3, None),
219 Region::new(3..4, Some(tvec!(false, false, true)))
220 )
221 );
222 }
223
224 #[test]
225 fn axis_5_2_s2_regions() {
226 let regions = PatchAxis::new(5, 2, 1, 1, 3, 2, 1).regions();
228 assert_eq!(
229 regions,
230 tvec!(Region::new(0..1, Some(tvec!(true, false))), Region::new(1..3, None),)
231 );
232 }
233
234 #[test]
235 fn axis_28_3_very_padded_regions() {
236 let regions = PatchAxis::new(28, 3, 2, 2, 30, 1, 1).regions();
238 assert_eq!(
239 regions,
240 tvec!(
241 Region::new(0..1, Some(tvec!(true, true, false))),
242 Region::new(1..2, Some(tvec!(true, false, false))),
243 Region::new(2..28, None),
244 Region::new(28..29, Some(tvec!(false, false, true))),
245 Region::new(29..30, Some(tvec!(false, true, true))),
246 )
247 );
248 }
249
250 #[test]
251 fn axis_7_1_s2_regions() {
252 let regions = PatchAxis::new(7, 1, 0, 0, 4, 2, 1).regions();
254 assert_eq!(regions, tvec!(Region::new(0..4, None),));
255 }
256
257 #[test]
258 fn axis_1_2_regions() {
259 let regions = PatchAxis::new(1, 2, 0, 1, 1, 1, 1).regions();
261 assert_eq!(regions, tvec!(Region::new(0..1, Some(tvec!(false, true))),));
262 }
263
264 #[test]
265 fn axis_dnn_left_pad() {
266 let regions = PatchAxis::new(1, 1, 2, 0, 3, 1, 1).regions();
267 assert_eq!(regions, tvec!(Region::new(0..2, Some(tvec!(true))), Region::new(2..3, None)));
268 }
269}