train_station/tensor/reductions/
std.rs

1//! Standard deviation reduction operations for tensors
2//!
3//! This module provides standard deviation reduction operations for tensors.
4//! The standard deviation measures the dispersion of values around the mean,
5//! calculated as the square root of the variance. This is commonly used in
6//! statistics, data analysis, and machine learning for understanding data
7//! variability and normalization.
8//!
9//! # Operations
10//!
11//! * `std()` - Computes standard deviation over all elements, returning a scalar tensor
12//! * `std_dims()` - Computes standard deviation over specified dimensions with optional dimension preservation
13//!
14//! # Statistical Details
15//!
16//! The implementation uses population standard deviation (unbiased=false), which
17//! divides by n rather than n-1. This matches PyTorch's default behavior for
18//! consistency with the reference implementation.
19//!
20//! # Examples
21//!
22//! ```
23//! use train_station::Tensor;
24//!
25//! // Compute standard deviation of all elements
26//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
27//! let std_dev = tensor.std();
28//! assert!((std_dev.get(&[0]) - 1.118_034).abs() < 1e-5);
29//!
30//! // Compute standard deviation along specific dimensions
31//! let matrix = Tensor::from_slice(&[1.0, 3.0, 2.0, 2.0], vec![2, 2]).unwrap();
32//! let row_stds = matrix.std_dims(&[1], true);
33//! assert_eq!(row_stds.shape().dims, vec![2, 1]);
34//! ```
35//!
36//! # Performance
37//!
38//! The implementation uses optimized paths for contiguous tensors with manual loop unrolling
39//! for better performance. Non-contiguous tensors use stride-aware iteration to maintain
40//! correctness while preserving memory layout efficiency.
41//!
42//! # Gradient Tracking
43//!
44//! Both operations support automatic gradient tracking when `requires_grad` is enabled.
45//! The gradient computation follows the mathematical derivative of the standard deviation operation.
46
47use crate::gradtrack::{GradEngine, GradFn};
48use crate::tensor::core::Tensor;
49
50impl Tensor {
51    /// Computes the standard deviation over all elements
52    ///
53    /// The standard deviation is calculated as sqrt(variance) where variance is
54    /// the mean of squared differences from the mean. This operation reduces the
55    /// tensor to a scalar value \[1\].
56    ///
57    /// The implementation uses population standard deviation (divides by n rather
58    /// than n-1) to match PyTorch's default behavior.
59    ///
60    /// # Returns
61    ///
62    /// A scalar tensor containing the standard deviation value
63    ///
64    /// # Examples
65    ///
66    /// ```
67    /// use train_station::Tensor;
68    ///
69    /// // Basic standard deviation calculation
70    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
71    /// let std_dev = tensor.std();
72    /// assert!((std_dev.get(&[0]) - 1.118_034).abs() < 1e-5);
73    /// ```
74    ///
75    /// ```
76    /// use train_station::Tensor;
77    ///
78    /// // Standard deviation of a larger dataset
79    /// let data = vec![1.0, 3.0, 5.0, 7.0, 2.0, 4.0, 6.0, 8.0];
80    /// let tensor = Tensor::from_slice(&data, vec![2, 2, 2]).unwrap();
81    /// let std_dev = tensor.std();
82    /// // mean=4.5, var=5.25, std=sqrt(5.25)≈2.291
83    /// let expected = 5.25_f32.sqrt();
84    /// assert!((std_dev.get(&[0]) - expected).abs() < 1e-5);
85    /// ```
86    ///
87    /// ```
88    /// use train_station::Tensor;
89    ///
90    /// // Standard deviation of constant values (should be 0)
91    /// let tensor = Tensor::from_slice(&[5.0, 5.0, 5.0, 5.0], vec![4]).unwrap();
92    /// let std_dev = tensor.std();
93    /// assert!((std_dev.get(&[0]) - 0.0).abs() < 1e-6);
94    /// ```
95    ///
96    /// # Performance
97    ///
98    /// Uses optimized contiguous tensor path with 4x loop unrolling for better
99    /// performance. Non-contiguous tensors use stride-aware iteration.
100    /// The algorithm performs two passes: first to compute the mean, then to
101    /// compute the variance.
102    #[track_caller]
103    pub fn std(&self) -> Tensor {
104        let mut out = Tensor::new(vec![1]);
105        if self.size() == 0 {
106            out.fill(0.0);
107        } else {
108            // First pass: mean
109            let mut mean_val = 0.0f32;
110            let n = self.size() as f32;
111
112            if self.is_contiguous() {
113                // Fast path for contiguous tensors
114                unsafe {
115                    let src = self.as_ptr();
116                    let mut i = 0usize;
117                    while i + 4 <= self.size() {
118                        mean_val +=
119                            *src.add(i) + *src.add(i + 1) + *src.add(i + 2) + *src.add(i + 3);
120                        i += 4;
121                    }
122                    while i < self.size() {
123                        mean_val += *src.add(i);
124                        i += 1;
125                    }
126                }
127            } else {
128                // Stride-aware path for non-contiguous tensors
129                let dims = self.shape().dims.clone();
130                for flat_idx in 0..self.size() {
131                    // Convert flat index to multi-dimensional coordinates
132                    let mut coords = vec![0; dims.len()];
133                    let mut tmp = flat_idx;
134                    for k in (0..dims.len()).rev() {
135                        coords[k] = tmp % dims[k];
136                        tmp /= dims[k];
137                    }
138
139                    // Get value using stride-aware offset
140                    let offset = self.shape().offset(&coords);
141                    let value = unsafe { *self.as_ptr().add(offset) };
142                    mean_val += value;
143                }
144            }
145            mean_val /= n;
146
147            // Second pass: variance
148            let mut var_val = 0.0f32;
149
150            if self.is_contiguous() {
151                // Fast path for contiguous tensors
152                unsafe {
153                    let src = self.as_ptr();
154                    let mut i = 0usize;
155                    while i + 4 <= self.size() {
156                        let x0 = *src.add(i) - mean_val;
157                        let x1 = *src.add(i + 1) - mean_val;
158                        let x2 = *src.add(i + 2) - mean_val;
159                        let x3 = *src.add(i + 3) - mean_val;
160                        var_val += x0 * x0 + x1 * x1 + x2 * x2 + x3 * x3;
161                        i += 4;
162                    }
163                    while i < self.size() {
164                        let d = *src.add(i) - mean_val;
165                        var_val += d * d;
166                        i += 1;
167                    }
168                }
169            } else {
170                // Stride-aware path for non-contiguous tensors
171                let dims = self.shape().dims.clone();
172                for flat_idx in 0..self.size() {
173                    // Convert flat index to multi-dimensional coordinates
174                    let mut coords = vec![0; dims.len()];
175                    let mut tmp = flat_idx;
176                    for k in (0..dims.len()).rev() {
177                        coords[k] = tmp % dims[k];
178                        tmp /= dims[k];
179                    }
180
181                    // Get value using stride-aware offset
182                    let offset = self.shape().offset(&coords);
183                    let value = unsafe { *self.as_ptr().add(offset) };
184                    let d = value - mean_val;
185                    var_val += d * d;
186                }
187            }
188            var_val /= n;
189
190            let std_val = var_val.sqrt();
191            unsafe {
192                *out.as_mut_ptr() = std_val;
193            }
194        }
195
196        if self.requires_grad() {
197            // Save tensors for backward
198            let mut result = out.clone();
199            result.set_requires_grad_internal(true);
200            let mean_tensor = {
201                let mut t = Tensor::new(vec![1]);
202                if self.size() == 0 {
203                    t.fill(0.0);
204                } else {
205                    let mut acc = 0.0f32;
206
207                    if self.is_contiguous() {
208                        unsafe {
209                            for i in 0..self.size() {
210                                acc += *self.as_ptr().add(i);
211                            }
212                        }
213                    } else {
214                        let dims = self.shape().dims.clone();
215                        for flat_idx in 0..self.size() {
216                            // Convert flat index to multi-dimensional coordinates
217                            let mut coords = vec![0; dims.len()];
218                            let mut tmp = flat_idx;
219                            for k in (0..dims.len()).rev() {
220                                coords[k] = tmp % dims[k];
221                                tmp /= dims[k];
222                            }
223
224                            // Get value using stride-aware offset
225                            let offset = self.shape().offset(&coords);
226                            let value = unsafe { *self.as_ptr().add(offset) };
227                            acc += value;
228                        }
229                    }
230
231                    unsafe {
232                        *t.as_mut_ptr() = acc / (self.size() as f32);
233                    }
234                }
235                t
236            };
237            let std_saved = out.clone();
238            let grad_fn = GradFn::ReduceStd {
239                saved_mean: Box::new(mean_tensor),
240                saved_std: Box::new(std_saved),
241                saved_input: Box::new(self.clone()),
242                input_shape: self.shape().dims.clone(),
243            };
244            result.set_grad_fn(grad_fn.clone());
245            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
246            return result;
247        }
248
249        out
250    }
251
252    /// Computes the standard deviation over specified dimensions
253    ///
254    /// Reduces the tensor along the specified dimensions by computing the standard
255    /// deviation of each slice. The result maintains the original tensor structure
256    /// with reduced dimensions optionally preserved as size-1 dimensions.
257    ///
258    /// Uses population standard deviation (divides by n rather than n-1) to match
259    /// PyTorch's default behavior.
260    ///
261    /// # Arguments
262    ///
263    /// * `dims` - Vector of dimension indices to reduce over (must be valid for tensor rank)
264    /// * `keepdim` - Whether to keep reduced dimensions as size-1 dimensions
265    ///
266    /// # Returns
267    ///
268    /// A tensor with standard deviation computed over the specified dimensions
269    ///
270    /// # Examples
271    ///
272    /// ```
273    /// use train_station::Tensor;
274    ///
275    /// // Standard deviation along rows (dimension 1) with keepdim=true
276    /// let matrix = Tensor::from_slice(&[1.0, 3.0, 2.0, 2.0], vec![2, 2]).unwrap();
277    /// let row_stds = matrix.std_dims(&[1], true);
278    /// assert_eq!(row_stds.shape().dims, vec![2, 1]);
279    /// assert!((row_stds.get(&[0, 0]) - 1.0).abs() < 1e-6); // std([1, 3]) = 1.0
280    /// assert!((row_stds.get(&[1, 0]) - 0.0).abs() < 1e-6); // std([2, 2]) = 0.0
281    /// ```
282    ///
283    /// ```
284    /// use train_station::Tensor;
285    ///
286    /// // Standard deviation along columns (dimension 0) with keepdim=false
287    /// let matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
288    /// let col_stds = matrix.std_dims(&[0], false);
289    /// assert_eq!(col_stds.shape().dims, vec![2]);
290    /// // std([1, 3]) = 1.0, std([2, 4]) = 1.0
291    /// assert!((col_stds.get(&[0]) - 1.0).abs() < 1e-6);
292    /// assert!((col_stds.get(&[1]) - 1.0).abs() < 1e-6);
293    /// ```
294    ///
295    /// ```
296    /// use train_station::Tensor;
297    ///
298    /// // Standard deviation over multiple dimensions
299    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
300    /// let std_all = tensor.std_dims(&[0, 1], false);
301    /// assert_eq!(std_all.shape().dims, vec![1]);
302    /// // std([1, 2, 3, 4]) = sqrt(1.25) ≈ 1.118
303    /// assert!((std_all.get(&[0]) - 1.25_f32.sqrt()).abs() < 1e-5);
304    /// ```
305    ///
306    /// # Panics
307    ///
308    /// * If `dims` is empty
309    /// * If any dimension index is out of bounds for the tensor rank
310    /// * If the reduced size is 0 (invalid for standard deviation calculation)
311    ///
312    /// # Performance
313    ///
314    /// Uses efficient coordinate-based iteration that works correctly with
315    /// both contiguous and non-contiguous tensor layouts. The algorithm performs
316    /// two passes: first to compute means, then to compute variances.
317    #[track_caller]
318    pub fn std_dims(&self, dims: &[usize], keepdim: bool) -> Tensor {
319        assert!(!dims.is_empty(), "std_dims requires at least one dimension");
320        let rank = self.shape().rank();
321        for &d in dims {
322            assert!(
323                d < rank,
324                "std_dims dim {} out of bounds for rank {}",
325                d,
326                rank
327            );
328        }
329
330        // Output shape
331        let mut out_dims = self.shape().dims.clone();
332        let mut reduced: Vec<usize> = dims.to_vec();
333        reduced.sort_unstable();
334        reduced.dedup();
335        for &d in reduced.iter() {
336            out_dims[d] = if keepdim { 1 } else { 0 };
337        }
338        if !keepdim {
339            out_dims.retain(|&s| s != 0);
340        }
341        if out_dims.is_empty() {
342            out_dims.push(1);
343        }
344        let mut mean = Tensor::zeros(out_dims.clone());
345        let mut var = Tensor::zeros(out_dims.clone());
346
347        let in_shape = self.shape().dims.clone();
348        let out_rank = mean.shape().rank();
349        let mut in_coords = vec![0usize; rank];
350        let n_reduced: usize = reduced.iter().map(|&d| in_shape[d]).product();
351        assert!(n_reduced > 0, "reduced size must be > 0");
352        unsafe {
353            let mptr = mean.as_mut_ptr();
354            let vptr = var.as_mut_ptr();
355
356            // First pass: sum to compute means
357            for lin in 0..self.size() {
358                let mut tmp = lin;
359                for i in (0..rank).rev() {
360                    let s = in_shape[i];
361                    in_coords[i] = if s == 0 { 0 } else { tmp % s };
362                    if s != 0 {
363                        tmp /= s;
364                    }
365                }
366
367                // Get input value using stride-aware offset
368                let in_offset = self.shape().offset(&in_coords);
369                let value = *self.as_ptr().add(in_offset);
370
371                let mut out_coords: Vec<usize> = Vec::with_capacity(out_rank);
372                for (i, &c) in in_coords.iter().enumerate().take(rank) {
373                    if reduced.contains(&i) {
374                        if keepdim {
375                            out_coords.push(0);
376                        }
377                    } else {
378                        out_coords.push(c);
379                    }
380                }
381                let off = if out_coords.is_empty() {
382                    0
383                } else {
384                    mean.shape().offset(&out_coords)
385                };
386                *mptr.add(off) += value;
387            }
388            let msize = mean.size();
389            for i in 0..msize {
390                *mptr.add(i) /= n_reduced as f32;
391            }
392
393            // Second pass: accumulate squared diffs
394            for lin in 0..self.size() {
395                let mut tmp = lin;
396                for i in (0..rank).rev() {
397                    let s = in_shape[i];
398                    in_coords[i] = if s == 0 { 0 } else { tmp % s };
399                    if s != 0 {
400                        tmp /= s;
401                    }
402                }
403
404                // Get input value using stride-aware offset
405                let in_offset = self.shape().offset(&in_coords);
406                let x = *self.as_ptr().add(in_offset);
407
408                let mut out_coords: Vec<usize> = Vec::with_capacity(out_rank);
409                for (i, &c) in in_coords.iter().enumerate().take(rank) {
410                    if reduced.contains(&i) {
411                        if keepdim {
412                            out_coords.push(0);
413                        }
414                    } else {
415                        out_coords.push(c);
416                    }
417                }
418                let off = if out_coords.is_empty() {
419                    0
420                } else {
421                    var.shape().offset(&out_coords)
422                };
423                let diff = x - *mptr.add(off);
424                *vptr.add(off) += diff * diff;
425            }
426            let vsize = var.size();
427            for i in 0..vsize {
428                *vptr.add(i) /= n_reduced as f32;
429            }
430        }
431        // std = sqrt(var)
432        let mut out = Tensor::zeros(out_dims.clone());
433        unsafe {
434            let d = out.as_mut_ptr();
435            let v = var.as_ptr();
436            for i in 0..out.size() {
437                *d.add(i) = (*v.add(i)).sqrt();
438            }
439        }
440
441        if self.requires_grad() {
442            let mut result = out.clone();
443            result.set_requires_grad_internal(true);
444            let grad_fn = GradFn::ReduceStdDims {
445                dims: reduced,
446                keepdim,
447                input_shape: self.shape().dims.clone(),
448                saved_mean: Box::new(mean),
449                saved_std: Box::new(out.clone()),
450                saved_input: Box::new(self.clone()),
451            };
452            result.set_grad_fn(grad_fn.clone());
453            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
454            return result;
455        }
456
457        out
458    }
459}
460
461#[cfg(test)]
462mod tests {
463    use super::*;
464
465    #[test]
466    fn test_std_forward_basic() {
467        let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
468        let s = x.std();
469        unsafe {
470            let v = *s.as_ptr();
471            assert!((v - 1.118_034).abs() < 1e-5);
472        }
473    }
474
475    #[test]
476    fn test_std_dims_forward() {
477        let x = Tensor::from_slice(&[1.0, 3.0, 2.0, 2.0], vec![2, 2]).unwrap();
478        let s = x.std_dims(&[1], true);
479        assert_eq!(s.shape().dims, vec![2, 1]);
480        assert!((s.get(&[0, 0]) - 1.0).abs() < 1e-6);
481        assert!((s.get(&[1, 0]) - 0.0).abs() < 1e-6);
482    }
483
484    #[test]
485    fn test_std_non_contiguous_transpose() {
486        // Test std on transposed tensor (non-contiguous view)
487        let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
488        // Original: [[1, 2, 3], [4, 5, 6]]
489
490        let x_t = x.transpose(0, 1);
491        // Transposed: [[1, 4], [2, 5], [3, 6]]
492        assert!(!x_t.is_contiguous()); // Should be a view
493
494        let std_orig = x.std();
495        let std_view = x_t.std();
496
497        // Both should give the same result
498        assert!((std_orig.get(&[0]) - std_view.get(&[0])).abs() < 1e-6);
499
500        // Expected std of [1,2,3,4,5,6]: mean=3.5, var=mean([2.5^2,1.5^2,0.5^2,0.5^2,1.5^2,2.5^2])=2.9167
501        let expected_std = (2.9166667_f32).sqrt(); // ≈ 1.708
502        assert!((std_orig.get(&[0]) - expected_std).abs() < 1e-5);
503    }
504
505    #[test]
506    fn test_std_dims_non_contiguous() {
507        // Test std_dims on non-contiguous tensor
508        let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
509        let x_t = x.transpose(0, 1); // [3, 2]
510        assert!(!x_t.is_contiguous());
511
512        // Std along dim 0 of transposed tensor
513        let std_dim0 = x_t.std_dims(&[0], false);
514        assert_eq!(std_dim0.shape().dims, vec![2]);
515
516        // For dim 0: [1,2,3] and [4,5,6]
517        // [1,2,3]: mean=2, var=((1-2)^2 + (2-2)^2 + (3-2)^2)/3 = 2/3, std=sqrt(2/3)≈0.816
518        // [4,5,6]: mean=5, var=((4-5)^2 + (5-5)^2 + (6-5)^2)/3 = 2/3, std=sqrt(2/3)≈0.816
519        let expected_std = (2.0 / 3.0_f32).sqrt();
520        assert!((std_dim0.get(&[0]) - expected_std).abs() < 1e-5);
521        assert!((std_dim0.get(&[1]) - expected_std).abs() < 1e-5);
522    }
523
524    #[test]
525    fn test_std_permuted_tensor() {
526        // Test with permuted tensor - simple case with known std
527        let data = vec![1.0, 3.0, 5.0, 7.0, 2.0, 4.0, 6.0, 8.0];
528        let x = Tensor::from_slice(&data, vec![2, 2, 2]).unwrap();
529
530        // Permute dimensions [2, 2, 2] -> [2, 2, 2] (swap first and last)
531        let x_perm = x.permute(vec![2, 1, 0]);
532        assert!(!x_perm.is_contiguous());
533
534        let std_orig = x.std();
535        let std_perm = x_perm.std();
536
537        // Should give same result
538        assert!((std_orig.get(&[0]) - std_perm.get(&[0])).abs() < 1e-6);
539
540        // Data is [1,3,5,7,2,4,6,8], mean=4.5
541        // var = mean([3.5^2, 1.5^2, 0.5^2, 2.5^2, 2.5^2, 0.5^2, 1.5^2, 3.5^2]) = 5.25
542        let expected_std = 5.25_f32.sqrt(); // ≈ 2.291
543        assert!((std_orig.get(&[0]) - expected_std).abs() < 1e-5);
544    }
545}