1use bon::bon;
12use snafu::ResultExt;
13use strum::{Display, EnumString};
14use svod_ir::IntoShrinkRange;
15
16use super::*;
17
18#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, EnumString, Display)]
20pub enum MeshgridIndexing {
21 #[default]
22 #[strum(serialize = "ij")]
23 Ij,
24 #[strum(serialize = "xy")]
25 Xy,
26}
27
28impl Tensor {
29 #[track_caller]
50 pub fn try_reshape(&self, new_shape: impl IntoIterator<Item = impl Into<SInt>>) -> Result<Tensor> {
51 let dims: Vec<SInt> = new_shape.into_iter().map(Into::into).collect();
52
53 let infer_count = dims.iter().filter(|d| d.is_infer()).count();
55 snafu::ensure!(infer_count <= 1, MultipleInferDimensionsSnafu);
56
57 let shape: Shape = if infer_count == 1 {
58 let current_shape = self.shape()?;
59 let total_elements =
60 current_shape.iter().try_fold(1usize, |acc, dim| dim.as_const().map(|v| acc * v)).ok_or_else(|| {
61 Error::SymbolicShapeUnsupported { operation: "reshape with -1 inference".to_string() }
62 })?;
63 let known_product: usize = dims
64 .iter()
65 .filter(|d| !d.is_infer())
66 .map(|d| d.as_const().expect("non-infer dims must be concrete for -1 inference"))
67 .product();
68 snafu::ensure!(
69 known_product > 0 && total_elements % known_product == 0,
70 ReshapeSizeMismatchSnafu { operation: "reshape with inference".to_string() }
71 );
72 let inferred = total_elements / known_product;
73 dims.iter().map(|d| if d.is_infer() { SInt::Const(inferred) } else { d.clone() }).collect()
74 } else {
75 dims.into()
76 };
77
78 self.uop().try_reshape(&shape).map(Self::new).context(UOpSnafu)
79 }
80
81 pub fn try_expand(&self, new_shape: impl IntoIterator<Item = impl Into<SInt>>) -> Result<Tensor> {
83 let requested: Vec<SInt> = new_shape.into_iter().map(Into::into).collect();
84 let current_shape = self.shape()?;
86 let shape: Shape = requested
87 .into_iter()
88 .enumerate()
89 .map(|(i, s)| if s.is_infer() { current_shape[i].clone() } else { s })
90 .collect();
91 self.uop().try_expand(&shape).map(Self::new).context(UOpSnafu)
92 }
93
94 #[track_caller]
114 pub fn try_permute(&self, axes: &[isize]) -> Result<Tensor> {
115 let shape = self.shape()?;
116 let ndim = shape.len();
117
118 let normalized_axes = self.normalize_axes(axes, ndim)?;
120
121 self.uop().try_permute(normalized_axes).map(Self::new).context(UOpSnafu)
122 }
123
124 #[track_caller]
142 pub fn try_transpose(&self, dim0: isize, dim1: isize) -> Result<Tensor> {
143 let shape = self.shape()?;
144 let ndim = shape.len();
145
146 let d0 = Self::normalize_axis(dim0, ndim)?;
148 let d1 = Self::normalize_axis(dim1, ndim)?;
149
150 let mut axes: Vec<usize> = (0..ndim).collect();
152 axes.swap(d0, d1);
153
154 self.uop().try_permute(axes).map(Self::new).context(UOpSnafu)
155 }
156
157 #[track_caller]
191 pub fn try_squeeze(&self, dim: Option<isize>) -> Result<Tensor> {
192 let shape = self.shape()?;
193
194 let new_shape = match dim {
195 None => {
196 shape
198 .iter()
199 .filter_map(|s| s.as_const().and_then(|v| if v != 1 { Some(SInt::Const(v)) } else { None }))
200 .collect()
201 }
202 Some(axis) => {
203 let ndim = shape.len();
204 let normalized_axis = Self::normalize_axis(axis, ndim)?;
205
206 let dim_size = shape[normalized_axis]
208 .as_const()
209 .ok_or_else(|| Error::SymbolicShapeUnsupported { operation: "squeeze".to_string() })?;
210
211 snafu::ensure!(dim_size == 1, SqueezeDimensionNotOneSnafu { dim: normalized_axis, size: dim_size });
212
213 shape
215 .iter()
216 .enumerate()
217 .filter_map(|(i, s)| if i != normalized_axis { Some(s.clone()) } else { None })
218 .collect()
219 }
220 };
221
222 self.uop().try_reshape(&new_shape).map(Self::new).context(UOpSnafu)
223 }
224
225 #[track_caller]
244 pub fn try_unsqueeze(&self, dim: isize) -> Result<Tensor> {
245 let shape = self.shape()?;
246 let ndim = shape.len();
247
248 let normalized_dim = if dim < 0 {
251 let positive = (ndim as isize + 1 + dim) as usize;
252 snafu::ensure!(dim >= -(ndim as isize + 1), AxisOutOfRangeSnafu { axis: dim, ndim });
253 positive
254 } else {
255 let pos = dim as usize;
256 snafu::ensure!(pos <= ndim, AxisOutOfRangeSnafu { axis: dim, ndim });
257 pos
258 };
259
260 let mut new_shape = shape.clone();
262 new_shape.insert(normalized_dim, SInt::Const(1));
263
264 self.uop().try_reshape(&new_shape).map(Self::new).context(UOpSnafu)
265 }
266
267 #[track_caller]
277 pub fn flip(&self, axes: &[isize]) -> Result<Tensor> {
278 let shape = self.shape()?;
279 let ndim = shape.len();
280 let flip_spec: Vec<bool> =
281 (0..ndim).map(|d| axes.iter().any(|&a| Self::normalize_axis(a, ndim).is_ok_and(|na| na == d))).collect();
282 self.uop().try_flip(flip_spec).map(Self::new).context(UOpSnafu)
283 }
284
285 #[track_caller]
295 pub fn split(&self, sizes: &[usize], dim: isize) -> Result<Vec<Tensor>> {
296 let shape = self.shape()?;
297 let ndim = shape.len();
298 let dim = Self::normalize_axis(dim, ndim)?;
299 let mut results = Vec::with_capacity(sizes.len());
300 let mut offset = 0usize;
301 for &size in sizes {
302 let ranges: Vec<Option<(isize, isize)>> = (0..ndim)
303 .map(|d| {
304 if d == dim {
305 Some((offset as isize, (offset + size) as isize))
306 } else {
307 None }
309 })
310 .collect();
311 results.push(self.try_shrink(ranges)?);
312 offset += size;
313 }
314 Ok(results)
315 }
316
317 #[track_caller]
329 pub fn repeat(&self, repeats: &[SInt]) -> Result<Tensor> {
330 let shape = self.shape()?;
331 let ndim = shape.len();
332 snafu::ensure!(
333 repeats.len() == ndim,
334 ShapeMismatchSnafu {
335 context: "repeat",
336 expected: format!("{} dimensions", ndim),
337 actual: format!("{} repeats", repeats.len())
338 }
339 );
340 let mut result = self.clone();
341 for (dim, rep) in repeats.iter().enumerate() {
342 if rep.as_const() == Some(1) {
343 continue;
344 }
345 let current_shape = result.shape()?;
346 let dim_size = ¤t_shape[dim];
347 result = result.try_unsqueeze(dim as isize)?;
349 let mut expand_shape: Vec<SInt> = current_shape.iter().cloned().collect();
350 expand_shape.insert(dim, rep.clone());
351 result = result.try_expand(&expand_shape)?;
352 expand_shape[dim] = rep * dim_size;
353 expand_shape.remove(dim + 1);
354 result = result.try_reshape(expand_shape)?;
355 }
356 Ok(result)
357 }
358
359 #[track_caller]
370 pub fn flatten(&self) -> Result<Tensor> {
371 self.try_reshape([-1])
372 }
373
374 #[track_caller]
393 pub fn try_pad(&self, padding: &[(isize, isize)]) -> Result<Tensor> {
394 let shape = self.shape()?;
395
396 if padding.is_empty() {
398 return Ok(self.clone());
399 }
400
401 snafu::ensure!(
403 padding.len() == shape.len(),
404 ShapeMismatchSnafu {
405 context: "pad",
406 expected: format!("{} dimensions", shape.len()),
407 actual: format!("{} padding pairs", padding.len())
408 }
409 );
410
411 let needs_shrink = padding.iter().any(|(b, e)| *b < 0 || *e < 0);
413 let base = if needs_shrink {
414 let shrink_ranges: Vec<(isize, isize)> = padding
415 .iter()
416 .zip(shape.iter())
417 .map(|((b, e), s)| {
418 let dim = s.as_const().expect("pad with negative values requires concrete shape") as isize;
419 let begin = (-*b).max(0);
420 let end = (dim + *e).min(dim);
421 (begin, end)
422 })
423 .collect();
424 self.try_shrink(&shrink_ranges)?
425 } else {
426 self.clone()
427 };
428
429 let pos_padding: Vec<(isize, isize)> = padding.iter().map(|(b, e)| ((*b).max(0), (*e).max(0))).collect();
431 if pos_padding.iter().all(|(b, e)| *b == 0 && *e == 0) {
432 return Ok(base);
433 }
434
435 let padding_sint: Vec<(SInt, SInt)> =
436 pos_padding.iter().map(|(begin, end)| (SInt::Const(*begin as usize), SInt::Const(*end as usize))).collect();
437
438 base.uop().try_pad(&padding_sint).map(Self::new).context(UOpSnafu)
439 }
440
441 #[track_caller]
460 pub fn cat(tensors: &[&Tensor], dim: isize) -> Result<Tensor> {
461 if tensors.is_empty() {
462 return Err(IrConstructionSnafu { details: "cat requires at least one tensor".to_string() }.build());
463 }
464
465 let first = tensors[0];
466 let first_shape = first.shape()?;
467 let ndim = first_shape.len();
468 let dim = Self::normalize_axis(dim, ndim)?;
469
470 for (i, t) in tensors.iter().enumerate().skip(1) {
472 let t_shape = t.shape()?;
473 snafu::ensure!(
474 t_shape.len() == ndim,
475 ShapeMismatchSnafu {
476 context: "cat",
477 expected: format!("{} dimensions", ndim),
478 actual: format!("{} dimensions for tensor {}", t_shape.len(), i)
479 }
480 );
481 for (d, (s1, s2)) in first_shape.iter().zip(t_shape.iter()).enumerate() {
482 if d != dim {
483 snafu::ensure!(
484 s1 == s2,
485 ShapeMismatchSnafu {
486 context: format!("cat dimension {}", d),
487 expected: format!("{:?}", s1),
488 actual: format!("{:?}", s2)
489 }
490 );
491 }
492 }
493 }
494
495 let dim_sizes: Vec<usize> = tensors.iter().map(|t| t.shape().unwrap()[dim].as_const().unwrap_or(0)).collect();
497 let total_dim: usize = dim_sizes.iter().sum();
498
499 let mut cumsum = 0usize;
501 let padded: Vec<Tensor> = tensors
502 .iter()
503 .zip(dim_sizes.iter())
504 .map(|(t, &sz)| {
505 let begin_pad = cumsum;
506 let end_pad = total_dim - cumsum - sz;
507 cumsum += sz;
508
509 let mut padding = vec![(0isize, 0isize); ndim];
510 padding[dim] = (begin_pad as isize, end_pad as isize);
511 t.try_pad(&padding)
512 })
513 .collect::<Result<Vec<_>>>()?;
514
515 let mut result = padded[0].clone();
517 for t in padded.iter().skip(1) {
518 result = result.try_add(t)?;
519 }
520 Ok(result)
521 }
522
523 #[track_caller]
527 pub fn stack(tensors: &[&Tensor], dim: isize) -> Result<Tensor> {
528 let unsqueezed: Vec<Tensor> = tensors.iter().map(|t| t.try_unsqueeze(dim)).collect::<Result<_>>()?;
529 Tensor::cat(&unsqueezed.iter().collect::<Vec<_>>(), dim)
530 }
531
532 #[track_caller]
536 pub fn unflatten(&self, dim: isize, sizes: &[isize]) -> Result<Tensor> {
537 let shape = self.shape()?;
538 let dim = Self::normalize_axis(dim, shape.len())?;
539 let mut new_shape = svod_ir::shape::to_vec_isize(&shape).context(UOpSnafu)?;
540 new_shape.splice(dim..=dim, sizes.iter().copied());
541 self.try_reshape(&new_shape)
542 }
543
544 #[track_caller]
548 pub fn meshgrid(tensors: &[&Tensor], indexing: MeshgridIndexing) -> Result<Vec<Tensor>> {
549 let n = tensors.len();
550 let sizes: Vec<usize> = tensors.iter().map(|t| t.numel().unwrap()).collect();
551 let swapped: Vec<usize> = if indexing == MeshgridIndexing::Xy && n >= 2 {
553 let mut s: Vec<usize> = (0..n).collect();
554 s.swap(0, 1);
555 s
556 } else {
557 (0..n).collect()
558 };
559 let out_shape: Vec<isize> = swapped.iter().map(|&i| sizes[i] as isize).collect();
561 tensors
562 .iter()
563 .enumerate()
564 .map(|(i, t)| {
565 let pos = swapped.iter().position(|&s| s == i).unwrap();
567 let mut shape = vec![1isize; n];
568 shape[pos] = sizes[i] as isize;
569 t.flatten()?.try_reshape(&shape)?.try_expand(&out_shape)
570 })
571 .collect()
572 }
573
574 #[track_caller]
591 pub fn shape_tensor(&self) -> Result<Tensor> {
592 let shape = self.shape()?;
593
594 if shape.iter().all(|d| d.is_const()) {
596 let dims: Vec<i64> = shape.iter().map(|d| d.as_const().unwrap() as i64).collect();
597 return Ok(Tensor::from_slice(&dims));
598 }
599
600 let shape_sint: smallvec::SmallVec<[SInt; 4]> = smallvec::smallvec![SInt::from(1usize)];
602 let scalars: Result<Vec<Tensor>> = shape
603 .iter()
604 .map(|d| {
605 let uop = d.to_uop(svod_dtype::DType::Int64);
606 uop.try_reshape(&shape_sint).map(Tensor::new).context(UOpSnafu)
607 })
608 .collect();
609 let scalars = scalars?;
610 let refs: Vec<&Tensor> = scalars.iter().collect();
611 Tensor::cat(&refs, 0)
612 }
613
614 #[track_caller]
631 pub fn try_shrink<R: IntoShrinkRange>(&self, ranges: impl IntoIterator<Item = R>) -> Result<Tensor> {
632 use svod_ir::ShrinkRange;
633
634 let shape = self.shape()?;
635 let resolved: Vec<ShrinkRange> = ranges.into_iter().map(|r| r.into_shrink_range()).collect();
636
637 if resolved.is_empty() {
639 return Ok(self.clone());
640 }
641
642 if resolved.iter().all(|r| matches!(r, ShrinkRange::None)) {
644 return Ok(self.clone());
645 }
646
647 let ranges_sint: Vec<(SInt, SInt)> = resolved
650 .into_iter()
651 .enumerate()
652 .map(|(dim_idx, range)| match range {
653 ShrinkRange::None => Ok((SInt::Const(0), shape[dim_idx].clone())),
654 ShrinkRange::Sint(begin, end) => Ok((begin, end)),
655 ShrinkRange::Isize(begin, end) => {
656 let (nb, ne) = if begin < 0 || end < 0 {
657 let dim_size = shape[dim_idx].as_const().ok_or_else(|| Error::SymbolicShapeUnsupported {
658 operation: "shrink with negative indices".to_string(),
659 })? as isize;
660 (if begin < 0 { dim_size + begin } else { begin }, if end < 0 { dim_size + end } else { end })
661 } else {
662 (begin, end)
663 };
664 Ok((SInt::Const(nb as usize), SInt::Const(ne as usize)))
665 }
666 })
667 .collect::<Result<Vec<_>>>()?;
668
669 self.uop().try_shrink(&ranges_sint).map(Self::new).context(UOpSnafu)
670 }
671
672 pub fn center_crop_pad(&self, target_shape: &[usize], axes: Option<&[usize]>) -> Result<Tensor> {
680 let shape = self.shape()?;
681 let ndim = shape.len();
682 let default_axes: Vec<usize> = (0..ndim).collect();
683 let axes = axes.unwrap_or(&default_axes);
684
685 let mut shrink_arg: Vec<(isize, isize)> =
686 (0..ndim).map(|i| (0, shape[i].as_const().unwrap_or(1) as isize)).collect();
687 let mut pad_arg: Vec<(isize, isize)> = vec![(0, 0); ndim];
688
689 for (&s, &ax) in target_shape.iter().zip(axes.iter()) {
690 let s = s as isize;
691 let tx = shape[ax].as_const().unwrap_or(1) as isize;
692 if s < tx {
693 shrink_arg[ax] = (tx / 2 - (s + 1) / 2, tx / 2 + s / 2);
694 } else if s > tx {
695 pad_arg[ax] = ((s - tx) / 2, (s - tx + 1) / 2);
696 }
697 }
698
699 self.try_shrink(&shrink_arg)?.try_pad(&pad_arg)
700 }
701
702 pub fn shape(&self) -> Result<Shape> {
708 self.uop().shape().context(UOpSnafu)?.cloned().ok_or(Error::ShapeUnknown)
709 }
710
711 pub fn ndim(&self) -> Result<usize> {
713 Ok(self.shape()?.len())
714 }
715
716 pub fn numel(&self) -> Result<usize> {
718 self.shape()?.iter().try_fold(1usize, |acc, d| {
719 d.as_const().map(|v| acc * v).ok_or(Error::SymbolicShapeUnsupported { operation: "numel".into() })
720 })
721 }
722
723 pub(crate) fn dim(&self, axis: isize) -> Result<svod_ir::SInt> {
742 let shape = self.shape()?;
743 let idx = Self::normalize_axis(axis, shape.len())?;
744 Ok(shape[idx].clone())
745 }
746
747 pub(crate) fn normalize_axis(axis: isize, ndim: usize) -> Result<usize> {
749 if axis < 0 {
750 let positive = (ndim as isize + axis) as usize;
751 snafu::ensure!(axis >= -(ndim as isize), AxisOutOfRangeSnafu { axis, ndim });
752 Ok(positive)
753 } else {
754 let pos = axis as usize;
755 snafu::ensure!(pos < ndim, AxisOutOfRangeSnafu { axis, ndim });
756 Ok(pos)
757 }
758 }
759
760 fn normalize_axes(&self, axes: &[isize], ndim: usize) -> Result<Vec<usize>> {
762 snafu::ensure!(axes.len() == ndim, PermutationLengthMismatchSnafu { expected: ndim, got: axes.len() });
763
764 let mut normalized = Vec::with_capacity(ndim);
765 for &axis in axes {
766 normalized.push(Self::normalize_axis(axis, ndim)?);
767 }
768
769 let mut seen = vec![false; ndim];
771 for &idx in &normalized {
772 snafu::ensure!(!seen[idx], InvalidPermutationSnafu { axes: axes.to_vec() });
773 seen[idx] = true;
774 }
775
776 Ok(normalized)
777 }
778
779 fn tri(rows: i64, cols: i64, diagonal: i64) -> Result<Tensor> {
781 let row = Tensor::arange(0, Some(rows), None)?.try_unsqueeze(-1)?;
782 let col = Tensor::arange(0, Some(cols), None)?;
783 let diag = Tensor::const_(ConstValue::Int(diagonal), DType::Int32);
784 row.try_add(&diag)?.try_le(&col)
785 }
786
787 pub fn triu(&self, diagonal: i64) -> Result<Tensor> {
789 let shape = self.shape()?;
790 let ndim = shape.len();
791 let r = shape[ndim - 2].as_const().unwrap() as i64;
792 let c = shape[ndim - 1].as_const().unwrap() as i64;
793 let mask = Self::tri(r, c, diagonal)?;
794 let zero = Tensor::new(self.uop().const_like(ConstValue::zero(self.uop().dtype().scalar().unwrap())));
795 self.where_(&mask, &zero)
796 }
797
798 pub fn tril(&self, diagonal: i64) -> Result<Tensor> {
800 let shape = self.shape()?;
801 let ndim = shape.len();
802 let r = shape[ndim - 2].as_const().unwrap() as i64;
803 let c = shape[ndim - 1].as_const().unwrap() as i64;
804 let mask = Self::tri(r, c, diagonal + 1)?;
805 let zero = Tensor::new(self.uop().const_like(ConstValue::zero(self.uop().dtype().scalar().unwrap())));
806 zero.where_(&mask, self)
807 }
808}
809
810#[bon]
811impl Tensor {
812 #[builder]
814 pub fn slice_with(
815 &self,
816 starts: &[i64],
817 ends: &[i64],
818 axes: Option<&[i64]>,
819 steps: Option<&[i64]>,
820 ) -> Result<Tensor> {
821 let shape = self.shape()?;
822 let ndim = shape.len();
823
824 let axes: Vec<usize> = axes
825 .map(|v| v.iter().map(|&a| if a < 0 { (ndim as i64 + a) as usize } else { a as usize }).collect())
826 .unwrap_or_else(|| (0..starts.len()).collect());
827
828 let default_steps;
829 let steps = match steps {
830 Some(s) => s,
831 None => {
832 default_steps = vec![1i64; starts.len()];
833 &default_steps
834 }
835 };
836
837 let mut ranges: Vec<(isize, isize)> =
838 (0..ndim).map(|d| (0isize, shape[d].as_const().unwrap() as isize)).collect();
839 let mut flip_axes: Vec<isize> = Vec::new();
840
841 for (i, &axis) in axes.iter().enumerate() {
842 let d = shape[axis].as_const().unwrap() as i64;
843 let step = steps[i];
844 if step == 0 {
845 return Err(crate::error::Error::IrConstruction { details: "Slice step cannot be 0".into() });
846 }
847
848 let (lower, upper) = if step > 0 { (0i64, d) } else { (-1i64, d - 1) };
849 let mut s = starts[i].clamp(-d, d);
850 if s < 0 {
851 s += d;
852 }
853 let s = s.clamp(lower, upper);
854
855 let mut e = ends[i].clamp(-d - 1, d);
856 if e < 0 {
857 e += d;
858 }
859 let e = e.clamp(lower, upper);
860
861 if step * (e - s) < 0 {
862 ranges[axis] = (0, 0);
863 } else if step < 0 {
864 flip_axes.push(axis as isize);
865 ranges[axis] = ((e + 1) as isize, (s + 1) as isize);
866 } else {
867 ranges[axis] = (s as isize, e as isize);
868 }
869 }
870
871 let mut result = self.try_shrink(&ranges)?;
872 if !flip_axes.is_empty() {
873 result = result.flip(&flip_axes)?;
874 }
875
876 for (i, &axis) in axes.iter().enumerate() {
877 let abs_step = steps[i].unsigned_abs() as usize;
878 if abs_step <= 1 {
879 continue;
880 }
881 let cur = result.shape()?;
882 let size = cur[axis].as_const().unwrap();
883 let padded = size.div_ceil(abs_step) * abs_step;
884 if padded > size {
885 let mut p = vec![(0isize, 0isize); cur.len()];
886 p[axis] = (0, (padded - size) as isize);
887 result = result.try_pad(&p)?;
888 }
889 let n = padded / abs_step;
890 let cs = result.shape()?;
891 let mut rs: Vec<isize> = Vec::new();
892 for (d, dim) in cs.iter().enumerate() {
893 if d == axis {
894 rs.push(n as isize);
895 rs.push(abs_step as isize);
896 } else {
897 rs.push(dim.as_const().unwrap() as isize);
898 }
899 }
900 result = result.try_reshape(&rs)?;
901 let ss = result.shape()?;
902 let sr: Vec<(isize, isize)> = ss
903 .iter()
904 .enumerate()
905 .map(|(d, dim)| if d == axis + 1 { (0, 1) } else { (0, dim.as_const().unwrap() as isize) })
906 .collect();
907 result = result.try_shrink(&sr)?;
908 let fs: Vec<isize> = result
909 .shape()?
910 .iter()
911 .enumerate()
912 .filter(|&(d, _)| d != axis + 1)
913 .map(|(_, dim)| dim.as_const().unwrap() as isize)
914 .collect();
915 result = result.try_reshape(&fs)?;
916 }
917
918 if !flip_axes.is_empty() || steps.iter().any(|&s| s.unsigned_abs() > 1) {
919 result = result.contiguous();
920 }
921
922 Ok(result)
923 }
924}