ta_lib_in_rust/indicators/
add_indicators.rs1use crate::indicators::{
2 moving_averages::{calculate_ema, calculate_sma},
3 oscillators::{calculate_macd, calculate_rsi},
4 volatility::{
5 calculate_atr, calculate_bb_b, calculate_bollinger_bands, calculate_gk_volatility,
6 },
7};
8use crate::util::dataframe_utils::ensure_f64_column;
9use crate::util::time_utils::create_cyclical_time_features;
10use polars::prelude::*;
11
12pub fn add_technical_indicators(df: &mut DataFrame) -> PolarsResult<DataFrame> {
22 let numeric_columns = ["open", "high", "low", "close", "volume"];
24 for col_name in numeric_columns {
25 if !df.schema().contains(col_name) {
27 continue;
28 }
29
30 ensure_f64_column(df, col_name)?;
31 }
32
33 let sma20 = calculate_sma(df, "close", 20)?.with_name("sma_20".into());
35 let sma50 = calculate_sma(df, "close", 50)?.with_name("sma_50".into());
36 let ema20 = calculate_ema(df, "close", 20)?.with_name("ema_20".into());
37
38 let rsi = calculate_rsi(df, 14, "close")?.with_name("rsi_14".into());
40 let (macd, macd_signal) = calculate_macd(df, 12, 26, 9, "close")?;
41 let macd = macd.with_name("macd".into());
42 let macd_signal = macd_signal.with_name("macd_signal".into());
43
44 let (bb_middle, bb_upper, bb_lower) = calculate_bollinger_bands(df, 20, 2.0, "close")?;
46 let bb_middle = bb_middle.with_name("bb_middle".into());
47 let bb_upper = bb_upper.with_name("bb_upper".into());
48 let bb_lower = bb_lower.with_name("bb_lower".into());
49 let bb_b = calculate_bb_b(df, 20, 2.0, "close")?.with_name("bb_b".into());
50 let atr = calculate_atr(df, 14)?.with_name("atr_14".into());
51 let gk_vol = calculate_gk_volatility(df, 10)?.with_name("gk_volatility".into());
52
53 let close = df.column("close")?.f64()?;
55 let prev_close = close.shift(1);
56
57 let returns = ((close.clone() - prev_close.clone()) / prev_close.clone())
59 .with_name("returns".into())
60 .into_series();
61
62 let high = df.column("high")?.f64()?;
64 let low = df.column("low")?.f64()?;
65 let price_range = ((high.clone() - low.clone()) / close.clone())
66 .with_name("price_range".into())
67 .into_series();
68
69 let close_lag_5 = close.shift(5).with_name("close_lag_5".into());
71 let close_lag_15 = close.shift(15).with_name("close_lag_15".into());
72 let close_lag_30 = close.shift(30).with_name("close_lag_30".into());
73
74 let close_lag_5_clone = close_lag_5.clone();
76 let returns_5min = ((close.clone() - close_lag_5_clone.clone()) / close_lag_5_clone)
77 .with_name("returns_5min".into())
78 .into_series();
79
80 let mut vol_15min = Vec::with_capacity(df.height());
82 for i in 0..df.height() {
83 if i < 15 {
84 vol_15min.push(0.0);
85 continue;
86 }
87
88 let mut returns = Vec::with_capacity(15);
89 for j in (i - 15)..i {
90 let current_opt = close.get(j);
92 let previous_opt = if j > 0 { close.get(j - 1) } else { None };
94
95 if let (Some(current), Some(previous)) = (current_opt, previous_opt) {
97 if previous != 0.0 {
98 returns.push((current - previous) / previous);
99 }
100 }
101 }
102
103 if returns.is_empty() {
105 vol_15min.push(0.0);
106 continue;
107 }
108
109 let mean = returns.iter().sum::<f64>() / returns.len() as f64;
110 let variance =
111 returns.iter().map(|r| (r - mean).powi(2)).sum::<f64>() / returns.len() as f64;
112 vol_15min.push(variance.sqrt());
113 }
114 let volatility_15min = Series::new("volatility_15min".into(), vol_15min);
115
116 let mut time_features = Vec::new();
118 if df.schema().contains("time") {
119 time_features = create_cyclical_time_features(df, "time", "%Y-%m-%d %H:%M:%S UTC")?;
120 }
121
122 let mut features_to_add = vec![
124 sma20,
125 sma50,
126 ema20,
127 rsi,
128 macd,
129 macd_signal,
130 bb_middle,
131 bb_upper,
132 bb_lower,
133 bb_b,
134 atr,
135 gk_vol,
136 returns,
137 price_range,
138 close_lag_5.into_series(),
139 close_lag_15.into_series(),
140 close_lag_30.into_series(),
141 returns_5min,
142 volatility_15min,
143 ];
144
145 features_to_add.extend(time_features);
147
148 for feature in features_to_add {
149 df.with_column(feature)?;
150 }
151
152 Ok(df.clone())
153}