1use bon::bon;
7use snafu::ResultExt;
8use svod_dtype::{DType, ScalarDType};
9use svod_ir::{ConstValue, ReduceOp, SInt, UOp};
10
11use crate::{
12 Error, Result, Tensor,
13 error::{SymbolicShapeUnsupportedSnafu, UOpSnafu},
14};
15
16#[derive(Debug, Clone)]
23pub enum AxisSpec {
24 All,
26 Single(isize),
28 Multiple(Vec<isize>),
30}
31
32impl From<()> for AxisSpec {
34 fn from(_: ()) -> Self {
35 Self::All
36 }
37}
38
39impl From<isize> for AxisSpec {
40 fn from(axis: isize) -> Self {
41 Self::Single(axis)
42 }
43}
44
45impl From<&[isize]> for AxisSpec {
46 fn from(axes: &[isize]) -> Self {
47 Self::Multiple(axes.to_vec())
48 }
49}
50
51impl From<Vec<isize>> for AxisSpec {
52 fn from(axes: Vec<isize>) -> Self {
53 Self::Multiple(axes)
54 }
55}
56
57impl Tensor {
62 pub(crate) fn resolve_axis_spec(spec: &AxisSpec, ndim: usize) -> Result<Vec<usize>> {
70 match spec {
71 AxisSpec::All => Ok((0..ndim).collect()),
72 AxisSpec::Single(axis) => {
73 let normalized = Self::normalize_axis(*axis, ndim)?;
74 Ok(vec![normalized])
75 }
76 AxisSpec::Multiple(axes) => {
77 let mut normalized: Vec<usize> =
78 axes.iter().map(|&axis| Self::normalize_axis(axis, ndim)).collect::<Result<_>>()?;
79
80 normalized.sort_unstable();
82 normalized.dedup();
83
84 Ok(normalized)
85 }
86 }
87 }
88
89 pub(crate) fn sum_acc_dtype(dtype: &DType) -> DType {
102 use ScalarDType::*;
103 let Some(scalar) = dtype.scalar() else {
104 return dtype.clone();
105 };
106
107 match scalar {
108 Bool => DType::Int32,
109 Int8 | Int16 => DType::Int32,
110 Int32 | Int64 => dtype.clone(),
111 UInt8 | UInt16 => DType::UInt32,
112 UInt32 | UInt64 => dtype.clone(),
113 Float16 | BFloat16 | FP8E4M3 | FP8E5M2 => DType::Float32,
114 Float32 | Float64 => dtype.clone(),
115 Void | Index => dtype.clone(),
116 }
117 }
118
119 fn should_cast_back_after_sum(dtype: &DType) -> bool {
126 matches!(
127 dtype.scalar(),
128 Some(ScalarDType::Float16 | ScalarDType::BFloat16 | ScalarDType::FP8E4M3 | ScalarDType::FP8E5M2)
129 )
130 }
131
132 fn is_integer_dtype(dtype: &DType) -> bool {
134 dtype.is_int() || matches!(dtype.scalar(), Some(ScalarDType::Bool))
135 }
136
137 fn remove_singleton_dims(self, reduced_axes: &[usize]) -> Result<Self> {
143 let shape = self.shape()?;
144
145 let new_shape: Vec<SInt> = shape
147 .iter()
148 .enumerate()
149 .filter_map(|(i, dim)| {
150 if reduced_axes.contains(&i) {
152 None } else {
154 Some(dim.clone())
155 }
156 })
157 .collect();
158
159 if new_shape.is_empty() {
161 self.try_reshape(std::iter::empty::<SInt>())
164 } else {
165 self.try_reshape(&new_shape)
166 }
167 }
168}
169
170#[bon]
171impl Tensor {
172 #[track_caller]
177 pub fn sum(&self, axes: impl Into<AxisSpec>) -> Result<Self> {
178 reduce_internal(self, ReduceOp::Add, axes.into(), false, None, true)
179 }
180
181 #[builder]
195 #[track_caller]
196 pub fn sum_with(
197 &self,
198 axes: impl Into<AxisSpec>,
199 #[builder(default = false)] keepdim: bool,
200 dtype: Option<DType>,
201 #[builder(default = false)] promote: bool,
202 ) -> Result<Self> {
203 reduce_internal(self, ReduceOp::Add, axes.into(), keepdim, dtype, promote)
204 }
205
206 #[track_caller]
210 pub fn prod(&self, axes: impl Into<AxisSpec>) -> Result<Self> {
211 reduce_internal(self, ReduceOp::Mul, axes.into(), false, None, false)
212 }
213
214 #[builder]
216 #[track_caller]
217 pub fn prod_with(
218 &self,
219 axes: impl Into<AxisSpec>,
220 #[builder(default = false)] keepdim: bool,
221 dtype: Option<DType>,
222 #[builder(default = false)] promote: bool,
223 ) -> Result<Self> {
224 reduce_internal(self, ReduceOp::Mul, axes.into(), keepdim, dtype, promote)
225 }
226
227 pub fn max(&self, axes: impl Into<AxisSpec>) -> Result<Self> {
231 reduce_internal(self, ReduceOp::Max, axes.into(), false, None, false)
232 }
233
234 #[builder]
236 #[track_caller]
237 pub fn max_with(&self, axes: impl Into<AxisSpec>, #[builder(default = false)] keepdim: bool) -> Result<Self> {
238 reduce_internal(self, ReduceOp::Max, axes.into(), keepdim, None, false)
239 }
240
241 #[track_caller]
245 pub fn min(&self, axes: impl Into<AxisSpec>) -> Result<Self> {
246 reduce_internal(self, ReduceOp::Min, axes.into(), false, None, false)
247 }
248
249 #[builder]
251 #[track_caller]
252 pub fn min_with(&self, axes: impl Into<AxisSpec>, #[builder(default = false)] keepdim: bool) -> Result<Self> {
253 reduce_internal(self, ReduceOp::Min, axes.into(), keepdim, None, false)
254 }
255
256 #[track_caller]
261 pub fn mean(&self, axes: impl Into<AxisSpec>) -> Result<Self> {
262 mean_impl(self, axes.into(), false)
263 }
264
265 #[builder]
267 #[track_caller]
268 pub fn mean_with(&self, axes: impl Into<AxisSpec>, #[builder(default = false)] keepdim: bool) -> Result<Self> {
269 mean_impl(self, axes, keepdim)
270 }
271
272 #[track_caller]
284 pub fn var(&self, axes: impl Into<AxisSpec>) -> Result<Self> {
285 var_impl(self, axes.into(), false)
286 }
287
288 #[builder]
290 #[track_caller]
291 pub fn var_with(&self, axes: impl Into<AxisSpec>, #[builder(default = false)] keepdim: bool) -> Result<Self> {
292 var_impl(self, axes.into(), keepdim)
293 }
294
295 #[track_caller]
307 pub fn std(&self, axes: impl Into<AxisSpec>) -> Result<Self> {
308 std_impl(self, axes.into(), false)
309 }
310
311 #[builder]
313 #[track_caller]
314 pub fn std_with(&self, axes: impl Into<AxisSpec>, #[builder(default = false)] keepdim: bool) -> Result<Self> {
315 std_impl(self, axes.into(), keepdim)
316 }
317
318 #[track_caller]
329 pub fn var_mean(&self, axes: impl Into<AxisSpec>) -> Result<(Self, Self)> {
330 var_mean_impl(self, axes.into(), false)
331 }
332
333 #[builder]
335 #[track_caller]
336 pub fn var_mean_with(
337 &self,
338 axes: impl Into<AxisSpec>,
339 #[builder(default = false)] keepdim: bool,
340 ) -> Result<(Self, Self)> {
341 var_mean_impl(self, axes.into(), keepdim)
342 }
343
344 #[track_caller]
355 pub fn std_mean(&self, axes: impl Into<AxisSpec>) -> Result<(Self, Self)> {
356 std_mean_impl(self, axes.into(), false)
357 }
358
359 #[builder]
361 #[track_caller]
362 pub fn std_mean_with(
363 &self,
364 axes: impl Into<AxisSpec>,
365 #[builder(default = false)] keepdim: bool,
366 ) -> Result<(Self, Self)> {
367 std_mean_impl(self, axes.into(), keepdim)
368 }
369
370 fn inverse(&self) -> Result<Self> {
376 let dtype = self.uop().dtype();
377 if dtype.is_float() {
378 self.try_neg()
379 } else if dtype.is_int() {
380 self.bitwise_not()
381 } else if matches!(dtype.scalar(), Some(ScalarDType::Bool)) {
382 self.logical_not()
383 } else {
384 Ok(self.clone()) }
386 }
387}
388
389#[bon]
394impl Tensor {
395 #[track_caller]
411 pub fn argmax(&self, axis: impl Into<Option<isize>>) -> Result<Self> {
412 argmax_impl(self, axis.into(), false)
413 }
414
415 #[builder]
417 #[track_caller]
418 pub fn argmax_with(
419 &self,
420 axis: impl Into<Option<isize>>,
421 #[builder(default = false)] keepdim: bool,
422 ) -> Result<Self> {
423 argmax_impl(self, axis.into(), keepdim)
424 }
425
426 #[track_caller]
431 pub fn hardmax(&self, axis: isize) -> Result<Self> {
432 let shape = self.shape()?;
433 let ndim = shape.len();
434 let norm_axis = Self::normalize_axis(axis, ndim)?;
435 let axis_size = shape[norm_axis].as_const().ok_or_else(|| crate::error::Error::SymbolicShapeUnsupported {
436 operation: format!("hardmax axis {norm_axis}"),
437 })?;
438 self.argmax_with()
439 .axis(Some(axis))
440 .keepdim(false)
441 .call()?
442 .try_unsqueeze(axis)?
443 .one_hot_along_dim(axis_size, axis)?
444 .cast(self.uop().dtype())
445 }
446
447 #[track_caller]
452 pub fn argmin(&self, axis: impl Into<Option<isize>>) -> Result<Self> {
453 argmin_impl(self, axis.into(), false)
454 }
455
456 #[builder]
458 #[track_caller]
459 pub fn argmin_with(
460 &self,
461 axis: impl Into<Option<isize>>,
462 #[builder(default = false)] keepdim: bool,
463 ) -> Result<Self> {
464 argmin_impl(self, axis.into(), keepdim)
465 }
466
467 #[track_caller]
480 pub fn any(&self, axes: impl Into<AxisSpec>) -> Result<Self> {
481 any_impl(self, axes.into(), false)
482 }
483
484 #[builder]
486 #[track_caller]
487 pub fn any_with(&self, axes: impl Into<AxisSpec>, #[builder(default = false)] keepdim: bool) -> Result<Self> {
488 any_impl(self, axes.into(), keepdim)
489 }
490
491 #[track_caller]
504 pub fn all(&self, axes: impl Into<AxisSpec>) -> Result<Self> {
505 all_impl(self, axes.into(), false)
506 }
507
508 #[builder]
510 #[track_caller]
511 pub fn all_with(&self, axes: impl Into<AxisSpec>, #[builder(default = false)] keepdim: bool) -> Result<Self> {
512 all_impl(self, axes.into(), keepdim)
513 }
514}
515
516fn argmax_impl(tensor: &Tensor, axis: Option<isize>, keepdim: bool) -> Result<Tensor> {
518 let (working_tensor, working_axis) =
520 if let Some(ax) = axis { (tensor.clone(), ax) } else { (tensor.flatten()?, 0) };
521
522 let shape = working_tensor.shape()?;
523 let normalized_axis = Tensor::normalize_axis(working_axis, shape.len())?;
524 let axis_size = shape[normalized_axis]
525 .as_const()
526 .ok_or_else(|| Error::SymbolicShapeUnsupported { operation: "argmax".to_string() })?;
527
528 let shape_vec = svod_ir::shape::to_vec_isize(&shape).context(UOpSnafu)?;
530
531 let max_vals_keepdim = working_tensor.max_with().axes(working_axis).keepdim(true).call()?;
533
534 let max_vals_broadcast = max_vals_keepdim.try_expand(&shape_vec)?;
537
538 let mask = working_tensor.try_eq(&max_vals_broadcast)?;
539
540 let indices = Tensor::arange(axis_size as i64, Some(0), Some(-1))?;
543
544 let mut idx_shape = vec![1isize; shape.len()];
547 idx_shape[normalized_axis] = axis_size as isize;
548 let indices_reshaped = indices.try_reshape(&idx_shape)?;
549
550 let indices_broadcast = indices_reshaped.try_expand(&shape_vec)?;
552
553 let mask_int = mask.cast(DType::Int32)?;
555 let masked_indices = mask_int.try_mul(&indices_broadcast)?;
556
557 let max_idx = masked_indices.max_with().axes(working_axis).keepdim(keepdim).call()?;
559
560 let n_tensor = Tensor::from_slice([axis_size as i32]);
562
563 let max_idx_shape = max_idx.shape()?;
565 let result = if !max_idx_shape.is_empty() {
566 let max_idx_shape_vec = svod_ir::shape::to_vec_isize(&max_idx_shape).context(UOpSnafu)?;
568 let ones_shape = vec![1isize; max_idx_shape.len()];
569 let n_reshaped = n_tensor.try_reshape(&ones_shape)?;
570 let n_broadcast = n_reshaped.try_expand(&max_idx_shape_vec)?;
571 n_broadcast.try_sub(&max_idx)?
572 } else {
573 let n_scalar = n_tensor.try_reshape(&[] as &[isize])?;
575 n_scalar.try_sub(&max_idx)?
576 };
577
578 result.cast(DType::Int32)
580}
581
582fn argmin_impl(tensor: &Tensor, axis: Option<isize>, keepdim: bool) -> Result<Tensor> {
584 let inverted = tensor.inverse()?;
586 argmax_impl(&inverted, axis, keepdim)
587}
588
589fn any_impl(tensor: &Tensor, axes: AxisSpec, keepdim: bool) -> Result<Tensor> {
591 let as_bool = tensor.cast(DType::Bool)?;
593
594 reduce_internal(&as_bool, ReduceOp::Max, axes, keepdim, None, false)
596}
597
598fn all_impl(tensor: &Tensor, axes: AxisSpec, keepdim: bool) -> Result<Tensor> {
600 let negated = tensor.logical_not()?;
602 let any_negated = any_impl(&negated, axes, keepdim)?;
603 any_negated.logical_not()
604}
605
606fn reduction_identity(op: ReduceOp, dtype: &DType) -> ConstValue {
608 let s = dtype.scalar().expect("scalar dtype");
609 match op {
610 ReduceOp::Add => ConstValue::zero(s),
611 ReduceOp::Mul => ConstValue::one(s),
612 ReduceOp::Max => ConstValue::min(s),
613 ReduceOp::Min => ConstValue::max(s),
614 }
615}
616
617#[track_caller]
619fn reduce_internal(
620 tensor: &Tensor,
621 op: ReduceOp,
622 axes: AxisSpec,
623 keepdim: bool,
624 dtype: Option<DType>,
625 promote: bool,
626) -> Result<Tensor> {
627 if dtype.is_some() && promote {
629 return Err(Error::ConflictingReductionOptions);
630 }
631
632 let shape = tensor.shape()?;
633 let resolved_axes = Tensor::resolve_axis_spec(&axes, shape.len())?;
634
635 let original_dtype = tensor.uop().dtype();
637 let acc_dtype = if let Some(ref dt) = dtype {
638 dt.clone()
640 } else if promote {
641 Tensor::sum_acc_dtype(&original_dtype)
643 } else {
644 original_dtype.clone()
646 };
647
648 let reducing_empty_axis = resolved_axes.iter().any(|&ax| shape[ax].as_const() == Some(0));
651 if reducing_empty_axis {
652 let out_shape: Vec<usize> = shape
654 .iter()
655 .enumerate()
656 .map(|(i, d)| if resolved_axes.contains(&i) { 1 } else { d.as_const().unwrap_or(1) })
657 .collect();
658
659 let identity = reduction_identity(op, &acc_dtype);
660 let result = Tensor::full(&out_shape, identity, acc_dtype)?;
661
662 let result = if !keepdim { result.remove_singleton_dims(&resolved_axes)? } else { result };
663
664 return if promote && dtype.is_none() && Tensor::should_cast_back_after_sum(&original_dtype) {
665 result.cast(original_dtype)
666 } else {
667 Ok(result)
668 };
669 }
670
671 let working_tensor = if acc_dtype != original_dtype { tensor.cast(acc_dtype.clone())? } else { tensor.clone() };
673
674 let reduced = working_tensor.uop().try_reduce_axis(op, resolved_axes.clone()).context(UOpSnafu)?;
676
677 let result = if keepdim {
679 Tensor::new(reduced)
680 } else {
681 let temp = Tensor::new(reduced);
682 temp.remove_singleton_dims(&resolved_axes)?
683 };
684
685 if promote && dtype.is_none() && Tensor::should_cast_back_after_sum(&original_dtype) {
687 result.cast(original_dtype)
688 } else {
689 Ok(result)
690 }
691}
692
693fn mean_impl(tensor: &Tensor, axes: impl Into<AxisSpec>, keepdim: bool) -> Result<Tensor> {
695 let axes = axes.into();
696 let shape = tensor.shape()?;
697 let resolved_axes = Tensor::resolve_axis_spec(&axes, shape.len())?;
698
699 let mut count = 1i64;
701 for &axis in &resolved_axes {
702 if let Some(dim_size) = shape[axis].as_const() {
703 count *= dim_size as i64;
704 } else {
705 return SymbolicShapeUnsupportedSnafu { operation: "mean" }.fail();
706 }
707 }
708
709 let dtype = tensor.uop().dtype();
711 let output_dtype = if Tensor::is_integer_dtype(&dtype) { DType::Float32 } else { dtype };
712
713 let sum = reduce_internal(tensor, ReduceOp::Add, axes, keepdim, Some(output_dtype.clone()), false)?;
715
716 let count_tensor = Tensor::new(UOp::const_(output_dtype.clone(), svod_ir::ConstValue::Float(count as f64)));
718 Ok(&sum / &count_tensor)
719}
720
721fn var_impl(tensor: &Tensor, axes: AxisSpec, keepdim: bool) -> Result<Tensor> {
723 let (var, _mean) = var_mean_impl(tensor, axes, keepdim)?;
724 Ok(var)
725}
726
727fn std_impl(tensor: &Tensor, axes: AxisSpec, keepdim: bool) -> Result<Tensor> {
729 let variance = var_impl(tensor, axes, keepdim)?;
730 variance.try_sqrt()
731}
732
733fn var_mean_impl(tensor: &Tensor, axes: AxisSpec, keepdim: bool) -> Result<(Tensor, Tensor)> {
735 let shape = tensor.shape()?;
736 let resolved_axes = Tensor::resolve_axis_spec(&axes, shape.len())?;
737
738 let mut count = 1i64;
740 for &axis in &resolved_axes {
741 if let Some(dim_size) = shape[axis].as_const() {
742 count *= dim_size as i64;
743 } else {
744 return SymbolicShapeUnsupportedSnafu { operation: "variance" }.fail();
745 }
746 }
747
748 let dtype = tensor.uop().dtype();
750 let output_dtype = if Tensor::is_integer_dtype(&dtype) { DType::Float32 } else { dtype.clone() };
751
752 let mean = mean_impl(tensor, axes.clone(), keepdim)?;
754
755 let deviation = if keepdim {
758 tensor.try_sub(&mean)?
759 } else {
760 let mut expanded_mean = mean.clone();
762 for &axis in &resolved_axes {
763 expanded_mean = expanded_mean.try_unsqueeze(axis as isize)?;
764 }
765 tensor.try_sub(&expanded_mean)?
766 };
767
768 let squared_dev = deviation.square()?;
770
771 let sum_sq_dev = reduce_internal(&squared_dev, ReduceOp::Add, axes, keepdim, Some(output_dtype.clone()), false)?;
773
774 let denom = if count > 1 { count - 1 } else { count };
776 let denom_tensor = Tensor::new(UOp::const_(output_dtype, svod_ir::ConstValue::Float(denom as f64)));
777 let variance = &sum_sq_dev / &denom_tensor;
778
779 Ok((variance, mean))
780}
781
782fn std_mean_impl(tensor: &Tensor, axes: AxisSpec, keepdim: bool) -> Result<(Tensor, Tensor)> {
784 let (variance, mean) = var_mean_impl(tensor, axes, keepdim)?;
785 let std = variance.try_sqrt()?;
786 Ok((std, mean))
787}