Skip to main content

wickra_core/indicators/
atr.rs

1//! Average True Range (Wilder).
2
3use crate::error::{Error, Result};
4use crate::ohlcv::Candle;
5use crate::traits::Indicator;
6
7/// Average True Range with Wilder smoothing.
8///
9/// The first emitted value, by convention, appears after `period` candles: the
10/// first `period − 1` true-range values seed the Wilder average alongside the
11/// `period`-th, then the smoothed update begins.
12///
13/// # Example
14///
15/// ```
16/// use wickra_core::{Candle, Indicator, Atr};
17///
18/// let mut indicator = Atr::new(5).unwrap();
19/// let mut last = None;
20/// for i in 0..80 {
21///     let base = 100.0 + f64::from(i);
22///     let candle =
23///         Candle::new(base, base + 2.0, base - 2.0, base + 1.0, 10.0, i64::from(i)).unwrap();
24///     last = indicator.update(candle);
25/// }
26/// assert!(last.is_some());
27/// ```
28#[derive(Debug, Clone)]
29pub struct Atr {
30    period: usize,
31    /// `period - 1` as `f64`, precomputed for the Wilder smoothing step.
32    n_minus_1: f64,
33    /// `1 / period`, precomputed so the per-tick smoothing multiplies instead of
34    /// divides.
35    inv_period: f64,
36    prev_close: Option<f64>,
37    seed_buf: Vec<f64>,
38    /// Smoothed ATR, valid once `seeded` is set. Bare `f64` + flag rather than
39    /// `Option<f64>` so the hot recurrence avoids an enum-tag read per tick.
40    avg: f64,
41    seeded: bool,
42}
43
44impl Atr {
45    /// Construct an ATR with the given Wilder period.
46    ///
47    /// # Errors
48    ///
49    /// Returns [`Error::PeriodZero`] if `period == 0`.
50    pub fn new(period: usize) -> Result<Self> {
51        if period == 0 {
52            return Err(Error::PeriodZero);
53        }
54        Ok(Self {
55            period,
56            n_minus_1: (period - 1) as f64,
57            inv_period: 1.0 / period as f64,
58            prev_close: None,
59            seed_buf: Vec::with_capacity(period),
60            avg: 0.0,
61            seeded: false,
62        })
63    }
64
65    /// Configured period.
66    pub const fn period(&self) -> usize {
67        self.period
68    }
69
70    /// Current value if available.
71    pub const fn value(&self) -> Option<f64> {
72        if self.seeded {
73            Some(self.avg)
74        } else {
75            None
76        }
77    }
78}
79
80impl Indicator for Atr {
81    type Input = Candle;
82    type Output = f64;
83
84    fn update(&mut self, candle: Candle) -> Option<f64> {
85        let tr = candle.true_range(self.prev_close);
86        self.prev_close = Some(candle.close);
87
88        if self.seeded {
89            // Wilder smoothing with the reciprocal hoisted out of the hot path.
90            let new_avg = self.avg.mul_add(self.n_minus_1, tr) * self.inv_period;
91            self.avg = new_avg;
92            return Some(new_avg);
93        }
94
95        self.seed_buf.push(tr);
96        if self.seed_buf.len() == self.period {
97            let seed = self.seed_buf.iter().copied().sum::<f64>() / self.period as f64;
98            self.avg = seed;
99            self.seeded = true;
100            return Some(seed);
101        }
102        None
103    }
104
105    fn reset(&mut self) {
106        self.prev_close = None;
107        self.seed_buf.clear();
108        self.avg = 0.0;
109        self.seeded = false;
110    }
111
112    fn warmup_period(&self) -> usize {
113        self.period
114    }
115
116    fn is_ready(&self) -> bool {
117        self.seeded
118    }
119
120    fn name(&self) -> &'static str {
121        "ATR"
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128    use crate::traits::BatchExt;
129    use approx::assert_relative_eq;
130
131    fn c(h: f64, l: f64, cl: f64) -> Candle {
132        // ts/open/volume don't affect ATR; use safe placeholders.
133        Candle::new(cl, h, l, cl, 1.0, 0).unwrap()
134    }
135
136    /// Independent reference: Wilder ATR computed straight from the definition.
137    fn atr_naive(hlc: &[(f64, f64, f64)], period: usize) -> Vec<Option<f64>> {
138        let n = period as f64;
139        let mut out = Vec::with_capacity(hlc.len());
140        let mut trs: Vec<f64> = Vec::new();
141        let mut avg: Option<f64> = None;
142        let mut prev_close: Option<f64> = None;
143        for &(h, l, cl) in hlc {
144            let tr = match prev_close {
145                None => h - l,
146                Some(pc) => (h - l).max((h - pc).abs()).max((l - pc).abs()),
147            };
148            prev_close = Some(cl);
149            if let Some(a) = avg {
150                let na = (a * (n - 1.0) + tr) / n;
151                avg = Some(na);
152                out.push(Some(na));
153            } else {
154                trs.push(tr);
155                if trs.len() == period {
156                    avg = Some(trs.iter().sum::<f64>() / n);
157                    out.push(avg);
158                } else {
159                    out.push(None);
160                }
161            }
162        }
163        out
164    }
165
166    #[test]
167    fn rejects_zero_period() {
168        assert!(matches!(Atr::new(0), Err(Error::PeriodZero)));
169    }
170
171    /// Cover the const accessors `period` / `value` (54-62) and the
172    /// Indicator-impl `name` body (103-105). Existing tests inspect
173    /// numeric ATR output but never query the metadata.
174    #[test]
175    fn accessors_and_metadata() {
176        let mut atr = Atr::new(14).unwrap();
177        assert_eq!(atr.period(), 14);
178        assert_eq!(atr.name(), "ATR");
179        assert_eq!(atr.value(), None);
180        for _ in 0..14 {
181            atr.update(c(11.0, 9.0, 10.0));
182        }
183        assert!(atr.value().is_some());
184    }
185
186    #[test]
187    fn warmup_emits_on_period_th_candle() {
188        let candles = vec![
189            c(2.0, 1.0, 1.5),
190            c(3.0, 2.0, 2.5),
191            c(4.0, 3.0, 3.5),
192            c(5.0, 4.0, 4.5),
193            c(6.0, 5.0, 5.5),
194        ];
195        let mut atr = Atr::new(3).unwrap();
196        let out = atr.batch(&candles);
197        assert!(out[0].is_none());
198        assert!(out[1].is_none());
199        assert!(out[2].is_some());
200        assert!(out[3].is_some());
201    }
202
203    #[test]
204    fn constant_range_yields_constant_atr() {
205        // Every candle has H=11, L=9, C=10 -> TR=2 (no gaps).
206        let candles: Vec<Candle> = (0..30).map(|_| c(11.0, 9.0, 10.0)).collect();
207        let mut atr = Atr::new(14).unwrap();
208        let out = atr.batch(&candles);
209        for v in out.iter().skip(13).flatten() {
210            assert_relative_eq!(*v, 2.0, epsilon = 1e-12);
211        }
212    }
213
214    #[test]
215    fn gap_up_uses_high_minus_prev_close() {
216        // Previous close 5, current candle H=10 L=9 C=9.5 -> TR = max(1, 5, 4) = 5.
217        let candles = vec![
218            c(6.0, 4.0, 5.0),  // prev close = 5
219            c(10.0, 9.0, 9.5), // TR = 5
220        ];
221        let mut atr = Atr::new(2).unwrap();
222        let out = atr.batch(&candles);
223        // Seed window covers TR_1 and TR_2. TR_1 = H1-L1 = 2 (no prev close). TR_2 = 5.
224        // Seed = (2+5)/2 = 3.5
225        assert_relative_eq!(out[1].unwrap(), 3.5, epsilon = 1e-12);
226    }
227
228    #[test]
229    fn batch_equals_streaming() {
230        let candles: Vec<Candle> = (0..40)
231            .map(|i| {
232                let mid = f64::from(i) + 10.0;
233                c(mid + 0.5, mid - 0.5, mid)
234            })
235            .collect();
236        let mut a = Atr::new(14).unwrap();
237        let mut b = Atr::new(14).unwrap();
238        assert_eq!(
239            a.batch(&candles),
240            candles.iter().map(|x| b.update(*x)).collect::<Vec<_>>()
241        );
242    }
243
244    #[test]
245    fn reset_clears_state() {
246        let candles: Vec<Candle> = (0..20).map(|_| c(11.0, 9.0, 10.0)).collect();
247        let mut atr = Atr::new(5).unwrap();
248        atr.batch(&candles);
249        assert!(atr.is_ready());
250        atr.reset();
251        assert!(!atr.is_ready());
252        assert_eq!(atr.update(candles[0]), None);
253    }
254
255    #[test]
256    fn never_negative() {
257        let candles: Vec<Candle> = (0..200)
258            .map(|i| {
259                let base = 100.0 + (f64::from(i) * 0.3).sin() * 5.0;
260                c(base + 1.0, base - 1.0, base)
261            })
262            .collect();
263        let mut atr = Atr::new(14).unwrap();
264        for v in atr.batch(&candles).into_iter().flatten() {
265            assert!(v >= 0.0, "ATR must be non-negative: {v}");
266        }
267    }
268
269    proptest::proptest! {
270        #![proptest_config(proptest::test_runner::Config::with_cases(48))]
271        #[test]
272        fn atr_matches_naive(
273            period in 1usize..15,
274            bars in proptest::collection::vec(
275                (10.0_f64..1000.0, 0.0_f64..50.0, 0.0_f64..1.0),
276                0..120,
277            ),
278        ) {
279            // bars: (low, range, close_fraction) -> a valid OHLC candle.
280            let hlc: Vec<(f64, f64, f64)> = bars
281                .iter()
282                .map(|&(low, range, frac)| (low + range, low, low + range * frac))
283                .collect();
284            let candles: Vec<Candle> = hlc.iter().map(|&(h, l, cl)| c(h, l, cl)).collect();
285            let mut atr = Atr::new(period).unwrap();
286            let got = atr.batch(&candles);
287            let want = atr_naive(&hlc, period);
288            proptest::prop_assert_eq!(got.len(), want.len());
289            for (g, w) in got.iter().zip(want.iter()) {
290                match (g, w) {
291                    (None, None) => {}
292                    (Some(a), Some(b)) => proptest::prop_assert!(
293                        (a - b).abs() <= 1e-9 * a.abs().max(1.0),
294                        "got={a} want={b}"
295                    ),
296                    _ => proptest::prop_assert!(false, "warmup mismatch"),
297                }
298            }
299        }
300    }
301}