1use crate::internal::*;
2
3use std::ops::Range;
4use tract_itertools::Itertools;
5
6#[derive(Clone, Debug, new, PartialEq, Eq)]
7pub struct Region {
8 pub range: Range<usize>,
9 pub mask: Option<TVec<bool>>,
10}
11
12#[derive(Clone, Debug, new, PartialEq, Eq)]
13pub struct PatchAxis {
14 pub input_dim: usize,
15 pub kernel_dim: usize,
16 pub pad_before: usize,
17 pub pad_after: usize,
18 pub output_dim: usize,
19 pub stride: usize,
20 pub dilation: usize,
21}
22
23impl PatchAxis {
24 fn valid_range(&self) -> Option<Range<usize>> {
25 let field = (self.kernel_dim - 1) * self.dilation + 1;
26 if field > self.input_dim {
27 return None;
28 }
29 let min = self.pad_before.divceil(self.stride);
30 let max = (self.input_dim + self.pad_before).saturating_sub(field) / self.stride;
31 if max >= min { Some(min..(max + 1)) } else { None }
32 }
33
34 fn invalid_at_left(&self, pos: usize) -> usize {
35 let center_pos = pos * self.stride;
36 self.pad_before.saturating_sub(center_pos).divceil(self.dilation).min(self.kernel_dim)
37 }
38
39 fn invalid_at_right(&self, pos: usize) -> usize {
40 let center_pos = pos * self.stride;
41 let last_valid = self.input_dim + self.pad_before;
42 let valid = last_valid.saturating_sub(center_pos).divceil(self.dilation);
43 self.kernel_dim.saturating_sub(valid)
44 }
45
46 fn make_invalid_regions(&self, range: Range<usize>) -> TVec<Region> {
47 range
48 .map(move |ix| (ix, (self.invalid_at_left(ix), self.invalid_at_right(ix))))
49 .group_by(|&pair| pair.1)
50 .into_iter()
51 .map(move |(invalid, pairs)| {
52 let (min, max) = pairs.map(|p| p.0).minmax().into_option().unwrap();
53 let mut mask = tvec!(false; self.kernel_dim);
54 for i in 0..invalid.0 {
55 mask[i] = true;
56 }
57 for i in 0..invalid.1 {
58 mask[self.kernel_dim - 1 - i] = true;
59 }
60 Region::new(min..max + 1, Some(mask))
61 })
62 .collect()
63 }
64
65 pub fn regions(&self) -> TVec<Region> {
66 let mut regions = tvec!();
67 if let Some(valid_range) = self.valid_range() {
68 if valid_range.start > 0 {
69 regions.extend(self.make_invalid_regions(0..valid_range.start));
70 }
71 if valid_range.start != valid_range.end {
72 regions.push(Region::new(valid_range.clone(), None));
73 }
74 if valid_range.end < self.output_dim {
75 regions.extend(self.make_invalid_regions(valid_range.end..self.output_dim));
76 }
77 } else {
78 regions.extend(self.make_invalid_regions(0..self.output_dim));
79 }
80 regions
81 }
82}
83
84#[cfg(test)]
85pub mod test {
86 use super::*;
87
88 fn axis_5_3() -> PatchAxis {
90 PatchAxis::new(5, 3, 1, 1, 5, 1, 1)
91 }
92
93 fn axis_5_4() -> PatchAxis {
95 PatchAxis::new(5, 4, 2, 1, 5, 1, 1)
96 }
97
98 fn axis_5_5() -> PatchAxis {
100 PatchAxis::new(5, 5, 2, 2, 5, 1, 1)
101 }
102
103 fn axis_5_3_s2() -> PatchAxis {
105 PatchAxis::new(5, 3, 1, 1, 3, 2, 1)
106 }
107
108 fn axis_5_3_d2() -> PatchAxis {
110 PatchAxis::new(5, 3, 2, 2, 5, 1, 2)
111 }
112
113 fn axis_10_2_s3_valid() -> PatchAxis {
115 PatchAxis::new(10, 2, 0, 0, 3, 3, 1)
116 }
117
118 #[test]
119 fn axis_valid_ranges() {
120 assert_eq!(axis_5_3().valid_range(), Some(1..4));
121 assert_eq!(axis_5_4().valid_range(), Some(2..4));
122 assert_eq!(axis_5_5().valid_range(), Some(2..3));
123 assert_eq!(axis_5_3_s2().valid_range(), Some(1..2));
124 assert_eq!(axis_5_3_d2().valid_range(), Some(2..3));
125 }
126
127 #[test]
128 fn axis_invalid_at_left() {
129 assert_eq!(axis_5_3().invalid_at_left(0), 1);
130 assert_eq!(axis_5_3().invalid_at_left(1), 0);
131 assert_eq!(axis_5_3().invalid_at_left(2), 0);
132
133 assert_eq!(axis_5_4().invalid_at_left(0), 2);
134 assert_eq!(axis_5_4().invalid_at_left(1), 1);
135 assert_eq!(axis_5_4().invalid_at_left(2), 0);
136
137 assert_eq!(axis_5_5().invalid_at_left(0), 2);
138 assert_eq!(axis_5_5().invalid_at_left(1), 1);
139 assert_eq!(axis_5_5().invalid_at_left(2), 0);
140
141 assert_eq!(axis_5_3_d2().invalid_at_left(0), 1);
142 assert_eq!(axis_5_3_d2().invalid_at_left(1), 1);
143 assert_eq!(axis_5_3_d2().invalid_at_left(2), 0);
144 }
145
146 #[test]
147 fn axis_invalid_at_right() {
148 assert_eq!(axis_5_3().invalid_at_right(0), 0);
149 assert_eq!(axis_5_3().invalid_at_right(3), 0);
150 assert_eq!(axis_5_3().invalid_at_right(4), 1);
151
152 assert_eq!(axis_5_4().invalid_at_right(0), 0);
153 assert_eq!(axis_5_4().invalid_at_right(3), 0);
154 assert_eq!(axis_5_4().invalid_at_right(4), 1);
155
156 assert_eq!(axis_5_5().invalid_at_right(0), 0);
157 assert_eq!(axis_5_5().invalid_at_right(3), 1);
158 assert_eq!(axis_5_5().invalid_at_right(4), 2);
159 }
160
161 #[test]
162 fn axis_5_3_regions() {
163 let regions = axis_5_3().regions();
164 assert_eq!(
165 regions,
166 tvec!(
167 Region::new(0..1, Some(tvec!(true, false, false))),
168 Region::new(1..4, None),
169 Region::new(4..5, Some(tvec!(false, false, true)))
170 )
171 );
172 }
173
174 #[test]
175 fn axis_5_3_s2_regions() {
176 let regions = axis_5_3_s2().regions();
177 assert_eq!(
178 regions,
179 tvec!(
180 Region::new(0..1, Some(tvec!(true, false, false))),
181 Region::new(1..2, None),
182 Region::new(2..3, Some(tvec!(false, false, true)))
183 )
184 );
185 }
186
187 #[test]
188 fn axis_5_3_d2_regions() {
189 let regions = axis_5_3_d2().regions();
190 assert_eq!(
191 regions,
192 tvec!(
193 Region::new(0..2, Some(tvec!(true, false, false))),
194 Region::new(2..3, None),
195 Region::new(3..5, Some(tvec!(false, false, true)))
196 )
197 );
198 }
199
200 #[test]
201 fn axis_10_2_s3_valid_regions() {
202 let regions = axis_10_2_s3_valid().regions();
203 assert_eq!(regions, tvec!(Region::new(0..3, None),));
204 }
205
206 #[test]
207 fn axis_7_3_s2_regions() {
208 let regions = PatchAxis::new(7, 3, 1, 1, 4, 2, 1).regions();
210 assert_eq!(
211 regions,
212 tvec!(
213 Region::new(0..1, Some(tvec!(true, false, false))),
214 Region::new(1..3, None),
215 Region::new(3..4, Some(tvec!(false, false, true)))
216 )
217 );
218 }
219
220 #[test]
221 fn axis_5_2_s2_regions() {
222 let regions = PatchAxis::new(5, 2, 1, 1, 3, 2, 1).regions();
224 assert_eq!(
225 regions,
226 tvec!(Region::new(0..1, Some(tvec!(true, false))), Region::new(1..3, None),)
227 );
228 }
229
230 #[test]
231 fn axis_28_3_very_padded_regions() {
232 let regions = PatchAxis::new(28, 3, 2, 2, 30, 1, 1).regions();
234 assert_eq!(
235 regions,
236 tvec!(
237 Region::new(0..1, Some(tvec!(true, true, false))),
238 Region::new(1..2, Some(tvec!(true, false, false))),
239 Region::new(2..28, None),
240 Region::new(28..29, Some(tvec!(false, false, true))),
241 Region::new(29..30, Some(tvec!(false, true, true))),
242 )
243 );
244 }
245
246 #[test]
247 fn axis_7_1_s2_regions() {
248 let regions = PatchAxis::new(7, 1, 0, 0, 4, 2, 1).regions();
250 assert_eq!(regions, tvec!(Region::new(0..4, None),));
251 }
252
253 #[test]
254 fn axis_1_2_regions() {
255 let regions = PatchAxis::new(1, 2, 0, 1, 1, 1, 1).regions();
257 assert_eq!(regions, tvec!(Region::new(0..1, Some(tvec!(false, true))),));
258 }
259
260 #[test]
261 fn axis_dnn_left_pad() {
262 let regions = PatchAxis::new(1, 1, 2, 0, 3, 1, 1).regions();
263 assert_eq!(regions, tvec!(Region::new(0..2, Some(tvec!(true))), Region::new(2..3, None)));
264 }
265}