Skip to main content

svod_tensor/
reduce.rs

1//! Reduction operations for tensors.
2//!
3//! This module provides reduction operations like sum, max, min, prod, and mean
4//! with ergonomic APIs that match PyTorch/NumPy conventions.
5
6use bon::bon;
7use snafu::ResultExt;
8use svod_dtype::{DType, ScalarDType};
9use svod_ir::{ConstValue, ReduceOp, SInt, UOp};
10
11use crate::{
12    Error, Result, Tensor,
13    error::{SymbolicShapeUnsupportedSnafu, UOpSnafu},
14};
15
16/// Specification for reduction axes.
17///
18/// Supports:
19/// - All axes: `AxisSpec::All` (from `()`)
20/// - Single axis: `AxisSpec::Single(0)` (from `isize`)
21/// - Multiple axes: `AxisSpec::Multiple(vec![0, 2])` (from `&[isize]` or `Vec<isize>`)
22#[derive(Debug, Clone)]
23pub enum AxisSpec {
24    /// Reduce all axes (produces scalar).
25    All,
26    /// Reduce a single axis (supports negative indexing).
27    Single(isize),
28    /// Reduce multiple axes (each supports negative indexing).
29    Multiple(Vec<isize>),
30}
31
32// Ergonomic Into conversions for AxisSpec
33impl From<()> for AxisSpec {
34    fn from(_: ()) -> Self {
35        Self::All
36    }
37}
38
39impl From<isize> for AxisSpec {
40    fn from(axis: isize) -> Self {
41        Self::Single(axis)
42    }
43}
44
45impl From<&[isize]> for AxisSpec {
46    fn from(axes: &[isize]) -> Self {
47        Self::Multiple(axes.to_vec())
48    }
49}
50
51impl From<Vec<isize>> for AxisSpec {
52    fn from(axes: Vec<isize>) -> Self {
53        Self::Multiple(axes)
54    }
55}
56
57// =========================================================================
58// Tensor Reduction Methods
59// =========================================================================
60
61impl Tensor {
62    /// Resolve axis specification to normalized axis indices.
63    ///
64    /// Handles:
65    /// - `AxisSpec::All` → all axes (0..ndim)
66    /// - Single/multiple axes → normalize negative indices
67    /// - Deduplication
68    /// - Bounds checking
69    pub(crate) fn resolve_axis_spec(spec: &AxisSpec, ndim: usize) -> Result<Vec<usize>> {
70        match spec {
71            AxisSpec::All => Ok((0..ndim).collect()),
72            AxisSpec::Single(axis) => {
73                let normalized = Self::normalize_axis(*axis, ndim)?;
74                Ok(vec![normalized])
75            }
76            AxisSpec::Multiple(axes) => {
77                let mut normalized: Vec<usize> =
78                    axes.iter().map(|&axis| Self::normalize_axis(axis, ndim)).collect::<Result<_>>()?;
79
80                // Deduplicate axes (keep first occurrence)
81                normalized.sort_unstable();
82                normalized.dedup();
83
84                Ok(normalized)
85            }
86        }
87    }
88
89    /// Get accumulation dtype for sum operations (Tinygrad-compatible).
90    ///
91    /// Used when `promote=true` in reduction builders.
92    ///
93    /// Promotion rules:
94    /// - int8, int16 → int32
95    /// - int32, int64 → preserve
96    /// - uint8, uint16 → uint32
97    /// - uint32, uint64 → preserve
98    /// - float16, bfloat16 → float32 (for accumulation)
99    /// - float32, float64 → preserve
100    /// - bool → int32
101    pub(crate) fn sum_acc_dtype(dtype: &DType) -> DType {
102        use ScalarDType::*;
103        let Some(scalar) = dtype.scalar() else {
104            return dtype.clone();
105        };
106
107        match scalar {
108            Bool => DType::Int32,
109            Int8 | Int16 => DType::Int32,
110            Int32 | Int64 => dtype.clone(),
111            UInt8 | UInt16 => DType::UInt32,
112            UInt32 | UInt64 => dtype.clone(),
113            Float16 | BFloat16 | FP8E4M3 | FP8E5M2 => DType::Float32,
114            Float32 | Float64 => dtype.clone(),
115            Void | Index => dtype.clone(),
116        }
117    }
118
119    /// Check if dtype should be cast back after sum accumulation.
120    ///
121    /// Tinygrad casts back to original dtype for:
122    /// - float16
123    /// - bfloat16
124    /// - fp8 variants
125    fn should_cast_back_after_sum(dtype: &DType) -> bool {
126        matches!(
127            dtype.scalar(),
128            Some(ScalarDType::Float16 | ScalarDType::BFloat16 | ScalarDType::FP8E4M3 | ScalarDType::FP8E5M2)
129        )
130    }
131
132    /// Check if dtype is an integer or bool type.
133    fn is_integer_dtype(dtype: &DType) -> bool {
134        dtype.is_int() || matches!(dtype.scalar(), Some(ScalarDType::Bool))
135    }
136
137    /// Remove singleton dimensions from reduced axes when keepdim=false.
138    ///
139    /// Example:
140    /// - shape [2, 3, 4], reduced axes [0, 2] → shape [2, 1, 4]
141    /// - keepdim=false → reshape to [3]
142    fn remove_singleton_dims(self, reduced_axes: &[usize]) -> Result<Self> {
143        let shape = self.shape()?;
144
145        // Build new shape by filtering out size-1 dimensions that were reduced
146        let new_shape: Vec<SInt> = shape
147            .iter()
148            .enumerate()
149            .filter_map(|(i, dim)| {
150                // Only keep non-reduced axes, or reduced axes that aren't size 1
151                if reduced_axes.contains(&i) {
152                    None // Remove this dimension
153                } else {
154                    Some(dim.clone())
155                }
156            })
157            .collect();
158
159        // If all dimensions were reduced, result is scalar (shape [])
160        if new_shape.is_empty() {
161            // For scalar result, reshape to shape [] (0-d tensor)
162            // IR reshape expects same product, so [] → [] is valid
163            self.try_reshape(std::iter::empty::<SInt>())
164        } else {
165            self.try_reshape(&new_shape)
166        }
167    }
168}
169
170#[bon]
171impl Tensor {
172    /// Sum of tensor elements over given axes.
173    ///
174    /// Auto-promotes accumulation dtype (bool→int32, float16→float32) like Tinygrad.
175    /// Use `sum_with().promote(false)` to preserve input dtype.
176    #[track_caller]
177    pub fn sum(&self, axes: impl Into<AxisSpec>) -> Result<Self> {
178        reduce_internal(self, ReduceOp::Add, axes.into(), false, None, true)
179    }
180
181    /// Sum with additional options (keepdim, dtype, promote).
182    ///
183    /// # Examples
184    /// ```ignore
185    /// // Explicit dtype
186    /// tensor.sum_with(0).dtype(DType::Float32).call()?;
187    ///
188    /// // Auto-promote (int8→int32, etc.)
189    /// tensor.sum_with(0).promote(true).call()?;
190    ///
191    /// // With keepdim
192    /// tensor.sum_with(0).keepdim(true).call()?;
193    /// ```
194    #[builder]
195    #[track_caller]
196    pub fn sum_with(
197        &self,
198        axes: impl Into<AxisSpec>,
199        #[builder(default = false)] keepdim: bool,
200        dtype: Option<DType>,
201        #[builder(default = false)] promote: bool,
202    ) -> Result<Self> {
203        reduce_internal(self, ReduceOp::Add, axes.into(), keepdim, dtype, promote)
204    }
205
206    /// Product of tensor elements over given axes.
207    ///
208    /// Preserves input dtype. Use `prod_with().promote(true)` or `.dtype(...)` for different accumulation.
209    #[track_caller]
210    pub fn prod(&self, axes: impl Into<AxisSpec>) -> Result<Self> {
211        reduce_internal(self, ReduceOp::Mul, axes.into(), false, None, false)
212    }
213
214    /// Product with additional options (keepdim, dtype, promote).
215    #[builder]
216    #[track_caller]
217    pub fn prod_with(
218        &self,
219        axes: impl Into<AxisSpec>,
220        #[builder(default = false)] keepdim: bool,
221        dtype: Option<DType>,
222        #[builder(default = false)] promote: bool,
223    ) -> Result<Self> {
224        reduce_internal(self, ReduceOp::Mul, axes.into(), keepdim, dtype, promote)
225    }
226
227    /// Maximum of tensor elements over given axes.
228    ///
229    /// Always preserves input dtype.
230    pub fn max(&self, axes: impl Into<AxisSpec>) -> Result<Self> {
231        reduce_internal(self, ReduceOp::Max, axes.into(), false, None, false)
232    }
233
234    /// Maximum with keepdim option.
235    #[builder]
236    #[track_caller]
237    pub fn max_with(&self, axes: impl Into<AxisSpec>, #[builder(default = false)] keepdim: bool) -> Result<Self> {
238        reduce_internal(self, ReduceOp::Max, axes.into(), keepdim, None, false)
239    }
240
241    /// Minimum of tensor elements over given axes.
242    ///
243    /// Always preserves input dtype.
244    #[track_caller]
245    pub fn min(&self, axes: impl Into<AxisSpec>) -> Result<Self> {
246        reduce_internal(self, ReduceOp::Min, axes.into(), false, None, false)
247    }
248
249    /// Minimum with keepdim option.
250    #[builder]
251    #[track_caller]
252    pub fn min_with(&self, axes: impl Into<AxisSpec>, #[builder(default = false)] keepdim: bool) -> Result<Self> {
253        reduce_internal(self, ReduceOp::Min, axes.into(), keepdim, None, false)
254    }
255
256    /// Mean of tensor elements over given axes.
257    ///
258    /// For integer inputs, automatically uses float32 accumulation.
259    /// For float inputs, preserves input dtype.
260    #[track_caller]
261    pub fn mean(&self, axes: impl Into<AxisSpec>) -> Result<Self> {
262        mean_impl(self, axes.into(), false)
263    }
264
265    /// Mean with keepdim option.
266    #[builder]
267    #[track_caller]
268    pub fn mean_with(&self, axes: impl Into<AxisSpec>, #[builder(default = false)] keepdim: bool) -> Result<Self> {
269        mean_impl(self, axes, keepdim)
270    }
271
272    /// Variance of tensor elements over given axes.
273    ///
274    /// Computes unbiased sample variance (divides by N-1).
275    /// For integer inputs, automatically uses float32 accumulation.
276    /// For float inputs, preserves input dtype.
277    ///
278    /// # Examples
279    /// ```ignore
280    /// let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0]);
281    /// let v = t.var(())?;  // Variance over all elements
282    /// ```
283    #[track_caller]
284    pub fn var(&self, axes: impl Into<AxisSpec>) -> Result<Self> {
285        var_impl(self, axes.into(), false)
286    }
287
288    /// Variance with keepdim option.
289    #[builder]
290    #[track_caller]
291    pub fn var_with(&self, axes: impl Into<AxisSpec>, #[builder(default = false)] keepdim: bool) -> Result<Self> {
292        var_impl(self, axes.into(), keepdim)
293    }
294
295    /// Standard deviation of tensor elements over given axes.
296    ///
297    /// Computes unbiased sample standard deviation (divides by N-1).
298    /// For integer inputs, automatically uses float32 accumulation.
299    /// For float inputs, preserves input dtype.
300    ///
301    /// # Examples
302    /// ```ignore
303    /// let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0]);
304    /// let s = t.std(())?;  // Std dev over all elements
305    /// ```
306    #[track_caller]
307    pub fn std(&self, axes: impl Into<AxisSpec>) -> Result<Self> {
308        std_impl(self, axes.into(), false)
309    }
310
311    /// Standard deviation with keepdim option.
312    #[builder]
313    #[track_caller]
314    pub fn std_with(&self, axes: impl Into<AxisSpec>, #[builder(default = false)] keepdim: bool) -> Result<Self> {
315        std_impl(self, axes.into(), keepdim)
316    }
317
318    /// Variance and mean of tensor elements over given axes.
319    ///
320    /// Returns (variance, mean) tuple. More efficient than computing separately.
321    /// Computes unbiased sample variance (divides by N-1).
322    ///
323    /// # Examples
324    /// ```ignore
325    /// let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0]);
326    /// let (v, m) = t.var_mean(())?;
327    /// ```
328    #[track_caller]
329    pub fn var_mean(&self, axes: impl Into<AxisSpec>) -> Result<(Self, Self)> {
330        var_mean_impl(self, axes.into(), false)
331    }
332
333    /// Variance and mean with keepdim option.
334    #[builder]
335    #[track_caller]
336    pub fn var_mean_with(
337        &self,
338        axes: impl Into<AxisSpec>,
339        #[builder(default = false)] keepdim: bool,
340    ) -> Result<(Self, Self)> {
341        var_mean_impl(self, axes.into(), keepdim)
342    }
343
344    /// Standard deviation and mean of tensor elements over given axes.
345    ///
346    /// Returns (std, mean) tuple. More efficient than computing separately.
347    /// Computes unbiased sample standard deviation (divides by N-1).
348    ///
349    /// # Examples
350    /// ```ignore
351    /// let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0]);
352    /// let (s, m) = t.std_mean(())?;
353    /// ```
354    #[track_caller]
355    pub fn std_mean(&self, axes: impl Into<AxisSpec>) -> Result<(Self, Self)> {
356        std_mean_impl(self, axes.into(), false)
357    }
358
359    /// Standard deviation and mean with keepdim option.
360    #[builder]
361    #[track_caller]
362    pub fn std_mean_with(
363        &self,
364        axes: impl Into<AxisSpec>,
365        #[builder(default = false)] keepdim: bool,
366    ) -> Result<(Self, Self)> {
367        std_mean_impl(self, axes.into(), keepdim)
368    }
369
370    /// Internal helper: inverse of tensor for argmin.
371    ///
372    /// - Float dtypes: -self
373    /// - Integer dtypes: ~self (bitwise NOT)
374    /// - Bool dtype: logical_not(self)
375    fn inverse(&self) -> Result<Self> {
376        let dtype = self.uop().dtype();
377        if dtype.is_float() {
378            self.try_neg()
379        } else if dtype.is_int() {
380            self.bitwise_not()
381        } else if matches!(dtype.scalar(), Some(ScalarDType::Bool)) {
382            self.logical_not()
383        } else {
384            Ok(self.clone()) // Fallback for other types
385        }
386    }
387}
388
389// =========================================================================
390// Argmax / Argmin Operations
391// =========================================================================
392
393#[bon]
394impl Tensor {
395    /// Index of maximum value along axis.
396    ///
397    /// Returns int32 tensor with indices of maximum values.
398    /// For ties, returns the index of the first occurrence.
399    ///
400    /// # Arguments
401    /// * `axis` - Axis to reduce (None = flatten first)
402    ///
403    /// # Examples
404    /// ```ignore
405    /// let t = Tensor::from_slice(&[[1.0, 3.0, 2.0], [4.0, 2.0, 5.0]]);
406    /// t.argmax(None)?;      // 5 (flattened: max is at index 5)
407    /// t.argmax(Some(0))?;   // [1, 0, 1] (row indices of max per column)
408    /// t.argmax(Some(1))?;   // [1, 2] (column indices of max per row)
409    /// ```
410    #[track_caller]
411    pub fn argmax(&self, axis: impl Into<Option<isize>>) -> Result<Self> {
412        argmax_impl(self, axis.into(), false)
413    }
414
415    /// Argmax with keepdim option.
416    #[builder]
417    #[track_caller]
418    pub fn argmax_with(
419        &self,
420        axis: impl Into<Option<isize>>,
421        #[builder(default = false)] keepdim: bool,
422    ) -> Result<Self> {
423        argmax_impl(self, axis.into(), keepdim)
424    }
425
426    /// Hard maximum: one-hot encoding of the argmax along an axis.
427    ///
428    /// Returns a tensor of the same shape with 1.0 at the position of the
429    /// maximum value along `axis` and 0.0 elsewhere, cast to the input dtype.
430    #[track_caller]
431    pub fn hardmax(&self, axis: isize) -> Result<Self> {
432        let shape = self.shape()?;
433        let ndim = shape.len();
434        let norm_axis = Self::normalize_axis(axis, ndim)?;
435        let axis_size = shape[norm_axis].as_const().ok_or_else(|| crate::error::Error::SymbolicShapeUnsupported {
436            operation: format!("hardmax axis {norm_axis}"),
437        })?;
438        self.argmax_with()
439            .axis(Some(axis))
440            .keepdim(false)
441            .call()?
442            .try_unsqueeze(axis)?
443            .one_hot_along_dim(axis_size, axis)?
444            .cast(self.uop().dtype())
445    }
446
447    /// Index of minimum value along axis.
448    ///
449    /// Returns int32 tensor with indices of minimum values.
450    /// For ties, returns the index of the first occurrence.
451    #[track_caller]
452    pub fn argmin(&self, axis: impl Into<Option<isize>>) -> Result<Self> {
453        argmin_impl(self, axis.into(), false)
454    }
455
456    /// Argmin with keepdim option.
457    #[builder]
458    #[track_caller]
459    pub fn argmin_with(
460        &self,
461        axis: impl Into<Option<isize>>,
462        #[builder(default = false)] keepdim: bool,
463    ) -> Result<Self> {
464        argmin_impl(self, axis.into(), keepdim)
465    }
466
467    /// Test if any element is true along axes.
468    ///
469    /// Logical OR reduction. Returns bool dtype.
470    /// Non-zero values are treated as true.
471    ///
472    /// # Examples
473    /// ```ignore
474    /// let t = Tensor::from_slice(&[[true, false], [false, false]]);
475    /// t.any(())?;           // true (any element is true)
476    /// t.any(0)?;            // [true, false] (any true per column)
477    /// t.any(1)?;            // [true, false] (any true per row)
478    /// ```
479    #[track_caller]
480    pub fn any(&self, axes: impl Into<AxisSpec>) -> Result<Self> {
481        any_impl(self, axes.into(), false)
482    }
483
484    /// Any with keepdim option.
485    #[builder]
486    #[track_caller]
487    pub fn any_with(&self, axes: impl Into<AxisSpec>, #[builder(default = false)] keepdim: bool) -> Result<Self> {
488        any_impl(self, axes.into(), keepdim)
489    }
490
491    /// Test if all elements are true along axes.
492    ///
493    /// Logical AND reduction. Returns bool dtype.
494    /// Non-zero values are treated as true.
495    ///
496    /// # Examples
497    /// ```ignore
498    /// let t = Tensor::from_slice(&[[true, true], [true, false]]);
499    /// t.all(())?;           // false (not all elements are true)
500    /// t.all(0)?;            // [true, false] (all true per column)
501    /// t.all(1)?;            // [true, false] (all true per row)
502    /// ```
503    #[track_caller]
504    pub fn all(&self, axes: impl Into<AxisSpec>) -> Result<Self> {
505        all_impl(self, axes.into(), false)
506    }
507
508    /// All with keepdim option.
509    #[builder]
510    #[track_caller]
511    pub fn all_with(&self, axes: impl Into<AxisSpec>, #[builder(default = false)] keepdim: bool) -> Result<Self> {
512        all_impl(self, axes.into(), keepdim)
513    }
514}
515
516/// Internal argmax implementation.
517fn argmax_impl(tensor: &Tensor, axis: Option<isize>, keepdim: bool) -> Result<Tensor> {
518    // Handle None axis: flatten and call argmax on axis 0
519    let (working_tensor, working_axis) =
520        if let Some(ax) = axis { (tensor.clone(), ax) } else { (tensor.flatten()?, 0) };
521
522    let shape = working_tensor.shape()?;
523    let normalized_axis = Tensor::normalize_axis(working_axis, shape.len())?;
524    let axis_size = shape[normalized_axis]
525        .as_const()
526        .ok_or_else(|| Error::SymbolicShapeUnsupported { operation: "argmax".to_string() })?;
527
528    // Convert shape to isize vec once for reuse in expand operations
529    let shape_vec = svod_ir::shape::to_vec_isize(&shape).context(UOpSnafu)?;
530
531    // Step 1: Find maximum values along axis (with keepdim for broadcasting)
532    let max_vals_keepdim = working_tensor.max_with().axes(working_axis).keepdim(true).call()?;
533
534    // Step 2: Create mask where values equal the max
535    // Need to broadcast max_vals to match working_tensor shape
536    let max_vals_broadcast = max_vals_keepdim.try_expand(&shape_vec)?;
537
538    let mask = working_tensor.try_eq(&max_vals_broadcast)?;
539
540    // Step 3: Create descending index tensor [N, N-1, ..., 1]
541    // This ensures ties go to first occurrence
542    let indices = Tensor::arange(axis_size as i64, Some(0), Some(-1))?;
543
544    // Step 4: Reshape indices to broadcast along the target axis
545    // E.g., for axis=1 with 3D tensor: [axis_size] -> [1, axis_size, 1]
546    let mut idx_shape = vec![1isize; shape.len()];
547    idx_shape[normalized_axis] = axis_size as isize;
548    let indices_reshaped = indices.try_reshape(&idx_shape)?;
549
550    // Expand indices to match working_tensor shape
551    let indices_broadcast = indices_reshaped.try_expand(&shape_vec)?;
552
553    // Step 5: Multiply mask by indices (0 where not max, index where max)
554    let mask_int = mask.cast(DType::Int32)?;
555    let masked_indices = mask_int.try_mul(&indices_broadcast)?;
556
557    // Step 6: Take max of masked indices (gives highest index, which is first occurrence)
558    let max_idx = masked_indices.max_with().axes(working_axis).keepdim(keepdim).call()?;
559
560    // Step 7: Invert: N - max_idx gives actual index
561    let n_tensor = Tensor::from_slice([axis_size as i32]);
562
563    // Broadcast n_tensor to match max_idx shape if needed
564    let max_idx_shape = max_idx.shape()?;
565    let result = if !max_idx_shape.is_empty() {
566        // Non-scalar result: broadcast n_tensor
567        let max_idx_shape_vec = svod_ir::shape::to_vec_isize(&max_idx_shape).context(UOpSnafu)?;
568        let ones_shape = vec![1isize; max_idx_shape.len()];
569        let n_reshaped = n_tensor.try_reshape(&ones_shape)?;
570        let n_broadcast = n_reshaped.try_expand(&max_idx_shape_vec)?;
571        n_broadcast.try_sub(&max_idx)?
572    } else {
573        // Scalar result: reshape n_tensor to scalar too
574        let n_scalar = n_tensor.try_reshape(&[] as &[isize])?;
575        n_scalar.try_sub(&max_idx)?
576    };
577
578    // Cast final result to Int32 (like Tinygrad)
579    result.cast(DType::Int32)
580}
581
582/// Internal argmin implementation.
583fn argmin_impl(tensor: &Tensor, axis: Option<isize>, keepdim: bool) -> Result<Tensor> {
584    // Argmin is just argmax of inverted values
585    let inverted = tensor.inverse()?;
586    argmax_impl(&inverted, axis, keepdim)
587}
588
589/// Internal any implementation.
590fn any_impl(tensor: &Tensor, axes: AxisSpec, keepdim: bool) -> Result<Tensor> {
591    // Cast to bool (non-zero becomes true)
592    let as_bool = tensor.cast(DType::Bool)?;
593
594    // Max reduction on bool is logical OR
595    reduce_internal(&as_bool, ReduceOp::Max, axes, keepdim, None, false)
596}
597
598/// Internal all implementation.
599fn all_impl(tensor: &Tensor, axes: AxisSpec, keepdim: bool) -> Result<Tensor> {
600    // De Morgan's law: all(x) = !any(!x)
601    let negated = tensor.logical_not()?;
602    let any_negated = any_impl(&negated, axes, keepdim)?;
603    any_negated.logical_not()
604}
605
606/// Identity element for a reduction over an empty set (matching Tinygrad's `identity_element`).
607fn reduction_identity(op: ReduceOp, dtype: &DType) -> ConstValue {
608    let s = dtype.scalar().expect("scalar dtype");
609    match op {
610        ReduceOp::Add => ConstValue::zero(s),
611        ReduceOp::Mul => ConstValue::one(s),
612        ReduceOp::Max => ConstValue::min(s),
613        ReduceOp::Min => ConstValue::max(s),
614    }
615}
616
617/// Internal reduction implementation.
618#[track_caller]
619fn reduce_internal(
620    tensor: &Tensor,
621    op: ReduceOp,
622    axes: AxisSpec,
623    keepdim: bool,
624    dtype: Option<DType>,
625    promote: bool,
626) -> Result<Tensor> {
627    // Validate conflicting options
628    if dtype.is_some() && promote {
629        return Err(Error::ConflictingReductionOptions);
630    }
631
632    let shape = tensor.shape()?;
633    let resolved_axes = Tensor::resolve_axis_spec(&axes, shape.len())?;
634
635    // Determine accumulation dtype
636    let original_dtype = tensor.uop().dtype();
637    let acc_dtype = if let Some(ref dt) = dtype {
638        // Explicit dtype takes precedence
639        dt.clone()
640    } else if promote {
641        // Auto-promote using sum_acc_dtype
642        Tensor::sum_acc_dtype(&original_dtype)
643    } else {
644        // Preserve input dtype
645        original_dtype.clone()
646    };
647
648    // Handle zero-sized dimensions: short-circuit to identity element to avoid
649    // DivisionByZero in indexing (matching Tinygrad rangeify.py:115-120).
650    let reducing_empty_axis = resolved_axes.iter().any(|&ax| shape[ax].as_const() == Some(0));
651    if reducing_empty_axis {
652        // Compute output shape: reduced axes become 1
653        let out_shape: Vec<usize> = shape
654            .iter()
655            .enumerate()
656            .map(|(i, d)| if resolved_axes.contains(&i) { 1 } else { d.as_const().unwrap_or(1) })
657            .collect();
658
659        let identity = reduction_identity(op, &acc_dtype);
660        let result = Tensor::full(&out_shape, identity, acc_dtype)?;
661
662        let result = if !keepdim { result.remove_singleton_dims(&resolved_axes)? } else { result };
663
664        return if promote && dtype.is_none() && Tensor::should_cast_back_after_sum(&original_dtype) {
665            result.cast(original_dtype)
666        } else {
667            Ok(result)
668        };
669    }
670
671    // Cast to accumulation dtype if needed
672    let working_tensor = if acc_dtype != original_dtype { tensor.cast(acc_dtype.clone())? } else { tensor.clone() };
673
674    // Perform reduction
675    let reduced = working_tensor.uop().try_reduce_axis(op, resolved_axes.clone()).context(UOpSnafu)?;
676
677    // Handle keepdim
678    let result = if keepdim {
679        Tensor::new(reduced)
680    } else {
681        let temp = Tensor::new(reduced);
682        temp.remove_singleton_dims(&resolved_axes)?
683    };
684
685    // Cast back if promoted and should cast back (fp16/bf16)
686    if promote && dtype.is_none() && Tensor::should_cast_back_after_sum(&original_dtype) {
687        result.cast(original_dtype)
688    } else {
689        Ok(result)
690    }
691}
692
693/// Mean implementation (shared by mean and mean_with).
694fn mean_impl(tensor: &Tensor, axes: impl Into<AxisSpec>, keepdim: bool) -> Result<Tensor> {
695    let axes = axes.into();
696    let shape = tensor.shape()?;
697    let resolved_axes = Tensor::resolve_axis_spec(&axes, shape.len())?;
698
699    // Calculate count of reduced elements
700    let mut count = 1i64;
701    for &axis in &resolved_axes {
702        if let Some(dim_size) = shape[axis].as_const() {
703            count *= dim_size as i64;
704        } else {
705            return SymbolicShapeUnsupportedSnafu { operation: "mean" }.fail();
706        }
707    }
708
709    // Determine output dtype (integers → float32, floats preserve)
710    let dtype = tensor.uop().dtype();
711    let output_dtype = if Tensor::is_integer_dtype(&dtype) { DType::Float32 } else { dtype };
712
713    // Sum with explicit accumulation dtype (no promotion needed, dtype is explicit)
714    let sum = reduce_internal(tensor, ReduceOp::Add, axes, keepdim, Some(output_dtype.clone()), false)?;
715
716    // Divide by count
717    let count_tensor = Tensor::new(UOp::const_(output_dtype.clone(), svod_ir::ConstValue::Float(count as f64)));
718    Ok(&sum / &count_tensor)
719}
720
721/// Variance implementation using E[X²] - E[X]² formula.
722fn var_impl(tensor: &Tensor, axes: AxisSpec, keepdim: bool) -> Result<Tensor> {
723    let (var, _mean) = var_mean_impl(tensor, axes, keepdim)?;
724    Ok(var)
725}
726
727/// Standard deviation implementation.
728fn std_impl(tensor: &Tensor, axes: AxisSpec, keepdim: bool) -> Result<Tensor> {
729    let variance = var_impl(tensor, axes, keepdim)?;
730    variance.try_sqrt()
731}
732
733/// Variance and mean implementation using single-pass algorithm.
734fn var_mean_impl(tensor: &Tensor, axes: AxisSpec, keepdim: bool) -> Result<(Tensor, Tensor)> {
735    let shape = tensor.shape()?;
736    let resolved_axes = Tensor::resolve_axis_spec(&axes, shape.len())?;
737
738    // Calculate count of reduced elements
739    let mut count = 1i64;
740    for &axis in &resolved_axes {
741        if let Some(dim_size) = shape[axis].as_const() {
742            count *= dim_size as i64;
743        } else {
744            return SymbolicShapeUnsupportedSnafu { operation: "variance" }.fail();
745        }
746    }
747
748    // Determine output dtype (integers → float32, floats preserve)
749    let dtype = tensor.uop().dtype();
750    let output_dtype = if Tensor::is_integer_dtype(&dtype) { DType::Float32 } else { dtype.clone() };
751
752    // Compute mean: E[X]
753    let mean = mean_impl(tensor, axes.clone(), keepdim)?;
754
755    // Compute deviation from mean: X - E[X]
756    // Need to broadcast mean if keepdim=false
757    let deviation = if keepdim {
758        tensor.try_sub(&mean)?
759    } else {
760        // Expand mean back to original shape for subtraction
761        let mut expanded_mean = mean.clone();
762        for &axis in &resolved_axes {
763            expanded_mean = expanded_mean.try_unsqueeze(axis as isize)?;
764        }
765        tensor.try_sub(&expanded_mean)?
766    };
767
768    // Square the deviations: (X - E[X])²
769    let squared_dev = deviation.square()?;
770
771    // Sum squared deviations with explicit dtype
772    let sum_sq_dev = reduce_internal(&squared_dev, ReduceOp::Add, axes, keepdim, Some(output_dtype.clone()), false)?;
773
774    // Divide by N-1 for unbiased estimate (Bessel's correction)
775    let denom = if count > 1 { count - 1 } else { count };
776    let denom_tensor = Tensor::new(UOp::const_(output_dtype, svod_ir::ConstValue::Float(denom as f64)));
777    let variance = &sum_sq_dev / &denom_tensor;
778
779    Ok((variance, mean))
780}
781
782/// Standard deviation and mean implementation.
783fn std_mean_impl(tensor: &Tensor, axes: AxisSpec, keepdim: bool) -> Result<(Tensor, Tensor)> {
784    let (variance, mean) = var_mean_impl(tensor, axes, keepdim)?;
785    let std = variance.try_sqrt()?;
786    Ok((std, mean))
787}