Skip to main content

wickra_core/indicators/
adx.rs

1//! Average Directional Index (ADX) with +DI / -DI components.
2
3use crate::error::{Error, Result};
4use crate::ohlcv::Candle;
5use crate::traits::Indicator;
6
7/// ADX output: the three Wilder lines.
8#[derive(Debug, Clone, Copy, PartialEq)]
9pub struct AdxOutput {
10    /// Plus Directional Indicator.
11    pub plus_di: f64,
12    /// Minus Directional Indicator.
13    pub minus_di: f64,
14    /// Average Directional Index (smoothed |DX|).
15    pub adx: f64,
16}
17
18/// Wilder's Average Directional Index.
19///
20/// Uses Wilder smoothing throughout. First `period` candles seed the directional
21/// movement / true range sums; the next `period` candles produce DX values that
22/// seed the ADX. The first complete `AdxOutput` is emitted after `2 * period`
23/// candles.
24///
25/// # Example
26///
27/// ```
28/// use wickra_core::{Candle, Indicator, Adx};
29///
30/// let mut indicator = Adx::new(5).unwrap();
31/// let mut last = None;
32/// for i in 0..80 {
33///     let base = 100.0 + f64::from(i);
34///     let candle =
35///         Candle::new(base, base + 2.0, base - 2.0, base + 1.0, 10.0, i64::from(i)).unwrap();
36///     last = indicator.update(candle);
37/// }
38/// assert!(last.is_some());
39/// ```
40#[allow(clippy::struct_field_names)] // adx_value pairs with adx (the output line) — renaming hurts clarity
41#[derive(Debug, Clone)]
42pub struct Adx {
43    period: usize,
44    prev: Option<Candle>,
45
46    // Wilder-smoothed sums during seeding.
47    tr_seed: f64,
48    plus_dm_seed: f64,
49    minus_dm_seed: f64,
50    seed_count: usize,
51
52    // Smoothed running values after seeding.
53    tr_smooth: Option<f64>,
54    plus_dm_smooth: Option<f64>,
55    minus_dm_smooth: Option<f64>,
56
57    // ADX seeding.
58    dx_buf: Vec<f64>,
59    adx_value: Option<f64>,
60    last_plus_di: f64,
61    last_minus_di: f64,
62}
63
64impl Adx {
65    /// # Errors
66    /// Returns [`Error::PeriodZero`] if `period == 0`.
67    pub fn new(period: usize) -> Result<Self> {
68        if period == 0 {
69            return Err(Error::PeriodZero);
70        }
71        Ok(Self {
72            period,
73            prev: None,
74            tr_seed: 0.0,
75            plus_dm_seed: 0.0,
76            minus_dm_seed: 0.0,
77            seed_count: 0,
78            tr_smooth: None,
79            plus_dm_smooth: None,
80            minus_dm_smooth: None,
81            dx_buf: Vec::with_capacity(period),
82            adx_value: None,
83            last_plus_di: 0.0,
84            last_minus_di: 0.0,
85        })
86    }
87
88    /// Configured period.
89    pub const fn period(&self) -> usize {
90        self.period
91    }
92}
93
94fn directional_movement(prev: &Candle, current: &Candle) -> (f64, f64) {
95    let up = current.high - prev.high;
96    let down = prev.low - current.low;
97    let plus_dm = if up > down && up > 0.0 { up } else { 0.0 };
98    let minus_dm = if down > up && down > 0.0 { down } else { 0.0 };
99    (plus_dm, minus_dm)
100}
101
102impl Indicator for Adx {
103    type Input = Candle;
104    type Output = AdxOutput;
105
106    fn update(&mut self, candle: Candle) -> Option<AdxOutput> {
107        let Some(prev) = self.prev else {
108            self.prev = Some(candle);
109            return None;
110        };
111        self.prev = Some(candle);
112
113        let tr = candle.true_range(Some(prev.close));
114        let (plus_dm, minus_dm) = directional_movement(&prev, &candle);
115        let n = self.period as f64;
116
117        let (tr_v, plus_v, minus_v) = if let (Some(t), Some(p), Some(m)) =
118            (self.tr_smooth, self.plus_dm_smooth, self.minus_dm_smooth)
119        {
120            let t_new = t - t / n + tr;
121            let p_new = p - p / n + plus_dm;
122            let m_new = m - m / n + minus_dm;
123            self.tr_smooth = Some(t_new);
124            self.plus_dm_smooth = Some(p_new);
125            self.minus_dm_smooth = Some(m_new);
126            (t_new, p_new, m_new)
127        } else {
128            self.tr_seed += tr;
129            self.plus_dm_seed += plus_dm;
130            self.minus_dm_seed += minus_dm;
131            self.seed_count += 1;
132            if self.seed_count < self.period {
133                return None;
134            }
135            self.tr_smooth = Some(self.tr_seed);
136            self.plus_dm_smooth = Some(self.plus_dm_seed);
137            self.minus_dm_smooth = Some(self.minus_dm_seed);
138            (self.tr_seed, self.plus_dm_seed, self.minus_dm_seed)
139        };
140
141        let plus_di = if tr_v == 0.0 {
142            0.0
143        } else {
144            100.0 * plus_v / tr_v
145        };
146        let minus_di = if tr_v == 0.0 {
147            0.0
148        } else {
149            100.0 * minus_v / tr_v
150        };
151        self.last_plus_di = plus_di;
152        self.last_minus_di = minus_di;
153
154        let dx_den = plus_di + minus_di;
155        let dx = if dx_den == 0.0 {
156            0.0
157        } else {
158            100.0 * (plus_di - minus_di).abs() / dx_den
159        };
160
161        if let Some(prev_adx) = self.adx_value {
162            let new_adx = (prev_adx * (n - 1.0) + dx) / n;
163            self.adx_value = Some(new_adx);
164            return Some(AdxOutput {
165                plus_di,
166                minus_di,
167                adx: new_adx,
168            });
169        }
170
171        self.dx_buf.push(dx);
172        if self.dx_buf.len() == self.period {
173            let seed = self.dx_buf.iter().sum::<f64>() / n;
174            self.adx_value = Some(seed);
175            return Some(AdxOutput {
176                plus_di,
177                minus_di,
178                adx: seed,
179            });
180        }
181        None
182    }
183
184    fn reset(&mut self) {
185        self.prev = None;
186        self.tr_seed = 0.0;
187        self.plus_dm_seed = 0.0;
188        self.minus_dm_seed = 0.0;
189        self.seed_count = 0;
190        self.tr_smooth = None;
191        self.plus_dm_smooth = None;
192        self.minus_dm_smooth = None;
193        self.dx_buf.clear();
194        self.adx_value = None;
195        self.last_plus_di = 0.0;
196        self.last_minus_di = 0.0;
197    }
198
199    fn warmup_period(&self) -> usize {
200        2 * self.period
201    }
202
203    fn is_ready(&self) -> bool {
204        self.adx_value.is_some()
205    }
206
207    fn name(&self) -> &'static str {
208        "ADX"
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215    use crate::traits::BatchExt;
216    use approx::assert_relative_eq;
217
218    fn c(h: f64, l: f64, cl: f64) -> Candle {
219        Candle::new(cl, h, l, cl, 1.0, 0).unwrap()
220    }
221
222    #[test]
223    fn pure_uptrend_yields_plus_di_dominant() {
224        // Strict uptrend: highs increase, lows increase, ADX should trend up,
225        // +DI should dominate -DI.
226        let candles: Vec<Candle> = (0..50)
227            .map(|i| {
228                let base = 100.0 + f64::from(i) * 2.0;
229                c(base + 1.0, base - 0.5, base + 0.5)
230            })
231            .collect();
232        let mut adx = Adx::new(14).unwrap();
233        let last = adx
234            .batch(&candles)
235            .into_iter()
236            .flatten()
237            .last()
238            .expect("emits");
239        assert!(
240            last.plus_di > last.minus_di,
241            "+DI {} should exceed -DI {}",
242            last.plus_di,
243            last.minus_di
244        );
245        assert!(last.adx > 0.0);
246    }
247
248    #[test]
249    fn pure_downtrend_yields_minus_di_dominant() {
250        let candles: Vec<Candle> = (0..50)
251            .rev()
252            .map(|i| {
253                let base = 100.0 + f64::from(i) * 2.0;
254                c(base + 1.0, base - 0.5, base + 0.5)
255            })
256            .collect();
257        let mut adx = Adx::new(14).unwrap();
258        let last = adx
259            .batch(&candles)
260            .into_iter()
261            .flatten()
262            .last()
263            .expect("emits");
264        assert!(last.minus_di > last.plus_di);
265    }
266
267    #[test]
268    fn rejects_zero_period() {
269        assert!(Adx::new(0).is_err());
270    }
271
272    /// Cover the const accessor `period` (lines 89-91) and the Indicator-impl
273    /// `warmup_period` (199-201) + `name` (207-209). None of the trend tests
274    /// inspect these metadata methods.
275    #[test]
276    fn accessors_and_metadata() {
277        let adx = Adx::new(14).unwrap();
278        assert_eq!(adx.period(), 14);
279        assert_eq!(adx.warmup_period(), 28);
280        assert_eq!(adx.name(), "ADX");
281    }
282
283    /// Cover the `tr_v == 0.0` defensive branches in `update` (lines 142,
284    /// 147) — feeding a stream of perfectly flat candles (H == L == close
285    /// every bar) gives true-range 0 each step, so the smoothed `tr_smooth`
286    /// stays at 0.0 and the `plus_di` / `minus_di` divisions would otherwise
287    /// blow up. The indicator must emit zeros (DX denominator is also 0).
288    #[test]
289    fn zero_true_range_yields_zero_di_and_zero_adx() {
290        let candles: Vec<Candle> = (0..30).map(|_| c(10.0, 10.0, 10.0)).collect();
291        let mut adx = Adx::new(5).unwrap();
292        let last = adx
293            .batch(&candles)
294            .into_iter()
295            .flatten()
296            .last()
297            .expect("ADX emits after 2 * period candles");
298        assert_eq!(last.plus_di, 0.0);
299        assert_eq!(last.minus_di, 0.0);
300        assert_eq!(last.adx, 0.0);
301    }
302
303    #[test]
304    fn batch_equals_streaming() {
305        let candles: Vec<Candle> = (0..60)
306            .map(|i| {
307                let base = 100.0 + (f64::from(i) * 0.3).sin() * 5.0;
308                c(base + 1.0, base - 1.0, base)
309            })
310            .collect();
311        let mut a = Adx::new(14).unwrap();
312        let mut b = Adx::new(14).unwrap();
313        assert_eq!(
314            a.batch(&candles),
315            candles.iter().map(|x| b.update(*x)).collect::<Vec<_>>()
316        );
317    }
318
319    #[test]
320    fn reset_clears_state() {
321        let candles: Vec<Candle> = (0..40).map(|_| c(11.0, 9.0, 10.0)).collect();
322        let mut adx = Adx::new(14).unwrap();
323        adx.batch(&candles);
324        adx.reset();
325        assert!(!adx.is_ready());
326    }
327
328    #[test]
329    fn outputs_remain_finite() {
330        let candles: Vec<Candle> = (0..200)
331            .map(|i| {
332                let m = 100.0 + (f64::from(i) * 0.2).sin() * 5.0;
333                c(m + 1.0, m - 1.0, m)
334            })
335            .collect();
336        let mut adx = Adx::new(14).unwrap();
337        for v in adx.batch(&candles).into_iter().flatten() {
338            assert!(v.plus_di.is_finite() && v.minus_di.is_finite() && v.adx.is_finite());
339        }
340        // Sanity: ADX is bounded by 100.
341        let last = adx.batch(&candles).into_iter().flatten().last().unwrap();
342        assert!(last.adx <= 100.0 + 1e-6);
343        assert_relative_eq!(0.0_f64.max(last.adx), last.adx, epsilon = 1e-9);
344    }
345}