train_station/tensor/reductions/
var.rs

1//! Variance reduction operations for tensors
2//!
3//! This module provides variance reduction operations for tensors.
4//! The variance measures the average squared deviation from the mean,
5//! calculated as the mean of squared differences from the mean. This is
6//! commonly used in statistics, data analysis, and machine learning for
7//! understanding data variability and as a component of other statistical
8//! measures like standard deviation.
9//!
10//! # Operations
11//!
12//! * `var()` - Computes variance over all elements, returning a scalar tensor
13//! * `var_dims()` - Computes variance over specified dimensions with optional dimension preservation
14//!
15//! # Statistical Details
16//!
17//! The implementation uses population variance (unbiased=false), which
18//! divides by n rather than n-1. This matches PyTorch's default behavior for
19//! consistency with the reference implementation.
20//!
21//! # Examples
22//!
23//! ```
24//! use train_station::Tensor;
25//!
26//! // Compute variance of all elements
27//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
28//! let variance = tensor.var();
29//! assert!((variance.get(&[0]) - 1.25).abs() < 1e-5);
30//!
31//! // Compute variance along specific dimensions
32//! let matrix = Tensor::from_slice(&[1.0, 3.0, 2.0, 2.0], vec![2, 2]).unwrap();
33//! let row_vars = matrix.var_dims(&[1], true);
34//! assert_eq!(row_vars.shape().dims(), vec![2, 1]);
35//! ```
36//!
37//! # Performance
38//!
39//! The implementation uses optimized paths for contiguous tensors with manual loop unrolling
40//! for better performance. Non-contiguous tensors use stride-aware iteration to maintain
41//! correctness while preserving memory layout efficiency.
42//!
43//! # Gradient Tracking
44//!
45//! Both operations support automatic gradient tracking when `requires_grad` is enabled.
46//! The gradient computation follows the mathematical derivative of the variance operation.
47
48use crate::gradtrack::{GradEngine, GradFn};
49use crate::tensor::core::Tensor;
50
51impl Tensor {
52    /// Computes the variance over all elements
53    ///
54    /// The variance is calculated as the mean of squared differences from the mean.
55    /// This operation reduces the tensor to a scalar value \[1\].
56    ///
57    /// The implementation uses population variance (divides by n rather
58    /// than n-1) to match PyTorch's default behavior.
59    ///
60    /// # Returns
61    ///
62    /// A scalar tensor containing the variance value
63    ///
64    /// # Examples
65    ///
66    /// ```
67    /// use train_station::Tensor;
68    ///
69    /// // Basic variance calculation
70    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
71    /// let variance = tensor.var();
72    /// assert!((variance.get(&[0]) - 1.25).abs() < 1e-5);
73    /// ```
74    ///
75    /// ```
76    /// use train_station::Tensor;
77    ///
78    /// // Variance of a larger dataset
79    /// let data = vec![1.0, 3.0, 5.0, 7.0, 2.0, 4.0, 6.0, 8.0];
80    /// let tensor = Tensor::from_slice(&data, vec![2, 2, 2]).unwrap();
81    /// let variance = tensor.var();
82    /// // mean=4.5, var=mean([3.5², 1.5², 0.5², 2.5², 2.5², 0.5², 1.5², 3.5²]) = 5.25
83    /// assert!((variance.get(&[0]) - 5.25).abs() < 1e-5);
84    /// ```
85    ///
86    /// ```
87    /// use train_station::Tensor;
88    ///
89    /// // Variance of constant values (should be 0)
90    /// let tensor = Tensor::from_slice(&[5.0, 5.0, 5.0, 5.0], vec![4]).unwrap();
91    /// let variance = tensor.var();
92    /// assert!((variance.get(&[0]) - 0.0).abs() < 1e-6);
93    /// ```
94    ///
95    /// # Performance
96    ///
97    /// Uses optimized contiguous tensor path with manual loop unrolling for better
98    /// performance. Non-contiguous tensors use stride-aware iteration.
99    /// The algorithm performs two passes: first to compute the mean, then to
100    /// compute the variance.
101    #[track_caller]
102    pub fn var(&self) -> Tensor {
103        let mut out = Tensor::new(vec![1]);
104        if self.size() == 0 {
105            out.fill(0.0);
106        } else {
107            // mean
108            let mut mean_val = 0.0f32;
109            let n = self.size() as f32;
110
111            if self.is_contiguous() {
112                // Fast path for contiguous tensors
113                unsafe {
114                    let src = self.as_ptr();
115                    for i in 0..self.size() {
116                        mean_val += *src.add(i);
117                    }
118                }
119            } else {
120                // Stride-aware path for non-contiguous tensors
121                let dims = self.shape().dims().to_vec();
122                for flat_idx in 0..self.size() {
123                    // Convert flat index to multi-dimensional coordinates
124                    let mut coords = vec![0; dims.len()];
125                    let mut tmp = flat_idx;
126                    for k in (0..dims.len()).rev() {
127                        coords[k] = tmp % dims[k];
128                        tmp /= dims[k];
129                    }
130
131                    // Get value using stride-aware offset
132                    let offset = self.shape().offset(&coords);
133                    let value = unsafe { *self.as_ptr().add(offset) };
134                    mean_val += value;
135                }
136            }
137            mean_val /= n;
138
139            // var
140            let mut var_val = 0.0f32;
141
142            if self.is_contiguous() {
143                // Fast path for contiguous tensors
144                unsafe {
145                    let src = self.as_ptr();
146                    for i in 0..self.size() {
147                        let d = *src.add(i) - mean_val;
148                        var_val += d * d;
149                    }
150                }
151            } else {
152                // Stride-aware path for non-contiguous tensors
153                let dims = self.shape().dims().to_vec();
154                for flat_idx in 0..self.size() {
155                    // Convert flat index to multi-dimensional coordinates
156                    let mut coords = vec![0; dims.len()];
157                    let mut tmp = flat_idx;
158                    for k in (0..dims.len()).rev() {
159                        coords[k] = tmp % dims[k];
160                        tmp /= dims[k];
161                    }
162
163                    // Get value using stride-aware offset
164                    let offset = self.shape().offset(&coords);
165                    let value = unsafe { *self.as_ptr().add(offset) };
166                    let d = value - mean_val;
167                    var_val += d * d;
168                }
169            }
170            var_val /= n;
171
172            unsafe {
173                *out.as_mut_ptr() = var_val;
174            }
175        }
176
177        if self.requires_grad() {
178            let mut result = out.clone();
179            result.set_requires_grad_internal(true);
180            let mean_tensor = {
181                let mut t = Tensor::new(vec![1]);
182                if self.size() == 0 {
183                    t.fill(0.0);
184                } else {
185                    let mut acc = 0.0f32;
186
187                    if self.is_contiguous() {
188                        unsafe {
189                            for i in 0..self.size() {
190                                acc += *self.as_ptr().add(i);
191                            }
192                        }
193                    } else {
194                        let dims = self.shape().dims().to_vec();
195                        for flat_idx in 0..self.size() {
196                            // Convert flat index to multi-dimensional coordinates
197                            let mut coords = vec![0; dims.len()];
198                            let mut tmp = flat_idx;
199                            for k in (0..dims.len()).rev() {
200                                coords[k] = tmp % dims[k];
201                                tmp /= dims[k];
202                            }
203
204                            // Get value using stride-aware offset
205                            let offset = self.shape().offset(&coords);
206                            let value = unsafe { *self.as_ptr().add(offset) };
207                            acc += value;
208                        }
209                    }
210
211                    unsafe {
212                        *t.as_mut_ptr() = acc / (self.size() as f32);
213                    }
214                }
215                t
216            };
217            let grad_fn = GradFn::ReduceVar {
218                saved_mean: Box::new(mean_tensor),
219                saved_input: Box::new(self.clone()),
220                input_shape: self.shape().dims().to_vec(),
221            };
222            result.set_grad_fn(grad_fn.clone());
223            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
224            return result;
225        }
226
227        out
228    }
229
230    /// Computes the variance over specified dimensions
231    ///
232    /// Reduces the tensor along the specified dimensions by computing the variance
233    /// of each slice. The result maintains the original tensor structure with
234    /// reduced dimensions optionally preserved as size-1 dimensions.
235    ///
236    /// Uses population variance (divides by n rather than n-1) to match
237    /// PyTorch's default behavior.
238    ///
239    /// # Arguments
240    ///
241    /// * `dims` - Vector of dimension indices to reduce over (must be valid for tensor rank)
242    /// * `keepdim` - Whether to keep reduced dimensions as size-1 dimensions
243    ///
244    /// # Returns
245    ///
246    /// A tensor with variance computed over the specified dimensions
247    ///
248    /// # Examples
249    ///
250    /// ```
251    /// use train_station::Tensor;
252    ///
253    /// // Variance along rows (dimension 1) with keepdim=true
254    /// let matrix = Tensor::from_slice(&[1.0, 3.0, 2.0, 2.0], vec![2, 2]).unwrap();
255    /// let row_vars = matrix.var_dims(&[1], true);
256    /// assert_eq!(row_vars.shape().dims(), vec![2, 1]);
257    /// assert!((row_vars.get(&[0, 0]) - 1.0).abs() < 1e-6); // var([1, 3]) = 1.0
258    /// assert!((row_vars.get(&[1, 0]) - 0.0).abs() < 1e-6); // var([2, 2]) = 0.0
259    /// ```
260    ///
261    /// ```
262    /// use train_station::Tensor;
263    ///
264    /// // Variance along columns (dimension 0) with keepdim=false
265    /// let matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
266    /// let col_vars = matrix.var_dims(&[0], false);
267    /// assert_eq!(col_vars.shape().dims(), vec![2]);
268    /// // var([1, 3]) = 1.0, var([2, 4]) = 1.0
269    /// assert!((col_vars.get(&[0]) - 1.0).abs() < 1e-6);
270    /// assert!((col_vars.get(&[1]) - 1.0).abs() < 1e-6);
271    /// ```
272    ///
273    /// ```
274    /// use train_station::Tensor;
275    ///
276    /// // Variance over multiple dimensions
277    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
278    /// let var_all = tensor.var_dims(&[0, 1], false);
279    /// assert_eq!(var_all.shape().dims(), vec![1]);
280    /// // var([1, 2, 3, 4]) = 1.25
281    /// assert!((var_all.get(&[0]) - 1.25).abs() < 1e-5);
282    /// ```
283    ///
284    /// # Panics
285    ///
286    /// * If `dims` is empty
287    /// * If any dimension index is out of bounds for the tensor rank
288    /// * If the reduced size is 0 (invalid for variance calculation)
289    ///
290    /// # Performance
291    ///
292    /// Uses efficient coordinate-based iteration that works correctly with
293    /// both contiguous and non-contiguous tensor layouts. The algorithm performs
294    /// two passes: first to compute means, then to compute variances.
295    #[track_caller]
296    pub fn var_dims(&self, dims: &[usize], keepdim: bool) -> Tensor {
297        assert!(!dims.is_empty(), "var_dims requires at least one dimension");
298        let rank = self.shape().rank();
299        for &d in dims {
300            assert!(
301                d < rank,
302                "var_dims dim {} out of bounds for rank {}",
303                d,
304                rank
305            );
306        }
307
308        // Output shape
309        let mut out_dims = self.shape().dims().to_vec();
310        let mut reduced: Vec<usize> = dims.to_vec();
311        reduced.sort_unstable();
312        reduced.dedup();
313        for &d in reduced.iter() {
314            out_dims[d] = if keepdim { 1 } else { 0 };
315        }
316        if !keepdim {
317            out_dims.retain(|&s| s != 0);
318        }
319        if out_dims.is_empty() {
320            out_dims.push(1);
321        }
322
323        let mut mean = Tensor::zeros(out_dims.clone());
324        let mut var = Tensor::zeros(out_dims.clone());
325
326        let in_shape = self.shape().dims().to_vec();
327        let out_rank = mean.shape().rank();
328        let mut in_coords = vec![0usize; rank];
329        let n_reduced: usize = reduced.iter().map(|&d| in_shape[d]).product();
330        assert!(n_reduced > 0, "reduced size must be > 0");
331        unsafe {
332            let mptr = mean.as_mut_ptr();
333            // sum for mean
334            for lin in 0..self.size() {
335                let mut tmp = lin;
336                for i in (0..rank).rev() {
337                    let s = in_shape[i];
338                    in_coords[i] = if s == 0 { 0 } else { tmp % s };
339                    if s != 0 {
340                        tmp /= s;
341                    }
342                }
343                let mut out_coords: Vec<usize> = Vec::with_capacity(out_rank);
344                for (i, &ic) in in_coords.iter().enumerate().take(rank) {
345                    if reduced.contains(&i) {
346                        if keepdim {
347                            out_coords.push(0);
348                        }
349                    } else {
350                        out_coords.push(ic);
351                    }
352                }
353                let off = if out_coords.is_empty() {
354                    0
355                } else {
356                    mean.shape().offset(&out_coords)
357                };
358                // Get input value using stride-aware offset
359                let in_offset = self.shape().offset(&in_coords);
360                let value = *self.as_ptr().add(in_offset);
361                *mptr.add(off) += value;
362            }
363            for i in 0..mean.size() {
364                *mptr.add(i) /= n_reduced as f32;
365            }
366            // accumulate squared diffs
367            let vptr = var.as_mut_ptr();
368            for lin in 0..self.size() {
369                let mut tmp = lin;
370                for i in (0..rank).rev() {
371                    let s = in_shape[i];
372                    in_coords[i] = if s == 0 { 0 } else { tmp % s };
373                    if s != 0 {
374                        tmp /= s;
375                    }
376                }
377                let mut out_coords: Vec<usize> = Vec::with_capacity(out_rank);
378                for (i, &ic) in in_coords.iter().enumerate().take(rank) {
379                    if reduced.contains(&i) {
380                        if keepdim {
381                            out_coords.push(0);
382                        }
383                    } else {
384                        out_coords.push(ic);
385                    }
386                }
387                let off = if out_coords.is_empty() {
388                    0
389                } else {
390                    var.shape().offset(&out_coords)
391                };
392                let mu = *mptr.add(off);
393
394                // Get input value using stride-aware offset
395                let in_offset = self.shape().offset(&in_coords);
396                let x = *self.as_ptr().add(in_offset);
397                *vptr.add(off) += (x - mu) * (x - mu);
398            }
399            for i in 0..var.size() {
400                *vptr.add(i) /= n_reduced as f32;
401            }
402        }
403
404        if self.requires_grad() {
405            let mut result = var.clone();
406            result.set_requires_grad_internal(true);
407            let grad_fn = GradFn::ReduceVarDims {
408                dims: reduced,
409                keepdim,
410                input_shape: self.shape().dims().to_vec(),
411                saved_mean: Box::new(mean),
412                saved_input: Box::new(self.clone()),
413            };
414            result.set_grad_fn(grad_fn.clone());
415            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
416            return result;
417        }
418
419        var
420    }
421}
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426
427    #[test]
428    fn test_var_forward_basic() {
429        let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
430        let v = x.var();
431        unsafe {
432            let val = *v.as_ptr();
433            assert!((val - 1.25).abs() < 1e-6);
434        }
435    }
436
437    #[test]
438    fn test_var_dims_forward() {
439        let x = Tensor::from_slice(&[1.0, 3.0, 2.0, 2.0], vec![2, 2]).unwrap();
440        let v = x.var_dims(&[1], true);
441        assert_eq!(v.shape().dims(), vec![2, 1]);
442        assert!((v.get(&[0, 0]) - 1.0).abs() < 1e-6);
443        assert!((v.get(&[1, 0]) - 0.0).abs() < 1e-6);
444    }
445
446    #[test]
447    fn test_var_non_contiguous_transpose() {
448        // Test var on transposed tensor (non-contiguous view)
449        let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
450        // Original: [[1, 2, 3], [4, 5, 6]]
451
452        let x_t = x.transpose(0, 1);
453        // Transposed: [[1, 4], [2, 5], [3, 6]]
454        assert!(!x_t.is_contiguous()); // Should be a view
455
456        let var_orig = x.var();
457        let var_view = x_t.var();
458
459        // Both should give the same result
460        assert!((var_orig.get(&[0]) - var_view.get(&[0])).abs() < 1e-6);
461
462        // Expected var of [1,2,3,4,5,6]: mean=3.5, var=mean([2.5^2,1.5^2,0.5^2,0.5^2,1.5^2,2.5^2])=2.9167
463        let expected_var = 2.9166667_f32;
464        assert!((var_orig.get(&[0]) - expected_var).abs() < 1e-5);
465    }
466
467    #[test]
468    fn test_var_dims_non_contiguous() {
469        // Test var_dims on non-contiguous tensor
470        let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
471        let x_t = x.transpose(0, 1); // [3, 2]
472        assert!(!x_t.is_contiguous());
473
474        // Var along dim 0 of transposed tensor
475        let var_dim0 = x_t.var_dims(&[0], false);
476        assert_eq!(var_dim0.shape().dims(), vec![2]);
477
478        // For dim 0: [1,2,3] and [4,5,6]
479        // [1,2,3]: mean=2, var=((1-2)^2 + (2-2)^2 + (3-2)^2)/3 = 2/3 ≈ 0.6667
480        // [4,5,6]: mean=5, var=((4-5)^2 + (5-5)^2 + (6-5)^2)/3 = 2/3 ≈ 0.6667
481        let expected_var = 2.0 / 3.0_f32;
482        assert!((var_dim0.get(&[0]) - expected_var).abs() < 1e-5);
483        assert!((var_dim0.get(&[1]) - expected_var).abs() < 1e-5);
484    }
485
486    #[test]
487    fn test_var_permuted_tensor() {
488        // Test with permuted tensor - simple case with known var
489        let data = vec![1.0, 3.0, 5.0, 7.0, 2.0, 4.0, 6.0, 8.0];
490        let x = Tensor::from_slice(&data, vec![2, 2, 2]).unwrap();
491
492        // Permute dimensions [2, 2, 2] -> [2, 2, 2] (swap first and last)
493        let x_perm = x.permute(vec![2, 1, 0]);
494        assert!(!x_perm.is_contiguous());
495
496        let var_orig = x.var();
497        let var_perm = x_perm.var();
498
499        // Should give same result
500        assert!((var_orig.get(&[0]) - var_perm.get(&[0])).abs() < 1e-6);
501
502        // Data is [1,3,5,7,2,4,6,8], mean=4.5
503        // var = mean([3.5^2, 1.5^2, 0.5^2, 2.5^2, 2.5^2, 0.5^2, 1.5^2, 3.5^2]) = 5.25
504        let expected_var = 5.25_f32;
505        assert!((var_orig.get(&[0]) - expected_var).abs() < 1e-5);
506    }
507}