Skip to main content

tract_core/ops/nn/
grid_sample.rs

1use crate::internal::*;
2use crate::ops::math::round_ties_to_even;
3
4/// Interpolation mode for [`GridSample`].
5#[derive(Clone, Debug, PartialEq, Eq)]
6pub enum InterpolationMode {
7    Bilinear,
8    Nearest,
9    Bicubic,
10}
11
12impl InterpolationMode {
13    pub fn as_str(&self) -> &'static str {
14        match self {
15            InterpolationMode::Bilinear => "bilinear",
16            InterpolationMode::Nearest => "nearest",
17            InterpolationMode::Bicubic => "bicubic",
18        }
19    }
20
21    pub fn parse(s: &str) -> TractResult<Self> {
22        Ok(match s {
23            "bilinear" => InterpolationMode::Bilinear,
24            "nearest" => InterpolationMode::Nearest,
25            "bicubic" => InterpolationMode::Bicubic,
26            _ => bail!("Unsupported GridSample mode: {}", s),
27        })
28    }
29}
30
31/// Out-of-bounds padding policy for [`GridSample`].
32#[derive(Clone, Debug, PartialEq, Eq)]
33pub enum PaddingMode {
34    Zeros,
35    Border,
36    Reflection,
37}
38
39impl PaddingMode {
40    pub fn as_str(&self) -> &'static str {
41        match self {
42            PaddingMode::Zeros => "zeros",
43            PaddingMode::Border => "border",
44            PaddingMode::Reflection => "reflection",
45        }
46    }
47
48    pub fn parse(s: &str) -> TractResult<Self> {
49        Ok(match s {
50            "zeros" => PaddingMode::Zeros,
51            "border" => PaddingMode::Border,
52            "reflection" => PaddingMode::Reflection,
53            _ => bail!("Unsupported GridSample padding_mode: {}", s),
54        })
55    }
56}
57
58/// Samples `input` (N, C, D1..Dk) at the normalized coordinates carried by
59/// `grid` (N, O1..Ok, k), following the ONNX/PyTorch GridSample contract:
60/// k spatial dims, `mode` × `padding_mode` × `align_corners`.
61#[derive(Clone, Debug, PartialEq, Eq)]
62pub struct GridSample {
63    pub mode: InterpolationMode,
64    pub padding_mode: PaddingMode,
65    pub align_corners: bool,
66}
67
68impl GridSample {
69    fn denormalize(&self, coord: f32, size: usize) -> f32 {
70        if self.align_corners {
71            (coord + 1.0) / 2.0 * (size as f32 - 1.0)
72        } else {
73            ((coord + 1.0) * size as f32 - 1.0) / 2.0
74        }
75    }
76
77    fn bounds(&self, size: usize) -> (f32, f32) {
78        if self.align_corners { (0.0, size as f32 - 1.0) } else { (-0.5, size as f32 - 0.5) }
79    }
80
81    fn pixel_at_nd(
82        &self,
83        x: &tract_ndarray::ArrayViewD<'_, f32>,
84        batch: usize,
85        channel: usize,
86        coords: &[isize],
87        spatial_sizes: &[usize],
88    ) -> f32 {
89        match self.padding_mode {
90            PaddingMode::Zeros => {
91                for (&c, &s) in coords.iter().zip(spatial_sizes.iter()) {
92                    if c < 0 || c >= s as isize {
93                        return 0.0;
94                    }
95                }
96                let mut idx = vec![batch, channel];
97                idx.extend(coords.iter().map(|&c| c as usize));
98                x[idx.as_slice()]
99            }
100            PaddingMode::Border => {
101                let mut idx = vec![batch, channel];
102                for (&c, &s) in coords.iter().zip(spatial_sizes.iter()) {
103                    idx.push((c.max(0) as usize).min(s - 1));
104                }
105                x[idx.as_slice()]
106            }
107            PaddingMode::Reflection => {
108                let mut idx = vec![batch, channel];
109                for (&c, &s) in coords.iter().zip(spatial_sizes.iter()) {
110                    let (lo, hi) = self.bounds(s);
111                    idx.push(gs_reflect(c as f32, lo, hi) as usize);
112                }
113                x[idx.as_slice()]
114            }
115        }
116    }
117
118    fn apply_padding(&self, coord: f32, lo: f32, hi: f32) -> f32 {
119        match self.padding_mode {
120            PaddingMode::Border => coord.clamp(0.0, hi + lo),
121            PaddingMode::Reflection => gs_reflect(coord, lo, hi),
122            PaddingMode::Zeros => coord,
123        }
124    }
125
126    fn is_oob(&self, coords: &[f32], bounds: &[(f32, f32)]) -> bool {
127        coords.iter().zip(bounds.iter()).any(|(&c, &(lo, hi))| c < lo || c > hi)
128    }
129
130    fn pad_coords(&self, coords: &mut [f32], bounds: &[(f32, f32)]) {
131        for (c, &(lo, hi)) in coords.iter_mut().zip(bounds.iter()) {
132            *c = self.apply_padding(*c, lo, hi);
133        }
134    }
135
136    fn sample_nd(
137        &self,
138        x: &tract_ndarray::ArrayViewD<'_, f32>,
139        batch: usize,
140        channel: usize,
141        pixel_coords: &[f32],
142        spatial_sizes: &[usize],
143    ) -> f32 {
144        let ndim = pixel_coords.len();
145        let bounds: Vec<(f32, f32)> = spatial_sizes.iter().map(|&s| self.bounds(s)).collect();
146
147        match self.mode {
148            InterpolationMode::Nearest => {
149                let mut coords: Vec<f32> =
150                    pixel_coords.iter().map(|&c| round_ties_to_even(c)).collect();
151                if self.is_oob(&coords, &bounds) {
152                    self.pad_coords(&mut coords, &bounds);
153                }
154                let icoords: Vec<isize> = coords.iter().map(|&c| c as isize).collect();
155                self.pixel_at_nd(x, batch, channel, &icoords, spatial_sizes)
156            }
157            InterpolationMode::Bilinear => {
158                let mut coords: Vec<f32> = pixel_coords.to_vec();
159                if self.is_oob(&coords, &bounds) {
160                    self.pad_coords(&mut coords, &bounds);
161                }
162                let num_corners = 1 << ndim;
163                let mut result = 0.0f32;
164                for corner in 0..num_corners {
165                    let mut weight = 1.0f32;
166                    let mut icoords = Vec::with_capacity(ndim);
167                    for (d, &c) in coords.iter().enumerate() {
168                        let lo = c.floor() as isize;
169                        if (corner >> d) & 1 == 0 {
170                            icoords.push(lo);
171                            weight *= (lo + 1) as f32 - c;
172                        } else {
173                            icoords.push(lo + 1);
174                            weight *= c - lo as f32;
175                        }
176                    }
177                    result += weight * self.pixel_at_nd(x, batch, channel, &icoords, spatial_sizes);
178                }
179                result
180            }
181            InterpolationMode::Bicubic => {
182                assert!(ndim == 2, "Bicubic interpolation only supports 2D spatial dimensions");
183                let (mut px, mut py) = (pixel_coords[0], pixel_coords[1]);
184                if self.is_oob(&[px, py], &bounds) {
185                    px = self.apply_padding(px, bounds[0].0, bounds[0].1);
186                    py = self.apply_padding(py, bounds[1].0, bounds[1].1);
187                }
188                let x0 = px.floor() as isize - 1;
189                let y0 = py.floor() as isize - 1;
190                let dx = px - x0 as f32 - 1.0;
191                let dy = py - y0 as f32 - 1.0;
192
193                let mut p = [[0.0f32; 4]; 4];
194                for (h, row) in p.iter_mut().enumerate() {
195                    for (w, val) in row.iter_mut().enumerate() {
196                        *val = self.pixel_at_nd(
197                            x,
198                            batch,
199                            channel,
200                            &[x0 + w as isize, y0 + h as isize],
201                            spatial_sizes,
202                        );
203                    }
204                }
205                bicubic_interpolate(&p, dx, dy)
206            }
207        }
208    }
209}
210
211fn gs_reflect(x: f32, x_min: f32, x_max: f32) -> f32 {
212    let rng = x_max - x_min;
213    if rng == 0.0 {
214        return x_min;
215    }
216    if x < x_min {
217        let dx = x_min - x;
218        let n = (dx / rng) as i32;
219        let r = dx - n as f32 * rng;
220        if n % 2 == 0 { x_min + r } else { x_max - r }
221    } else if x > x_max {
222        let dx = x - x_max;
223        let n = (dx / rng) as i32;
224        let r = dx - n as f32 * rng;
225        if n % 2 == 0 { x_max - r } else { x_min + r }
226    } else {
227        x
228    }
229}
230
231fn bicubic_interpolate(p: &[[f32; 4]; 4], dx: f32, dy: f32) -> f32 {
232    let mut v = [0.0f32; 4];
233    let mut coeffs = [0.0f32; 4];
234    cubic_coeffs(dx, &mut coeffs);
235    for i in 0..4 {
236        v[i] =
237            coeffs[0] * p[i][0] + coeffs[1] * p[i][1] + coeffs[2] * p[i][2] + coeffs[3] * p[i][3];
238    }
239    cubic_coeffs(dy, &mut coeffs);
240    coeffs[0] * v[0] + coeffs[1] * v[1] + coeffs[2] * v[2] + coeffs[3] * v[3]
241}
242
243fn cubic_coeffs(x: f32, coeffs: &mut [f32; 4]) {
244    let a = -0.75f32;
245    let xp1 = x + 1.0;
246    let xm1 = 1.0 - x;
247    let xm2 = 2.0 - x;
248    coeffs[0] = ((a * xp1 - 5.0 * a) * xp1 + 8.0 * a) * xp1 - 4.0 * a;
249    coeffs[1] = ((a + 2.0) * x - (a + 3.0)) * x * x + 1.0;
250    coeffs[2] = ((a + 2.0) * xm1 - (a + 3.0)) * xm1 * xm1 + 1.0;
251    coeffs[3] = ((a * xm2 - 5.0 * a) * xm2 + 8.0 * a) * xm2 - 4.0 * a;
252}
253
254impl Op for GridSample {
255    fn name(&self) -> StaticName {
256        "GridSample".into()
257    }
258
259    op_as_typed_op!();
260}
261
262impl EvalOp for GridSample {
263    fn is_stateless(&self) -> bool {
264        true
265    }
266
267    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
268        let (x, grid) = args_2!(inputs);
269        let input_dt = x.datum_type();
270        let x_tensor = x.into_tensor();
271        let x_cow = x_tensor.cast_to::<f32>()?;
272        let x = x_cow.to_plain_array_view::<f32>()?;
273        let grid_tensor = grid.into_tensor();
274        let grid_cow = grid_tensor.cast_to::<f32>()?;
275        let grid = grid_cow.to_plain_array_view::<f32>()?;
276
277        let x_shape = x.shape();
278        let grid_shape = grid.shape();
279        let rank = x_shape.len();
280        let spatial_rank = rank - 2;
281
282        let n_batch = x_shape[0];
283        let n_channel = x_shape[1];
284        let spatial_sizes: Vec<usize> = x_shape[2..].to_vec();
285
286        let mut output_shape = vec![n_batch, n_channel];
287        output_shape.extend_from_slice(&grid_shape[1..rank - 1]);
288
289        let output = tract_ndarray::ArrayD::from_shape_fn(&*output_shape, |idx| -> f32 {
290            let batch = idx[0];
291            let channel = idx[1];
292            let out_spatial: Vec<usize> = (2..rank).map(|d| idx[d]).collect();
293
294            let mut grid_idx = vec![batch];
295            grid_idx.extend_from_slice(&out_spatial);
296            grid_idx.push(0);
297
298            let mut pixel_coords = Vec::with_capacity(spatial_rank);
299            for (d, &size) in spatial_sizes.iter().enumerate() {
300                *grid_idx.last_mut().unwrap() = spatial_rank - 1 - d;
301                let norm_coord = grid[grid_idx.as_slice()];
302                pixel_coords.push(self.denormalize(norm_coord, size));
303            }
304
305            self.sample_nd(&x, batch, channel, &pixel_coords, &spatial_sizes)
306        });
307
308        Ok(tvec!(output.into_tensor().cast_to_dt(input_dt)?.into_owned().into_tvalue()))
309    }
310}
311
312impl TypedOp for GridSample {
313    as_op!();
314
315    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
316        let x_shape = &inputs[0].shape;
317        let grid_shape = &inputs[1].shape;
318        let rank = x_shape.len();
319
320        let mut output_shape: TVec<TDim> = tvec![x_shape[0].clone(), x_shape[1].clone()];
321        for d in 1..rank - 1 {
322            output_shape.push(grid_shape[d].clone());
323        }
324
325        Ok(tvec!(inputs[0].datum_type.fact(&output_shape)))
326    }
327}