1use crate::internal::*;
2use crate::ops::math::round_ties_to_even;
3
4#[derive(Clone, Debug, PartialEq, Eq)]
6pub enum InterpolationMode {
7 Bilinear,
8 Nearest,
9 Bicubic,
10}
11
12impl InterpolationMode {
13 pub fn as_str(&self) -> &'static str {
14 match self {
15 InterpolationMode::Bilinear => "bilinear",
16 InterpolationMode::Nearest => "nearest",
17 InterpolationMode::Bicubic => "bicubic",
18 }
19 }
20
21 pub fn parse(s: &str) -> TractResult<Self> {
22 Ok(match s {
23 "bilinear" => InterpolationMode::Bilinear,
24 "nearest" => InterpolationMode::Nearest,
25 "bicubic" => InterpolationMode::Bicubic,
26 _ => bail!("Unsupported GridSample mode: {}", s),
27 })
28 }
29}
30
31#[derive(Clone, Debug, PartialEq, Eq)]
33pub enum PaddingMode {
34 Zeros,
35 Border,
36 Reflection,
37}
38
39impl PaddingMode {
40 pub fn as_str(&self) -> &'static str {
41 match self {
42 PaddingMode::Zeros => "zeros",
43 PaddingMode::Border => "border",
44 PaddingMode::Reflection => "reflection",
45 }
46 }
47
48 pub fn parse(s: &str) -> TractResult<Self> {
49 Ok(match s {
50 "zeros" => PaddingMode::Zeros,
51 "border" => PaddingMode::Border,
52 "reflection" => PaddingMode::Reflection,
53 _ => bail!("Unsupported GridSample padding_mode: {}", s),
54 })
55 }
56}
57
58#[derive(Clone, Debug, PartialEq, Eq)]
62pub struct GridSample {
63 pub mode: InterpolationMode,
64 pub padding_mode: PaddingMode,
65 pub align_corners: bool,
66}
67
68impl GridSample {
69 fn denormalize(&self, coord: f32, size: usize) -> f32 {
70 if self.align_corners {
71 (coord + 1.0) / 2.0 * (size as f32 - 1.0)
72 } else {
73 ((coord + 1.0) * size as f32 - 1.0) / 2.0
74 }
75 }
76
77 fn bounds(&self, size: usize) -> (f32, f32) {
78 if self.align_corners { (0.0, size as f32 - 1.0) } else { (-0.5, size as f32 - 0.5) }
79 }
80
81 fn pixel_at_nd(
82 &self,
83 x: &tract_ndarray::ArrayViewD<'_, f32>,
84 batch: usize,
85 channel: usize,
86 coords: &[isize],
87 spatial_sizes: &[usize],
88 ) -> f32 {
89 match self.padding_mode {
90 PaddingMode::Zeros => {
91 for (&c, &s) in coords.iter().zip(spatial_sizes.iter()) {
92 if c < 0 || c >= s as isize {
93 return 0.0;
94 }
95 }
96 let mut idx = vec![batch, channel];
97 idx.extend(coords.iter().map(|&c| c as usize));
98 x[idx.as_slice()]
99 }
100 PaddingMode::Border => {
101 let mut idx = vec![batch, channel];
102 for (&c, &s) in coords.iter().zip(spatial_sizes.iter()) {
103 idx.push((c.max(0) as usize).min(s - 1));
104 }
105 x[idx.as_slice()]
106 }
107 PaddingMode::Reflection => {
108 let mut idx = vec![batch, channel];
109 for (&c, &s) in coords.iter().zip(spatial_sizes.iter()) {
110 let (lo, hi) = self.bounds(s);
111 idx.push(gs_reflect(c as f32, lo, hi) as usize);
112 }
113 x[idx.as_slice()]
114 }
115 }
116 }
117
118 fn apply_padding(&self, coord: f32, lo: f32, hi: f32) -> f32 {
119 match self.padding_mode {
120 PaddingMode::Border => coord.clamp(0.0, hi + lo),
121 PaddingMode::Reflection => gs_reflect(coord, lo, hi),
122 PaddingMode::Zeros => coord,
123 }
124 }
125
126 fn is_oob(&self, coords: &[f32], bounds: &[(f32, f32)]) -> bool {
127 coords.iter().zip(bounds.iter()).any(|(&c, &(lo, hi))| c < lo || c > hi)
128 }
129
130 fn pad_coords(&self, coords: &mut [f32], bounds: &[(f32, f32)]) {
131 for (c, &(lo, hi)) in coords.iter_mut().zip(bounds.iter()) {
132 *c = self.apply_padding(*c, lo, hi);
133 }
134 }
135
136 fn sample_nd(
137 &self,
138 x: &tract_ndarray::ArrayViewD<'_, f32>,
139 batch: usize,
140 channel: usize,
141 pixel_coords: &[f32],
142 spatial_sizes: &[usize],
143 ) -> f32 {
144 let ndim = pixel_coords.len();
145 let bounds: Vec<(f32, f32)> = spatial_sizes.iter().map(|&s| self.bounds(s)).collect();
146
147 match self.mode {
148 InterpolationMode::Nearest => {
149 let mut coords: Vec<f32> =
150 pixel_coords.iter().map(|&c| round_ties_to_even(c)).collect();
151 if self.is_oob(&coords, &bounds) {
152 self.pad_coords(&mut coords, &bounds);
153 }
154 let icoords: Vec<isize> = coords.iter().map(|&c| c as isize).collect();
155 self.pixel_at_nd(x, batch, channel, &icoords, spatial_sizes)
156 }
157 InterpolationMode::Bilinear => {
158 let mut coords: Vec<f32> = pixel_coords.to_vec();
159 if self.is_oob(&coords, &bounds) {
160 self.pad_coords(&mut coords, &bounds);
161 }
162 let num_corners = 1 << ndim;
163 let mut result = 0.0f32;
164 for corner in 0..num_corners {
165 let mut weight = 1.0f32;
166 let mut icoords = Vec::with_capacity(ndim);
167 for (d, &c) in coords.iter().enumerate() {
168 let lo = c.floor() as isize;
169 if (corner >> d) & 1 == 0 {
170 icoords.push(lo);
171 weight *= (lo + 1) as f32 - c;
172 } else {
173 icoords.push(lo + 1);
174 weight *= c - lo as f32;
175 }
176 }
177 result += weight * self.pixel_at_nd(x, batch, channel, &icoords, spatial_sizes);
178 }
179 result
180 }
181 InterpolationMode::Bicubic => {
182 assert!(ndim == 2, "Bicubic interpolation only supports 2D spatial dimensions");
183 let (mut px, mut py) = (pixel_coords[0], pixel_coords[1]);
184 if self.is_oob(&[px, py], &bounds) {
185 px = self.apply_padding(px, bounds[0].0, bounds[0].1);
186 py = self.apply_padding(py, bounds[1].0, bounds[1].1);
187 }
188 let x0 = px.floor() as isize - 1;
189 let y0 = py.floor() as isize - 1;
190 let dx = px - x0 as f32 - 1.0;
191 let dy = py - y0 as f32 - 1.0;
192
193 let mut p = [[0.0f32; 4]; 4];
194 for (h, row) in p.iter_mut().enumerate() {
195 for (w, val) in row.iter_mut().enumerate() {
196 *val = self.pixel_at_nd(
197 x,
198 batch,
199 channel,
200 &[x0 + w as isize, y0 + h as isize],
201 spatial_sizes,
202 );
203 }
204 }
205 bicubic_interpolate(&p, dx, dy)
206 }
207 }
208 }
209}
210
211fn gs_reflect(x: f32, x_min: f32, x_max: f32) -> f32 {
212 let rng = x_max - x_min;
213 if rng == 0.0 {
214 return x_min;
215 }
216 if x < x_min {
217 let dx = x_min - x;
218 let n = (dx / rng) as i32;
219 let r = dx - n as f32 * rng;
220 if n % 2 == 0 { x_min + r } else { x_max - r }
221 } else if x > x_max {
222 let dx = x - x_max;
223 let n = (dx / rng) as i32;
224 let r = dx - n as f32 * rng;
225 if n % 2 == 0 { x_max - r } else { x_min + r }
226 } else {
227 x
228 }
229}
230
231fn bicubic_interpolate(p: &[[f32; 4]; 4], dx: f32, dy: f32) -> f32 {
232 let mut v = [0.0f32; 4];
233 let mut coeffs = [0.0f32; 4];
234 cubic_coeffs(dx, &mut coeffs);
235 for i in 0..4 {
236 v[i] =
237 coeffs[0] * p[i][0] + coeffs[1] * p[i][1] + coeffs[2] * p[i][2] + coeffs[3] * p[i][3];
238 }
239 cubic_coeffs(dy, &mut coeffs);
240 coeffs[0] * v[0] + coeffs[1] * v[1] + coeffs[2] * v[2] + coeffs[3] * v[3]
241}
242
243fn cubic_coeffs(x: f32, coeffs: &mut [f32; 4]) {
244 let a = -0.75f32;
245 let xp1 = x + 1.0;
246 let xm1 = 1.0 - x;
247 let xm2 = 2.0 - x;
248 coeffs[0] = ((a * xp1 - 5.0 * a) * xp1 + 8.0 * a) * xp1 - 4.0 * a;
249 coeffs[1] = ((a + 2.0) * x - (a + 3.0)) * x * x + 1.0;
250 coeffs[2] = ((a + 2.0) * xm1 - (a + 3.0)) * xm1 * xm1 + 1.0;
251 coeffs[3] = ((a * xm2 - 5.0 * a) * xm2 + 8.0 * a) * xm2 - 4.0 * a;
252}
253
254impl Op for GridSample {
255 fn name(&self) -> StaticName {
256 "GridSample".into()
257 }
258
259 op_as_typed_op!();
260}
261
262impl EvalOp for GridSample {
263 fn is_stateless(&self) -> bool {
264 true
265 }
266
267 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
268 let (x, grid) = args_2!(inputs);
269 let input_dt = x.datum_type();
270 let x_tensor = x.into_tensor();
271 let x_cow = x_tensor.cast_to::<f32>()?;
272 let x = x_cow.to_plain_array_view::<f32>()?;
273 let grid_tensor = grid.into_tensor();
274 let grid_cow = grid_tensor.cast_to::<f32>()?;
275 let grid = grid_cow.to_plain_array_view::<f32>()?;
276
277 let x_shape = x.shape();
278 let grid_shape = grid.shape();
279 let rank = x_shape.len();
280 let spatial_rank = rank - 2;
281
282 let n_batch = x_shape[0];
283 let n_channel = x_shape[1];
284 let spatial_sizes: Vec<usize> = x_shape[2..].to_vec();
285
286 let mut output_shape = vec![n_batch, n_channel];
287 output_shape.extend_from_slice(&grid_shape[1..rank - 1]);
288
289 let output = tract_ndarray::ArrayD::from_shape_fn(&*output_shape, |idx| -> f32 {
290 let batch = idx[0];
291 let channel = idx[1];
292 let out_spatial: Vec<usize> = (2..rank).map(|d| idx[d]).collect();
293
294 let mut grid_idx = vec![batch];
295 grid_idx.extend_from_slice(&out_spatial);
296 grid_idx.push(0);
297
298 let mut pixel_coords = Vec::with_capacity(spatial_rank);
299 for (d, &size) in spatial_sizes.iter().enumerate() {
300 *grid_idx.last_mut().unwrap() = spatial_rank - 1 - d;
301 let norm_coord = grid[grid_idx.as_slice()];
302 pixel_coords.push(self.denormalize(norm_coord, size));
303 }
304
305 self.sample_nd(&x, batch, channel, &pixel_coords, &spatial_sizes)
306 });
307
308 Ok(tvec!(output.into_tensor().cast_to_dt(input_dt)?.into_owned().into_tvalue()))
309 }
310}
311
312impl TypedOp for GridSample {
313 as_op!();
314
315 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
316 let x_shape = &inputs[0].shape;
317 let grid_shape = &inputs[1].shape;
318 let rank = x_shape.len();
319
320 let mut output_shape: TVec<TDim> = tvec![x_shape[0].clone(), x_shape[1].clone()];
321 for d in 1..rank - 1 {
322 output_shape.push(grid_shape[d].clone());
323 }
324
325 Ok(tvec!(inputs[0].datum_type.fact(&output_shape)))
326 }
327}