1use crate::error::{Error, Result};
4use crate::ohlcv::Candle;
5use crate::traits::Indicator;
6
7#[derive(Debug, Clone)]
29pub struct Atr {
30 period: usize,
31 n_minus_1: f64,
33 inv_period: f64,
36 prev_close: Option<f64>,
37 seed_buf: Vec<f64>,
38 avg: f64,
41 seeded: bool,
42}
43
44impl Atr {
45 pub fn new(period: usize) -> Result<Self> {
51 if period == 0 {
52 return Err(Error::PeriodZero);
53 }
54 Ok(Self {
55 period,
56 n_minus_1: (period - 1) as f64,
57 inv_period: 1.0 / period as f64,
58 prev_close: None,
59 seed_buf: Vec::with_capacity(period),
60 avg: 0.0,
61 seeded: false,
62 })
63 }
64
65 pub const fn period(&self) -> usize {
67 self.period
68 }
69
70 pub const fn value(&self) -> Option<f64> {
72 if self.seeded {
73 Some(self.avg)
74 } else {
75 None
76 }
77 }
78
79 pub fn batch_atr(&mut self, high: &[f64], low: &[f64], close: &[f64]) -> Vec<f64> {
91 let p = self.period;
92 let n = high.len();
93 if self.seeded || !self.seed_buf.is_empty() || self.prev_close.is_some() || n < p {
94 let mut out = vec![f64::NAN; n];
95 for i in 0..n {
96 let candle = Candle::new_unchecked(close[i], high[i], low[i], close[i], 0.0, 0);
97 if let Some(v) = self.update(candle) {
98 out[i] = v;
99 }
100 }
101 return out;
102 }
103
104 let mut out = vec![f64::NAN; p - 1];
106 out.reserve(n - (p - 1));
107 let mut prev_close = close[0];
109 let mut sum_tr = high[0] - low[0];
110 self.seed_buf.push(sum_tr);
111 for i in 1..p {
112 let (h, l) = (high[i], low[i]);
113 let tr = (h - l)
114 .max((h - prev_close).abs())
115 .max((l - prev_close).abs());
116 prev_close = close[i];
117 self.seed_buf.push(tr);
118 sum_tr += tr;
119 }
120 let mut avg = sum_tr / p as f64;
121 out.push(avg);
122 for i in p..n {
124 let (h, l) = (high[i], low[i]);
125 let tr = (h - l)
126 .max((h - prev_close).abs())
127 .max((l - prev_close).abs());
128 prev_close = close[i];
129 avg = avg.mul_add(self.n_minus_1, tr) * self.inv_period;
130 out.push(avg);
131 }
132
133 self.prev_close = Some(prev_close);
135 self.avg = avg;
136 self.seeded = true;
137 out
138 }
139}
140
141impl Indicator for Atr {
142 type Input = Candle;
143 type Output = f64;
144
145 fn update(&mut self, candle: Candle) -> Option<f64> {
146 let tr = candle.true_range(self.prev_close);
147 self.prev_close = Some(candle.close);
148
149 if self.seeded {
150 let new_avg = self.avg.mul_add(self.n_minus_1, tr) * self.inv_period;
152 self.avg = new_avg;
153 return Some(new_avg);
154 }
155
156 self.seed_buf.push(tr);
157 if self.seed_buf.len() == self.period {
158 let seed = self.seed_buf.iter().copied().sum::<f64>() / self.period as f64;
159 self.avg = seed;
160 self.seeded = true;
161 return Some(seed);
162 }
163 None
164 }
165
166 fn reset(&mut self) {
167 self.prev_close = None;
168 self.seed_buf.clear();
169 self.avg = 0.0;
170 self.seeded = false;
171 }
172
173 fn warmup_period(&self) -> usize {
174 self.period
175 }
176
177 fn is_ready(&self) -> bool {
178 self.seeded
179 }
180
181 fn name(&self) -> &'static str {
182 "ATR"
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189 use crate::traits::BatchExt;
190 use approx::assert_relative_eq;
191
192 fn c(h: f64, l: f64, cl: f64) -> Candle {
193 Candle::new(cl, h, l, cl, 1.0, 0).unwrap()
195 }
196
197 fn atr_naive(hlc: &[(f64, f64, f64)], period: usize) -> Vec<Option<f64>> {
199 let n = period as f64;
200 let mut out = Vec::with_capacity(hlc.len());
201 let mut trs: Vec<f64> = Vec::new();
202 let mut avg: Option<f64> = None;
203 let mut prev_close: Option<f64> = None;
204 for &(h, l, cl) in hlc {
205 let tr = match prev_close {
206 None => h - l,
207 Some(pc) => (h - l).max((h - pc).abs()).max((l - pc).abs()),
208 };
209 prev_close = Some(cl);
210 if let Some(a) = avg {
211 let na = (a * (n - 1.0) + tr) / n;
212 avg = Some(na);
213 out.push(Some(na));
214 } else {
215 trs.push(tr);
216 if trs.len() == period {
217 avg = Some(trs.iter().sum::<f64>() / n);
218 out.push(avg);
219 } else {
220 out.push(None);
221 }
222 }
223 }
224 out
225 }
226
227 #[test]
228 fn rejects_zero_period() {
229 assert!(matches!(Atr::new(0), Err(Error::PeriodZero)));
230 }
231
232 #[test]
236 fn accessors_and_metadata() {
237 let mut atr = Atr::new(14).unwrap();
238 assert_eq!(atr.period(), 14);
239 assert_eq!(atr.name(), "ATR");
240 assert_eq!(atr.value(), None);
241 for _ in 0..14 {
242 atr.update(c(11.0, 9.0, 10.0));
243 }
244 assert!(atr.value().is_some());
245 }
246
247 #[test]
248 fn warmup_emits_on_period_th_candle() {
249 let candles = vec![
250 c(2.0, 1.0, 1.5),
251 c(3.0, 2.0, 2.5),
252 c(4.0, 3.0, 3.5),
253 c(5.0, 4.0, 4.5),
254 c(6.0, 5.0, 5.5),
255 ];
256 let mut atr = Atr::new(3).unwrap();
257 let out = atr.batch(&candles);
258 assert!(out[0].is_none());
259 assert!(out[1].is_none());
260 assert!(out[2].is_some());
261 assert!(out[3].is_some());
262 }
263
264 #[test]
265 fn constant_range_yields_constant_atr() {
266 let candles: Vec<Candle> = (0..30).map(|_| c(11.0, 9.0, 10.0)).collect();
268 let mut atr = Atr::new(14).unwrap();
269 let out = atr.batch(&candles);
270 for v in out.iter().skip(13).flatten() {
271 assert_relative_eq!(*v, 2.0, epsilon = 1e-12);
272 }
273 }
274
275 #[test]
276 fn gap_up_uses_high_minus_prev_close() {
277 let candles = vec![
279 c(6.0, 4.0, 5.0), c(10.0, 9.0, 9.5), ];
282 let mut atr = Atr::new(2).unwrap();
283 let out = atr.batch(&candles);
284 assert_relative_eq!(out[1].unwrap(), 3.5, epsilon = 1e-12);
287 }
288
289 #[test]
290 fn batch_equals_streaming() {
291 let candles: Vec<Candle> = (0..40)
292 .map(|i| {
293 let mid = f64::from(i) + 10.0;
294 c(mid + 0.5, mid - 0.5, mid)
295 })
296 .collect();
297 let mut a = Atr::new(14).unwrap();
298 let mut b = Atr::new(14).unwrap();
299 assert_eq!(
300 a.batch(&candles),
301 candles.iter().map(|x| b.update(*x)).collect::<Vec<_>>()
302 );
303 }
304
305 #[test]
306 fn reset_clears_state() {
307 let candles: Vec<Candle> = (0..20).map(|_| c(11.0, 9.0, 10.0)).collect();
308 let mut atr = Atr::new(5).unwrap();
309 atr.batch(&candles);
310 assert!(atr.is_ready());
311 atr.reset();
312 assert!(!atr.is_ready());
313 assert_eq!(atr.update(candles[0]), None);
314 }
315
316 #[test]
317 fn never_negative() {
318 let candles: Vec<Candle> = (0..200)
319 .map(|i| {
320 let base = 100.0 + (f64::from(i) * 0.3).sin() * 5.0;
321 c(base + 1.0, base - 1.0, base)
322 })
323 .collect();
324 let mut atr = Atr::new(14).unwrap();
325 for v in atr.batch(&candles).into_iter().flatten() {
326 assert!(v >= 0.0, "ATR must be non-negative: {v}");
327 }
328 }
329
330 fn bits_eq(a: &[f64], b: &[f64]) -> bool {
331 a.len() == b.len()
332 && a.iter()
333 .zip(b)
334 .all(|(x, y)| x == y || (x.is_nan() && y.is_nan()))
335 }
336
337 fn atr_replay(period: usize, high: &[f64], low: &[f64], close: &[f64]) -> Vec<f64> {
338 let mut a = Atr::new(period).unwrap();
339 (0..high.len())
340 .map(|i| {
341 let candle = Candle::new_unchecked(close[i], high[i], low[i], close[i], 0.0, 0);
342 a.update(candle).unwrap_or(f64::NAN)
343 })
344 .collect()
345 }
346
347 fn columns(n: usize) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
349 let base: Vec<f64> = (0..n)
350 .map(|i| (f64::from(u32::try_from(i).unwrap()) * 0.3).sin() * 5.0 + 100.0)
351 .collect();
352 let high = base.iter().map(|b| b + 1.0).collect();
353 let low = base.iter().map(|b| b - 1.0).collect();
354 (high, low, base)
355 }
356
357 #[test]
358 fn batch_atr_fast_path_is_bit_identical() {
359 let (high, low, close) = columns(300);
360 let mut atr = Atr::new(14).unwrap();
361 let got = atr.batch_atr(&high, &low, &close);
362 assert!(bits_eq(&got, &atr_replay(14, &high, &low, &close)));
363 let mut ref_atr = Atr::new(14).unwrap();
364 for i in 0..high.len() {
365 ref_atr.update(Candle::new_unchecked(
366 close[i], high[i], low[i], close[i], 0.0, 0,
367 ));
368 }
369 let next = Candle::new_unchecked(101.0, 102.0, 100.0, 101.0, 0.0, 0);
370 assert_eq!(atr.update(next), ref_atr.update(next));
371 }
372
373 #[test]
374 fn batch_atr_falls_back_when_not_fresh() {
375 let (high, low, close) = columns(40);
376 let mut atr = Atr::new(14).unwrap();
377 atr.update(Candle::new_unchecked(
378 close[0], high[0], low[0], close[0], 0.0, 0,
379 ));
380 let mut ref_atr = Atr::new(14).unwrap();
381 ref_atr.update(Candle::new_unchecked(
382 close[0], high[0], low[0], close[0], 0.0, 0,
383 ));
384 let want: Vec<f64> = (0..high.len())
385 .map(|i| {
386 ref_atr
387 .update(Candle::new_unchecked(
388 close[i], high[i], low[i], close[i], 0.0, 0,
389 ))
390 .unwrap_or(f64::NAN)
391 })
392 .collect();
393 assert!(bits_eq(&atr.batch_atr(&high, &low, &close), &want));
394 }
395
396 #[test]
397 fn batch_atr_sub_period_slice_falls_back() {
398 let (high, low, close) = columns(5);
399 let mut atr = Atr::new(14).unwrap();
400 let got = atr.batch_atr(&high, &low, &close);
401 assert!(bits_eq(&got, &atr_replay(14, &high, &low, &close)));
402 assert!(got.iter().all(|x| x.is_nan()));
403 }
404
405 proptest::proptest! {
406 #![proptest_config(proptest::test_runner::Config::with_cases(48))]
407 #[test]
408 fn atr_matches_naive(
409 period in 1usize..15,
410 bars in proptest::collection::vec(
411 (10.0_f64..1000.0, 0.0_f64..50.0, 0.0_f64..1.0),
412 0..120,
413 ),
414 ) {
415 let hlc: Vec<(f64, f64, f64)> = bars
417 .iter()
418 .map(|&(low, range, frac)| (low + range, low, low + range * frac))
419 .collect();
420 let candles: Vec<Candle> = hlc.iter().map(|&(h, l, cl)| c(h, l, cl)).collect();
421 let mut atr = Atr::new(period).unwrap();
422 let got = atr.batch(&candles);
423 let want = atr_naive(&hlc, period);
424 proptest::prop_assert_eq!(got.len(), want.len());
425 for (g, w) in got.iter().zip(want.iter()) {
426 match (g, w) {
427 (None, None) => {}
428 (Some(a), Some(b)) => proptest::prop_assert!(
429 (a - b).abs() <= 1e-9 * a.abs().max(1.0),
430 "got={a} want={b}"
431 ),
432 _ => proptest::prop_assert!(false, "warmup mismatch"),
433 }
434 }
435 }
436 }
437}