wickra_core/indicators/
atr_trailing_stop.rs1use crate::error::{Error, Result};
4use crate::indicators::atr::Atr;
5use crate::ohlcv::Candle;
6use crate::traits::Indicator;
7
8#[derive(Debug, Clone)]
43pub struct AtrTrailingStop {
44 atr: Atr,
45 multiplier: f64,
46 atr_period: usize,
47 prev_close: Option<f64>,
48 prev_stop: Option<f64>,
49}
50
51impl AtrTrailingStop {
52 pub fn new(atr_period: usize, multiplier: f64) -> Result<Self> {
59 if !multiplier.is_finite() || multiplier <= 0.0 {
60 return Err(Error::NonPositiveMultiplier);
61 }
62 Ok(Self {
63 atr: Atr::new(atr_period)?,
64 multiplier,
65 atr_period,
66 prev_close: None,
67 prev_stop: None,
68 })
69 }
70
71 pub fn classic() -> Self {
73 Self::new(14, 3.0).expect("classic ATR Trailing Stop params are valid")
74 }
75
76 pub const fn params(&self) -> (usize, f64) {
78 (self.atr_period, self.multiplier)
79 }
80}
81
82impl Indicator for AtrTrailingStop {
83 type Input = Candle;
84 type Output = f64;
85
86 fn update(&mut self, candle: Candle) -> Option<f64> {
87 let atr = self.atr.update(candle)?;
88 let loss = self.multiplier * atr;
89 let close = candle.close;
90
91 let stop = match (self.prev_stop, self.prev_close) {
92 (Some(prev_stop), Some(prev_close)) => {
93 if close > prev_stop && prev_close > prev_stop {
94 (close - loss).max(prev_stop)
96 } else if close < prev_stop && prev_close < prev_stop {
97 (close + loss).min(prev_stop)
99 } else if close > prev_stop {
100 close - loss
102 } else {
103 close + loss
105 }
106 }
107 _ => close - loss,
109 };
110
111 self.prev_close = Some(close);
112 self.prev_stop = Some(stop);
113 Some(stop)
114 }
115
116 fn reset(&mut self) {
117 self.atr.reset();
118 self.prev_close = None;
119 self.prev_stop = None;
120 }
121
122 fn warmup_period(&self) -> usize {
123 self.atr_period
124 }
125
126 fn is_ready(&self) -> bool {
127 self.prev_stop.is_some()
128 }
129
130 fn name(&self) -> &'static str {
131 "AtrTrailingStop"
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138 use crate::traits::BatchExt;
139 use approx::assert_relative_eq;
140
141 fn c(high: f64, low: f64, close: f64, ts: i64) -> Candle {
142 Candle::new(f64::midpoint(high, low), high, low, close, 1.0, ts).unwrap()
143 }
144
145 #[test]
146 fn reference_values_flat_market() {
147 let candles: Vec<Candle> = (0..20).map(|i| c(11.0, 9.0, 10.0, i)).collect();
150 let mut ts = AtrTrailingStop::new(5, 3.0).unwrap();
151 for v in ts.batch(&candles).into_iter().flatten() {
152 assert_relative_eq!(v, 4.0, epsilon = 1e-12);
153 }
154 }
155
156 #[test]
157 fn uptrend_stop_ratchets_up_and_stays_below_price() {
158 let candles: Vec<Candle> = (0..50)
159 .map(|i| {
160 let base = 100.0 + i as f64;
161 c(base + 1.0, base - 1.0, base, i)
162 })
163 .collect();
164 let mut ts = AtrTrailingStop::new(14, 3.0).unwrap();
165 let emitted: Vec<(f64, f64)> = ts
166 .batch(&candles)
167 .into_iter()
168 .zip(candles.iter())
169 .filter_map(|(o, c)| o.map(|v| (v, c.close)))
170 .collect();
171 for w in emitted.windows(2) {
172 assert!(
173 w[1].0 >= w[0].0 - 1e-9,
174 "stop must not loosen in an uptrend"
175 );
176 }
177 for &(stop, close) in &emitted {
178 assert!(stop < close, "uptrend stop should sit below the close");
179 }
180 }
181
182 #[test]
183 fn stop_flips_to_the_other_side_when_price_reverses() {
184 let mut candles: Vec<Candle> = (0..40)
185 .map(|i| {
186 let base = 100.0 + i as f64;
187 c(base + 1.0, base - 1.0, base, i)
188 })
189 .collect();
190 candles.extend((0..40).map(|i| {
192 let base = 140.0 - 3.0 * i as f64;
193 c(base + 1.0, base - 1.0, base, 40 + i)
194 }));
195 let mut ts = AtrTrailingStop::new(14, 3.0).unwrap();
196 let paired: Vec<(f64, f64)> = ts
197 .batch(&candles)
198 .into_iter()
199 .zip(candles.iter())
200 .filter_map(|(o, c)| o.map(|v| (v, c.close)))
201 .collect();
202 assert!(
203 paired.iter().any(|&(stop, close)| stop < close),
204 "expected a long stretch with the stop below price"
205 );
206 assert!(
207 paired.iter().any(|&(stop, close)| stop > close),
208 "expected the stop to flip above price after the reversal"
209 );
210 }
211
212 #[test]
213 fn first_emission_matches_warmup_period() {
214 let candles: Vec<Candle> = (0..20)
215 .map(|i| {
216 let base = 100.0 + i as f64;
217 c(base + 1.0, base - 1.0, base, i)
218 })
219 .collect();
220 let mut ts = AtrTrailingStop::new(8, 3.0).unwrap();
221 let out = ts.batch(&candles);
222 assert_eq!(ts.warmup_period(), 8);
223 for (i, v) in out.iter().enumerate().take(7) {
224 assert!(v.is_none(), "index {i} must be None during warmup");
225 }
226 assert!(out[7].is_some(), "first value lands at warmup_period - 1");
227 }
228
229 #[test]
230 fn rejects_invalid_params() {
231 assert!(AtrTrailingStop::new(0, 3.0).is_err());
232 assert!(AtrTrailingStop::new(14, 0.0).is_err());
233 assert!(AtrTrailingStop::new(14, -1.0).is_err());
234 assert!(AtrTrailingStop::new(14, f64::NAN).is_err());
235 }
236
237 #[test]
240 fn accessors_and_metadata() {
241 let s = AtrTrailingStop::classic();
242 let (atr_p, mult) = s.params();
243 assert_eq!(atr_p, 14);
244 assert!((mult - 3.0).abs() < 1e-12);
245 assert_eq!(s.name(), "AtrTrailingStop");
246 }
247
248 #[test]
249 fn reset_clears_state() {
250 let candles: Vec<Candle> = (0..40)
251 .map(|i| {
252 let base = 100.0 + i as f64;
253 c(base + 1.0, base - 1.0, base, i)
254 })
255 .collect();
256 let mut ts = AtrTrailingStop::classic();
257 ts.batch(&candles);
258 assert!(ts.is_ready());
259 ts.reset();
260 assert!(!ts.is_ready());
261 assert_eq!(ts.update(candles[0]), None);
262 }
263
264 #[test]
265 fn batch_equals_streaming() {
266 let candles: Vec<Candle> = (0..80)
267 .map(|i| {
268 let mid = 100.0 + (i as f64 * 0.3).sin() * 8.0;
269 c(mid + 1.5, mid - 1.5, mid + 0.5, i)
270 })
271 .collect();
272 let mut a = AtrTrailingStop::classic();
273 let mut b = AtrTrailingStop::classic();
274 assert_eq!(
275 a.batch(&candles),
276 candles.iter().map(|x| b.update(*x)).collect::<Vec<_>>()
277 );
278 }
279}