Skip to main content

wickra_core/indicators/
atr.rs

1//! Average True Range (Wilder).
2
3use crate::error::{Error, Result};
4use crate::ohlcv::Candle;
5use crate::traits::Indicator;
6
7/// Average True Range with Wilder smoothing.
8///
9/// The first emitted value, by convention, appears after `period` candles: the
10/// first `period − 1` true-range values seed the Wilder average alongside the
11/// `period`-th, then the smoothed update begins.
12///
13/// # Example
14///
15/// ```
16/// use wickra_core::{Candle, Indicator, Atr};
17///
18/// let mut indicator = Atr::new(5).unwrap();
19/// let mut last = None;
20/// for i in 0..80 {
21///     let base = 100.0 + f64::from(i);
22///     let candle =
23///         Candle::new(base, base + 2.0, base - 2.0, base + 1.0, 10.0, i64::from(i)).unwrap();
24///     last = indicator.update(candle);
25/// }
26/// assert!(last.is_some());
27/// ```
28#[derive(Debug, Clone)]
29pub struct Atr {
30    period: usize,
31    /// `period - 1` as `f64`, precomputed for the Wilder smoothing step.
32    n_minus_1: f64,
33    /// `1 / period`, precomputed so the per-tick smoothing multiplies instead of
34    /// divides.
35    inv_period: f64,
36    prev_close: Option<f64>,
37    seed_buf: Vec<f64>,
38    /// Smoothed ATR, valid once `seeded` is set. Bare `f64` + flag rather than
39    /// `Option<f64>` so the hot recurrence avoids an enum-tag read per tick.
40    avg: f64,
41    seeded: bool,
42}
43
44impl Atr {
45    /// Construct an ATR with the given Wilder period.
46    ///
47    /// # Errors
48    ///
49    /// Returns [`Error::PeriodZero`] if `period == 0`.
50    pub fn new(period: usize) -> Result<Self> {
51        if period == 0 {
52            return Err(Error::PeriodZero);
53        }
54        Ok(Self {
55            period,
56            n_minus_1: (period - 1) as f64,
57            inv_period: 1.0 / period as f64,
58            prev_close: None,
59            seed_buf: Vec::with_capacity(period),
60            avg: 0.0,
61            seeded: false,
62        })
63    }
64
65    /// Configured period.
66    pub const fn period(&self) -> usize {
67        self.period
68    }
69
70    /// Current value if available.
71    pub const fn value(&self) -> Option<f64> {
72        if self.seeded {
73            Some(self.avg)
74        } else {
75            None
76        }
77    }
78
79    /// Vectorized batch over raw high/low/close columns: one `f64` per bar
80    /// (`NaN` during warmup). The caller guarantees the three slices are equal
81    /// length and finite with valid OHLC ordering (the binding validates once up
82    /// front); ATR only reads high, low and the previous close.
83    ///
84    /// For a fresh indicator long enough to seed (`n >= period`) it runs the
85    /// true-range seed once and then the bare Wilder recurrence in a tight loop —
86    /// no per-bar `Candle` construction/validation, no `Option`, identical
87    /// division at the seed and `mul_add` afterwards, so the result is
88    /// *bit-for-bit* equal to replaying `update` over the same candles. Shorter
89    /// or non-fresh inputs defer to an exact `update` replay.
90    pub fn batch_atr(&mut self, high: &[f64], low: &[f64], close: &[f64]) -> Vec<f64> {
91        let p = self.period;
92        let n = high.len();
93        if self.seeded || !self.seed_buf.is_empty() || self.prev_close.is_some() || n < p {
94            let mut out = vec![f64::NAN; n];
95            for i in 0..n {
96                let candle = Candle::new_unchecked(close[i], high[i], low[i], close[i], 0.0, 0);
97                if let Some(v) = self.update(candle) {
98                    out[i] = v;
99                }
100            }
101            return out;
102        }
103
104        // Warmup `[0, p-1)` is `NaN`; the first ATR is emitted at index `p - 1`.
105        let mut out = vec![f64::NAN; p - 1];
106        out.reserve(n - (p - 1));
107        // Seed: mean of the first `period` true ranges. TRâ‚€ has no previous close.
108        let mut prev_close = close[0];
109        let mut sum_tr = high[0] - low[0];
110        self.seed_buf.push(sum_tr);
111        for i in 1..p {
112            let (h, l) = (high[i], low[i]);
113            let tr = (h - l)
114                .max((h - prev_close).abs())
115                .max((l - prev_close).abs());
116            prev_close = close[i];
117            self.seed_buf.push(tr);
118            sum_tr += tr;
119        }
120        let mut avg = sum_tr / p as f64;
121        out.push(avg);
122        // Steady state: Wilder smoothing, reciprocal hoisted out of the loop.
123        for i in p..n {
124            let (h, l) = (high[i], low[i]);
125            let tr = (h - l)
126                .max((h - prev_close).abs())
127                .max((l - prev_close).abs());
128            prev_close = close[i];
129            avg = avg.mul_add(self.n_minus_1, tr) * self.inv_period;
130            out.push(avg);
131        }
132
133        // Leave state where a full `update` replay would (seeded; seed_buf retained).
134        self.prev_close = Some(prev_close);
135        self.avg = avg;
136        self.seeded = true;
137        out
138    }
139}
140
141impl Indicator for Atr {
142    type Input = Candle;
143    type Output = f64;
144
145    fn update(&mut self, candle: Candle) -> Option<f64> {
146        let tr = candle.true_range(self.prev_close);
147        self.prev_close = Some(candle.close);
148
149        if self.seeded {
150            // Wilder smoothing with the reciprocal hoisted out of the hot path.
151            let new_avg = self.avg.mul_add(self.n_minus_1, tr) * self.inv_period;
152            self.avg = new_avg;
153            return Some(new_avg);
154        }
155
156        self.seed_buf.push(tr);
157        if self.seed_buf.len() == self.period {
158            let seed = self.seed_buf.iter().copied().sum::<f64>() / self.period as f64;
159            self.avg = seed;
160            self.seeded = true;
161            return Some(seed);
162        }
163        None
164    }
165
166    fn reset(&mut self) {
167        self.prev_close = None;
168        self.seed_buf.clear();
169        self.avg = 0.0;
170        self.seeded = false;
171    }
172
173    fn warmup_period(&self) -> usize {
174        self.period
175    }
176
177    fn is_ready(&self) -> bool {
178        self.seeded
179    }
180
181    fn name(&self) -> &'static str {
182        "ATR"
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use crate::traits::BatchExt;
190    use approx::assert_relative_eq;
191
192    fn c(h: f64, l: f64, cl: f64) -> Candle {
193        // ts/open/volume don't affect ATR; use safe placeholders.
194        Candle::new(cl, h, l, cl, 1.0, 0).unwrap()
195    }
196
197    /// Independent reference: Wilder ATR computed straight from the definition.
198    fn atr_naive(hlc: &[(f64, f64, f64)], period: usize) -> Vec<Option<f64>> {
199        let n = period as f64;
200        let mut out = Vec::with_capacity(hlc.len());
201        let mut trs: Vec<f64> = Vec::new();
202        let mut avg: Option<f64> = None;
203        let mut prev_close: Option<f64> = None;
204        for &(h, l, cl) in hlc {
205            let tr = match prev_close {
206                None => h - l,
207                Some(pc) => (h - l).max((h - pc).abs()).max((l - pc).abs()),
208            };
209            prev_close = Some(cl);
210            if let Some(a) = avg {
211                let na = (a * (n - 1.0) + tr) / n;
212                avg = Some(na);
213                out.push(Some(na));
214            } else {
215                trs.push(tr);
216                if trs.len() == period {
217                    avg = Some(trs.iter().sum::<f64>() / n);
218                    out.push(avg);
219                } else {
220                    out.push(None);
221                }
222            }
223        }
224        out
225    }
226
227    #[test]
228    fn rejects_zero_period() {
229        assert!(matches!(Atr::new(0), Err(Error::PeriodZero)));
230    }
231
232    /// Cover the const accessors `period` / `value` (54-62) and the
233    /// Indicator-impl `name` body (103-105). Existing tests inspect
234    /// numeric ATR output but never query the metadata.
235    #[test]
236    fn accessors_and_metadata() {
237        let mut atr = Atr::new(14).unwrap();
238        assert_eq!(atr.period(), 14);
239        assert_eq!(atr.name(), "ATR");
240        assert_eq!(atr.value(), None);
241        for _ in 0..14 {
242            atr.update(c(11.0, 9.0, 10.0));
243        }
244        assert!(atr.value().is_some());
245    }
246
247    #[test]
248    fn warmup_emits_on_period_th_candle() {
249        let candles = vec![
250            c(2.0, 1.0, 1.5),
251            c(3.0, 2.0, 2.5),
252            c(4.0, 3.0, 3.5),
253            c(5.0, 4.0, 4.5),
254            c(6.0, 5.0, 5.5),
255        ];
256        let mut atr = Atr::new(3).unwrap();
257        let out = atr.batch(&candles);
258        assert!(out[0].is_none());
259        assert!(out[1].is_none());
260        assert!(out[2].is_some());
261        assert!(out[3].is_some());
262    }
263
264    #[test]
265    fn constant_range_yields_constant_atr() {
266        // Every candle has H=11, L=9, C=10 -> TR=2 (no gaps).
267        let candles: Vec<Candle> = (0..30).map(|_| c(11.0, 9.0, 10.0)).collect();
268        let mut atr = Atr::new(14).unwrap();
269        let out = atr.batch(&candles);
270        for v in out.iter().skip(13).flatten() {
271            assert_relative_eq!(*v, 2.0, epsilon = 1e-12);
272        }
273    }
274
275    #[test]
276    fn gap_up_uses_high_minus_prev_close() {
277        // Previous close 5, current candle H=10 L=9 C=9.5 -> TR = max(1, 5, 4) = 5.
278        let candles = vec![
279            c(6.0, 4.0, 5.0),  // prev close = 5
280            c(10.0, 9.0, 9.5), // TR = 5
281        ];
282        let mut atr = Atr::new(2).unwrap();
283        let out = atr.batch(&candles);
284        // Seed window covers TR_1 and TR_2. TR_1 = H1-L1 = 2 (no prev close). TR_2 = 5.
285        // Seed = (2+5)/2 = 3.5
286        assert_relative_eq!(out[1].unwrap(), 3.5, epsilon = 1e-12);
287    }
288
289    #[test]
290    fn batch_equals_streaming() {
291        let candles: Vec<Candle> = (0..40)
292            .map(|i| {
293                let mid = f64::from(i) + 10.0;
294                c(mid + 0.5, mid - 0.5, mid)
295            })
296            .collect();
297        let mut a = Atr::new(14).unwrap();
298        let mut b = Atr::new(14).unwrap();
299        assert_eq!(
300            a.batch(&candles),
301            candles.iter().map(|x| b.update(*x)).collect::<Vec<_>>()
302        );
303    }
304
305    #[test]
306    fn reset_clears_state() {
307        let candles: Vec<Candle> = (0..20).map(|_| c(11.0, 9.0, 10.0)).collect();
308        let mut atr = Atr::new(5).unwrap();
309        atr.batch(&candles);
310        assert!(atr.is_ready());
311        atr.reset();
312        assert!(!atr.is_ready());
313        assert_eq!(atr.update(candles[0]), None);
314    }
315
316    #[test]
317    fn never_negative() {
318        let candles: Vec<Candle> = (0..200)
319            .map(|i| {
320                let base = 100.0 + (f64::from(i) * 0.3).sin() * 5.0;
321                c(base + 1.0, base - 1.0, base)
322            })
323            .collect();
324        let mut atr = Atr::new(14).unwrap();
325        for v in atr.batch(&candles).into_iter().flatten() {
326            assert!(v >= 0.0, "ATR must be non-negative: {v}");
327        }
328    }
329
330    fn bits_eq(a: &[f64], b: &[f64]) -> bool {
331        a.len() == b.len()
332            && a.iter()
333                .zip(b)
334                .all(|(x, y)| x == y || (x.is_nan() && y.is_nan()))
335    }
336
337    fn atr_replay(period: usize, high: &[f64], low: &[f64], close: &[f64]) -> Vec<f64> {
338        let mut a = Atr::new(period).unwrap();
339        (0..high.len())
340            .map(|i| {
341                let candle = Candle::new_unchecked(close[i], high[i], low[i], close[i], 0.0, 0);
342                a.update(candle).unwrap_or(f64::NAN)
343            })
344            .collect()
345    }
346
347    /// Valid OHLC columns from a wandering base price.
348    fn columns(n: usize) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
349        let base: Vec<f64> = (0..n)
350            .map(|i| (f64::from(u32::try_from(i).unwrap()) * 0.3).sin() * 5.0 + 100.0)
351            .collect();
352        let high = base.iter().map(|b| b + 1.0).collect();
353        let low = base.iter().map(|b| b - 1.0).collect();
354        (high, low, base)
355    }
356
357    #[test]
358    fn batch_atr_fast_path_is_bit_identical() {
359        let (high, low, close) = columns(300);
360        let mut atr = Atr::new(14).unwrap();
361        let got = atr.batch_atr(&high, &low, &close);
362        assert!(bits_eq(&got, &atr_replay(14, &high, &low, &close)));
363        let mut ref_atr = Atr::new(14).unwrap();
364        for i in 0..high.len() {
365            ref_atr.update(Candle::new_unchecked(
366                close[i], high[i], low[i], close[i], 0.0, 0,
367            ));
368        }
369        let next = Candle::new_unchecked(101.0, 102.0, 100.0, 101.0, 0.0, 0);
370        assert_eq!(atr.update(next), ref_atr.update(next));
371    }
372
373    #[test]
374    fn batch_atr_falls_back_when_not_fresh() {
375        let (high, low, close) = columns(40);
376        let mut atr = Atr::new(14).unwrap();
377        atr.update(Candle::new_unchecked(
378            close[0], high[0], low[0], close[0], 0.0, 0,
379        ));
380        let mut ref_atr = Atr::new(14).unwrap();
381        ref_atr.update(Candle::new_unchecked(
382            close[0], high[0], low[0], close[0], 0.0, 0,
383        ));
384        let want: Vec<f64> = (0..high.len())
385            .map(|i| {
386                ref_atr
387                    .update(Candle::new_unchecked(
388                        close[i], high[i], low[i], close[i], 0.0, 0,
389                    ))
390                    .unwrap_or(f64::NAN)
391            })
392            .collect();
393        assert!(bits_eq(&atr.batch_atr(&high, &low, &close), &want));
394    }
395
396    #[test]
397    fn batch_atr_sub_period_slice_falls_back() {
398        let (high, low, close) = columns(5);
399        let mut atr = Atr::new(14).unwrap();
400        let got = atr.batch_atr(&high, &low, &close);
401        assert!(bits_eq(&got, &atr_replay(14, &high, &low, &close)));
402        assert!(got.iter().all(|x| x.is_nan()));
403    }
404
405    proptest::proptest! {
406        #![proptest_config(proptest::test_runner::Config::with_cases(48))]
407        #[test]
408        fn atr_matches_naive(
409            period in 1usize..15,
410            bars in proptest::collection::vec(
411                (10.0_f64..1000.0, 0.0_f64..50.0, 0.0_f64..1.0),
412                0..120,
413            ),
414        ) {
415            // bars: (low, range, close_fraction) -> a valid OHLC candle.
416            let hlc: Vec<(f64, f64, f64)> = bars
417                .iter()
418                .map(|&(low, range, frac)| (low + range, low, low + range * frac))
419                .collect();
420            let candles: Vec<Candle> = hlc.iter().map(|&(h, l, cl)| c(h, l, cl)).collect();
421            let mut atr = Atr::new(period).unwrap();
422            let got = atr.batch(&candles);
423            let want = atr_naive(&hlc, period);
424            proptest::prop_assert_eq!(got.len(), want.len());
425            for (g, w) in got.iter().zip(want.iter()) {
426                match (g, w) {
427                    (None, None) => {}
428                    (Some(a), Some(b)) => proptest::prop_assert!(
429                        (a - b).abs() <= 1e-9 * a.abs().max(1.0),
430                        "got={a} want={b}"
431                    ),
432                    _ => proptest::prop_assert!(false, "warmup mismatch"),
433                }
434            }
435        }
436    }
437}