Skip to main content

svod_tensor/
shape_ops.rs

1//! Shape manipulation operations for Tensors.
2//!
3//! This module provides operations that change tensor shapes without copying data:
4//! - Reshape: Change shape while preserving total elements
5//! - Permute: Reorder dimensions
6//! - Transpose: Swap two dimensions (convenience wrapper for permute)
7//! - Expand: Broadcast dimensions from size 1
8//! - Squeeze: Remove dimensions of size 1
9//! - Unsqueeze: Add dimensions of size 1
10
11use bon::bon;
12use snafu::ResultExt;
13use strum::{Display, EnumString};
14use svod_ir::IntoShrinkRange;
15
16use super::*;
17
18/// Indexing convention for meshgrid.
19#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, EnumString, Display)]
20pub enum MeshgridIndexing {
21    #[default]
22    #[strum(serialize = "ij")]
23    Ij,
24    #[strum(serialize = "xy")]
25    Xy,
26}
27
28impl Tensor {
29    /// Reshape tensor to a new shape.
30    ///
31    /// The total number of elements must remain the same.
32    /// Supports negative indices: -1 means "infer this dimension".
33    ///
34    /// # Examples
35    ///
36    /// ```
37    /// # use svod_tensor::Tensor;
38    /// let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]);
39    /// let reshaped = t.try_reshape(&[2, 3]).unwrap();  // [6] -> [2, 3]
40    /// let inferred = t.try_reshape(&[-1, 2]).unwrap(); // [6] -> [3, 2]
41    /// ```
42    ///
43    /// # Errors
44    ///
45    /// Returns error if:
46    /// - Shape contains negative values other than -1
47    /// - Multiple -1 dimensions specified
48    /// - Total elements don't match
49    #[track_caller]
50    pub fn try_reshape(&self, new_shape: impl IntoIterator<Item = impl Into<SInt>>) -> Result<Tensor> {
51        let dims: Vec<SInt> = new_shape.into_iter().map(Into::into).collect();
52
53        // Handle Infer (-1) if present
54        let infer_count = dims.iter().filter(|d| d.is_infer()).count();
55        snafu::ensure!(infer_count <= 1, MultipleInferDimensionsSnafu);
56
57        let shape: Shape = if infer_count == 1 {
58            let current_shape = self.shape()?;
59            let total_elements =
60                current_shape.iter().try_fold(1usize, |acc, dim| dim.as_const().map(|v| acc * v)).ok_or_else(|| {
61                    Error::SymbolicShapeUnsupported { operation: "reshape with -1 inference".to_string() }
62                })?;
63            let known_product: usize = dims
64                .iter()
65                .filter(|d| !d.is_infer())
66                .map(|d| d.as_const().expect("non-infer dims must be concrete for -1 inference"))
67                .product();
68            snafu::ensure!(
69                known_product > 0 && total_elements % known_product == 0,
70                ReshapeSizeMismatchSnafu { operation: "reshape with inference".to_string() }
71            );
72            let inferred = total_elements / known_product;
73            dims.iter().map(|d| if d.is_infer() { SInt::Const(inferred) } else { d.clone() }).collect()
74        } else {
75            dims.into()
76        };
77
78        self.uop().try_reshape(&shape).map(Self::new).context(UOpSnafu)
79    }
80
81    /// Expand tensor to a new shape with mixed concrete/symbolic dimensions.
82    pub fn try_expand(&self, new_shape: impl IntoIterator<Item = impl Into<SInt>>) -> Result<Tensor> {
83        let requested: Vec<SInt> = new_shape.into_iter().map(Into::into).collect();
84        // Resolve Infer (-1) to current dimension (expand's "keep" semantics)
85        let current_shape = self.shape()?;
86        let shape: Shape = requested
87            .into_iter()
88            .enumerate()
89            .map(|(i, s)| if s.is_infer() { current_shape[i].clone() } else { s })
90            .collect();
91        self.uop().try_expand(&shape).map(Self::new).context(UOpSnafu)
92    }
93
94    /// Permute (reorder) tensor dimensions.
95    ///
96    /// The axes parameter specifies the new order of dimensions.
97    /// Each axis index 0..ndim must appear exactly once.
98    ///
99    /// # Examples
100    ///
101    /// ```
102    /// # use svod_tensor::Tensor;
103    /// // Tensor with shape [2, 3, 4]
104    /// // t.try_permute(&[2, 0, 1]) -> shape [4, 2, 3]
105    /// // t.try_permute(&[1, 0, 2]) -> shape [3, 2, 4]
106    /// ```
107    ///
108    /// # Errors
109    ///
110    /// Returns error if:
111    /// - Axes is not a valid permutation
112    /// - Axis indices out of range
113    #[track_caller]
114    pub fn try_permute(&self, axes: &[isize]) -> Result<Tensor> {
115        let shape = self.shape()?;
116        let ndim = shape.len();
117
118        // Normalize negative indices and validate
119        let normalized_axes = self.normalize_axes(axes, ndim)?;
120
121        self.uop().try_permute(normalized_axes).map(Self::new).context(UOpSnafu)
122    }
123
124    /// Transpose two dimensions.
125    ///
126    /// Convenience method for swapping two dimensions.
127    /// Equivalent to permute with the two dimensions swapped.
128    ///
129    /// # Examples
130    ///
131    /// ```
132    /// # use svod_tensor::Tensor;
133    /// // Tensor with shape [2, 3, 4]
134    /// // t.try_transpose(0, 1) -> shape [3, 2, 4]
135    /// // t.try_transpose(-1, 0) -> shape [4, 3, 2]  (negative indices supported)
136    /// ```
137    ///
138    /// # Errors
139    ///
140    /// Returns error if axis indices are out of range.
141    #[track_caller]
142    pub fn try_transpose(&self, dim0: isize, dim1: isize) -> Result<Tensor> {
143        let shape = self.shape()?;
144        let ndim = shape.len();
145
146        // Normalize negative indices
147        let d0 = Self::normalize_axis(dim0, ndim)?;
148        let d1 = Self::normalize_axis(dim1, ndim)?;
149
150        // Build permutation with swapped dimensions
151        let mut axes: Vec<usize> = (0..ndim).collect();
152        axes.swap(d0, d1);
153
154        self.uop().try_permute(axes).map(Self::new).context(UOpSnafu)
155    }
156
157    /// Expand (broadcast) dimensions.
158    ///
159    /// Dimensions of size 1 can be expanded to larger sizes.
160    /// Use -1 to keep the current dimension size.
161    ///
162    /// # Examples
163    ///
164    /// ```
165    /// # use svod_tensor::Tensor;
166    /// // Tensor with shape [1, 3, 1]
167    /// // t.try_expand(&[4, -1, 5]) -> shape [4, 3, 5]
168    /// ```
169    ///
170    /// Squeeze dimensions of size 1.
171    ///
172    /// If dim is None, removes all dimensions of size 1.
173    /// If dim is Some(axis), removes only that dimension if it's size 1.
174    ///
175    /// # Examples
176    ///
177    /// ```
178    /// # use svod_tensor::Tensor;
179    /// // Tensor with shape [1, 3, 1, 4]
180    /// // t.try_squeeze(None) -> shape [3, 4]
181    /// // t.try_squeeze(Some(0)) -> shape [3, 1, 4]
182    /// // t.try_squeeze(Some(2)) -> shape [1, 3, 4]
183    /// ```
184    ///
185    /// # Errors
186    ///
187    /// Returns error if:
188    /// - Specified dimension is not size 1
189    /// - Axis index out of range
190    #[track_caller]
191    pub fn try_squeeze(&self, dim: Option<isize>) -> Result<Tensor> {
192        let shape = self.shape()?;
193
194        let new_shape = match dim {
195            None => {
196                // Remove all dimensions of size 1
197                shape
198                    .iter()
199                    .filter_map(|s| s.as_const().and_then(|v| if v != 1 { Some(SInt::Const(v)) } else { None }))
200                    .collect()
201            }
202            Some(axis) => {
203                let ndim = shape.len();
204                let normalized_axis = Self::normalize_axis(axis, ndim)?;
205
206                // Check if dimension is size 1
207                let dim_size = shape[normalized_axis]
208                    .as_const()
209                    .ok_or_else(|| Error::SymbolicShapeUnsupported { operation: "squeeze".to_string() })?;
210
211                snafu::ensure!(dim_size == 1, SqueezeDimensionNotOneSnafu { dim: normalized_axis, size: dim_size });
212
213                // Remove the specified dimension
214                shape
215                    .iter()
216                    .enumerate()
217                    .filter_map(|(i, s)| if i != normalized_axis { Some(s.clone()) } else { None })
218                    .collect()
219            }
220        };
221
222        self.uop().try_reshape(&new_shape).map(Self::new).context(UOpSnafu)
223    }
224
225    /// Add a dimension of size 1.
226    ///
227    /// Inserts a new dimension at the specified position.
228    /// Supports negative indices: -1 means after the last dimension.
229    ///
230    /// # Examples
231    ///
232    /// ```
233    /// # use svod_tensor::Tensor;
234    /// // Tensor with shape [3, 4]
235    /// // t.try_unsqueeze(0) -> shape [1, 3, 4]
236    /// // t.try_unsqueeze(1) -> shape [3, 1, 4]
237    /// // t.try_unsqueeze(-1) -> shape [3, 4, 1]
238    /// ```
239    ///
240    /// # Errors
241    ///
242    /// Returns error if axis index is out of range.
243    #[track_caller]
244    pub fn try_unsqueeze(&self, dim: isize) -> Result<Tensor> {
245        let shape = self.shape()?;
246        let ndim = shape.len();
247
248        // For unsqueeze, valid range is [0, ndim] (can insert at end)
249        // Normalize negative indices: -1 means ndim (after last), -2 means ndim-1, etc.
250        let normalized_dim = if dim < 0 {
251            let positive = (ndim as isize + 1 + dim) as usize;
252            snafu::ensure!(dim >= -(ndim as isize + 1), AxisOutOfRangeSnafu { axis: dim, ndim });
253            positive
254        } else {
255            let pos = dim as usize;
256            snafu::ensure!(pos <= ndim, AxisOutOfRangeSnafu { axis: dim, ndim });
257            pos
258        };
259
260        // Insert dimension of size 1
261        let mut new_shape = shape.clone();
262        new_shape.insert(normalized_dim, SInt::Const(1));
263
264        self.uop().try_reshape(&new_shape).map(Self::new).context(UOpSnafu)
265    }
266
267    /// Reverse elements along specified axes.
268    ///
269    /// Each axis in the list is flipped (reversed). Supports negative indexing.
270    ///
271    /// # Examples
272    /// ```ignore
273    /// let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0]).try_reshape(&[2, 2])?;
274    /// let flipped = t.flip(&[0])?;  // Flip along axis 0
275    /// ```
276    #[track_caller]
277    pub fn flip(&self, axes: &[isize]) -> Result<Tensor> {
278        let shape = self.shape()?;
279        let ndim = shape.len();
280        let flip_spec: Vec<bool> =
281            (0..ndim).map(|d| axes.iter().any(|&a| Self::normalize_axis(a, ndim).is_ok_and(|na| na == d))).collect();
282        self.uop().try_flip(flip_spec).map(Self::new).context(UOpSnafu)
283    }
284
285    /// Split tensor into chunks along a dimension.
286    ///
287    /// Returns a vector of tensors, each with the specified size along the split dimension.
288    ///
289    /// # Examples
290    /// ```ignore
291    /// let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0]);
292    /// let parts = t.split(&[2, 3], 0)?;  // [2] and [3]
293    /// ```
294    #[track_caller]
295    pub fn split(&self, sizes: &[usize], dim: isize) -> Result<Vec<Tensor>> {
296        let shape = self.shape()?;
297        let ndim = shape.len();
298        let dim = Self::normalize_axis(dim, ndim)?;
299        let mut results = Vec::with_capacity(sizes.len());
300        let mut offset = 0usize;
301        for &size in sizes {
302            let ranges: Vec<Option<(isize, isize)>> = (0..ndim)
303                .map(|d| {
304                    if d == dim {
305                        Some((offset as isize, (offset + size) as isize))
306                    } else {
307                        None // keep entire dim (supports symbolic)
308                    }
309                })
310                .collect();
311            results.push(self.try_shrink(ranges)?);
312            offset += size;
313        }
314        Ok(results)
315    }
316
317    /// Repeat tensor along each dimension.
318    ///
319    /// `repeats[i]` is the number of times to repeat along dimension `i`.
320    /// Accepts `&[SInt]` — supports both concrete and symbolic repeat counts.
321    ///
322    /// # Examples
323    /// ```ignore
324    /// use svod_ir::SInt;
325    /// let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0]).try_reshape(&[1, 3])?;
326    /// let tiled = t.repeat(&[SInt::from(3), SInt::from(2)])?;  // Shape [3, 6]
327    /// ```
328    #[track_caller]
329    pub fn repeat(&self, repeats: &[SInt]) -> Result<Tensor> {
330        let shape = self.shape()?;
331        let ndim = shape.len();
332        snafu::ensure!(
333            repeats.len() == ndim,
334            ShapeMismatchSnafu {
335                context: "repeat",
336                expected: format!("{} dimensions", ndim),
337                actual: format!("{} repeats", repeats.len())
338            }
339        );
340        let mut result = self.clone();
341        for (dim, rep) in repeats.iter().enumerate() {
342            if rep.as_const() == Some(1) {
343                continue;
344            }
345            let current_shape = result.shape()?;
346            let dim_size = &current_shape[dim];
347            // Unsqueeze at dim, expand rep times, then reshape to merge.
348            result = result.try_unsqueeze(dim as isize)?;
349            let mut expand_shape: Vec<SInt> = current_shape.iter().cloned().collect();
350            expand_shape.insert(dim, rep.clone());
351            result = result.try_expand(&expand_shape)?;
352            expand_shape[dim] = rep * dim_size;
353            expand_shape.remove(dim + 1);
354            result = result.try_reshape(expand_shape)?;
355        }
356        Ok(result)
357    }
358
359    /// Flatten tensor to 1D.
360    ///
361    /// Reshapes tensor to have a single dimension containing all elements.
362    /// Equivalent to `try_reshape(&[-1])`.
363    ///
364    /// # Examples
365    /// ```ignore
366    /// let t = Tensor::from_slice(&[[1, 2], [3, 4]]);  // Shape [2, 2]
367    /// let flattened = t.flatten()?;  // Shape [4]
368    /// ```
369    #[track_caller]
370    pub fn flatten(&self) -> Result<Tensor> {
371        self.try_reshape([-1])
372    }
373
374    /// Pad tensor with zeros (or other padding value).
375    ///
376    /// Each tuple in `padding` specifies (begin, end) padding for a dimension.
377    /// Use 0 for no padding on that side.
378    ///
379    /// # Examples
380    ///
381    /// ```
382    /// # use svod_tensor::Tensor;
383    /// let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0]);  // Shape [3]
384    /// let padded = t.try_pad(&[(1, 2)]).unwrap();  // Shape [6]: [0, 1, 2, 3, 0, 0]
385    /// ```
386    ///
387    /// # Errors
388    ///
389    /// Returns error if:
390    /// - Padding values are symbolic (not concrete)
391    /// - Number of padding pairs doesn't match dimensions
392    #[track_caller]
393    pub fn try_pad(&self, padding: &[(isize, isize)]) -> Result<Tensor> {
394        let shape = self.shape()?;
395
396        // Empty padding (scalar) → identity
397        if padding.is_empty() {
398            return Ok(self.clone());
399        }
400
401        // Convert to SInt and validate
402        snafu::ensure!(
403            padding.len() == shape.len(),
404            ShapeMismatchSnafu {
405                context: "pad",
406                expected: format!("{} dimensions", shape.len()),
407                actual: format!("{} padding pairs", padding.len())
408            }
409        );
410
411        // Phase 1: shrink for negative padding (negative padding = cropping)
412        let needs_shrink = padding.iter().any(|(b, e)| *b < 0 || *e < 0);
413        let base = if needs_shrink {
414            let shrink_ranges: Vec<(isize, isize)> = padding
415                .iter()
416                .zip(shape.iter())
417                .map(|((b, e), s)| {
418                    let dim = s.as_const().expect("pad with negative values requires concrete shape") as isize;
419                    let begin = (-*b).max(0);
420                    let end = (dim + *e).min(dim);
421                    (begin, end)
422                })
423                .collect();
424            self.try_shrink(&shrink_ranges)?
425        } else {
426            self.clone()
427        };
428
429        // Phase 2: pad with positive-only values
430        let pos_padding: Vec<(isize, isize)> = padding.iter().map(|(b, e)| ((*b).max(0), (*e).max(0))).collect();
431        if pos_padding.iter().all(|(b, e)| *b == 0 && *e == 0) {
432            return Ok(base);
433        }
434
435        let padding_sint: Vec<(SInt, SInt)> =
436            pos_padding.iter().map(|(begin, end)| (SInt::Const(*begin as usize), SInt::Const(*end as usize))).collect();
437
438        base.uop().try_pad(&padding_sint).map(Self::new).context(UOpSnafu)
439    }
440
441    /// Concatenate tensors along an axis.
442    ///
443    /// All tensors must have the same shape except in the concatenating dimension.
444    ///
445    /// # Examples
446    ///
447    /// ```
448    /// # use svod_tensor::Tensor;
449    /// let a = Tensor::from_slice(&[1.0f32, 2.0, 3.0]).try_reshape(&[3]).unwrap();
450    /// let b = Tensor::from_slice(&[4.0f32, 5.0]).try_reshape(&[2]).unwrap();
451    /// let c = Tensor::cat(&[&a, &b], 0).unwrap();  // Shape [5]: [1, 2, 3, 4, 5]
452    /// ```
453    ///
454    /// # Errors
455    ///
456    /// Returns error if:
457    /// - Tensors have different number of dimensions
458    /// - Non-concat dimensions don't match
459    #[track_caller]
460    pub fn cat(tensors: &[&Tensor], dim: isize) -> Result<Tensor> {
461        if tensors.is_empty() {
462            return Err(IrConstructionSnafu { details: "cat requires at least one tensor".to_string() }.build());
463        }
464
465        let first = tensors[0];
466        let first_shape = first.shape()?;
467        let ndim = first_shape.len();
468        let dim = Self::normalize_axis(dim, ndim)?;
469
470        // Validate all tensors have compatible shapes
471        for (i, t) in tensors.iter().enumerate().skip(1) {
472            let t_shape = t.shape()?;
473            snafu::ensure!(
474                t_shape.len() == ndim,
475                ShapeMismatchSnafu {
476                    context: "cat",
477                    expected: format!("{} dimensions", ndim),
478                    actual: format!("{} dimensions for tensor {}", t_shape.len(), i)
479                }
480            );
481            for (d, (s1, s2)) in first_shape.iter().zip(t_shape.iter()).enumerate() {
482                if d != dim {
483                    snafu::ensure!(
484                        s1 == s2,
485                        ShapeMismatchSnafu {
486                            context: format!("cat dimension {}", d),
487                            expected: format!("{:?}", s1),
488                            actual: format!("{:?}", s2)
489                        }
490                    );
491                }
492            }
493        }
494
495        // Compute cumulative sizes along concat dimension
496        let dim_sizes: Vec<usize> = tensors.iter().map(|t| t.shape().unwrap()[dim].as_const().unwrap_or(0)).collect();
497        let total_dim: usize = dim_sizes.iter().sum();
498
499        // Pad each tensor to final size and add
500        let mut cumsum = 0usize;
501        let padded: Vec<Tensor> = tensors
502            .iter()
503            .zip(dim_sizes.iter())
504            .map(|(t, &sz)| {
505                let begin_pad = cumsum;
506                let end_pad = total_dim - cumsum - sz;
507                cumsum += sz;
508
509                let mut padding = vec![(0isize, 0isize); ndim];
510                padding[dim] = (begin_pad as isize, end_pad as isize);
511                t.try_pad(&padding)
512            })
513            .collect::<Result<Vec<_>>>()?;
514
515        // Sum all padded tensors
516        let mut result = padded[0].clone();
517        for t in padded.iter().skip(1) {
518            result = result.try_add(t)?;
519        }
520        Ok(result)
521    }
522
523    /// Stack tensors along a new dimension.
524    ///
525    /// Creates a new axis at `dim` by unsqueezing each tensor, then concatenating.
526    #[track_caller]
527    pub fn stack(tensors: &[&Tensor], dim: isize) -> Result<Tensor> {
528        let unsqueezed: Vec<Tensor> = tensors.iter().map(|t| t.try_unsqueeze(dim)).collect::<Result<_>>()?;
529        Tensor::cat(&unsqueezed.iter().collect::<Vec<_>>(), dim)
530    }
531
532    /// Replace a single dimension with multiple dimensions.
533    ///
534    /// Inverse of flatten: splits dimension `dim` into the shape given by `sizes`.
535    #[track_caller]
536    pub fn unflatten(&self, dim: isize, sizes: &[isize]) -> Result<Tensor> {
537        let shape = self.shape()?;
538        let dim = Self::normalize_axis(dim, shape.len())?;
539        let mut new_shape = svod_ir::shape::to_vec_isize(&shape).context(UOpSnafu)?;
540        new_shape.splice(dim..=dim, sizes.iter().copied());
541        self.try_reshape(&new_shape)
542    }
543
544    /// Create coordinate grids from 1D tensors.
545    ///
546    /// `indexing`: `Ij` (matrix/default) or `Xy` (Cartesian, swaps first two inputs).
547    #[track_caller]
548    pub fn meshgrid(tensors: &[&Tensor], indexing: MeshgridIndexing) -> Result<Vec<Tensor>> {
549        let n = tensors.len();
550        let sizes: Vec<usize> = tensors.iter().map(|t| t.numel().unwrap()).collect();
551        // For "xy" indexing, swap the first two inputs
552        let swapped: Vec<usize> = if indexing == MeshgridIndexing::Xy && n >= 2 {
553            let mut s: Vec<usize> = (0..n).collect();
554            s.swap(0, 1);
555            s
556        } else {
557            (0..n).collect()
558        };
559        // Output shape is [sizes[swapped[0]], sizes[swapped[1]], ...]
560        let out_shape: Vec<isize> = swapped.iter().map(|&i| sizes[i] as isize).collect();
561        tensors
562            .iter()
563            .enumerate()
564            .map(|(i, t)| {
565                // Position of this tensor's dimension in the output
566                let pos = swapped.iter().position(|&s| s == i).unwrap();
567                let mut shape = vec![1isize; n];
568                shape[pos] = sizes[i] as isize;
569                t.flatten()?.try_reshape(&shape)?.try_expand(&out_shape)
570            })
571            .collect()
572    }
573
574    /// Get the shape of this tensor as a new tensor.
575    ///
576    /// Returns a 1D tensor of int64 containing the shape dimensions.
577    /// This is useful for ONNX Shape operator compatibility.
578    ///
579    /// # Examples
580    ///
581    /// ```
582    /// # use svod_tensor::Tensor;
583    /// let t = Tensor::from_slice(&[1.0f32; 6]).try_reshape(&[2, 3]).unwrap();
584    /// let shape_tensor = t.shape_tensor().unwrap();  // Tensor([2, 3]) with dtype int64
585    /// ```
586    ///
587    /// # Errors
588    ///
589    /// Supports symbolic dimensions — symbolic dims produce scalar UOp tensors.
590    #[track_caller]
591    pub fn shape_tensor(&self) -> Result<Tensor> {
592        let shape = self.shape()?;
593
594        // If all concrete, fast path
595        if shape.iter().all(|d| d.is_const()) {
596            let dims: Vec<i64> = shape.iter().map(|d| d.as_const().unwrap() as i64).collect();
597            return Ok(Tensor::from_slice(&dims));
598        }
599
600        // Mixed concrete/symbolic: create scalar tensors and cat
601        let shape_sint: smallvec::SmallVec<[SInt; 4]> = smallvec::smallvec![SInt::from(1usize)];
602        let scalars: Result<Vec<Tensor>> = shape
603            .iter()
604            .map(|d| {
605                let uop = d.to_uop(svod_dtype::DType::Int64);
606                uop.try_reshape(&shape_sint).map(Tensor::new).context(UOpSnafu)
607            })
608            .collect();
609        let scalars = scalars?;
610        let refs: Vec<&Tensor> = scalars.iter().collect();
611        Tensor::cat(&refs, 0)
612    }
613
614    /// Shrink (slice) tensor along each dimension.
615    ///
616    /// Each tuple in `ranges` specifies (begin, end) for a dimension.
617    /// Use (0, size) to keep full dimension.
618    ///
619    /// # Examples
620    ///
621    /// ```
622    /// # use svod_tensor::Tensor;
623    /// let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0]);
624    /// let sliced = t.try_shrink(&[(1, 4)]).unwrap();  // Elements [2, 3, 4]
625    /// ```
626    ///
627    /// # Errors
628    ///
629    /// Returns error if negative indices are used with symbolic shape dimensions.
630    #[track_caller]
631    pub fn try_shrink<R: IntoShrinkRange>(&self, ranges: impl IntoIterator<Item = R>) -> Result<Tensor> {
632        use svod_ir::ShrinkRange;
633
634        let shape = self.shape()?;
635        let resolved: Vec<ShrinkRange> = ranges.into_iter().map(|r| r.into_shrink_range()).collect();
636
637        // Empty ranges (scalar) → identity
638        if resolved.is_empty() {
639            return Ok(self.clone());
640        }
641
642        // Check if all ranges are None (no-op)
643        if resolved.iter().all(|r| matches!(r, ShrinkRange::None)) {
644            return Ok(self.clone());
645        }
646
647        // Convert to (SInt, SInt), resolving negative isize indices.
648        // None means "keep entire dim" (matches Tinygrad's shrink(None) semantics).
649        let ranges_sint: Vec<(SInt, SInt)> = resolved
650            .into_iter()
651            .enumerate()
652            .map(|(dim_idx, range)| match range {
653                ShrinkRange::None => Ok((SInt::Const(0), shape[dim_idx].clone())),
654                ShrinkRange::Sint(begin, end) => Ok((begin, end)),
655                ShrinkRange::Isize(begin, end) => {
656                    let (nb, ne) = if begin < 0 || end < 0 {
657                        let dim_size = shape[dim_idx].as_const().ok_or_else(|| Error::SymbolicShapeUnsupported {
658                            operation: "shrink with negative indices".to_string(),
659                        })? as isize;
660                        (if begin < 0 { dim_size + begin } else { begin }, if end < 0 { dim_size + end } else { end })
661                    } else {
662                        (begin, end)
663                    };
664                    Ok((SInt::Const(nb as usize), SInt::Const(ne as usize)))
665                }
666            })
667            .collect::<Result<Vec<_>>>()?;
668
669        self.uop().try_shrink(&ranges_sint).map(Self::new).context(UOpSnafu)
670    }
671
672    /// Center-crop or center-pad each specified axis to the target size.
673    ///
674    /// For axes where `target < current`, crops from the center.
675    /// For axes where `target > current`, pads symmetrically around the center.
676    /// Axes where `target == current` are unchanged.
677    ///
678    /// `axes` specifies which dimensions to apply (default: all).
679    pub fn center_crop_pad(&self, target_shape: &[usize], axes: Option<&[usize]>) -> Result<Tensor> {
680        let shape = self.shape()?;
681        let ndim = shape.len();
682        let default_axes: Vec<usize> = (0..ndim).collect();
683        let axes = axes.unwrap_or(&default_axes);
684
685        let mut shrink_arg: Vec<(isize, isize)> =
686            (0..ndim).map(|i| (0, shape[i].as_const().unwrap_or(1) as isize)).collect();
687        let mut pad_arg: Vec<(isize, isize)> = vec![(0, 0); ndim];
688
689        for (&s, &ax) in target_shape.iter().zip(axes.iter()) {
690            let s = s as isize;
691            let tx = shape[ax].as_const().unwrap_or(1) as isize;
692            if s < tx {
693                shrink_arg[ax] = (tx / 2 - (s + 1) / 2, tx / 2 + s / 2);
694            } else if s > tx {
695                pad_arg[ax] = ((s - tx) / 2, (s - tx + 1) / 2);
696            }
697        }
698
699        self.try_shrink(&shrink_arg)?.try_pad(&pad_arg)
700    }
701
702    // =========================================================================
703    // Helper Methods
704    // =========================================================================
705
706    /// Get the concrete shape of this tensor.
707    pub fn shape(&self) -> Result<Shape> {
708        self.uop().shape().context(UOpSnafu)?.cloned().ok_or(Error::ShapeUnknown)
709    }
710
711    /// Get the number of dimensions (rank).
712    pub fn ndim(&self) -> Result<usize> {
713        Ok(self.shape()?.len())
714    }
715
716    /// Total number of elements. Fails if any dimension is symbolic.
717    pub fn numel(&self) -> Result<usize> {
718        self.shape()?.iter().try_fold(1usize, |acc, d| {
719            d.as_const().map(|v| acc * v).ok_or(Error::SymbolicShapeUnsupported { operation: "numel".into() })
720        })
721    }
722
723    /// Get the size of a specific dimension.
724    ///
725    /// Supports negative indexing (e.g., -1 for last dimension).
726    /// Returns a SInt which can be either concrete (Const) or symbolic.
727    ///
728    /// # Examples
729    ///
730    /// ```ignore
731    /// let t = Tensor::from_slice([1.0f32; 6]).try_reshape(&[2, 3])?;
732    /// assert_eq!(t.dim(0)?.as_const(), Some(2));   // First dimension
733    /// assert_eq!(t.dim(1)?.as_const(), Some(3));   // Second dimension
734    /// assert_eq!(t.dim(-1)?.as_const(), Some(3));  // Last dimension (negative indexing)
735    /// assert_eq!(t.dim(-2)?.as_const(), Some(2));  // Second-to-last dimension
736    /// ```
737    ///
738    /// # Errors
739    ///
740    /// Returns error if axis is out of range.
741    pub(crate) fn dim(&self, axis: isize) -> Result<svod_ir::SInt> {
742        let shape = self.shape()?;
743        let idx = Self::normalize_axis(axis, shape.len())?;
744        Ok(shape[idx].clone())
745    }
746
747    /// Normalize a single axis index (handle negative indices).
748    pub(crate) fn normalize_axis(axis: isize, ndim: usize) -> Result<usize> {
749        if axis < 0 {
750            let positive = (ndim as isize + axis) as usize;
751            snafu::ensure!(axis >= -(ndim as isize), AxisOutOfRangeSnafu { axis, ndim });
752            Ok(positive)
753        } else {
754            let pos = axis as usize;
755            snafu::ensure!(pos < ndim, AxisOutOfRangeSnafu { axis, ndim });
756            Ok(pos)
757        }
758    }
759
760    /// Normalize axes list and validate it's a valid permutation.
761    fn normalize_axes(&self, axes: &[isize], ndim: usize) -> Result<Vec<usize>> {
762        snafu::ensure!(axes.len() == ndim, PermutationLengthMismatchSnafu { expected: ndim, got: axes.len() });
763
764        let mut normalized = Vec::with_capacity(ndim);
765        for &axis in axes {
766            normalized.push(Self::normalize_axis(axis, ndim)?);
767        }
768
769        // Validate it's a permutation (each index appears exactly once)
770        let mut seen = vec![false; ndim];
771        for &idx in &normalized {
772            snafu::ensure!(!seen[idx], InvalidPermutationSnafu { axes: axes.to_vec() });
773            seen[idx] = true;
774        }
775
776        Ok(normalized)
777    }
778
779    /// Upper triangular mask: row + diagonal <= col.
780    fn tri(rows: i64, cols: i64, diagonal: i64) -> Result<Tensor> {
781        let row = Tensor::arange(0, Some(rows), None)?.try_unsqueeze(-1)?;
782        let col = Tensor::arange(0, Some(cols), None)?;
783        let diag = Tensor::const_(ConstValue::Int(diagonal), DType::Int32);
784        row.try_add(&diag)?.try_le(&col)
785    }
786
787    /// Keep upper triangle, zero below. Matches Tinygrad `Tensor.triu(diagonal)`.
788    pub fn triu(&self, diagonal: i64) -> Result<Tensor> {
789        let shape = self.shape()?;
790        let ndim = shape.len();
791        let r = shape[ndim - 2].as_const().unwrap() as i64;
792        let c = shape[ndim - 1].as_const().unwrap() as i64;
793        let mask = Self::tri(r, c, diagonal)?;
794        let zero = Tensor::new(self.uop().const_like(ConstValue::zero(self.uop().dtype().scalar().unwrap())));
795        self.where_(&mask, &zero)
796    }
797
798    /// Keep lower triangle, zero above. Matches Tinygrad `Tensor.tril(diagonal)`.
799    pub fn tril(&self, diagonal: i64) -> Result<Tensor> {
800        let shape = self.shape()?;
801        let ndim = shape.len();
802        let r = shape[ndim - 2].as_const().unwrap() as i64;
803        let c = shape[ndim - 1].as_const().unwrap() as i64;
804        let mask = Self::tri(r, c, diagonal + 1)?;
805        let zero = Tensor::new(self.uop().const_like(ConstValue::zero(self.uop().dtype().scalar().unwrap())));
806        zero.where_(&mask, self)
807    }
808}
809
810#[bon]
811impl Tensor {
812    /// Slice tensor with Python-style indexing: negative indices, steps, and axis selection.
813    #[builder]
814    pub fn slice_with(
815        &self,
816        starts: &[i64],
817        ends: &[i64],
818        axes: Option<&[i64]>,
819        steps: Option<&[i64]>,
820    ) -> Result<Tensor> {
821        let shape = self.shape()?;
822        let ndim = shape.len();
823
824        let axes: Vec<usize> = axes
825            .map(|v| v.iter().map(|&a| if a < 0 { (ndim as i64 + a) as usize } else { a as usize }).collect())
826            .unwrap_or_else(|| (0..starts.len()).collect());
827
828        let default_steps;
829        let steps = match steps {
830            Some(s) => s,
831            None => {
832                default_steps = vec![1i64; starts.len()];
833                &default_steps
834            }
835        };
836
837        let mut ranges: Vec<(isize, isize)> =
838            (0..ndim).map(|d| (0isize, shape[d].as_const().unwrap() as isize)).collect();
839        let mut flip_axes: Vec<isize> = Vec::new();
840
841        for (i, &axis) in axes.iter().enumerate() {
842            let d = shape[axis].as_const().unwrap() as i64;
843            let step = steps[i];
844            if step == 0 {
845                return Err(crate::error::Error::IrConstruction { details: "Slice step cannot be 0".into() });
846            }
847
848            let (lower, upper) = if step > 0 { (0i64, d) } else { (-1i64, d - 1) };
849            let mut s = starts[i].clamp(-d, d);
850            if s < 0 {
851                s += d;
852            }
853            let s = s.clamp(lower, upper);
854
855            let mut e = ends[i].clamp(-d - 1, d);
856            if e < 0 {
857                e += d;
858            }
859            let e = e.clamp(lower, upper);
860
861            if step * (e - s) < 0 {
862                ranges[axis] = (0, 0);
863            } else if step < 0 {
864                flip_axes.push(axis as isize);
865                ranges[axis] = ((e + 1) as isize, (s + 1) as isize);
866            } else {
867                ranges[axis] = (s as isize, e as isize);
868            }
869        }
870
871        let mut result = self.try_shrink(&ranges)?;
872        if !flip_axes.is_empty() {
873            result = result.flip(&flip_axes)?;
874        }
875
876        for (i, &axis) in axes.iter().enumerate() {
877            let abs_step = steps[i].unsigned_abs() as usize;
878            if abs_step <= 1 {
879                continue;
880            }
881            let cur = result.shape()?;
882            let size = cur[axis].as_const().unwrap();
883            let padded = size.div_ceil(abs_step) * abs_step;
884            if padded > size {
885                let mut p = vec![(0isize, 0isize); cur.len()];
886                p[axis] = (0, (padded - size) as isize);
887                result = result.try_pad(&p)?;
888            }
889            let n = padded / abs_step;
890            let cs = result.shape()?;
891            let mut rs: Vec<isize> = Vec::new();
892            for (d, dim) in cs.iter().enumerate() {
893                if d == axis {
894                    rs.push(n as isize);
895                    rs.push(abs_step as isize);
896                } else {
897                    rs.push(dim.as_const().unwrap() as isize);
898                }
899            }
900            result = result.try_reshape(&rs)?;
901            let ss = result.shape()?;
902            let sr: Vec<(isize, isize)> = ss
903                .iter()
904                .enumerate()
905                .map(|(d, dim)| if d == axis + 1 { (0, 1) } else { (0, dim.as_const().unwrap() as isize) })
906                .collect();
907            result = result.try_shrink(&sr)?;
908            let fs: Vec<isize> = result
909                .shape()?
910                .iter()
911                .enumerate()
912                .filter(|&(d, _)| d != axis + 1)
913                .map(|(_, dim)| dim.as_const().unwrap() as isize)
914                .collect();
915            result = result.try_reshape(&fs)?;
916        }
917
918        if !flip_axes.is_empty() || steps.iter().any(|&s| s.unsigned_abs() > 1) {
919            result = result.contiguous();
920        }
921
922        Ok(result)
923    }
924}