ta_lib_in_rust/indicators/volatility/
atr.rs

1use crate::util::dataframe_utils::check_window_size;
2use polars::prelude::*;
3
4/// Calculates Average True Range (ATR)
5///
6/// # Arguments
7///
8/// * `df` - DataFrame containing the price data
9/// * `window` - Window size for ATR (typically 14)
10///
11/// # Returns
12///
13/// Returns a PolarsResult containing the ATR Series
14pub fn calculate_atr(df: &DataFrame, window: usize) -> PolarsResult<Series> {
15    check_window_size(df, window, "ATR")?;
16
17    let high = df.column("high")?.f64()?.clone().into_series();
18    let low = df.column("low")?.f64()?.clone().into_series();
19    let close = df.column("close")?.f64()?.clone().into_series();
20
21    let prev_close = close.shift(1);
22    let mut tr_values = Vec::with_capacity(df.height());
23
24    let first_tr = {
25        let h = high.f64()?.get(0).unwrap_or(0.0);
26        let l = low.f64()?.get(0).unwrap_or(0.0);
27        h - l
28    };
29    tr_values.push(first_tr);
30
31    for i in 1..df.height() {
32        let h = high.f64()?.get(i).unwrap_or(0.0);
33        let l = low.f64()?.get(i).unwrap_or(0.0);
34        let pc = prev_close.f64()?.get(i).unwrap_or(0.0);
35
36        let tr = if pc == 0.0 {
37            h - l
38        } else {
39            (h - l).max((h - pc).abs()).max((l - pc).abs())
40        };
41        tr_values.push(tr);
42    }
43
44    // Implement Wilder's smoothing for ATR
45    let mut atr_values = Vec::with_capacity(df.height());
46
47    // Fill with NaN for the first window-1 elements
48    for _ in 0..(window - 1) {
49        atr_values.push(f64::NAN);
50    }
51
52    // Initialize ATR with simple average of first window TR values
53    let mut atr = 0.0;
54    for &tr in tr_values.iter().take(window) {
55        atr += tr;
56    }
57    atr /= window as f64;
58    atr_values.push(atr);
59
60    // Apply Wilder's smoothing formula: ATR(t) = ((window-1) * ATR(t-1) + TR(t)) / window
61    for &tr in tr_values.iter().skip(window) {
62        atr = ((window as f64 - 1.0) * atr + tr) / window as f64;
63        atr_values.push(atr);
64    }
65
66    Ok(Series::new("atr".into(), atr_values))
67}