Skip to main content

svod_tensor/nn/
resize.rs

1//! Resize operations (ONNX Resize operator building block).
2
3use bon::bon;
4use snafu::ResultExt;
5use svod_dtype::DType;
6use svod_ir::ConstValue;
7
8use super::{AspectRatioPolicy, CoordinateTransformMode, NearestMode, ResizeMode};
9use crate::Tensor;
10use crate::error::UOpSnafu;
11
12type Result<T> = crate::Result<T>;
13
14#[bon]
15impl Tensor {
16    /// Resize a tensor using interpolation (ONNX Resize operator).
17    ///
18    /// Supports nearest, linear, and cubic interpolation modes with various
19    /// coordinate transformation modes. Either `scales` or `sizes` must be
20    /// provided to specify the target dimensions.
21    ///
22    /// # Examples
23    ///
24    /// Nearest-mode 2x upscale via `scales`:
25    ///
26    /// ```
27    /// # use svod_tensor::Tensor;
28    /// # use ndarray::Array4;
29    /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 2, 2), 1.0f32));
30    /// let mut y = x.resize().scales(&[1.0, 1.0, 2.0, 2.0]).call().unwrap();
31    /// y.realize().unwrap();
32    /// let shape: Vec<usize> = y.shape().unwrap().iter().map(|d| d.as_const().unwrap()).collect();
33    /// assert_eq!(shape, vec![1, 1, 4, 4]);
34    /// assert!(y.as_vec::<f32>().unwrap().iter().all(|&v| (v - 1.0).abs() < 1e-5));
35    /// ```
36    ///
37    /// Resize to explicit output `sizes`:
38    ///
39    /// ```
40    /// # use svod_tensor::Tensor;
41    /// # use ndarray::Array4;
42    /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 2, 2), 1.0f32));
43    /// let mut y = x.resize().sizes(&[1, 1, 6, 6]).call().unwrap();
44    /// y.realize().unwrap();
45    /// let shape: Vec<usize> = y.shape().unwrap().iter().map(|d| d.as_const().unwrap()).collect();
46    /// assert_eq!(shape, vec![1, 1, 6, 6]);
47    /// assert!(y.as_vec::<f32>().unwrap().iter().all(|&v| (v - 1.0).abs() < 1e-5));
48    /// ```
49    ///
50    /// Linear interpolation mode:
51    ///
52    /// ```
53    /// # use svod_tensor::Tensor;
54    /// # use svod_tensor::nn::ResizeMode;
55    /// # use ndarray::Array4;
56    /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 2, 2), 1.0f32));
57    /// let mut y = x.resize()
58    ///     .scales(&[1.0, 1.0, 2.0, 2.0])
59    ///     .mode(ResizeMode::Linear)
60    ///     .call()
61    ///     .unwrap();
62    /// y.realize().unwrap();
63    /// let shape: Vec<usize> = y.shape().unwrap().iter().map(|d| d.as_const().unwrap()).collect();
64    /// assert_eq!(shape, vec![1, 1, 4, 4]);
65    /// assert!(y.as_vec::<f32>().unwrap().iter().all(|&v| (v - 1.0).abs() < 1e-5));
66    /// ```
67    // Tinygrad onnx.py:789-890
68    #[builder]
69    #[allow(clippy::too_many_arguments)]
70    pub fn resize(
71        &self,
72        scales: Option<&[f64]>,
73        sizes: Option<&[usize]>,
74        #[builder(default)] mode: ResizeMode,
75        #[builder(default)] coordinate_transformation_mode: CoordinateTransformMode,
76        #[builder(default)] nearest_mode: NearestMode,
77        #[builder(default = -0.75)] cubic_coeff_a: f64,
78        #[builder(default = false)] exclude_outside: bool,
79        #[builder(default = false)] antialias: bool,
80        #[builder(default)] keep_aspect_ratio_policy: AspectRatioPolicy,
81        axes: Option<&[usize]>,
82        roi: Option<&[f64]>,
83        #[builder(default = 0.0)] extrapolation_value: f64,
84    ) -> Result<Tensor> {
85        let ndim = self.ndim()?;
86        let shape = self.shape()?;
87        // TODO(symbolic-batch): this validates *every* dim is concrete even
88        // though only the `axes` dims are read below (line 107) and only the
89        // non-axes dims are used to build expand shapes (lines 257, 274). For
90        // a symbolic batch with concrete spatial dims (e.g. a JIT input shrunk
91        // to a bound `b`), this fails unnecessarily. Narrow this to the axes
92        // dims, and have the expand-shape construction below carry SInt rather
93        // than going through `usize`. The result is the only thing using
94        // `_shape_dims` is its discard binding — drop it once the rest of the
95        // function stops needing fully-concrete shapes.
96        let _shape_dims = svod_ir::shape::to_vec_usize(&shape).context(UOpSnafu)?;
97
98        let axes: Vec<usize> = axes.map(|a| a.to_vec()).unwrap_or_else(|| (0..ndim).collect());
99
100        // Permute: put target axes last
101        let non_axes: Vec<usize> = (0..ndim).filter(|d| !axes.contains(d)).collect();
102        let perm: Vec<isize> = non_axes.iter().chain(axes.iter()).map(|&d| d as isize).collect();
103        let inv_perm = argsort_usize(&perm.iter().map(|&p| p as usize).collect::<Vec<_>>());
104        let inv_perm_i: Vec<isize> = inv_perm.iter().map(|&i| i as isize).collect();
105
106        let mut x = if perm.iter().enumerate().all(|(i, &p)| p == i as isize) {
107            self.clone()
108        } else {
109            self.try_permute(&perm)?
110        };
111
112        // Input spatial dimensions (last len(axes) dims of permuted x)
113        let x_shape = x.shape()?;
114        // TODO(symbolic-batch): same issue as above — only `input_shape` (the
115        // trailing spatial dims) is actually used; the non-axes prefix is
116        // copied into `expand_shape` later and never compared to a `usize`,
117        // so it could stay as `SInt` and admit symbolic dims.
118        let x_dims = svod_ir::shape::to_vec_usize(&x_shape).context(UOpSnafu)?;
119        let n_spatial = axes.len();
120        let input_shape: Vec<usize> = x_dims[ndim - n_spatial..].to_vec();
121
122        // Filter scales/sizes to spatial dims only
123        let scales_trimmed: Option<Vec<f64>> = scales.map(|s| s[s.len().saturating_sub(n_spatial)..].to_vec());
124        let sizes_trimmed: Option<Vec<usize>> = sizes.map(|s| s[s.len().saturating_sub(n_spatial)..].to_vec());
125
126        // Compute output sizes and scales
127        let (output_sizes, final_scales) = if let Some(mut sz) = sizes_trimmed {
128            if keep_aspect_ratio_policy == AspectRatioPolicy::NotLarger
129                || keep_aspect_ratio_policy == AspectRatioPolicy::NotSmaller
130            {
131                let scale_fn: fn(f64, f64) -> f64 =
132                    if keep_aspect_ratio_policy == AspectRatioPolicy::NotLarger { f64::min } else { f64::max };
133                let mut scale = f64::NAN;
134                for (s, &inp) in sz.iter().zip(&input_shape) {
135                    let s_val = *s as f64 / inp as f64;
136                    if scale.is_nan() {
137                        scale = s_val;
138                    } else {
139                        scale = scale_fn(scale, s_val);
140                    }
141                }
142                sz = input_shape.iter().map(|&sh| (scale * sh as f64 + 0.5) as usize).collect();
143                let sc = vec![scale; n_spatial];
144                (sz, sc)
145            } else {
146                let sc: Vec<f64> = sz.iter().zip(&input_shape).map(|(&s, &sh)| s as f64 / sh as f64).collect();
147                (sz, sc)
148            }
149        } else if let Some(sc) = scales_trimmed {
150            let sz: Vec<usize> = sc.iter().zip(&input_shape).map(|(&s, &sh)| (s * sh as f64) as usize).collect();
151            (sz, sc)
152        } else {
153            return Err(crate::error::Error::IrConstruction {
154                details: "resize: either scales or sizes must be provided".into(),
155            });
156        };
157
158        // Early exit if no resize needed
159        if output_sizes.iter().zip(&input_shape).all(|(&o, &i)| o == i) {
160            return if perm.iter().enumerate().any(|(i, &p)| p != i as isize) {
161                x.try_permute(&inv_perm_i)
162            } else {
163                Ok(x)
164            };
165        }
166
167        // Extract per-spatial-dim ROI (start, end) pairs
168        let roi_pairs: Vec<(f64, f64)> = if let Some(roi) = roi {
169            let half = roi.len() / 2;
170            let starts = &roi[half - n_spatial..half];
171            let ends = &roi[roi.len() - n_spatial..];
172            starts.iter().zip(ends).map(|(&s, &e)| (s, e)).collect()
173        } else {
174            vec![(0.0, 1.0); n_spatial]
175        };
176
177        // Build coordinate transforms for each spatial dim
178        let dtype = x.uop().dtype();
179        let indexes: Vec<Tensor> = input_shape
180            .iter()
181            .zip(&output_sizes)
182            .zip(&final_scales)
183            .zip(&roi_pairs)
184            .map(|(((&inp_sz, &out_sz), &scale), &(roi_start, roi_end))| {
185                apply_coordinate_transform(
186                    inp_sz,
187                    out_sz,
188                    scale,
189                    coordinate_transformation_mode,
190                    &dtype,
191                    roi_start,
192                    roi_end,
193                )
194            })
195            .collect::<Result<_>>()?;
196
197        // Clip for nearest/linear modes (skip for tf_crop_and_resize — uses extrapolation instead)
198        let is_tf_crop = coordinate_transformation_mode == CoordinateTransformMode::TfCropAndResize;
199        let indexes: Vec<Tensor> = if !is_tf_crop && matches!(mode, ResizeMode::Nearest | ResizeMode::Linear) {
200            indexes
201                .into_iter()
202                .zip(&input_shape)
203                .map(|(idx, &sz)| {
204                    let zero = Tensor::const_(ConstValue::Float(0.0), dtype.clone());
205                    let max_val = Tensor::const_(ConstValue::Float((sz - 1) as f64), dtype.clone());
206                    idx.clamp().min(&zero).max(&max_val).call()
207                })
208                .collect::<Result<Vec<_>>>()?
209        } else {
210            indexes
211        };
212
213        // For tf_crop_and_resize, build a validity mask from unclipped indexes
214        let validity_mask: Option<Vec<Tensor>> = if is_tf_crop {
215            Some(
216                indexes
217                    .iter()
218                    .zip(&input_shape)
219                    .map(|(idx, &sz)| {
220                        let zero = Tensor::const_(ConstValue::Float(0.0), dtype.clone());
221                        let max_val = Tensor::const_(ConstValue::Float((sz - 1) as f64), dtype.clone());
222                        idx.try_ge(&zero)?.bitwise_and(&idx.try_le(&max_val)?)
223                    })
224                    .collect::<Result<Vec<_>>>()?,
225            )
226        } else {
227            None
228        };
229
230        // For tf_crop_and_resize, clip indexes before gather (to avoid OOB)
231        let indexes: Vec<Tensor> = if is_tf_crop {
232            indexes
233                .into_iter()
234                .zip(&input_shape)
235                .map(|(idx, &sz)| {
236                    let zero = Tensor::const_(ConstValue::Float(0.0), dtype.clone());
237                    let max_val = Tensor::const_(ConstValue::Float((sz - 1) as f64), dtype.clone());
238                    idx.clamp().min(&zero).max(&max_val).call()
239                })
240                .collect::<Result<Vec<_>>>()?
241        } else {
242            indexes
243        };
244
245        if mode == ResizeMode::Nearest {
246            let int_indexes: Vec<Tensor> = indexes
247                .into_iter()
248                .map(|idx| {
249                    let rounded = match nearest_mode {
250                        NearestMode::RoundPreferFloor => idx.try_sub(&Tensor::const_(0.5f64, dtype.clone()))?.ceil()?,
251                        NearestMode::RoundPreferCeil => idx.try_add(&Tensor::const_(0.5f64, dtype.clone()))?.floor()?,
252                        NearestMode::Floor => idx.floor()?,
253                        NearestMode::Ceil => idx.ceil()?,
254                    };
255                    rounded.cast(DType::Int32)
256                })
257                .collect::<Result<Vec<_>>>()?;
258
259            // Sequential gather per spatial dim
260            for (i, idx) in int_indexes.iter().enumerate() {
261                let dim = (ndim - n_spatial + i) as isize;
262                let cur_shape = x.shape()?;
263                // TODO(symbolic-batch): `cur_dims` is built only to feed
264                // `expand_shape` below, which forces a `usize → isize` cast on
265                // every dim. For a symbolic prefix this loses information and
266                // aborts here. `try_expand` would need to accept `SInt` (it
267                // already does internally) so we can pass the symbolic dims
268                // through; then we'd substitute `out_sz` at the axis position
269                // and keep the rest of the shape as-is.
270                let cur_dims = svod_ir::shape::to_vec_usize(&cur_shape).context(UOpSnafu)?;
271                let out_sz = output_sizes[i];
272
273                let mut idx_shape = vec![1isize; ndim];
274                idx_shape[ndim - n_spatial + i] = out_sz as isize;
275                let idx_reshaped = idx.try_reshape(&idx_shape)?;
276
277                let mut expand_shape: Vec<isize> = cur_dims.iter().map(|&d| d as isize).collect();
278                expand_shape[ndim - n_spatial + i] = out_sz as isize;
279                let idx_expanded = idx_reshaped.try_expand(&expand_shape)?;
280
281                x = x.gather(dim, &idx_expanded)?;
282            }
283        } else if mode == ResizeMode::Linear {
284            let mut expand = x_dims.clone();
285            for (i, &out_sz) in output_sizes.iter().enumerate() {
286                let dim_pos = ndim - n_spatial + i;
287                let scale = final_scales[i];
288                let input_sz = input_shape[i];
289                let index = &indexes[i];
290
291                let mut reshape = vec![1isize; ndim];
292                reshape[dim_pos] = out_sz as isize;
293                expand[dim_pos] = out_sz;
294                let expand_i: Vec<isize> = expand.iter().map(|&d| d as isize).collect();
295
296                if antialias && scale < 1.0 {
297                    x = interpolate_antialias_linear(&x, index, dim_pos, input_sz, scale, &reshape, &expand_i, &dtype)?;
298                } else {
299                    let low = index.floor()?.cast(DType::Int32)?.try_reshape(&reshape)?.try_expand(&expand_i)?;
300                    let high = index.ceil()?.cast(DType::Int32)?.try_reshape(&reshape)?.try_expand(&expand_i)?;
301                    let perc = index.try_sub(&index.floor()?)?.try_reshape(&reshape)?.try_expand(&expand_i)?;
302
303                    let dim_i = dim_pos as isize;
304                    let gathered_low = x.gather(dim_i, &low)?;
305                    let gathered_high = x.gather(dim_i, &high)?;
306                    x = gathered_low.lerp(&gathered_high, &perc)?;
307                }
308            }
309        } else if mode == ResizeMode::Cubic {
310            let a = cubic_coeff_a;
311            let mut expand = x_dims.clone();
312            for (i, &out_sz) in output_sizes.iter().enumerate() {
313                let dim_pos = ndim - n_spatial + i;
314                let scale = final_scales[i];
315                let input_sz = input_shape[i];
316                let index = &indexes[i];
317
318                let mut reshape = vec![1isize; ndim];
319                reshape[dim_pos] = out_sz as isize;
320                expand[dim_pos] = out_sz;
321                let expand_i: Vec<isize> = expand.iter().map(|&d| d as isize).collect();
322
323                if antialias && scale < 1.0 {
324                    x = interpolate_antialias_cubic(
325                        &x, index, dim_pos, input_sz, scale, a, &reshape, &expand_i, &dtype,
326                    )?;
327                } else {
328                    let p = index.floor()?.cast(DType::Int32)?;
329                    let ratio = index.try_sub(&index.floor()?)?;
330
331                    let one = Tensor::const_(ConstValue::Int(1), DType::Int32);
332                    let two = Tensor::const_(ConstValue::Int(2), DType::Int32);
333                    let idx0 = p.try_sub(&one)?;
334                    let idx1 = p.clone();
335                    let idx2 = p.try_add(&one)?;
336                    let idx3 = p.try_add(&two)?;
337
338                    let r1 = ratio.try_add(&Tensor::const_(1.0f64, dtype.clone()))?;
339                    let c0 = poly_n(&r1, &[a, -5.0 * a, 8.0 * a, -4.0 * a], &dtype)?;
340                    let c1 = poly_n(&ratio, &[a + 2.0, -(a + 3.0), 0.0, 1.0], &dtype)?;
341                    let r_neg1 = Tensor::const_(1.0f64, dtype.clone()).try_sub(&ratio)?;
342                    let c2 = poly_n(&r_neg1, &[a + 2.0, -(a + 3.0), 0.0, 1.0], &dtype)?;
343                    let r_neg2 = Tensor::const_(2.0f64, dtype.clone()).try_sub(&ratio)?;
344                    let c3 = poly_n(&r_neg2, &[a, -5.0 * a, 8.0 * a, -4.0 * a], &dtype)?;
345
346                    let (mut c0, mut c1, mut c2, mut c3) = (c0, c1, c2, c3);
347                    if exclude_outside {
348                        let max_idx = Tensor::const_(ConstValue::Int(input_sz as i64), DType::Int32);
349                        let zero_i = Tensor::const_(ConstValue::Int(0), DType::Int32);
350                        let zero_f = Tensor::const_(0.0f64, dtype.clone());
351                        let valid0 = idx0.try_ge(&zero_i)?.try_mul(&idx0.try_lt(&max_idx)?)?;
352                        let valid1 = idx1.try_ge(&zero_i)?.try_mul(&idx1.try_lt(&max_idx)?)?;
353                        let valid2 = idx2.try_ge(&zero_i)?.try_mul(&idx2.try_lt(&max_idx)?)?;
354                        let valid3 = idx3.try_ge(&zero_i)?.try_mul(&idx3.try_lt(&max_idx)?)?;
355                        c0 = c0.where_(&valid0, &zero_f)?;
356                        c1 = c1.where_(&valid1, &zero_f)?;
357                        c2 = c2.where_(&valid2, &zero_f)?;
358                        c3 = c3.where_(&valid3, &zero_f)?;
359                        let total = c0.try_add(&c1)?.try_add(&c2)?.try_add(&c3)?;
360                        let eps = Tensor::const_(1e-9f64, dtype.clone());
361                        let total_safe = total.try_add(&eps)?;
362                        c0 = c0.try_div(&total_safe)?;
363                        c1 = c1.try_div(&total_safe)?;
364                        c2 = c2.try_div(&total_safe)?;
365                        c3 = c3.try_div(&total_safe)?;
366                    }
367
368                    let max_val = Tensor::const_(ConstValue::Int((input_sz - 1) as i64), DType::Int32);
369                    let zero_i = Tensor::const_(ConstValue::Int(0), DType::Int32);
370                    let clip = |t: &Tensor| -> Result<Tensor> {
371                        t.clamp().min(&zero_i).max(&max_val).call()?.try_reshape(&reshape)?.try_expand(&expand_i)
372                    };
373                    let ei0 = clip(&idx0)?;
374                    let ei1 = clip(&idx1)?;
375                    let ei2 = clip(&idx2)?;
376                    let ei3 = clip(&idx3)?;
377
378                    let ec = |c: Tensor| -> Result<Tensor> { c.try_reshape(&reshape)?.try_expand(&expand_i) };
379                    let ec0 = ec(c0)?;
380                    let ec1 = ec(c1)?;
381                    let ec2 = ec(c2)?;
382                    let ec3 = ec(c3)?;
383
384                    let dim_i = dim_pos as isize;
385                    let v0 = x.gather(dim_i, &ei0)?.try_mul(&ec0)?;
386                    let v1 = x.gather(dim_i, &ei1)?.try_mul(&ec1)?;
387                    let v2 = x.gather(dim_i, &ei2)?.try_mul(&ec2)?;
388                    let v3 = x.gather(dim_i, &ei3)?.try_mul(&ec3)?;
389                    x = v0.try_add(&v1)?.try_add(&v2)?.try_add(&v3)?;
390                }
391            }
392        }
393
394        // Apply extrapolation for tf_crop_and_resize: out-of-bounds → extrapolation_value
395        if let Some(masks) = validity_mask {
396            let extrap = Tensor::const_(ConstValue::Float(extrapolation_value), dtype.clone());
397            let x_shape = x.shape()?;
398            let x_dims = svod_ir::shape::to_vec_usize(&x_shape).context(UOpSnafu)?;
399            let expand_shape: Vec<isize> = x_dims.iter().map(|&d| d as isize).collect();
400
401            // Each mask_i is 1D [out_sz_i]; reshape to [1,..,out_sz_i,..,1] and broadcast
402            let mut combined: Option<Tensor> = None;
403            for (i, mask) in masks.into_iter().enumerate() {
404                let mut shape = vec![1isize; ndim];
405                shape[ndim - n_spatial + i] = output_sizes[i] as isize;
406                let broad = mask.try_reshape(&shape)?.try_expand(&expand_shape)?;
407                combined = Some(match combined {
408                    Some(c) => c.bitwise_and(&broad)?,
409                    None => broad,
410                });
411            }
412            if let Some(valid) = combined {
413                x = x.where_(&valid, &extrap)?;
414            }
415        }
416
417        // Permute back
418        if perm.iter().enumerate().any(|(i, &p)| p != i as isize) { x.try_permute(&inv_perm_i) } else { Ok(x) }
419    }
420}
421
422/// Coordinate transform for resize operations.
423///
424/// Computes in f64 to avoid precision loss from IR constant folding
425/// (which uses mixed f64/f32 arithmetic), then casts to target dtype.
426fn apply_coordinate_transform(
427    input_sz: usize,
428    output_sz: usize,
429    scale: f64,
430    mode: CoordinateTransformMode,
431    dtype: &DType,
432    roi_start: f64,
433    roi_end: f64,
434) -> Result<Tensor> {
435    let f64_dt = DType::Float64;
436    let index = Tensor::arange(0, Some(output_sz as i64), None)?.cast(f64_dt.clone())?;
437    let result = match mode {
438        CoordinateTransformMode::HalfPixel => {
439            let half = Tensor::const_(0.5f64, f64_dt.clone());
440            index.try_add(&half)?.try_div(&Tensor::const_(scale, f64_dt))?.try_sub(&half)?
441        }
442        CoordinateTransformMode::AlignCorners => {
443            // ONNX reference uses float output_width = scale * input_sz, not integer output_sz.
444            // This matters when scale * input_sz is non-integer (e.g. 0.8 * 4 = 3.2 vs int 3).
445            let output_width = scale * input_sz as f64;
446            if output_width == 1.0 {
447                Tensor::const_(0.0f64, f64_dt)
448            } else {
449                let ratio = (input_sz as f64 - 1.0) / (output_width - 1.0);
450                index.try_mul(&Tensor::const_(ratio, f64_dt))?
451            }
452        }
453        CoordinateTransformMode::Asymmetric => index.try_div(&Tensor::const_(scale, f64_dt))?,
454        CoordinateTransformMode::PytorchHalfPixel => {
455            let output_width = scale * input_sz as f64;
456            if output_width == 1.0 {
457                Tensor::const_(0.0f64, f64_dt)
458            } else {
459                let half = Tensor::const_(0.5f64, f64_dt.clone());
460                index.try_add(&half)?.try_div(&Tensor::const_(scale, f64_dt))?.try_sub(&half)?
461            }
462        }
463        CoordinateTransformMode::HalfPixelSymmetric => {
464            let output_dim_scaled = input_sz as f64 * scale;
465            let offset = (input_sz as f64 / 2.0) * (1.0 - output_sz as f64 / output_dim_scaled);
466            let half = Tensor::const_(0.5f64, f64_dt.clone());
467            let off_t = Tensor::const_(offset, f64_dt.clone());
468            off_t.try_add(&index.try_add(&half)?.try_div(&Tensor::const_(scale, f64_dt))?)?.try_sub(&half)?
469        }
470        CoordinateTransformMode::TfCropAndResize => {
471            let len = (input_sz as f64) - 1.0;
472            let output_width = scale * input_sz as f64;
473            if output_width == 1.0 {
474                Tensor::const_((roi_end - roi_start) * len / 2.0 + roi_start * len, f64_dt)
475            } else {
476                let stride = (roi_end - roi_start) * len / (output_width - 1.0);
477                let offset = roi_start * len;
478                index.try_mul(&Tensor::const_(stride, f64_dt.clone()))?.try_add(&Tensor::const_(offset, f64_dt))?
479            }
480        }
481    };
482    result.cast(dtype.clone())
483}
484
485/// Horner's method for polynomial evaluation.
486fn poly_n(x: &Tensor, coeffs: &[f64], dtype: &DType) -> Result<Tensor> {
487    coeffs.iter().try_fold(Tensor::const_(0.0f64, dtype.clone()), |acc, &c| {
488        acc.try_mul(x)?.try_add(&Tensor::const_(c, dtype.clone()))
489    })
490}
491
492/// Antialias cubic interpolation for one spatial dimension.
493/// When downsampling (scale < 1), widens the kernel by 1/scale to prevent aliasing.
494/// ONNX ref: _cubic_coeffs_antialias in op_resize.py
495#[allow(clippy::too_many_arguments)]
496fn interpolate_antialias_cubic(
497    x: &Tensor,
498    index: &Tensor,
499    dim_pos: usize,
500    input_sz: usize,
501    scale: f64,
502    a: f64,
503    reshape: &[isize],
504    expand_i: &[isize],
505    dtype: &DType,
506) -> Result<Tensor> {
507    let i_start = (-2.0_f64 / scale).floor() as i32 + 1;
508    let i_end = 2 - i_start;
509    let n_taps = (i_end - i_start) as usize;
510
511    let floored = index.floor()?;
512    let p = floored.cast(DType::Int32)?;
513    let ratio = index.try_sub(&floored)?;
514
515    let one = Tensor::const_(1.0f64, dtype.clone());
516    let two = Tensor::const_(2.0f64, dtype.clone());
517    let zero_f = Tensor::const_(0.0f64, dtype.clone());
518
519    let mut coeffs = Vec::with_capacity(n_taps);
520    for tap in i_start..i_end {
521        let arg = ratio
522            .try_mul(&Tensor::const_(-scale, dtype.clone()))?
523            .try_add(&Tensor::const_(scale * tap as f64, dtype.clone()))?;
524        let abs_arg = arg.try_abs()?;
525        let c_inner = poly_n(&abs_arg, &[a + 2.0, -(a + 3.0), 0.0, 1.0], dtype)?;
526        let c_outer = poly_n(&abs_arg, &[a, -5.0 * a, 8.0 * a, -4.0 * a], dtype)?;
527        let mask_outer = abs_arg.try_lt(&two)?;
528        let c = c_outer.where_(&mask_outer, &zero_f)?;
529        let mask_inner = abs_arg.try_le(&one)?;
530        let c = c_inner.where_(&mask_inner, &c)?;
531        coeffs.push(c);
532    }
533
534    normalize_and_gather(x, coeffs, &p, i_start, dim_pos, input_sz, reshape, expand_i, dtype)
535}
536
537/// Antialias linear interpolation for one spatial dimension.
538/// ONNX ref: _linear_coeffs_antialias in op_resize.py
539#[allow(clippy::too_many_arguments)]
540fn interpolate_antialias_linear(
541    x: &Tensor,
542    index: &Tensor,
543    dim_pos: usize,
544    input_sz: usize,
545    scale: f64,
546    reshape: &[isize],
547    expand_i: &[isize],
548    dtype: &DType,
549) -> Result<Tensor> {
550    let start = (-1.0_f64 / scale).floor() as i32 + 1;
551    let footprint = (2 - 2 * start) as usize;
552
553    let floored = index.floor()?;
554    let p = floored.cast(DType::Int32)?;
555    let ratio = index.try_sub(&floored)?;
556
557    let one = Tensor::const_(1.0f64, dtype.clone());
558    let zero_f = Tensor::const_(0.0f64, dtype.clone());
559
560    let mut coeffs = Vec::with_capacity(footprint);
561    for j in 0..footprint {
562        let tap = start + j as i32;
563        let arg = ratio
564            .try_mul(&Tensor::const_(-scale, dtype.clone()))?
565            .try_add(&Tensor::const_(scale * tap as f64, dtype.clone()))?;
566        let abs_arg = arg.try_abs()?;
567        let c = one.try_sub(&abs_arg)?;
568        let c = c.clamp().min(&zero_f).max(&one).call()?;
569        coeffs.push(c);
570    }
571
572    normalize_and_gather(x, coeffs, &p, start, dim_pos, input_sz, reshape, expand_i, dtype)
573}
574
575/// Normalize coefficients to sum to 1, then gather and accumulate weighted values.
576/// Shared by antialias cubic and linear interpolation.
577#[allow(clippy::too_many_arguments)]
578fn normalize_and_gather(
579    x: &Tensor,
580    mut coeffs: Vec<Tensor>,
581    p: &Tensor,
582    tap_start: i32,
583    dim_pos: usize,
584    input_sz: usize,
585    reshape: &[isize],
586    expand_i: &[isize],
587    dtype: &DType,
588) -> Result<Tensor> {
589    let mut total = coeffs[0].clone();
590    for c in &coeffs[1..] {
591        total = total.try_add(c)?;
592    }
593    let eps = Tensor::const_(1e-9f64, dtype.clone());
594    let total_safe = total.try_add(&eps)?;
595    for c in &mut coeffs {
596        *c = c.try_div(&total_safe)?;
597    }
598
599    let max_val = Tensor::const_(ConstValue::Int((input_sz - 1) as i64), DType::Int32);
600    let zero_i = Tensor::const_(ConstValue::Int(0), DType::Int32);
601    let dim_i = dim_pos as isize;
602
603    let mut result: Option<Tensor> = None;
604    for (j, c) in coeffs.into_iter().enumerate() {
605        let tap = tap_start + j as i32;
606        let idx = p.try_add(&Tensor::const_(ConstValue::Int(tap as i64), DType::Int32))?;
607        let idx_clipped = idx.clamp().min(&zero_i).max(&max_val).call()?.try_reshape(reshape)?.try_expand(expand_i)?;
608        let c_expanded = c.try_reshape(reshape)?.try_expand(expand_i)?;
609        let val = x.gather(dim_i, &idx_clipped)?.try_mul(&c_expanded)?;
610        result = Some(match result {
611            Some(acc) => acc.try_add(&val)?,
612            None => val,
613        });
614    }
615    Ok(result.unwrap())
616}
617
618fn argsort_usize(slice: &[usize]) -> Vec<usize> {
619    let mut indices: Vec<usize> = (0..slice.len()).collect();
620    indices.sort_by_key(|&i| slice[i]);
621    indices
622}