1use bon::bon;
4use snafu::ResultExt;
5use svod_dtype::DType;
6use svod_ir::ConstValue;
7
8use super::{AspectRatioPolicy, CoordinateTransformMode, NearestMode, ResizeMode};
9use crate::Tensor;
10use crate::error::UOpSnafu;
11
12type Result<T> = crate::Result<T>;
13
14#[bon]
15impl Tensor {
16 #[builder]
69 #[allow(clippy::too_many_arguments)]
70 pub fn resize(
71 &self,
72 scales: Option<&[f64]>,
73 sizes: Option<&[usize]>,
74 #[builder(default)] mode: ResizeMode,
75 #[builder(default)] coordinate_transformation_mode: CoordinateTransformMode,
76 #[builder(default)] nearest_mode: NearestMode,
77 #[builder(default = -0.75)] cubic_coeff_a: f64,
78 #[builder(default = false)] exclude_outside: bool,
79 #[builder(default = false)] antialias: bool,
80 #[builder(default)] keep_aspect_ratio_policy: AspectRatioPolicy,
81 axes: Option<&[usize]>,
82 roi: Option<&[f64]>,
83 #[builder(default = 0.0)] extrapolation_value: f64,
84 ) -> Result<Tensor> {
85 let ndim = self.ndim()?;
86 let shape = self.shape()?;
87 let _shape_dims = svod_ir::shape::to_vec_usize(&shape).context(UOpSnafu)?;
97
98 let axes: Vec<usize> = axes.map(|a| a.to_vec()).unwrap_or_else(|| (0..ndim).collect());
99
100 let non_axes: Vec<usize> = (0..ndim).filter(|d| !axes.contains(d)).collect();
102 let perm: Vec<isize> = non_axes.iter().chain(axes.iter()).map(|&d| d as isize).collect();
103 let inv_perm = argsort_usize(&perm.iter().map(|&p| p as usize).collect::<Vec<_>>());
104 let inv_perm_i: Vec<isize> = inv_perm.iter().map(|&i| i as isize).collect();
105
106 let mut x = if perm.iter().enumerate().all(|(i, &p)| p == i as isize) {
107 self.clone()
108 } else {
109 self.try_permute(&perm)?
110 };
111
112 let x_shape = x.shape()?;
114 let x_dims = svod_ir::shape::to_vec_usize(&x_shape).context(UOpSnafu)?;
119 let n_spatial = axes.len();
120 let input_shape: Vec<usize> = x_dims[ndim - n_spatial..].to_vec();
121
122 let scales_trimmed: Option<Vec<f64>> = scales.map(|s| s[s.len().saturating_sub(n_spatial)..].to_vec());
124 let sizes_trimmed: Option<Vec<usize>> = sizes.map(|s| s[s.len().saturating_sub(n_spatial)..].to_vec());
125
126 let (output_sizes, final_scales) = if let Some(mut sz) = sizes_trimmed {
128 if keep_aspect_ratio_policy == AspectRatioPolicy::NotLarger
129 || keep_aspect_ratio_policy == AspectRatioPolicy::NotSmaller
130 {
131 let scale_fn: fn(f64, f64) -> f64 =
132 if keep_aspect_ratio_policy == AspectRatioPolicy::NotLarger { f64::min } else { f64::max };
133 let mut scale = f64::NAN;
134 for (s, &inp) in sz.iter().zip(&input_shape) {
135 let s_val = *s as f64 / inp as f64;
136 if scale.is_nan() {
137 scale = s_val;
138 } else {
139 scale = scale_fn(scale, s_val);
140 }
141 }
142 sz = input_shape.iter().map(|&sh| (scale * sh as f64 + 0.5) as usize).collect();
143 let sc = vec![scale; n_spatial];
144 (sz, sc)
145 } else {
146 let sc: Vec<f64> = sz.iter().zip(&input_shape).map(|(&s, &sh)| s as f64 / sh as f64).collect();
147 (sz, sc)
148 }
149 } else if let Some(sc) = scales_trimmed {
150 let sz: Vec<usize> = sc.iter().zip(&input_shape).map(|(&s, &sh)| (s * sh as f64) as usize).collect();
151 (sz, sc)
152 } else {
153 return Err(crate::error::Error::IrConstruction {
154 details: "resize: either scales or sizes must be provided".into(),
155 });
156 };
157
158 if output_sizes.iter().zip(&input_shape).all(|(&o, &i)| o == i) {
160 return if perm.iter().enumerate().any(|(i, &p)| p != i as isize) {
161 x.try_permute(&inv_perm_i)
162 } else {
163 Ok(x)
164 };
165 }
166
167 let roi_pairs: Vec<(f64, f64)> = if let Some(roi) = roi {
169 let half = roi.len() / 2;
170 let starts = &roi[half - n_spatial..half];
171 let ends = &roi[roi.len() - n_spatial..];
172 starts.iter().zip(ends).map(|(&s, &e)| (s, e)).collect()
173 } else {
174 vec![(0.0, 1.0); n_spatial]
175 };
176
177 let dtype = x.uop().dtype();
179 let indexes: Vec<Tensor> = input_shape
180 .iter()
181 .zip(&output_sizes)
182 .zip(&final_scales)
183 .zip(&roi_pairs)
184 .map(|(((&inp_sz, &out_sz), &scale), &(roi_start, roi_end))| {
185 apply_coordinate_transform(
186 inp_sz,
187 out_sz,
188 scale,
189 coordinate_transformation_mode,
190 &dtype,
191 roi_start,
192 roi_end,
193 )
194 })
195 .collect::<Result<_>>()?;
196
197 let is_tf_crop = coordinate_transformation_mode == CoordinateTransformMode::TfCropAndResize;
199 let indexes: Vec<Tensor> = if !is_tf_crop && matches!(mode, ResizeMode::Nearest | ResizeMode::Linear) {
200 indexes
201 .into_iter()
202 .zip(&input_shape)
203 .map(|(idx, &sz)| {
204 let zero = Tensor::const_(ConstValue::Float(0.0), dtype.clone());
205 let max_val = Tensor::const_(ConstValue::Float((sz - 1) as f64), dtype.clone());
206 idx.clamp().min(&zero).max(&max_val).call()
207 })
208 .collect::<Result<Vec<_>>>()?
209 } else {
210 indexes
211 };
212
213 let validity_mask: Option<Vec<Tensor>> = if is_tf_crop {
215 Some(
216 indexes
217 .iter()
218 .zip(&input_shape)
219 .map(|(idx, &sz)| {
220 let zero = Tensor::const_(ConstValue::Float(0.0), dtype.clone());
221 let max_val = Tensor::const_(ConstValue::Float((sz - 1) as f64), dtype.clone());
222 idx.try_ge(&zero)?.bitwise_and(&idx.try_le(&max_val)?)
223 })
224 .collect::<Result<Vec<_>>>()?,
225 )
226 } else {
227 None
228 };
229
230 let indexes: Vec<Tensor> = if is_tf_crop {
232 indexes
233 .into_iter()
234 .zip(&input_shape)
235 .map(|(idx, &sz)| {
236 let zero = Tensor::const_(ConstValue::Float(0.0), dtype.clone());
237 let max_val = Tensor::const_(ConstValue::Float((sz - 1) as f64), dtype.clone());
238 idx.clamp().min(&zero).max(&max_val).call()
239 })
240 .collect::<Result<Vec<_>>>()?
241 } else {
242 indexes
243 };
244
245 if mode == ResizeMode::Nearest {
246 let int_indexes: Vec<Tensor> = indexes
247 .into_iter()
248 .map(|idx| {
249 let rounded = match nearest_mode {
250 NearestMode::RoundPreferFloor => idx.try_sub(&Tensor::const_(0.5f64, dtype.clone()))?.ceil()?,
251 NearestMode::RoundPreferCeil => idx.try_add(&Tensor::const_(0.5f64, dtype.clone()))?.floor()?,
252 NearestMode::Floor => idx.floor()?,
253 NearestMode::Ceil => idx.ceil()?,
254 };
255 rounded.cast(DType::Int32)
256 })
257 .collect::<Result<Vec<_>>>()?;
258
259 for (i, idx) in int_indexes.iter().enumerate() {
261 let dim = (ndim - n_spatial + i) as isize;
262 let cur_shape = x.shape()?;
263 let cur_dims = svod_ir::shape::to_vec_usize(&cur_shape).context(UOpSnafu)?;
271 let out_sz = output_sizes[i];
272
273 let mut idx_shape = vec![1isize; ndim];
274 idx_shape[ndim - n_spatial + i] = out_sz as isize;
275 let idx_reshaped = idx.try_reshape(&idx_shape)?;
276
277 let mut expand_shape: Vec<isize> = cur_dims.iter().map(|&d| d as isize).collect();
278 expand_shape[ndim - n_spatial + i] = out_sz as isize;
279 let idx_expanded = idx_reshaped.try_expand(&expand_shape)?;
280
281 x = x.gather(dim, &idx_expanded)?;
282 }
283 } else if mode == ResizeMode::Linear {
284 let mut expand = x_dims.clone();
285 for (i, &out_sz) in output_sizes.iter().enumerate() {
286 let dim_pos = ndim - n_spatial + i;
287 let scale = final_scales[i];
288 let input_sz = input_shape[i];
289 let index = &indexes[i];
290
291 let mut reshape = vec![1isize; ndim];
292 reshape[dim_pos] = out_sz as isize;
293 expand[dim_pos] = out_sz;
294 let expand_i: Vec<isize> = expand.iter().map(|&d| d as isize).collect();
295
296 if antialias && scale < 1.0 {
297 x = interpolate_antialias_linear(&x, index, dim_pos, input_sz, scale, &reshape, &expand_i, &dtype)?;
298 } else {
299 let low = index.floor()?.cast(DType::Int32)?.try_reshape(&reshape)?.try_expand(&expand_i)?;
300 let high = index.ceil()?.cast(DType::Int32)?.try_reshape(&reshape)?.try_expand(&expand_i)?;
301 let perc = index.try_sub(&index.floor()?)?.try_reshape(&reshape)?.try_expand(&expand_i)?;
302
303 let dim_i = dim_pos as isize;
304 let gathered_low = x.gather(dim_i, &low)?;
305 let gathered_high = x.gather(dim_i, &high)?;
306 x = gathered_low.lerp(&gathered_high, &perc)?;
307 }
308 }
309 } else if mode == ResizeMode::Cubic {
310 let a = cubic_coeff_a;
311 let mut expand = x_dims.clone();
312 for (i, &out_sz) in output_sizes.iter().enumerate() {
313 let dim_pos = ndim - n_spatial + i;
314 let scale = final_scales[i];
315 let input_sz = input_shape[i];
316 let index = &indexes[i];
317
318 let mut reshape = vec![1isize; ndim];
319 reshape[dim_pos] = out_sz as isize;
320 expand[dim_pos] = out_sz;
321 let expand_i: Vec<isize> = expand.iter().map(|&d| d as isize).collect();
322
323 if antialias && scale < 1.0 {
324 x = interpolate_antialias_cubic(
325 &x, index, dim_pos, input_sz, scale, a, &reshape, &expand_i, &dtype,
326 )?;
327 } else {
328 let p = index.floor()?.cast(DType::Int32)?;
329 let ratio = index.try_sub(&index.floor()?)?;
330
331 let one = Tensor::const_(ConstValue::Int(1), DType::Int32);
332 let two = Tensor::const_(ConstValue::Int(2), DType::Int32);
333 let idx0 = p.try_sub(&one)?;
334 let idx1 = p.clone();
335 let idx2 = p.try_add(&one)?;
336 let idx3 = p.try_add(&two)?;
337
338 let r1 = ratio.try_add(&Tensor::const_(1.0f64, dtype.clone()))?;
339 let c0 = poly_n(&r1, &[a, -5.0 * a, 8.0 * a, -4.0 * a], &dtype)?;
340 let c1 = poly_n(&ratio, &[a + 2.0, -(a + 3.0), 0.0, 1.0], &dtype)?;
341 let r_neg1 = Tensor::const_(1.0f64, dtype.clone()).try_sub(&ratio)?;
342 let c2 = poly_n(&r_neg1, &[a + 2.0, -(a + 3.0), 0.0, 1.0], &dtype)?;
343 let r_neg2 = Tensor::const_(2.0f64, dtype.clone()).try_sub(&ratio)?;
344 let c3 = poly_n(&r_neg2, &[a, -5.0 * a, 8.0 * a, -4.0 * a], &dtype)?;
345
346 let (mut c0, mut c1, mut c2, mut c3) = (c0, c1, c2, c3);
347 if exclude_outside {
348 let max_idx = Tensor::const_(ConstValue::Int(input_sz as i64), DType::Int32);
349 let zero_i = Tensor::const_(ConstValue::Int(0), DType::Int32);
350 let zero_f = Tensor::const_(0.0f64, dtype.clone());
351 let valid0 = idx0.try_ge(&zero_i)?.try_mul(&idx0.try_lt(&max_idx)?)?;
352 let valid1 = idx1.try_ge(&zero_i)?.try_mul(&idx1.try_lt(&max_idx)?)?;
353 let valid2 = idx2.try_ge(&zero_i)?.try_mul(&idx2.try_lt(&max_idx)?)?;
354 let valid3 = idx3.try_ge(&zero_i)?.try_mul(&idx3.try_lt(&max_idx)?)?;
355 c0 = c0.where_(&valid0, &zero_f)?;
356 c1 = c1.where_(&valid1, &zero_f)?;
357 c2 = c2.where_(&valid2, &zero_f)?;
358 c3 = c3.where_(&valid3, &zero_f)?;
359 let total = c0.try_add(&c1)?.try_add(&c2)?.try_add(&c3)?;
360 let eps = Tensor::const_(1e-9f64, dtype.clone());
361 let total_safe = total.try_add(&eps)?;
362 c0 = c0.try_div(&total_safe)?;
363 c1 = c1.try_div(&total_safe)?;
364 c2 = c2.try_div(&total_safe)?;
365 c3 = c3.try_div(&total_safe)?;
366 }
367
368 let max_val = Tensor::const_(ConstValue::Int((input_sz - 1) as i64), DType::Int32);
369 let zero_i = Tensor::const_(ConstValue::Int(0), DType::Int32);
370 let clip = |t: &Tensor| -> Result<Tensor> {
371 t.clamp().min(&zero_i).max(&max_val).call()?.try_reshape(&reshape)?.try_expand(&expand_i)
372 };
373 let ei0 = clip(&idx0)?;
374 let ei1 = clip(&idx1)?;
375 let ei2 = clip(&idx2)?;
376 let ei3 = clip(&idx3)?;
377
378 let ec = |c: Tensor| -> Result<Tensor> { c.try_reshape(&reshape)?.try_expand(&expand_i) };
379 let ec0 = ec(c0)?;
380 let ec1 = ec(c1)?;
381 let ec2 = ec(c2)?;
382 let ec3 = ec(c3)?;
383
384 let dim_i = dim_pos as isize;
385 let v0 = x.gather(dim_i, &ei0)?.try_mul(&ec0)?;
386 let v1 = x.gather(dim_i, &ei1)?.try_mul(&ec1)?;
387 let v2 = x.gather(dim_i, &ei2)?.try_mul(&ec2)?;
388 let v3 = x.gather(dim_i, &ei3)?.try_mul(&ec3)?;
389 x = v0.try_add(&v1)?.try_add(&v2)?.try_add(&v3)?;
390 }
391 }
392 }
393
394 if let Some(masks) = validity_mask {
396 let extrap = Tensor::const_(ConstValue::Float(extrapolation_value), dtype.clone());
397 let x_shape = x.shape()?;
398 let x_dims = svod_ir::shape::to_vec_usize(&x_shape).context(UOpSnafu)?;
399 let expand_shape: Vec<isize> = x_dims.iter().map(|&d| d as isize).collect();
400
401 let mut combined: Option<Tensor> = None;
403 for (i, mask) in masks.into_iter().enumerate() {
404 let mut shape = vec![1isize; ndim];
405 shape[ndim - n_spatial + i] = output_sizes[i] as isize;
406 let broad = mask.try_reshape(&shape)?.try_expand(&expand_shape)?;
407 combined = Some(match combined {
408 Some(c) => c.bitwise_and(&broad)?,
409 None => broad,
410 });
411 }
412 if let Some(valid) = combined {
413 x = x.where_(&valid, &extrap)?;
414 }
415 }
416
417 if perm.iter().enumerate().any(|(i, &p)| p != i as isize) { x.try_permute(&inv_perm_i) } else { Ok(x) }
419 }
420}
421
422fn apply_coordinate_transform(
427 input_sz: usize,
428 output_sz: usize,
429 scale: f64,
430 mode: CoordinateTransformMode,
431 dtype: &DType,
432 roi_start: f64,
433 roi_end: f64,
434) -> Result<Tensor> {
435 let f64_dt = DType::Float64;
436 let index = Tensor::arange(0, Some(output_sz as i64), None)?.cast(f64_dt.clone())?;
437 let result = match mode {
438 CoordinateTransformMode::HalfPixel => {
439 let half = Tensor::const_(0.5f64, f64_dt.clone());
440 index.try_add(&half)?.try_div(&Tensor::const_(scale, f64_dt))?.try_sub(&half)?
441 }
442 CoordinateTransformMode::AlignCorners => {
443 let output_width = scale * input_sz as f64;
446 if output_width == 1.0 {
447 Tensor::const_(0.0f64, f64_dt)
448 } else {
449 let ratio = (input_sz as f64 - 1.0) / (output_width - 1.0);
450 index.try_mul(&Tensor::const_(ratio, f64_dt))?
451 }
452 }
453 CoordinateTransformMode::Asymmetric => index.try_div(&Tensor::const_(scale, f64_dt))?,
454 CoordinateTransformMode::PytorchHalfPixel => {
455 let output_width = scale * input_sz as f64;
456 if output_width == 1.0 {
457 Tensor::const_(0.0f64, f64_dt)
458 } else {
459 let half = Tensor::const_(0.5f64, f64_dt.clone());
460 index.try_add(&half)?.try_div(&Tensor::const_(scale, f64_dt))?.try_sub(&half)?
461 }
462 }
463 CoordinateTransformMode::HalfPixelSymmetric => {
464 let output_dim_scaled = input_sz as f64 * scale;
465 let offset = (input_sz as f64 / 2.0) * (1.0 - output_sz as f64 / output_dim_scaled);
466 let half = Tensor::const_(0.5f64, f64_dt.clone());
467 let off_t = Tensor::const_(offset, f64_dt.clone());
468 off_t.try_add(&index.try_add(&half)?.try_div(&Tensor::const_(scale, f64_dt))?)?.try_sub(&half)?
469 }
470 CoordinateTransformMode::TfCropAndResize => {
471 let len = (input_sz as f64) - 1.0;
472 let output_width = scale * input_sz as f64;
473 if output_width == 1.0 {
474 Tensor::const_((roi_end - roi_start) * len / 2.0 + roi_start * len, f64_dt)
475 } else {
476 let stride = (roi_end - roi_start) * len / (output_width - 1.0);
477 let offset = roi_start * len;
478 index.try_mul(&Tensor::const_(stride, f64_dt.clone()))?.try_add(&Tensor::const_(offset, f64_dt))?
479 }
480 }
481 };
482 result.cast(dtype.clone())
483}
484
485fn poly_n(x: &Tensor, coeffs: &[f64], dtype: &DType) -> Result<Tensor> {
487 coeffs.iter().try_fold(Tensor::const_(0.0f64, dtype.clone()), |acc, &c| {
488 acc.try_mul(x)?.try_add(&Tensor::const_(c, dtype.clone()))
489 })
490}
491
492#[allow(clippy::too_many_arguments)]
496fn interpolate_antialias_cubic(
497 x: &Tensor,
498 index: &Tensor,
499 dim_pos: usize,
500 input_sz: usize,
501 scale: f64,
502 a: f64,
503 reshape: &[isize],
504 expand_i: &[isize],
505 dtype: &DType,
506) -> Result<Tensor> {
507 let i_start = (-2.0_f64 / scale).floor() as i32 + 1;
508 let i_end = 2 - i_start;
509 let n_taps = (i_end - i_start) as usize;
510
511 let floored = index.floor()?;
512 let p = floored.cast(DType::Int32)?;
513 let ratio = index.try_sub(&floored)?;
514
515 let one = Tensor::const_(1.0f64, dtype.clone());
516 let two = Tensor::const_(2.0f64, dtype.clone());
517 let zero_f = Tensor::const_(0.0f64, dtype.clone());
518
519 let mut coeffs = Vec::with_capacity(n_taps);
520 for tap in i_start..i_end {
521 let arg = ratio
522 .try_mul(&Tensor::const_(-scale, dtype.clone()))?
523 .try_add(&Tensor::const_(scale * tap as f64, dtype.clone()))?;
524 let abs_arg = arg.try_abs()?;
525 let c_inner = poly_n(&abs_arg, &[a + 2.0, -(a + 3.0), 0.0, 1.0], dtype)?;
526 let c_outer = poly_n(&abs_arg, &[a, -5.0 * a, 8.0 * a, -4.0 * a], dtype)?;
527 let mask_outer = abs_arg.try_lt(&two)?;
528 let c = c_outer.where_(&mask_outer, &zero_f)?;
529 let mask_inner = abs_arg.try_le(&one)?;
530 let c = c_inner.where_(&mask_inner, &c)?;
531 coeffs.push(c);
532 }
533
534 normalize_and_gather(x, coeffs, &p, i_start, dim_pos, input_sz, reshape, expand_i, dtype)
535}
536
537#[allow(clippy::too_many_arguments)]
540fn interpolate_antialias_linear(
541 x: &Tensor,
542 index: &Tensor,
543 dim_pos: usize,
544 input_sz: usize,
545 scale: f64,
546 reshape: &[isize],
547 expand_i: &[isize],
548 dtype: &DType,
549) -> Result<Tensor> {
550 let start = (-1.0_f64 / scale).floor() as i32 + 1;
551 let footprint = (2 - 2 * start) as usize;
552
553 let floored = index.floor()?;
554 let p = floored.cast(DType::Int32)?;
555 let ratio = index.try_sub(&floored)?;
556
557 let one = Tensor::const_(1.0f64, dtype.clone());
558 let zero_f = Tensor::const_(0.0f64, dtype.clone());
559
560 let mut coeffs = Vec::with_capacity(footprint);
561 for j in 0..footprint {
562 let tap = start + j as i32;
563 let arg = ratio
564 .try_mul(&Tensor::const_(-scale, dtype.clone()))?
565 .try_add(&Tensor::const_(scale * tap as f64, dtype.clone()))?;
566 let abs_arg = arg.try_abs()?;
567 let c = one.try_sub(&abs_arg)?;
568 let c = c.clamp().min(&zero_f).max(&one).call()?;
569 coeffs.push(c);
570 }
571
572 normalize_and_gather(x, coeffs, &p, start, dim_pos, input_sz, reshape, expand_i, dtype)
573}
574
575#[allow(clippy::too_many_arguments)]
578fn normalize_and_gather(
579 x: &Tensor,
580 mut coeffs: Vec<Tensor>,
581 p: &Tensor,
582 tap_start: i32,
583 dim_pos: usize,
584 input_sz: usize,
585 reshape: &[isize],
586 expand_i: &[isize],
587 dtype: &DType,
588) -> Result<Tensor> {
589 let mut total = coeffs[0].clone();
590 for c in &coeffs[1..] {
591 total = total.try_add(c)?;
592 }
593 let eps = Tensor::const_(1e-9f64, dtype.clone());
594 let total_safe = total.try_add(&eps)?;
595 for c in &mut coeffs {
596 *c = c.try_div(&total_safe)?;
597 }
598
599 let max_val = Tensor::const_(ConstValue::Int((input_sz - 1) as i64), DType::Int32);
600 let zero_i = Tensor::const_(ConstValue::Int(0), DType::Int32);
601 let dim_i = dim_pos as isize;
602
603 let mut result: Option<Tensor> = None;
604 for (j, c) in coeffs.into_iter().enumerate() {
605 let tap = tap_start + j as i32;
606 let idx = p.try_add(&Tensor::const_(ConstValue::Int(tap as i64), DType::Int32))?;
607 let idx_clipped = idx.clamp().min(&zero_i).max(&max_val).call()?.try_reshape(reshape)?.try_expand(expand_i)?;
608 let c_expanded = c.try_reshape(reshape)?.try_expand(expand_i)?;
609 let val = x.gather(dim_i, &idx_clipped)?.try_mul(&c_expanded)?;
610 result = Some(match result {
611 Some(acc) => acc.try_add(&val)?,
612 None => val,
613 });
614 }
615 Ok(result.unwrap())
616}
617
618fn argsort_usize(slice: &[usize]) -> Vec<usize> {
619 let mut indices: Vec<usize> = (0..slice.len()).collect();
620 indices.sort_by_key(|&i| slice[i]);
621 indices
622}