train_station/tensor/reductions/
mean.rs

1use crate::gradtrack::{GradEngine, GradFn};
2use crate::tensor::core::Tensor;
3
4impl Tensor {
5    /// Computes the arithmetic mean of all elements in the tensor
6    ///
7    /// This method calculates the average value across all tensor elements by summing
8    /// all values and dividing by the total number of elements. The result is a scalar
9    /// tensor containing the mean value. This operation supports gradient tracking
10    /// through the GradTrack system.
11    ///
12    /// # Returns
13    ///
14    /// A tensor with shape `[1]` containing the arithmetic mean of all elements.
15    /// For empty tensors, returns `0.0` as a safe default.
16    ///
17    /// # Performance Characteristics
18    ///
19    /// - **Linear Time**: O(n) complexity for computing the sum
20    /// - **Memory Efficient**: Single pass through tensor data with SIMD-optimized accumulation
21    /// - **Numerical Stability**: Uses direct accumulation for typical tensor sizes
22    /// - **Edge Case Handling**: Returns 0.0 for empty tensors
23    ///
24    /// # Examples
25    ///
26    /// ```
27    /// use train_station::Tensor;
28    ///
29    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
30    /// let mean_val = tensor.mean();
31    /// assert_eq!(mean_val.get(&[0]), 2.5); // (1+2+3+4)/4 = 2.5
32    /// ```
33    ///
34    /// ```
35    /// use train_station::Tensor;
36    ///
37    /// // Empty tensor case
38    /// let empty_tensor = Tensor::new(vec![0]);
39    /// let mean_val = empty_tensor.mean();
40    /// assert_eq!(mean_val.get(&[0]), 0.0);
41    /// ```
42    ///
43    /// # GradTrack Support
44    ///
45    /// When `requires_grad` is true, this operation is tracked for automatic
46    /// differentiation. The gradient computation distributes the gradient equally
47    /// across all input elements.
48    #[track_caller]
49    pub fn mean(&self) -> Tensor {
50        let mut out = Tensor::new(vec![1]);
51        if self.size() == 0 {
52            // Convention: mean over empty returns 0.0 (aligns with safe behavior for now)
53            out.fill(0.0);
54        } else {
55            let mut acc0 = 0.0f32;
56
57            if self.is_contiguous() {
58                // Fast path for contiguous tensors
59                unsafe {
60                    let src = self.as_ptr();
61                    let size = self.size();
62                    let mut i = 0usize;
63                    while i + 4 <= size {
64                        let x0 = *src.add(i);
65                        let x1 = *src.add(i + 1);
66                        let x2 = *src.add(i + 2);
67                        let x3 = *src.add(i + 3);
68                        acc0 += x0 + x1 + x2 + x3;
69                        i += 4;
70                    }
71                    while i < size {
72                        acc0 += *src.add(i);
73                        i += 1;
74                    }
75                }
76            } else {
77                // Stride-aware path for non-contiguous tensors
78                let dims = self.shape().dims().to_vec();
79                for flat_idx in 0..self.size() {
80                    // Convert flat index to multi-dimensional coordinates
81                    let mut coords = vec![0; dims.len()];
82                    let mut tmp = flat_idx;
83                    for k in (0..dims.len()).rev() {
84                        coords[k] = tmp % dims[k];
85                        tmp /= dims[k];
86                    }
87
88                    // Get value using stride-aware offset
89                    let offset = self.shape().offset(&coords);
90                    let value = unsafe { *self.as_ptr().add(offset) };
91                    acc0 += value;
92                }
93            }
94
95            unsafe {
96                *out.as_mut_ptr() = acc0 / (self.size() as f32);
97            }
98        }
99
100        if self.requires_grad() {
101            out.set_requires_grad_internal(true);
102            let grad_fn = GradFn::ReduceMean {
103                input_shape: self.shape().dims().to_vec(),
104                numel: self.size(),
105            };
106            out.set_grad_fn(grad_fn.clone());
107            GradEngine::register_operation(out.id(), vec![self.id()], grad_fn);
108        }
109        out
110    }
111
112    /// Computes the arithmetic mean over specified dimensions
113    ///
114    /// This method calculates the mean value along the specified dimensions by first
115    /// computing the sum over those dimensions and then dividing by the product of
116    /// the reduced dimension sizes. The `keepdim` parameter determines whether
117    /// reduced dimensions are kept with size 1 or removed entirely.
118    ///
119    /// # Arguments
120    ///
121    /// * `dims` - Dimensions to reduce over (must be valid for the tensor's rank)
122    /// * `keepdim` - If true, reduced dimensions are kept with size 1; if false, they are removed
123    ///
124    /// # Returns
125    ///
126    /// A tensor with the specified dimensions reduced by computing the mean.
127    /// The output shape depends on `keepdim`:
128    /// - If `keepdim` is `true`, reduced dimensions have size 1
129    /// - If `keepdim` is `false`, reduced dimensions are removed
130    ///
131    /// # Performance Characteristics
132    ///
133    /// - **Efficient Implementation**: Uses `sum_dims` followed by scalar multiplication
134    /// - **Memory Optimized**: Leverages existing sum reduction for optimal performance
135    /// - **Shape Computation**: Fast output shape calculation with dimension preservation
136    /// - **Numerical Stability**: Maintains precision through direct computation
137    ///
138    /// # Examples
139    ///
140    /// ```
141    /// use train_station::Tensor;
142    ///
143    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
144    ///
145    /// // Mean over columns (dim 1), keeping dimensions
146    /// let mean_cols = tensor.mean_dims(&[1], true);
147    /// assert_eq!(mean_cols.shape().dims(), vec![2, 1]);
148    /// assert_eq!(mean_cols.get(&[0, 0]), 2.0); // (1+2+3)/3 = 2.0
149    /// assert_eq!(mean_cols.get(&[1, 0]), 5.0); // (4+5+6)/3 = 5.0
150    /// ```
151    ///
152    /// ```
153    /// use train_station::Tensor;
154    ///
155    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
156    ///
157    /// // Mean over rows (dim 0), removing dimensions
158    /// let mean_rows = tensor.mean_dims(&[0], false);
159    /// assert_eq!(mean_rows.shape().dims(), vec![3]);
160    /// assert_eq!(mean_rows.get(&[0]), 2.5); // (1+4)/2 = 2.5
161    /// assert_eq!(mean_rows.get(&[1]), 3.5); // (2+5)/2 = 3.5
162    /// assert_eq!(mean_rows.get(&[2]), 4.5); // (3+6)/2 = 4.5
163    /// ```
164    ///
165    /// ```
166    /// use train_station::Tensor;
167    ///
168    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
169    ///
170    /// // Mean over multiple dimensions
171    /// let mean_all = tensor.mean_dims(&[0, 1], false);
172    /// assert_eq!(mean_all.shape().dims(), vec![1]);
173    /// assert_eq!(mean_all.get(&[0]), 2.5); // (1+2+3+4)/4 = 2.5
174    /// ```
175    ///
176    /// # Panics
177    ///
178    /// Panics if:
179    /// * `dims` is empty
180    /// * Any dimension in `dims` is out of bounds for the tensor's rank
181    ///
182    /// # GradTrack Support
183    ///
184    /// When `requires_grad` is true, this operation is tracked for automatic
185    /// differentiation. The gradient computation preserves the original input
186    /// shape and handles broadcasting correctly through the ReduceMeanDims gradient function.
187    #[track_caller]
188    pub fn mean_dims(&self, dims: &[usize], keepdim: bool) -> Tensor {
189        assert!(
190            !dims.is_empty(),
191            "mean_dims requires at least one dimension"
192        );
193        let rank = self.shape().rank();
194        for &d in dims {
195            assert!(
196                d < rank,
197                "mean_dims dim {} out of bounds for rank {}",
198                d,
199                rank
200            );
201        }
202
203        // Compute sum over dims first, then divide by product of reduced sizes
204        let sum = self.sum_dims(dims, keepdim);
205        let factor: usize = dims.iter().map(|&d| self.shape().dims()[d]).product();
206        let scale = if factor > 0 {
207            1.0f32 / (factor as f32)
208        } else {
209            0.0
210        };
211        let out = sum.mul_scalar(scale);
212
213        if self.requires_grad() {
214            // Override autograd of mul to a single ReduceMeanDims node for correctness and clarity
215            // Re-register operation for out to use ReduceMeanDims
216            let mut reg = out.clone();
217            reg.set_requires_grad_internal(true);
218            let mut reduced: Vec<usize> = dims.to_vec();
219            reduced.sort_unstable();
220            reduced.dedup();
221            let grad_fn = GradFn::ReduceMeanDims {
222                dims: reduced,
223                input_shape: self.shape().dims().to_vec(),
224                keepdim,
225            };
226            reg.set_grad_fn(grad_fn.clone());
227            GradEngine::register_operation(reg.id(), vec![self.id()], grad_fn);
228            return reg;
229        }
230
231        out
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238
239    #[test]
240    fn test_mean_forward_basic() {
241        let mut x = Tensor::zeros(vec![2, 3]);
242        unsafe {
243            for i in 0..6 {
244                *x.as_mut_ptr().add(i) = i as f32;
245            }
246        }
247        let m = x.mean();
248        assert_eq!(m.shape().dims(), vec![1]);
249        unsafe {
250            assert!((*m.as_ptr() - (0.0 + 1.0 + 2.0 + 3.0 + 4.0 + 5.0) / 6.0).abs() < 1e-6);
251        }
252    }
253
254    #[test]
255    fn test_mean_autograd_all_equal() {
256        let x = Tensor::from_slice(&[1.0, 3.0, 5.0, 7.0], vec![4])
257            .unwrap()
258            .with_requires_grad();
259        let mut m = x.mean();
260        m.backward(None);
261        let gx = x.grad_owned().expect("grad missing");
262        for i in 0..4 {
263            unsafe {
264                assert_eq!(*gx.as_ptr().add(i), 0.25);
265            }
266        }
267    }
268
269    #[test]
270    fn test_mean_non_contiguous_transpose() {
271        // Test mean on transposed tensor (non-contiguous view)
272        let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
273        // Original: [[1, 2, 3], [4, 5, 6]]
274
275        let x_t = x.transpose(0, 1);
276        // Transposed: [[1, 4], [2, 5], [3, 6]]
277        assert!(!x_t.is_contiguous()); // Should be a view
278
279        let mean_orig = x.mean();
280        let mean_view = x_t.mean();
281
282        // Both should give the same result: (1+2+3+4+5+6)/6 = 3.5
283        assert!((mean_orig.get(&[0]) - 3.5).abs() < 1e-6);
284        assert!((mean_view.get(&[0]) - 3.5).abs() < 1e-6);
285    }
286
287    #[test]
288    fn test_mean_dims_non_contiguous() {
289        // Test mean_dims on non-contiguous tensor
290        let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
291        let x_t = x.transpose(0, 1); // [3, 2]
292        assert!(!x_t.is_contiguous());
293
294        // Mean along dim 0 of transposed tensor
295        let mean_dim0 = x_t.mean_dims(&[0], false);
296        assert_eq!(mean_dim0.shape().dims(), vec![2]);
297        // Should be [(1+2+3)/3, (4+5+6)/3] = [2.0, 5.0]
298        assert!((mean_dim0.get(&[0]) - 2.0).abs() < 1e-6);
299        assert!((mean_dim0.get(&[1]) - 5.0).abs() < 1e-6);
300
301        // Mean along dim 1 of transposed tensor
302        let mean_dim1 = x_t.mean_dims(&[1], false);
303        assert_eq!(mean_dim1.shape().dims(), vec![3]);
304        // Should be [(1+4)/2, (2+5)/2, (3+6)/2] = [2.5, 3.5, 4.5]
305        assert!((mean_dim1.get(&[0]) - 2.5).abs() < 1e-6);
306        assert!((mean_dim1.get(&[1]) - 3.5).abs() < 1e-6);
307        assert!((mean_dim1.get(&[2]) - 4.5).abs() < 1e-6);
308    }
309
310    #[test]
311    fn test_mean_permuted_tensor() {
312        // Test with permuted tensor
313        let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
314        let x = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
315
316        // Permute dimensions [2, 3, 4] -> [4, 2, 3]
317        let x_perm = x.permute(vec![2, 1, 0]);
318        assert!(!x_perm.is_contiguous());
319
320        let mean_orig = x.mean();
321        let mean_perm = x_perm.mean();
322
323        // Should give same result
324        assert!((mean_orig.get(&[0]) - mean_perm.get(&[0])).abs() < 1e-6);
325
326        // Expected mean: (0+1+2+...+23)/24 = 23*24/2/24 = 11.5
327        assert!((mean_orig.get(&[0]) - 11.5).abs() < 1e-6);
328    }
329}