Skip to main content

wickra_core/indicators/
super_trend.rs

1//! `SuperTrend`.
2
3use crate::error::{Error, Result};
4use crate::indicators::atr::Atr;
5use crate::ohlcv::Candle;
6use crate::traits::Indicator;
7
8/// `SuperTrend` output: the trailing-stop level and the trend direction.
9#[derive(Debug, Clone, Copy, PartialEq)]
10pub struct SuperTrendOutput {
11    /// The `SuperTrend` line — the active trailing-stop level for this bar.
12    pub value: f64,
13    /// Trend direction: `+1.0` in an uptrend (the line sits below price),
14    /// `-1.0` in a downtrend (the line sits above price).
15    pub direction: f64,
16}
17
18/// Previous-bar state carried forward by the `SuperTrend` recurrence.
19#[derive(Debug, Clone, Copy)]
20struct PrevState {
21    final_upper: f64,
22    final_lower: f64,
23    close: f64,
24    direction: f64,
25}
26
27/// `SuperTrend` — an ATR-banded trailing stop that flips sides on a close
28/// through the band.
29///
30/// ```text
31/// hl2          = (high + low) / 2
32/// basic_upper  = hl2 + multiplier · ATR
33/// basic_lower  = hl2 − multiplier · ATR
34///
35/// final_upper  = basic_upper  if basic_upper < prev_final_upper or prev_close > prev_final_upper
36///                else prev_final_upper
37/// final_lower  = basic_lower  if basic_lower > prev_final_lower or prev_close < prev_final_lower
38///                else prev_final_lower
39///
40/// in a downtrend: stay down while close <= final_upper, else flip up
41/// in an uptrend:  stay up   while close >= final_lower, else flip down
42/// SuperTrend   = final_lower in an uptrend, final_upper in a downtrend
43/// ```
44///
45/// The final bands ratchet — the upper band only moves down (and the lower
46/// band only moves up) until price closes through it, which flips the trend
47/// and hands the role of trailing stop to the opposite band. The first
48/// ATR-ready bar seeds the trend as up. Wilder's classic configuration is
49/// `ATR(10)` with a `3.0` multiplier.
50///
51/// # Example
52///
53/// ```
54/// use wickra_core::{Candle, Indicator, SuperTrend};
55///
56/// let mut indicator = SuperTrend::classic();
57/// let mut last = None;
58/// for i in 0..80 {
59///     let base = 100.0 + f64::from(i);
60///     let candle =
61///         Candle::new(base, base + 2.0, base - 2.0, base + 1.0, 10.0, i64::from(i)).unwrap();
62///     last = indicator.update(candle);
63/// }
64/// assert!(last.is_some());
65/// ```
66#[derive(Debug, Clone)]
67pub struct SuperTrend {
68    atr: Atr,
69    multiplier: f64,
70    atr_period: usize,
71    prev: Option<PrevState>,
72}
73
74impl SuperTrend {
75    /// Construct a `SuperTrend` with an explicit ATR period and band multiplier.
76    ///
77    /// # Errors
78    /// Returns [`Error::PeriodZero`] if `atr_period == 0` and
79    /// [`Error::NonPositiveMultiplier`] if `multiplier` is not strictly
80    /// positive and finite.
81    pub fn new(atr_period: usize, multiplier: f64) -> Result<Self> {
82        if !multiplier.is_finite() || multiplier <= 0.0 {
83            return Err(Error::NonPositiveMultiplier);
84        }
85        Ok(Self {
86            atr: Atr::new(atr_period)?,
87            multiplier,
88            atr_period,
89            prev: None,
90        })
91    }
92
93    /// Wilder's classic configuration: `ATR(10)` with a `3.0` multiplier.
94    pub fn classic() -> Self {
95        Self::new(10, 3.0).expect("classic SuperTrend params are valid")
96    }
97
98    /// Configured `(atr_period, multiplier)`.
99    pub const fn params(&self) -> (usize, f64) {
100        (self.atr_period, self.multiplier)
101    }
102}
103
104impl Indicator for SuperTrend {
105    type Input = Candle;
106    type Output = SuperTrendOutput;
107
108    fn update(&mut self, candle: Candle) -> Option<SuperTrendOutput> {
109        let atr = self.atr.update(candle)?;
110        let hl2 = f64::midpoint(candle.high, candle.low);
111        let basic_upper = hl2 + self.multiplier * atr;
112        let basic_lower = hl2 - self.multiplier * atr;
113
114        let (final_upper, final_lower, direction) = match self.prev {
115            None => {
116                // First ATR-ready bar: no prior bands, seed the trend as up.
117                (basic_upper, basic_lower, 1.0)
118            }
119            Some(p) => {
120                let final_upper = if basic_upper < p.final_upper || p.close > p.final_upper {
121                    basic_upper
122                } else {
123                    p.final_upper
124                };
125                let final_lower = if basic_lower > p.final_lower || p.close < p.final_lower {
126                    basic_lower
127                } else {
128                    p.final_lower
129                };
130                let direction = if p.direction < 0.0 {
131                    // Previous downtrend — the line was the upper band.
132                    if candle.close <= final_upper {
133                        -1.0
134                    } else {
135                        1.0
136                    }
137                } else {
138                    // Previous uptrend — the line was the lower band.
139                    if candle.close >= final_lower {
140                        1.0
141                    } else {
142                        -1.0
143                    }
144                };
145                (final_upper, final_lower, direction)
146            }
147        };
148
149        let value = if direction > 0.0 {
150            final_lower
151        } else {
152            final_upper
153        };
154        self.prev = Some(PrevState {
155            final_upper,
156            final_lower,
157            close: candle.close,
158            direction,
159        });
160        Some(SuperTrendOutput { value, direction })
161    }
162
163    fn reset(&mut self) {
164        self.atr.reset();
165        self.prev = None;
166    }
167
168    fn warmup_period(&self) -> usize {
169        self.atr_period
170    }
171
172    fn is_ready(&self) -> bool {
173        self.prev.is_some()
174    }
175
176    fn name(&self) -> &'static str {
177        "SuperTrend"
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184    use crate::traits::BatchExt;
185
186    fn c(high: f64, low: f64, close: f64, ts: i64) -> Candle {
187        Candle::new(f64::midpoint(high, low), high, low, close, 1.0, ts).unwrap()
188    }
189
190    #[test]
191    fn uptrend_keeps_line_below_price_and_direction_up() {
192        let candles: Vec<Candle> = (0..60)
193            .map(|i| {
194                let base = 100.0 + 2.0 * i as f64;
195                c(base + 1.0, base - 1.0, base + 0.5, i)
196            })
197            .collect();
198        let mut st = SuperTrend::classic();
199        for (o, candle) in st.batch(&candles).into_iter().zip(candles.iter()) {
200            if let Some(o) = o {
201                assert_eq!(o.direction, 1.0, "a pure uptrend stays in direction +1");
202                assert!(o.value < candle.close, "the stop line sits below price");
203            }
204        }
205    }
206
207    #[test]
208    fn downtrend_keeps_line_above_price_and_direction_down() {
209        let candles: Vec<Candle> = (0..60)
210            .map(|i| {
211                let base = 220.0 - 2.0 * i as f64;
212                c(base + 1.0, base - 1.0, base - 0.5, i)
213            })
214            .collect();
215        let mut st = SuperTrend::classic();
216        let emitted: Vec<(SuperTrendOutput, f64)> = st
217            .batch(&candles)
218            .into_iter()
219            .zip(candles.iter())
220            .filter_map(|(o, c)| o.map(|v| (v, c.close)))
221            .collect();
222        // The seed bar starts the trend up; a steep decline flips it within a
223        // few bars. The settled tail must be a clean downtrend.
224        for &(o, close) in emitted.iter().skip(10) {
225            assert_eq!(
226                o.direction, -1.0,
227                "a steep downtrend settles to direction -1"
228            );
229            assert!(o.value > close, "the stop line sits above price");
230        }
231    }
232
233    #[test]
234    fn trend_flips_when_price_reverses() {
235        let mut candles: Vec<Candle> = (0..40)
236            .map(|i| {
237                let base = 100.0 + i as f64;
238                c(base + 1.0, base - 1.0, base + 0.5, i)
239            })
240            .collect();
241        candles.extend((0..40).map(|i| {
242            let base = 140.0 - i as f64;
243            c(base + 1.0, base - 1.0, base - 0.5, 40 + i)
244        }));
245        let mut st = SuperTrend::classic();
246        let dirs: Vec<f64> = st
247            .batch(&candles)
248            .into_iter()
249            .flatten()
250            .map(|o| o.direction)
251            .collect();
252        assert!(dirs.iter().any(|&d| d > 0.0), "expected an uptrend stretch");
253        assert!(
254            dirs.iter().any(|&d| d < 0.0),
255            "expected a downtrend stretch"
256        );
257    }
258
259    #[test]
260    fn first_emission_matches_warmup_period() {
261        let candles: Vec<Candle> = (0..30)
262            .map(|i| {
263                let base = 100.0 + i as f64;
264                c(base + 1.0, base - 1.0, base, i)
265            })
266            .collect();
267        let mut st = SuperTrend::classic();
268        let out = st.batch(&candles);
269        assert_eq!(st.warmup_period(), 10);
270        for (i, v) in out.iter().enumerate().take(9) {
271            assert!(v.is_none(), "index {i} must be None during warmup");
272        }
273        assert!(out[9].is_some(), "first value lands at warmup_period - 1");
274    }
275
276    #[test]
277    fn rejects_invalid_params() {
278        assert!(SuperTrend::new(0, 3.0).is_err());
279        assert!(SuperTrend::new(10, 0.0).is_err());
280        assert!(SuperTrend::new(10, -1.0).is_err());
281        assert!(SuperTrend::new(10, f64::NAN).is_err());
282    }
283
284    /// Cover the const accessor `params` (99-101) and the Indicator-impl
285    /// `name` body (176-178). `warmup_period` is exercised elsewhere.
286    #[test]
287    fn accessors_and_metadata() {
288        let st = SuperTrend::new(10, 3.0).unwrap();
289        let (p, m) = st.params();
290        assert_eq!(p, 10);
291        assert!((m - 3.0).abs() < 1e-12);
292        assert_eq!(st.name(), "SuperTrend");
293    }
294
295    #[test]
296    fn reset_clears_state() {
297        let candles: Vec<Candle> = (0..40)
298            .map(|i| {
299                let base = 100.0 + i as f64;
300                c(base + 1.0, base - 1.0, base, i)
301            })
302            .collect();
303        let mut st = SuperTrend::classic();
304        st.batch(&candles);
305        assert!(st.is_ready());
306        st.reset();
307        assert!(!st.is_ready());
308        assert_eq!(st.update(candles[0]), None);
309    }
310
311    #[test]
312    fn batch_equals_streaming() {
313        let candles: Vec<Candle> = (0..80)
314            .map(|i| {
315                let mid = 100.0 + (i as f64 * 0.3).sin() * 8.0;
316                c(mid + 1.5, mid - 1.5, mid + 0.5, i)
317            })
318            .collect();
319        let mut a = SuperTrend::classic();
320        let mut b = SuperTrend::classic();
321        assert_eq!(
322            a.batch(&candles),
323            candles.iter().map(|x| b.update(*x)).collect::<Vec<_>>()
324        );
325    }
326}