Skip to main content

tract_core/ops/nn/
resize.rs

1use crate::internal::*;
2use crate::ops::array::Tile;
3
4/// Maps an output coordinate back to the input axis. The four ONNX coordinate
5/// transformation modes that have a well-defined inverse without an input ROI.
6#[derive(Clone, Debug, Hash, PartialEq, Eq)]
7pub enum CoordTransformer {
8    HalfPixel,
9    AlignCorners,
10    Asymmetric,
11    PytorchHalfPixel,
12}
13
14impl CoordTransformer {
15    pub fn transform(&self, x_out: usize, scale: f32, len_in: usize, len_out: usize) -> f32 {
16        match self {
17            CoordTransformer::HalfPixel => (x_out as f32 + 0.5) / scale - 0.5,
18            CoordTransformer::AlignCorners => {
19                let output_width = scale * len_in as f32;
20                if output_width == 1.0 {
21                    0.0
22                } else {
23                    (x_out as f32 * (len_in as f32 - 1.0)) / (output_width - 1.0)
24                }
25            }
26            CoordTransformer::Asymmetric => (x_out as f32) / scale,
27            CoordTransformer::PytorchHalfPixel => {
28                if len_out > 1 {
29                    (x_out as f32 + 0.5) / scale - 0.5
30                } else {
31                    0.0
32                }
33            }
34        }
35    }
36
37    pub fn as_str(&self) -> &'static str {
38        match self {
39            CoordTransformer::HalfPixel => "half_pixel",
40            CoordTransformer::AlignCorners => "align_corners",
41            CoordTransformer::Asymmetric => "asymmetric",
42            CoordTransformer::PytorchHalfPixel => "pytorch_half_pixel",
43        }
44    }
45
46    pub fn parse(s: &str) -> TractResult<Self> {
47        Ok(match s {
48            "half_pixel" => CoordTransformer::HalfPixel,
49            "align_corners" => CoordTransformer::AlignCorners,
50            "asymmetric" => CoordTransformer::Asymmetric,
51            "pytorch_half_pixel" => CoordTransformer::PytorchHalfPixel,
52            s => bail!("coordinate_transformation_mode: {s}"),
53        })
54    }
55}
56
57/// Interpolation kernel. `Linear` and `Nearest` use a 2-tap path; `Cubic` uses
58/// a 4-tap kernel with the standard `a = -0.75` coefficient.
59#[derive(Clone, Debug, Hash, PartialEq, Eq)]
60pub enum Interpolator {
61    Linear,
62    Nearest,
63    Cubic,
64}
65
66impl Interpolator {
67    pub fn as_str(&self) -> &'static str {
68        match self {
69            Interpolator::Linear => "linear",
70            Interpolator::Nearest => "nearest",
71            Interpolator::Cubic => "cubic",
72        }
73    }
74
75    pub fn parse(s: &str) -> TractResult<Self> {
76        Ok(match s {
77            "linear" => Interpolator::Linear,
78            "nearest" => Interpolator::Nearest,
79            "cubic" => Interpolator::Cubic,
80            s => bail!("mode: {s}"),
81        })
82    }
83}
84
85/// Standard Catmull-Rom-family cubic convolution kernel with coefficient `a`.
86pub fn cubic_kernel(s: f32, a: f32) -> f32 {
87    let abs_s = s.abs();
88    if abs_s <= 1.0 {
89        (a + 2.0) * abs_s * abs_s * abs_s - (a + 3.0) * abs_s * abs_s + 1.0
90    } else if abs_s <= 2.0 {
91        a * abs_s * abs_s * abs_s - 5.0 * a * abs_s * abs_s + 8.0 * a * abs_s - 4.0 * a
92    } else {
93        0.0
94    }
95}
96
97/// Nearest-neighbour tie-breaking. Restricted to the two modes tract-core
98/// supports: `Floor` (PyTorch `upsample_nearest`) and `RoundPreferCeil`
99/// (`_upsample_nearest_exact`). The other ONNX modes stay in the ONNX op.
100#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
101pub enum Nearest {
102    Floor,
103    RoundPreferCeil,
104}
105
106impl Nearest {
107    /// True when the right (ceil) neighbour wins for a fractional offset.
108    pub fn prefers_right(&self, x_ratio: f32) -> bool {
109        match self {
110            Nearest::Floor => false,
111            Nearest::RoundPreferCeil => x_ratio >= 0.5,
112        }
113    }
114
115    pub fn as_str(&self) -> &'static str {
116        match self {
117            Nearest::Floor => "floor",
118            Nearest::RoundPreferCeil => "round_prefer_ceil",
119        }
120    }
121
122    pub fn parse(s: &str) -> TractResult<Self> {
123        Ok(match s {
124            "floor" => Nearest::Floor,
125            "round_prefer_ceil" => Nearest::RoundPreferCeil,
126            s => bail!("nearest_mode: {s}"),
127        })
128    }
129}
130
131/// Resamples `input` along the axes given by `scales`/`sizes`, the clean subset
132/// of ONNX Resize: `interpolator` × `coord_transformer` × `nearest`, fixed
133/// `cubic_coeff_a = -0.75`, no ROI and no `exclude_outside`. The ONNX op carries
134/// the remaining edge cases and decltters into this op when it fits the subset.
135#[derive(Clone, Debug, Hash, PartialEq, Eq)]
136pub struct Resize {
137    pub coord_transformer: CoordTransformer,
138    pub interpolator: Interpolator,
139    pub nearest: Nearest,
140    pub optional_scales_input: Option<usize>,
141    pub optional_sizes_input: Option<usize>,
142}
143
144impl Resize {
145    pub fn compute_output_shape<D: DimLike>(
146        &self,
147        input_shape: &[D],
148        input_scale: Option<&Tensor>,
149        input_sizes: Option<&Tensor>,
150    ) -> TractResult<TVec<D>> {
151        if let Some(scale) = input_scale
152            && scale.len() == input_shape.len()
153        {
154            let mut shape = tvec!();
155            for (i, s) in input_shape
156                .iter()
157                .zip(scale.cast_to::<f32>()?.try_as_plain()?.as_slice::<f32>()?.iter())
158            {
159                if s.round() == *s {
160                    shape.push(i.clone() * (*s as usize));
161                } else if let Ok(i) = i.to_usize() {
162                    shape.push(((i as f32 * s) as usize).into());
163                } else {
164                    bail!(
165                        "Can not compute output shape. inputs are {input_shape:?} and scale {scale:?}"
166                    )
167                }
168            }
169            return Ok(shape);
170        }
171        if let Some(sizes) = input_sizes
172            && sizes.len() == input_shape.len()
173        {
174            return sizes
175                .cast_to::<TDim>()?
176                .try_as_plain()?
177                .as_slice::<TDim>()?
178                .iter()
179                .map(|i| i.try_into())
180                .collect();
181        }
182        bail!(
183            "Neither sizes nor scales makes sense: input_shape: {:?}, scale: {:?}, sizes: {:?}",
184            input_shape,
185            input_scale,
186            input_sizes,
187        );
188    }
189}
190
191impl Op for Resize {
192    fn name(&self) -> StaticName {
193        "Resize".into()
194    }
195
196    op_as_typed_op!();
197}
198
199impl EvalOp for Resize {
200    fn is_stateless(&self) -> bool {
201        true
202    }
203
204    fn eval(&self, mut inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
205        let input_dt = inputs[0].datum_type();
206        let scales = self.optional_scales_input.and_then(|ix| inputs.get(ix));
207        let sizes = self.optional_sizes_input.and_then(|ix| inputs.get(ix));
208        let output_shape = self.compute_output_shape(
209            inputs[0].shape(),
210            scales.map(|t| &**t),
211            sizes.map(|t| &**t),
212        )?;
213        let scales: TVec<f32> = if let Some(scales) = scales.filter(|s| s.len() == inputs[0].rank())
214        {
215            scales.try_as_plain()?.as_slice::<f32>()?.into()
216        } else {
217            output_shape.iter().zip(inputs[0].shape()).map(|(o, i)| *o as f32 / *i as f32).collect()
218        };
219        let input = inputs.remove(0).into_tensor();
220        let mut data = if input.datum_type() == f32::datum_type() {
221            input.into_plain_array::<f32>()?
222        } else {
223            input.cast_to::<f32>()?.into_owned().into_plain_array::<f32>()?
224        };
225        for (axis, scale) in scales.into_iter().enumerate().filter(|(_, s)| *s != 1.0) {
226            let mut new_shape: TVec<usize> = data.shape().into();
227            new_shape[axis] = output_shape[axis];
228            let input_len = data.shape()[axis];
229            data = match self.interpolator {
230                Interpolator::Cubic => {
231                    let a = -0.75f32;
232                    tract_ndarray::ArrayD::from_shape_fn(&*new_shape, |co_o| -> f32 {
233                        let x_out = co_o[axis];
234                        let x_in = self.coord_transformer.transform(
235                            x_out,
236                            scale,
237                            input_len,
238                            new_shape[axis],
239                        );
240                        let x_floor = x_in.floor() as isize;
241                        let t = x_in - x_floor as f32;
242                        let mut co_i = co_o;
243                        let mut acc = 0.0f32;
244                        for j in -1..=2isize {
245                            let w = cubic_kernel(t - j as f32, a);
246                            let idx = (x_floor + j).clamp(0, input_len as isize - 1) as usize;
247                            co_i[axis] = idx;
248                            acc += w * data[&co_i];
249                        }
250                        acc
251                    })
252                }
253                _ => tract_ndarray::ArrayD::from_shape_fn(&*new_shape, |co_o| -> f32 {
254                    let x_out = co_o[axis];
255                    let x_in =
256                        self.coord_transformer.transform(x_out, scale, input_len, new_shape[axis]);
257                    let mut co_i = co_o;
258                    let x_floor = x_in.floor() as isize;
259                    let x_left = x_floor.clamp(0, input_len as isize - 1) as usize;
260                    co_i[axis] = x_left;
261                    let y_left = data[&co_i];
262                    let x_right = (x_floor + 1).clamp(0, input_len as isize - 1) as usize;
263                    co_i[axis] = x_right;
264                    let y_right = data[&co_i];
265                    let x_frac = x_in - x_floor as f32;
266                    match self.interpolator {
267                        Interpolator::Linear => y_left * (1.0 - x_frac) + y_right * x_frac,
268                        Interpolator::Nearest => {
269                            if self.nearest.prefers_right(x_frac) {
270                                y_right
271                            } else {
272                                y_left
273                            }
274                        }
275                        Interpolator::Cubic => unreachable!(),
276                    }
277                }),
278            }
279        }
280        let out = data.into_tensor();
281        let out =
282            if out.datum_type() == input_dt { out } else { out.cast_to_dt(input_dt)?.into_owned() };
283        Ok(tvec!(out.into_tvalue()))
284    }
285}
286
287impl TypedOp for Resize {
288    as_op!();
289
290    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
291        let scales = self.optional_scales_input.and_then(|ix| inputs.get(ix));
292        let sizes = self.optional_sizes_input.and_then(|ix| inputs.get(ix));
293        let output_shape = self.compute_output_shape(
294            &inputs[0].shape,
295            scales.and_then(|f| f.konst.as_deref()),
296            sizes.and_then(|f| f.konst.as_deref()),
297        )?;
298        Ok(tvec!(inputs[0].datum_type.fact(&output_shape)))
299    }
300
301    fn declutter(
302        &self,
303        model: &TypedModel,
304        node: &TypedNode,
305    ) -> TractResult<Option<TypedModelPatch>> {
306        rule_if!(matches!(self.interpolator, Interpolator::Nearest));
307        rule_if_some!(scales_input = self.optional_scales_input);
308        let scales_fact = model.outlet_fact(node.inputs[scales_input])?;
309        rule_if_some!(scales_tensor = &scales_fact.konst);
310        let scales: Vec<f32> =
311            scales_tensor.cast_to::<f32>()?.try_as_plain()?.as_slice::<f32>()?.to_vec();
312        let int_scales: Vec<usize> = scales.iter().map(|&s| s.round() as usize).collect();
313        rule_if!(
314            scales.iter().zip(&int_scales).all(|(&s, &i)| (s - i as f32).abs() <= 1e-5 && i != 0)
315        );
316        rule_if!(int_scales.iter().any(|&s| s != 1));
317
318        lower_nearest_integer_upsample(model, node, &int_scales)
319    }
320}
321
322/// Lowers a nearest-neighbour integer upsample to Reshape → Tile → Reshape: each
323/// upsampled axis is split into a size-1 axis, tiled by its scale, then merged
324/// back. Shared by the core and ONNX Resize declutters.
325pub fn lower_nearest_integer_upsample(
326    model: &TypedModel,
327    node: &TypedNode,
328    int_scales: &[usize],
329) -> TractResult<Option<TypedModelPatch>> {
330    let input_fact = model.outlet_fact(node.inputs[0])?;
331    let input_shape = &input_fact.shape;
332
333    let mut patch = TypedModelPatch::default();
334    let mut wire = patch.tap_model(model, node.inputs[0])?;
335
336    let mut from_dims: TVec<TDim> = tvec![];
337    let mut to_dims: TVec<TDim> = tvec![];
338    let mut tile_multipliers: TVec<TDim> = tvec![];
339    let mut first_upsampled = None;
340
341    for (i, &scale) in int_scales.iter().enumerate() {
342        from_dims.push(input_shape[i].clone());
343        to_dims.push(input_shape[i].clone());
344        tile_multipliers.push(1.into());
345        if scale > 1 {
346            if first_upsampled.is_none() {
347                first_upsampled = Some(i);
348            }
349            to_dims.push(1.into());
350            tile_multipliers.push(scale.into());
351        }
352    }
353
354    if to_dims.len() > from_dims.len() {
355        let first = first_upsampled.unwrap();
356        wire = patch.wire_node(
357            format!("{}.reshape_pre", node.name),
358            AxisOp::Reshape(first, from_dims[first..].into(), to_dims[first..].into()),
359            &[wire],
360        )?[0];
361    }
362
363    wire = patch.wire_node(
364        format!("{}.tile", node.name),
365        Tile { multipliers: tile_multipliers },
366        &[wire],
367    )?[0];
368
369    let tiled_shape: TVec<TDim> = to_dims
370        .iter()
371        .zip(int_scales.iter().flat_map(|&s| if s > 1 { vec![1usize, s] } else { vec![1] }))
372        .map(|(d, s)| d.clone() * s)
373        .collect();
374    let mut final_dims: TVec<TDim> = tvec![];
375    let mut idx = 0;
376    for &scale in int_scales {
377        if scale > 1 {
378            final_dims.push(tiled_shape[idx].clone() * tiled_shape[idx + 1].clone());
379            idx += 2;
380        } else {
381            final_dims.push(tiled_shape[idx].clone());
382            idx += 1;
383        }
384    }
385
386    if tiled_shape.len() > final_dims.len() {
387        let first = first_upsampled.unwrap();
388        wire = patch.wire_node(
389            format!("{}.reshape_post", node.name),
390            AxisOp::Reshape(first, tiled_shape[first..].into(), final_dims[first..].into()),
391            &[wire],
392        )?[0];
393    }
394
395    patch.shunt_outside(model, node.id.into(), wire)?;
396    Ok(Some(patch))
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402
403    #[test]
404    fn cubic_kernel_properties() {
405        let a = -0.75f32;
406        assert!((cubic_kernel(0.0, a) - 1.0).abs() < 1e-6);
407        assert!(cubic_kernel(2.0, a).abs() < 1e-6);
408        assert!(cubic_kernel(3.0, a).abs() < 1e-6);
409
410        for t_int in 0..=100 {
411            let t = t_int as f32 / 100.0;
412            let sum = cubic_kernel(t + 1.0, a)
413                + cubic_kernel(t, a)
414                + cubic_kernel(1.0 - t, a)
415                + cubic_kernel(2.0 - t, a);
416            assert!((sum - 1.0).abs() < 1e-5, "kernel weights must sum to 1.0, got {sum} at t={t}");
417        }
418    }
419
420    fn cubic_resize(input: Tensor, scales: &[f32]) -> Tensor {
421        let scales = tract_ndarray::Array1::from(scales.to_vec()).into_tensor();
422        let op = Resize {
423            coord_transformer: CoordTransformer::HalfPixel,
424            interpolator: Interpolator::Cubic,
425            nearest: Nearest::Floor,
426            optional_scales_input: Some(1),
427            optional_sizes_input: None,
428        };
429        op.eval(tvec!(input.into_tvalue(), scales.into_tvalue())).unwrap().remove(0).into_tensor()
430    }
431
432    #[test]
433    fn cubic_resize_1d_upsample() {
434        let out = cubic_resize(tract_ndarray::arr1(&[0.0f32, 1.0, 2.0, 3.0]).into_tensor(), &[2.0]);
435        let plain = out.try_as_plain().unwrap();
436        let output = plain.as_slice::<f32>().unwrap();
437        assert_eq!(output.len(), 8);
438        assert!((output[0] - (-0.10546875)).abs() < 1e-4, "got {}", output[0]);
439    }
440
441    #[test]
442    fn cubic_resize_2d_upsample() {
443        let out = cubic_resize(
444            tract_ndarray::arr2(&[[1.0f32, 2.0], [3.0, 4.0]]).into_tensor(),
445            &[2.0, 2.0],
446        );
447        assert_eq!(out.shape(), &[4, 4]);
448    }
449}