Skip to main content

svod_tensor/nn/
grid_sample.rs

1//! GridSample: spatial sampling via coordinate grids (ONNX GridSample operator).
2
3use bon::bon;
4use snafu::ResultExt;
5use svod_dtype::DType;
6use svod_ir::ConstValue;
7
8use crate::Tensor;
9use crate::error::{NdimMinimumSnafu, UOpSnafu};
10use crate::shape_ops::MeshgridIndexing;
11
12use super::{GridSampleMode, GridSamplePaddingMode};
13
14type Result<T> = crate::Result<T>;
15
16#[bon]
17impl Tensor {
18    /// Generate an affine sampling grid from transformation parameters.
19    ///
20    /// Produces a grid of normalized coordinates suitable for [`grid_sample`](Tensor::grid_sample).
21    /// `theta` holds affine matrices of shape `[N, spatial_dims, spatial_dims+1]`.
22    /// `size` is the target output shape `[N, C, *spatial_dims]`.
23    ///
24    /// # Examples
25    ///
26    /// Identity transform producing a 4x4 grid:
27    ///
28    /// ```
29    /// # use svod_tensor::Tensor;
30    /// # use ndarray::array;
31    /// let theta = Tensor::from_ndarray(&array![[[1.0f32, 0.0, 0.0], [0.0, 1.0, 0.0]]]);
32    /// let grid = Tensor::affine_grid().theta(&theta).size(&[1, 1, 4, 4]).call().unwrap();
33    /// let shape: Vec<usize> = grid.shape().unwrap().iter()
34    ///     .map(|d| d.as_const().unwrap()).collect();
35    /// assert_eq!(shape, vec![1, 4, 4, 2]); // [N, H, W, 2]
36    /// ```
37    ///
38    /// With `align_corners`:
39    ///
40    /// ```
41    /// # use svod_tensor::Tensor;
42    /// # use ndarray::array;
43    /// let theta = Tensor::from_ndarray(&array![[[1.0f32, 0.0, 0.0], [0.0, 1.0, 0.0]]]);
44    /// let grid = Tensor::affine_grid()
45    ///     .theta(&theta)
46    ///     .size(&[1, 1, 4, 4])
47    ///     .align_corners(true)
48    ///     .call()
49    ///     .unwrap();
50    /// let shape: Vec<usize> = grid.shape().unwrap().iter()
51    ///     .map(|d| d.as_const().unwrap()).collect();
52    /// assert_eq!(shape, vec![1, 4, 4, 2]);
53    /// ```
54    #[builder]
55    pub fn affine_grid(
56        theta: &Tensor,
57        size: &[i64],
58        #[builder(default = false)] align_corners: bool,
59    ) -> Result<Tensor> {
60        snafu::ensure!(size.len() >= 3, NdimMinimumSnafu { op: "affine_grid", min: 3_usize, actual: size.len() });
61        let n = size[0] as usize;
62        let ndim = size.len() - 2; // spatial dims
63
64        let spatial_dims: Vec<usize> = size[2..].iter().map(|&s| s as usize).collect();
65        let mut grids = Vec::with_capacity(ndim);
66        for &dim_size in &spatial_dims {
67            let g = if align_corners {
68                Tensor::linspace(-1.0, 1.0, dim_size, DType::Float32)?
69            } else {
70                let start = -1.0 + 1.0 / dim_size as f64;
71                let end = 1.0 - 1.0 / dim_size as f64;
72                Tensor::linspace(start, end, dim_size, DType::Float32)?
73            };
74            grids.push(g);
75        }
76
77        let grid_refs: Vec<&Tensor> = grids.iter().collect();
78        let mesh = Tensor::meshgrid(&grid_refs, MeshgridIndexing::Ij)?;
79
80        let total_elements: usize = spatial_dims.iter().product();
81        let flat_shape = [total_elements as isize];
82        let mut components: Vec<Tensor> = Vec::with_capacity(ndim + 1);
83        for g in mesh.iter().rev() {
84            components.push(g.try_reshape(flat_shape)?);
85        }
86        components.push(Tensor::full(&[total_elements], 1.0, DType::Float32)?);
87
88        let comp_refs: Vec<&Tensor> = components.iter().collect();
89        let base_grid = Tensor::cat(&comp_refs, 0)?
90            .try_reshape([(ndim + 1) as isize, total_elements as isize])?
91            .try_transpose(0, 1)?;
92
93        let base_grid =
94            base_grid.try_unsqueeze(0)?.try_expand([n as isize, total_elements as isize, (ndim + 1) as isize])?;
95
96        let theta_t = theta.try_transpose(1, 2)?;
97        let output = base_grid.matmul(&theta_t)?;
98
99        let mut out_shape: Vec<isize> = vec![n as isize];
100        out_shape.extend(spatial_dims.iter().map(|&d| d as isize));
101        out_shape.push(ndim as isize);
102        output.try_reshape(&out_shape)
103    }
104
105    /// Sample input at positions specified by a coordinate grid.
106    ///
107    /// - `self`: Input tensor `[N, C, *spatial_dims]`
108    /// - `grid`: Coordinate grid `[N, *output_spatial_dims, n_spatial]` with values in `[-1, 1]`
109    /// - Returns: `[N, C, *output_spatial_dims]`
110    ///
111    /// # Examples
112    ///
113    /// Sample with a grid from `affine_grid`:
114    ///
115    /// ```
116    /// # use svod_tensor::Tensor;
117    /// # use ndarray::{array, Array4};
118    /// let theta = Tensor::from_ndarray(&array![[[1.0f32, 0.0, 0.0], [0.0, 1.0, 0.0]]]);
119    /// let grid = Tensor::affine_grid().theta(&theta).size(&[1, 1, 4, 4]).call().unwrap();
120    /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 4, 4), 1.0f32));
121    /// let y = x.grid_sample().grid(&grid).call().unwrap();
122    /// let shape: Vec<usize> = y.shape().unwrap().iter()
123    ///     .map(|d| d.as_const().unwrap()).collect();
124    /// assert_eq!(shape, vec![1, 1, 4, 4]);
125    /// ```
126    ///
127    /// With nearest-mode interpolation:
128    ///
129    /// ```
130    /// # use svod_tensor::Tensor;
131    /// # use svod_tensor::nn::GridSampleMode;
132    /// # use ndarray::{array, Array4};
133    /// let theta = Tensor::from_ndarray(&array![[[1.0f32, 0.0, 0.0], [0.0, 1.0, 0.0]]]);
134    /// let grid = Tensor::affine_grid().theta(&theta).size(&[1, 1, 4, 4]).call().unwrap();
135    /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 4, 4), 1.0f32));
136    /// let y = x.grid_sample()
137    ///     .grid(&grid)
138    ///     .mode(GridSampleMode::Nearest)
139    ///     .call()
140    ///     .unwrap();
141    /// let shape: Vec<usize> = y.shape().unwrap().iter()
142    ///     .map(|d| d.as_const().unwrap()).collect();
143    /// assert_eq!(shape, vec![1, 1, 4, 4]);
144    /// ```
145    #[builder]
146    pub fn grid_sample(
147        &self,
148        grid: &Tensor,
149        #[builder(default)] mode: GridSampleMode,
150        #[builder(default)] padding_mode: GridSamplePaddingMode,
151        #[builder(default = false)] align_corners: bool,
152    ) -> Result<Tensor> {
153        let x_ndim = self.ndim()?;
154        snafu::ensure!(x_ndim >= 3, NdimMinimumSnafu { op: "grid_sample", min: 3_usize, actual: x_ndim });
155        let x_shape = self.shape()?;
156        let grid_shape = grid.shape()?;
157        let x_dims = svod_ir::shape::to_vec_usize(&x_shape).context(UOpSnafu)?;
158        let grid_dims = svod_ir::shape::to_vec_usize(&grid_shape).context(UOpSnafu)?;
159        let n_spatial = x_dims.len() - 2;
160
161        let n = x_dims[0];
162        let c = x_dims[1];
163        let spatial: Vec<usize> = x_dims[2..].to_vec();
164        let out_spatial: Vec<usize> = grid_dims[1..grid_dims.len() - 1].to_vec();
165        let spatial_prod: usize = spatial.iter().product();
166        let out_prod: usize = out_spatial.iter().product();
167        let dtype = self.uop().dtype();
168
169        // Flatten X spatial: [N, C, prod(spatial)]
170        let x_flat = self.try_reshape([n as isize, c as isize, spatial_prod as isize])?;
171
172        // Flatten grid spatial: [N, prod(out_spatial), n_spatial]
173        let grid_flat = grid.try_reshape([n as isize, out_prod as isize, n_spatial as isize])?;
174
175        // Strides for flat index: stride[i] = product of spatial[i+1..]
176        let strides = compute_strides(&spatial);
177
178        // Extract, denormalize coordinates for each spatial dim.
179        // Grid stores coords in reverse order: grid[...,0]=x→last spatial dim, etc.
180        let mut coords: Vec<Tensor> = Vec::with_capacity(n_spatial);
181        for (i, &dim_size) in spatial.iter().enumerate() {
182            let grid_idx = n_spatial - 1 - i;
183            let coord = slice_last_dim(&grid_flat, grid_idx, n, out_prod)?;
184            let denorm = gs_denormalize(&coord, dim_size, align_corners, &dtype)?;
185            coords.push(denorm);
186        }
187
188        // Apply padding mode to float coordinates before interpolation
189        let coords = match padding_mode {
190            GridSamplePaddingMode::Border => coords
191                .iter()
192                .enumerate()
193                .map(|(i, c)| {
194                    let zero = Tensor::const_(0.0, dtype.clone());
195                    let max_val = Tensor::const_((spatial[i] - 1) as f64, dtype.clone());
196                    c.clamp().min(&zero).max(&max_val).call()
197                })
198                .collect::<Result<Vec<_>>>()?,
199            GridSamplePaddingMode::Reflection => coords
200                .iter()
201                .enumerate()
202                .map(|(i, c)| gs_reflect(c, spatial[i], align_corners, &dtype))
203                .collect::<Result<Vec<_>>>()?,
204            GridSamplePaddingMode::Zeros => coords,
205        };
206
207        let result = match mode {
208            GridSampleMode::Nearest => {
209                interpolate_nearest(&x_flat, &coords, &spatial, &strides, padding_mode, n, c, out_prod, &dtype)?
210            }
211            GridSampleMode::Linear => {
212                interpolate_linear(&x_flat, &coords, &spatial, &strides, padding_mode, n, c, out_prod, &dtype)?
213            }
214            GridSampleMode::Cubic => {
215                interpolate_cubic(&x_flat, &coords, &spatial, &strides, padding_mode, n, c, out_prod, &dtype)?
216            }
217        };
218
219        // Reshape to [N, C, *out_spatial]
220        let mut out_shape: Vec<isize> = vec![n as isize, c as isize];
221        out_shape.extend(out_spatial.iter().map(|&d| d as isize));
222        result.try_reshape(&out_shape)
223    }
224}
225
226fn compute_strides(dims: &[usize]) -> Vec<usize> {
227    let n = dims.len();
228    let mut strides = vec![1usize; n];
229    for i in (0..n.saturating_sub(1)).rev() {
230        strides[i] = strides[i + 1] * dims[i + 1];
231    }
232    strides
233}
234
235/// Extract `t[:, :, idx]` from shape `[N, out_prod, n_spatial]` → `[N, out_prod]`.
236fn slice_last_dim(t: &Tensor, idx: usize, n: usize, out_prod: usize) -> Result<Tensor> {
237    t.try_shrink([(0, n as isize), (0, out_prod as isize), (idx as isize, (idx + 1) as isize)])?.try_squeeze(Some(-1))
238}
239
240/// Denormalize grid coordinate from [-1, 1] to pixel space.
241fn gs_denormalize(coord: &Tensor, dim_size: usize, align_corners: bool, dtype: &DType) -> Result<Tensor> {
242    if align_corners {
243        // x = (n + 1) / 2 * (dim_size - 1)
244        coord
245            .try_add(&Tensor::const_(1.0, dtype.clone()))?
246            .try_mul(&Tensor::const_(0.5 * (dim_size - 1) as f64, dtype.clone()))
247    } else {
248        // x = ((n + 1) * dim_size - 1) / 2
249        coord
250            .try_add(&Tensor::const_(1.0, dtype.clone()))?
251            .try_mul(&Tensor::const_(dim_size as f64, dtype.clone()))?
252            .try_sub(&Tensor::const_(1.0, dtype.clone()))?
253            .try_mul(&Tensor::const_(0.5, dtype.clone()))
254    }
255}
256
257/// Reflect coordinate into [lo, hi] range for reflection padding.
258fn gs_reflect(coord: &Tensor, dim_size: usize, align_corners: bool, dtype: &DType) -> Result<Tensor> {
259    let (lo, hi) = if align_corners { (0.0, (dim_size - 1) as f64) } else { (-0.5, dim_size as f64 - 0.5) };
260    let rng = hi - lo;
261    if rng == 0.0 {
262        return Ok(Tensor::const_(lo, dtype.clone()));
263    }
264    let lo_t = Tensor::const_(lo, dtype.clone());
265    let rng_t = Tensor::const_(rng, dtype.clone());
266    let period_t = Tensor::const_(2.0 * rng, dtype.clone());
267
268    // Shift to [0, 2*rng) via positive modulo
269    let shifted = coord.try_sub(&lo_t)?;
270    let t = shifted.try_sub(&shifted.try_div(&period_t)?.floor()?.try_mul(&period_t)?)?;
271
272    // Reflect: if t > rng → 2*rng - t, else t
273    let two_rng_t = Tensor::const_(2.0 * rng, dtype.clone());
274    let reflected = two_rng_t.try_sub(&t)?;
275    let cond = rng_t.try_lt(&t)?; // t > rng
276    reflected.where_(&cond, &t)?.try_add(&lo_t)
277}
278
279/// Build flat index from per-dim integer indices and accumulate validity mask for zeros padding.
280fn build_flat_index(
281    indices: &[Tensor],
282    spatial: &[usize],
283    strides: &[usize],
284    padding_mode: GridSamplePaddingMode,
285) -> Result<(Tensor, Option<Tensor>)> {
286    let n_spatial = indices.len();
287    let mut flat_idx = Tensor::const_(ConstValue::Int(0), DType::Int32);
288    let mut valid_mask: Option<Tensor> = None;
289
290    for i in 0..n_spatial {
291        let idx = &indices[i];
292
293        if padding_mode == GridSamplePaddingMode::Zeros {
294            let zero_i = Tensor::const_(ConstValue::Int(0), DType::Int32);
295            let max_i = Tensor::const_(ConstValue::Int(spatial[i] as i64), DType::Int32);
296            let v = idx.try_ge(&zero_i)?.bitwise_and(&idx.try_lt(&max_i)?)?;
297            valid_mask = Some(match valid_mask {
298                Some(m) => m.bitwise_and(&v)?,
299                None => v,
300            });
301        }
302
303        // Clamp for safe gather (even out-of-bounds values need a valid index for gather)
304        let zero_i = Tensor::const_(ConstValue::Int(0), DType::Int32);
305        let max_i = Tensor::const_(ConstValue::Int((spatial[i] - 1) as i64), DType::Int32);
306        let safe_idx = idx.clamp().min(&zero_i).max(&max_i).call()?;
307
308        let stride_t = Tensor::const_(ConstValue::Int(strides[i] as i64), DType::Int32);
309        flat_idx = flat_idx.try_add(&safe_idx.try_mul(&stride_t)?)?;
310    }
311
312    Ok((flat_idx, valid_mask))
313}
314
315/// Gather from flat X and apply zeros mask if needed.
316fn gather_and_mask(
317    x_flat: &Tensor,
318    flat_idx: &Tensor,
319    valid_mask: Option<&Tensor>,
320    n: usize,
321    c: usize,
322    out_prod: usize,
323    dtype: &DType,
324) -> Result<Tensor> {
325    let expanded_idx = flat_idx.try_unsqueeze(1)?.try_expand([n as isize, c as isize, out_prod as isize])?;
326    let mut gathered = x_flat.gather(2, &expanded_idx)?;
327    if let Some(mask) = valid_mask {
328        let mask = mask.try_unsqueeze(1)?.try_expand([n as isize, c as isize, out_prod as isize])?;
329        gathered = gathered.try_mul(&mask.cast(dtype.clone())?)?;
330    }
331    Ok(gathered)
332}
333
334#[allow(clippy::too_many_arguments)]
335fn interpolate_nearest(
336    x_flat: &Tensor,
337    coords: &[Tensor],
338    spatial: &[usize],
339    strides: &[usize],
340    padding_mode: GridSamplePaddingMode,
341    n: usize,
342    c: usize,
343    out_prod: usize,
344    dtype: &DType,
345) -> Result<Tensor> {
346    // ONNX uses np.rint (round to nearest even); Tensor::round() implements this.
347    let rounded: Vec<Tensor> = coords.iter().map(|c| c.round()?.cast(DType::Int32)).collect::<Result<_>>()?;
348    let (flat_idx, valid_mask) = build_flat_index(&rounded, spatial, strides, padding_mode)?;
349    gather_and_mask(x_flat, &flat_idx, valid_mask.as_ref(), n, c, out_prod, dtype)
350}
351
352#[allow(clippy::too_many_arguments)]
353fn interpolate_linear(
354    x_flat: &Tensor,
355    coords: &[Tensor],
356    spatial: &[usize],
357    strides: &[usize],
358    padding_mode: GridSamplePaddingMode,
359    n: usize,
360    c: usize,
361    out_prod: usize,
362    dtype: &DType,
363) -> Result<Tensor> {
364    let n_spatial = coords.len();
365    let floors: Vec<Tensor> = coords.iter().map(|c| c.floor()).collect::<Result<_>>()?;
366    let fracs: Vec<Tensor> = coords.iter().zip(&floors).map(|(c, f)| c.try_sub(f)).collect::<Result<_>>()?;
367
368    // 2^n_spatial corners
369    let n_combos = 1usize << n_spatial;
370    let mut result = Tensor::const_(ConstValue::Float(0.0), dtype.clone());
371
372    for combo in 0..n_combos {
373        let mut weight = Tensor::const_(ConstValue::Float(1.0), dtype.clone());
374        let mut corner_indices: Vec<Tensor> = Vec::with_capacity(n_spatial);
375
376        for i in 0..n_spatial {
377            let use_ceil = (combo >> i) & 1 == 1;
378            let idx_f =
379                if use_ceil { floors[i].try_add(&Tensor::const_(1.0, dtype.clone()))? } else { floors[i].clone() };
380            let w = if use_ceil { fracs[i].clone() } else { Tensor::const_(1.0, dtype.clone()).try_sub(&fracs[i])? };
381            weight = weight.try_mul(&w)?;
382            corner_indices.push(idx_f.cast(DType::Int32)?);
383        }
384
385        let (flat_idx, valid_mask) = build_flat_index(&corner_indices, spatial, strides, padding_mode)?;
386        let gathered = gather_and_mask(x_flat, &flat_idx, valid_mask.as_ref(), n, c, out_prod, dtype)?;
387
388        let weight = weight.try_unsqueeze(1)?.try_expand([n as isize, c as isize, out_prod as isize])?;
389        result = result.try_add(&gathered.try_mul(&weight)?)?;
390    }
391
392    Ok(result)
393}
394
395#[allow(clippy::too_many_arguments)]
396fn interpolate_cubic(
397    x_flat: &Tensor,
398    coords: &[Tensor],
399    spatial: &[usize],
400    strides: &[usize],
401    padding_mode: GridSamplePaddingMode,
402    n: usize,
403    c: usize,
404    out_prod: usize,
405    dtype: &DType,
406) -> Result<Tensor> {
407    let n_spatial = coords.len();
408    let floors: Vec<Tensor> = coords.iter().map(|c| c.floor()).collect::<Result<_>>()?;
409    let fracs: Vec<Tensor> = coords.iter().zip(&floors).map(|(c, f)| c.try_sub(f)).collect::<Result<_>>()?;
410
411    // Cubic coefficients for each spatial dim (4 weights per dim)
412    let coeffs: Vec<[Tensor; 4]> = fracs.iter().map(|s| gs_cubic_coeffs(s, -0.75, dtype)).collect::<Result<_>>()?;
413
414    // 4^n_spatial combinations
415    let n_combos = 4usize.pow(n_spatial as u32);
416    let mut result = Tensor::const_(ConstValue::Float(0.0), dtype.clone());
417
418    for combo in 0..n_combos {
419        let mut weight = Tensor::const_(ConstValue::Float(1.0), dtype.clone());
420        let mut corner_indices: Vec<Tensor> = Vec::with_capacity(n_spatial);
421
422        for i in 0..n_spatial {
423            let offset_idx = (combo / 4usize.pow(i as u32)) % 4;
424            let offset = offset_idx as f64 - 1.0; // -1, 0, 1, 2
425
426            let idx_f = floors[i].try_add(&Tensor::const_(offset, dtype.clone()))?;
427            weight = weight.try_mul(&coeffs[i][offset_idx])?;
428            corner_indices.push(idx_f.cast(DType::Int32)?);
429        }
430
431        let (flat_idx, valid_mask) = build_flat_index(&corner_indices, spatial, strides, padding_mode)?;
432        let gathered = gather_and_mask(x_flat, &flat_idx, valid_mask.as_ref(), n, c, out_prod, dtype)?;
433
434        let weight = weight.try_unsqueeze(1)?.try_expand([n as isize, c as isize, out_prod as isize])?;
435        result = result.try_add(&gathered.try_mul(&weight)?)?;
436    }
437
438    Ok(result)
439}
440
441/// Cubic interpolation coefficients (Keys convolution, alpha = -0.75).
442/// Returns weights for offsets [-1, 0, 1, 2] relative to floor(x).
443fn gs_cubic_coeffs(s: &Tensor, a: f64, dtype: &DType) -> Result<[Tensor; 4]> {
444    let one = Tensor::const_(1.0, dtype.clone());
445    let two = Tensor::const_(2.0, dtype.clone());
446
447    // c0: |x| = s+1 (far neighbor)
448    // c0 = ((a*(s+1) - 5a)*(s+1) + 8a)*(s+1) - 4a
449    let sp1 = s.try_add(&one)?;
450    let c0 = sp1
451        .try_mul(&Tensor::const_(a, dtype.clone()))?
452        .try_sub(&Tensor::const_(5.0 * a, dtype.clone()))?
453        .try_mul(&sp1)?
454        .try_add(&Tensor::const_(8.0 * a, dtype.clone()))?
455        .try_mul(&sp1)?
456        .try_sub(&Tensor::const_(4.0 * a, dtype.clone()))?;
457
458    // c1: |x| = s (center-left)
459    // c1 = ((a+2)*s - (a+3))*s*s + 1
460    let c1 = s
461        .try_mul(&Tensor::const_(a + 2.0, dtype.clone()))?
462        .try_sub(&Tensor::const_(a + 3.0, dtype.clone()))?
463        .try_mul(s)?
464        .try_mul(s)?
465        .try_add(&one)?;
466
467    // c2: |x| = 1-s (center-right)
468    let sm1 = one.try_sub(s)?;
469    let c2 = sm1
470        .try_mul(&Tensor::const_(a + 2.0, dtype.clone()))?
471        .try_sub(&Tensor::const_(a + 3.0, dtype.clone()))?
472        .try_mul(&sm1)?
473        .try_mul(&sm1)?
474        .try_add(&Tensor::const_(1.0, dtype.clone()))?;
475
476    // c3: |x| = 2-s (far neighbor)
477    let sm2 = two.try_sub(s)?;
478    let c3 = sm2
479        .try_mul(&Tensor::const_(a, dtype.clone()))?
480        .try_sub(&Tensor::const_(5.0 * a, dtype.clone()))?
481        .try_mul(&sm2)?
482        .try_add(&Tensor::const_(8.0 * a, dtype.clone()))?
483        .try_mul(&sm2)?
484        .try_sub(&Tensor::const_(4.0 * a, dtype.clone()))?;
485
486    Ok([c0, c1, c2, c3])
487}