1use crate::error::{Error, Result};
4use crate::indicators::atr::Atr;
5use crate::ohlcv::Candle;
6use crate::traits::Indicator;
7
8#[derive(Debug, Clone, Copy, PartialEq)]
10pub struct AtrRatchetOutput {
11 pub value: f64,
13 pub direction: f64,
15}
16
17#[derive(Debug, Clone)]
51pub struct AtrRatchet {
52 atr: Atr,
53 atr_period: usize,
54 start_mult: f64,
55 increment: f64,
56 direction: f64,
57 stop: f64,
58 last: Option<AtrRatchetOutput>,
59}
60
61impl AtrRatchet {
62 pub fn new(atr_period: usize, start_mult: f64, increment: f64) -> Result<Self> {
70 if !start_mult.is_finite()
71 || start_mult <= 0.0
72 || !increment.is_finite()
73 || increment <= 0.0
74 {
75 return Err(Error::NonPositiveMultiplier);
76 }
77 Ok(Self {
78 atr: Atr::new(atr_period)?,
79 atr_period,
80 start_mult,
81 increment,
82 direction: 0.0,
83 stop: 0.0,
84 last: None,
85 })
86 }
87
88 pub const fn params(&self) -> (usize, f64, f64) {
90 (self.atr_period, self.start_mult, self.increment)
91 }
92
93 pub const fn value(&self) -> Option<AtrRatchetOutput> {
95 self.last
96 }
97}
98
99impl Indicator for AtrRatchet {
100 type Input = Candle;
101 type Output = AtrRatchetOutput;
102
103 fn update(&mut self, candle: Candle) -> Option<AtrRatchetOutput> {
104 let atr = self.atr.update(candle)?;
105 let close = candle.close;
106
107 if self.direction == 0.0 {
108 self.direction = 1.0;
109 self.stop = close - self.start_mult * atr;
110 } else if self.direction > 0.0 {
111 self.stop += self.increment * atr;
112 if close < self.stop {
113 self.direction = -1.0;
114 self.stop = close + self.start_mult * atr;
115 }
116 } else {
117 self.stop -= self.increment * atr;
118 if close > self.stop {
119 self.direction = 1.0;
120 self.stop = close - self.start_mult * atr;
121 }
122 }
123
124 let out = AtrRatchetOutput {
125 value: self.stop,
126 direction: self.direction,
127 };
128 self.last = Some(out);
129 Some(out)
130 }
131
132 fn reset(&mut self) {
133 self.atr.reset();
134 self.direction = 0.0;
135 self.stop = 0.0;
136 self.last = None;
137 }
138
139 fn warmup_period(&self) -> usize {
140 self.atr_period
141 }
142
143 fn is_ready(&self) -> bool {
144 self.last.is_some()
145 }
146
147 fn name(&self) -> &'static str {
148 "AtrRatchet"
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155 use crate::traits::BatchExt;
156
157 fn c(high: f64, low: f64, close: f64) -> Candle {
158 Candle::new_unchecked(f64::midpoint(high, low), high, low, close, 1_000.0, 0)
159 }
160
161 #[test]
162 fn rejects_invalid_params() {
163 assert!(matches!(
164 AtrRatchet::new(0, 4.0, 0.1),
165 Err(Error::PeriodZero)
166 ));
167 assert!(matches!(
168 AtrRatchet::new(14, 0.0, 0.1),
169 Err(Error::NonPositiveMultiplier)
170 ));
171 assert!(matches!(
172 AtrRatchet::new(14, 4.0, 0.0),
173 Err(Error::NonPositiveMultiplier)
174 ));
175 assert!(matches!(
176 AtrRatchet::new(14, 4.0, f64::NAN),
177 Err(Error::NonPositiveMultiplier)
178 ));
179 }
180
181 #[test]
182 fn accessors_and_metadata() {
183 let r = AtrRatchet::new(14, 4.0, 0.1).unwrap();
184 assert_eq!(r.params(), (14, 4.0, 0.1));
185 assert_eq!(r.warmup_period(), 14);
186 assert_eq!(r.name(), "AtrRatchet");
187 assert!(!r.is_ready());
188 assert_eq!(r.value(), None);
189 }
190
191 #[test]
192 fn first_emission_at_warmup_period() {
193 let mut r = AtrRatchet::new(5, 4.0, 0.1).unwrap();
194 let candles: Vec<Candle> = (0..12)
195 .map(|i| {
196 let base = 100.0 + f64::from(i);
197 c(base + 1.0, base - 1.0, base)
198 })
199 .collect();
200 let out = r.batch(&candles);
201 for v in out.iter().take(4) {
202 assert!(v.is_none());
203 }
204 assert!(out[4].is_some());
205 }
206
207 #[test]
208 fn uptrend_keeps_stop_below_price() {
209 let mut r = AtrRatchet::new(5, 4.0, 0.05).unwrap();
210 let candles: Vec<Candle> = (0..60)
211 .map(|i| {
212 let base = 100.0 + 2.0 * f64::from(i);
213 c(base + 1.0, base - 1.0, base + 0.5)
214 })
215 .collect();
216 for (o, candle) in r.batch(&candles).into_iter().zip(candles.iter()) {
217 if let Some(o) = o {
218 assert_eq!(o.direction, 1.0);
219 assert!(o.value < candle.close);
220 }
221 }
222 }
223
224 #[test]
225 fn stall_eventually_triggers_flip() {
226 let mut r = AtrRatchet::new(5, 2.0, 0.5).unwrap();
229 let mut candles: Vec<Candle> = (0..20)
230 .map(|i| {
231 let base = 100.0 + f64::from(i);
232 c(base + 1.0, base - 1.0, base + 0.5)
233 })
234 .collect();
235 candles.extend((0..40).map(|_| c(120.6, 118.6, 119.5)));
237 let dirs: Vec<f64> = r
238 .batch(&candles)
239 .into_iter()
240 .flatten()
241 .map(|o| o.direction)
242 .collect();
243 assert!(
244 dirs.iter().any(|&d| d < 0.0),
245 "the ratchet should eventually flip short"
246 );
247 }
248
249 #[test]
250 fn reset_clears_state() {
251 let mut r = AtrRatchet::new(5, 4.0, 0.1).unwrap();
252 let candles: Vec<Candle> = (0..40)
253 .map(|i| {
254 let base = 100.0 + f64::from(i);
255 c(base + 1.0, base - 1.0, base + 0.5)
256 })
257 .collect();
258 r.batch(&candles);
259 assert!(r.is_ready());
260 r.reset();
261 assert!(!r.is_ready());
262 assert_eq!(r.value(), None);
263 assert_eq!(r.update(candles[0]), None);
264 }
265
266 #[test]
267 fn batch_equals_streaming() {
268 let candles: Vec<Candle> = (0..120)
269 .map(|i| {
270 let base = 100.0 + (f64::from(i) * 0.25).sin() * 9.0;
271 c(base + 2.0, base - 1.5, base + 0.5)
272 })
273 .collect();
274 let batch = AtrRatchet::new(14, 4.0, 0.1).unwrap().batch(&candles);
275 let mut b = AtrRatchet::new(14, 4.0, 0.1).unwrap();
276 let streamed: Vec<_> = candles.iter().map(|c| b.update(*c)).collect();
277 assert_eq!(batch, streamed);
278 }
279}