train_station/tensor/reductions/
var.rs

1//! Variance reduction operations for tensors
2//!
3//! This module provides variance reduction operations for tensors.
4//! The variance measures the average squared deviation from the mean,
5//! calculated as the mean of squared differences from the mean. This is
6//! commonly used in statistics, data analysis, and machine learning for
7//! understanding data variability and as a component of other statistical
8//! measures like standard deviation.
9//!
10//! # Operations
11//!
12//! * `var()` - Computes variance over all elements, returning a scalar tensor
13//! * `var_dims()` - Computes variance over specified dimensions with optional dimension preservation
14//!
15//! # Statistical Details
16//!
17//! The implementation uses population variance (unbiased=false), which
18//! divides by n rather than n-1. This matches PyTorch's default behavior for
19//! consistency with the reference implementation.
20//!
21//! # Examples
22//!
23//! ```
24//! use train_station::Tensor;
25//!
26//! // Compute variance of all elements
27//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
28//! let variance = tensor.var();
29//! assert!((variance.get(&[0]) - 1.25).abs() < 1e-5);
30//!
31//! // Compute variance along specific dimensions
32//! let matrix = Tensor::from_slice(&[1.0, 3.0, 2.0, 2.0], vec![2, 2]).unwrap();
33//! let row_vars = matrix.var_dims(&[1], true);
34//! assert_eq!(row_vars.shape().dims, vec![2, 1]);
35//! ```
36//!
37//! # Performance
38//!
39//! The implementation uses optimized paths for contiguous tensors with manual loop unrolling
40//! for better performance. Non-contiguous tensors use stride-aware iteration to maintain
41//! correctness while preserving memory layout efficiency.
42//!
43//! # Gradient Tracking
44//!
45//! Both operations support automatic gradient tracking when `requires_grad` is enabled.
46//! The gradient computation follows the mathematical derivative of the variance operation.
47
48use crate::gradtrack::{GradEngine, GradFn};
49use crate::tensor::core::Tensor;
50
51impl Tensor {
52    /// Computes the variance over all elements
53    ///
54    /// The variance is calculated as the mean of squared differences from the mean.
55    /// This operation reduces the tensor to a scalar value \[1\].
56    ///
57    /// The implementation uses population variance (divides by n rather
58    /// than n-1) to match PyTorch's default behavior.
59    ///
60    /// # Returns
61    ///
62    /// A scalar tensor containing the variance value
63    ///
64    /// # Examples
65    ///
66    /// ```
67    /// use train_station::Tensor;
68    ///
69    /// // Basic variance calculation
70    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
71    /// let variance = tensor.var();
72    /// assert!((variance.get(&[0]) - 1.25).abs() < 1e-5);
73    /// ```
74    ///
75    /// ```
76    /// use train_station::Tensor;
77    ///
78    /// // Variance of a larger dataset
79    /// let data = vec![1.0, 3.0, 5.0, 7.0, 2.0, 4.0, 6.0, 8.0];
80    /// let tensor = Tensor::from_slice(&data, vec![2, 2, 2]).unwrap();
81    /// let variance = tensor.var();
82    /// // mean=4.5, var=mean([3.5², 1.5², 0.5², 2.5², 2.5², 0.5², 1.5², 3.5²]) = 5.25
83    /// assert!((variance.get(&[0]) - 5.25).abs() < 1e-5);
84    /// ```
85    ///
86    /// ```
87    /// use train_station::Tensor;
88    ///
89    /// // Variance of constant values (should be 0)
90    /// let tensor = Tensor::from_slice(&[5.0, 5.0, 5.0, 5.0], vec![4]).unwrap();
91    /// let variance = tensor.var();
92    /// assert!((variance.get(&[0]) - 0.0).abs() < 1e-6);
93    /// ```
94    ///
95    /// # Performance
96    ///
97    /// Uses optimized contiguous tensor path with manual loop unrolling for better
98    /// performance. Non-contiguous tensors use stride-aware iteration.
99    /// The algorithm performs two passes: first to compute the mean, then to
100    /// compute the variance.
101    pub fn var(&self) -> Tensor {
102        let mut out = Tensor::new(vec![1]);
103        if self.size() == 0 {
104            out.fill(0.0);
105        } else {
106            // mean
107            let mut mean_val = 0.0f32;
108            let n = self.size() as f32;
109
110            if self.is_contiguous() {
111                // Fast path for contiguous tensors
112                unsafe {
113                    let src = self.as_ptr();
114                    for i in 0..self.size() {
115                        mean_val += *src.add(i);
116                    }
117                }
118            } else {
119                // Stride-aware path for non-contiguous tensors
120                let dims = self.shape().dims.clone();
121                for flat_idx in 0..self.size() {
122                    // Convert flat index to multi-dimensional coordinates
123                    let mut coords = vec![0; dims.len()];
124                    let mut tmp = flat_idx;
125                    for k in (0..dims.len()).rev() {
126                        coords[k] = tmp % dims[k];
127                        tmp /= dims[k];
128                    }
129
130                    // Get value using stride-aware offset
131                    let offset = self.shape().offset(&coords);
132                    let value = unsafe { *self.as_ptr().add(offset) };
133                    mean_val += value;
134                }
135            }
136            mean_val /= n;
137
138            // var
139            let mut var_val = 0.0f32;
140
141            if self.is_contiguous() {
142                // Fast path for contiguous tensors
143                unsafe {
144                    let src = self.as_ptr();
145                    for i in 0..self.size() {
146                        let d = *src.add(i) - mean_val;
147                        var_val += d * d;
148                    }
149                }
150            } else {
151                // Stride-aware path for non-contiguous tensors
152                let dims = self.shape().dims.clone();
153                for flat_idx in 0..self.size() {
154                    // Convert flat index to multi-dimensional coordinates
155                    let mut coords = vec![0; dims.len()];
156                    let mut tmp = flat_idx;
157                    for k in (0..dims.len()).rev() {
158                        coords[k] = tmp % dims[k];
159                        tmp /= dims[k];
160                    }
161
162                    // Get value using stride-aware offset
163                    let offset = self.shape().offset(&coords);
164                    let value = unsafe { *self.as_ptr().add(offset) };
165                    let d = value - mean_val;
166                    var_val += d * d;
167                }
168            }
169            var_val /= n;
170
171            unsafe {
172                *out.as_mut_ptr() = var_val;
173            }
174        }
175
176        if self.requires_grad() {
177            let mut result = out.clone();
178            result.set_requires_grad_internal(true);
179            let mean_tensor = {
180                let mut t = Tensor::new(vec![1]);
181                if self.size() == 0 {
182                    t.fill(0.0);
183                } else {
184                    let mut acc = 0.0f32;
185
186                    if self.is_contiguous() {
187                        unsafe {
188                            for i in 0..self.size() {
189                                acc += *self.as_ptr().add(i);
190                            }
191                        }
192                    } else {
193                        let dims = self.shape().dims.clone();
194                        for flat_idx in 0..self.size() {
195                            // Convert flat index to multi-dimensional coordinates
196                            let mut coords = vec![0; dims.len()];
197                            let mut tmp = flat_idx;
198                            for k in (0..dims.len()).rev() {
199                                coords[k] = tmp % dims[k];
200                                tmp /= dims[k];
201                            }
202
203                            // Get value using stride-aware offset
204                            let offset = self.shape().offset(&coords);
205                            let value = unsafe { *self.as_ptr().add(offset) };
206                            acc += value;
207                        }
208                    }
209
210                    unsafe {
211                        *t.as_mut_ptr() = acc / (self.size() as f32);
212                    }
213                }
214                t
215            };
216            let grad_fn = GradFn::ReduceVar {
217                saved_mean: Box::new(mean_tensor),
218                saved_input: Box::new(self.clone()),
219                input_shape: self.shape().dims.clone(),
220            };
221            result.set_grad_fn(grad_fn.clone());
222            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
223            return result;
224        }
225
226        out
227    }
228
229    /// Computes the variance over specified dimensions
230    ///
231    /// Reduces the tensor along the specified dimensions by computing the variance
232    /// of each slice. The result maintains the original tensor structure with
233    /// reduced dimensions optionally preserved as size-1 dimensions.
234    ///
235    /// Uses population variance (divides by n rather than n-1) to match
236    /// PyTorch's default behavior.
237    ///
238    /// # Arguments
239    ///
240    /// * `dims` - Vector of dimension indices to reduce over (must be valid for tensor rank)
241    /// * `keepdim` - Whether to keep reduced dimensions as size-1 dimensions
242    ///
243    /// # Returns
244    ///
245    /// A tensor with variance computed over the specified dimensions
246    ///
247    /// # Examples
248    ///
249    /// ```
250    /// use train_station::Tensor;
251    ///
252    /// // Variance along rows (dimension 1) with keepdim=true
253    /// let matrix = Tensor::from_slice(&[1.0, 3.0, 2.0, 2.0], vec![2, 2]).unwrap();
254    /// let row_vars = matrix.var_dims(&[1], true);
255    /// assert_eq!(row_vars.shape().dims, vec![2, 1]);
256    /// assert!((row_vars.get(&[0, 0]) - 1.0).abs() < 1e-6); // var([1, 3]) = 1.0
257    /// assert!((row_vars.get(&[1, 0]) - 0.0).abs() < 1e-6); // var([2, 2]) = 0.0
258    /// ```
259    ///
260    /// ```
261    /// use train_station::Tensor;
262    ///
263    /// // Variance along columns (dimension 0) with keepdim=false
264    /// let matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
265    /// let col_vars = matrix.var_dims(&[0], false);
266    /// assert_eq!(col_vars.shape().dims, vec![2]);
267    /// // var([1, 3]) = 1.0, var([2, 4]) = 1.0
268    /// assert!((col_vars.get(&[0]) - 1.0).abs() < 1e-6);
269    /// assert!((col_vars.get(&[1]) - 1.0).abs() < 1e-6);
270    /// ```
271    ///
272    /// ```
273    /// use train_station::Tensor;
274    ///
275    /// // Variance over multiple dimensions
276    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
277    /// let var_all = tensor.var_dims(&[0, 1], false);
278    /// assert_eq!(var_all.shape().dims, vec![1]);
279    /// // var([1, 2, 3, 4]) = 1.25
280    /// assert!((var_all.get(&[0]) - 1.25).abs() < 1e-5);
281    /// ```
282    ///
283    /// # Panics
284    ///
285    /// * If `dims` is empty
286    /// * If any dimension index is out of bounds for the tensor rank
287    /// * If the reduced size is 0 (invalid for variance calculation)
288    ///
289    /// # Performance
290    ///
291    /// Uses efficient coordinate-based iteration that works correctly with
292    /// both contiguous and non-contiguous tensor layouts. The algorithm performs
293    /// two passes: first to compute means, then to compute variances.
294    pub fn var_dims(&self, dims: &[usize], keepdim: bool) -> Tensor {
295        assert!(!dims.is_empty(), "var_dims requires at least one dimension");
296        let rank = self.shape().rank();
297        for &d in dims {
298            assert!(
299                d < rank,
300                "var_dims dim {} out of bounds for rank {}",
301                d,
302                rank
303            );
304        }
305
306        // Output shape
307        let mut out_dims = self.shape().dims.clone();
308        let mut reduced: Vec<usize> = dims.to_vec();
309        reduced.sort_unstable();
310        reduced.dedup();
311        for &d in reduced.iter() {
312            out_dims[d] = if keepdim { 1 } else { 0 };
313        }
314        if !keepdim {
315            out_dims.retain(|&s| s != 0);
316        }
317        if out_dims.is_empty() {
318            out_dims.push(1);
319        }
320
321        let mut mean = Tensor::zeros(out_dims.clone());
322        let mut var = Tensor::zeros(out_dims.clone());
323
324        let in_shape = self.shape().dims.clone();
325        let out_rank = mean.shape().rank();
326        let mut in_coords = vec![0usize; rank];
327        let n_reduced: usize = reduced.iter().map(|&d| in_shape[d]).product();
328        assert!(n_reduced > 0, "reduced size must be > 0");
329        unsafe {
330            let mptr = mean.as_mut_ptr();
331            // sum for mean
332            for lin in 0..self.size() {
333                let mut tmp = lin;
334                for i in (0..rank).rev() {
335                    let s = in_shape[i];
336                    in_coords[i] = if s == 0 { 0 } else { tmp % s };
337                    if s != 0 {
338                        tmp /= s;
339                    }
340                }
341                let mut out_coords: Vec<usize> = Vec::with_capacity(out_rank);
342                for (i, &ic) in in_coords.iter().enumerate().take(rank) {
343                    if reduced.contains(&i) {
344                        if keepdim {
345                            out_coords.push(0);
346                        }
347                    } else {
348                        out_coords.push(ic);
349                    }
350                }
351                let off = if out_coords.is_empty() {
352                    0
353                } else {
354                    mean.shape().offset(&out_coords)
355                };
356                // Get input value using stride-aware offset
357                let in_offset = self.shape().offset(&in_coords);
358                let value = *self.as_ptr().add(in_offset);
359                *mptr.add(off) += value;
360            }
361            for i in 0..mean.size() {
362                *mptr.add(i) /= n_reduced as f32;
363            }
364            // accumulate squared diffs
365            let vptr = var.as_mut_ptr();
366            for lin in 0..self.size() {
367                let mut tmp = lin;
368                for i in (0..rank).rev() {
369                    let s = in_shape[i];
370                    in_coords[i] = if s == 0 { 0 } else { tmp % s };
371                    if s != 0 {
372                        tmp /= s;
373                    }
374                }
375                let mut out_coords: Vec<usize> = Vec::with_capacity(out_rank);
376                for (i, &ic) in in_coords.iter().enumerate().take(rank) {
377                    if reduced.contains(&i) {
378                        if keepdim {
379                            out_coords.push(0);
380                        }
381                    } else {
382                        out_coords.push(ic);
383                    }
384                }
385                let off = if out_coords.is_empty() {
386                    0
387                } else {
388                    var.shape().offset(&out_coords)
389                };
390                let mu = *mptr.add(off);
391
392                // Get input value using stride-aware offset
393                let in_offset = self.shape().offset(&in_coords);
394                let x = *self.as_ptr().add(in_offset);
395                *vptr.add(off) += (x - mu) * (x - mu);
396            }
397            for i in 0..var.size() {
398                *vptr.add(i) /= n_reduced as f32;
399            }
400        }
401
402        if self.requires_grad() {
403            let mut result = var.clone();
404            result.set_requires_grad_internal(true);
405            let grad_fn = GradFn::ReduceVarDims {
406                dims: reduced,
407                keepdim,
408                input_shape: self.shape().dims.clone(),
409                saved_mean: Box::new(mean),
410                saved_input: Box::new(self.clone()),
411            };
412            result.set_grad_fn(grad_fn.clone());
413            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
414            return result;
415        }
416
417        var
418    }
419}
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424
425    #[test]
426    fn test_var_forward_basic() {
427        let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
428        let v = x.var();
429        unsafe {
430            let val = *v.as_ptr();
431            assert!((val - 1.25).abs() < 1e-6);
432        }
433    }
434
435    #[test]
436    fn test_var_dims_forward() {
437        let x = Tensor::from_slice(&[1.0, 3.0, 2.0, 2.0], vec![2, 2]).unwrap();
438        let v = x.var_dims(&[1], true);
439        assert_eq!(v.shape().dims, vec![2, 1]);
440        assert!((v.get(&[0, 0]) - 1.0).abs() < 1e-6);
441        assert!((v.get(&[1, 0]) - 0.0).abs() < 1e-6);
442    }
443
444    #[test]
445    fn test_var_non_contiguous_transpose() {
446        // Test var on transposed tensor (non-contiguous view)
447        let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
448        // Original: [[1, 2, 3], [4, 5, 6]]
449
450        let x_t = x.transpose(0, 1);
451        // Transposed: [[1, 4], [2, 5], [3, 6]]
452        assert!(!x_t.is_contiguous()); // Should be a view
453
454        let var_orig = x.var();
455        let var_view = x_t.var();
456
457        // Both should give the same result
458        assert!((var_orig.get(&[0]) - var_view.get(&[0])).abs() < 1e-6);
459
460        // Expected var of [1,2,3,4,5,6]: mean=3.5, var=mean([2.5^2,1.5^2,0.5^2,0.5^2,1.5^2,2.5^2])=2.9167
461        let expected_var = 2.9166667_f32;
462        assert!((var_orig.get(&[0]) - expected_var).abs() < 1e-5);
463    }
464
465    #[test]
466    fn test_var_dims_non_contiguous() {
467        // Test var_dims on non-contiguous tensor
468        let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
469        let x_t = x.transpose(0, 1); // [3, 2]
470        assert!(!x_t.is_contiguous());
471
472        // Var along dim 0 of transposed tensor
473        let var_dim0 = x_t.var_dims(&[0], false);
474        assert_eq!(var_dim0.shape().dims, vec![2]);
475
476        // For dim 0: [1,2,3] and [4,5,6]
477        // [1,2,3]: mean=2, var=((1-2)^2 + (2-2)^2 + (3-2)^2)/3 = 2/3 ≈ 0.6667
478        // [4,5,6]: mean=5, var=((4-5)^2 + (5-5)^2 + (6-5)^2)/3 = 2/3 ≈ 0.6667
479        let expected_var = 2.0 / 3.0_f32;
480        assert!((var_dim0.get(&[0]) - expected_var).abs() < 1e-5);
481        assert!((var_dim0.get(&[1]) - expected_var).abs() < 1e-5);
482    }
483
484    #[test]
485    fn test_var_permuted_tensor() {
486        // Test with permuted tensor - simple case with known var
487        let data = vec![1.0, 3.0, 5.0, 7.0, 2.0, 4.0, 6.0, 8.0];
488        let x = Tensor::from_slice(&data, vec![2, 2, 2]).unwrap();
489
490        // Permute dimensions [2, 2, 2] -> [2, 2, 2] (swap first and last)
491        let x_perm = x.permute(vec![2, 1, 0]);
492        assert!(!x_perm.is_contiguous());
493
494        let var_orig = x.var();
495        let var_perm = x_perm.var();
496
497        // Should give same result
498        assert!((var_orig.get(&[0]) - var_perm.get(&[0])).abs() < 1e-6);
499
500        // Data is [1,3,5,7,2,4,6,8], mean=4.5
501        // var = mean([3.5^2, 1.5^2, 0.5^2, 2.5^2, 2.5^2, 0.5^2, 1.5^2, 3.5^2]) = 5.25
502        let expected_var = 5.25_f32;
503        assert!((var_orig.get(&[0]) - expected_var).abs() < 1e-5);
504    }
505}