train_station/tensor/reductions/var.rs
1//! Variance reduction operations for tensors
2//!
3//! This module provides variance reduction operations for tensors.
4//! The variance measures the average squared deviation from the mean,
5//! calculated as the mean of squared differences from the mean. This is
6//! commonly used in statistics, data analysis, and machine learning for
7//! understanding data variability and as a component of other statistical
8//! measures like standard deviation.
9//!
10//! # Operations
11//!
12//! * `var()` - Computes variance over all elements, returning a scalar tensor
13//! * `var_dims()` - Computes variance over specified dimensions with optional dimension preservation
14//!
15//! # Statistical Details
16//!
17//! The implementation uses population variance (unbiased=false), which
18//! divides by n rather than n-1. This matches PyTorch's default behavior for
19//! consistency with the reference implementation.
20//!
21//! # Examples
22//!
23//! ```
24//! use train_station::Tensor;
25//!
26//! // Compute variance of all elements
27//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
28//! let variance = tensor.var();
29//! assert!((variance.get(&[0]) - 1.25).abs() < 1e-5);
30//!
31//! // Compute variance along specific dimensions
32//! let matrix = Tensor::from_slice(&[1.0, 3.0, 2.0, 2.0], vec![2, 2]).unwrap();
33//! let row_vars = matrix.var_dims(&[1], true);
34//! assert_eq!(row_vars.shape().dims(), vec![2, 1]);
35//! ```
36//!
37//! # Performance
38//!
39//! The implementation uses optimized paths for contiguous tensors with manual loop unrolling
40//! for better performance. Non-contiguous tensors use stride-aware iteration to maintain
41//! correctness while preserving memory layout efficiency.
42//!
43//! # Gradient Tracking
44//!
45//! Both operations support automatic gradient tracking when `requires_grad` is enabled.
46//! The gradient computation follows the mathematical derivative of the variance operation.
47
48use crate::gradtrack::{GradEngine, GradFn};
49use crate::tensor::core::Tensor;
50
51impl Tensor {
52 /// Computes the variance over all elements
53 ///
54 /// The variance is calculated as the mean of squared differences from the mean.
55 /// This operation reduces the tensor to a scalar value \[1\].
56 ///
57 /// The implementation uses population variance (divides by n rather
58 /// than n-1) to match PyTorch's default behavior.
59 ///
60 /// # Returns
61 ///
62 /// A scalar tensor containing the variance value
63 ///
64 /// # Examples
65 ///
66 /// ```
67 /// use train_station::Tensor;
68 ///
69 /// // Basic variance calculation
70 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
71 /// let variance = tensor.var();
72 /// assert!((variance.get(&[0]) - 1.25).abs() < 1e-5);
73 /// ```
74 ///
75 /// ```
76 /// use train_station::Tensor;
77 ///
78 /// // Variance of a larger dataset
79 /// let data = vec![1.0, 3.0, 5.0, 7.0, 2.0, 4.0, 6.0, 8.0];
80 /// let tensor = Tensor::from_slice(&data, vec![2, 2, 2]).unwrap();
81 /// let variance = tensor.var();
82 /// // mean=4.5, var=mean([3.5², 1.5², 0.5², 2.5², 2.5², 0.5², 1.5², 3.5²]) = 5.25
83 /// assert!((variance.get(&[0]) - 5.25).abs() < 1e-5);
84 /// ```
85 ///
86 /// ```
87 /// use train_station::Tensor;
88 ///
89 /// // Variance of constant values (should be 0)
90 /// let tensor = Tensor::from_slice(&[5.0, 5.0, 5.0, 5.0], vec![4]).unwrap();
91 /// let variance = tensor.var();
92 /// assert!((variance.get(&[0]) - 0.0).abs() < 1e-6);
93 /// ```
94 ///
95 /// # Performance
96 ///
97 /// Uses optimized contiguous tensor path with manual loop unrolling for better
98 /// performance. Non-contiguous tensors use stride-aware iteration.
99 /// The algorithm performs two passes: first to compute the mean, then to
100 /// compute the variance.
101 #[track_caller]
102 pub fn var(&self) -> Tensor {
103 let mut out = Tensor::new(vec![1]);
104 if self.size() == 0 {
105 out.fill(0.0);
106 } else {
107 // mean
108 let mut mean_val = 0.0f32;
109 let n = self.size() as f32;
110
111 if self.is_contiguous() {
112 // Fast path for contiguous tensors
113 unsafe {
114 let src = self.as_ptr();
115 for i in 0..self.size() {
116 mean_val += *src.add(i);
117 }
118 }
119 } else {
120 // Stride-aware path for non-contiguous tensors
121 let dims = self.shape().dims().to_vec();
122 for flat_idx in 0..self.size() {
123 // Convert flat index to multi-dimensional coordinates
124 let mut coords = vec![0; dims.len()];
125 let mut tmp = flat_idx;
126 for k in (0..dims.len()).rev() {
127 coords[k] = tmp % dims[k];
128 tmp /= dims[k];
129 }
130
131 // Get value using stride-aware offset
132 let offset = self.shape().offset(&coords);
133 let value = unsafe { *self.as_ptr().add(offset) };
134 mean_val += value;
135 }
136 }
137 mean_val /= n;
138
139 // var
140 let mut var_val = 0.0f32;
141
142 if self.is_contiguous() {
143 // Fast path for contiguous tensors
144 unsafe {
145 let src = self.as_ptr();
146 for i in 0..self.size() {
147 let d = *src.add(i) - mean_val;
148 var_val += d * d;
149 }
150 }
151 } else {
152 // Stride-aware path for non-contiguous tensors
153 let dims = self.shape().dims().to_vec();
154 for flat_idx in 0..self.size() {
155 // Convert flat index to multi-dimensional coordinates
156 let mut coords = vec![0; dims.len()];
157 let mut tmp = flat_idx;
158 for k in (0..dims.len()).rev() {
159 coords[k] = tmp % dims[k];
160 tmp /= dims[k];
161 }
162
163 // Get value using stride-aware offset
164 let offset = self.shape().offset(&coords);
165 let value = unsafe { *self.as_ptr().add(offset) };
166 let d = value - mean_val;
167 var_val += d * d;
168 }
169 }
170 var_val /= n;
171
172 unsafe {
173 *out.as_mut_ptr() = var_val;
174 }
175 }
176
177 if self.requires_grad() {
178 let mut result = out.clone();
179 result.set_requires_grad_internal(true);
180 let mean_tensor = {
181 let mut t = Tensor::new(vec![1]);
182 if self.size() == 0 {
183 t.fill(0.0);
184 } else {
185 let mut acc = 0.0f32;
186
187 if self.is_contiguous() {
188 unsafe {
189 for i in 0..self.size() {
190 acc += *self.as_ptr().add(i);
191 }
192 }
193 } else {
194 let dims = self.shape().dims().to_vec();
195 for flat_idx in 0..self.size() {
196 // Convert flat index to multi-dimensional coordinates
197 let mut coords = vec![0; dims.len()];
198 let mut tmp = flat_idx;
199 for k in (0..dims.len()).rev() {
200 coords[k] = tmp % dims[k];
201 tmp /= dims[k];
202 }
203
204 // Get value using stride-aware offset
205 let offset = self.shape().offset(&coords);
206 let value = unsafe { *self.as_ptr().add(offset) };
207 acc += value;
208 }
209 }
210
211 unsafe {
212 *t.as_mut_ptr() = acc / (self.size() as f32);
213 }
214 }
215 t
216 };
217 let grad_fn = GradFn::ReduceVar {
218 saved_mean: Box::new(mean_tensor),
219 saved_input: Box::new(self.clone()),
220 input_shape: self.shape().dims().to_vec(),
221 };
222 result.set_grad_fn(grad_fn.clone());
223 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
224 return result;
225 }
226
227 out
228 }
229
230 /// Computes the variance over specified dimensions
231 ///
232 /// Reduces the tensor along the specified dimensions by computing the variance
233 /// of each slice. The result maintains the original tensor structure with
234 /// reduced dimensions optionally preserved as size-1 dimensions.
235 ///
236 /// Uses population variance (divides by n rather than n-1) to match
237 /// PyTorch's default behavior.
238 ///
239 /// # Arguments
240 ///
241 /// * `dims` - Vector of dimension indices to reduce over (must be valid for tensor rank)
242 /// * `keepdim` - Whether to keep reduced dimensions as size-1 dimensions
243 ///
244 /// # Returns
245 ///
246 /// A tensor with variance computed over the specified dimensions
247 ///
248 /// # Examples
249 ///
250 /// ```
251 /// use train_station::Tensor;
252 ///
253 /// // Variance along rows (dimension 1) with keepdim=true
254 /// let matrix = Tensor::from_slice(&[1.0, 3.0, 2.0, 2.0], vec![2, 2]).unwrap();
255 /// let row_vars = matrix.var_dims(&[1], true);
256 /// assert_eq!(row_vars.shape().dims(), vec![2, 1]);
257 /// assert!((row_vars.get(&[0, 0]) - 1.0).abs() < 1e-6); // var([1, 3]) = 1.0
258 /// assert!((row_vars.get(&[1, 0]) - 0.0).abs() < 1e-6); // var([2, 2]) = 0.0
259 /// ```
260 ///
261 /// ```
262 /// use train_station::Tensor;
263 ///
264 /// // Variance along columns (dimension 0) with keepdim=false
265 /// let matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
266 /// let col_vars = matrix.var_dims(&[0], false);
267 /// assert_eq!(col_vars.shape().dims(), vec![2]);
268 /// // var([1, 3]) = 1.0, var([2, 4]) = 1.0
269 /// assert!((col_vars.get(&[0]) - 1.0).abs() < 1e-6);
270 /// assert!((col_vars.get(&[1]) - 1.0).abs() < 1e-6);
271 /// ```
272 ///
273 /// ```
274 /// use train_station::Tensor;
275 ///
276 /// // Variance over multiple dimensions
277 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
278 /// let var_all = tensor.var_dims(&[0, 1], false);
279 /// assert_eq!(var_all.shape().dims(), vec![1]);
280 /// // var([1, 2, 3, 4]) = 1.25
281 /// assert!((var_all.get(&[0]) - 1.25).abs() < 1e-5);
282 /// ```
283 ///
284 /// # Panics
285 ///
286 /// * If `dims` is empty
287 /// * If any dimension index is out of bounds for the tensor rank
288 /// * If the reduced size is 0 (invalid for variance calculation)
289 ///
290 /// # Performance
291 ///
292 /// Uses efficient coordinate-based iteration that works correctly with
293 /// both contiguous and non-contiguous tensor layouts. The algorithm performs
294 /// two passes: first to compute means, then to compute variances.
295 #[track_caller]
296 pub fn var_dims(&self, dims: &[usize], keepdim: bool) -> Tensor {
297 assert!(!dims.is_empty(), "var_dims requires at least one dimension");
298 let rank = self.shape().rank();
299 for &d in dims {
300 assert!(
301 d < rank,
302 "var_dims dim {} out of bounds for rank {}",
303 d,
304 rank
305 );
306 }
307
308 // Output shape
309 let mut out_dims = self.shape().dims().to_vec();
310 let mut reduced: Vec<usize> = dims.to_vec();
311 reduced.sort_unstable();
312 reduced.dedup();
313 for &d in reduced.iter() {
314 out_dims[d] = if keepdim { 1 } else { 0 };
315 }
316 if !keepdim {
317 out_dims.retain(|&s| s != 0);
318 }
319 if out_dims.is_empty() {
320 out_dims.push(1);
321 }
322
323 let mut mean = Tensor::zeros(out_dims.clone());
324 let mut var = Tensor::zeros(out_dims.clone());
325
326 let in_shape = self.shape().dims().to_vec();
327 let out_rank = mean.shape().rank();
328 let mut in_coords = vec![0usize; rank];
329 let n_reduced: usize = reduced.iter().map(|&d| in_shape[d]).product();
330 assert!(n_reduced > 0, "reduced size must be > 0");
331 unsafe {
332 let mptr = mean.as_mut_ptr();
333 // sum for mean
334 for lin in 0..self.size() {
335 let mut tmp = lin;
336 for i in (0..rank).rev() {
337 let s = in_shape[i];
338 in_coords[i] = if s == 0 { 0 } else { tmp % s };
339 if s != 0 {
340 tmp /= s;
341 }
342 }
343 let mut out_coords: Vec<usize> = Vec::with_capacity(out_rank);
344 for (i, &ic) in in_coords.iter().enumerate().take(rank) {
345 if reduced.contains(&i) {
346 if keepdim {
347 out_coords.push(0);
348 }
349 } else {
350 out_coords.push(ic);
351 }
352 }
353 let off = if out_coords.is_empty() {
354 0
355 } else {
356 mean.shape().offset(&out_coords)
357 };
358 // Get input value using stride-aware offset
359 let in_offset = self.shape().offset(&in_coords);
360 let value = *self.as_ptr().add(in_offset);
361 *mptr.add(off) += value;
362 }
363 for i in 0..mean.size() {
364 *mptr.add(i) /= n_reduced as f32;
365 }
366 // accumulate squared diffs
367 let vptr = var.as_mut_ptr();
368 for lin in 0..self.size() {
369 let mut tmp = lin;
370 for i in (0..rank).rev() {
371 let s = in_shape[i];
372 in_coords[i] = if s == 0 { 0 } else { tmp % s };
373 if s != 0 {
374 tmp /= s;
375 }
376 }
377 let mut out_coords: Vec<usize> = Vec::with_capacity(out_rank);
378 for (i, &ic) in in_coords.iter().enumerate().take(rank) {
379 if reduced.contains(&i) {
380 if keepdim {
381 out_coords.push(0);
382 }
383 } else {
384 out_coords.push(ic);
385 }
386 }
387 let off = if out_coords.is_empty() {
388 0
389 } else {
390 var.shape().offset(&out_coords)
391 };
392 let mu = *mptr.add(off);
393
394 // Get input value using stride-aware offset
395 let in_offset = self.shape().offset(&in_coords);
396 let x = *self.as_ptr().add(in_offset);
397 *vptr.add(off) += (x - mu) * (x - mu);
398 }
399 for i in 0..var.size() {
400 *vptr.add(i) /= n_reduced as f32;
401 }
402 }
403
404 if self.requires_grad() {
405 let mut result = var.clone();
406 result.set_requires_grad_internal(true);
407 let grad_fn = GradFn::ReduceVarDims {
408 dims: reduced,
409 keepdim,
410 input_shape: self.shape().dims().to_vec(),
411 saved_mean: Box::new(mean),
412 saved_input: Box::new(self.clone()),
413 };
414 result.set_grad_fn(grad_fn.clone());
415 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
416 return result;
417 }
418
419 var
420 }
421}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426
427 #[test]
428 fn test_var_forward_basic() {
429 let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
430 let v = x.var();
431 unsafe {
432 let val = *v.as_ptr();
433 assert!((val - 1.25).abs() < 1e-6);
434 }
435 }
436
437 #[test]
438 fn test_var_dims_forward() {
439 let x = Tensor::from_slice(&[1.0, 3.0, 2.0, 2.0], vec![2, 2]).unwrap();
440 let v = x.var_dims(&[1], true);
441 assert_eq!(v.shape().dims(), vec![2, 1]);
442 assert!((v.get(&[0, 0]) - 1.0).abs() < 1e-6);
443 assert!((v.get(&[1, 0]) - 0.0).abs() < 1e-6);
444 }
445
446 #[test]
447 fn test_var_non_contiguous_transpose() {
448 // Test var on transposed tensor (non-contiguous view)
449 let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
450 // Original: [[1, 2, 3], [4, 5, 6]]
451
452 let x_t = x.transpose(0, 1);
453 // Transposed: [[1, 4], [2, 5], [3, 6]]
454 assert!(!x_t.is_contiguous()); // Should be a view
455
456 let var_orig = x.var();
457 let var_view = x_t.var();
458
459 // Both should give the same result
460 assert!((var_orig.get(&[0]) - var_view.get(&[0])).abs() < 1e-6);
461
462 // Expected var of [1,2,3,4,5,6]: mean=3.5, var=mean([2.5^2,1.5^2,0.5^2,0.5^2,1.5^2,2.5^2])=2.9167
463 let expected_var = 2.9166667_f32;
464 assert!((var_orig.get(&[0]) - expected_var).abs() < 1e-5);
465 }
466
467 #[test]
468 fn test_var_dims_non_contiguous() {
469 // Test var_dims on non-contiguous tensor
470 let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
471 let x_t = x.transpose(0, 1); // [3, 2]
472 assert!(!x_t.is_contiguous());
473
474 // Var along dim 0 of transposed tensor
475 let var_dim0 = x_t.var_dims(&[0], false);
476 assert_eq!(var_dim0.shape().dims(), vec![2]);
477
478 // For dim 0: [1,2,3] and [4,5,6]
479 // [1,2,3]: mean=2, var=((1-2)^2 + (2-2)^2 + (3-2)^2)/3 = 2/3 ≈ 0.6667
480 // [4,5,6]: mean=5, var=((4-5)^2 + (5-5)^2 + (6-5)^2)/3 = 2/3 ≈ 0.6667
481 let expected_var = 2.0 / 3.0_f32;
482 assert!((var_dim0.get(&[0]) - expected_var).abs() < 1e-5);
483 assert!((var_dim0.get(&[1]) - expected_var).abs() < 1e-5);
484 }
485
486 #[test]
487 fn test_var_permuted_tensor() {
488 // Test with permuted tensor - simple case with known var
489 let data = vec![1.0, 3.0, 5.0, 7.0, 2.0, 4.0, 6.0, 8.0];
490 let x = Tensor::from_slice(&data, vec![2, 2, 2]).unwrap();
491
492 // Permute dimensions [2, 2, 2] -> [2, 2, 2] (swap first and last)
493 let x_perm = x.permute(vec![2, 1, 0]);
494 assert!(!x_perm.is_contiguous());
495
496 let var_orig = x.var();
497 let var_perm = x_perm.var();
498
499 // Should give same result
500 assert!((var_orig.get(&[0]) - var_perm.get(&[0])).abs() < 1e-6);
501
502 // Data is [1,3,5,7,2,4,6,8], mean=4.5
503 // var = mean([3.5^2, 1.5^2, 0.5^2, 2.5^2, 2.5^2, 0.5^2, 1.5^2, 3.5^2]) = 5.25
504 let expected_var = 5.25_f32;
505 assert!((var_orig.get(&[0]) - expected_var).abs() < 1e-5);
506 }
507}