ta_lib_in_rust/indicators/
add_indicators.rs

1use crate::indicators::{
2    moving_averages::{calculate_ema, calculate_sma},
3    oscillators::{calculate_macd, calculate_rsi},
4    volatility::{
5        calculate_atr, calculate_bb_b, calculate_bollinger_bands, calculate_gk_volatility,
6    },
7};
8use crate::util::dataframe_utils::ensure_f64_column;
9use crate::util::time_utils::create_cyclical_time_features;
10use polars::prelude::*;
11
12/// Adds all technical indicators to the DataFrame
13///
14/// # Arguments
15///
16/// * `df` - DataFrame to add indicators to
17///
18/// # Returns
19///
20/// Returns a PolarsResult containing the enhanced DataFrame
21pub fn add_technical_indicators(df: &mut DataFrame) -> PolarsResult<DataFrame> {
22    // Convert numeric columns to Float64 by mutating in-place via Column
23    let numeric_columns = ["open", "high", "low", "close", "volume"];
24    for col_name in numeric_columns {
25        // Skip if column doesn't exist
26        if !df.schema().contains(col_name) {
27            continue;
28        }
29
30        ensure_f64_column(df, col_name)?;
31    }
32
33    // Calculate moving averages
34    let sma20 = calculate_sma(df, "close", 20)?.with_name("sma_20".into());
35    let sma50 = calculate_sma(df, "close", 50)?.with_name("sma_50".into());
36    let ema20 = calculate_ema(df, "close", 20)?.with_name("ema_20".into());
37
38    // Calculate oscillators
39    let rsi = calculate_rsi(df, 14, "close")?.with_name("rsi_14".into());
40    let (macd, macd_signal) = calculate_macd(df, 12, 26, 9, "close")?;
41    let macd = macd.with_name("macd".into());
42    let macd_signal = macd_signal.with_name("macd_signal".into());
43
44    // Calculate volatility indicators
45    let (bb_middle, bb_upper, bb_lower) = calculate_bollinger_bands(df, 20, 2.0, "close")?;
46    let bb_middle = bb_middle.with_name("bb_middle".into());
47    let bb_upper = bb_upper.with_name("bb_upper".into());
48    let bb_lower = bb_lower.with_name("bb_lower".into());
49    let bb_b = calculate_bb_b(df, 20, 2.0, "close")?.with_name("bb_b".into());
50    let atr = calculate_atr(df, 14)?.with_name("atr_14".into());
51    let gk_vol = calculate_gk_volatility(df, 10)?.with_name("gk_volatility".into());
52
53    // Calculate price dynamics
54    let close = df.column("close")?.f64()?;
55    let prev_close = close.shift(1);
56
57    // Calculate percentage returns
58    let returns = ((close.clone() - prev_close.clone()) / prev_close.clone())
59        .with_name("returns".into())
60        .into_series();
61
62    // Calculate daily price range
63    let high = df.column("high")?.f64()?;
64    let low = df.column("low")?.f64()?;
65    let price_range = ((high.clone() - low.clone()) / close.clone())
66        .with_name("price_range".into())
67        .into_series();
68
69    // Add lag features
70    let close_lag_5 = close.shift(5).with_name("close_lag_5".into());
71    let close_lag_15 = close.shift(15).with_name("close_lag_15".into());
72    let close_lag_30 = close.shift(30).with_name("close_lag_30".into());
73
74    // Returns over different time windows
75    let close_lag_5_clone = close_lag_5.clone();
76    let returns_5min = ((close.clone() - close_lag_5_clone.clone()) / close_lag_5_clone)
77        .with_name("returns_5min".into())
78        .into_series();
79
80    // Shorter-term volatility (15-min window)
81    let mut vol_15min = Vec::with_capacity(df.height());
82    for i in 0..df.height() {
83        if i < 15 {
84            vol_15min.push(0.0);
85            continue;
86        }
87
88        let mut returns = Vec::with_capacity(15);
89        for j in (i - 15)..i {
90            // Safely access the current price value
91            let current_opt = close.get(j);
92            // Safely access the previous price value, checking if j-1 is valid
93            let previous_opt = if j > 0 { close.get(j - 1) } else { None };
94
95            // Only calculate return if both values are valid and previous is not zero
96            if let (Some(current), Some(previous)) = (current_opt, previous_opt) {
97                if previous != 0.0 {
98                    returns.push((current - previous) / previous);
99                }
100            }
101        }
102
103        // Calculate standard deviation of returns
104        if returns.is_empty() {
105            vol_15min.push(0.0);
106            continue;
107        }
108
109        let mean = returns.iter().sum::<f64>() / returns.len() as f64;
110        let variance =
111            returns.iter().map(|r| (r - mean).powi(2)).sum::<f64>() / returns.len() as f64;
112        vol_15min.push(variance.sqrt());
113    }
114    let volatility_15min = Series::new("volatility_15min".into(), vol_15min);
115
116    // Time-based features
117    let mut time_features = Vec::new();
118    if df.schema().contains("time") {
119        time_features = create_cyclical_time_features(df, "time", "%Y-%m-%d %H:%M:%S UTC")?;
120    }
121
122    // Add all features to the DataFrame
123    let mut features_to_add = vec![
124        sma20,
125        sma50,
126        ema20,
127        rsi,
128        macd,
129        macd_signal,
130        bb_middle,
131        bb_upper,
132        bb_lower,
133        bb_b,
134        atr,
135        gk_vol,
136        returns,
137        price_range,
138        close_lag_5.into_series(),
139        close_lag_15.into_series(),
140        close_lag_30.into_series(),
141        returns_5min,
142        volatility_15min,
143    ];
144
145    // Add time features if available
146    features_to_add.extend(time_features);
147
148    for feature in features_to_add {
149        df.with_column(feature)?;
150    }
151
152    Ok(df.clone())
153}