train_station/tensor/reductions/
std.rs

1//! Standard deviation reduction operations for tensors
2//!
3//! This module provides standard deviation reduction operations for tensors.
4//! The standard deviation measures the dispersion of values around the mean,
5//! calculated as the square root of the variance. This is commonly used in
6//! statistics, data analysis, and machine learning for understanding data
7//! variability and normalization.
8//!
9//! # Operations
10//!
11//! * `std()` - Computes standard deviation over all elements, returning a scalar tensor
12//! * `std_dims()` - Computes standard deviation over specified dimensions with optional dimension preservation
13//!
14//! # Statistical Details
15//!
16//! The implementation uses population standard deviation (unbiased=false), which
17//! divides by n rather than n-1. This matches PyTorch's default behavior for
18//! consistency with the reference implementation.
19//!
20//! # Examples
21//!
22//! ```
23//! use train_station::Tensor;
24//!
25//! // Compute standard deviation of all elements
26//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
27//! let std_dev = tensor.std();
28//! assert!((std_dev.get(&[0]) - 1.118_034).abs() < 1e-5);
29//!
30//! // Compute standard deviation along specific dimensions
31//! let matrix = Tensor::from_slice(&[1.0, 3.0, 2.0, 2.0], vec![2, 2]).unwrap();
32//! let row_stds = matrix.std_dims(&[1], true);
33//! assert_eq!(row_stds.shape().dims, vec![2, 1]);
34//! ```
35//!
36//! # Performance
37//!
38//! The implementation uses optimized paths for contiguous tensors with manual loop unrolling
39//! for better performance. Non-contiguous tensors use stride-aware iteration to maintain
40//! correctness while preserving memory layout efficiency.
41//!
42//! # Gradient Tracking
43//!
44//! Both operations support automatic gradient tracking when `requires_grad` is enabled.
45//! The gradient computation follows the mathematical derivative of the standard deviation operation.
46
47use crate::gradtrack::{GradEngine, GradFn};
48use crate::tensor::core::Tensor;
49
50impl Tensor {
51    /// Computes the standard deviation over all elements
52    ///
53    /// The standard deviation is calculated as sqrt(variance) where variance is
54    /// the mean of squared differences from the mean. This operation reduces the
55    /// tensor to a scalar value \[1\].
56    ///
57    /// The implementation uses population standard deviation (divides by n rather
58    /// than n-1) to match PyTorch's default behavior.
59    ///
60    /// # Returns
61    ///
62    /// A scalar tensor containing the standard deviation value
63    ///
64    /// # Examples
65    ///
66    /// ```
67    /// use train_station::Tensor;
68    ///
69    /// // Basic standard deviation calculation
70    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
71    /// let std_dev = tensor.std();
72    /// assert!((std_dev.get(&[0]) - 1.118_034).abs() < 1e-5);
73    /// ```
74    ///
75    /// ```
76    /// use train_station::Tensor;
77    ///
78    /// // Standard deviation of a larger dataset
79    /// let data = vec![1.0, 3.0, 5.0, 7.0, 2.0, 4.0, 6.0, 8.0];
80    /// let tensor = Tensor::from_slice(&data, vec![2, 2, 2]).unwrap();
81    /// let std_dev = tensor.std();
82    /// // mean=4.5, var=5.25, std=sqrt(5.25)≈2.291
83    /// let expected = 5.25_f32.sqrt();
84    /// assert!((std_dev.get(&[0]) - expected).abs() < 1e-5);
85    /// ```
86    ///
87    /// ```
88    /// use train_station::Tensor;
89    ///
90    /// // Standard deviation of constant values (should be 0)
91    /// let tensor = Tensor::from_slice(&[5.0, 5.0, 5.0, 5.0], vec![4]).unwrap();
92    /// let std_dev = tensor.std();
93    /// assert!((std_dev.get(&[0]) - 0.0).abs() < 1e-6);
94    /// ```
95    ///
96    /// # Performance
97    ///
98    /// Uses optimized contiguous tensor path with 4x loop unrolling for better
99    /// performance. Non-contiguous tensors use stride-aware iteration.
100    /// The algorithm performs two passes: first to compute the mean, then to
101    /// compute the variance.
102    pub fn std(&self) -> Tensor {
103        let mut out = Tensor::new(vec![1]);
104        if self.size() == 0 {
105            out.fill(0.0);
106        } else {
107            // First pass: mean
108            let mut mean_val = 0.0f32;
109            let n = self.size() as f32;
110
111            if self.is_contiguous() {
112                // Fast path for contiguous tensors
113                unsafe {
114                    let src = self.as_ptr();
115                    let mut i = 0usize;
116                    while i + 4 <= self.size() {
117                        mean_val +=
118                            *src.add(i) + *src.add(i + 1) + *src.add(i + 2) + *src.add(i + 3);
119                        i += 4;
120                    }
121                    while i < self.size() {
122                        mean_val += *src.add(i);
123                        i += 1;
124                    }
125                }
126            } else {
127                // Stride-aware path for non-contiguous tensors
128                let dims = self.shape().dims.clone();
129                for flat_idx in 0..self.size() {
130                    // Convert flat index to multi-dimensional coordinates
131                    let mut coords = vec![0; dims.len()];
132                    let mut tmp = flat_idx;
133                    for k in (0..dims.len()).rev() {
134                        coords[k] = tmp % dims[k];
135                        tmp /= dims[k];
136                    }
137
138                    // Get value using stride-aware offset
139                    let offset = self.shape().offset(&coords);
140                    let value = unsafe { *self.as_ptr().add(offset) };
141                    mean_val += value;
142                }
143            }
144            mean_val /= n;
145
146            // Second pass: variance
147            let mut var_val = 0.0f32;
148
149            if self.is_contiguous() {
150                // Fast path for contiguous tensors
151                unsafe {
152                    let src = self.as_ptr();
153                    let mut i = 0usize;
154                    while i + 4 <= self.size() {
155                        let x0 = *src.add(i) - mean_val;
156                        let x1 = *src.add(i + 1) - mean_val;
157                        let x2 = *src.add(i + 2) - mean_val;
158                        let x3 = *src.add(i + 3) - mean_val;
159                        var_val += x0 * x0 + x1 * x1 + x2 * x2 + x3 * x3;
160                        i += 4;
161                    }
162                    while i < self.size() {
163                        let d = *src.add(i) - mean_val;
164                        var_val += d * d;
165                        i += 1;
166                    }
167                }
168            } else {
169                // Stride-aware path for non-contiguous tensors
170                let dims = self.shape().dims.clone();
171                for flat_idx in 0..self.size() {
172                    // Convert flat index to multi-dimensional coordinates
173                    let mut coords = vec![0; dims.len()];
174                    let mut tmp = flat_idx;
175                    for k in (0..dims.len()).rev() {
176                        coords[k] = tmp % dims[k];
177                        tmp /= dims[k];
178                    }
179
180                    // Get value using stride-aware offset
181                    let offset = self.shape().offset(&coords);
182                    let value = unsafe { *self.as_ptr().add(offset) };
183                    let d = value - mean_val;
184                    var_val += d * d;
185                }
186            }
187            var_val /= n;
188
189            let std_val = var_val.sqrt();
190            unsafe {
191                *out.as_mut_ptr() = std_val;
192            }
193        }
194
195        if self.requires_grad() {
196            // Save tensors for backward
197            let mut result = out.clone();
198            result.set_requires_grad_internal(true);
199            let mean_tensor = {
200                let mut t = Tensor::new(vec![1]);
201                if self.size() == 0 {
202                    t.fill(0.0);
203                } else {
204                    let mut acc = 0.0f32;
205
206                    if self.is_contiguous() {
207                        unsafe {
208                            for i in 0..self.size() {
209                                acc += *self.as_ptr().add(i);
210                            }
211                        }
212                    } else {
213                        let dims = self.shape().dims.clone();
214                        for flat_idx in 0..self.size() {
215                            // Convert flat index to multi-dimensional coordinates
216                            let mut coords = vec![0; dims.len()];
217                            let mut tmp = flat_idx;
218                            for k in (0..dims.len()).rev() {
219                                coords[k] = tmp % dims[k];
220                                tmp /= dims[k];
221                            }
222
223                            // Get value using stride-aware offset
224                            let offset = self.shape().offset(&coords);
225                            let value = unsafe { *self.as_ptr().add(offset) };
226                            acc += value;
227                        }
228                    }
229
230                    unsafe {
231                        *t.as_mut_ptr() = acc / (self.size() as f32);
232                    }
233                }
234                t
235            };
236            let std_saved = out.clone();
237            let grad_fn = GradFn::ReduceStd {
238                saved_mean: Box::new(mean_tensor),
239                saved_std: Box::new(std_saved),
240                saved_input: Box::new(self.clone()),
241                input_shape: self.shape().dims.clone(),
242            };
243            result.set_grad_fn(grad_fn.clone());
244            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
245            return result;
246        }
247
248        out
249    }
250
251    /// Computes the standard deviation over specified dimensions
252    ///
253    /// Reduces the tensor along the specified dimensions by computing the standard
254    /// deviation of each slice. The result maintains the original tensor structure
255    /// with reduced dimensions optionally preserved as size-1 dimensions.
256    ///
257    /// Uses population standard deviation (divides by n rather than n-1) to match
258    /// PyTorch's default behavior.
259    ///
260    /// # Arguments
261    ///
262    /// * `dims` - Vector of dimension indices to reduce over (must be valid for tensor rank)
263    /// * `keepdim` - Whether to keep reduced dimensions as size-1 dimensions
264    ///
265    /// # Returns
266    ///
267    /// A tensor with standard deviation computed over the specified dimensions
268    ///
269    /// # Examples
270    ///
271    /// ```
272    /// use train_station::Tensor;
273    ///
274    /// // Standard deviation along rows (dimension 1) with keepdim=true
275    /// let matrix = Tensor::from_slice(&[1.0, 3.0, 2.0, 2.0], vec![2, 2]).unwrap();
276    /// let row_stds = matrix.std_dims(&[1], true);
277    /// assert_eq!(row_stds.shape().dims, vec![2, 1]);
278    /// assert!((row_stds.get(&[0, 0]) - 1.0).abs() < 1e-6); // std([1, 3]) = 1.0
279    /// assert!((row_stds.get(&[1, 0]) - 0.0).abs() < 1e-6); // std([2, 2]) = 0.0
280    /// ```
281    ///
282    /// ```
283    /// use train_station::Tensor;
284    ///
285    /// // Standard deviation along columns (dimension 0) with keepdim=false
286    /// let matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
287    /// let col_stds = matrix.std_dims(&[0], false);
288    /// assert_eq!(col_stds.shape().dims, vec![2]);
289    /// // std([1, 3]) = 1.0, std([2, 4]) = 1.0
290    /// assert!((col_stds.get(&[0]) - 1.0).abs() < 1e-6);
291    /// assert!((col_stds.get(&[1]) - 1.0).abs() < 1e-6);
292    /// ```
293    ///
294    /// ```
295    /// use train_station::Tensor;
296    ///
297    /// // Standard deviation over multiple dimensions
298    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
299    /// let std_all = tensor.std_dims(&[0, 1], false);
300    /// assert_eq!(std_all.shape().dims, vec![1]);
301    /// // std([1, 2, 3, 4]) = sqrt(1.25) ≈ 1.118
302    /// assert!((std_all.get(&[0]) - 1.25_f32.sqrt()).abs() < 1e-5);
303    /// ```
304    ///
305    /// # Panics
306    ///
307    /// * If `dims` is empty
308    /// * If any dimension index is out of bounds for the tensor rank
309    /// * If the reduced size is 0 (invalid for standard deviation calculation)
310    ///
311    /// # Performance
312    ///
313    /// Uses efficient coordinate-based iteration that works correctly with
314    /// both contiguous and non-contiguous tensor layouts. The algorithm performs
315    /// two passes: first to compute means, then to compute variances.
316    pub fn std_dims(&self, dims: &[usize], keepdim: bool) -> Tensor {
317        assert!(!dims.is_empty(), "std_dims requires at least one dimension");
318        let rank = self.shape().rank();
319        for &d in dims {
320            assert!(
321                d < rank,
322                "std_dims dim {} out of bounds for rank {}",
323                d,
324                rank
325            );
326        }
327
328        // Output shape
329        let mut out_dims = self.shape().dims.clone();
330        let mut reduced: Vec<usize> = dims.to_vec();
331        reduced.sort_unstable();
332        reduced.dedup();
333        for &d in reduced.iter() {
334            out_dims[d] = if keepdim { 1 } else { 0 };
335        }
336        if !keepdim {
337            out_dims.retain(|&s| s != 0);
338        }
339        if out_dims.is_empty() {
340            out_dims.push(1);
341        }
342        let mut mean = Tensor::zeros(out_dims.clone());
343        let mut var = Tensor::zeros(out_dims.clone());
344
345        let in_shape = self.shape().dims.clone();
346        let out_rank = mean.shape().rank();
347        let mut in_coords = vec![0usize; rank];
348        let n_reduced: usize = reduced.iter().map(|&d| in_shape[d]).product();
349        assert!(n_reduced > 0, "reduced size must be > 0");
350        unsafe {
351            let mptr = mean.as_mut_ptr();
352            let vptr = var.as_mut_ptr();
353
354            // First pass: sum to compute means
355            for lin in 0..self.size() {
356                let mut tmp = lin;
357                for i in (0..rank).rev() {
358                    let s = in_shape[i];
359                    in_coords[i] = if s == 0 { 0 } else { tmp % s };
360                    if s != 0 {
361                        tmp /= s;
362                    }
363                }
364
365                // Get input value using stride-aware offset
366                let in_offset = self.shape().offset(&in_coords);
367                let value = *self.as_ptr().add(in_offset);
368
369                let mut out_coords: Vec<usize> = Vec::with_capacity(out_rank);
370                for (i, &c) in in_coords.iter().enumerate().take(rank) {
371                    if reduced.contains(&i) {
372                        if keepdim {
373                            out_coords.push(0);
374                        }
375                    } else {
376                        out_coords.push(c);
377                    }
378                }
379                let off = if out_coords.is_empty() {
380                    0
381                } else {
382                    mean.shape().offset(&out_coords)
383                };
384                *mptr.add(off) += value;
385            }
386            let msize = mean.size();
387            for i in 0..msize {
388                *mptr.add(i) /= n_reduced as f32;
389            }
390
391            // Second pass: accumulate squared diffs
392            for lin in 0..self.size() {
393                let mut tmp = lin;
394                for i in (0..rank).rev() {
395                    let s = in_shape[i];
396                    in_coords[i] = if s == 0 { 0 } else { tmp % s };
397                    if s != 0 {
398                        tmp /= s;
399                    }
400                }
401
402                // Get input value using stride-aware offset
403                let in_offset = self.shape().offset(&in_coords);
404                let x = *self.as_ptr().add(in_offset);
405
406                let mut out_coords: Vec<usize> = Vec::with_capacity(out_rank);
407                for (i, &c) in in_coords.iter().enumerate().take(rank) {
408                    if reduced.contains(&i) {
409                        if keepdim {
410                            out_coords.push(0);
411                        }
412                    } else {
413                        out_coords.push(c);
414                    }
415                }
416                let off = if out_coords.is_empty() {
417                    0
418                } else {
419                    var.shape().offset(&out_coords)
420                };
421                let diff = x - *mptr.add(off);
422                *vptr.add(off) += diff * diff;
423            }
424            let vsize = var.size();
425            for i in 0..vsize {
426                *vptr.add(i) /= n_reduced as f32;
427            }
428        }
429        // std = sqrt(var)
430        let mut out = Tensor::zeros(out_dims.clone());
431        unsafe {
432            let d = out.as_mut_ptr();
433            let v = var.as_ptr();
434            for i in 0..out.size() {
435                *d.add(i) = (*v.add(i)).sqrt();
436            }
437        }
438
439        if self.requires_grad() {
440            let mut result = out.clone();
441            result.set_requires_grad_internal(true);
442            let grad_fn = GradFn::ReduceStdDims {
443                dims: reduced,
444                keepdim,
445                input_shape: self.shape().dims.clone(),
446                saved_mean: Box::new(mean),
447                saved_std: Box::new(out.clone()),
448                saved_input: Box::new(self.clone()),
449            };
450            result.set_grad_fn(grad_fn.clone());
451            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
452            return result;
453        }
454
455        out
456    }
457}
458
459#[cfg(test)]
460mod tests {
461    use super::*;
462
463    #[test]
464    fn test_std_forward_basic() {
465        let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
466        let s = x.std();
467        unsafe {
468            let v = *s.as_ptr();
469            assert!((v - 1.118_034).abs() < 1e-5);
470        }
471    }
472
473    #[test]
474    fn test_std_dims_forward() {
475        let x = Tensor::from_slice(&[1.0, 3.0, 2.0, 2.0], vec![2, 2]).unwrap();
476        let s = x.std_dims(&[1], true);
477        assert_eq!(s.shape().dims, vec![2, 1]);
478        assert!((s.get(&[0, 0]) - 1.0).abs() < 1e-6);
479        assert!((s.get(&[1, 0]) - 0.0).abs() < 1e-6);
480    }
481
482    #[test]
483    fn test_std_non_contiguous_transpose() {
484        // Test std on transposed tensor (non-contiguous view)
485        let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
486        // Original: [[1, 2, 3], [4, 5, 6]]
487
488        let x_t = x.transpose(0, 1);
489        // Transposed: [[1, 4], [2, 5], [3, 6]]
490        assert!(!x_t.is_contiguous()); // Should be a view
491
492        let std_orig = x.std();
493        let std_view = x_t.std();
494
495        // Both should give the same result
496        assert!((std_orig.get(&[0]) - std_view.get(&[0])).abs() < 1e-6);
497
498        // Expected std of [1,2,3,4,5,6]: mean=3.5, var=mean([2.5^2,1.5^2,0.5^2,0.5^2,1.5^2,2.5^2])=2.9167
499        let expected_std = (2.9166667_f32).sqrt(); // ≈ 1.708
500        assert!((std_orig.get(&[0]) - expected_std).abs() < 1e-5);
501    }
502
503    #[test]
504    fn test_std_dims_non_contiguous() {
505        // Test std_dims on non-contiguous tensor
506        let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
507        let x_t = x.transpose(0, 1); // [3, 2]
508        assert!(!x_t.is_contiguous());
509
510        // Std along dim 0 of transposed tensor
511        let std_dim0 = x_t.std_dims(&[0], false);
512        assert_eq!(std_dim0.shape().dims, vec![2]);
513
514        // For dim 0: [1,2,3] and [4,5,6]
515        // [1,2,3]: mean=2, var=((1-2)^2 + (2-2)^2 + (3-2)^2)/3 = 2/3, std=sqrt(2/3)≈0.816
516        // [4,5,6]: mean=5, var=((4-5)^2 + (5-5)^2 + (6-5)^2)/3 = 2/3, std=sqrt(2/3)≈0.816
517        let expected_std = (2.0 / 3.0_f32).sqrt();
518        assert!((std_dim0.get(&[0]) - expected_std).abs() < 1e-5);
519        assert!((std_dim0.get(&[1]) - expected_std).abs() < 1e-5);
520    }
521
522    #[test]
523    fn test_std_permuted_tensor() {
524        // Test with permuted tensor - simple case with known std
525        let data = vec![1.0, 3.0, 5.0, 7.0, 2.0, 4.0, 6.0, 8.0];
526        let x = Tensor::from_slice(&data, vec![2, 2, 2]).unwrap();
527
528        // Permute dimensions [2, 2, 2] -> [2, 2, 2] (swap first and last)
529        let x_perm = x.permute(vec![2, 1, 0]);
530        assert!(!x_perm.is_contiguous());
531
532        let std_orig = x.std();
533        let std_perm = x_perm.std();
534
535        // Should give same result
536        assert!((std_orig.get(&[0]) - std_perm.get(&[0])).abs() < 1e-6);
537
538        // Data is [1,3,5,7,2,4,6,8], mean=4.5
539        // var = mean([3.5^2, 1.5^2, 0.5^2, 2.5^2, 2.5^2, 0.5^2, 1.5^2, 3.5^2]) = 5.25
540        let expected_std = 5.25_f32.sqrt(); // ≈ 2.291
541        assert!((std_orig.get(&[0]) - expected_std).abs() < 1e-5);
542    }
543}