1use crate::error::{Error, Result};
4use crate::ohlcv::Candle;
5use crate::traits::Indicator;
6
7#[derive(Debug, Clone)]
29pub struct Atr {
30 period: usize,
31 prev_close: Option<f64>,
32 seed_buf: Vec<f64>,
33 avg: Option<f64>,
34}
35
36impl Atr {
37 pub fn new(period: usize) -> Result<Self> {
43 if period == 0 {
44 return Err(Error::PeriodZero);
45 }
46 Ok(Self {
47 period,
48 prev_close: None,
49 seed_buf: Vec::with_capacity(period),
50 avg: None,
51 })
52 }
53
54 pub const fn period(&self) -> usize {
56 self.period
57 }
58
59 pub const fn value(&self) -> Option<f64> {
61 self.avg
62 }
63}
64
65impl Indicator for Atr {
66 type Input = Candle;
67 type Output = f64;
68
69 fn update(&mut self, candle: Candle) -> Option<f64> {
70 let tr = candle.true_range(self.prev_close);
71 self.prev_close = Some(candle.close);
72
73 if let Some(avg) = self.avg {
74 let n = self.period as f64;
75 let new_avg = avg.mul_add(n - 1.0, tr) / n;
76 self.avg = Some(new_avg);
77 return Some(new_avg);
78 }
79
80 self.seed_buf.push(tr);
81 if self.seed_buf.len() == self.period {
82 let seed = self.seed_buf.iter().copied().sum::<f64>() / self.period as f64;
83 self.avg = Some(seed);
84 return Some(seed);
85 }
86 None
87 }
88
89 fn reset(&mut self) {
90 self.prev_close = None;
91 self.seed_buf.clear();
92 self.avg = None;
93 }
94
95 fn warmup_period(&self) -> usize {
96 self.period
97 }
98
99 fn is_ready(&self) -> bool {
100 self.avg.is_some()
101 }
102
103 fn name(&self) -> &'static str {
104 "ATR"
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111 use crate::traits::BatchExt;
112 use approx::assert_relative_eq;
113
114 fn c(h: f64, l: f64, cl: f64) -> Candle {
115 Candle::new(cl, h, l, cl, 1.0, 0).unwrap()
117 }
118
119 fn atr_naive(hlc: &[(f64, f64, f64)], period: usize) -> Vec<Option<f64>> {
121 let n = period as f64;
122 let mut out = Vec::with_capacity(hlc.len());
123 let mut trs: Vec<f64> = Vec::new();
124 let mut avg: Option<f64> = None;
125 let mut prev_close: Option<f64> = None;
126 for &(h, l, cl) in hlc {
127 let tr = match prev_close {
128 None => h - l,
129 Some(pc) => (h - l).max((h - pc).abs()).max((l - pc).abs()),
130 };
131 prev_close = Some(cl);
132 if let Some(a) = avg {
133 let na = (a * (n - 1.0) + tr) / n;
134 avg = Some(na);
135 out.push(Some(na));
136 } else {
137 trs.push(tr);
138 if trs.len() == period {
139 avg = Some(trs.iter().sum::<f64>() / n);
140 out.push(avg);
141 } else {
142 out.push(None);
143 }
144 }
145 }
146 out
147 }
148
149 #[test]
150 fn rejects_zero_period() {
151 assert!(matches!(Atr::new(0), Err(Error::PeriodZero)));
152 }
153
154 #[test]
158 fn accessors_and_metadata() {
159 let mut atr = Atr::new(14).unwrap();
160 assert_eq!(atr.period(), 14);
161 assert_eq!(atr.name(), "ATR");
162 assert_eq!(atr.value(), None);
163 for _ in 0..14 {
164 atr.update(c(11.0, 9.0, 10.0));
165 }
166 assert!(atr.value().is_some());
167 }
168
169 #[test]
170 fn warmup_emits_on_period_th_candle() {
171 let candles = vec![
172 c(2.0, 1.0, 1.5),
173 c(3.0, 2.0, 2.5),
174 c(4.0, 3.0, 3.5),
175 c(5.0, 4.0, 4.5),
176 c(6.0, 5.0, 5.5),
177 ];
178 let mut atr = Atr::new(3).unwrap();
179 let out = atr.batch(&candles);
180 assert!(out[0].is_none());
181 assert!(out[1].is_none());
182 assert!(out[2].is_some());
183 assert!(out[3].is_some());
184 }
185
186 #[test]
187 fn constant_range_yields_constant_atr() {
188 let candles: Vec<Candle> = (0..30).map(|_| c(11.0, 9.0, 10.0)).collect();
190 let mut atr = Atr::new(14).unwrap();
191 let out = atr.batch(&candles);
192 for v in out.iter().skip(13).flatten() {
193 assert_relative_eq!(*v, 2.0, epsilon = 1e-12);
194 }
195 }
196
197 #[test]
198 fn gap_up_uses_high_minus_prev_close() {
199 let candles = vec![
201 c(6.0, 4.0, 5.0), c(10.0, 9.0, 9.5), ];
204 let mut atr = Atr::new(2).unwrap();
205 let out = atr.batch(&candles);
206 assert_relative_eq!(out[1].unwrap(), 3.5, epsilon = 1e-12);
209 }
210
211 #[test]
212 fn batch_equals_streaming() {
213 let candles: Vec<Candle> = (0..40)
214 .map(|i| {
215 let mid = f64::from(i) + 10.0;
216 c(mid + 0.5, mid - 0.5, mid)
217 })
218 .collect();
219 let mut a = Atr::new(14).unwrap();
220 let mut b = Atr::new(14).unwrap();
221 assert_eq!(
222 a.batch(&candles),
223 candles.iter().map(|x| b.update(*x)).collect::<Vec<_>>()
224 );
225 }
226
227 #[test]
228 fn reset_clears_state() {
229 let candles: Vec<Candle> = (0..20).map(|_| c(11.0, 9.0, 10.0)).collect();
230 let mut atr = Atr::new(5).unwrap();
231 atr.batch(&candles);
232 assert!(atr.is_ready());
233 atr.reset();
234 assert!(!atr.is_ready());
235 assert_eq!(atr.update(candles[0]), None);
236 }
237
238 #[test]
239 fn never_negative() {
240 let candles: Vec<Candle> = (0..200)
241 .map(|i| {
242 let base = 100.0 + (f64::from(i) * 0.3).sin() * 5.0;
243 c(base + 1.0, base - 1.0, base)
244 })
245 .collect();
246 let mut atr = Atr::new(14).unwrap();
247 for v in atr.batch(&candles).into_iter().flatten() {
248 assert!(v >= 0.0, "ATR must be non-negative: {v}");
249 }
250 }
251
252 proptest::proptest! {
253 #![proptest_config(proptest::test_runner::Config::with_cases(48))]
254 #[test]
255 fn atr_matches_naive(
256 period in 1usize..15,
257 bars in proptest::collection::vec(
258 (10.0_f64..1000.0, 0.0_f64..50.0, 0.0_f64..1.0),
259 0..120,
260 ),
261 ) {
262 let hlc: Vec<(f64, f64, f64)> = bars
264 .iter()
265 .map(|&(low, range, frac)| (low + range, low, low + range * frac))
266 .collect();
267 let candles: Vec<Candle> = hlc.iter().map(|&(h, l, cl)| c(h, l, cl)).collect();
268 let mut atr = Atr::new(period).unwrap();
269 let got = atr.batch(&candles);
270 let want = atr_naive(&hlc, period);
271 proptest::prop_assert_eq!(got.len(), want.len());
272 for (g, w) in got.iter().zip(want.iter()) {
273 match (g, w) {
274 (None, None) => {}
275 (Some(a), Some(b)) => proptest::prop_assert!(
276 (a - b).abs() <= 1e-9 * a.abs().max(1.0),
277 "got={a} want={b}"
278 ),
279 _ => proptest::prop_assert!(false, "warmup mismatch"),
280 }
281 }
282 }
283 }
284}