1use crate::error::{Error, Result};
4use crate::ohlcv::Candle;
5use crate::traits::Indicator;
6
7#[derive(Debug, Clone, Copy, PartialEq)]
9pub struct AdxOutput {
10 pub plus_di: f64,
12 pub minus_di: f64,
14 pub adx: f64,
16}
17
18#[allow(clippy::struct_field_names)] #[derive(Debug, Clone)]
42pub struct Adx {
43 period: usize,
44 prev: Option<Candle>,
45
46 tr_seed: f64,
48 plus_dm_seed: f64,
49 minus_dm_seed: f64,
50 seed_count: usize,
51
52 tr_smooth: Option<f64>,
54 plus_dm_smooth: Option<f64>,
55 minus_dm_smooth: Option<f64>,
56
57 dx_buf: Vec<f64>,
59 adx_value: Option<f64>,
60 last_plus_di: f64,
61 last_minus_di: f64,
62}
63
64impl Adx {
65 pub fn new(period: usize) -> Result<Self> {
68 if period == 0 {
69 return Err(Error::PeriodZero);
70 }
71 Ok(Self {
72 period,
73 prev: None,
74 tr_seed: 0.0,
75 plus_dm_seed: 0.0,
76 minus_dm_seed: 0.0,
77 seed_count: 0,
78 tr_smooth: None,
79 plus_dm_smooth: None,
80 minus_dm_smooth: None,
81 dx_buf: Vec::with_capacity(period),
82 adx_value: None,
83 last_plus_di: 0.0,
84 last_minus_di: 0.0,
85 })
86 }
87
88 pub const fn period(&self) -> usize {
90 self.period
91 }
92}
93
94fn directional_movement(prev: &Candle, current: &Candle) -> (f64, f64) {
95 let up = current.high - prev.high;
96 let down = prev.low - current.low;
97 let plus_dm = if up > down && up > 0.0 { up } else { 0.0 };
98 let minus_dm = if down > up && down > 0.0 { down } else { 0.0 };
99 (plus_dm, minus_dm)
100}
101
102impl Indicator for Adx {
103 type Input = Candle;
104 type Output = AdxOutput;
105
106 fn update(&mut self, candle: Candle) -> Option<AdxOutput> {
107 let Some(prev) = self.prev else {
108 self.prev = Some(candle);
109 return None;
110 };
111 self.prev = Some(candle);
112
113 let tr = candle.true_range(Some(prev.close));
114 let (plus_dm, minus_dm) = directional_movement(&prev, &candle);
115 let n = self.period as f64;
116
117 let (tr_v, plus_v, minus_v) = if let (Some(t), Some(p), Some(m)) =
118 (self.tr_smooth, self.plus_dm_smooth, self.minus_dm_smooth)
119 {
120 let t_new = t - t / n + tr;
121 let p_new = p - p / n + plus_dm;
122 let m_new = m - m / n + minus_dm;
123 self.tr_smooth = Some(t_new);
124 self.plus_dm_smooth = Some(p_new);
125 self.minus_dm_smooth = Some(m_new);
126 (t_new, p_new, m_new)
127 } else {
128 self.tr_seed += tr;
129 self.plus_dm_seed += plus_dm;
130 self.minus_dm_seed += minus_dm;
131 self.seed_count += 1;
132 if self.seed_count < self.period {
133 return None;
134 }
135 self.tr_smooth = Some(self.tr_seed);
136 self.plus_dm_smooth = Some(self.plus_dm_seed);
137 self.minus_dm_smooth = Some(self.minus_dm_seed);
138 (self.tr_seed, self.plus_dm_seed, self.minus_dm_seed)
139 };
140
141 let plus_di = if tr_v == 0.0 {
142 0.0
143 } else {
144 100.0 * plus_v / tr_v
145 };
146 let minus_di = if tr_v == 0.0 {
147 0.0
148 } else {
149 100.0 * minus_v / tr_v
150 };
151 self.last_plus_di = plus_di;
152 self.last_minus_di = minus_di;
153
154 let dx_den = plus_di + minus_di;
155 let dx = if dx_den == 0.0 {
156 0.0
157 } else {
158 100.0 * (plus_di - minus_di).abs() / dx_den
159 };
160
161 if let Some(prev_adx) = self.adx_value {
162 let new_adx = (prev_adx * (n - 1.0) + dx) / n;
163 self.adx_value = Some(new_adx);
164 return Some(AdxOutput {
165 plus_di,
166 minus_di,
167 adx: new_adx,
168 });
169 }
170
171 self.dx_buf.push(dx);
172 if self.dx_buf.len() == self.period {
173 let seed = self.dx_buf.iter().sum::<f64>() / n;
174 self.adx_value = Some(seed);
175 return Some(AdxOutput {
176 plus_di,
177 minus_di,
178 adx: seed,
179 });
180 }
181 None
182 }
183
184 fn reset(&mut self) {
185 self.prev = None;
186 self.tr_seed = 0.0;
187 self.plus_dm_seed = 0.0;
188 self.minus_dm_seed = 0.0;
189 self.seed_count = 0;
190 self.tr_smooth = None;
191 self.plus_dm_smooth = None;
192 self.minus_dm_smooth = None;
193 self.dx_buf.clear();
194 self.adx_value = None;
195 self.last_plus_di = 0.0;
196 self.last_minus_di = 0.0;
197 }
198
199 fn warmup_period(&self) -> usize {
200 2 * self.period
201 }
202
203 fn is_ready(&self) -> bool {
204 self.adx_value.is_some()
205 }
206
207 fn name(&self) -> &'static str {
208 "ADX"
209 }
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215 use crate::traits::BatchExt;
216 use approx::assert_relative_eq;
217
218 fn c(h: f64, l: f64, cl: f64) -> Candle {
219 Candle::new(cl, h, l, cl, 1.0, 0).unwrap()
220 }
221
222 #[test]
223 fn pure_uptrend_yields_plus_di_dominant() {
224 let candles: Vec<Candle> = (0..50)
227 .map(|i| {
228 let base = 100.0 + f64::from(i) * 2.0;
229 c(base + 1.0, base - 0.5, base + 0.5)
230 })
231 .collect();
232 let mut adx = Adx::new(14).unwrap();
233 let last = adx
234 .batch(&candles)
235 .into_iter()
236 .flatten()
237 .last()
238 .expect("emits");
239 assert!(
240 last.plus_di > last.minus_di,
241 "+DI {} should exceed -DI {}",
242 last.plus_di,
243 last.minus_di
244 );
245 assert!(last.adx > 0.0);
246 }
247
248 #[test]
249 fn pure_downtrend_yields_minus_di_dominant() {
250 let candles: Vec<Candle> = (0..50)
251 .rev()
252 .map(|i| {
253 let base = 100.0 + f64::from(i) * 2.0;
254 c(base + 1.0, base - 0.5, base + 0.5)
255 })
256 .collect();
257 let mut adx = Adx::new(14).unwrap();
258 let last = adx
259 .batch(&candles)
260 .into_iter()
261 .flatten()
262 .last()
263 .expect("emits");
264 assert!(last.minus_di > last.plus_di);
265 }
266
267 #[test]
268 fn rejects_zero_period() {
269 assert!(Adx::new(0).is_err());
270 }
271
272 #[test]
276 fn accessors_and_metadata() {
277 let adx = Adx::new(14).unwrap();
278 assert_eq!(adx.period(), 14);
279 assert_eq!(adx.warmup_period(), 28);
280 assert_eq!(adx.name(), "ADX");
281 }
282
283 #[test]
289 fn zero_true_range_yields_zero_di_and_zero_adx() {
290 let candles: Vec<Candle> = (0..30).map(|_| c(10.0, 10.0, 10.0)).collect();
291 let mut adx = Adx::new(5).unwrap();
292 let last = adx
293 .batch(&candles)
294 .into_iter()
295 .flatten()
296 .last()
297 .expect("ADX emits after 2 * period candles");
298 assert_eq!(last.plus_di, 0.0);
299 assert_eq!(last.minus_di, 0.0);
300 assert_eq!(last.adx, 0.0);
301 }
302
303 #[test]
304 fn batch_equals_streaming() {
305 let candles: Vec<Candle> = (0..60)
306 .map(|i| {
307 let base = 100.0 + (f64::from(i) * 0.3).sin() * 5.0;
308 c(base + 1.0, base - 1.0, base)
309 })
310 .collect();
311 let mut a = Adx::new(14).unwrap();
312 let mut b = Adx::new(14).unwrap();
313 assert_eq!(
314 a.batch(&candles),
315 candles.iter().map(|x| b.update(*x)).collect::<Vec<_>>()
316 );
317 }
318
319 #[test]
320 fn reset_clears_state() {
321 let candles: Vec<Candle> = (0..40).map(|_| c(11.0, 9.0, 10.0)).collect();
322 let mut adx = Adx::new(14).unwrap();
323 adx.batch(&candles);
324 adx.reset();
325 assert!(!adx.is_ready());
326 }
327
328 #[test]
329 fn outputs_remain_finite() {
330 let candles: Vec<Candle> = (0..200)
331 .map(|i| {
332 let m = 100.0 + (f64::from(i) * 0.2).sin() * 5.0;
333 c(m + 1.0, m - 1.0, m)
334 })
335 .collect();
336 let mut adx = Adx::new(14).unwrap();
337 for v in adx.batch(&candles).into_iter().flatten() {
338 assert!(v.plus_di.is_finite() && v.minus_di.is_finite() && v.adx.is_finite());
339 }
340 let last = adx.batch(&candles).into_iter().flatten().last().unwrap();
342 assert!(last.adx <= 100.0 + 1e-6);
343 assert_relative_eq!(0.0_f64.max(last.adx), last.adx, epsilon = 1e-9);
344 }
345}