Skip to main content

wickra_core/indicators/
atr_trailing_stop.rs

1//! ATR Trailing Stop.
2
3use crate::error::{Error, Result};
4use crate::indicators::atr::Atr;
5use crate::ohlcv::Candle;
6use crate::traits::Indicator;
7
8/// ATR Trailing Stop — a stop level that trails price by a fixed ATR multiple
9/// and ratchets in the direction of the trend.
10///
11/// ```text
12/// loss = multiplier · ATR
13///
14/// stop_t = max(stop_{t−1}, close − loss)   while price holds above the stop
15///        = min(stop_{t−1}, close + loss)   while price holds below the stop
16///        = close − loss                   on a fresh break above the stop
17///        = close + loss                   on a fresh break below the stop
18/// ```
19///
20/// While price stays on one side of the stop the level only ratchets toward
21/// price — up in an uptrend, down in a downtrend — never away from it. When a
22/// close crosses the stop the level snaps to the opposite side, `loss` away
23/// from the new close, flipping the trade. This is the trailing stop used by
24/// the well-known "UT Bot"; the first ATR-ready bar seeds the stop below
25/// price (a long).
26///
27/// # Example
28///
29/// ```
30/// use wickra_core::{Candle, Indicator, AtrTrailingStop};
31///
32/// let mut indicator = AtrTrailingStop::new(14, 3.0).unwrap();
33/// let mut last = None;
34/// for i in 0..80 {
35///     let base = 100.0 + f64::from(i);
36///     let candle =
37///         Candle::new(base, base + 2.0, base - 2.0, base + 1.0, 10.0, i64::from(i)).unwrap();
38///     last = indicator.update(candle);
39/// }
40/// assert!(last.is_some());
41/// ```
42#[derive(Debug, Clone)]
43pub struct AtrTrailingStop {
44    atr: Atr,
45    multiplier: f64,
46    atr_period: usize,
47    prev_close: Option<f64>,
48    prev_stop: Option<f64>,
49}
50
51impl AtrTrailingStop {
52    /// Construct an ATR Trailing Stop with an explicit ATR period and multiple.
53    ///
54    /// # Errors
55    /// Returns [`Error::PeriodZero`] if `atr_period == 0` and
56    /// [`Error::NonPositiveMultiplier`] if `multiplier` is not strictly
57    /// positive and finite.
58    pub fn new(atr_period: usize, multiplier: f64) -> Result<Self> {
59        if !multiplier.is_finite() || multiplier <= 0.0 {
60            return Err(Error::NonPositiveMultiplier);
61        }
62        Ok(Self {
63            atr: Atr::new(atr_period)?,
64            multiplier,
65            atr_period,
66            prev_close: None,
67            prev_stop: None,
68        })
69    }
70
71    /// A common configuration: `ATR(14)` with a `3.0` multiplier.
72    pub fn classic() -> Self {
73        Self::new(14, 3.0).expect("classic ATR Trailing Stop params are valid")
74    }
75
76    /// Configured `(atr_period, multiplier)`.
77    pub const fn params(&self) -> (usize, f64) {
78        (self.atr_period, self.multiplier)
79    }
80}
81
82impl Indicator for AtrTrailingStop {
83    type Input = Candle;
84    type Output = f64;
85
86    fn update(&mut self, candle: Candle) -> Option<f64> {
87        let atr = self.atr.update(candle)?;
88        let loss = self.multiplier * atr;
89        let close = candle.close;
90
91        let stop = match (self.prev_stop, self.prev_close) {
92            (Some(prev_stop), Some(prev_close)) => {
93                if close > prev_stop && prev_close > prev_stop {
94                    // Holding above the stop — ratchet it up only.
95                    (close - loss).max(prev_stop)
96                } else if close < prev_stop && prev_close < prev_stop {
97                    // Holding below the stop — ratchet it down only.
98                    (close + loss).min(prev_stop)
99                } else if close > prev_stop {
100                    // Fresh break above — place the stop below the new close.
101                    close - loss
102                } else {
103                    // Fresh break below — place the stop above the new close.
104                    close + loss
105                }
106            }
107            // First ATR-ready bar: seed the stop below price (a long).
108            _ => close - loss,
109        };
110
111        self.prev_close = Some(close);
112        self.prev_stop = Some(stop);
113        Some(stop)
114    }
115
116    fn reset(&mut self) {
117        self.atr.reset();
118        self.prev_close = None;
119        self.prev_stop = None;
120    }
121
122    fn warmup_period(&self) -> usize {
123        self.atr_period
124    }
125
126    fn is_ready(&self) -> bool {
127        self.prev_stop.is_some()
128    }
129
130    fn name(&self) -> &'static str {
131        "AtrTrailingStop"
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138    use crate::traits::BatchExt;
139    use approx::assert_relative_eq;
140
141    fn c(high: f64, low: f64, close: f64, ts: i64) -> Candle {
142        Candle::new(f64::midpoint(high, low), high, low, close, 1.0, ts).unwrap()
143    }
144
145    #[test]
146    fn reference_values_flat_market() {
147        // Flat candles H=11, L=9, C=10 -> TR=2 -> ATR=2; loss = 3·2 = 6.
148        // Seed stop = close - loss = 10 - 6 = 4, and it holds there.
149        let candles: Vec<Candle> = (0..20).map(|i| c(11.0, 9.0, 10.0, i)).collect();
150        let mut ts = AtrTrailingStop::new(5, 3.0).unwrap();
151        for v in ts.batch(&candles).into_iter().flatten() {
152            assert_relative_eq!(v, 4.0, epsilon = 1e-12);
153        }
154    }
155
156    #[test]
157    fn uptrend_stop_ratchets_up_and_stays_below_price() {
158        let candles: Vec<Candle> = (0..50)
159            .map(|i| {
160                let base = 100.0 + i as f64;
161                c(base + 1.0, base - 1.0, base, i)
162            })
163            .collect();
164        let mut ts = AtrTrailingStop::new(14, 3.0).unwrap();
165        let emitted: Vec<(f64, f64)> = ts
166            .batch(&candles)
167            .into_iter()
168            .zip(candles.iter())
169            .filter_map(|(o, c)| o.map(|v| (v, c.close)))
170            .collect();
171        for w in emitted.windows(2) {
172            assert!(
173                w[1].0 >= w[0].0 - 1e-9,
174                "stop must not loosen in an uptrend"
175            );
176        }
177        for &(stop, close) in &emitted {
178            assert!(stop < close, "uptrend stop should sit below the close");
179        }
180    }
181
182    #[test]
183    fn stop_flips_to_the_other_side_when_price_reverses() {
184        let mut candles: Vec<Candle> = (0..40)
185            .map(|i| {
186                let base = 100.0 + i as f64;
187                c(base + 1.0, base - 1.0, base, i)
188            })
189            .collect();
190        // A steep decline drags price through the trailing stop.
191        candles.extend((0..40).map(|i| {
192            let base = 140.0 - 3.0 * i as f64;
193            c(base + 1.0, base - 1.0, base, 40 + i)
194        }));
195        let mut ts = AtrTrailingStop::new(14, 3.0).unwrap();
196        let paired: Vec<(f64, f64)> = ts
197            .batch(&candles)
198            .into_iter()
199            .zip(candles.iter())
200            .filter_map(|(o, c)| o.map(|v| (v, c.close)))
201            .collect();
202        assert!(
203            paired.iter().any(|&(stop, close)| stop < close),
204            "expected a long stretch with the stop below price"
205        );
206        assert!(
207            paired.iter().any(|&(stop, close)| stop > close),
208            "expected the stop to flip above price after the reversal"
209        );
210    }
211
212    #[test]
213    fn first_emission_matches_warmup_period() {
214        let candles: Vec<Candle> = (0..20)
215            .map(|i| {
216                let base = 100.0 + i as f64;
217                c(base + 1.0, base - 1.0, base, i)
218            })
219            .collect();
220        let mut ts = AtrTrailingStop::new(8, 3.0).unwrap();
221        let out = ts.batch(&candles);
222        assert_eq!(ts.warmup_period(), 8);
223        for (i, v) in out.iter().enumerate().take(7) {
224            assert!(v.is_none(), "index {i} must be None during warmup");
225        }
226        assert!(out[7].is_some(), "first value lands at warmup_period - 1");
227    }
228
229    #[test]
230    fn rejects_invalid_params() {
231        assert!(AtrTrailingStop::new(0, 3.0).is_err());
232        assert!(AtrTrailingStop::new(14, 0.0).is_err());
233        assert!(AtrTrailingStop::new(14, -1.0).is_err());
234        assert!(AtrTrailingStop::new(14, f64::NAN).is_err());
235    }
236
237    /// Cover the const accessor `params` (77-79) and the Indicator-impl
238    /// `name` body (130-132). `warmup_period` is exercised elsewhere.
239    #[test]
240    fn accessors_and_metadata() {
241        let s = AtrTrailingStop::classic();
242        let (atr_p, mult) = s.params();
243        assert_eq!(atr_p, 14);
244        assert!((mult - 3.0).abs() < 1e-12);
245        assert_eq!(s.name(), "AtrTrailingStop");
246    }
247
248    #[test]
249    fn reset_clears_state() {
250        let candles: Vec<Candle> = (0..40)
251            .map(|i| {
252                let base = 100.0 + i as f64;
253                c(base + 1.0, base - 1.0, base, i)
254            })
255            .collect();
256        let mut ts = AtrTrailingStop::classic();
257        ts.batch(&candles);
258        assert!(ts.is_ready());
259        ts.reset();
260        assert!(!ts.is_ready());
261        assert_eq!(ts.update(candles[0]), None);
262    }
263
264    #[test]
265    fn batch_equals_streaming() {
266        let candles: Vec<Candle> = (0..80)
267            .map(|i| {
268                let mid = 100.0 + (i as f64 * 0.3).sin() * 8.0;
269                c(mid + 1.5, mid - 1.5, mid + 0.5, i)
270            })
271            .collect();
272        let mut a = AtrTrailingStop::classic();
273        let mut b = AtrTrailingStop::classic();
274        assert_eq!(
275            a.batch(&candles),
276            candles.iter().map(|x| b.update(*x)).collect::<Vec<_>>()
277        );
278    }
279}