Skip to main content

wickra_core/indicators/
atr.rs

1//! Average True Range (Wilder).
2
3use crate::error::{Error, Result};
4use crate::ohlcv::Candle;
5use crate::traits::Indicator;
6
7/// Average True Range with Wilder smoothing.
8///
9/// The first emitted value, by convention, appears after `period` candles: the
10/// first `period − 1` true-range values seed the Wilder average alongside the
11/// `period`-th, then the smoothed update begins.
12///
13/// # Example
14///
15/// ```
16/// use wickra_core::{Candle, Indicator, Atr};
17///
18/// let mut indicator = Atr::new(5).unwrap();
19/// let mut last = None;
20/// for i in 0..80 {
21///     let base = 100.0 + f64::from(i);
22///     let candle =
23///         Candle::new(base, base + 2.0, base - 2.0, base + 1.0, 10.0, i64::from(i)).unwrap();
24///     last = indicator.update(candle);
25/// }
26/// assert!(last.is_some());
27/// ```
28#[derive(Debug, Clone)]
29pub struct Atr {
30    period: usize,
31    prev_close: Option<f64>,
32    seed_buf: Vec<f64>,
33    avg: Option<f64>,
34}
35
36impl Atr {
37    /// Construct an ATR with the given Wilder period.
38    ///
39    /// # Errors
40    ///
41    /// Returns [`Error::PeriodZero`] if `period == 0`.
42    pub fn new(period: usize) -> Result<Self> {
43        if period == 0 {
44            return Err(Error::PeriodZero);
45        }
46        Ok(Self {
47            period,
48            prev_close: None,
49            seed_buf: Vec::with_capacity(period),
50            avg: None,
51        })
52    }
53
54    /// Configured period.
55    pub const fn period(&self) -> usize {
56        self.period
57    }
58
59    /// Current value if available.
60    pub const fn value(&self) -> Option<f64> {
61        self.avg
62    }
63}
64
65impl Indicator for Atr {
66    type Input = Candle;
67    type Output = f64;
68
69    fn update(&mut self, candle: Candle) -> Option<f64> {
70        let tr = candle.true_range(self.prev_close);
71        self.prev_close = Some(candle.close);
72
73        if let Some(avg) = self.avg {
74            let n = self.period as f64;
75            let new_avg = avg.mul_add(n - 1.0, tr) / n;
76            self.avg = Some(new_avg);
77            return Some(new_avg);
78        }
79
80        self.seed_buf.push(tr);
81        if self.seed_buf.len() == self.period {
82            let seed = self.seed_buf.iter().copied().sum::<f64>() / self.period as f64;
83            self.avg = Some(seed);
84            return Some(seed);
85        }
86        None
87    }
88
89    fn reset(&mut self) {
90        self.prev_close = None;
91        self.seed_buf.clear();
92        self.avg = None;
93    }
94
95    fn warmup_period(&self) -> usize {
96        self.period
97    }
98
99    fn is_ready(&self) -> bool {
100        self.avg.is_some()
101    }
102
103    fn name(&self) -> &'static str {
104        "ATR"
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111    use crate::traits::BatchExt;
112    use approx::assert_relative_eq;
113
114    fn c(h: f64, l: f64, cl: f64) -> Candle {
115        // ts/open/volume don't affect ATR; use safe placeholders.
116        Candle::new(cl, h, l, cl, 1.0, 0).unwrap()
117    }
118
119    /// Independent reference: Wilder ATR computed straight from the definition.
120    fn atr_naive(hlc: &[(f64, f64, f64)], period: usize) -> Vec<Option<f64>> {
121        let n = period as f64;
122        let mut out = Vec::with_capacity(hlc.len());
123        let mut trs: Vec<f64> = Vec::new();
124        let mut avg: Option<f64> = None;
125        let mut prev_close: Option<f64> = None;
126        for &(h, l, cl) in hlc {
127            let tr = match prev_close {
128                None => h - l,
129                Some(pc) => (h - l).max((h - pc).abs()).max((l - pc).abs()),
130            };
131            prev_close = Some(cl);
132            if let Some(a) = avg {
133                let na = (a * (n - 1.0) + tr) / n;
134                avg = Some(na);
135                out.push(Some(na));
136            } else {
137                trs.push(tr);
138                if trs.len() == period {
139                    avg = Some(trs.iter().sum::<f64>() / n);
140                    out.push(avg);
141                } else {
142                    out.push(None);
143                }
144            }
145        }
146        out
147    }
148
149    #[test]
150    fn rejects_zero_period() {
151        assert!(matches!(Atr::new(0), Err(Error::PeriodZero)));
152    }
153
154    /// Cover the const accessors `period` / `value` (54-62) and the
155    /// Indicator-impl `name` body (103-105). Existing tests inspect
156    /// numeric ATR output but never query the metadata.
157    #[test]
158    fn accessors_and_metadata() {
159        let mut atr = Atr::new(14).unwrap();
160        assert_eq!(atr.period(), 14);
161        assert_eq!(atr.name(), "ATR");
162        assert_eq!(atr.value(), None);
163        for _ in 0..14 {
164            atr.update(c(11.0, 9.0, 10.0));
165        }
166        assert!(atr.value().is_some());
167    }
168
169    #[test]
170    fn warmup_emits_on_period_th_candle() {
171        let candles = vec![
172            c(2.0, 1.0, 1.5),
173            c(3.0, 2.0, 2.5),
174            c(4.0, 3.0, 3.5),
175            c(5.0, 4.0, 4.5),
176            c(6.0, 5.0, 5.5),
177        ];
178        let mut atr = Atr::new(3).unwrap();
179        let out = atr.batch(&candles);
180        assert!(out[0].is_none());
181        assert!(out[1].is_none());
182        assert!(out[2].is_some());
183        assert!(out[3].is_some());
184    }
185
186    #[test]
187    fn constant_range_yields_constant_atr() {
188        // Every candle has H=11, L=9, C=10 -> TR=2 (no gaps).
189        let candles: Vec<Candle> = (0..30).map(|_| c(11.0, 9.0, 10.0)).collect();
190        let mut atr = Atr::new(14).unwrap();
191        let out = atr.batch(&candles);
192        for v in out.iter().skip(13).flatten() {
193            assert_relative_eq!(*v, 2.0, epsilon = 1e-12);
194        }
195    }
196
197    #[test]
198    fn gap_up_uses_high_minus_prev_close() {
199        // Previous close 5, current candle H=10 L=9 C=9.5 -> TR = max(1, 5, 4) = 5.
200        let candles = vec![
201            c(6.0, 4.0, 5.0),  // prev close = 5
202            c(10.0, 9.0, 9.5), // TR = 5
203        ];
204        let mut atr = Atr::new(2).unwrap();
205        let out = atr.batch(&candles);
206        // Seed window covers TR_1 and TR_2. TR_1 = H1-L1 = 2 (no prev close). TR_2 = 5.
207        // Seed = (2+5)/2 = 3.5
208        assert_relative_eq!(out[1].unwrap(), 3.5, epsilon = 1e-12);
209    }
210
211    #[test]
212    fn batch_equals_streaming() {
213        let candles: Vec<Candle> = (0..40)
214            .map(|i| {
215                let mid = f64::from(i) + 10.0;
216                c(mid + 0.5, mid - 0.5, mid)
217            })
218            .collect();
219        let mut a = Atr::new(14).unwrap();
220        let mut b = Atr::new(14).unwrap();
221        assert_eq!(
222            a.batch(&candles),
223            candles.iter().map(|x| b.update(*x)).collect::<Vec<_>>()
224        );
225    }
226
227    #[test]
228    fn reset_clears_state() {
229        let candles: Vec<Candle> = (0..20).map(|_| c(11.0, 9.0, 10.0)).collect();
230        let mut atr = Atr::new(5).unwrap();
231        atr.batch(&candles);
232        assert!(atr.is_ready());
233        atr.reset();
234        assert!(!atr.is_ready());
235        assert_eq!(atr.update(candles[0]), None);
236    }
237
238    #[test]
239    fn never_negative() {
240        let candles: Vec<Candle> = (0..200)
241            .map(|i| {
242                let base = 100.0 + (f64::from(i) * 0.3).sin() * 5.0;
243                c(base + 1.0, base - 1.0, base)
244            })
245            .collect();
246        let mut atr = Atr::new(14).unwrap();
247        for v in atr.batch(&candles).into_iter().flatten() {
248            assert!(v >= 0.0, "ATR must be non-negative: {v}");
249        }
250    }
251
252    proptest::proptest! {
253        #![proptest_config(proptest::test_runner::Config::with_cases(48))]
254        #[test]
255        fn atr_matches_naive(
256            period in 1usize..15,
257            bars in proptest::collection::vec(
258                (10.0_f64..1000.0, 0.0_f64..50.0, 0.0_f64..1.0),
259                0..120,
260            ),
261        ) {
262            // bars: (low, range, close_fraction) -> a valid OHLC candle.
263            let hlc: Vec<(f64, f64, f64)> = bars
264                .iter()
265                .map(|&(low, range, frac)| (low + range, low, low + range * frac))
266                .collect();
267            let candles: Vec<Candle> = hlc.iter().map(|&(h, l, cl)| c(h, l, cl)).collect();
268            let mut atr = Atr::new(period).unwrap();
269            let got = atr.batch(&candles);
270            let want = atr_naive(&hlc, period);
271            proptest::prop_assert_eq!(got.len(), want.len());
272            for (g, w) in got.iter().zip(want.iter()) {
273                match (g, w) {
274                    (None, None) => {}
275                    (Some(a), Some(b)) => proptest::prop_assert!(
276                        (a - b).abs() <= 1e-9 * a.abs().max(1.0),
277                        "got={a} want={b}"
278                    ),
279                    _ => proptest::prop_assert!(false, "warmup mismatch"),
280                }
281            }
282        }
283    }
284}