train_station/tensor/reductions/mean.rs
1use crate::gradtrack::{GradEngine, GradFn};
2use crate::tensor::core::Tensor;
3
4impl Tensor {
5 /// Computes the arithmetic mean of all elements in the tensor
6 ///
7 /// This method calculates the average value across all tensor elements by summing
8 /// all values and dividing by the total number of elements. The result is a scalar
9 /// tensor containing the mean value. This operation supports gradient tracking
10 /// through the GradTrack system.
11 ///
12 /// # Returns
13 ///
14 /// A tensor with shape `[1]` containing the arithmetic mean of all elements.
15 /// For empty tensors, returns `0.0` as a safe default.
16 ///
17 /// # Performance Characteristics
18 ///
19 /// - **Linear Time**: O(n) complexity for computing the sum
20 /// - **Memory Efficient**: Single pass through tensor data with SIMD-optimized accumulation
21 /// - **Numerical Stability**: Uses direct accumulation for typical tensor sizes
22 /// - **Edge Case Handling**: Returns 0.0 for empty tensors
23 ///
24 /// # Examples
25 ///
26 /// ```
27 /// use train_station::Tensor;
28 ///
29 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
30 /// let mean_val = tensor.mean();
31 /// assert_eq!(mean_val.get(&[0]), 2.5); // (1+2+3+4)/4 = 2.5
32 /// ```
33 ///
34 /// ```
35 /// use train_station::Tensor;
36 ///
37 /// // Empty tensor case
38 /// let empty_tensor = Tensor::new(vec![0]);
39 /// let mean_val = empty_tensor.mean();
40 /// assert_eq!(mean_val.get(&[0]), 0.0);
41 /// ```
42 ///
43 /// # GradTrack Support
44 ///
45 /// When `requires_grad` is true, this operation is tracked for automatic
46 /// differentiation. The gradient computation distributes the gradient equally
47 /// across all input elements.
48 pub fn mean(&self) -> Tensor {
49 let mut out = Tensor::new(vec![1]);
50 if self.size() == 0 {
51 // Convention: mean over empty returns 0.0 (aligns with safe behavior for now)
52 out.fill(0.0);
53 } else {
54 let mut acc0 = 0.0f32;
55
56 if self.is_contiguous() {
57 // Fast path for contiguous tensors
58 unsafe {
59 let src = self.as_ptr();
60 let size = self.size();
61 let mut i = 0usize;
62 while i + 4 <= size {
63 let x0 = *src.add(i);
64 let x1 = *src.add(i + 1);
65 let x2 = *src.add(i + 2);
66 let x3 = *src.add(i + 3);
67 acc0 += x0 + x1 + x2 + x3;
68 i += 4;
69 }
70 while i < size {
71 acc0 += *src.add(i);
72 i += 1;
73 }
74 }
75 } else {
76 // Stride-aware path for non-contiguous tensors
77 let dims = self.shape().dims.clone();
78 for flat_idx in 0..self.size() {
79 // Convert flat index to multi-dimensional coordinates
80 let mut coords = vec![0; dims.len()];
81 let mut tmp = flat_idx;
82 for k in (0..dims.len()).rev() {
83 coords[k] = tmp % dims[k];
84 tmp /= dims[k];
85 }
86
87 // Get value using stride-aware offset
88 let offset = self.shape().offset(&coords);
89 let value = unsafe { *self.as_ptr().add(offset) };
90 acc0 += value;
91 }
92 }
93
94 unsafe {
95 *out.as_mut_ptr() = acc0 / (self.size() as f32);
96 }
97 }
98
99 if self.requires_grad() {
100 out.set_requires_grad_internal(true);
101 let grad_fn = GradFn::ReduceMean {
102 input_shape: self.shape().dims.clone(),
103 numel: self.size(),
104 };
105 out.set_grad_fn(grad_fn.clone());
106 GradEngine::register_operation(out.id(), vec![self.id()], grad_fn);
107 }
108 out
109 }
110
111 /// Computes the arithmetic mean over specified dimensions
112 ///
113 /// This method calculates the mean value along the specified dimensions by first
114 /// computing the sum over those dimensions and then dividing by the product of
115 /// the reduced dimension sizes. The `keepdim` parameter determines whether
116 /// reduced dimensions are kept with size 1 or removed entirely.
117 ///
118 /// # Arguments
119 ///
120 /// * `dims` - Dimensions to reduce over (must be valid for the tensor's rank)
121 /// * `keepdim` - If true, reduced dimensions are kept with size 1; if false, they are removed
122 ///
123 /// # Returns
124 ///
125 /// A tensor with the specified dimensions reduced by computing the mean.
126 /// The output shape depends on `keepdim`:
127 /// - If `keepdim` is `true`, reduced dimensions have size 1
128 /// - If `keepdim` is `false`, reduced dimensions are removed
129 ///
130 /// # Performance Characteristics
131 ///
132 /// - **Efficient Implementation**: Uses `sum_dims` followed by scalar multiplication
133 /// - **Memory Optimized**: Leverages existing sum reduction for optimal performance
134 /// - **Shape Computation**: Fast output shape calculation with dimension preservation
135 /// - **Numerical Stability**: Maintains precision through direct computation
136 ///
137 /// # Examples
138 ///
139 /// ```
140 /// use train_station::Tensor;
141 ///
142 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
143 ///
144 /// // Mean over columns (dim 1), keeping dimensions
145 /// let mean_cols = tensor.mean_dims(&[1], true);
146 /// assert_eq!(mean_cols.shape().dims, vec![2, 1]);
147 /// assert_eq!(mean_cols.get(&[0, 0]), 2.0); // (1+2+3)/3 = 2.0
148 /// assert_eq!(mean_cols.get(&[1, 0]), 5.0); // (4+5+6)/3 = 5.0
149 /// ```
150 ///
151 /// ```
152 /// use train_station::Tensor;
153 ///
154 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
155 ///
156 /// // Mean over rows (dim 0), removing dimensions
157 /// let mean_rows = tensor.mean_dims(&[0], false);
158 /// assert_eq!(mean_rows.shape().dims, vec![3]);
159 /// assert_eq!(mean_rows.get(&[0]), 2.5); // (1+4)/2 = 2.5
160 /// assert_eq!(mean_rows.get(&[1]), 3.5); // (2+5)/2 = 3.5
161 /// assert_eq!(mean_rows.get(&[2]), 4.5); // (3+6)/2 = 4.5
162 /// ```
163 ///
164 /// ```
165 /// use train_station::Tensor;
166 ///
167 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
168 ///
169 /// // Mean over multiple dimensions
170 /// let mean_all = tensor.mean_dims(&[0, 1], false);
171 /// assert_eq!(mean_all.shape().dims, vec![1]);
172 /// assert_eq!(mean_all.get(&[0]), 2.5); // (1+2+3+4)/4 = 2.5
173 /// ```
174 ///
175 /// # Panics
176 ///
177 /// Panics if:
178 /// * `dims` is empty
179 /// * Any dimension in `dims` is out of bounds for the tensor's rank
180 ///
181 /// # GradTrack Support
182 ///
183 /// When `requires_grad` is true, this operation is tracked for automatic
184 /// differentiation. The gradient computation preserves the original input
185 /// shape and handles broadcasting correctly through the ReduceMeanDims gradient function.
186 pub fn mean_dims(&self, dims: &[usize], keepdim: bool) -> Tensor {
187 assert!(
188 !dims.is_empty(),
189 "mean_dims requires at least one dimension"
190 );
191 let rank = self.shape().rank();
192 for &d in dims {
193 assert!(
194 d < rank,
195 "mean_dims dim {} out of bounds for rank {}",
196 d,
197 rank
198 );
199 }
200
201 // Compute sum over dims first, then divide by product of reduced sizes
202 let sum = self.sum_dims(dims, keepdim);
203 let factor: usize = dims.iter().map(|&d| self.shape().dims[d]).product();
204 let scale = if factor > 0 {
205 1.0f32 / (factor as f32)
206 } else {
207 0.0
208 };
209 let out = sum.mul_scalar(scale);
210
211 if self.requires_grad() {
212 // Override autograd of mul to a single ReduceMeanDims node for correctness and clarity
213 // Re-register operation for out to use ReduceMeanDims
214 let mut reg = out.clone();
215 reg.set_requires_grad_internal(true);
216 let mut reduced: Vec<usize> = dims.to_vec();
217 reduced.sort_unstable();
218 reduced.dedup();
219 let grad_fn = GradFn::ReduceMeanDims {
220 dims: reduced,
221 input_shape: self.shape().dims.clone(),
222 keepdim,
223 };
224 reg.set_grad_fn(grad_fn.clone());
225 GradEngine::register_operation(reg.id(), vec![self.id()], grad_fn);
226 return reg;
227 }
228
229 out
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236
237 #[test]
238 fn test_mean_forward_basic() {
239 let mut x = Tensor::zeros(vec![2, 3]);
240 unsafe {
241 for i in 0..6 {
242 *x.as_mut_ptr().add(i) = i as f32;
243 }
244 }
245 let m = x.mean();
246 assert_eq!(m.shape().dims, vec![1]);
247 unsafe {
248 assert!((*m.as_ptr() - (0.0 + 1.0 + 2.0 + 3.0 + 4.0 + 5.0) / 6.0).abs() < 1e-6);
249 }
250 }
251
252 #[test]
253 fn test_mean_autograd_all_equal() {
254 let x = Tensor::from_slice(&[1.0, 3.0, 5.0, 7.0], vec![4])
255 .unwrap()
256 .with_requires_grad();
257 let mut m = x.mean();
258 m.backward(None);
259 let gx = x.grad_by_value().expect("grad missing");
260 for i in 0..4 {
261 unsafe {
262 assert_eq!(*gx.as_ptr().add(i), 0.25);
263 }
264 }
265 }
266
267 #[test]
268 fn test_mean_non_contiguous_transpose() {
269 // Test mean on transposed tensor (non-contiguous view)
270 let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
271 // Original: [[1, 2, 3], [4, 5, 6]]
272
273 let x_t = x.transpose(0, 1);
274 // Transposed: [[1, 4], [2, 5], [3, 6]]
275 assert!(!x_t.is_contiguous()); // Should be a view
276
277 let mean_orig = x.mean();
278 let mean_view = x_t.mean();
279
280 // Both should give the same result: (1+2+3+4+5+6)/6 = 3.5
281 assert!((mean_orig.get(&[0]) - 3.5).abs() < 1e-6);
282 assert!((mean_view.get(&[0]) - 3.5).abs() < 1e-6);
283 }
284
285 #[test]
286 fn test_mean_dims_non_contiguous() {
287 // Test mean_dims on non-contiguous tensor
288 let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
289 let x_t = x.transpose(0, 1); // [3, 2]
290 assert!(!x_t.is_contiguous());
291
292 // Mean along dim 0 of transposed tensor
293 let mean_dim0 = x_t.mean_dims(&[0], false);
294 assert_eq!(mean_dim0.shape().dims, vec![2]);
295 // Should be [(1+2+3)/3, (4+5+6)/3] = [2.0, 5.0]
296 assert!((mean_dim0.get(&[0]) - 2.0).abs() < 1e-6);
297 assert!((mean_dim0.get(&[1]) - 5.0).abs() < 1e-6);
298
299 // Mean along dim 1 of transposed tensor
300 let mean_dim1 = x_t.mean_dims(&[1], false);
301 assert_eq!(mean_dim1.shape().dims, vec![3]);
302 // Should be [(1+4)/2, (2+5)/2, (3+6)/2] = [2.5, 3.5, 4.5]
303 assert!((mean_dim1.get(&[0]) - 2.5).abs() < 1e-6);
304 assert!((mean_dim1.get(&[1]) - 3.5).abs() < 1e-6);
305 assert!((mean_dim1.get(&[2]) - 4.5).abs() < 1e-6);
306 }
307
308 #[test]
309 fn test_mean_permuted_tensor() {
310 // Test with permuted tensor
311 let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
312 let x = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
313
314 // Permute dimensions [2, 3, 4] -> [4, 2, 3]
315 let x_perm = x.permute(vec![2, 1, 0]);
316 assert!(!x_perm.is_contiguous());
317
318 let mean_orig = x.mean();
319 let mean_perm = x_perm.mean();
320
321 // Should give same result
322 assert!((mean_orig.get(&[0]) - mean_perm.get(&[0])).abs() < 1e-6);
323
324 // Expected mean: (0+1+2+...+23)/24 = 23*24/2/24 = 11.5
325 assert!((mean_orig.get(&[0]) - 11.5).abs() < 1e-6);
326 }
327}