1use crate::internal::*;
2use crate::ops::array::Tile;
3
4#[derive(Clone, Debug, Hash, PartialEq, Eq)]
7pub enum CoordTransformer {
8 HalfPixel,
9 AlignCorners,
10 Asymmetric,
11 PytorchHalfPixel,
12}
13
14impl CoordTransformer {
15 pub fn transform(&self, x_out: usize, scale: f32, len_in: usize, len_out: usize) -> f32 {
16 match self {
17 CoordTransformer::HalfPixel => (x_out as f32 + 0.5) / scale - 0.5,
18 CoordTransformer::AlignCorners => {
19 let output_width = scale * len_in as f32;
20 if output_width == 1.0 {
21 0.0
22 } else {
23 (x_out as f32 * (len_in as f32 - 1.0)) / (output_width - 1.0)
24 }
25 }
26 CoordTransformer::Asymmetric => (x_out as f32) / scale,
27 CoordTransformer::PytorchHalfPixel => {
28 if len_out > 1 {
29 (x_out as f32 + 0.5) / scale - 0.5
30 } else {
31 0.0
32 }
33 }
34 }
35 }
36
37 pub fn as_str(&self) -> &'static str {
38 match self {
39 CoordTransformer::HalfPixel => "half_pixel",
40 CoordTransformer::AlignCorners => "align_corners",
41 CoordTransformer::Asymmetric => "asymmetric",
42 CoordTransformer::PytorchHalfPixel => "pytorch_half_pixel",
43 }
44 }
45
46 pub fn parse(s: &str) -> TractResult<Self> {
47 Ok(match s {
48 "half_pixel" => CoordTransformer::HalfPixel,
49 "align_corners" => CoordTransformer::AlignCorners,
50 "asymmetric" => CoordTransformer::Asymmetric,
51 "pytorch_half_pixel" => CoordTransformer::PytorchHalfPixel,
52 s => bail!("coordinate_transformation_mode: {s}"),
53 })
54 }
55}
56
57#[derive(Clone, Debug, Hash, PartialEq, Eq)]
60pub enum Interpolator {
61 Linear,
62 Nearest,
63 Cubic,
64}
65
66impl Interpolator {
67 pub fn as_str(&self) -> &'static str {
68 match self {
69 Interpolator::Linear => "linear",
70 Interpolator::Nearest => "nearest",
71 Interpolator::Cubic => "cubic",
72 }
73 }
74
75 pub fn parse(s: &str) -> TractResult<Self> {
76 Ok(match s {
77 "linear" => Interpolator::Linear,
78 "nearest" => Interpolator::Nearest,
79 "cubic" => Interpolator::Cubic,
80 s => bail!("mode: {s}"),
81 })
82 }
83}
84
85pub fn cubic_kernel(s: f32, a: f32) -> f32 {
87 let abs_s = s.abs();
88 if abs_s <= 1.0 {
89 (a + 2.0) * abs_s * abs_s * abs_s - (a + 3.0) * abs_s * abs_s + 1.0
90 } else if abs_s <= 2.0 {
91 a * abs_s * abs_s * abs_s - 5.0 * a * abs_s * abs_s + 8.0 * a * abs_s - 4.0 * a
92 } else {
93 0.0
94 }
95}
96
97#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
101pub enum Nearest {
102 Floor,
103 RoundPreferCeil,
104}
105
106impl Nearest {
107 pub fn prefers_right(&self, x_ratio: f32) -> bool {
109 match self {
110 Nearest::Floor => false,
111 Nearest::RoundPreferCeil => x_ratio >= 0.5,
112 }
113 }
114
115 pub fn as_str(&self) -> &'static str {
116 match self {
117 Nearest::Floor => "floor",
118 Nearest::RoundPreferCeil => "round_prefer_ceil",
119 }
120 }
121
122 pub fn parse(s: &str) -> TractResult<Self> {
123 Ok(match s {
124 "floor" => Nearest::Floor,
125 "round_prefer_ceil" => Nearest::RoundPreferCeil,
126 s => bail!("nearest_mode: {s}"),
127 })
128 }
129}
130
131#[derive(Clone, Debug, Hash, PartialEq, Eq)]
136pub struct Resize {
137 pub coord_transformer: CoordTransformer,
138 pub interpolator: Interpolator,
139 pub nearest: Nearest,
140 pub optional_scales_input: Option<usize>,
141 pub optional_sizes_input: Option<usize>,
142}
143
144impl Resize {
145 pub fn compute_output_shape<D: DimLike>(
146 &self,
147 input_shape: &[D],
148 input_scale: Option<&Tensor>,
149 input_sizes: Option<&Tensor>,
150 ) -> TractResult<TVec<D>> {
151 if let Some(scale) = input_scale
152 && scale.len() == input_shape.len()
153 {
154 let mut shape = tvec!();
155 for (i, s) in input_shape
156 .iter()
157 .zip(scale.cast_to::<f32>()?.try_as_plain()?.as_slice::<f32>()?.iter())
158 {
159 if s.round() == *s {
160 shape.push(i.clone() * (*s as usize));
161 } else if let Ok(i) = i.to_usize() {
162 shape.push(((i as f32 * s) as usize).into());
163 } else {
164 bail!(
165 "Can not compute output shape. inputs are {input_shape:?} and scale {scale:?}"
166 )
167 }
168 }
169 return Ok(shape);
170 }
171 if let Some(sizes) = input_sizes
172 && sizes.len() == input_shape.len()
173 {
174 return sizes
175 .cast_to::<TDim>()?
176 .try_as_plain()?
177 .as_slice::<TDim>()?
178 .iter()
179 .map(|i| i.try_into())
180 .collect();
181 }
182 bail!(
183 "Neither sizes nor scales makes sense: input_shape: {:?}, scale: {:?}, sizes: {:?}",
184 input_shape,
185 input_scale,
186 input_sizes,
187 );
188 }
189}
190
191impl Op for Resize {
192 fn name(&self) -> StaticName {
193 "Resize".into()
194 }
195
196 op_as_typed_op!();
197}
198
199impl EvalOp for Resize {
200 fn is_stateless(&self) -> bool {
201 true
202 }
203
204 fn eval(&self, mut inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
205 let input_dt = inputs[0].datum_type();
206 let scales = self.optional_scales_input.and_then(|ix| inputs.get(ix));
207 let sizes = self.optional_sizes_input.and_then(|ix| inputs.get(ix));
208 let output_shape = self.compute_output_shape(
209 inputs[0].shape(),
210 scales.map(|t| &**t),
211 sizes.map(|t| &**t),
212 )?;
213 let scales: TVec<f32> = if let Some(scales) = scales.filter(|s| s.len() == inputs[0].rank())
214 {
215 scales.try_as_plain()?.as_slice::<f32>()?.into()
216 } else {
217 output_shape.iter().zip(inputs[0].shape()).map(|(o, i)| *o as f32 / *i as f32).collect()
218 };
219 let input = inputs.remove(0).into_tensor();
220 let mut data = if input.datum_type() == f32::datum_type() {
221 input.into_plain_array::<f32>()?
222 } else {
223 input.cast_to::<f32>()?.into_owned().into_plain_array::<f32>()?
224 };
225 for (axis, scale) in scales.into_iter().enumerate().filter(|(_, s)| *s != 1.0) {
226 let mut new_shape: TVec<usize> = data.shape().into();
227 new_shape[axis] = output_shape[axis];
228 let input_len = data.shape()[axis];
229 data = match self.interpolator {
230 Interpolator::Cubic => {
231 let a = -0.75f32;
232 tract_ndarray::ArrayD::from_shape_fn(&*new_shape, |co_o| -> f32 {
233 let x_out = co_o[axis];
234 let x_in = self.coord_transformer.transform(
235 x_out,
236 scale,
237 input_len,
238 new_shape[axis],
239 );
240 let x_floor = x_in.floor() as isize;
241 let t = x_in - x_floor as f32;
242 let mut co_i = co_o;
243 let mut acc = 0.0f32;
244 for j in -1..=2isize {
245 let w = cubic_kernel(t - j as f32, a);
246 let idx = (x_floor + j).clamp(0, input_len as isize - 1) as usize;
247 co_i[axis] = idx;
248 acc += w * data[&co_i];
249 }
250 acc
251 })
252 }
253 _ => tract_ndarray::ArrayD::from_shape_fn(&*new_shape, |co_o| -> f32 {
254 let x_out = co_o[axis];
255 let x_in =
256 self.coord_transformer.transform(x_out, scale, input_len, new_shape[axis]);
257 let mut co_i = co_o;
258 let x_floor = x_in.floor() as isize;
259 let x_left = x_floor.clamp(0, input_len as isize - 1) as usize;
260 co_i[axis] = x_left;
261 let y_left = data[&co_i];
262 let x_right = (x_floor + 1).clamp(0, input_len as isize - 1) as usize;
263 co_i[axis] = x_right;
264 let y_right = data[&co_i];
265 let x_frac = x_in - x_floor as f32;
266 match self.interpolator {
267 Interpolator::Linear => y_left * (1.0 - x_frac) + y_right * x_frac,
268 Interpolator::Nearest => {
269 if self.nearest.prefers_right(x_frac) {
270 y_right
271 } else {
272 y_left
273 }
274 }
275 Interpolator::Cubic => unreachable!(),
276 }
277 }),
278 }
279 }
280 let out = data.into_tensor();
281 let out =
282 if out.datum_type() == input_dt { out } else { out.cast_to_dt(input_dt)?.into_owned() };
283 Ok(tvec!(out.into_tvalue()))
284 }
285}
286
287impl TypedOp for Resize {
288 as_op!();
289
290 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
291 let scales = self.optional_scales_input.and_then(|ix| inputs.get(ix));
292 let sizes = self.optional_sizes_input.and_then(|ix| inputs.get(ix));
293 let output_shape = self.compute_output_shape(
294 &inputs[0].shape,
295 scales.and_then(|f| f.konst.as_deref()),
296 sizes.and_then(|f| f.konst.as_deref()),
297 )?;
298 Ok(tvec!(inputs[0].datum_type.fact(&output_shape)))
299 }
300
301 fn declutter(
302 &self,
303 model: &TypedModel,
304 node: &TypedNode,
305 ) -> TractResult<Option<TypedModelPatch>> {
306 rule_if!(matches!(self.interpolator, Interpolator::Nearest));
307 rule_if_some!(scales_input = self.optional_scales_input);
308 let scales_fact = model.outlet_fact(node.inputs[scales_input])?;
309 rule_if_some!(scales_tensor = &scales_fact.konst);
310 let scales: Vec<f32> =
311 scales_tensor.cast_to::<f32>()?.try_as_plain()?.as_slice::<f32>()?.to_vec();
312 let int_scales: Vec<usize> = scales.iter().map(|&s| s.round() as usize).collect();
313 rule_if!(
314 scales.iter().zip(&int_scales).all(|(&s, &i)| (s - i as f32).abs() <= 1e-5 && i != 0)
315 );
316 rule_if!(int_scales.iter().any(|&s| s != 1));
317
318 lower_nearest_integer_upsample(model, node, &int_scales)
319 }
320}
321
322pub fn lower_nearest_integer_upsample(
326 model: &TypedModel,
327 node: &TypedNode,
328 int_scales: &[usize],
329) -> TractResult<Option<TypedModelPatch>> {
330 let input_fact = model.outlet_fact(node.inputs[0])?;
331 let input_shape = &input_fact.shape;
332
333 let mut patch = TypedModelPatch::default();
334 let mut wire = patch.tap_model(model, node.inputs[0])?;
335
336 let mut from_dims: TVec<TDim> = tvec![];
337 let mut to_dims: TVec<TDim> = tvec![];
338 let mut tile_multipliers: TVec<TDim> = tvec![];
339 let mut first_upsampled = None;
340
341 for (i, &scale) in int_scales.iter().enumerate() {
342 from_dims.push(input_shape[i].clone());
343 to_dims.push(input_shape[i].clone());
344 tile_multipliers.push(1.into());
345 if scale > 1 {
346 if first_upsampled.is_none() {
347 first_upsampled = Some(i);
348 }
349 to_dims.push(1.into());
350 tile_multipliers.push(scale.into());
351 }
352 }
353
354 if to_dims.len() > from_dims.len() {
355 let first = first_upsampled.unwrap();
356 wire = patch.wire_node(
357 format!("{}.reshape_pre", node.name),
358 AxisOp::Reshape(first, from_dims[first..].into(), to_dims[first..].into()),
359 &[wire],
360 )?[0];
361 }
362
363 wire = patch.wire_node(
364 format!("{}.tile", node.name),
365 Tile { multipliers: tile_multipliers },
366 &[wire],
367 )?[0];
368
369 let tiled_shape: TVec<TDim> = to_dims
370 .iter()
371 .zip(int_scales.iter().flat_map(|&s| if s > 1 { vec![1usize, s] } else { vec![1] }))
372 .map(|(d, s)| d.clone() * s)
373 .collect();
374 let mut final_dims: TVec<TDim> = tvec![];
375 let mut idx = 0;
376 for &scale in int_scales {
377 if scale > 1 {
378 final_dims.push(tiled_shape[idx].clone() * tiled_shape[idx + 1].clone());
379 idx += 2;
380 } else {
381 final_dims.push(tiled_shape[idx].clone());
382 idx += 1;
383 }
384 }
385
386 if tiled_shape.len() > final_dims.len() {
387 let first = first_upsampled.unwrap();
388 wire = patch.wire_node(
389 format!("{}.reshape_post", node.name),
390 AxisOp::Reshape(first, tiled_shape[first..].into(), final_dims[first..].into()),
391 &[wire],
392 )?[0];
393 }
394
395 patch.shunt_outside(model, node.id.into(), wire)?;
396 Ok(Some(patch))
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402
403 #[test]
404 fn cubic_kernel_properties() {
405 let a = -0.75f32;
406 assert!((cubic_kernel(0.0, a) - 1.0).abs() < 1e-6);
407 assert!(cubic_kernel(2.0, a).abs() < 1e-6);
408 assert!(cubic_kernel(3.0, a).abs() < 1e-6);
409
410 for t_int in 0..=100 {
411 let t = t_int as f32 / 100.0;
412 let sum = cubic_kernel(t + 1.0, a)
413 + cubic_kernel(t, a)
414 + cubic_kernel(1.0 - t, a)
415 + cubic_kernel(2.0 - t, a);
416 assert!((sum - 1.0).abs() < 1e-5, "kernel weights must sum to 1.0, got {sum} at t={t}");
417 }
418 }
419
420 fn cubic_resize(input: Tensor, scales: &[f32]) -> Tensor {
421 let scales = tract_ndarray::Array1::from(scales.to_vec()).into_tensor();
422 let op = Resize {
423 coord_transformer: CoordTransformer::HalfPixel,
424 interpolator: Interpolator::Cubic,
425 nearest: Nearest::Floor,
426 optional_scales_input: Some(1),
427 optional_sizes_input: None,
428 };
429 op.eval(tvec!(input.into_tvalue(), scales.into_tvalue())).unwrap().remove(0).into_tensor()
430 }
431
432 #[test]
433 fn cubic_resize_1d_upsample() {
434 let out = cubic_resize(tract_ndarray::arr1(&[0.0f32, 1.0, 2.0, 3.0]).into_tensor(), &[2.0]);
435 let plain = out.try_as_plain().unwrap();
436 let output = plain.as_slice::<f32>().unwrap();
437 assert_eq!(output.len(), 8);
438 assert!((output[0] - (-0.10546875)).abs() < 1e-4, "got {}", output[0]);
439 }
440
441 #[test]
442 fn cubic_resize_2d_upsample() {
443 let out = cubic_resize(
444 tract_ndarray::arr2(&[[1.0f32, 2.0], [3.0, 4.0]]).into_tensor(),
445 &[2.0, 2.0],
446 );
447 assert_eq!(out.shape(), &[4, 4]);
448 }
449}