1use bon::bon;
4use snafu::ResultExt;
5use svod_dtype::DType;
6use svod_ir::ConstValue;
7
8use crate::Tensor;
9use crate::error::{NdimMinimumSnafu, UOpSnafu};
10use crate::shape_ops::MeshgridIndexing;
11
12use super::{GridSampleMode, GridSamplePaddingMode};
13
14type Result<T> = crate::Result<T>;
15
16#[bon]
17impl Tensor {
18 #[builder]
55 pub fn affine_grid(
56 theta: &Tensor,
57 size: &[i64],
58 #[builder(default = false)] align_corners: bool,
59 ) -> Result<Tensor> {
60 snafu::ensure!(size.len() >= 3, NdimMinimumSnafu { op: "affine_grid", min: 3_usize, actual: size.len() });
61 let n = size[0] as usize;
62 let ndim = size.len() - 2; let spatial_dims: Vec<usize> = size[2..].iter().map(|&s| s as usize).collect();
65 let mut grids = Vec::with_capacity(ndim);
66 for &dim_size in &spatial_dims {
67 let g = if align_corners {
68 Tensor::linspace(-1.0, 1.0, dim_size, DType::Float32)?
69 } else {
70 let start = -1.0 + 1.0 / dim_size as f64;
71 let end = 1.0 - 1.0 / dim_size as f64;
72 Tensor::linspace(start, end, dim_size, DType::Float32)?
73 };
74 grids.push(g);
75 }
76
77 let grid_refs: Vec<&Tensor> = grids.iter().collect();
78 let mesh = Tensor::meshgrid(&grid_refs, MeshgridIndexing::Ij)?;
79
80 let total_elements: usize = spatial_dims.iter().product();
81 let flat_shape = [total_elements as isize];
82 let mut components: Vec<Tensor> = Vec::with_capacity(ndim + 1);
83 for g in mesh.iter().rev() {
84 components.push(g.try_reshape(flat_shape)?);
85 }
86 components.push(Tensor::full(&[total_elements], 1.0, DType::Float32)?);
87
88 let comp_refs: Vec<&Tensor> = components.iter().collect();
89 let base_grid = Tensor::cat(&comp_refs, 0)?
90 .try_reshape([(ndim + 1) as isize, total_elements as isize])?
91 .try_transpose(0, 1)?;
92
93 let base_grid =
94 base_grid.try_unsqueeze(0)?.try_expand([n as isize, total_elements as isize, (ndim + 1) as isize])?;
95
96 let theta_t = theta.try_transpose(1, 2)?;
97 let output = base_grid.matmul(&theta_t)?;
98
99 let mut out_shape: Vec<isize> = vec![n as isize];
100 out_shape.extend(spatial_dims.iter().map(|&d| d as isize));
101 out_shape.push(ndim as isize);
102 output.try_reshape(&out_shape)
103 }
104
105 #[builder]
146 pub fn grid_sample(
147 &self,
148 grid: &Tensor,
149 #[builder(default)] mode: GridSampleMode,
150 #[builder(default)] padding_mode: GridSamplePaddingMode,
151 #[builder(default = false)] align_corners: bool,
152 ) -> Result<Tensor> {
153 let x_ndim = self.ndim()?;
154 snafu::ensure!(x_ndim >= 3, NdimMinimumSnafu { op: "grid_sample", min: 3_usize, actual: x_ndim });
155 let x_shape = self.shape()?;
156 let grid_shape = grid.shape()?;
157 let x_dims = svod_ir::shape::to_vec_usize(&x_shape).context(UOpSnafu)?;
158 let grid_dims = svod_ir::shape::to_vec_usize(&grid_shape).context(UOpSnafu)?;
159 let n_spatial = x_dims.len() - 2;
160
161 let n = x_dims[0];
162 let c = x_dims[1];
163 let spatial: Vec<usize> = x_dims[2..].to_vec();
164 let out_spatial: Vec<usize> = grid_dims[1..grid_dims.len() - 1].to_vec();
165 let spatial_prod: usize = spatial.iter().product();
166 let out_prod: usize = out_spatial.iter().product();
167 let dtype = self.uop().dtype();
168
169 let x_flat = self.try_reshape([n as isize, c as isize, spatial_prod as isize])?;
171
172 let grid_flat = grid.try_reshape([n as isize, out_prod as isize, n_spatial as isize])?;
174
175 let strides = compute_strides(&spatial);
177
178 let mut coords: Vec<Tensor> = Vec::with_capacity(n_spatial);
181 for (i, &dim_size) in spatial.iter().enumerate() {
182 let grid_idx = n_spatial - 1 - i;
183 let coord = slice_last_dim(&grid_flat, grid_idx, n, out_prod)?;
184 let denorm = gs_denormalize(&coord, dim_size, align_corners, &dtype)?;
185 coords.push(denorm);
186 }
187
188 let coords = match padding_mode {
190 GridSamplePaddingMode::Border => coords
191 .iter()
192 .enumerate()
193 .map(|(i, c)| {
194 let zero = Tensor::const_(0.0, dtype.clone());
195 let max_val = Tensor::const_((spatial[i] - 1) as f64, dtype.clone());
196 c.clamp().min(&zero).max(&max_val).call()
197 })
198 .collect::<Result<Vec<_>>>()?,
199 GridSamplePaddingMode::Reflection => coords
200 .iter()
201 .enumerate()
202 .map(|(i, c)| gs_reflect(c, spatial[i], align_corners, &dtype))
203 .collect::<Result<Vec<_>>>()?,
204 GridSamplePaddingMode::Zeros => coords,
205 };
206
207 let result = match mode {
208 GridSampleMode::Nearest => {
209 interpolate_nearest(&x_flat, &coords, &spatial, &strides, padding_mode, n, c, out_prod, &dtype)?
210 }
211 GridSampleMode::Linear => {
212 interpolate_linear(&x_flat, &coords, &spatial, &strides, padding_mode, n, c, out_prod, &dtype)?
213 }
214 GridSampleMode::Cubic => {
215 interpolate_cubic(&x_flat, &coords, &spatial, &strides, padding_mode, n, c, out_prod, &dtype)?
216 }
217 };
218
219 let mut out_shape: Vec<isize> = vec![n as isize, c as isize];
221 out_shape.extend(out_spatial.iter().map(|&d| d as isize));
222 result.try_reshape(&out_shape)
223 }
224}
225
226fn compute_strides(dims: &[usize]) -> Vec<usize> {
227 let n = dims.len();
228 let mut strides = vec![1usize; n];
229 for i in (0..n.saturating_sub(1)).rev() {
230 strides[i] = strides[i + 1] * dims[i + 1];
231 }
232 strides
233}
234
235fn slice_last_dim(t: &Tensor, idx: usize, n: usize, out_prod: usize) -> Result<Tensor> {
237 t.try_shrink([(0, n as isize), (0, out_prod as isize), (idx as isize, (idx + 1) as isize)])?.try_squeeze(Some(-1))
238}
239
240fn gs_denormalize(coord: &Tensor, dim_size: usize, align_corners: bool, dtype: &DType) -> Result<Tensor> {
242 if align_corners {
243 coord
245 .try_add(&Tensor::const_(1.0, dtype.clone()))?
246 .try_mul(&Tensor::const_(0.5 * (dim_size - 1) as f64, dtype.clone()))
247 } else {
248 coord
250 .try_add(&Tensor::const_(1.0, dtype.clone()))?
251 .try_mul(&Tensor::const_(dim_size as f64, dtype.clone()))?
252 .try_sub(&Tensor::const_(1.0, dtype.clone()))?
253 .try_mul(&Tensor::const_(0.5, dtype.clone()))
254 }
255}
256
257fn gs_reflect(coord: &Tensor, dim_size: usize, align_corners: bool, dtype: &DType) -> Result<Tensor> {
259 let (lo, hi) = if align_corners { (0.0, (dim_size - 1) as f64) } else { (-0.5, dim_size as f64 - 0.5) };
260 let rng = hi - lo;
261 if rng == 0.0 {
262 return Ok(Tensor::const_(lo, dtype.clone()));
263 }
264 let lo_t = Tensor::const_(lo, dtype.clone());
265 let rng_t = Tensor::const_(rng, dtype.clone());
266 let period_t = Tensor::const_(2.0 * rng, dtype.clone());
267
268 let shifted = coord.try_sub(&lo_t)?;
270 let t = shifted.try_sub(&shifted.try_div(&period_t)?.floor()?.try_mul(&period_t)?)?;
271
272 let two_rng_t = Tensor::const_(2.0 * rng, dtype.clone());
274 let reflected = two_rng_t.try_sub(&t)?;
275 let cond = rng_t.try_lt(&t)?; reflected.where_(&cond, &t)?.try_add(&lo_t)
277}
278
279fn build_flat_index(
281 indices: &[Tensor],
282 spatial: &[usize],
283 strides: &[usize],
284 padding_mode: GridSamplePaddingMode,
285) -> Result<(Tensor, Option<Tensor>)> {
286 let n_spatial = indices.len();
287 let mut flat_idx = Tensor::const_(ConstValue::Int(0), DType::Int32);
288 let mut valid_mask: Option<Tensor> = None;
289
290 for i in 0..n_spatial {
291 let idx = &indices[i];
292
293 if padding_mode == GridSamplePaddingMode::Zeros {
294 let zero_i = Tensor::const_(ConstValue::Int(0), DType::Int32);
295 let max_i = Tensor::const_(ConstValue::Int(spatial[i] as i64), DType::Int32);
296 let v = idx.try_ge(&zero_i)?.bitwise_and(&idx.try_lt(&max_i)?)?;
297 valid_mask = Some(match valid_mask {
298 Some(m) => m.bitwise_and(&v)?,
299 None => v,
300 });
301 }
302
303 let zero_i = Tensor::const_(ConstValue::Int(0), DType::Int32);
305 let max_i = Tensor::const_(ConstValue::Int((spatial[i] - 1) as i64), DType::Int32);
306 let safe_idx = idx.clamp().min(&zero_i).max(&max_i).call()?;
307
308 let stride_t = Tensor::const_(ConstValue::Int(strides[i] as i64), DType::Int32);
309 flat_idx = flat_idx.try_add(&safe_idx.try_mul(&stride_t)?)?;
310 }
311
312 Ok((flat_idx, valid_mask))
313}
314
315fn gather_and_mask(
317 x_flat: &Tensor,
318 flat_idx: &Tensor,
319 valid_mask: Option<&Tensor>,
320 n: usize,
321 c: usize,
322 out_prod: usize,
323 dtype: &DType,
324) -> Result<Tensor> {
325 let expanded_idx = flat_idx.try_unsqueeze(1)?.try_expand([n as isize, c as isize, out_prod as isize])?;
326 let mut gathered = x_flat.gather(2, &expanded_idx)?;
327 if let Some(mask) = valid_mask {
328 let mask = mask.try_unsqueeze(1)?.try_expand([n as isize, c as isize, out_prod as isize])?;
329 gathered = gathered.try_mul(&mask.cast(dtype.clone())?)?;
330 }
331 Ok(gathered)
332}
333
334#[allow(clippy::too_many_arguments)]
335fn interpolate_nearest(
336 x_flat: &Tensor,
337 coords: &[Tensor],
338 spatial: &[usize],
339 strides: &[usize],
340 padding_mode: GridSamplePaddingMode,
341 n: usize,
342 c: usize,
343 out_prod: usize,
344 dtype: &DType,
345) -> Result<Tensor> {
346 let rounded: Vec<Tensor> = coords.iter().map(|c| c.round()?.cast(DType::Int32)).collect::<Result<_>>()?;
348 let (flat_idx, valid_mask) = build_flat_index(&rounded, spatial, strides, padding_mode)?;
349 gather_and_mask(x_flat, &flat_idx, valid_mask.as_ref(), n, c, out_prod, dtype)
350}
351
352#[allow(clippy::too_many_arguments)]
353fn interpolate_linear(
354 x_flat: &Tensor,
355 coords: &[Tensor],
356 spatial: &[usize],
357 strides: &[usize],
358 padding_mode: GridSamplePaddingMode,
359 n: usize,
360 c: usize,
361 out_prod: usize,
362 dtype: &DType,
363) -> Result<Tensor> {
364 let n_spatial = coords.len();
365 let floors: Vec<Tensor> = coords.iter().map(|c| c.floor()).collect::<Result<_>>()?;
366 let fracs: Vec<Tensor> = coords.iter().zip(&floors).map(|(c, f)| c.try_sub(f)).collect::<Result<_>>()?;
367
368 let n_combos = 1usize << n_spatial;
370 let mut result = Tensor::const_(ConstValue::Float(0.0), dtype.clone());
371
372 for combo in 0..n_combos {
373 let mut weight = Tensor::const_(ConstValue::Float(1.0), dtype.clone());
374 let mut corner_indices: Vec<Tensor> = Vec::with_capacity(n_spatial);
375
376 for i in 0..n_spatial {
377 let use_ceil = (combo >> i) & 1 == 1;
378 let idx_f =
379 if use_ceil { floors[i].try_add(&Tensor::const_(1.0, dtype.clone()))? } else { floors[i].clone() };
380 let w = if use_ceil { fracs[i].clone() } else { Tensor::const_(1.0, dtype.clone()).try_sub(&fracs[i])? };
381 weight = weight.try_mul(&w)?;
382 corner_indices.push(idx_f.cast(DType::Int32)?);
383 }
384
385 let (flat_idx, valid_mask) = build_flat_index(&corner_indices, spatial, strides, padding_mode)?;
386 let gathered = gather_and_mask(x_flat, &flat_idx, valid_mask.as_ref(), n, c, out_prod, dtype)?;
387
388 let weight = weight.try_unsqueeze(1)?.try_expand([n as isize, c as isize, out_prod as isize])?;
389 result = result.try_add(&gathered.try_mul(&weight)?)?;
390 }
391
392 Ok(result)
393}
394
395#[allow(clippy::too_many_arguments)]
396fn interpolate_cubic(
397 x_flat: &Tensor,
398 coords: &[Tensor],
399 spatial: &[usize],
400 strides: &[usize],
401 padding_mode: GridSamplePaddingMode,
402 n: usize,
403 c: usize,
404 out_prod: usize,
405 dtype: &DType,
406) -> Result<Tensor> {
407 let n_spatial = coords.len();
408 let floors: Vec<Tensor> = coords.iter().map(|c| c.floor()).collect::<Result<_>>()?;
409 let fracs: Vec<Tensor> = coords.iter().zip(&floors).map(|(c, f)| c.try_sub(f)).collect::<Result<_>>()?;
410
411 let coeffs: Vec<[Tensor; 4]> = fracs.iter().map(|s| gs_cubic_coeffs(s, -0.75, dtype)).collect::<Result<_>>()?;
413
414 let n_combos = 4usize.pow(n_spatial as u32);
416 let mut result = Tensor::const_(ConstValue::Float(0.0), dtype.clone());
417
418 for combo in 0..n_combos {
419 let mut weight = Tensor::const_(ConstValue::Float(1.0), dtype.clone());
420 let mut corner_indices: Vec<Tensor> = Vec::with_capacity(n_spatial);
421
422 for i in 0..n_spatial {
423 let offset_idx = (combo / 4usize.pow(i as u32)) % 4;
424 let offset = offset_idx as f64 - 1.0; let idx_f = floors[i].try_add(&Tensor::const_(offset, dtype.clone()))?;
427 weight = weight.try_mul(&coeffs[i][offset_idx])?;
428 corner_indices.push(idx_f.cast(DType::Int32)?);
429 }
430
431 let (flat_idx, valid_mask) = build_flat_index(&corner_indices, spatial, strides, padding_mode)?;
432 let gathered = gather_and_mask(x_flat, &flat_idx, valid_mask.as_ref(), n, c, out_prod, dtype)?;
433
434 let weight = weight.try_unsqueeze(1)?.try_expand([n as isize, c as isize, out_prod as isize])?;
435 result = result.try_add(&gathered.try_mul(&weight)?)?;
436 }
437
438 Ok(result)
439}
440
441fn gs_cubic_coeffs(s: &Tensor, a: f64, dtype: &DType) -> Result<[Tensor; 4]> {
444 let one = Tensor::const_(1.0, dtype.clone());
445 let two = Tensor::const_(2.0, dtype.clone());
446
447 let sp1 = s.try_add(&one)?;
450 let c0 = sp1
451 .try_mul(&Tensor::const_(a, dtype.clone()))?
452 .try_sub(&Tensor::const_(5.0 * a, dtype.clone()))?
453 .try_mul(&sp1)?
454 .try_add(&Tensor::const_(8.0 * a, dtype.clone()))?
455 .try_mul(&sp1)?
456 .try_sub(&Tensor::const_(4.0 * a, dtype.clone()))?;
457
458 let c1 = s
461 .try_mul(&Tensor::const_(a + 2.0, dtype.clone()))?
462 .try_sub(&Tensor::const_(a + 3.0, dtype.clone()))?
463 .try_mul(s)?
464 .try_mul(s)?
465 .try_add(&one)?;
466
467 let sm1 = one.try_sub(s)?;
469 let c2 = sm1
470 .try_mul(&Tensor::const_(a + 2.0, dtype.clone()))?
471 .try_sub(&Tensor::const_(a + 3.0, dtype.clone()))?
472 .try_mul(&sm1)?
473 .try_mul(&sm1)?
474 .try_add(&Tensor::const_(1.0, dtype.clone()))?;
475
476 let sm2 = two.try_sub(s)?;
478 let c3 = sm2
479 .try_mul(&Tensor::const_(a, dtype.clone()))?
480 .try_sub(&Tensor::const_(5.0 * a, dtype.clone()))?
481 .try_mul(&sm2)?
482 .try_add(&Tensor::const_(8.0 * a, dtype.clone()))?
483 .try_mul(&sm2)?
484 .try_sub(&Tensor::const_(4.0 * a, dtype.clone()))?;
485
486 Ok([c0, c1, c2, c3])
487}