train_station/tensor/reductions/mean.rs
1use crate::gradtrack::{GradEngine, GradFn};
2use crate::tensor::core::Tensor;
3
4impl Tensor {
5 /// Computes the arithmetic mean of all elements in the tensor
6 ///
7 /// This method calculates the average value across all tensor elements by summing
8 /// all values and dividing by the total number of elements. The result is a scalar
9 /// tensor containing the mean value. This operation supports gradient tracking
10 /// through the GradTrack system.
11 ///
12 /// # Returns
13 ///
14 /// A tensor with shape `[1]` containing the arithmetic mean of all elements.
15 /// For empty tensors, returns `0.0` as a safe default.
16 ///
17 /// # Performance Characteristics
18 ///
19 /// - **Linear Time**: O(n) complexity for computing the sum
20 /// - **Memory Efficient**: Single pass through tensor data with SIMD-optimized accumulation
21 /// - **Numerical Stability**: Uses direct accumulation for typical tensor sizes
22 /// - **Edge Case Handling**: Returns 0.0 for empty tensors
23 ///
24 /// # Examples
25 ///
26 /// ```
27 /// use train_station::Tensor;
28 ///
29 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
30 /// let mean_val = tensor.mean();
31 /// assert_eq!(mean_val.get(&[0]), 2.5); // (1+2+3+4)/4 = 2.5
32 /// ```
33 ///
34 /// ```
35 /// use train_station::Tensor;
36 ///
37 /// // Empty tensor case
38 /// let empty_tensor = Tensor::new(vec![0]);
39 /// let mean_val = empty_tensor.mean();
40 /// assert_eq!(mean_val.get(&[0]), 0.0);
41 /// ```
42 ///
43 /// # GradTrack Support
44 ///
45 /// When `requires_grad` is true, this operation is tracked for automatic
46 /// differentiation. The gradient computation distributes the gradient equally
47 /// across all input elements.
48 #[track_caller]
49 pub fn mean(&self) -> Tensor {
50 let mut out = Tensor::new(vec![1]);
51 if self.size() == 0 {
52 // Convention: mean over empty returns 0.0 (aligns with safe behavior for now)
53 out.fill(0.0);
54 } else {
55 let mut acc0 = 0.0f32;
56
57 if self.is_contiguous() {
58 // Fast path for contiguous tensors
59 unsafe {
60 let src = self.as_ptr();
61 let size = self.size();
62 let mut i = 0usize;
63 while i + 4 <= size {
64 let x0 = *src.add(i);
65 let x1 = *src.add(i + 1);
66 let x2 = *src.add(i + 2);
67 let x3 = *src.add(i + 3);
68 acc0 += x0 + x1 + x2 + x3;
69 i += 4;
70 }
71 while i < size {
72 acc0 += *src.add(i);
73 i += 1;
74 }
75 }
76 } else {
77 // Stride-aware path for non-contiguous tensors
78 let dims = self.shape().dims.clone();
79 for flat_idx in 0..self.size() {
80 // Convert flat index to multi-dimensional coordinates
81 let mut coords = vec![0; dims.len()];
82 let mut tmp = flat_idx;
83 for k in (0..dims.len()).rev() {
84 coords[k] = tmp % dims[k];
85 tmp /= dims[k];
86 }
87
88 // Get value using stride-aware offset
89 let offset = self.shape().offset(&coords);
90 let value = unsafe { *self.as_ptr().add(offset) };
91 acc0 += value;
92 }
93 }
94
95 unsafe {
96 *out.as_mut_ptr() = acc0 / (self.size() as f32);
97 }
98 }
99
100 if self.requires_grad() {
101 out.set_requires_grad_internal(true);
102 let grad_fn = GradFn::ReduceMean {
103 input_shape: self.shape().dims.clone(),
104 numel: self.size(),
105 };
106 out.set_grad_fn(grad_fn.clone());
107 GradEngine::register_operation(out.id(), vec![self.id()], grad_fn);
108 }
109 out
110 }
111
112 /// Computes the arithmetic mean over specified dimensions
113 ///
114 /// This method calculates the mean value along the specified dimensions by first
115 /// computing the sum over those dimensions and then dividing by the product of
116 /// the reduced dimension sizes. The `keepdim` parameter determines whether
117 /// reduced dimensions are kept with size 1 or removed entirely.
118 ///
119 /// # Arguments
120 ///
121 /// * `dims` - Dimensions to reduce over (must be valid for the tensor's rank)
122 /// * `keepdim` - If true, reduced dimensions are kept with size 1; if false, they are removed
123 ///
124 /// # Returns
125 ///
126 /// A tensor with the specified dimensions reduced by computing the mean.
127 /// The output shape depends on `keepdim`:
128 /// - If `keepdim` is `true`, reduced dimensions have size 1
129 /// - If `keepdim` is `false`, reduced dimensions are removed
130 ///
131 /// # Performance Characteristics
132 ///
133 /// - **Efficient Implementation**: Uses `sum_dims` followed by scalar multiplication
134 /// - **Memory Optimized**: Leverages existing sum reduction for optimal performance
135 /// - **Shape Computation**: Fast output shape calculation with dimension preservation
136 /// - **Numerical Stability**: Maintains precision through direct computation
137 ///
138 /// # Examples
139 ///
140 /// ```
141 /// use train_station::Tensor;
142 ///
143 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
144 ///
145 /// // Mean over columns (dim 1), keeping dimensions
146 /// let mean_cols = tensor.mean_dims(&[1], true);
147 /// assert_eq!(mean_cols.shape().dims, vec![2, 1]);
148 /// assert_eq!(mean_cols.get(&[0, 0]), 2.0); // (1+2+3)/3 = 2.0
149 /// assert_eq!(mean_cols.get(&[1, 0]), 5.0); // (4+5+6)/3 = 5.0
150 /// ```
151 ///
152 /// ```
153 /// use train_station::Tensor;
154 ///
155 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
156 ///
157 /// // Mean over rows (dim 0), removing dimensions
158 /// let mean_rows = tensor.mean_dims(&[0], false);
159 /// assert_eq!(mean_rows.shape().dims, vec![3]);
160 /// assert_eq!(mean_rows.get(&[0]), 2.5); // (1+4)/2 = 2.5
161 /// assert_eq!(mean_rows.get(&[1]), 3.5); // (2+5)/2 = 3.5
162 /// assert_eq!(mean_rows.get(&[2]), 4.5); // (3+6)/2 = 4.5
163 /// ```
164 ///
165 /// ```
166 /// use train_station::Tensor;
167 ///
168 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
169 ///
170 /// // Mean over multiple dimensions
171 /// let mean_all = tensor.mean_dims(&[0, 1], false);
172 /// assert_eq!(mean_all.shape().dims, vec![1]);
173 /// assert_eq!(mean_all.get(&[0]), 2.5); // (1+2+3+4)/4 = 2.5
174 /// ```
175 ///
176 /// # Panics
177 ///
178 /// Panics if:
179 /// * `dims` is empty
180 /// * Any dimension in `dims` is out of bounds for the tensor's rank
181 ///
182 /// # GradTrack Support
183 ///
184 /// When `requires_grad` is true, this operation is tracked for automatic
185 /// differentiation. The gradient computation preserves the original input
186 /// shape and handles broadcasting correctly through the ReduceMeanDims gradient function.
187 #[track_caller]
188 pub fn mean_dims(&self, dims: &[usize], keepdim: bool) -> Tensor {
189 assert!(
190 !dims.is_empty(),
191 "mean_dims requires at least one dimension"
192 );
193 let rank = self.shape().rank();
194 for &d in dims {
195 assert!(
196 d < rank,
197 "mean_dims dim {} out of bounds for rank {}",
198 d,
199 rank
200 );
201 }
202
203 // Compute sum over dims first, then divide by product of reduced sizes
204 let sum = self.sum_dims(dims, keepdim);
205 let factor: usize = dims.iter().map(|&d| self.shape().dims[d]).product();
206 let scale = if factor > 0 {
207 1.0f32 / (factor as f32)
208 } else {
209 0.0
210 };
211 let out = sum.mul_scalar(scale);
212
213 if self.requires_grad() {
214 // Override autograd of mul to a single ReduceMeanDims node for correctness and clarity
215 // Re-register operation for out to use ReduceMeanDims
216 let mut reg = out.clone();
217 reg.set_requires_grad_internal(true);
218 let mut reduced: Vec<usize> = dims.to_vec();
219 reduced.sort_unstable();
220 reduced.dedup();
221 let grad_fn = GradFn::ReduceMeanDims {
222 dims: reduced,
223 input_shape: self.shape().dims.clone(),
224 keepdim,
225 };
226 reg.set_grad_fn(grad_fn.clone());
227 GradEngine::register_operation(reg.id(), vec![self.id()], grad_fn);
228 return reg;
229 }
230
231 out
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238
239 #[test]
240 fn test_mean_forward_basic() {
241 let mut x = Tensor::zeros(vec![2, 3]);
242 unsafe {
243 for i in 0..6 {
244 *x.as_mut_ptr().add(i) = i as f32;
245 }
246 }
247 let m = x.mean();
248 assert_eq!(m.shape().dims, vec![1]);
249 unsafe {
250 assert!((*m.as_ptr() - (0.0 + 1.0 + 2.0 + 3.0 + 4.0 + 5.0) / 6.0).abs() < 1e-6);
251 }
252 }
253
254 #[test]
255 fn test_mean_autograd_all_equal() {
256 let x = Tensor::from_slice(&[1.0, 3.0, 5.0, 7.0], vec![4])
257 .unwrap()
258 .with_requires_grad();
259 let mut m = x.mean();
260 m.backward(None);
261 let gx = x.grad_by_value().expect("grad missing");
262 for i in 0..4 {
263 unsafe {
264 assert_eq!(*gx.as_ptr().add(i), 0.25);
265 }
266 }
267 }
268
269 #[test]
270 fn test_mean_non_contiguous_transpose() {
271 // Test mean on transposed tensor (non-contiguous view)
272 let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
273 // Original: [[1, 2, 3], [4, 5, 6]]
274
275 let x_t = x.transpose(0, 1);
276 // Transposed: [[1, 4], [2, 5], [3, 6]]
277 assert!(!x_t.is_contiguous()); // Should be a view
278
279 let mean_orig = x.mean();
280 let mean_view = x_t.mean();
281
282 // Both should give the same result: (1+2+3+4+5+6)/6 = 3.5
283 assert!((mean_orig.get(&[0]) - 3.5).abs() < 1e-6);
284 assert!((mean_view.get(&[0]) - 3.5).abs() < 1e-6);
285 }
286
287 #[test]
288 fn test_mean_dims_non_contiguous() {
289 // Test mean_dims on non-contiguous tensor
290 let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
291 let x_t = x.transpose(0, 1); // [3, 2]
292 assert!(!x_t.is_contiguous());
293
294 // Mean along dim 0 of transposed tensor
295 let mean_dim0 = x_t.mean_dims(&[0], false);
296 assert_eq!(mean_dim0.shape().dims, vec![2]);
297 // Should be [(1+2+3)/3, (4+5+6)/3] = [2.0, 5.0]
298 assert!((mean_dim0.get(&[0]) - 2.0).abs() < 1e-6);
299 assert!((mean_dim0.get(&[1]) - 5.0).abs() < 1e-6);
300
301 // Mean along dim 1 of transposed tensor
302 let mean_dim1 = x_t.mean_dims(&[1], false);
303 assert_eq!(mean_dim1.shape().dims, vec![3]);
304 // Should be [(1+4)/2, (2+5)/2, (3+6)/2] = [2.5, 3.5, 4.5]
305 assert!((mean_dim1.get(&[0]) - 2.5).abs() < 1e-6);
306 assert!((mean_dim1.get(&[1]) - 3.5).abs() < 1e-6);
307 assert!((mean_dim1.get(&[2]) - 4.5).abs() < 1e-6);
308 }
309
310 #[test]
311 fn test_mean_permuted_tensor() {
312 // Test with permuted tensor
313 let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
314 let x = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
315
316 // Permute dimensions [2, 3, 4] -> [4, 2, 3]
317 let x_perm = x.permute(vec![2, 1, 0]);
318 assert!(!x_perm.is_contiguous());
319
320 let mean_orig = x.mean();
321 let mean_perm = x_perm.mean();
322
323 // Should give same result
324 assert!((mean_orig.get(&[0]) - mean_perm.get(&[0])).abs() < 1e-6);
325
326 // Expected mean: (0+1+2+...+23)/24 = 23*24/2/24 = 11.5
327 assert!((mean_orig.get(&[0]) - 11.5).abs() < 1e-6);
328 }
329}