train_station/tensor/reductions/std.rs
1//! Standard deviation reduction operations for tensors
2//!
3//! This module provides standard deviation reduction operations for tensors.
4//! The standard deviation measures the dispersion of values around the mean,
5//! calculated as the square root of the variance. This is commonly used in
6//! statistics, data analysis, and machine learning for understanding data
7//! variability and normalization.
8//!
9//! # Operations
10//!
11//! * `std()` - Computes standard deviation over all elements, returning a scalar tensor
12//! * `std_dims()` - Computes standard deviation over specified dimensions with optional dimension preservation
13//!
14//! # Statistical Details
15//!
16//! The implementation uses population standard deviation (unbiased=false), which
17//! divides by n rather than n-1. This matches PyTorch's default behavior for
18//! consistency with the reference implementation.
19//!
20//! # Examples
21//!
22//! ```
23//! use train_station::Tensor;
24//!
25//! // Compute standard deviation of all elements
26//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
27//! let std_dev = tensor.std();
28//! assert!((std_dev.get(&[0]) - 1.118_034).abs() < 1e-5);
29//!
30//! // Compute standard deviation along specific dimensions
31//! let matrix = Tensor::from_slice(&[1.0, 3.0, 2.0, 2.0], vec![2, 2]).unwrap();
32//! let row_stds = matrix.std_dims(&[1], true);
33//! assert_eq!(row_stds.shape().dims, vec![2, 1]);
34//! ```
35//!
36//! # Performance
37//!
38//! The implementation uses optimized paths for contiguous tensors with manual loop unrolling
39//! for better performance. Non-contiguous tensors use stride-aware iteration to maintain
40//! correctness while preserving memory layout efficiency.
41//!
42//! # Gradient Tracking
43//!
44//! Both operations support automatic gradient tracking when `requires_grad` is enabled.
45//! The gradient computation follows the mathematical derivative of the standard deviation operation.
46
47use crate::gradtrack::{GradEngine, GradFn};
48use crate::tensor::core::Tensor;
49
50impl Tensor {
51 /// Computes the standard deviation over all elements
52 ///
53 /// The standard deviation is calculated as sqrt(variance) where variance is
54 /// the mean of squared differences from the mean. This operation reduces the
55 /// tensor to a scalar value \[1\].
56 ///
57 /// The implementation uses population standard deviation (divides by n rather
58 /// than n-1) to match PyTorch's default behavior.
59 ///
60 /// # Returns
61 ///
62 /// A scalar tensor containing the standard deviation value
63 ///
64 /// # Examples
65 ///
66 /// ```
67 /// use train_station::Tensor;
68 ///
69 /// // Basic standard deviation calculation
70 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
71 /// let std_dev = tensor.std();
72 /// assert!((std_dev.get(&[0]) - 1.118_034).abs() < 1e-5);
73 /// ```
74 ///
75 /// ```
76 /// use train_station::Tensor;
77 ///
78 /// // Standard deviation of a larger dataset
79 /// let data = vec![1.0, 3.0, 5.0, 7.0, 2.0, 4.0, 6.0, 8.0];
80 /// let tensor = Tensor::from_slice(&data, vec![2, 2, 2]).unwrap();
81 /// let std_dev = tensor.std();
82 /// // mean=4.5, var=5.25, std=sqrt(5.25)≈2.291
83 /// let expected = 5.25_f32.sqrt();
84 /// assert!((std_dev.get(&[0]) - expected).abs() < 1e-5);
85 /// ```
86 ///
87 /// ```
88 /// use train_station::Tensor;
89 ///
90 /// // Standard deviation of constant values (should be 0)
91 /// let tensor = Tensor::from_slice(&[5.0, 5.0, 5.0, 5.0], vec![4]).unwrap();
92 /// let std_dev = tensor.std();
93 /// assert!((std_dev.get(&[0]) - 0.0).abs() < 1e-6);
94 /// ```
95 ///
96 /// # Performance
97 ///
98 /// Uses optimized contiguous tensor path with 4x loop unrolling for better
99 /// performance. Non-contiguous tensors use stride-aware iteration.
100 /// The algorithm performs two passes: first to compute the mean, then to
101 /// compute the variance.
102 pub fn std(&self) -> Tensor {
103 let mut out = Tensor::new(vec![1]);
104 if self.size() == 0 {
105 out.fill(0.0);
106 } else {
107 // First pass: mean
108 let mut mean_val = 0.0f32;
109 let n = self.size() as f32;
110
111 if self.is_contiguous() {
112 // Fast path for contiguous tensors
113 unsafe {
114 let src = self.as_ptr();
115 let mut i = 0usize;
116 while i + 4 <= self.size() {
117 mean_val +=
118 *src.add(i) + *src.add(i + 1) + *src.add(i + 2) + *src.add(i + 3);
119 i += 4;
120 }
121 while i < self.size() {
122 mean_val += *src.add(i);
123 i += 1;
124 }
125 }
126 } else {
127 // Stride-aware path for non-contiguous tensors
128 let dims = self.shape().dims.clone();
129 for flat_idx in 0..self.size() {
130 // Convert flat index to multi-dimensional coordinates
131 let mut coords = vec![0; dims.len()];
132 let mut tmp = flat_idx;
133 for k in (0..dims.len()).rev() {
134 coords[k] = tmp % dims[k];
135 tmp /= dims[k];
136 }
137
138 // Get value using stride-aware offset
139 let offset = self.shape().offset(&coords);
140 let value = unsafe { *self.as_ptr().add(offset) };
141 mean_val += value;
142 }
143 }
144 mean_val /= n;
145
146 // Second pass: variance
147 let mut var_val = 0.0f32;
148
149 if self.is_contiguous() {
150 // Fast path for contiguous tensors
151 unsafe {
152 let src = self.as_ptr();
153 let mut i = 0usize;
154 while i + 4 <= self.size() {
155 let x0 = *src.add(i) - mean_val;
156 let x1 = *src.add(i + 1) - mean_val;
157 let x2 = *src.add(i + 2) - mean_val;
158 let x3 = *src.add(i + 3) - mean_val;
159 var_val += x0 * x0 + x1 * x1 + x2 * x2 + x3 * x3;
160 i += 4;
161 }
162 while i < self.size() {
163 let d = *src.add(i) - mean_val;
164 var_val += d * d;
165 i += 1;
166 }
167 }
168 } else {
169 // Stride-aware path for non-contiguous tensors
170 let dims = self.shape().dims.clone();
171 for flat_idx in 0..self.size() {
172 // Convert flat index to multi-dimensional coordinates
173 let mut coords = vec![0; dims.len()];
174 let mut tmp = flat_idx;
175 for k in (0..dims.len()).rev() {
176 coords[k] = tmp % dims[k];
177 tmp /= dims[k];
178 }
179
180 // Get value using stride-aware offset
181 let offset = self.shape().offset(&coords);
182 let value = unsafe { *self.as_ptr().add(offset) };
183 let d = value - mean_val;
184 var_val += d * d;
185 }
186 }
187 var_val /= n;
188
189 let std_val = var_val.sqrt();
190 unsafe {
191 *out.as_mut_ptr() = std_val;
192 }
193 }
194
195 if self.requires_grad() {
196 // Save tensors for backward
197 let mut result = out.clone();
198 result.set_requires_grad_internal(true);
199 let mean_tensor = {
200 let mut t = Tensor::new(vec![1]);
201 if self.size() == 0 {
202 t.fill(0.0);
203 } else {
204 let mut acc = 0.0f32;
205
206 if self.is_contiguous() {
207 unsafe {
208 for i in 0..self.size() {
209 acc += *self.as_ptr().add(i);
210 }
211 }
212 } else {
213 let dims = self.shape().dims.clone();
214 for flat_idx in 0..self.size() {
215 // Convert flat index to multi-dimensional coordinates
216 let mut coords = vec![0; dims.len()];
217 let mut tmp = flat_idx;
218 for k in (0..dims.len()).rev() {
219 coords[k] = tmp % dims[k];
220 tmp /= dims[k];
221 }
222
223 // Get value using stride-aware offset
224 let offset = self.shape().offset(&coords);
225 let value = unsafe { *self.as_ptr().add(offset) };
226 acc += value;
227 }
228 }
229
230 unsafe {
231 *t.as_mut_ptr() = acc / (self.size() as f32);
232 }
233 }
234 t
235 };
236 let std_saved = out.clone();
237 let grad_fn = GradFn::ReduceStd {
238 saved_mean: Box::new(mean_tensor),
239 saved_std: Box::new(std_saved),
240 saved_input: Box::new(self.clone()),
241 input_shape: self.shape().dims.clone(),
242 };
243 result.set_grad_fn(grad_fn.clone());
244 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
245 return result;
246 }
247
248 out
249 }
250
251 /// Computes the standard deviation over specified dimensions
252 ///
253 /// Reduces the tensor along the specified dimensions by computing the standard
254 /// deviation of each slice. The result maintains the original tensor structure
255 /// with reduced dimensions optionally preserved as size-1 dimensions.
256 ///
257 /// Uses population standard deviation (divides by n rather than n-1) to match
258 /// PyTorch's default behavior.
259 ///
260 /// # Arguments
261 ///
262 /// * `dims` - Vector of dimension indices to reduce over (must be valid for tensor rank)
263 /// * `keepdim` - Whether to keep reduced dimensions as size-1 dimensions
264 ///
265 /// # Returns
266 ///
267 /// A tensor with standard deviation computed over the specified dimensions
268 ///
269 /// # Examples
270 ///
271 /// ```
272 /// use train_station::Tensor;
273 ///
274 /// // Standard deviation along rows (dimension 1) with keepdim=true
275 /// let matrix = Tensor::from_slice(&[1.0, 3.0, 2.0, 2.0], vec![2, 2]).unwrap();
276 /// let row_stds = matrix.std_dims(&[1], true);
277 /// assert_eq!(row_stds.shape().dims, vec![2, 1]);
278 /// assert!((row_stds.get(&[0, 0]) - 1.0).abs() < 1e-6); // std([1, 3]) = 1.0
279 /// assert!((row_stds.get(&[1, 0]) - 0.0).abs() < 1e-6); // std([2, 2]) = 0.0
280 /// ```
281 ///
282 /// ```
283 /// use train_station::Tensor;
284 ///
285 /// // Standard deviation along columns (dimension 0) with keepdim=false
286 /// let matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
287 /// let col_stds = matrix.std_dims(&[0], false);
288 /// assert_eq!(col_stds.shape().dims, vec![2]);
289 /// // std([1, 3]) = 1.0, std([2, 4]) = 1.0
290 /// assert!((col_stds.get(&[0]) - 1.0).abs() < 1e-6);
291 /// assert!((col_stds.get(&[1]) - 1.0).abs() < 1e-6);
292 /// ```
293 ///
294 /// ```
295 /// use train_station::Tensor;
296 ///
297 /// // Standard deviation over multiple dimensions
298 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
299 /// let std_all = tensor.std_dims(&[0, 1], false);
300 /// assert_eq!(std_all.shape().dims, vec![1]);
301 /// // std([1, 2, 3, 4]) = sqrt(1.25) ≈ 1.118
302 /// assert!((std_all.get(&[0]) - 1.25_f32.sqrt()).abs() < 1e-5);
303 /// ```
304 ///
305 /// # Panics
306 ///
307 /// * If `dims` is empty
308 /// * If any dimension index is out of bounds for the tensor rank
309 /// * If the reduced size is 0 (invalid for standard deviation calculation)
310 ///
311 /// # Performance
312 ///
313 /// Uses efficient coordinate-based iteration that works correctly with
314 /// both contiguous and non-contiguous tensor layouts. The algorithm performs
315 /// two passes: first to compute means, then to compute variances.
316 pub fn std_dims(&self, dims: &[usize], keepdim: bool) -> Tensor {
317 assert!(!dims.is_empty(), "std_dims requires at least one dimension");
318 let rank = self.shape().rank();
319 for &d in dims {
320 assert!(
321 d < rank,
322 "std_dims dim {} out of bounds for rank {}",
323 d,
324 rank
325 );
326 }
327
328 // Output shape
329 let mut out_dims = self.shape().dims.clone();
330 let mut reduced: Vec<usize> = dims.to_vec();
331 reduced.sort_unstable();
332 reduced.dedup();
333 for &d in reduced.iter() {
334 out_dims[d] = if keepdim { 1 } else { 0 };
335 }
336 if !keepdim {
337 out_dims.retain(|&s| s != 0);
338 }
339 if out_dims.is_empty() {
340 out_dims.push(1);
341 }
342 let mut mean = Tensor::zeros(out_dims.clone());
343 let mut var = Tensor::zeros(out_dims.clone());
344
345 let in_shape = self.shape().dims.clone();
346 let out_rank = mean.shape().rank();
347 let mut in_coords = vec![0usize; rank];
348 let n_reduced: usize = reduced.iter().map(|&d| in_shape[d]).product();
349 assert!(n_reduced > 0, "reduced size must be > 0");
350 unsafe {
351 let mptr = mean.as_mut_ptr();
352 let vptr = var.as_mut_ptr();
353
354 // First pass: sum to compute means
355 for lin in 0..self.size() {
356 let mut tmp = lin;
357 for i in (0..rank).rev() {
358 let s = in_shape[i];
359 in_coords[i] = if s == 0 { 0 } else { tmp % s };
360 if s != 0 {
361 tmp /= s;
362 }
363 }
364
365 // Get input value using stride-aware offset
366 let in_offset = self.shape().offset(&in_coords);
367 let value = *self.as_ptr().add(in_offset);
368
369 let mut out_coords: Vec<usize> = Vec::with_capacity(out_rank);
370 for (i, &c) in in_coords.iter().enumerate().take(rank) {
371 if reduced.contains(&i) {
372 if keepdim {
373 out_coords.push(0);
374 }
375 } else {
376 out_coords.push(c);
377 }
378 }
379 let off = if out_coords.is_empty() {
380 0
381 } else {
382 mean.shape().offset(&out_coords)
383 };
384 *mptr.add(off) += value;
385 }
386 let msize = mean.size();
387 for i in 0..msize {
388 *mptr.add(i) /= n_reduced as f32;
389 }
390
391 // Second pass: accumulate squared diffs
392 for lin in 0..self.size() {
393 let mut tmp = lin;
394 for i in (0..rank).rev() {
395 let s = in_shape[i];
396 in_coords[i] = if s == 0 { 0 } else { tmp % s };
397 if s != 0 {
398 tmp /= s;
399 }
400 }
401
402 // Get input value using stride-aware offset
403 let in_offset = self.shape().offset(&in_coords);
404 let x = *self.as_ptr().add(in_offset);
405
406 let mut out_coords: Vec<usize> = Vec::with_capacity(out_rank);
407 for (i, &c) in in_coords.iter().enumerate().take(rank) {
408 if reduced.contains(&i) {
409 if keepdim {
410 out_coords.push(0);
411 }
412 } else {
413 out_coords.push(c);
414 }
415 }
416 let off = if out_coords.is_empty() {
417 0
418 } else {
419 var.shape().offset(&out_coords)
420 };
421 let diff = x - *mptr.add(off);
422 *vptr.add(off) += diff * diff;
423 }
424 let vsize = var.size();
425 for i in 0..vsize {
426 *vptr.add(i) /= n_reduced as f32;
427 }
428 }
429 // std = sqrt(var)
430 let mut out = Tensor::zeros(out_dims.clone());
431 unsafe {
432 let d = out.as_mut_ptr();
433 let v = var.as_ptr();
434 for i in 0..out.size() {
435 *d.add(i) = (*v.add(i)).sqrt();
436 }
437 }
438
439 if self.requires_grad() {
440 let mut result = out.clone();
441 result.set_requires_grad_internal(true);
442 let grad_fn = GradFn::ReduceStdDims {
443 dims: reduced,
444 keepdim,
445 input_shape: self.shape().dims.clone(),
446 saved_mean: Box::new(mean),
447 saved_std: Box::new(out.clone()),
448 saved_input: Box::new(self.clone()),
449 };
450 result.set_grad_fn(grad_fn.clone());
451 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
452 return result;
453 }
454
455 out
456 }
457}
458
459#[cfg(test)]
460mod tests {
461 use super::*;
462
463 #[test]
464 fn test_std_forward_basic() {
465 let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
466 let s = x.std();
467 unsafe {
468 let v = *s.as_ptr();
469 assert!((v - 1.118_034).abs() < 1e-5);
470 }
471 }
472
473 #[test]
474 fn test_std_dims_forward() {
475 let x = Tensor::from_slice(&[1.0, 3.0, 2.0, 2.0], vec![2, 2]).unwrap();
476 let s = x.std_dims(&[1], true);
477 assert_eq!(s.shape().dims, vec![2, 1]);
478 assert!((s.get(&[0, 0]) - 1.0).abs() < 1e-6);
479 assert!((s.get(&[1, 0]) - 0.0).abs() < 1e-6);
480 }
481
482 #[test]
483 fn test_std_non_contiguous_transpose() {
484 // Test std on transposed tensor (non-contiguous view)
485 let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
486 // Original: [[1, 2, 3], [4, 5, 6]]
487
488 let x_t = x.transpose(0, 1);
489 // Transposed: [[1, 4], [2, 5], [3, 6]]
490 assert!(!x_t.is_contiguous()); // Should be a view
491
492 let std_orig = x.std();
493 let std_view = x_t.std();
494
495 // Both should give the same result
496 assert!((std_orig.get(&[0]) - std_view.get(&[0])).abs() < 1e-6);
497
498 // Expected std of [1,2,3,4,5,6]: mean=3.5, var=mean([2.5^2,1.5^2,0.5^2,0.5^2,1.5^2,2.5^2])=2.9167
499 let expected_std = (2.9166667_f32).sqrt(); // ≈ 1.708
500 assert!((std_orig.get(&[0]) - expected_std).abs() < 1e-5);
501 }
502
503 #[test]
504 fn test_std_dims_non_contiguous() {
505 // Test std_dims on non-contiguous tensor
506 let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
507 let x_t = x.transpose(0, 1); // [3, 2]
508 assert!(!x_t.is_contiguous());
509
510 // Std along dim 0 of transposed tensor
511 let std_dim0 = x_t.std_dims(&[0], false);
512 assert_eq!(std_dim0.shape().dims, vec![2]);
513
514 // For dim 0: [1,2,3] and [4,5,6]
515 // [1,2,3]: mean=2, var=((1-2)^2 + (2-2)^2 + (3-2)^2)/3 = 2/3, std=sqrt(2/3)≈0.816
516 // [4,5,6]: mean=5, var=((4-5)^2 + (5-5)^2 + (6-5)^2)/3 = 2/3, std=sqrt(2/3)≈0.816
517 let expected_std = (2.0 / 3.0_f32).sqrt();
518 assert!((std_dim0.get(&[0]) - expected_std).abs() < 1e-5);
519 assert!((std_dim0.get(&[1]) - expected_std).abs() < 1e-5);
520 }
521
522 #[test]
523 fn test_std_permuted_tensor() {
524 // Test with permuted tensor - simple case with known std
525 let data = vec![1.0, 3.0, 5.0, 7.0, 2.0, 4.0, 6.0, 8.0];
526 let x = Tensor::from_slice(&data, vec![2, 2, 2]).unwrap();
527
528 // Permute dimensions [2, 2, 2] -> [2, 2, 2] (swap first and last)
529 let x_perm = x.permute(vec![2, 1, 0]);
530 assert!(!x_perm.is_contiguous());
531
532 let std_orig = x.std();
533 let std_perm = x_perm.std();
534
535 // Should give same result
536 assert!((std_orig.get(&[0]) - std_perm.get(&[0])).abs() < 1e-6);
537
538 // Data is [1,3,5,7,2,4,6,8], mean=4.5
539 // var = mean([3.5^2, 1.5^2, 0.5^2, 2.5^2, 2.5^2, 0.5^2, 1.5^2, 3.5^2]) = 5.25
540 let expected_std = 5.25_f32.sqrt(); // ≈ 2.291
541 assert!((std_orig.get(&[0]) - expected_std).abs() < 1e-5);
542 }
543}