train_station/tensor/reductions/var.rs
1//! Variance reduction operations for tensors
2//!
3//! This module provides variance reduction operations for tensors.
4//! The variance measures the average squared deviation from the mean,
5//! calculated as the mean of squared differences from the mean. This is
6//! commonly used in statistics, data analysis, and machine learning for
7//! understanding data variability and as a component of other statistical
8//! measures like standard deviation.
9//!
10//! # Operations
11//!
12//! * `var()` - Computes variance over all elements, returning a scalar tensor
13//! * `var_dims()` - Computes variance over specified dimensions with optional dimension preservation
14//!
15//! # Statistical Details
16//!
17//! The implementation uses population variance (unbiased=false), which
18//! divides by n rather than n-1. This matches PyTorch's default behavior for
19//! consistency with the reference implementation.
20//!
21//! # Examples
22//!
23//! ```
24//! use train_station::Tensor;
25//!
26//! // Compute variance of all elements
27//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
28//! let variance = tensor.var();
29//! assert!((variance.get(&[0]) - 1.25).abs() < 1e-5);
30//!
31//! // Compute variance along specific dimensions
32//! let matrix = Tensor::from_slice(&[1.0, 3.0, 2.0, 2.0], vec![2, 2]).unwrap();
33//! let row_vars = matrix.var_dims(&[1], true);
34//! assert_eq!(row_vars.shape().dims, vec![2, 1]);
35//! ```
36//!
37//! # Performance
38//!
39//! The implementation uses optimized paths for contiguous tensors with manual loop unrolling
40//! for better performance. Non-contiguous tensors use stride-aware iteration to maintain
41//! correctness while preserving memory layout efficiency.
42//!
43//! # Gradient Tracking
44//!
45//! Both operations support automatic gradient tracking when `requires_grad` is enabled.
46//! The gradient computation follows the mathematical derivative of the variance operation.
47
48use crate::gradtrack::{GradEngine, GradFn};
49use crate::tensor::core::Tensor;
50
51impl Tensor {
52 /// Computes the variance over all elements
53 ///
54 /// The variance is calculated as the mean of squared differences from the mean.
55 /// This operation reduces the tensor to a scalar value \[1\].
56 ///
57 /// The implementation uses population variance (divides by n rather
58 /// than n-1) to match PyTorch's default behavior.
59 ///
60 /// # Returns
61 ///
62 /// A scalar tensor containing the variance value
63 ///
64 /// # Examples
65 ///
66 /// ```
67 /// use train_station::Tensor;
68 ///
69 /// // Basic variance calculation
70 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
71 /// let variance = tensor.var();
72 /// assert!((variance.get(&[0]) - 1.25).abs() < 1e-5);
73 /// ```
74 ///
75 /// ```
76 /// use train_station::Tensor;
77 ///
78 /// // Variance of a larger dataset
79 /// let data = vec![1.0, 3.0, 5.0, 7.0, 2.0, 4.0, 6.0, 8.0];
80 /// let tensor = Tensor::from_slice(&data, vec![2, 2, 2]).unwrap();
81 /// let variance = tensor.var();
82 /// // mean=4.5, var=mean([3.5², 1.5², 0.5², 2.5², 2.5², 0.5², 1.5², 3.5²]) = 5.25
83 /// assert!((variance.get(&[0]) - 5.25).abs() < 1e-5);
84 /// ```
85 ///
86 /// ```
87 /// use train_station::Tensor;
88 ///
89 /// // Variance of constant values (should be 0)
90 /// let tensor = Tensor::from_slice(&[5.0, 5.0, 5.0, 5.0], vec![4]).unwrap();
91 /// let variance = tensor.var();
92 /// assert!((variance.get(&[0]) - 0.0).abs() < 1e-6);
93 /// ```
94 ///
95 /// # Performance
96 ///
97 /// Uses optimized contiguous tensor path with manual loop unrolling for better
98 /// performance. Non-contiguous tensors use stride-aware iteration.
99 /// The algorithm performs two passes: first to compute the mean, then to
100 /// compute the variance.
101 pub fn var(&self) -> Tensor {
102 let mut out = Tensor::new(vec![1]);
103 if self.size() == 0 {
104 out.fill(0.0);
105 } else {
106 // mean
107 let mut mean_val = 0.0f32;
108 let n = self.size() as f32;
109
110 if self.is_contiguous() {
111 // Fast path for contiguous tensors
112 unsafe {
113 let src = self.as_ptr();
114 for i in 0..self.size() {
115 mean_val += *src.add(i);
116 }
117 }
118 } else {
119 // Stride-aware path for non-contiguous tensors
120 let dims = self.shape().dims.clone();
121 for flat_idx in 0..self.size() {
122 // Convert flat index to multi-dimensional coordinates
123 let mut coords = vec![0; dims.len()];
124 let mut tmp = flat_idx;
125 for k in (0..dims.len()).rev() {
126 coords[k] = tmp % dims[k];
127 tmp /= dims[k];
128 }
129
130 // Get value using stride-aware offset
131 let offset = self.shape().offset(&coords);
132 let value = unsafe { *self.as_ptr().add(offset) };
133 mean_val += value;
134 }
135 }
136 mean_val /= n;
137
138 // var
139 let mut var_val = 0.0f32;
140
141 if self.is_contiguous() {
142 // Fast path for contiguous tensors
143 unsafe {
144 let src = self.as_ptr();
145 for i in 0..self.size() {
146 let d = *src.add(i) - mean_val;
147 var_val += d * d;
148 }
149 }
150 } else {
151 // Stride-aware path for non-contiguous tensors
152 let dims = self.shape().dims.clone();
153 for flat_idx in 0..self.size() {
154 // Convert flat index to multi-dimensional coordinates
155 let mut coords = vec![0; dims.len()];
156 let mut tmp = flat_idx;
157 for k in (0..dims.len()).rev() {
158 coords[k] = tmp % dims[k];
159 tmp /= dims[k];
160 }
161
162 // Get value using stride-aware offset
163 let offset = self.shape().offset(&coords);
164 let value = unsafe { *self.as_ptr().add(offset) };
165 let d = value - mean_val;
166 var_val += d * d;
167 }
168 }
169 var_val /= n;
170
171 unsafe {
172 *out.as_mut_ptr() = var_val;
173 }
174 }
175
176 if self.requires_grad() {
177 let mut result = out.clone();
178 result.set_requires_grad_internal(true);
179 let mean_tensor = {
180 let mut t = Tensor::new(vec![1]);
181 if self.size() == 0 {
182 t.fill(0.0);
183 } else {
184 let mut acc = 0.0f32;
185
186 if self.is_contiguous() {
187 unsafe {
188 for i in 0..self.size() {
189 acc += *self.as_ptr().add(i);
190 }
191 }
192 } else {
193 let dims = self.shape().dims.clone();
194 for flat_idx in 0..self.size() {
195 // Convert flat index to multi-dimensional coordinates
196 let mut coords = vec![0; dims.len()];
197 let mut tmp = flat_idx;
198 for k in (0..dims.len()).rev() {
199 coords[k] = tmp % dims[k];
200 tmp /= dims[k];
201 }
202
203 // Get value using stride-aware offset
204 let offset = self.shape().offset(&coords);
205 let value = unsafe { *self.as_ptr().add(offset) };
206 acc += value;
207 }
208 }
209
210 unsafe {
211 *t.as_mut_ptr() = acc / (self.size() as f32);
212 }
213 }
214 t
215 };
216 let grad_fn = GradFn::ReduceVar {
217 saved_mean: Box::new(mean_tensor),
218 saved_input: Box::new(self.clone()),
219 input_shape: self.shape().dims.clone(),
220 };
221 result.set_grad_fn(grad_fn.clone());
222 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
223 return result;
224 }
225
226 out
227 }
228
229 /// Computes the variance over specified dimensions
230 ///
231 /// Reduces the tensor along the specified dimensions by computing the variance
232 /// of each slice. The result maintains the original tensor structure with
233 /// reduced dimensions optionally preserved as size-1 dimensions.
234 ///
235 /// Uses population variance (divides by n rather than n-1) to match
236 /// PyTorch's default behavior.
237 ///
238 /// # Arguments
239 ///
240 /// * `dims` - Vector of dimension indices to reduce over (must be valid for tensor rank)
241 /// * `keepdim` - Whether to keep reduced dimensions as size-1 dimensions
242 ///
243 /// # Returns
244 ///
245 /// A tensor with variance computed over the specified dimensions
246 ///
247 /// # Examples
248 ///
249 /// ```
250 /// use train_station::Tensor;
251 ///
252 /// // Variance along rows (dimension 1) with keepdim=true
253 /// let matrix = Tensor::from_slice(&[1.0, 3.0, 2.0, 2.0], vec![2, 2]).unwrap();
254 /// let row_vars = matrix.var_dims(&[1], true);
255 /// assert_eq!(row_vars.shape().dims, vec![2, 1]);
256 /// assert!((row_vars.get(&[0, 0]) - 1.0).abs() < 1e-6); // var([1, 3]) = 1.0
257 /// assert!((row_vars.get(&[1, 0]) - 0.0).abs() < 1e-6); // var([2, 2]) = 0.0
258 /// ```
259 ///
260 /// ```
261 /// use train_station::Tensor;
262 ///
263 /// // Variance along columns (dimension 0) with keepdim=false
264 /// let matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
265 /// let col_vars = matrix.var_dims(&[0], false);
266 /// assert_eq!(col_vars.shape().dims, vec![2]);
267 /// // var([1, 3]) = 1.0, var([2, 4]) = 1.0
268 /// assert!((col_vars.get(&[0]) - 1.0).abs() < 1e-6);
269 /// assert!((col_vars.get(&[1]) - 1.0).abs() < 1e-6);
270 /// ```
271 ///
272 /// ```
273 /// use train_station::Tensor;
274 ///
275 /// // Variance over multiple dimensions
276 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
277 /// let var_all = tensor.var_dims(&[0, 1], false);
278 /// assert_eq!(var_all.shape().dims, vec![1]);
279 /// // var([1, 2, 3, 4]) = 1.25
280 /// assert!((var_all.get(&[0]) - 1.25).abs() < 1e-5);
281 /// ```
282 ///
283 /// # Panics
284 ///
285 /// * If `dims` is empty
286 /// * If any dimension index is out of bounds for the tensor rank
287 /// * If the reduced size is 0 (invalid for variance calculation)
288 ///
289 /// # Performance
290 ///
291 /// Uses efficient coordinate-based iteration that works correctly with
292 /// both contiguous and non-contiguous tensor layouts. The algorithm performs
293 /// two passes: first to compute means, then to compute variances.
294 pub fn var_dims(&self, dims: &[usize], keepdim: bool) -> Tensor {
295 assert!(!dims.is_empty(), "var_dims requires at least one dimension");
296 let rank = self.shape().rank();
297 for &d in dims {
298 assert!(
299 d < rank,
300 "var_dims dim {} out of bounds for rank {}",
301 d,
302 rank
303 );
304 }
305
306 // Output shape
307 let mut out_dims = self.shape().dims.clone();
308 let mut reduced: Vec<usize> = dims.to_vec();
309 reduced.sort_unstable();
310 reduced.dedup();
311 for &d in reduced.iter() {
312 out_dims[d] = if keepdim { 1 } else { 0 };
313 }
314 if !keepdim {
315 out_dims.retain(|&s| s != 0);
316 }
317 if out_dims.is_empty() {
318 out_dims.push(1);
319 }
320
321 let mut mean = Tensor::zeros(out_dims.clone());
322 let mut var = Tensor::zeros(out_dims.clone());
323
324 let in_shape = self.shape().dims.clone();
325 let out_rank = mean.shape().rank();
326 let mut in_coords = vec![0usize; rank];
327 let n_reduced: usize = reduced.iter().map(|&d| in_shape[d]).product();
328 assert!(n_reduced > 0, "reduced size must be > 0");
329 unsafe {
330 let mptr = mean.as_mut_ptr();
331 // sum for mean
332 for lin in 0..self.size() {
333 let mut tmp = lin;
334 for i in (0..rank).rev() {
335 let s = in_shape[i];
336 in_coords[i] = if s == 0 { 0 } else { tmp % s };
337 if s != 0 {
338 tmp /= s;
339 }
340 }
341 let mut out_coords: Vec<usize> = Vec::with_capacity(out_rank);
342 for (i, &ic) in in_coords.iter().enumerate().take(rank) {
343 if reduced.contains(&i) {
344 if keepdim {
345 out_coords.push(0);
346 }
347 } else {
348 out_coords.push(ic);
349 }
350 }
351 let off = if out_coords.is_empty() {
352 0
353 } else {
354 mean.shape().offset(&out_coords)
355 };
356 // Get input value using stride-aware offset
357 let in_offset = self.shape().offset(&in_coords);
358 let value = *self.as_ptr().add(in_offset);
359 *mptr.add(off) += value;
360 }
361 for i in 0..mean.size() {
362 *mptr.add(i) /= n_reduced as f32;
363 }
364 // accumulate squared diffs
365 let vptr = var.as_mut_ptr();
366 for lin in 0..self.size() {
367 let mut tmp = lin;
368 for i in (0..rank).rev() {
369 let s = in_shape[i];
370 in_coords[i] = if s == 0 { 0 } else { tmp % s };
371 if s != 0 {
372 tmp /= s;
373 }
374 }
375 let mut out_coords: Vec<usize> = Vec::with_capacity(out_rank);
376 for (i, &ic) in in_coords.iter().enumerate().take(rank) {
377 if reduced.contains(&i) {
378 if keepdim {
379 out_coords.push(0);
380 }
381 } else {
382 out_coords.push(ic);
383 }
384 }
385 let off = if out_coords.is_empty() {
386 0
387 } else {
388 var.shape().offset(&out_coords)
389 };
390 let mu = *mptr.add(off);
391
392 // Get input value using stride-aware offset
393 let in_offset = self.shape().offset(&in_coords);
394 let x = *self.as_ptr().add(in_offset);
395 *vptr.add(off) += (x - mu) * (x - mu);
396 }
397 for i in 0..var.size() {
398 *vptr.add(i) /= n_reduced as f32;
399 }
400 }
401
402 if self.requires_grad() {
403 let mut result = var.clone();
404 result.set_requires_grad_internal(true);
405 let grad_fn = GradFn::ReduceVarDims {
406 dims: reduced,
407 keepdim,
408 input_shape: self.shape().dims.clone(),
409 saved_mean: Box::new(mean),
410 saved_input: Box::new(self.clone()),
411 };
412 result.set_grad_fn(grad_fn.clone());
413 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
414 return result;
415 }
416
417 var
418 }
419}
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424
425 #[test]
426 fn test_var_forward_basic() {
427 let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
428 let v = x.var();
429 unsafe {
430 let val = *v.as_ptr();
431 assert!((val - 1.25).abs() < 1e-6);
432 }
433 }
434
435 #[test]
436 fn test_var_dims_forward() {
437 let x = Tensor::from_slice(&[1.0, 3.0, 2.0, 2.0], vec![2, 2]).unwrap();
438 let v = x.var_dims(&[1], true);
439 assert_eq!(v.shape().dims, vec![2, 1]);
440 assert!((v.get(&[0, 0]) - 1.0).abs() < 1e-6);
441 assert!((v.get(&[1, 0]) - 0.0).abs() < 1e-6);
442 }
443
444 #[test]
445 fn test_var_non_contiguous_transpose() {
446 // Test var on transposed tensor (non-contiguous view)
447 let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
448 // Original: [[1, 2, 3], [4, 5, 6]]
449
450 let x_t = x.transpose(0, 1);
451 // Transposed: [[1, 4], [2, 5], [3, 6]]
452 assert!(!x_t.is_contiguous()); // Should be a view
453
454 let var_orig = x.var();
455 let var_view = x_t.var();
456
457 // Both should give the same result
458 assert!((var_orig.get(&[0]) - var_view.get(&[0])).abs() < 1e-6);
459
460 // Expected var of [1,2,3,4,5,6]: mean=3.5, var=mean([2.5^2,1.5^2,0.5^2,0.5^2,1.5^2,2.5^2])=2.9167
461 let expected_var = 2.9166667_f32;
462 assert!((var_orig.get(&[0]) - expected_var).abs() < 1e-5);
463 }
464
465 #[test]
466 fn test_var_dims_non_contiguous() {
467 // Test var_dims on non-contiguous tensor
468 let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
469 let x_t = x.transpose(0, 1); // [3, 2]
470 assert!(!x_t.is_contiguous());
471
472 // Var along dim 0 of transposed tensor
473 let var_dim0 = x_t.var_dims(&[0], false);
474 assert_eq!(var_dim0.shape().dims, vec![2]);
475
476 // For dim 0: [1,2,3] and [4,5,6]
477 // [1,2,3]: mean=2, var=((1-2)^2 + (2-2)^2 + (3-2)^2)/3 = 2/3 ≈ 0.6667
478 // [4,5,6]: mean=5, var=((4-5)^2 + (5-5)^2 + (6-5)^2)/3 = 2/3 ≈ 0.6667
479 let expected_var = 2.0 / 3.0_f32;
480 assert!((var_dim0.get(&[0]) - expected_var).abs() < 1e-5);
481 assert!((var_dim0.get(&[1]) - expected_var).abs() < 1e-5);
482 }
483
484 #[test]
485 fn test_var_permuted_tensor() {
486 // Test with permuted tensor - simple case with known var
487 let data = vec![1.0, 3.0, 5.0, 7.0, 2.0, 4.0, 6.0, 8.0];
488 let x = Tensor::from_slice(&data, vec![2, 2, 2]).unwrap();
489
490 // Permute dimensions [2, 2, 2] -> [2, 2, 2] (swap first and last)
491 let x_perm = x.permute(vec![2, 1, 0]);
492 assert!(!x_perm.is_contiguous());
493
494 let var_orig = x.var();
495 let var_perm = x_perm.var();
496
497 // Should give same result
498 assert!((var_orig.get(&[0]) - var_perm.get(&[0])).abs() < 1e-6);
499
500 // Data is [1,3,5,7,2,4,6,8], mean=4.5
501 // var = mean([3.5^2, 1.5^2, 0.5^2, 2.5^2, 2.5^2, 0.5^2, 1.5^2, 3.5^2]) = 5.25
502 let expected_var = 5.25_f32;
503 assert!((var_orig.get(&[0]) - expected_var).abs() < 1e-5);
504 }
505}