train_station/tensor/reductions/
mean.rs

1use crate::gradtrack::{GradEngine, GradFn};
2use crate::tensor::core::Tensor;
3
4impl Tensor {
5    /// Computes the arithmetic mean of all elements in the tensor
6    ///
7    /// This method calculates the average value across all tensor elements by summing
8    /// all values and dividing by the total number of elements. The result is a scalar
9    /// tensor containing the mean value. This operation supports gradient tracking
10    /// through the GradTrack system.
11    ///
12    /// # Returns
13    ///
14    /// A tensor with shape `[1]` containing the arithmetic mean of all elements.
15    /// For empty tensors, returns `0.0` as a safe default.
16    ///
17    /// # Performance Characteristics
18    ///
19    /// - **Linear Time**: O(n) complexity for computing the sum
20    /// - **Memory Efficient**: Single pass through tensor data with SIMD-optimized accumulation
21    /// - **Numerical Stability**: Uses direct accumulation for typical tensor sizes
22    /// - **Edge Case Handling**: Returns 0.0 for empty tensors
23    ///
24    /// # Examples
25    ///
26    /// ```
27    /// use train_station::Tensor;
28    ///
29    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
30    /// let mean_val = tensor.mean();
31    /// assert_eq!(mean_val.get(&[0]), 2.5); // (1+2+3+4)/4 = 2.5
32    /// ```
33    ///
34    /// ```
35    /// use train_station::Tensor;
36    ///
37    /// // Empty tensor case
38    /// let empty_tensor = Tensor::new(vec![0]);
39    /// let mean_val = empty_tensor.mean();
40    /// assert_eq!(mean_val.get(&[0]), 0.0);
41    /// ```
42    ///
43    /// # GradTrack Support
44    ///
45    /// When `requires_grad` is true, this operation is tracked for automatic
46    /// differentiation. The gradient computation distributes the gradient equally
47    /// across all input elements.
48    pub fn mean(&self) -> Tensor {
49        let mut out = Tensor::new(vec![1]);
50        if self.size() == 0 {
51            // Convention: mean over empty returns 0.0 (aligns with safe behavior for now)
52            out.fill(0.0);
53        } else {
54            let mut acc0 = 0.0f32;
55
56            if self.is_contiguous() {
57                // Fast path for contiguous tensors
58                unsafe {
59                    let src = self.as_ptr();
60                    let size = self.size();
61                    let mut i = 0usize;
62                    while i + 4 <= size {
63                        let x0 = *src.add(i);
64                        let x1 = *src.add(i + 1);
65                        let x2 = *src.add(i + 2);
66                        let x3 = *src.add(i + 3);
67                        acc0 += x0 + x1 + x2 + x3;
68                        i += 4;
69                    }
70                    while i < size {
71                        acc0 += *src.add(i);
72                        i += 1;
73                    }
74                }
75            } else {
76                // Stride-aware path for non-contiguous tensors
77                let dims = self.shape().dims.clone();
78                for flat_idx in 0..self.size() {
79                    // Convert flat index to multi-dimensional coordinates
80                    let mut coords = vec![0; dims.len()];
81                    let mut tmp = flat_idx;
82                    for k in (0..dims.len()).rev() {
83                        coords[k] = tmp % dims[k];
84                        tmp /= dims[k];
85                    }
86
87                    // Get value using stride-aware offset
88                    let offset = self.shape().offset(&coords);
89                    let value = unsafe { *self.as_ptr().add(offset) };
90                    acc0 += value;
91                }
92            }
93
94            unsafe {
95                *out.as_mut_ptr() = acc0 / (self.size() as f32);
96            }
97        }
98
99        if self.requires_grad() {
100            out.set_requires_grad_internal(true);
101            let grad_fn = GradFn::ReduceMean {
102                input_shape: self.shape().dims.clone(),
103                numel: self.size(),
104            };
105            out.set_grad_fn(grad_fn.clone());
106            GradEngine::register_operation(out.id(), vec![self.id()], grad_fn);
107        }
108        out
109    }
110
111    /// Computes the arithmetic mean over specified dimensions
112    ///
113    /// This method calculates the mean value along the specified dimensions by first
114    /// computing the sum over those dimensions and then dividing by the product of
115    /// the reduced dimension sizes. The `keepdim` parameter determines whether
116    /// reduced dimensions are kept with size 1 or removed entirely.
117    ///
118    /// # Arguments
119    ///
120    /// * `dims` - Dimensions to reduce over (must be valid for the tensor's rank)
121    /// * `keepdim` - If true, reduced dimensions are kept with size 1; if false, they are removed
122    ///
123    /// # Returns
124    ///
125    /// A tensor with the specified dimensions reduced by computing the mean.
126    /// The output shape depends on `keepdim`:
127    /// - If `keepdim` is `true`, reduced dimensions have size 1
128    /// - If `keepdim` is `false`, reduced dimensions are removed
129    ///
130    /// # Performance Characteristics
131    ///
132    /// - **Efficient Implementation**: Uses `sum_dims` followed by scalar multiplication
133    /// - **Memory Optimized**: Leverages existing sum reduction for optimal performance
134    /// - **Shape Computation**: Fast output shape calculation with dimension preservation
135    /// - **Numerical Stability**: Maintains precision through direct computation
136    ///
137    /// # Examples
138    ///
139    /// ```
140    /// use train_station::Tensor;
141    ///
142    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
143    ///
144    /// // Mean over columns (dim 1), keeping dimensions
145    /// let mean_cols = tensor.mean_dims(&[1], true);
146    /// assert_eq!(mean_cols.shape().dims, vec![2, 1]);
147    /// assert_eq!(mean_cols.get(&[0, 0]), 2.0); // (1+2+3)/3 = 2.0
148    /// assert_eq!(mean_cols.get(&[1, 0]), 5.0); // (4+5+6)/3 = 5.0
149    /// ```
150    ///
151    /// ```
152    /// use train_station::Tensor;
153    ///
154    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
155    ///
156    /// // Mean over rows (dim 0), removing dimensions
157    /// let mean_rows = tensor.mean_dims(&[0], false);
158    /// assert_eq!(mean_rows.shape().dims, vec![3]);
159    /// assert_eq!(mean_rows.get(&[0]), 2.5); // (1+4)/2 = 2.5
160    /// assert_eq!(mean_rows.get(&[1]), 3.5); // (2+5)/2 = 3.5
161    /// assert_eq!(mean_rows.get(&[2]), 4.5); // (3+6)/2 = 4.5
162    /// ```
163    ///
164    /// ```
165    /// use train_station::Tensor;
166    ///
167    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
168    ///
169    /// // Mean over multiple dimensions
170    /// let mean_all = tensor.mean_dims(&[0, 1], false);
171    /// assert_eq!(mean_all.shape().dims, vec![1]);
172    /// assert_eq!(mean_all.get(&[0]), 2.5); // (1+2+3+4)/4 = 2.5
173    /// ```
174    ///
175    /// # Panics
176    ///
177    /// Panics if:
178    /// * `dims` is empty
179    /// * Any dimension in `dims` is out of bounds for the tensor's rank
180    ///
181    /// # GradTrack Support
182    ///
183    /// When `requires_grad` is true, this operation is tracked for automatic
184    /// differentiation. The gradient computation preserves the original input
185    /// shape and handles broadcasting correctly through the ReduceMeanDims gradient function.
186    pub fn mean_dims(&self, dims: &[usize], keepdim: bool) -> Tensor {
187        assert!(
188            !dims.is_empty(),
189            "mean_dims requires at least one dimension"
190        );
191        let rank = self.shape().rank();
192        for &d in dims {
193            assert!(
194                d < rank,
195                "mean_dims dim {} out of bounds for rank {}",
196                d,
197                rank
198            );
199        }
200
201        // Compute sum over dims first, then divide by product of reduced sizes
202        let sum = self.sum_dims(dims, keepdim);
203        let factor: usize = dims.iter().map(|&d| self.shape().dims[d]).product();
204        let scale = if factor > 0 {
205            1.0f32 / (factor as f32)
206        } else {
207            0.0
208        };
209        let out = sum.mul_scalar(scale);
210
211        if self.requires_grad() {
212            // Override autograd of mul to a single ReduceMeanDims node for correctness and clarity
213            // Re-register operation for out to use ReduceMeanDims
214            let mut reg = out.clone();
215            reg.set_requires_grad_internal(true);
216            let mut reduced: Vec<usize> = dims.to_vec();
217            reduced.sort_unstable();
218            reduced.dedup();
219            let grad_fn = GradFn::ReduceMeanDims {
220                dims: reduced,
221                input_shape: self.shape().dims.clone(),
222                keepdim,
223            };
224            reg.set_grad_fn(grad_fn.clone());
225            GradEngine::register_operation(reg.id(), vec![self.id()], grad_fn);
226            return reg;
227        }
228
229        out
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236
237    #[test]
238    fn test_mean_forward_basic() {
239        let mut x = Tensor::zeros(vec![2, 3]);
240        unsafe {
241            for i in 0..6 {
242                *x.as_mut_ptr().add(i) = i as f32;
243            }
244        }
245        let m = x.mean();
246        assert_eq!(m.shape().dims, vec![1]);
247        unsafe {
248            assert!((*m.as_ptr() - (0.0 + 1.0 + 2.0 + 3.0 + 4.0 + 5.0) / 6.0).abs() < 1e-6);
249        }
250    }
251
252    #[test]
253    fn test_mean_autograd_all_equal() {
254        let x = Tensor::from_slice(&[1.0, 3.0, 5.0, 7.0], vec![4])
255            .unwrap()
256            .with_requires_grad();
257        let mut m = x.mean();
258        m.backward(None);
259        let gx = x.grad_by_value().expect("grad missing");
260        for i in 0..4 {
261            unsafe {
262                assert_eq!(*gx.as_ptr().add(i), 0.25);
263            }
264        }
265    }
266
267    #[test]
268    fn test_mean_non_contiguous_transpose() {
269        // Test mean on transposed tensor (non-contiguous view)
270        let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
271        // Original: [[1, 2, 3], [4, 5, 6]]
272
273        let x_t = x.transpose(0, 1);
274        // Transposed: [[1, 4], [2, 5], [3, 6]]
275        assert!(!x_t.is_contiguous()); // Should be a view
276
277        let mean_orig = x.mean();
278        let mean_view = x_t.mean();
279
280        // Both should give the same result: (1+2+3+4+5+6)/6 = 3.5
281        assert!((mean_orig.get(&[0]) - 3.5).abs() < 1e-6);
282        assert!((mean_view.get(&[0]) - 3.5).abs() < 1e-6);
283    }
284
285    #[test]
286    fn test_mean_dims_non_contiguous() {
287        // Test mean_dims on non-contiguous tensor
288        let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
289        let x_t = x.transpose(0, 1); // [3, 2]
290        assert!(!x_t.is_contiguous());
291
292        // Mean along dim 0 of transposed tensor
293        let mean_dim0 = x_t.mean_dims(&[0], false);
294        assert_eq!(mean_dim0.shape().dims, vec![2]);
295        // Should be [(1+2+3)/3, (4+5+6)/3] = [2.0, 5.0]
296        assert!((mean_dim0.get(&[0]) - 2.0).abs() < 1e-6);
297        assert!((mean_dim0.get(&[1]) - 5.0).abs() < 1e-6);
298
299        // Mean along dim 1 of transposed tensor
300        let mean_dim1 = x_t.mean_dims(&[1], false);
301        assert_eq!(mean_dim1.shape().dims, vec![3]);
302        // Should be [(1+4)/2, (2+5)/2, (3+6)/2] = [2.5, 3.5, 4.5]
303        assert!((mean_dim1.get(&[0]) - 2.5).abs() < 1e-6);
304        assert!((mean_dim1.get(&[1]) - 3.5).abs() < 1e-6);
305        assert!((mean_dim1.get(&[2]) - 4.5).abs() < 1e-6);
306    }
307
308    #[test]
309    fn test_mean_permuted_tensor() {
310        // Test with permuted tensor
311        let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
312        let x = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
313
314        // Permute dimensions [2, 3, 4] -> [4, 2, 3]
315        let x_perm = x.permute(vec![2, 1, 0]);
316        assert!(!x_perm.is_contiguous());
317
318        let mean_orig = x.mean();
319        let mean_perm = x_perm.mean();
320
321        // Should give same result
322        assert!((mean_orig.get(&[0]) - mean_perm.get(&[0])).abs() < 1e-6);
323
324        // Expected mean: (0+1+2+...+23)/24 = 23*24/2/24 = 11.5
325        assert!((mean_orig.get(&[0]) - 11.5).abs() < 1e-6);
326    }
327}