ta_lib_in_rust/indicators/volatility/stddev.rs
1use crate::util::dataframe_utils::check_window_size;
2use polars::prelude::*;
3
4/// Calculates Standard Deviation (StdDev) of a series over a window
5///
6/// # Arguments
7///
8/// * `df` - DataFrame containing the price data
9/// * `window` - Window size for StdDev calculation (typically 14 or 20)
10/// * `column` - Column to calculate StdDev on (usually "close")
11///
12/// # Returns
13///
14/// Returns a PolarsResult containing the StdDev Series
15///
16/// # Example
17///
18/// ```
19/// use polars::prelude::*;
20/// use ta_lib_in_rust::indicators::volatility::calculate_stddev;
21///
22/// let close = Series::new("close".into(), &[10.0, 11.0, 12.0, 9.0, 8.0, 10.0]);
23/// let df = DataFrame::new(vec![close.into()]).unwrap();
24///
25/// let stddev = calculate_stddev(&df, 3, "close").unwrap();
26/// ```
27pub fn calculate_stddev(df: &DataFrame, window: usize, column: &str) -> PolarsResult<Series> {
28 // Check window size
29 check_window_size(df, window, "StdDev")?;
30
31 // Check if the specified column exists
32 if !df.schema().contains(column) {
33 return Err(PolarsError::ShapeMismatch(
34 format!(
35 "DataFrame must contain '{}' column for StdDev calculation",
36 column
37 )
38 .into(),
39 ));
40 }
41
42 // Get the column to calculate StdDev on
43 let col = df.column(column)?.f64()?;
44
45 // Calculate rolling standard deviation
46 let mut stddev_values = Vec::with_capacity(df.height());
47
48 // Fill NaN for the first window-1 elements
49 for _ in 0..(window - 1) {
50 stddev_values.push(f64::NAN);
51 }
52
53 // Calculate StdDev for each window
54 for i in (window - 1)..df.height() {
55 let mut sum = 0.0;
56 let mut sum_sq = 0.0;
57 let mut count = 0;
58
59 for j in (i - window + 1)..=i {
60 let val = col.get(j).unwrap_or(f64::NAN);
61 if !val.is_nan() {
62 sum += val;
63 sum_sq += val * val;
64 count += 1;
65 }
66 }
67
68 if count > 1 {
69 let mean = sum / count as f64;
70 let variance = sum_sq / count as f64 - mean * mean;
71 let stddev = if variance > 0.0 { variance.sqrt() } else { 0.0 };
72 stddev_values.push(stddev);
73 } else {
74 stddev_values.push(f64::NAN);
75 }
76 }
77
78 Ok(Series::new("stddev".into(), stddev_values))
79}