Skip to main content

svod_tensor/
indexing.rs

1//! Indexing operations for Tensors.
2
3use snafu::ResultExt;
4use strum::{Display, EnumString};
5
6use super::*;
7use crate::error::ShapeMismatchSnafu;
8
9/// Reduction mode for scatter operations.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, EnumString, Display)]
11pub enum ScatterReduction {
12    #[strum(serialize = "sum")]
13    Sum,
14    #[strum(serialize = "prod")]
15    Prod,
16    #[strum(serialize = "amax")]
17    Amax,
18    #[strum(serialize = "amin")]
19    Amin,
20}
21
22impl Tensor {
23    /// Gather values along an axis specified by `dim`, using `index` for element selection.
24    #[track_caller]
25    pub fn gather(&self, dim: isize, index: &Tensor) -> Result<Self> {
26        let self_shape = self.shape()?;
27        let index_shape = index.shape()?;
28        let ndim = self_shape.len();
29        let dim = Self::normalize_axis(dim, ndim)?;
30
31        snafu::ensure!(
32            index_shape.len() == ndim,
33            ShapeMismatchSnafu {
34                context: "gather",
35                expected: format!("{ndim}D"),
36                actual: format!("{}D index", index_shape.len())
37            }
38        );
39
40        // TODO(symbolic-batch): both `to_vec_usize` calls require every dim of
41        // both `self` and `index` to be concrete. The arithmetic that uses
42        // them — the size-comparison loop and the `shrink` bounds — only
43        // needs the dims along which we shrink, not the symbolic prefix
44        // (typically a JIT batch bound to a `BoundVariable`). The symbolic
45        // dim could be passed through as `SInt`, and the comparison could be
46        // restricted to dims that are concrete on both sides. As-is, gather
47        // is unusable on tensors whose shape contains any symbolic dim.
48        let self_dims = svod_ir::shape::to_vec_usize(&self_shape).context(UOpSnafu)?;
49        let index_dims = svod_ir::shape::to_vec_usize(&index_shape).context(UOpSnafu)?;
50
51        snafu::ensure!(
52            self_dims.iter().zip(&index_dims).enumerate().all(|(d, (s, i))| d == dim || s >= i),
53            ShapeMismatchSnafu {
54                context: "gather",
55                expected: "self[d] >= index[d] for d != dim".to_string(),
56                actual: format!("self={self_dims:?}, index={index_dims:?}")
57            }
58        );
59
60        let shrink: Vec<_> =
61            (0..ndim).map(|d| (0, (if d == dim { self_dims[d] } else { index_dims[d] }) as isize)).collect();
62        let x = self.try_shrink(&shrink)?.try_unsqueeze(-1)?.try_transpose(-1, dim as isize)?;
63
64        let arange = Tensor::arange(0, Some(self_dims[dim] as i64), None)?.cast(index.uop().dtype())?;
65        let mask = index.try_unsqueeze(-1)?.try_eq(&arange)?;
66
67        x.where_(&mask, &Self::new(x.uop().const_like(0)))?.sum_with().axes(-1).dtype(self.uop().dtype()).call()
68    }
69
70    /// Select elements along `dim` using a 1D index tensor.
71    ///
72    /// For input shape `[A, B, C]` with `dim=1` and index shape `[K]`,
73    /// returns shape `[A, K, C]`.
74    #[track_caller]
75    pub fn index_select(&self, dim: isize, index: &Tensor) -> Result<Self> {
76        let self_shape = self.shape()?;
77        let ndim = self_shape.len();
78        let dim = Self::normalize_axis(dim, ndim)?;
79        // TODO(symbolic-batch): `self_dims` is consumed only to build
80        // `expand_shape` below (line 90). Forcing every dim through `usize`
81        // makes this unusable when the input has a symbolic dim (e.g. a JIT
82        // batch). The same SInt-aware `try_expand` shape would suffice.
83        let self_dims = svod_ir::shape::to_vec_usize(&self_shape).context(UOpSnafu)?;
84
85        // Reshape 1D index [K] → [1, ..., K, ..., 1] matching input ndim
86        let idx_len = index.shape()?[0].as_const().expect("index_select: index length must be concrete");
87        let mut idx_shape = vec![1isize; ndim];
88        idx_shape[dim] = idx_len as isize;
89        let idx_nd = index.try_reshape(&idx_shape)?;
90
91        // Expand to [self[0], ..., K, ..., self[-1]] (K at dim position)
92        let mut expand_shape: Vec<isize> = self_dims.iter().map(|&d| d as isize).collect();
93        expand_shape[dim] = idx_len as isize;
94        let idx_expanded = idx_nd.try_expand(&expand_shape)?;
95
96        self.gather(dim as isize, &idx_expanded)
97    }
98
99    /// One-hot encoding: self == arange(num_classes) broadcast along dim.
100    /// Returns a boolean tensor with True at the class positions.
101    pub fn one_hot_along_dim(&self, num_classes: usize, dim: isize) -> Result<Tensor> {
102        let ndim = self.ndim()?;
103        let norm_dim = Self::normalize_axis(dim, ndim)?;
104        let offset = ndim - norm_dim - 1;
105        let arange = Tensor::arange(0, Some(num_classes as i64), None)?;
106        let mut ar_shape = vec![1isize; 1 + offset];
107        ar_shape[0] = num_classes as isize;
108        self.try_eq(&arange.try_reshape(&ar_shape)?)
109    }
110
111    /// Normalize negative indices: `indices[i] = indices[i] < 0 ? indices[i] + dim_size : indices[i]`
112    pub fn normalize_negative_indices(&self, dim_size: i64) -> Result<Tensor> {
113        let zero = Tensor::const_(ConstValue::Int(0), self.uop().dtype());
114        let dim_t = Tensor::const_(ConstValue::Int(dim_size), self.uop().dtype());
115        let neg_mask = self.try_lt(&zero)?;
116        self.try_add(&dim_t)?.where_(&neg_mask, self)
117    }
118
119    // =========================================================================
120    // Scatter Operations (Tinygrad tensor.py:2641-2728)
121    // =========================================================================
122
123    /// Internal: prepare src and mask for scatter operations.
124    ///
125    /// Validates shapes, shrinks src to index.shape, then:
126    ///  - src: unsqueeze(-1), expand(self.shape[dim]), transpose(-1, dim)
127    ///  - mask: one_hot_along_dim(self.shape[dim]), transpose(-1, dim)
128    ///
129    /// Both are padded to self.shape on non-dim axes.
130    fn _pre_scatter(&self, dim: isize, index: &Tensor, src: &Tensor) -> Result<(Tensor, Tensor)> {
131        let self_shape = self.shape()?;
132        let index_shape = index.shape()?;
133        let src_shape = src.shape()?;
134        let ndim = self_shape.len();
135        let dim = Self::normalize_axis(dim, ndim)?;
136
137        let self_dims = svod_ir::shape::to_vec_usize(&self_shape).context(UOpSnafu)?;
138        let index_dims = svod_ir::shape::to_vec_usize(&index_shape).context(UOpSnafu)?;
139        let src_dims = svod_ir::shape::to_vec_usize(&src_shape).context(UOpSnafu)?;
140
141        snafu::ensure!(
142            index_shape.len() == ndim && src_shape.len() == ndim,
143            ShapeMismatchSnafu {
144                context: "scatter",
145                expected: format!("{ndim}D"),
146                actual: format!("index={}D, src={}D", index_shape.len(), src_shape.len())
147            }
148        );
149        snafu::ensure!(
150            self_dims
151                .iter()
152                .zip(&index_dims)
153                .zip(&src_dims)
154                .enumerate()
155                .all(|(d, ((s, i), sr))| { (d == dim || s >= i) && sr >= i }),
156            ShapeMismatchSnafu {
157                context: "scatter",
158                expected: "valid scatter shape constraints".to_string(),
159                actual: format!("self={self_dims:?}, index={index_dims:?}, src={src_dims:?}")
160            }
161        );
162
163        // Shrink src to index shape
164        let shrink_ranges: Vec<(isize, isize)> = index_dims.iter().map(|&d| (0, d as isize)).collect();
165        let src = src.try_shrink(&shrink_ranges)?;
166
167        // src: unsqueeze(-1) → expand(... self.shape[dim]) → transpose(-1, dim)
168        let mut expand_shape: Vec<isize> = index_dims.iter().map(|&d| d as isize).collect();
169        expand_shape.push(self_dims[dim] as isize);
170        let src = src.try_unsqueeze(-1)?.try_expand(&expand_shape)?.try_transpose(-1, dim as isize)?;
171
172        // mask: one_hot_along_dim(self.shape[dim]) → transpose(-1, dim)
173        let mask = index.try_unsqueeze(-1)?.one_hot_along_dim(self_dims[dim], -1)?.try_transpose(-1, dim as isize)?;
174
175        // Pad both to self.shape on non-dim axes
176        let src_cur = src.shape()?;
177        let src_cur_dims = svod_ir::shape::to_vec_usize(&src_cur).context(UOpSnafu)?;
178        let padding: Vec<(isize, isize)> =
179            (0..ndim).map(|d| (0, (self_dims[d] as isize - src_cur_dims[d] as isize).max(0))).collect();
180        let needs_pad = padding.iter().any(|&(_, e)| e > 0);
181        let src = if needs_pad { src.try_pad(&padding)? } else { src };
182        let mask = if needs_pad { mask.try_pad(&padding)? } else { mask };
183
184        Ok((src, mask))
185    }
186
187    /// Scatter values along dim using index positions.
188    ///
189    /// For each position in index, places the corresponding src value into self at
190    /// the specified index along dim. When multiple indices map to the same position,
191    /// the last value wins (matching PyTorch/Tinygrad semantics).
192    #[track_caller]
193    pub fn scatter(&self, dim: isize, index: &Tensor, src: &Tensor) -> Result<Tensor> {
194        let (src_p, mask_p) = self._pre_scatter(dim, index, src)?;
195        masked_setitem(self, &src_p, &mask_p, &[-1])
196    }
197
198    /// Scatter with reduction. Applies reduce (sum/prod/amax/amin) at scatter positions.
199    #[track_caller]
200    pub fn scatter_reduce(
201        &self,
202        dim: isize,
203        index: &Tensor,
204        src: &Tensor,
205        reduce: ScatterReduction,
206        include_self: bool,
207    ) -> Result<Tensor> {
208        let (src_p, mask_p) = self._pre_scatter(dim, index, src)?;
209        let dtype = src_p.uop().dtype();
210        let inv_mask = |a: &Tensor, b: &Tensor| -> Result<Tensor> {
211            let no_hit = mask_p.any(-1isize)?.logical_not()?;
212            a.where_(&no_hit, b)
213        };
214        let self_or = |identity_val: ConstValue| -> Result<Tensor> {
215            if include_self { Ok(self.clone()) } else { inv_mask(self, &Tensor::const_(identity_val, dtype.clone())) }
216        };
217
218        match reduce {
219            ScatterReduction::Sum => {
220                let zero = Tensor::const_(ConstValue::Int(0), dtype.clone());
221                let reduced = src_p.where_(&mask_p, &zero)?.sum_with().axes(-1isize).call()?;
222                reduced.try_add(&self_or(ConstValue::Int(0))?)
223            }
224            ScatterReduction::Prod => {
225                let one = Tensor::const_(ConstValue::Int(1), dtype.clone());
226                let reduced = src_p.where_(&mask_p, &one)?.prod_with().axes(-1isize).call()?;
227                reduced.try_mul(&self_or(ConstValue::Int(1))?)
228            }
229            ScatterReduction::Amax => {
230                let min_val =
231                    if dtype.is_float() { ConstValue::Float(f64::NEG_INFINITY) } else { ConstValue::Int(i64::MIN) };
232                let fill = Tensor::const_(min_val, dtype.clone());
233                let reduced = src_p.where_(&mask_p, &fill)?.max(-1isize)?;
234                reduced.maximum(&self_or(min_val)?)
235            }
236            ScatterReduction::Amin => {
237                let max_val =
238                    if dtype.is_float() { ConstValue::Float(f64::INFINITY) } else { ConstValue::Int(i64::MAX) };
239                let fill = Tensor::const_(max_val, dtype.clone());
240                let reduced = src_p.where_(&mask_p, &fill)?.min(-1isize)?;
241                reduced.minimum(&self_or(max_val)?)
242            }
243        }
244    }
245
246    // =========================================================================
247    // Masked Select (Tinygrad tensor.py:1528-1547)
248    // =========================================================================
249
250    /// Select elements where mask is true, returning a flat tensor.
251    ///
252    /// Requires `realize()` internally (data-dependent output size).
253    #[track_caller]
254    pub fn masked_select(&self, mask: &Tensor) -> Result<Tensor> {
255        let x = self.flatten()?;
256        let mask_flat = mask.broadcast_to(&self.shape()?)?.flatten()?;
257        let mask_cumsum = mask_flat.cast(svod_dtype::DType::Int32)?.cumsum(0)?;
258        // Realize to get output size (data-dependent shape)
259        let n = mask_flat.numel()?;
260        let mut count_t = mask_cumsum.try_shrink([((n - 1) as isize, n as isize)])?;
261        count_t.realize()?;
262        let count_t = count_t.as_ndarray::<i32>()?;
263        let count = count_t[[0]] as usize;
264        if count == 0 {
265            return Ok(Tensor::empty_zero(self.uop().dtype()));
266        }
267
268        // Build gather indices: zeros.scatter(0, cumsum, 1).cumsum
269        let zeros = Tensor::full(&[count], ConstValue::Int(0), svod_dtype::DType::Int32)?;
270        let ones = Tensor::full(&[n], ConstValue::Int(1), svod_dtype::DType::Int32)?;
271        let idxs = zeros.scatter_reduce(0, &mask_cumsum, &ones, ScatterReduction::Sum, false)?.cumsum(0)?;
272        x.gather(0, &idxs)
273    }
274
275    /// Select elements along an axis where `condition` is true.
276    ///
277    /// If `axis` is None, the input is flattened first and selection is along axis 0.
278    /// The condition is a 1D boolean/integer tensor; nonzero values select.
279    #[track_caller]
280    pub fn compress(&self, condition: &[bool], axis: Option<isize>) -> Result<Tensor> {
281        let x = if axis.is_none() { self.flatten()? } else { self.clone() };
282        let axis = axis.unwrap_or(0);
283        let indices: Vec<i64> = condition.iter().enumerate().filter(|(_, v)| **v).map(|(i, _)| i as i64).collect();
284        let idx = Tensor::from_slice(&indices);
285        x.index_select(axis, &idx)
286    }
287
288    // =========================================================================
289    // Sort (Bitonic) (Tinygrad tensor.py:2730-2779)
290    // =========================================================================
291
292    /// Bitonic sort along a dimension. Returns (sorted_values, indices).
293    #[track_caller]
294    pub fn sort(&self, dim: isize, descending: bool) -> Result<(Tensor, Tensor)> {
295        let shape = self.shape()?;
296        let ndim = shape.len();
297        let dim = Self::normalize_axis(dim, ndim)?;
298        let orig_len = shape[dim]
299            .as_const()
300            .ok_or_else(|| crate::error::Error::SymbolicShapeUnsupported { operation: "sort".into() })?;
301
302        if orig_len <= 1 {
303            let idx = Tensor::full(
304                &svod_ir::shape::to_vec_usize(&shape).unwrap(),
305                ConstValue::Int(0),
306                svod_dtype::DType::Int32,
307            )?;
308            return Ok((self.clone(), idx));
309        }
310
311        let n_stages = (orig_len as u64 - 1).ilog2() as usize + 1;
312        let padded_len = 1usize << n_stages;
313
314        // Pad to power of 2
315        let sentinel = if descending {
316            if self.uop().dtype().is_float() { f64::NEG_INFINITY } else { i64::MIN as f64 }
317        } else if self.uop().dtype().is_float() {
318            f64::INFINITY
319        } else {
320            i64::MAX as f64
321        };
322        let mut padding = vec![(0isize, 0isize); ndim];
323        padding[dim] = (0, (padded_len - orig_len) as isize);
324        let mut x = self.try_pad_value(&padding, sentinel)?;
325
326        // Unflatten dim into n_stages binary dimensions
327        let unflatten_sizes: Vec<isize> = vec![2; n_stages];
328        x = x.unflatten(dim as isize, &unflatten_sizes)?;
329
330        // Bitonic sort network
331        for stage in 1..=n_stages {
332            if stage != n_stages {
333                // Crossover: flip for green boxes
334                let crossover_dim = (dim + n_stages - stage - 1) as isize;
335                let halves = x.split(&[1, 1], crossover_dim)?;
336                let (blue, green) = (&halves[0], &halves[1]);
337                let flip_dims: Vec<isize> = (1..=(stage + (ndim - dim))).map(|i| -(i as isize)).collect();
338                x = Tensor::cat(&[blue, &green.flip(&flip_dims)?], crossover_dim)?.contiguous();
339            }
340
341            for substage in (0..stage).rev() {
342                let partner_dim = (dim + n_stages - substage - 1) as isize;
343                let parts = x.split(&[1, 1], partner_dim)?;
344                let (x_top, x_bottom) = (&parts[0], &parts[1]);
345                let x_larger = x_top.maximum(x_bottom)?;
346                let x_smaller = x_top.minimum(x_bottom)?;
347                x = if descending {
348                    Tensor::cat(&[&x_larger, &x_smaller], partner_dim)?
349                } else {
350                    Tensor::cat(&[&x_smaller, &x_larger], partner_dim)?
351                }
352                .contiguous();
353            }
354
355            if stage != n_stages {
356                // Undo crossover
357                let crossover_dim = (dim + n_stages - stage - 1) as isize;
358                let halves = x.split(&[1, 1], crossover_dim)?;
359                let (blue, flipped_green) = (&halves[0], &halves[1]);
360                let flip_dims: Vec<isize> = (1..=(stage + (ndim - dim))).map(|i| -(i as isize)).collect();
361                x = Tensor::cat(&[blue, &flipped_green.flip(&flip_dims)?], crossover_dim)?;
362            }
363        }
364
365        // Flatten back and shrink to original size
366        let flatten_end = dim + n_stages - 1;
367        // Flatten dims [dim..dim+n_stages] back to one
368        let cur_shape = x.shape()?;
369        let cur_dims = svod_ir::shape::to_vec_usize(&cur_shape).context(UOpSnafu)?;
370        let mut flat_shape: Vec<isize> = Vec::new();
371        for (i, &d) in cur_dims.iter().enumerate() {
372            if i == dim {
373                flat_shape.push(padded_len as isize);
374            } else if i > dim && i <= flatten_end {
375                continue;
376            } else {
377                flat_shape.push(d as isize);
378            }
379        }
380        x = x.try_reshape(&flat_shape)?;
381
382        // Shrink to original size
383        let x_shape = x.shape()?;
384        let x_dims = svod_ir::shape::to_vec_usize(&x_shape).context(UOpSnafu)?;
385        let shrink_ranges: Vec<(isize, isize)> =
386            x_dims.iter().enumerate().map(|(d, &s)| (0, if d == dim { orig_len } else { s } as isize)).collect();
387        x = x.try_shrink(&shrink_ranges)?;
388
389        // Compute indices via count-matching (matches Tinygrad's approach)
390        // Create 2D tril mask first (tril operates on last 2 dims), then reshape
391        // to broadcast shape [1, ..., orig_len, orig_len, 1, ..., 1]
392        // Tinygrad: Tensor.ones(orig_len, orig_len).tril().reshape((None, None) + (1,)*(ndim-dim-1))
393        let tril_2d = Tensor::full(&[orig_len, orig_len], true, svod_dtype::DType::Bool)?.tril(0)?;
394        let mut tril_reshape: Vec<isize> = vec![1; ndim + 1];
395        tril_reshape[dim] = orig_len as isize;
396        tril_reshape[dim + 1] = orig_len as isize;
397        let tril_mask = tril_2d.try_reshape(&tril_reshape)?;
398
399        // Count occurrences of each value up to current position
400        let compute_counts = |t: &Tensor| -> Result<Tensor> {
401            let eq = t.try_unsqueeze(dim as isize)?.try_eq(&t.try_unsqueeze((dim + 1) as isize)?)?;
402            eq.bitwise_and(&tril_mask)?.sum((dim + 1) as isize)
403        };
404
405        let count_orig = compute_counts(self)?;
406        let count_sorted = compute_counts(&x)?;
407
408        // Match: original[unsqueeze(dim+1)] == sorted[unsqueeze(dim)] && counts match
409        let val_match = self.try_unsqueeze((dim + 1) as isize)?.try_eq(&x.try_unsqueeze(dim as isize)?)?;
410        let cnt_match =
411            count_orig.try_unsqueeze((dim + 1) as isize)?.try_eq(&count_sorted.try_unsqueeze(dim as isize)?)?;
412        let cond = val_match.bitwise_and(&cnt_match)?;
413
414        // Build index arange and compute weighted sum
415        let mut idx_shape = vec![1isize; ndim + 1];
416        idx_shape[dim] = orig_len as isize;
417        let idx = (cond
418            .cast(svod_dtype::DType::Int32)?
419            .try_mul(&Tensor::arange(0, Some(orig_len as i64), None)?.try_reshape(&idx_shape)?)?)
420        .sum(dim as isize)?;
421
422        Ok((x, idx))
423    }
424
425    // =========================================================================
426    // TopK (Tinygrad tensor.py:2792-2812)
427    // =========================================================================
428
429    /// Top-k elements along a dimension. Returns (values, indices).
430    #[track_caller]
431    pub fn topk(&self, k: usize, dim: isize, largest: bool) -> Result<(Tensor, Tensor)> {
432        let shape = self.shape()?;
433        let ndim = shape.len();
434        let norm_dim = Self::normalize_axis(dim, ndim)?;
435        let (x, idx) = self.sort(dim, largest)?;
436        // Shrink to first k along dim
437        let x_shape = x.shape()?;
438        let x_dims = svod_ir::shape::to_vec_usize(&x_shape).context(UOpSnafu)?;
439        let shrink: Vec<(isize, isize)> =
440            x_dims.iter().enumerate().map(|(d, &s)| (0, if d == norm_dim { k } else { s } as isize)).collect();
441        Ok((x.try_shrink(&shrink)?, idx.try_shrink(&shrink)?))
442    }
443
444    // =========================================================================
445    // NonZero (Tinygrad tensor.py:1549-1573)
446    // =========================================================================
447
448    /// Indices of non-zero elements. Returns [num_nonzero, ndim] tensor.
449    #[track_caller]
450    pub fn nonzero(&self) -> Result<Tensor> {
451        let shape = self.shape()?;
452        let ndim = shape.len();
453        let dims = svod_ir::shape::to_vec_usize(&shape).context(UOpSnafu)?;
454        let numel: usize = dims.iter().product();
455
456        let mask = self.try_ne(&Tensor::const_(ConstValue::Int(0), self.uop().dtype()))?.flatten()?;
457
458        // Build coordinate tensor: for each dim, arange → reshape to broadcast → flatten
459        let coords: Vec<Tensor> = (0..ndim)
460            .map(|i| {
461                let ar = Tensor::arange(0, Some(dims[i] as i64), None)?;
462                let mut rshape = vec![1isize; ndim];
463                rshape[i] = dims[i] as isize;
464                let expand_shape: Vec<isize> = dims.iter().map(|&d| d as isize).collect();
465                ar.try_reshape(&rshape)?.try_expand(&expand_shape)?.flatten()
466            })
467            .collect::<Result<Vec<_>>>()?;
468
469        let coords_refs: Vec<&Tensor> = coords.iter().collect();
470        let indices = Tensor::stack(&coords_refs, -1)?; // [numel, ndim]
471
472        // Select nonzero coordinates
473        let expanded_mask = mask.try_unsqueeze(-1)?.try_expand([numel as isize, ndim as isize])?;
474        let selected = indices.masked_select(&expanded_mask)?;
475        selected.try_reshape([-1, ndim as isize])
476    }
477
478    /// Reverse the first `sequence_lens[i]` elements along `time_axis` for each
479    /// batch element `i` along `batch_axis`, leaving the rest unchanged.
480    #[track_caller]
481    pub fn reverse_sequence(&self, sequence_lens: &Tensor, time_axis: usize, batch_axis: usize) -> Result<Self> {
482        let dims = svod_ir::shape::to_vec_usize(&self.shape()?).context(UOpSnafu)?;
483        let ndim = dims.len();
484        let time_len = dims[time_axis];
485
486        // Transpose so time_axis→0, batch_axis→1
487        let mut perm: Vec<usize> = (0..ndim).collect();
488        perm.swap(0, time_axis);
489        let batch_pos = if batch_axis == 0 {
490            time_axis
491        } else if batch_axis == time_axis {
492            0
493        } else {
494            batch_axis
495        };
496        perm.swap(1, batch_pos);
497        let perm_i: Vec<isize> = perm.iter().map(|&p| p as isize).collect();
498        let work = self.try_permute(&perm_i)?;
499        let work_dims = svod_ir::shape::to_vec_usize(&work.shape()?).context(UOpSnafu)?;
500
501        // t = arange(T) as [T, 1], seq_lens as [1, B]
502        let idx_dt = sequence_lens.uop().dtype();
503        let t = Tensor::arange(0, Some(time_len as i64), None)?.cast(idx_dt.clone())?.try_unsqueeze(1)?;
504        let sl = sequence_lens.try_unsqueeze(0)?;
505
506        // reversed_t = seq_lens - 1 - t; idx = where(t < seq_lens, reversed_t, t)
507        let one = Tensor::const_(ConstValue::Int(1), idx_dt);
508        let reversed_t = sl.try_sub(&one)?.try_sub(&t)?;
509        let mask = t.try_lt(&sl)?;
510        let idx = reversed_t.where_(&mask, &t)?;
511
512        // Expand indices to match work shape [T, B, ...] and gather along axis 0
513        let expand_shape: Vec<isize> = work_dims.iter().map(|&d| d as isize).collect();
514        let idx = idx.try_reshape(&expand_shape[..2])?.try_expand(&expand_shape)?;
515        let result = work.gather(0, &idx)?;
516
517        // Inverse permutation to restore original axis order
518        let mut inv_perm = vec![0usize; ndim];
519        for (i, &p) in perm.iter().enumerate() {
520            inv_perm[p] = i;
521        }
522        let inv_perm_i: Vec<isize> = inv_perm.iter().map(|&p| p as isize).collect();
523        result.try_permute(&inv_perm_i)
524    }
525
526    // =========================================================================
527    // N-dimensional Gather/Scatter (from ONNX GatherND/ScatterND/TensorScatter)
528    // =========================================================================
529
530    /// Gather values using N-dimensional indices.
531    pub fn gather_nd(&self, indices: &Tensor, batch_dims: usize) -> Result<Tensor> {
532        let x_shape = self.shape()?;
533        let x_dims = svod_ir::shape::to_vec_usize(&x_shape).context(UOpSnafu)?;
534        let idx_shape = indices.shape()?;
535        let idx_dims = svod_ir::shape::to_vec_usize(&idx_shape).context(UOpSnafu)?;
536        let last_idx_dim = *idx_dims.last().unwrap();
537
538        if batch_dims == 0 {
539            let strides: Vec<i64> =
540                (0..last_idx_dim).map(|k| x_dims[k + 1..last_idx_dim].iter().product::<usize>() as i64).collect();
541            let inner: usize = x_dims[last_idx_dim..].iter().product();
542            let outer = x_dims[..last_idx_dim].iter().product::<usize>();
543
544            let mut flat_idx = Tensor::const_(ConstValue::Int(0), DType::Int64);
545            for (k, stride) in strides.iter().enumerate() {
546                let mut ranges: Vec<(isize, isize)> = idx_dims.iter().map(|&s| (0, s as isize)).collect();
547                ranges[idx_dims.len() - 1] = (k as isize, k as isize + 1);
548                let idx_k = indices.try_shrink(&ranges)?.try_squeeze(Some(-1))?;
549                let stride_t = Tensor::const_(ConstValue::Int(*stride), DType::Int64);
550                flat_idx = flat_idx.try_add(&idx_k.cast(DType::Int64)?.try_mul(&stride_t)?)?;
551            }
552
553            let x_flat = self.try_reshape([outer as isize, inner as isize])?;
554            let gather_outer: Vec<isize> = idx_dims[..idx_dims.len() - 1].iter().map(|&d| d as isize).collect();
555            let num_gathers: usize = gather_outer.iter().map(|&d| d as usize).product();
556
557            let flat_idx_2d = flat_idx
558                .try_reshape([num_gathers as isize, 1])?
559                .try_expand([num_gathers as isize, inner as isize])?
560                .cast(DType::Int32)?;
561            let result = x_flat.gather(0, &flat_idx_2d)?;
562
563            let mut out_shape = gather_outer;
564            for &d in &x_dims[last_idx_dim..] {
565                out_shape.push(d as isize);
566            }
567            result.try_reshape(&out_shape)
568        } else {
569            let batch_size: usize = x_dims[..batch_dims].iter().product();
570            let inner_x: Vec<usize> = x_dims[batch_dims..].to_vec();
571            let inner_idx: Vec<usize> = idx_dims[batch_dims..].to_vec();
572
573            let x_flat = self.try_reshape(
574                std::iter::once(batch_size as isize).chain(inner_x.iter().map(|&d| d as isize)).collect::<Vec<_>>(),
575            )?;
576            let idx_flat = indices.try_reshape(
577                std::iter::once(batch_size as isize).chain(inner_idx.iter().map(|&d| d as isize)).collect::<Vec<_>>(),
578            )?;
579
580            let last_inner = *inner_idx.last().unwrap();
581            let strides: Vec<i64> =
582                (0..last_inner).map(|k| inner_x[k + 1..last_inner].iter().product::<usize>() as i64).collect();
583
584            let mut flat_idx = Tensor::const_(ConstValue::Int(0), DType::Int64);
585            let idx_flat_shape = idx_flat.shape()?;
586            let idx_flat_dims = svod_ir::shape::to_vec_usize(&idx_flat_shape).context(UOpSnafu)?;
587            for (k, stride) in strides.iter().enumerate() {
588                let mut ranges: Vec<(isize, isize)> = idx_flat_dims.iter().map(|&s| (0, s as isize)).collect();
589                ranges[idx_flat_dims.len() - 1] = (k as isize, k as isize + 1);
590                let idx_k = idx_flat.try_shrink(&ranges)?.try_squeeze(Some(-1))?;
591                let stride_t = Tensor::const_(ConstValue::Int(*stride), DType::Int64);
592                flat_idx = flat_idx.try_add(&idx_k.cast(DType::Int64)?.try_mul(&stride_t)?)?;
593            }
594
595            let batch_stride = inner_x[..last_inner].iter().product::<usize>();
596            let batch_offset_arr = Tensor::arange(0, Some(batch_size as i64), None)?
597                .try_mul(&Tensor::from_slice([batch_stride as i64]))?;
598            let gather_inner = idx_flat_dims[1..idx_flat_dims.len() - 1].iter().product::<usize>();
599            flat_idx = flat_idx.try_reshape([batch_size as isize, gather_inner as isize])?;
600            let batch_offset = batch_offset_arr
601                .try_reshape([batch_size as isize, 1])?
602                .try_expand([batch_size as isize, gather_inner as isize])?;
603            flat_idx = flat_idx.try_add(&batch_offset)?;
604
605            let remaining: usize = inner_x[last_inner..].iter().product();
606            let x_2d = x_flat.try_reshape([(batch_size * batch_stride) as isize, remaining as isize])?;
607            let fi = flat_idx
608                .try_reshape([(batch_size * gather_inner) as isize, 1])?
609                .try_expand([(batch_size * gather_inner) as isize, remaining as isize])?
610                .cast(DType::Int32)?;
611            let result = x_2d.gather(0, &fi)?;
612
613            let mut out_shape: Vec<isize> = x_dims[..batch_dims].iter().map(|&d| d as isize).collect();
614            out_shape.extend(inner_idx[..inner_idx.len() - 1].iter().map(|&d| d as isize));
615            out_shape.extend(inner_x[last_inner..].iter().map(|&d| d as isize));
616            result.try_reshape(&out_shape)
617        }
618    }
619
620    /// Scatter updates into a tensor using N-dimensional indices.
621    pub fn scatter_nd(&self, indices: &Tensor, updates: &Tensor, reduction: &str) -> Result<Tensor> {
622        let x_shape = self.shape()?;
623        let x_dims = svod_ir::shape::to_vec_usize(&x_shape).context(UOpSnafu)?;
624        let idx_shape = indices.shape()?;
625        let last_idx_dim = idx_shape[idx_shape.len() - 1].as_const().unwrap();
626        let strides: Vec<i64> =
627            (0..last_idx_dim).map(|k| x_dims[k + 1..last_idx_dim].iter().product::<usize>() as i64).collect();
628        let x_numel: usize = x_dims.iter().product();
629        let inner: usize = x_dims[last_idx_dim..].iter().product();
630        let outer = x_numel / inner;
631        let x_flat = self.try_reshape([outer as isize, inner as isize])?;
632        let idx_splits: Vec<Tensor> = (0..last_idx_dim)
633            .map(|k| {
634                let mut ranges: Vec<(isize, isize)> =
635                    idx_shape.iter().map(|s| (0, s.as_const().unwrap() as isize)).collect();
636                ranges[idx_shape.len() - 1] = (k as isize, k as isize + 1);
637                let slice = indices.try_shrink(&ranges)?;
638                slice.try_squeeze(Some(-1))
639            })
640            .collect::<Result<_>>()?;
641        let mut flat_idx = Tensor::const_(ConstValue::Int(0), DType::Int64);
642        for (k, idx_k) in idx_splits.iter().enumerate() {
643            let stride_t = Tensor::const_(ConstValue::Int(strides[k]), DType::Int64);
644            flat_idx = flat_idx.try_add(&idx_k.cast(DType::Int64)?.try_mul(&stride_t)?)?;
645        }
646        let upd_shape = updates.shape()?;
647        let upd_outer: usize = upd_shape[..upd_shape.len() - (x_dims.len() - last_idx_dim)]
648            .iter()
649            .map(|s| s.as_const().unwrap())
650            .product();
651        let upd_flat = updates.try_reshape([upd_outer as isize, inner as isize])?;
652        let flat_idx =
653            flat_idx.try_reshape([upd_outer as isize, 1])?.try_expand([upd_outer as isize, inner as isize])?;
654        let flat_idx_i32 = flat_idx.cast(DType::Int32)?;
655        let mut result = match reduction {
656            "none" => x_flat.scatter(0, &flat_idx_i32, &upd_flat)?,
657            "add" => x_flat.scatter_reduce(0, &flat_idx_i32, &upd_flat, ScatterReduction::Sum, true)?,
658            "mul" => x_flat.scatter_reduce(0, &flat_idx_i32, &upd_flat, ScatterReduction::Prod, true)?,
659            "max" => x_flat.scatter_reduce(0, &flat_idx_i32, &upd_flat, ScatterReduction::Amax, true)?,
660            "min" => x_flat.scatter_reduce(0, &flat_idx_i32, &upd_flat, ScatterReduction::Amin, true)?,
661            _ => {
662                return Err(crate::error::Error::IrConstruction {
663                    details: format!("ScatterND: unsupported reduction '{reduction}'"),
664                });
665            }
666        };
667        let out_shape: Vec<isize> = x_dims.iter().map(|&d| d as isize).collect();
668        result = result.try_reshape(&out_shape)?;
669        Ok(result)
670    }
671
672    /// Batch-aware tensor scatter with write index offsets.
673    pub fn tensor_scatter(
674        &self,
675        update: &Tensor,
676        write_indices: Option<&Tensor>,
677        mode: &str,
678        axis: isize,
679    ) -> Result<Tensor> {
680        let data_shape = self.shape()?;
681        let ndim = data_shape.len();
682        let axis = Self::normalize_axis(axis, ndim)?;
683        let data_dims = svod_ir::shape::to_vec_usize(&data_shape).context(UOpSnafu)?;
684        let update_dims = svod_ir::shape::to_vec_usize(&update.shape()?).context(UOpSnafu)?;
685
686        let batch_size = data_dims[0];
687        let max_seq = data_dims[axis];
688        let seq_len = update_dims[axis];
689
690        let b_total: usize = data_dims[..axis].iter().product();
691        let features: usize = data_dims[axis + 1..].iter().product();
692
693        let write_idx = if let Some(wi) = write_indices {
694            wi.cast(DType::Int32)?
695        } else {
696            Tensor::full(&[batch_size], ConstValue::Int(0), DType::Int32)?
697        };
698
699        let wi_flat = if axis > 1 {
700            let mut wi_reshape: Vec<isize> = vec![batch_size as isize];
701            wi_reshape.extend(std::iter::repeat_n(1, axis - 1));
702            let wi_expand: Vec<isize> = data_dims[..axis].iter().map(|&d| d as isize).collect();
703            write_idx.try_reshape(&wi_reshape)?.try_expand(&wi_expand)?.try_reshape([b_total as isize])?
704        } else {
705            write_idx
706        };
707
708        let data_flat = self.try_reshape([(b_total * max_seq) as isize, features as isize])?;
709        let updates_flat = update.try_reshape([(b_total * seq_len) as isize, features as isize])?;
710
711        let batch_offset = Tensor::arange(0, Some(b_total as i64), None)?
712            .cast(DType::Int32)?
713            .try_mul(&Tensor::const_(ConstValue::Int(max_seq as i64), DType::Int32))?
714            .try_reshape([b_total as isize, 1])?;
715
716        let wi_2d = wi_flat.try_reshape([b_total as isize, 1])?;
717        let seq_arange =
718            Tensor::arange(0, Some(seq_len as i64), None)?.cast(DType::Int32)?.try_reshape([1, seq_len as isize])?;
719        let mut row_idx = wi_2d.try_add(&seq_arange)?;
720
721        if mode == "circular" {
722            let max_seq_t = Tensor::const_(ConstValue::Int(max_seq as i64), DType::Int32);
723            row_idx = row_idx.try_mod(&max_seq_t)?;
724        }
725
726        let flat_idx = batch_offset
727            .try_add(&row_idx)?
728            .try_reshape([(b_total * seq_len) as isize, 1])?
729            .try_expand([(b_total * seq_len) as isize, features as isize])?;
730
731        let result = data_flat.scatter(0, &flat_idx, &updates_flat)?;
732
733        let out_shape: Vec<isize> = data_dims.iter().map(|&d| d as isize).collect();
734        result.try_reshape(&out_shape)
735    }
736}
737
738/// Reduce repeated indices so the last value wins, then apply mask.
739///
740/// Tinygrad's `_masked_setitem`: for each axis, split mask/values into slices,
741/// fold with OR on mask and last-writer-wins on values, squeeze, then
742/// `mask.where(values, target)`.
743fn masked_setitem(target: &Tensor, values: &Tensor, mask: &Tensor, axes: &[isize]) -> Result<Tensor> {
744    let mut mask = mask.clone();
745    let mut values = values.clone();
746
747    // Phase 1: reduce repeated indices — last value wins
748    for &dim in axes.iter().rev() {
749        let shape = mask.shape()?;
750        let ndim = shape.len();
751        let norm_dim = Tensor::normalize_axis(dim, ndim)?;
752        let dim_size = shape[norm_dim].as_const().unwrap();
753        let ones = vec![1usize; dim_size];
754        let mask_slices = mask.split(&ones, dim)?;
755        let val_slices = values.split(&ones, dim)?;
756        let (mut acc_mask, mut acc_vals) = (mask_slices[0].clone(), val_slices[0].clone());
757        for (m, v) in mask_slices[1..].iter().zip(&val_slices[1..]) {
758            // last-writer-wins: where m is true take v, otherwise keep acc
759            acc_vals = v.where_(m, &acc_vals)?;
760            acc_mask = acc_mask.bitwise_or(m)?;
761        }
762        mask = acc_mask;
763        values = acc_vals;
764    }
765
766    // Phase 2: squeeze reduced axes
767    for &dim in axes.iter().rev() {
768        mask = mask.try_squeeze(Some(dim))?;
769        values = values.try_squeeze(Some(dim))?;
770    }
771
772    // Phase 3: select from values where mask is true, else target
773    values.where_(&mask, target)
774}