1use snafu::ResultExt;
4use strum::{Display, EnumString};
5
6use super::*;
7use crate::error::ShapeMismatchSnafu;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, EnumString, Display)]
11pub enum ScatterReduction {
12 #[strum(serialize = "sum")]
13 Sum,
14 #[strum(serialize = "prod")]
15 Prod,
16 #[strum(serialize = "amax")]
17 Amax,
18 #[strum(serialize = "amin")]
19 Amin,
20}
21
22impl Tensor {
23 #[track_caller]
25 pub fn gather(&self, dim: isize, index: &Tensor) -> Result<Self> {
26 let self_shape = self.shape()?;
27 let index_shape = index.shape()?;
28 let ndim = self_shape.len();
29 let dim = Self::normalize_axis(dim, ndim)?;
30
31 snafu::ensure!(
32 index_shape.len() == ndim,
33 ShapeMismatchSnafu {
34 context: "gather",
35 expected: format!("{ndim}D"),
36 actual: format!("{}D index", index_shape.len())
37 }
38 );
39
40 let self_dims = svod_ir::shape::to_vec_usize(&self_shape).context(UOpSnafu)?;
49 let index_dims = svod_ir::shape::to_vec_usize(&index_shape).context(UOpSnafu)?;
50
51 snafu::ensure!(
52 self_dims.iter().zip(&index_dims).enumerate().all(|(d, (s, i))| d == dim || s >= i),
53 ShapeMismatchSnafu {
54 context: "gather",
55 expected: "self[d] >= index[d] for d != dim".to_string(),
56 actual: format!("self={self_dims:?}, index={index_dims:?}")
57 }
58 );
59
60 let shrink: Vec<_> =
61 (0..ndim).map(|d| (0, (if d == dim { self_dims[d] } else { index_dims[d] }) as isize)).collect();
62 let x = self.try_shrink(&shrink)?.try_unsqueeze(-1)?.try_transpose(-1, dim as isize)?;
63
64 let arange = Tensor::arange(0, Some(self_dims[dim] as i64), None)?.cast(index.uop().dtype())?;
65 let mask = index.try_unsqueeze(-1)?.try_eq(&arange)?;
66
67 x.where_(&mask, &Self::new(x.uop().const_like(0)))?.sum_with().axes(-1).dtype(self.uop().dtype()).call()
68 }
69
70 #[track_caller]
75 pub fn index_select(&self, dim: isize, index: &Tensor) -> Result<Self> {
76 let self_shape = self.shape()?;
77 let ndim = self_shape.len();
78 let dim = Self::normalize_axis(dim, ndim)?;
79 let self_dims = svod_ir::shape::to_vec_usize(&self_shape).context(UOpSnafu)?;
84
85 let idx_len = index.shape()?[0].as_const().expect("index_select: index length must be concrete");
87 let mut idx_shape = vec![1isize; ndim];
88 idx_shape[dim] = idx_len as isize;
89 let idx_nd = index.try_reshape(&idx_shape)?;
90
91 let mut expand_shape: Vec<isize> = self_dims.iter().map(|&d| d as isize).collect();
93 expand_shape[dim] = idx_len as isize;
94 let idx_expanded = idx_nd.try_expand(&expand_shape)?;
95
96 self.gather(dim as isize, &idx_expanded)
97 }
98
99 pub fn one_hot_along_dim(&self, num_classes: usize, dim: isize) -> Result<Tensor> {
102 let ndim = self.ndim()?;
103 let norm_dim = Self::normalize_axis(dim, ndim)?;
104 let offset = ndim - norm_dim - 1;
105 let arange = Tensor::arange(0, Some(num_classes as i64), None)?;
106 let mut ar_shape = vec![1isize; 1 + offset];
107 ar_shape[0] = num_classes as isize;
108 self.try_eq(&arange.try_reshape(&ar_shape)?)
109 }
110
111 pub fn normalize_negative_indices(&self, dim_size: i64) -> Result<Tensor> {
113 let zero = Tensor::const_(ConstValue::Int(0), self.uop().dtype());
114 let dim_t = Tensor::const_(ConstValue::Int(dim_size), self.uop().dtype());
115 let neg_mask = self.try_lt(&zero)?;
116 self.try_add(&dim_t)?.where_(&neg_mask, self)
117 }
118
119 fn _pre_scatter(&self, dim: isize, index: &Tensor, src: &Tensor) -> Result<(Tensor, Tensor)> {
131 let self_shape = self.shape()?;
132 let index_shape = index.shape()?;
133 let src_shape = src.shape()?;
134 let ndim = self_shape.len();
135 let dim = Self::normalize_axis(dim, ndim)?;
136
137 let self_dims = svod_ir::shape::to_vec_usize(&self_shape).context(UOpSnafu)?;
138 let index_dims = svod_ir::shape::to_vec_usize(&index_shape).context(UOpSnafu)?;
139 let src_dims = svod_ir::shape::to_vec_usize(&src_shape).context(UOpSnafu)?;
140
141 snafu::ensure!(
142 index_shape.len() == ndim && src_shape.len() == ndim,
143 ShapeMismatchSnafu {
144 context: "scatter",
145 expected: format!("{ndim}D"),
146 actual: format!("index={}D, src={}D", index_shape.len(), src_shape.len())
147 }
148 );
149 snafu::ensure!(
150 self_dims
151 .iter()
152 .zip(&index_dims)
153 .zip(&src_dims)
154 .enumerate()
155 .all(|(d, ((s, i), sr))| { (d == dim || s >= i) && sr >= i }),
156 ShapeMismatchSnafu {
157 context: "scatter",
158 expected: "valid scatter shape constraints".to_string(),
159 actual: format!("self={self_dims:?}, index={index_dims:?}, src={src_dims:?}")
160 }
161 );
162
163 let shrink_ranges: Vec<(isize, isize)> = index_dims.iter().map(|&d| (0, d as isize)).collect();
165 let src = src.try_shrink(&shrink_ranges)?;
166
167 let mut expand_shape: Vec<isize> = index_dims.iter().map(|&d| d as isize).collect();
169 expand_shape.push(self_dims[dim] as isize);
170 let src = src.try_unsqueeze(-1)?.try_expand(&expand_shape)?.try_transpose(-1, dim as isize)?;
171
172 let mask = index.try_unsqueeze(-1)?.one_hot_along_dim(self_dims[dim], -1)?.try_transpose(-1, dim as isize)?;
174
175 let src_cur = src.shape()?;
177 let src_cur_dims = svod_ir::shape::to_vec_usize(&src_cur).context(UOpSnafu)?;
178 let padding: Vec<(isize, isize)> =
179 (0..ndim).map(|d| (0, (self_dims[d] as isize - src_cur_dims[d] as isize).max(0))).collect();
180 let needs_pad = padding.iter().any(|&(_, e)| e > 0);
181 let src = if needs_pad { src.try_pad(&padding)? } else { src };
182 let mask = if needs_pad { mask.try_pad(&padding)? } else { mask };
183
184 Ok((src, mask))
185 }
186
187 #[track_caller]
193 pub fn scatter(&self, dim: isize, index: &Tensor, src: &Tensor) -> Result<Tensor> {
194 let (src_p, mask_p) = self._pre_scatter(dim, index, src)?;
195 masked_setitem(self, &src_p, &mask_p, &[-1])
196 }
197
198 #[track_caller]
200 pub fn scatter_reduce(
201 &self,
202 dim: isize,
203 index: &Tensor,
204 src: &Tensor,
205 reduce: ScatterReduction,
206 include_self: bool,
207 ) -> Result<Tensor> {
208 let (src_p, mask_p) = self._pre_scatter(dim, index, src)?;
209 let dtype = src_p.uop().dtype();
210 let inv_mask = |a: &Tensor, b: &Tensor| -> Result<Tensor> {
211 let no_hit = mask_p.any(-1isize)?.logical_not()?;
212 a.where_(&no_hit, b)
213 };
214 let self_or = |identity_val: ConstValue| -> Result<Tensor> {
215 if include_self { Ok(self.clone()) } else { inv_mask(self, &Tensor::const_(identity_val, dtype.clone())) }
216 };
217
218 match reduce {
219 ScatterReduction::Sum => {
220 let zero = Tensor::const_(ConstValue::Int(0), dtype.clone());
221 let reduced = src_p.where_(&mask_p, &zero)?.sum_with().axes(-1isize).call()?;
222 reduced.try_add(&self_or(ConstValue::Int(0))?)
223 }
224 ScatterReduction::Prod => {
225 let one = Tensor::const_(ConstValue::Int(1), dtype.clone());
226 let reduced = src_p.where_(&mask_p, &one)?.prod_with().axes(-1isize).call()?;
227 reduced.try_mul(&self_or(ConstValue::Int(1))?)
228 }
229 ScatterReduction::Amax => {
230 let min_val =
231 if dtype.is_float() { ConstValue::Float(f64::NEG_INFINITY) } else { ConstValue::Int(i64::MIN) };
232 let fill = Tensor::const_(min_val, dtype.clone());
233 let reduced = src_p.where_(&mask_p, &fill)?.max(-1isize)?;
234 reduced.maximum(&self_or(min_val)?)
235 }
236 ScatterReduction::Amin => {
237 let max_val =
238 if dtype.is_float() { ConstValue::Float(f64::INFINITY) } else { ConstValue::Int(i64::MAX) };
239 let fill = Tensor::const_(max_val, dtype.clone());
240 let reduced = src_p.where_(&mask_p, &fill)?.min(-1isize)?;
241 reduced.minimum(&self_or(max_val)?)
242 }
243 }
244 }
245
246 #[track_caller]
254 pub fn masked_select(&self, mask: &Tensor) -> Result<Tensor> {
255 let x = self.flatten()?;
256 let mask_flat = mask.broadcast_to(&self.shape()?)?.flatten()?;
257 let mask_cumsum = mask_flat.cast(svod_dtype::DType::Int32)?.cumsum(0)?;
258 let n = mask_flat.numel()?;
260 let mut count_t = mask_cumsum.try_shrink([((n - 1) as isize, n as isize)])?;
261 count_t.realize()?;
262 let count_t = count_t.as_ndarray::<i32>()?;
263 let count = count_t[[0]] as usize;
264 if count == 0 {
265 return Ok(Tensor::empty_zero(self.uop().dtype()));
266 }
267
268 let zeros = Tensor::full(&[count], ConstValue::Int(0), svod_dtype::DType::Int32)?;
270 let ones = Tensor::full(&[n], ConstValue::Int(1), svod_dtype::DType::Int32)?;
271 let idxs = zeros.scatter_reduce(0, &mask_cumsum, &ones, ScatterReduction::Sum, false)?.cumsum(0)?;
272 x.gather(0, &idxs)
273 }
274
275 #[track_caller]
280 pub fn compress(&self, condition: &[bool], axis: Option<isize>) -> Result<Tensor> {
281 let x = if axis.is_none() { self.flatten()? } else { self.clone() };
282 let axis = axis.unwrap_or(0);
283 let indices: Vec<i64> = condition.iter().enumerate().filter(|(_, v)| **v).map(|(i, _)| i as i64).collect();
284 let idx = Tensor::from_slice(&indices);
285 x.index_select(axis, &idx)
286 }
287
288 #[track_caller]
294 pub fn sort(&self, dim: isize, descending: bool) -> Result<(Tensor, Tensor)> {
295 let shape = self.shape()?;
296 let ndim = shape.len();
297 let dim = Self::normalize_axis(dim, ndim)?;
298 let orig_len = shape[dim]
299 .as_const()
300 .ok_or_else(|| crate::error::Error::SymbolicShapeUnsupported { operation: "sort".into() })?;
301
302 if orig_len <= 1 {
303 let idx = Tensor::full(
304 &svod_ir::shape::to_vec_usize(&shape).unwrap(),
305 ConstValue::Int(0),
306 svod_dtype::DType::Int32,
307 )?;
308 return Ok((self.clone(), idx));
309 }
310
311 let n_stages = (orig_len as u64 - 1).ilog2() as usize + 1;
312 let padded_len = 1usize << n_stages;
313
314 let sentinel = if descending {
316 if self.uop().dtype().is_float() { f64::NEG_INFINITY } else { i64::MIN as f64 }
317 } else if self.uop().dtype().is_float() {
318 f64::INFINITY
319 } else {
320 i64::MAX as f64
321 };
322 let mut padding = vec![(0isize, 0isize); ndim];
323 padding[dim] = (0, (padded_len - orig_len) as isize);
324 let mut x = self.try_pad_value(&padding, sentinel)?;
325
326 let unflatten_sizes: Vec<isize> = vec![2; n_stages];
328 x = x.unflatten(dim as isize, &unflatten_sizes)?;
329
330 for stage in 1..=n_stages {
332 if stage != n_stages {
333 let crossover_dim = (dim + n_stages - stage - 1) as isize;
335 let halves = x.split(&[1, 1], crossover_dim)?;
336 let (blue, green) = (&halves[0], &halves[1]);
337 let flip_dims: Vec<isize> = (1..=(stage + (ndim - dim))).map(|i| -(i as isize)).collect();
338 x = Tensor::cat(&[blue, &green.flip(&flip_dims)?], crossover_dim)?.contiguous();
339 }
340
341 for substage in (0..stage).rev() {
342 let partner_dim = (dim + n_stages - substage - 1) as isize;
343 let parts = x.split(&[1, 1], partner_dim)?;
344 let (x_top, x_bottom) = (&parts[0], &parts[1]);
345 let x_larger = x_top.maximum(x_bottom)?;
346 let x_smaller = x_top.minimum(x_bottom)?;
347 x = if descending {
348 Tensor::cat(&[&x_larger, &x_smaller], partner_dim)?
349 } else {
350 Tensor::cat(&[&x_smaller, &x_larger], partner_dim)?
351 }
352 .contiguous();
353 }
354
355 if stage != n_stages {
356 let crossover_dim = (dim + n_stages - stage - 1) as isize;
358 let halves = x.split(&[1, 1], crossover_dim)?;
359 let (blue, flipped_green) = (&halves[0], &halves[1]);
360 let flip_dims: Vec<isize> = (1..=(stage + (ndim - dim))).map(|i| -(i as isize)).collect();
361 x = Tensor::cat(&[blue, &flipped_green.flip(&flip_dims)?], crossover_dim)?;
362 }
363 }
364
365 let flatten_end = dim + n_stages - 1;
367 let cur_shape = x.shape()?;
369 let cur_dims = svod_ir::shape::to_vec_usize(&cur_shape).context(UOpSnafu)?;
370 let mut flat_shape: Vec<isize> = Vec::new();
371 for (i, &d) in cur_dims.iter().enumerate() {
372 if i == dim {
373 flat_shape.push(padded_len as isize);
374 } else if i > dim && i <= flatten_end {
375 continue;
376 } else {
377 flat_shape.push(d as isize);
378 }
379 }
380 x = x.try_reshape(&flat_shape)?;
381
382 let x_shape = x.shape()?;
384 let x_dims = svod_ir::shape::to_vec_usize(&x_shape).context(UOpSnafu)?;
385 let shrink_ranges: Vec<(isize, isize)> =
386 x_dims.iter().enumerate().map(|(d, &s)| (0, if d == dim { orig_len } else { s } as isize)).collect();
387 x = x.try_shrink(&shrink_ranges)?;
388
389 let tril_2d = Tensor::full(&[orig_len, orig_len], true, svod_dtype::DType::Bool)?.tril(0)?;
394 let mut tril_reshape: Vec<isize> = vec![1; ndim + 1];
395 tril_reshape[dim] = orig_len as isize;
396 tril_reshape[dim + 1] = orig_len as isize;
397 let tril_mask = tril_2d.try_reshape(&tril_reshape)?;
398
399 let compute_counts = |t: &Tensor| -> Result<Tensor> {
401 let eq = t.try_unsqueeze(dim as isize)?.try_eq(&t.try_unsqueeze((dim + 1) as isize)?)?;
402 eq.bitwise_and(&tril_mask)?.sum((dim + 1) as isize)
403 };
404
405 let count_orig = compute_counts(self)?;
406 let count_sorted = compute_counts(&x)?;
407
408 let val_match = self.try_unsqueeze((dim + 1) as isize)?.try_eq(&x.try_unsqueeze(dim as isize)?)?;
410 let cnt_match =
411 count_orig.try_unsqueeze((dim + 1) as isize)?.try_eq(&count_sorted.try_unsqueeze(dim as isize)?)?;
412 let cond = val_match.bitwise_and(&cnt_match)?;
413
414 let mut idx_shape = vec![1isize; ndim + 1];
416 idx_shape[dim] = orig_len as isize;
417 let idx = (cond
418 .cast(svod_dtype::DType::Int32)?
419 .try_mul(&Tensor::arange(0, Some(orig_len as i64), None)?.try_reshape(&idx_shape)?)?)
420 .sum(dim as isize)?;
421
422 Ok((x, idx))
423 }
424
425 #[track_caller]
431 pub fn topk(&self, k: usize, dim: isize, largest: bool) -> Result<(Tensor, Tensor)> {
432 let shape = self.shape()?;
433 let ndim = shape.len();
434 let norm_dim = Self::normalize_axis(dim, ndim)?;
435 let (x, idx) = self.sort(dim, largest)?;
436 let x_shape = x.shape()?;
438 let x_dims = svod_ir::shape::to_vec_usize(&x_shape).context(UOpSnafu)?;
439 let shrink: Vec<(isize, isize)> =
440 x_dims.iter().enumerate().map(|(d, &s)| (0, if d == norm_dim { k } else { s } as isize)).collect();
441 Ok((x.try_shrink(&shrink)?, idx.try_shrink(&shrink)?))
442 }
443
444 #[track_caller]
450 pub fn nonzero(&self) -> Result<Tensor> {
451 let shape = self.shape()?;
452 let ndim = shape.len();
453 let dims = svod_ir::shape::to_vec_usize(&shape).context(UOpSnafu)?;
454 let numel: usize = dims.iter().product();
455
456 let mask = self.try_ne(&Tensor::const_(ConstValue::Int(0), self.uop().dtype()))?.flatten()?;
457
458 let coords: Vec<Tensor> = (0..ndim)
460 .map(|i| {
461 let ar = Tensor::arange(0, Some(dims[i] as i64), None)?;
462 let mut rshape = vec![1isize; ndim];
463 rshape[i] = dims[i] as isize;
464 let expand_shape: Vec<isize> = dims.iter().map(|&d| d as isize).collect();
465 ar.try_reshape(&rshape)?.try_expand(&expand_shape)?.flatten()
466 })
467 .collect::<Result<Vec<_>>>()?;
468
469 let coords_refs: Vec<&Tensor> = coords.iter().collect();
470 let indices = Tensor::stack(&coords_refs, -1)?; let expanded_mask = mask.try_unsqueeze(-1)?.try_expand([numel as isize, ndim as isize])?;
474 let selected = indices.masked_select(&expanded_mask)?;
475 selected.try_reshape([-1, ndim as isize])
476 }
477
478 #[track_caller]
481 pub fn reverse_sequence(&self, sequence_lens: &Tensor, time_axis: usize, batch_axis: usize) -> Result<Self> {
482 let dims = svod_ir::shape::to_vec_usize(&self.shape()?).context(UOpSnafu)?;
483 let ndim = dims.len();
484 let time_len = dims[time_axis];
485
486 let mut perm: Vec<usize> = (0..ndim).collect();
488 perm.swap(0, time_axis);
489 let batch_pos = if batch_axis == 0 {
490 time_axis
491 } else if batch_axis == time_axis {
492 0
493 } else {
494 batch_axis
495 };
496 perm.swap(1, batch_pos);
497 let perm_i: Vec<isize> = perm.iter().map(|&p| p as isize).collect();
498 let work = self.try_permute(&perm_i)?;
499 let work_dims = svod_ir::shape::to_vec_usize(&work.shape()?).context(UOpSnafu)?;
500
501 let idx_dt = sequence_lens.uop().dtype();
503 let t = Tensor::arange(0, Some(time_len as i64), None)?.cast(idx_dt.clone())?.try_unsqueeze(1)?;
504 let sl = sequence_lens.try_unsqueeze(0)?;
505
506 let one = Tensor::const_(ConstValue::Int(1), idx_dt);
508 let reversed_t = sl.try_sub(&one)?.try_sub(&t)?;
509 let mask = t.try_lt(&sl)?;
510 let idx = reversed_t.where_(&mask, &t)?;
511
512 let expand_shape: Vec<isize> = work_dims.iter().map(|&d| d as isize).collect();
514 let idx = idx.try_reshape(&expand_shape[..2])?.try_expand(&expand_shape)?;
515 let result = work.gather(0, &idx)?;
516
517 let mut inv_perm = vec![0usize; ndim];
519 for (i, &p) in perm.iter().enumerate() {
520 inv_perm[p] = i;
521 }
522 let inv_perm_i: Vec<isize> = inv_perm.iter().map(|&p| p as isize).collect();
523 result.try_permute(&inv_perm_i)
524 }
525
526 pub fn gather_nd(&self, indices: &Tensor, batch_dims: usize) -> Result<Tensor> {
532 let x_shape = self.shape()?;
533 let x_dims = svod_ir::shape::to_vec_usize(&x_shape).context(UOpSnafu)?;
534 let idx_shape = indices.shape()?;
535 let idx_dims = svod_ir::shape::to_vec_usize(&idx_shape).context(UOpSnafu)?;
536 let last_idx_dim = *idx_dims.last().unwrap();
537
538 if batch_dims == 0 {
539 let strides: Vec<i64> =
540 (0..last_idx_dim).map(|k| x_dims[k + 1..last_idx_dim].iter().product::<usize>() as i64).collect();
541 let inner: usize = x_dims[last_idx_dim..].iter().product();
542 let outer = x_dims[..last_idx_dim].iter().product::<usize>();
543
544 let mut flat_idx = Tensor::const_(ConstValue::Int(0), DType::Int64);
545 for (k, stride) in strides.iter().enumerate() {
546 let mut ranges: Vec<(isize, isize)> = idx_dims.iter().map(|&s| (0, s as isize)).collect();
547 ranges[idx_dims.len() - 1] = (k as isize, k as isize + 1);
548 let idx_k = indices.try_shrink(&ranges)?.try_squeeze(Some(-1))?;
549 let stride_t = Tensor::const_(ConstValue::Int(*stride), DType::Int64);
550 flat_idx = flat_idx.try_add(&idx_k.cast(DType::Int64)?.try_mul(&stride_t)?)?;
551 }
552
553 let x_flat = self.try_reshape([outer as isize, inner as isize])?;
554 let gather_outer: Vec<isize> = idx_dims[..idx_dims.len() - 1].iter().map(|&d| d as isize).collect();
555 let num_gathers: usize = gather_outer.iter().map(|&d| d as usize).product();
556
557 let flat_idx_2d = flat_idx
558 .try_reshape([num_gathers as isize, 1])?
559 .try_expand([num_gathers as isize, inner as isize])?
560 .cast(DType::Int32)?;
561 let result = x_flat.gather(0, &flat_idx_2d)?;
562
563 let mut out_shape = gather_outer;
564 for &d in &x_dims[last_idx_dim..] {
565 out_shape.push(d as isize);
566 }
567 result.try_reshape(&out_shape)
568 } else {
569 let batch_size: usize = x_dims[..batch_dims].iter().product();
570 let inner_x: Vec<usize> = x_dims[batch_dims..].to_vec();
571 let inner_idx: Vec<usize> = idx_dims[batch_dims..].to_vec();
572
573 let x_flat = self.try_reshape(
574 std::iter::once(batch_size as isize).chain(inner_x.iter().map(|&d| d as isize)).collect::<Vec<_>>(),
575 )?;
576 let idx_flat = indices.try_reshape(
577 std::iter::once(batch_size as isize).chain(inner_idx.iter().map(|&d| d as isize)).collect::<Vec<_>>(),
578 )?;
579
580 let last_inner = *inner_idx.last().unwrap();
581 let strides: Vec<i64> =
582 (0..last_inner).map(|k| inner_x[k + 1..last_inner].iter().product::<usize>() as i64).collect();
583
584 let mut flat_idx = Tensor::const_(ConstValue::Int(0), DType::Int64);
585 let idx_flat_shape = idx_flat.shape()?;
586 let idx_flat_dims = svod_ir::shape::to_vec_usize(&idx_flat_shape).context(UOpSnafu)?;
587 for (k, stride) in strides.iter().enumerate() {
588 let mut ranges: Vec<(isize, isize)> = idx_flat_dims.iter().map(|&s| (0, s as isize)).collect();
589 ranges[idx_flat_dims.len() - 1] = (k as isize, k as isize + 1);
590 let idx_k = idx_flat.try_shrink(&ranges)?.try_squeeze(Some(-1))?;
591 let stride_t = Tensor::const_(ConstValue::Int(*stride), DType::Int64);
592 flat_idx = flat_idx.try_add(&idx_k.cast(DType::Int64)?.try_mul(&stride_t)?)?;
593 }
594
595 let batch_stride = inner_x[..last_inner].iter().product::<usize>();
596 let batch_offset_arr = Tensor::arange(0, Some(batch_size as i64), None)?
597 .try_mul(&Tensor::from_slice([batch_stride as i64]))?;
598 let gather_inner = idx_flat_dims[1..idx_flat_dims.len() - 1].iter().product::<usize>();
599 flat_idx = flat_idx.try_reshape([batch_size as isize, gather_inner as isize])?;
600 let batch_offset = batch_offset_arr
601 .try_reshape([batch_size as isize, 1])?
602 .try_expand([batch_size as isize, gather_inner as isize])?;
603 flat_idx = flat_idx.try_add(&batch_offset)?;
604
605 let remaining: usize = inner_x[last_inner..].iter().product();
606 let x_2d = x_flat.try_reshape([(batch_size * batch_stride) as isize, remaining as isize])?;
607 let fi = flat_idx
608 .try_reshape([(batch_size * gather_inner) as isize, 1])?
609 .try_expand([(batch_size * gather_inner) as isize, remaining as isize])?
610 .cast(DType::Int32)?;
611 let result = x_2d.gather(0, &fi)?;
612
613 let mut out_shape: Vec<isize> = x_dims[..batch_dims].iter().map(|&d| d as isize).collect();
614 out_shape.extend(inner_idx[..inner_idx.len() - 1].iter().map(|&d| d as isize));
615 out_shape.extend(inner_x[last_inner..].iter().map(|&d| d as isize));
616 result.try_reshape(&out_shape)
617 }
618 }
619
620 pub fn scatter_nd(&self, indices: &Tensor, updates: &Tensor, reduction: &str) -> Result<Tensor> {
622 let x_shape = self.shape()?;
623 let x_dims = svod_ir::shape::to_vec_usize(&x_shape).context(UOpSnafu)?;
624 let idx_shape = indices.shape()?;
625 let last_idx_dim = idx_shape[idx_shape.len() - 1].as_const().unwrap();
626 let strides: Vec<i64> =
627 (0..last_idx_dim).map(|k| x_dims[k + 1..last_idx_dim].iter().product::<usize>() as i64).collect();
628 let x_numel: usize = x_dims.iter().product();
629 let inner: usize = x_dims[last_idx_dim..].iter().product();
630 let outer = x_numel / inner;
631 let x_flat = self.try_reshape([outer as isize, inner as isize])?;
632 let idx_splits: Vec<Tensor> = (0..last_idx_dim)
633 .map(|k| {
634 let mut ranges: Vec<(isize, isize)> =
635 idx_shape.iter().map(|s| (0, s.as_const().unwrap() as isize)).collect();
636 ranges[idx_shape.len() - 1] = (k as isize, k as isize + 1);
637 let slice = indices.try_shrink(&ranges)?;
638 slice.try_squeeze(Some(-1))
639 })
640 .collect::<Result<_>>()?;
641 let mut flat_idx = Tensor::const_(ConstValue::Int(0), DType::Int64);
642 for (k, idx_k) in idx_splits.iter().enumerate() {
643 let stride_t = Tensor::const_(ConstValue::Int(strides[k]), DType::Int64);
644 flat_idx = flat_idx.try_add(&idx_k.cast(DType::Int64)?.try_mul(&stride_t)?)?;
645 }
646 let upd_shape = updates.shape()?;
647 let upd_outer: usize = upd_shape[..upd_shape.len() - (x_dims.len() - last_idx_dim)]
648 .iter()
649 .map(|s| s.as_const().unwrap())
650 .product();
651 let upd_flat = updates.try_reshape([upd_outer as isize, inner as isize])?;
652 let flat_idx =
653 flat_idx.try_reshape([upd_outer as isize, 1])?.try_expand([upd_outer as isize, inner as isize])?;
654 let flat_idx_i32 = flat_idx.cast(DType::Int32)?;
655 let mut result = match reduction {
656 "none" => x_flat.scatter(0, &flat_idx_i32, &upd_flat)?,
657 "add" => x_flat.scatter_reduce(0, &flat_idx_i32, &upd_flat, ScatterReduction::Sum, true)?,
658 "mul" => x_flat.scatter_reduce(0, &flat_idx_i32, &upd_flat, ScatterReduction::Prod, true)?,
659 "max" => x_flat.scatter_reduce(0, &flat_idx_i32, &upd_flat, ScatterReduction::Amax, true)?,
660 "min" => x_flat.scatter_reduce(0, &flat_idx_i32, &upd_flat, ScatterReduction::Amin, true)?,
661 _ => {
662 return Err(crate::error::Error::IrConstruction {
663 details: format!("ScatterND: unsupported reduction '{reduction}'"),
664 });
665 }
666 };
667 let out_shape: Vec<isize> = x_dims.iter().map(|&d| d as isize).collect();
668 result = result.try_reshape(&out_shape)?;
669 Ok(result)
670 }
671
672 pub fn tensor_scatter(
674 &self,
675 update: &Tensor,
676 write_indices: Option<&Tensor>,
677 mode: &str,
678 axis: isize,
679 ) -> Result<Tensor> {
680 let data_shape = self.shape()?;
681 let ndim = data_shape.len();
682 let axis = Self::normalize_axis(axis, ndim)?;
683 let data_dims = svod_ir::shape::to_vec_usize(&data_shape).context(UOpSnafu)?;
684 let update_dims = svod_ir::shape::to_vec_usize(&update.shape()?).context(UOpSnafu)?;
685
686 let batch_size = data_dims[0];
687 let max_seq = data_dims[axis];
688 let seq_len = update_dims[axis];
689
690 let b_total: usize = data_dims[..axis].iter().product();
691 let features: usize = data_dims[axis + 1..].iter().product();
692
693 let write_idx = if let Some(wi) = write_indices {
694 wi.cast(DType::Int32)?
695 } else {
696 Tensor::full(&[batch_size], ConstValue::Int(0), DType::Int32)?
697 };
698
699 let wi_flat = if axis > 1 {
700 let mut wi_reshape: Vec<isize> = vec![batch_size as isize];
701 wi_reshape.extend(std::iter::repeat_n(1, axis - 1));
702 let wi_expand: Vec<isize> = data_dims[..axis].iter().map(|&d| d as isize).collect();
703 write_idx.try_reshape(&wi_reshape)?.try_expand(&wi_expand)?.try_reshape([b_total as isize])?
704 } else {
705 write_idx
706 };
707
708 let data_flat = self.try_reshape([(b_total * max_seq) as isize, features as isize])?;
709 let updates_flat = update.try_reshape([(b_total * seq_len) as isize, features as isize])?;
710
711 let batch_offset = Tensor::arange(0, Some(b_total as i64), None)?
712 .cast(DType::Int32)?
713 .try_mul(&Tensor::const_(ConstValue::Int(max_seq as i64), DType::Int32))?
714 .try_reshape([b_total as isize, 1])?;
715
716 let wi_2d = wi_flat.try_reshape([b_total as isize, 1])?;
717 let seq_arange =
718 Tensor::arange(0, Some(seq_len as i64), None)?.cast(DType::Int32)?.try_reshape([1, seq_len as isize])?;
719 let mut row_idx = wi_2d.try_add(&seq_arange)?;
720
721 if mode == "circular" {
722 let max_seq_t = Tensor::const_(ConstValue::Int(max_seq as i64), DType::Int32);
723 row_idx = row_idx.try_mod(&max_seq_t)?;
724 }
725
726 let flat_idx = batch_offset
727 .try_add(&row_idx)?
728 .try_reshape([(b_total * seq_len) as isize, 1])?
729 .try_expand([(b_total * seq_len) as isize, features as isize])?;
730
731 let result = data_flat.scatter(0, &flat_idx, &updates_flat)?;
732
733 let out_shape: Vec<isize> = data_dims.iter().map(|&d| d as isize).collect();
734 result.try_reshape(&out_shape)
735 }
736}
737
738fn masked_setitem(target: &Tensor, values: &Tensor, mask: &Tensor, axes: &[isize]) -> Result<Tensor> {
744 let mut mask = mask.clone();
745 let mut values = values.clone();
746
747 for &dim in axes.iter().rev() {
749 let shape = mask.shape()?;
750 let ndim = shape.len();
751 let norm_dim = Tensor::normalize_axis(dim, ndim)?;
752 let dim_size = shape[norm_dim].as_const().unwrap();
753 let ones = vec![1usize; dim_size];
754 let mask_slices = mask.split(&ones, dim)?;
755 let val_slices = values.split(&ones, dim)?;
756 let (mut acc_mask, mut acc_vals) = (mask_slices[0].clone(), val_slices[0].clone());
757 for (m, v) in mask_slices[1..].iter().zip(&val_slices[1..]) {
758 acc_vals = v.where_(m, &acc_vals)?;
760 acc_mask = acc_mask.bitwise_or(m)?;
761 }
762 mask = acc_mask;
763 values = acc_vals;
764 }
765
766 for &dim in axes.iter().rev() {
768 mask = mask.try_squeeze(Some(dim))?;
769 values = values.try_squeeze(Some(dim))?;
770 }
771
772 values.where_(&mask, target)
774}