ta_lib_in_rust/indicators/volatility/
atr.rs1use crate::util::dataframe_utils::check_window_size;
2use polars::prelude::*;
3
4pub fn calculate_atr(df: &DataFrame, window: usize) -> PolarsResult<Series> {
15 check_window_size(df, window, "ATR")?;
16
17 let high = df.column("high")?.f64()?.clone().into_series();
18 let low = df.column("low")?.f64()?.clone().into_series();
19 let close = df.column("close")?.f64()?.clone().into_series();
20
21 let prev_close = close.shift(1);
22 let mut tr_values = Vec::with_capacity(df.height());
23
24 let first_tr = {
25 let h = high.f64()?.get(0).unwrap_or(0.0);
26 let l = low.f64()?.get(0).unwrap_or(0.0);
27 h - l
28 };
29 tr_values.push(first_tr);
30
31 for i in 1..df.height() {
32 let h = high.f64()?.get(i).unwrap_or(0.0);
33 let l = low.f64()?.get(i).unwrap_or(0.0);
34 let pc = prev_close.f64()?.get(i).unwrap_or(0.0);
35
36 let tr = if pc == 0.0 {
37 h - l
38 } else {
39 (h - l).max((h - pc).abs()).max((l - pc).abs())
40 };
41 tr_values.push(tr);
42 }
43
44 let mut atr_values = Vec::with_capacity(df.height());
46
47 for _ in 0..(window - 1) {
49 atr_values.push(f64::NAN);
50 }
51
52 let mut atr = 0.0;
54 for &tr in tr_values.iter().take(window) {
55 atr += tr;
56 }
57 atr /= window as f64;
58 atr_values.push(atr);
59
60 for &tr in tr_values.iter().skip(window) {
62 atr = ((window as f64 - 1.0) * atr + tr) / window as f64;
63 atr_values.push(atr);
64 }
65
66 Ok(Series::new("atr".into(), atr_values))
67}