1use crate::error::{Error, Result};
4use crate::ohlcv::Candle;
5use crate::traits::Indicator;
6
7#[derive(Debug, Clone)]
29pub struct Atr {
30 period: usize,
31 n_minus_1: f64,
33 inv_period: f64,
36 prev_close: Option<f64>,
37 seed_buf: Vec<f64>,
38 avg: f64,
41 seeded: bool,
42}
43
44impl Atr {
45 pub fn new(period: usize) -> Result<Self> {
51 if period == 0 {
52 return Err(Error::PeriodZero);
53 }
54 Ok(Self {
55 period,
56 n_minus_1: (period - 1) as f64,
57 inv_period: 1.0 / period as f64,
58 prev_close: None,
59 seed_buf: Vec::with_capacity(period),
60 avg: 0.0,
61 seeded: false,
62 })
63 }
64
65 pub const fn period(&self) -> usize {
67 self.period
68 }
69
70 pub const fn value(&self) -> Option<f64> {
72 if self.seeded {
73 Some(self.avg)
74 } else {
75 None
76 }
77 }
78}
79
80impl Indicator for Atr {
81 type Input = Candle;
82 type Output = f64;
83
84 fn update(&mut self, candle: Candle) -> Option<f64> {
85 let tr = candle.true_range(self.prev_close);
86 self.prev_close = Some(candle.close);
87
88 if self.seeded {
89 let new_avg = self.avg.mul_add(self.n_minus_1, tr) * self.inv_period;
91 self.avg = new_avg;
92 return Some(new_avg);
93 }
94
95 self.seed_buf.push(tr);
96 if self.seed_buf.len() == self.period {
97 let seed = self.seed_buf.iter().copied().sum::<f64>() / self.period as f64;
98 self.avg = seed;
99 self.seeded = true;
100 return Some(seed);
101 }
102 None
103 }
104
105 fn reset(&mut self) {
106 self.prev_close = None;
107 self.seed_buf.clear();
108 self.avg = 0.0;
109 self.seeded = false;
110 }
111
112 fn warmup_period(&self) -> usize {
113 self.period
114 }
115
116 fn is_ready(&self) -> bool {
117 self.seeded
118 }
119
120 fn name(&self) -> &'static str {
121 "ATR"
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128 use crate::traits::BatchExt;
129 use approx::assert_relative_eq;
130
131 fn c(h: f64, l: f64, cl: f64) -> Candle {
132 Candle::new(cl, h, l, cl, 1.0, 0).unwrap()
134 }
135
136 fn atr_naive(hlc: &[(f64, f64, f64)], period: usize) -> Vec<Option<f64>> {
138 let n = period as f64;
139 let mut out = Vec::with_capacity(hlc.len());
140 let mut trs: Vec<f64> = Vec::new();
141 let mut avg: Option<f64> = None;
142 let mut prev_close: Option<f64> = None;
143 for &(h, l, cl) in hlc {
144 let tr = match prev_close {
145 None => h - l,
146 Some(pc) => (h - l).max((h - pc).abs()).max((l - pc).abs()),
147 };
148 prev_close = Some(cl);
149 if let Some(a) = avg {
150 let na = (a * (n - 1.0) + tr) / n;
151 avg = Some(na);
152 out.push(Some(na));
153 } else {
154 trs.push(tr);
155 if trs.len() == period {
156 avg = Some(trs.iter().sum::<f64>() / n);
157 out.push(avg);
158 } else {
159 out.push(None);
160 }
161 }
162 }
163 out
164 }
165
166 #[test]
167 fn rejects_zero_period() {
168 assert!(matches!(Atr::new(0), Err(Error::PeriodZero)));
169 }
170
171 #[test]
175 fn accessors_and_metadata() {
176 let mut atr = Atr::new(14).unwrap();
177 assert_eq!(atr.period(), 14);
178 assert_eq!(atr.name(), "ATR");
179 assert_eq!(atr.value(), None);
180 for _ in 0..14 {
181 atr.update(c(11.0, 9.0, 10.0));
182 }
183 assert!(atr.value().is_some());
184 }
185
186 #[test]
187 fn warmup_emits_on_period_th_candle() {
188 let candles = vec![
189 c(2.0, 1.0, 1.5),
190 c(3.0, 2.0, 2.5),
191 c(4.0, 3.0, 3.5),
192 c(5.0, 4.0, 4.5),
193 c(6.0, 5.0, 5.5),
194 ];
195 let mut atr = Atr::new(3).unwrap();
196 let out = atr.batch(&candles);
197 assert!(out[0].is_none());
198 assert!(out[1].is_none());
199 assert!(out[2].is_some());
200 assert!(out[3].is_some());
201 }
202
203 #[test]
204 fn constant_range_yields_constant_atr() {
205 let candles: Vec<Candle> = (0..30).map(|_| c(11.0, 9.0, 10.0)).collect();
207 let mut atr = Atr::new(14).unwrap();
208 let out = atr.batch(&candles);
209 for v in out.iter().skip(13).flatten() {
210 assert_relative_eq!(*v, 2.0, epsilon = 1e-12);
211 }
212 }
213
214 #[test]
215 fn gap_up_uses_high_minus_prev_close() {
216 let candles = vec![
218 c(6.0, 4.0, 5.0), c(10.0, 9.0, 9.5), ];
221 let mut atr = Atr::new(2).unwrap();
222 let out = atr.batch(&candles);
223 assert_relative_eq!(out[1].unwrap(), 3.5, epsilon = 1e-12);
226 }
227
228 #[test]
229 fn batch_equals_streaming() {
230 let candles: Vec<Candle> = (0..40)
231 .map(|i| {
232 let mid = f64::from(i) + 10.0;
233 c(mid + 0.5, mid - 0.5, mid)
234 })
235 .collect();
236 let mut a = Atr::new(14).unwrap();
237 let mut b = Atr::new(14).unwrap();
238 assert_eq!(
239 a.batch(&candles),
240 candles.iter().map(|x| b.update(*x)).collect::<Vec<_>>()
241 );
242 }
243
244 #[test]
245 fn reset_clears_state() {
246 let candles: Vec<Candle> = (0..20).map(|_| c(11.0, 9.0, 10.0)).collect();
247 let mut atr = Atr::new(5).unwrap();
248 atr.batch(&candles);
249 assert!(atr.is_ready());
250 atr.reset();
251 assert!(!atr.is_ready());
252 assert_eq!(atr.update(candles[0]), None);
253 }
254
255 #[test]
256 fn never_negative() {
257 let candles: Vec<Candle> = (0..200)
258 .map(|i| {
259 let base = 100.0 + (f64::from(i) * 0.3).sin() * 5.0;
260 c(base + 1.0, base - 1.0, base)
261 })
262 .collect();
263 let mut atr = Atr::new(14).unwrap();
264 for v in atr.batch(&candles).into_iter().flatten() {
265 assert!(v >= 0.0, "ATR must be non-negative: {v}");
266 }
267 }
268
269 proptest::proptest! {
270 #![proptest_config(proptest::test_runner::Config::with_cases(48))]
271 #[test]
272 fn atr_matches_naive(
273 period in 1usize..15,
274 bars in proptest::collection::vec(
275 (10.0_f64..1000.0, 0.0_f64..50.0, 0.0_f64..1.0),
276 0..120,
277 ),
278 ) {
279 let hlc: Vec<(f64, f64, f64)> = bars
281 .iter()
282 .map(|&(low, range, frac)| (low + range, low, low + range * frac))
283 .collect();
284 let candles: Vec<Candle> = hlc.iter().map(|&(h, l, cl)| c(h, l, cl)).collect();
285 let mut atr = Atr::new(period).unwrap();
286 let got = atr.batch(&candles);
287 let want = atr_naive(&hlc, period);
288 proptest::prop_assert_eq!(got.len(), want.len());
289 for (g, w) in got.iter().zip(want.iter()) {
290 match (g, w) {
291 (None, None) => {}
292 (Some(a), Some(b)) => proptest::prop_assert!(
293 (a - b).abs() <= 1e-9 * a.abs().max(1.0),
294 "got={a} want={b}"
295 ),
296 _ => proptest::prop_assert!(false, "warmup mismatch"),
297 }
298 }
299 }
300 }
301}