ta_lib_in_rust/indicators/oscillators/
stochastic.rs

1use polars::prelude::*;
2
3/// Calculates the Stochastic Oscillator, which consists of %K and %D lines
4///
5/// The Stochastic Oscillator is a momentum indicator that compares a security's closing price
6/// to its price range over a given time period. It's particularly useful for intraday trading
7/// to identify overbought and oversold conditions.
8///
9/// # Arguments
10///
11/// * `df` - DataFrame containing OHLCV data with "high", "low", and "close" columns
12/// * `k_period` - Lookback period for %K calculation (typically 14)
13/// * `d_period` - Smoothing period for %D calculation (typically 3)
14/// * `slowing` - Slowing period (typically 3)
15///
16/// # Returns
17///
18/// * `PolarsResult<(Series, Series)>` - Tuple containing %K and %D Series
19///
20/// # Formula
21///
22/// %K = 100 * (Close - Lowest Low) / (Highest High - Lowest Low)
23/// %D = SMA of %K over d_period
24///
25/// # Example
26///
27/// ```
28/// use polars::prelude::*;
29/// use ta_lib_in_rust::indicators::oscillators::calculate_stochastic;
30///
31/// // Create or load a DataFrame with OHLCV data
32/// let df = DataFrame::default(); // Replace with actual data
33///
34/// // Calculate Stochastic Oscillator with default parameters
35/// let (stoch_k, stoch_d) = calculate_stochastic(&df, 14, 3, 3).unwrap();
36/// ```
37pub fn calculate_stochastic(
38    df: &DataFrame,
39    k_period: usize,
40    d_period: usize,
41    slowing: usize,
42) -> PolarsResult<(Series, Series)> {
43    // Validate required columns
44    if !df.schema().contains("high")
45        || !df.schema().contains("low")
46        || !df.schema().contains("close")
47    {
48        return Err(PolarsError::ShapeMismatch(
49            "Missing required columns for Stochastic calculation. Required: high, low, close"
50                .to_string()
51                .into(),
52        ));
53    }
54
55    // Extract required columns
56    let high = df.column("high")?.f64()?;
57    let low = df.column("low")?.f64()?;
58    let close = df.column("close")?.f64()?;
59
60    // Calculate raw %K values
61    let mut raw_k_values = Vec::with_capacity(df.height());
62
63    // Fill initial values with NaN
64    for _ in 0..k_period - 1 {
65        raw_k_values.push(f64::NAN);
66    }
67
68    // Calculate raw %K for each data point
69    for i in k_period - 1..df.height() {
70        let mut highest_high = f64::NEG_INFINITY;
71        let mut lowest_low = f64::INFINITY;
72        let mut valid_data = true;
73
74        // Find highest high and lowest low in the period
75        for j in i - (k_period - 1)..=i {
76            let h = high.get(j).unwrap_or(f64::NAN);
77            let l = low.get(j).unwrap_or(f64::NAN);
78
79            if h.is_nan() || l.is_nan() {
80                valid_data = false;
81                break;
82            }
83
84            highest_high = highest_high.max(h);
85            lowest_low = lowest_low.min(l);
86        }
87
88        if !valid_data || (highest_high - lowest_low).abs() < 1e-10 {
89            raw_k_values.push(f64::NAN);
90        } else {
91            let c = close.get(i).unwrap_or(f64::NAN);
92            if c.is_nan() {
93                raw_k_values.push(f64::NAN);
94            } else {
95                let raw_k = 100.0 * (c - lowest_low) / (highest_high - lowest_low);
96                raw_k_values.push(raw_k);
97            }
98        }
99    }
100
101    // Apply slowing to %K (if slowing > 1)
102    let mut k_values = Vec::with_capacity(df.height());
103
104    // Fill initial values with NaN - ensure we have NaN for all values before k_period + slowing - 1
105    let k_offset = k_period + slowing - 1;
106    for _ in 0..k_offset {
107        k_values.push(f64::NAN);
108    }
109
110    // Calculate slowed %K
111    for i in k_offset..df.height() {
112        let mut sum = 0.0;
113        let mut count = 0;
114        let mut has_nan = false;
115
116        for j in 0..slowing {
117            let val = raw_k_values[i - j];
118            if val.is_nan() {
119                has_nan = true;
120                break;
121            }
122            sum += val;
123            count += 1;
124        }
125
126        if has_nan || count == 0 {
127            k_values.push(f64::NAN);
128        } else {
129            k_values.push(sum / count as f64);
130        }
131    }
132
133    // Calculate %D (SMA of %K)
134    let mut d_values = Vec::with_capacity(df.height());
135
136    // Fill initial values with NaN
137    let d_offset = k_offset + d_period - 1;
138    for _ in 0..d_offset {
139        d_values.push(f64::NAN);
140    }
141
142    // Calculate %D
143    for i in d_offset..df.height() {
144        let mut sum = 0.0;
145        let mut count = 0;
146        let mut has_nan = false;
147
148        for j in 0..d_period {
149            let val = k_values[i - j];
150            if val.is_nan() {
151                has_nan = true;
152                break;
153            }
154            sum += val;
155            count += 1;
156        }
157
158        if has_nan || count == 0 {
159            d_values.push(f64::NAN);
160        } else {
161            d_values.push(sum / count as f64);
162        }
163    }
164
165    // Create Series with names that reflect parameters
166    let k_name = format!("stoch_k_{}_{}_{}", k_period, slowing, d_period);
167    let d_name = format!("stoch_d_{}_{}_{}", k_period, slowing, d_period);
168
169    Ok((
170        Series::new(k_name.into(), k_values),
171        Series::new(d_name.into(), d_values),
172    ))
173}