1use crate::error::{Error, Result};
4use crate::indicators::atr::Atr;
5use crate::ohlcv::Candle;
6use crate::traits::Indicator;
7
8#[derive(Debug, Clone, Copy, PartialEq)]
10pub struct SuperTrendOutput {
11 pub value: f64,
13 pub direction: f64,
16}
17
18#[derive(Debug, Clone, Copy)]
20struct PrevState {
21 final_upper: f64,
22 final_lower: f64,
23 close: f64,
24 direction: f64,
25}
26
27#[derive(Debug, Clone)]
67pub struct SuperTrend {
68 atr: Atr,
69 multiplier: f64,
70 atr_period: usize,
71 prev: Option<PrevState>,
72}
73
74impl SuperTrend {
75 pub fn new(atr_period: usize, multiplier: f64) -> Result<Self> {
82 if !multiplier.is_finite() || multiplier <= 0.0 {
83 return Err(Error::NonPositiveMultiplier);
84 }
85 Ok(Self {
86 atr: Atr::new(atr_period)?,
87 multiplier,
88 atr_period,
89 prev: None,
90 })
91 }
92
93 pub fn classic() -> Self {
95 Self::new(10, 3.0).expect("classic SuperTrend params are valid")
96 }
97
98 pub const fn params(&self) -> (usize, f64) {
100 (self.atr_period, self.multiplier)
101 }
102}
103
104impl Indicator for SuperTrend {
105 type Input = Candle;
106 type Output = SuperTrendOutput;
107
108 fn update(&mut self, candle: Candle) -> Option<SuperTrendOutput> {
109 let atr = self.atr.update(candle)?;
110 let hl2 = f64::midpoint(candle.high, candle.low);
111 let basic_upper = hl2 + self.multiplier * atr;
112 let basic_lower = hl2 - self.multiplier * atr;
113
114 let (final_upper, final_lower, direction) = match self.prev {
115 None => {
116 (basic_upper, basic_lower, 1.0)
118 }
119 Some(p) => {
120 let final_upper = if basic_upper < p.final_upper || p.close > p.final_upper {
121 basic_upper
122 } else {
123 p.final_upper
124 };
125 let final_lower = if basic_lower > p.final_lower || p.close < p.final_lower {
126 basic_lower
127 } else {
128 p.final_lower
129 };
130 let direction = if p.direction < 0.0 {
131 if candle.close <= final_upper {
133 -1.0
134 } else {
135 1.0
136 }
137 } else {
138 if candle.close >= final_lower {
140 1.0
141 } else {
142 -1.0
143 }
144 };
145 (final_upper, final_lower, direction)
146 }
147 };
148
149 let value = if direction > 0.0 {
150 final_lower
151 } else {
152 final_upper
153 };
154 self.prev = Some(PrevState {
155 final_upper,
156 final_lower,
157 close: candle.close,
158 direction,
159 });
160 Some(SuperTrendOutput { value, direction })
161 }
162
163 fn reset(&mut self) {
164 self.atr.reset();
165 self.prev = None;
166 }
167
168 fn warmup_period(&self) -> usize {
169 self.atr_period
170 }
171
172 fn is_ready(&self) -> bool {
173 self.prev.is_some()
174 }
175
176 fn name(&self) -> &'static str {
177 "SuperTrend"
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184 use crate::traits::BatchExt;
185
186 fn c(high: f64, low: f64, close: f64, ts: i64) -> Candle {
187 Candle::new(f64::midpoint(high, low), high, low, close, 1.0, ts).unwrap()
188 }
189
190 #[test]
191 fn uptrend_keeps_line_below_price_and_direction_up() {
192 let candles: Vec<Candle> = (0..60)
193 .map(|i| {
194 let base = 100.0 + 2.0 * i as f64;
195 c(base + 1.0, base - 1.0, base + 0.5, i)
196 })
197 .collect();
198 let mut st = SuperTrend::classic();
199 for (o, candle) in st.batch(&candles).into_iter().zip(candles.iter()) {
200 if let Some(o) = o {
201 assert_eq!(o.direction, 1.0, "a pure uptrend stays in direction +1");
202 assert!(o.value < candle.close, "the stop line sits below price");
203 }
204 }
205 }
206
207 #[test]
208 fn downtrend_keeps_line_above_price_and_direction_down() {
209 let candles: Vec<Candle> = (0..60)
210 .map(|i| {
211 let base = 220.0 - 2.0 * i as f64;
212 c(base + 1.0, base - 1.0, base - 0.5, i)
213 })
214 .collect();
215 let mut st = SuperTrend::classic();
216 let emitted: Vec<(SuperTrendOutput, f64)> = st
217 .batch(&candles)
218 .into_iter()
219 .zip(candles.iter())
220 .filter_map(|(o, c)| o.map(|v| (v, c.close)))
221 .collect();
222 for &(o, close) in emitted.iter().skip(10) {
225 assert_eq!(
226 o.direction, -1.0,
227 "a steep downtrend settles to direction -1"
228 );
229 assert!(o.value > close, "the stop line sits above price");
230 }
231 }
232
233 #[test]
234 fn trend_flips_when_price_reverses() {
235 let mut candles: Vec<Candle> = (0..40)
236 .map(|i| {
237 let base = 100.0 + i as f64;
238 c(base + 1.0, base - 1.0, base + 0.5, i)
239 })
240 .collect();
241 candles.extend((0..40).map(|i| {
242 let base = 140.0 - i as f64;
243 c(base + 1.0, base - 1.0, base - 0.5, 40 + i)
244 }));
245 let mut st = SuperTrend::classic();
246 let dirs: Vec<f64> = st
247 .batch(&candles)
248 .into_iter()
249 .flatten()
250 .map(|o| o.direction)
251 .collect();
252 assert!(dirs.iter().any(|&d| d > 0.0), "expected an uptrend stretch");
253 assert!(
254 dirs.iter().any(|&d| d < 0.0),
255 "expected a downtrend stretch"
256 );
257 }
258
259 #[test]
260 fn first_emission_matches_warmup_period() {
261 let candles: Vec<Candle> = (0..30)
262 .map(|i| {
263 let base = 100.0 + i as f64;
264 c(base + 1.0, base - 1.0, base, i)
265 })
266 .collect();
267 let mut st = SuperTrend::classic();
268 let out = st.batch(&candles);
269 assert_eq!(st.warmup_period(), 10);
270 for (i, v) in out.iter().enumerate().take(9) {
271 assert!(v.is_none(), "index {i} must be None during warmup");
272 }
273 assert!(out[9].is_some(), "first value lands at warmup_period - 1");
274 }
275
276 #[test]
277 fn rejects_invalid_params() {
278 assert!(SuperTrend::new(0, 3.0).is_err());
279 assert!(SuperTrend::new(10, 0.0).is_err());
280 assert!(SuperTrend::new(10, -1.0).is_err());
281 assert!(SuperTrend::new(10, f64::NAN).is_err());
282 }
283
284 #[test]
287 fn accessors_and_metadata() {
288 let st = SuperTrend::new(10, 3.0).unwrap();
289 let (p, m) = st.params();
290 assert_eq!(p, 10);
291 assert!((m - 3.0).abs() < 1e-12);
292 assert_eq!(st.name(), "SuperTrend");
293 }
294
295 #[test]
296 fn reset_clears_state() {
297 let candles: Vec<Candle> = (0..40)
298 .map(|i| {
299 let base = 100.0 + i as f64;
300 c(base + 1.0, base - 1.0, base, i)
301 })
302 .collect();
303 let mut st = SuperTrend::classic();
304 st.batch(&candles);
305 assert!(st.is_ready());
306 st.reset();
307 assert!(!st.is_ready());
308 assert_eq!(st.update(candles[0]), None);
309 }
310
311 #[test]
312 fn batch_equals_streaming() {
313 let candles: Vec<Candle> = (0..80)
314 .map(|i| {
315 let mid = 100.0 + (i as f64 * 0.3).sin() * 8.0;
316 c(mid + 1.5, mid - 1.5, mid + 0.5, i)
317 })
318 .collect();
319 let mut a = SuperTrend::classic();
320 let mut b = SuperTrend::classic();
321 assert_eq!(
322 a.batch(&candles),
323 candles.iter().map(|x| b.update(*x)).collect::<Vec<_>>()
324 );
325 }
326}