train_station/tensor/reductions/std.rs
1//! Standard deviation reduction operations for tensors
2//!
3//! This module provides standard deviation reduction operations for tensors.
4//! The standard deviation measures the dispersion of values around the mean,
5//! calculated as the square root of the variance. This is commonly used in
6//! statistics, data analysis, and machine learning for understanding data
7//! variability and normalization.
8//!
9//! # Operations
10//!
11//! * `std()` - Computes standard deviation over all elements, returning a scalar tensor
12//! * `std_dims()` - Computes standard deviation over specified dimensions with optional dimension preservation
13//!
14//! # Statistical Details
15//!
16//! The implementation uses population standard deviation (unbiased=false), which
17//! divides by n rather than n-1. This matches PyTorch's default behavior for
18//! consistency with the reference implementation.
19//!
20//! # Examples
21//!
22//! ```
23//! use train_station::Tensor;
24//!
25//! // Compute standard deviation of all elements
26//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
27//! let std_dev = tensor.std();
28//! assert!((std_dev.get(&[0]) - 1.118_034).abs() < 1e-5);
29//!
30//! // Compute standard deviation along specific dimensions
31//! let matrix = Tensor::from_slice(&[1.0, 3.0, 2.0, 2.0], vec![2, 2]).unwrap();
32//! let row_stds = matrix.std_dims(&[1], true);
33//! assert_eq!(row_stds.shape().dims(), vec![2, 1]);
34//! ```
35//!
36//! # Performance
37//!
38//! The implementation uses optimized paths for contiguous tensors with manual loop unrolling
39//! for better performance. Non-contiguous tensors use stride-aware iteration to maintain
40//! correctness while preserving memory layout efficiency.
41//!
42//! # Gradient Tracking
43//!
44//! Both operations support automatic gradient tracking when `requires_grad` is enabled.
45//! The gradient computation follows the mathematical derivative of the standard deviation operation.
46
47use crate::gradtrack::{GradEngine, GradFn};
48use crate::tensor::core::Tensor;
49
50impl Tensor {
51 /// Computes the standard deviation over all elements
52 ///
53 /// The standard deviation is calculated as sqrt(variance) where variance is
54 /// the mean of squared differences from the mean. This operation reduces the
55 /// tensor to a scalar value \[1\].
56 ///
57 /// The implementation uses population standard deviation (divides by n rather
58 /// than n-1) to match PyTorch's default behavior.
59 ///
60 /// # Returns
61 ///
62 /// A scalar tensor containing the standard deviation value
63 ///
64 /// # Examples
65 ///
66 /// ```
67 /// use train_station::Tensor;
68 ///
69 /// // Basic standard deviation calculation
70 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
71 /// let std_dev = tensor.std();
72 /// assert!((std_dev.get(&[0]) - 1.118_034).abs() < 1e-5);
73 /// ```
74 ///
75 /// ```
76 /// use train_station::Tensor;
77 ///
78 /// // Standard deviation of a larger dataset
79 /// let data = vec![1.0, 3.0, 5.0, 7.0, 2.0, 4.0, 6.0, 8.0];
80 /// let tensor = Tensor::from_slice(&data, vec![2, 2, 2]).unwrap();
81 /// let std_dev = tensor.std();
82 /// // mean=4.5, var=5.25, std=sqrt(5.25)≈2.291
83 /// let expected = 5.25_f32.sqrt();
84 /// assert!((std_dev.get(&[0]) - expected).abs() < 1e-5);
85 /// ```
86 ///
87 /// ```
88 /// use train_station::Tensor;
89 ///
90 /// // Standard deviation of constant values (should be 0)
91 /// let tensor = Tensor::from_slice(&[5.0, 5.0, 5.0, 5.0], vec![4]).unwrap();
92 /// let std_dev = tensor.std();
93 /// assert!((std_dev.get(&[0]) - 0.0).abs() < 1e-6);
94 /// ```
95 ///
96 /// # Performance
97 ///
98 /// Uses optimized contiguous tensor path with 4x loop unrolling for better
99 /// performance. Non-contiguous tensors use stride-aware iteration.
100 /// The algorithm performs two passes: first to compute the mean, then to
101 /// compute the variance.
102 #[track_caller]
103 pub fn std(&self) -> Tensor {
104 let mut out = Tensor::new(vec![1]);
105 if self.size() == 0 {
106 out.fill(0.0);
107 } else {
108 // First pass: mean
109 let mut mean_val = 0.0f32;
110 let n = self.size() as f32;
111
112 if self.is_contiguous() {
113 // Fast path for contiguous tensors
114 unsafe {
115 let src = self.as_ptr();
116 let mut i = 0usize;
117 while i + 4 <= self.size() {
118 mean_val +=
119 *src.add(i) + *src.add(i + 1) + *src.add(i + 2) + *src.add(i + 3);
120 i += 4;
121 }
122 while i < self.size() {
123 mean_val += *src.add(i);
124 i += 1;
125 }
126 }
127 } else {
128 // Stride-aware path for non-contiguous tensors
129 let dims = self.shape().dims().to_vec();
130 for flat_idx in 0..self.size() {
131 // Convert flat index to multi-dimensional coordinates
132 let mut coords = vec![0; dims.len()];
133 let mut tmp = flat_idx;
134 for k in (0..dims.len()).rev() {
135 coords[k] = tmp % dims[k];
136 tmp /= dims[k];
137 }
138
139 // Get value using stride-aware offset
140 let offset = self.shape().offset(&coords);
141 let value = unsafe { *self.as_ptr().add(offset) };
142 mean_val += value;
143 }
144 }
145 mean_val /= n;
146
147 // Second pass: variance
148 let mut var_val = 0.0f32;
149
150 if self.is_contiguous() {
151 // Fast path for contiguous tensors
152 unsafe {
153 let src = self.as_ptr();
154 let mut i = 0usize;
155 while i + 4 <= self.size() {
156 let x0 = *src.add(i) - mean_val;
157 let x1 = *src.add(i + 1) - mean_val;
158 let x2 = *src.add(i + 2) - mean_val;
159 let x3 = *src.add(i + 3) - mean_val;
160 var_val += x0 * x0 + x1 * x1 + x2 * x2 + x3 * x3;
161 i += 4;
162 }
163 while i < self.size() {
164 let d = *src.add(i) - mean_val;
165 var_val += d * d;
166 i += 1;
167 }
168 }
169 } else {
170 // Stride-aware path for non-contiguous tensors
171 let dims = self.shape().dims().to_vec();
172 for flat_idx in 0..self.size() {
173 // Convert flat index to multi-dimensional coordinates
174 let mut coords = vec![0; dims.len()];
175 let mut tmp = flat_idx;
176 for k in (0..dims.len()).rev() {
177 coords[k] = tmp % dims[k];
178 tmp /= dims[k];
179 }
180
181 // Get value using stride-aware offset
182 let offset = self.shape().offset(&coords);
183 let value = unsafe { *self.as_ptr().add(offset) };
184 let d = value - mean_val;
185 var_val += d * d;
186 }
187 }
188 var_val /= n;
189
190 let std_val = var_val.sqrt();
191 unsafe {
192 *out.as_mut_ptr() = std_val;
193 }
194 }
195
196 if self.requires_grad() {
197 // Save tensors for backward
198 let mut result = out.clone();
199 result.set_requires_grad_internal(true);
200 let mean_tensor = {
201 let mut t = Tensor::new(vec![1]);
202 if self.size() == 0 {
203 t.fill(0.0);
204 } else {
205 let mut acc = 0.0f32;
206
207 if self.is_contiguous() {
208 unsafe {
209 for i in 0..self.size() {
210 acc += *self.as_ptr().add(i);
211 }
212 }
213 } else {
214 let dims = self.shape().dims().to_vec();
215 for flat_idx in 0..self.size() {
216 // Convert flat index to multi-dimensional coordinates
217 let mut coords = vec![0; dims.len()];
218 let mut tmp = flat_idx;
219 for k in (0..dims.len()).rev() {
220 coords[k] = tmp % dims[k];
221 tmp /= dims[k];
222 }
223
224 // Get value using stride-aware offset
225 let offset = self.shape().offset(&coords);
226 let value = unsafe { *self.as_ptr().add(offset) };
227 acc += value;
228 }
229 }
230
231 unsafe {
232 *t.as_mut_ptr() = acc / (self.size() as f32);
233 }
234 }
235 t
236 };
237 let std_saved = out.clone();
238 let grad_fn = GradFn::ReduceStd {
239 saved_mean: Box::new(mean_tensor),
240 saved_std: Box::new(std_saved),
241 saved_input: Box::new(self.clone()),
242 input_shape: self.shape().dims().to_vec(),
243 };
244 result.set_grad_fn(grad_fn.clone());
245 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
246 return result;
247 }
248
249 out
250 }
251
252 /// Computes the standard deviation over specified dimensions
253 ///
254 /// Reduces the tensor along the specified dimensions by computing the standard
255 /// deviation of each slice. The result maintains the original tensor structure
256 /// with reduced dimensions optionally preserved as size-1 dimensions.
257 ///
258 /// Uses population standard deviation (divides by n rather than n-1) to match
259 /// PyTorch's default behavior.
260 ///
261 /// # Arguments
262 ///
263 /// * `dims` - Vector of dimension indices to reduce over (must be valid for tensor rank)
264 /// * `keepdim` - Whether to keep reduced dimensions as size-1 dimensions
265 ///
266 /// # Returns
267 ///
268 /// A tensor with standard deviation computed over the specified dimensions
269 ///
270 /// # Examples
271 ///
272 /// ```
273 /// use train_station::Tensor;
274 ///
275 /// // Standard deviation along rows (dimension 1) with keepdim=true
276 /// let matrix = Tensor::from_slice(&[1.0, 3.0, 2.0, 2.0], vec![2, 2]).unwrap();
277 /// let row_stds = matrix.std_dims(&[1], true);
278 /// assert_eq!(row_stds.shape().dims(), vec![2, 1]);
279 /// assert!((row_stds.get(&[0, 0]) - 1.0).abs() < 1e-6); // std([1, 3]) = 1.0
280 /// assert!((row_stds.get(&[1, 0]) - 0.0).abs() < 1e-6); // std([2, 2]) = 0.0
281 /// ```
282 ///
283 /// ```
284 /// use train_station::Tensor;
285 ///
286 /// // Standard deviation along columns (dimension 0) with keepdim=false
287 /// let matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
288 /// let col_stds = matrix.std_dims(&[0], false);
289 /// assert_eq!(col_stds.shape().dims(), vec![2]);
290 /// // std([1, 3]) = 1.0, std([2, 4]) = 1.0
291 /// assert!((col_stds.get(&[0]) - 1.0).abs() < 1e-6);
292 /// assert!((col_stds.get(&[1]) - 1.0).abs() < 1e-6);
293 /// ```
294 ///
295 /// ```
296 /// use train_station::Tensor;
297 ///
298 /// // Standard deviation over multiple dimensions
299 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
300 /// let std_all = tensor.std_dims(&[0, 1], false);
301 /// assert_eq!(std_all.shape().dims(), vec![1]);
302 /// // std([1, 2, 3, 4]) = sqrt(1.25) ≈ 1.118
303 /// assert!((std_all.get(&[0]) - 1.25_f32.sqrt()).abs() < 1e-5);
304 /// ```
305 ///
306 /// # Panics
307 ///
308 /// * If `dims` is empty
309 /// * If any dimension index is out of bounds for the tensor rank
310 /// * If the reduced size is 0 (invalid for standard deviation calculation)
311 ///
312 /// # Performance
313 ///
314 /// Uses efficient coordinate-based iteration that works correctly with
315 /// both contiguous and non-contiguous tensor layouts. The algorithm performs
316 /// two passes: first to compute means, then to compute variances.
317 #[track_caller]
318 pub fn std_dims(&self, dims: &[usize], keepdim: bool) -> Tensor {
319 assert!(!dims.is_empty(), "std_dims requires at least one dimension");
320 let rank = self.shape().rank();
321 for &d in dims {
322 assert!(
323 d < rank,
324 "std_dims dim {} out of bounds for rank {}",
325 d,
326 rank
327 );
328 }
329
330 // Output shape
331 let mut out_dims = self.shape().dims().to_vec();
332 let mut reduced: Vec<usize> = dims.to_vec();
333 reduced.sort_unstable();
334 reduced.dedup();
335 for &d in reduced.iter() {
336 out_dims[d] = if keepdim { 1 } else { 0 };
337 }
338 if !keepdim {
339 out_dims.retain(|&s| s != 0);
340 }
341 if out_dims.is_empty() {
342 out_dims.push(1);
343 }
344 let mut mean = Tensor::zeros(out_dims.clone());
345 let mut var = Tensor::zeros(out_dims.clone());
346
347 let in_shape = self.shape().dims().to_vec();
348 let out_rank = mean.shape().rank();
349 let mut in_coords = vec![0usize; rank];
350 let n_reduced: usize = reduced.iter().map(|&d| in_shape[d]).product();
351 assert!(n_reduced > 0, "reduced size must be > 0");
352 unsafe {
353 let mptr = mean.as_mut_ptr();
354 let vptr = var.as_mut_ptr();
355
356 // First pass: sum to compute means
357 for lin in 0..self.size() {
358 let mut tmp = lin;
359 for i in (0..rank).rev() {
360 let s = in_shape[i];
361 in_coords[i] = if s == 0 { 0 } else { tmp % s };
362 if s != 0 {
363 tmp /= s;
364 }
365 }
366
367 // Get input value using stride-aware offset
368 let in_offset = self.shape().offset(&in_coords);
369 let value = *self.as_ptr().add(in_offset);
370
371 let mut out_coords: Vec<usize> = Vec::with_capacity(out_rank);
372 for (i, &c) in in_coords.iter().enumerate().take(rank) {
373 if reduced.contains(&i) {
374 if keepdim {
375 out_coords.push(0);
376 }
377 } else {
378 out_coords.push(c);
379 }
380 }
381 let off = if out_coords.is_empty() {
382 0
383 } else {
384 mean.shape().offset(&out_coords)
385 };
386 *mptr.add(off) += value;
387 }
388 let msize = mean.size();
389 for i in 0..msize {
390 *mptr.add(i) /= n_reduced as f32;
391 }
392
393 // Second pass: accumulate squared diffs
394 for lin in 0..self.size() {
395 let mut tmp = lin;
396 for i in (0..rank).rev() {
397 let s = in_shape[i];
398 in_coords[i] = if s == 0 { 0 } else { tmp % s };
399 if s != 0 {
400 tmp /= s;
401 }
402 }
403
404 // Get input value using stride-aware offset
405 let in_offset = self.shape().offset(&in_coords);
406 let x = *self.as_ptr().add(in_offset);
407
408 let mut out_coords: Vec<usize> = Vec::with_capacity(out_rank);
409 for (i, &c) in in_coords.iter().enumerate().take(rank) {
410 if reduced.contains(&i) {
411 if keepdim {
412 out_coords.push(0);
413 }
414 } else {
415 out_coords.push(c);
416 }
417 }
418 let off = if out_coords.is_empty() {
419 0
420 } else {
421 var.shape().offset(&out_coords)
422 };
423 let diff = x - *mptr.add(off);
424 *vptr.add(off) += diff * diff;
425 }
426 let vsize = var.size();
427 for i in 0..vsize {
428 *vptr.add(i) /= n_reduced as f32;
429 }
430 }
431 // std = sqrt(var)
432 let mut out = Tensor::zeros(out_dims.clone());
433 unsafe {
434 let d = out.as_mut_ptr();
435 let v = var.as_ptr();
436 for i in 0..out.size() {
437 *d.add(i) = (*v.add(i)).sqrt();
438 }
439 }
440
441 if self.requires_grad() {
442 let mut result = out.clone();
443 result.set_requires_grad_internal(true);
444 let grad_fn = GradFn::ReduceStdDims {
445 dims: reduced,
446 keepdim,
447 input_shape: self.shape().dims().to_vec(),
448 saved_mean: Box::new(mean),
449 saved_std: Box::new(out.clone()),
450 saved_input: Box::new(self.clone()),
451 };
452 result.set_grad_fn(grad_fn.clone());
453 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
454 return result;
455 }
456
457 out
458 }
459}
460
461#[cfg(test)]
462mod tests {
463 use super::*;
464
465 #[test]
466 fn test_std_forward_basic() {
467 let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
468 let s = x.std();
469 unsafe {
470 let v = *s.as_ptr();
471 assert!((v - 1.118_034).abs() < 1e-5);
472 }
473 }
474
475 #[test]
476 fn test_std_dims_forward() {
477 let x = Tensor::from_slice(&[1.0, 3.0, 2.0, 2.0], vec![2, 2]).unwrap();
478 let s = x.std_dims(&[1], true);
479 assert_eq!(s.shape().dims(), vec![2, 1]);
480 assert!((s.get(&[0, 0]) - 1.0).abs() < 1e-6);
481 assert!((s.get(&[1, 0]) - 0.0).abs() < 1e-6);
482 }
483
484 #[test]
485 fn test_std_non_contiguous_transpose() {
486 // Test std on transposed tensor (non-contiguous view)
487 let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
488 // Original: [[1, 2, 3], [4, 5, 6]]
489
490 let x_t = x.transpose(0, 1);
491 // Transposed: [[1, 4], [2, 5], [3, 6]]
492 assert!(!x_t.is_contiguous()); // Should be a view
493
494 let std_orig = x.std();
495 let std_view = x_t.std();
496
497 // Both should give the same result
498 assert!((std_orig.get(&[0]) - std_view.get(&[0])).abs() < 1e-6);
499
500 // Expected std of [1,2,3,4,5,6]: mean=3.5, var=mean([2.5^2,1.5^2,0.5^2,0.5^2,1.5^2,2.5^2])=2.9167
501 let expected_std = (2.9166667_f32).sqrt(); // ≈ 1.708
502 assert!((std_orig.get(&[0]) - expected_std).abs() < 1e-5);
503 }
504
505 #[test]
506 fn test_std_dims_non_contiguous() {
507 // Test std_dims on non-contiguous tensor
508 let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
509 let x_t = x.transpose(0, 1); // [3, 2]
510 assert!(!x_t.is_contiguous());
511
512 // Std along dim 0 of transposed tensor
513 let std_dim0 = x_t.std_dims(&[0], false);
514 assert_eq!(std_dim0.shape().dims(), vec![2]);
515
516 // For dim 0: [1,2,3] and [4,5,6]
517 // [1,2,3]: mean=2, var=((1-2)^2 + (2-2)^2 + (3-2)^2)/3 = 2/3, std=sqrt(2/3)≈0.816
518 // [4,5,6]: mean=5, var=((4-5)^2 + (5-5)^2 + (6-5)^2)/3 = 2/3, std=sqrt(2/3)≈0.816
519 let expected_std = (2.0 / 3.0_f32).sqrt();
520 assert!((std_dim0.get(&[0]) - expected_std).abs() < 1e-5);
521 assert!((std_dim0.get(&[1]) - expected_std).abs() < 1e-5);
522 }
523
524 #[test]
525 fn test_std_permuted_tensor() {
526 // Test with permuted tensor - simple case with known std
527 let data = vec![1.0, 3.0, 5.0, 7.0, 2.0, 4.0, 6.0, 8.0];
528 let x = Tensor::from_slice(&data, vec![2, 2, 2]).unwrap();
529
530 // Permute dimensions [2, 2, 2] -> [2, 2, 2] (swap first and last)
531 let x_perm = x.permute(vec![2, 1, 0]);
532 assert!(!x_perm.is_contiguous());
533
534 let std_orig = x.std();
535 let std_perm = x_perm.std();
536
537 // Should give same result
538 assert!((std_orig.get(&[0]) - std_perm.get(&[0])).abs() < 1e-6);
539
540 // Data is [1,3,5,7,2,4,6,8], mean=4.5
541 // var = mean([3.5^2, 1.5^2, 0.5^2, 2.5^2, 2.5^2, 0.5^2, 1.5^2, 3.5^2]) = 5.25
542 let expected_std = 5.25_f32.sqrt(); // ≈ 2.291
543 assert!((std_orig.get(&[0]) - expected_std).abs() < 1e-5);
544 }
545}