1use crate::error::{Error, Result};
4use crate::ohlcv::Candle;
5use crate::traits::Indicator;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8enum Trend {
9 Up,
10 Down,
11}
12
13#[derive(Debug, Clone, Copy)]
15struct Accel {
16 init: f64,
17 step: f64,
18 max: f64,
19}
20
21impl Accel {
22 fn validate(self) -> Result<Self> {
23 if !(self.init.is_finite() && self.step.is_finite() && self.max.is_finite()) {
24 return Err(Error::NonPositiveMultiplier);
25 }
26 if self.init <= 0.0 || self.step <= 0.0 || self.max <= 0.0 {
27 return Err(Error::NonPositiveMultiplier);
28 }
29 if self.init > self.max {
30 return Err(Error::InvalidPeriod {
31 message: "acceleration init must be <= max",
32 });
33 }
34 Ok(self)
35 }
36}
37
38#[derive(Debug, Clone)]
71pub struct SarExt {
72 start_value: f64,
73 offset_on_reverse: f64,
74 long: Accel,
75 short: Accel,
76
77 initialised: bool,
78 has_emitted: bool,
79 prev_high: f64,
80 prev_low: f64,
81 trend: Trend,
82 sar: f64,
83 ep: f64,
84 af: f64,
85}
86
87impl SarExt {
88 #[allow(clippy::too_many_arguments)]
100 pub fn new(
101 start_value: f64,
102 offset_on_reverse: f64,
103 accel_init_long: f64,
104 accel_long: f64,
105 accel_max_long: f64,
106 accel_init_short: f64,
107 accel_short: f64,
108 accel_max_short: f64,
109 ) -> Result<Self> {
110 if !start_value.is_finite() || !offset_on_reverse.is_finite() || offset_on_reverse < 0.0 {
111 return Err(Error::NonPositiveMultiplier);
112 }
113 let long = Accel {
114 init: accel_init_long,
115 step: accel_long,
116 max: accel_max_long,
117 }
118 .validate()?;
119 let short = Accel {
120 init: accel_init_short,
121 step: accel_short,
122 max: accel_max_short,
123 }
124 .validate()?;
125 Ok(Self {
126 start_value,
127 offset_on_reverse,
128 long,
129 short,
130 initialised: false,
131 has_emitted: false,
132 prev_high: f64::NAN,
133 prev_low: f64::NAN,
134 trend: Trend::Up,
135 sar: f64::NAN,
136 ep: f64::NAN,
137 af: long.init,
138 })
139 }
140
141 pub fn classic() -> Self {
144 Self::new(0.0, 0.0, 0.02, 0.02, 0.20, 0.02, 0.02, 0.20)
145 .expect("classic SAREXT params are valid")
146 }
147
148 fn signed(&self, sar: f64) -> f64 {
149 match self.trend {
150 Trend::Up => sar,
151 Trend::Down => -sar,
152 }
153 }
154}
155
156impl Indicator for SarExt {
157 type Input = Candle;
158 type Output = f64;
159
160 fn update(&mut self, candle: Candle) -> Option<f64> {
161 if !self.initialised {
162 self.prev_high = candle.high;
163 self.prev_low = candle.low;
164 if self.start_value > 0.0 {
165 self.trend = Trend::Up;
166 self.sar = self.start_value;
167 self.ep = candle.high;
168 self.af = self.long.init;
169 } else if self.start_value < 0.0 {
170 self.trend = Trend::Down;
171 self.sar = -self.start_value;
172 self.ep = candle.low;
173 self.af = self.short.init;
174 } else {
175 self.trend = Trend::Up;
176 self.sar = candle.low;
177 self.ep = candle.high;
178 self.af = self.long.init;
179 }
180 self.initialised = true;
181 return None;
182 }
183
184 let mut new_sar = self.sar + self.af * (self.ep - self.sar);
185 let prev_h = self.prev_high;
186 let prev_l = self.prev_low;
187 new_sar = match self.trend {
188 Trend::Up => new_sar.min(prev_l).min(candle.low),
189 Trend::Down => new_sar.max(prev_h).max(candle.high),
190 };
191
192 let mut output_sar = new_sar;
193 let reversed = match self.trend {
194 Trend::Up => candle.low <= new_sar,
195 Trend::Down => candle.high >= new_sar,
196 };
197
198 if reversed {
199 output_sar = self.ep;
200 self.trend = match self.trend {
201 Trend::Up => Trend::Down,
202 Trend::Down => Trend::Up,
203 };
204 match self.trend {
205 Trend::Up => {
206 output_sar -= output_sar.abs() * self.offset_on_reverse;
207 self.ep = candle.high;
208 self.af = self.long.init;
209 }
210 Trend::Down => {
211 output_sar += output_sar.abs() * self.offset_on_reverse;
212 self.ep = candle.low;
213 self.af = self.short.init;
214 }
215 }
216 } else {
217 match self.trend {
218 Trend::Up => {
219 if candle.high > self.ep {
220 self.ep = candle.high;
221 self.af = (self.af + self.long.step).min(self.long.max);
222 }
223 }
224 Trend::Down => {
225 if candle.low < self.ep {
226 self.ep = candle.low;
227 self.af = (self.af + self.short.step).min(self.short.max);
228 }
229 }
230 }
231 }
232
233 self.sar = output_sar;
234 self.prev_high = candle.high;
235 self.prev_low = candle.low;
236 self.has_emitted = true;
237 Some(self.signed(output_sar))
238 }
239
240 fn reset(&mut self) {
241 self.initialised = false;
242 self.has_emitted = false;
243 self.prev_high = f64::NAN;
244 self.prev_low = f64::NAN;
245 self.trend = Trend::Up;
246 self.sar = f64::NAN;
247 self.ep = f64::NAN;
248 self.af = self.long.init;
249 }
250
251 fn warmup_period(&self) -> usize {
252 2
253 }
254
255 fn is_ready(&self) -> bool {
256 self.has_emitted
257 }
258
259 fn name(&self) -> &'static str {
260 "SAREXT"
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267 use crate::traits::BatchExt;
268
269 fn c(h: f64, l: f64, cl: f64) -> Candle {
270 Candle::new(cl, h, l, cl, 1.0, 0).unwrap()
271 }
272
273 fn classic() -> SarExt {
274 SarExt::classic()
275 }
276
277 #[test]
278 fn rejects_invalid_params() {
279 assert!(SarExt::new(0.0, 0.0, 0.0, 0.02, 0.2, 0.02, 0.02, 0.2).is_err());
281 assert!(SarExt::new(0.0, 0.0, 0.02, 0.02, 0.2, 0.0, 0.02, 0.2).is_err());
282 assert!(SarExt::new(0.0, 0.0, 0.30, 0.02, 0.2, 0.02, 0.02, 0.2).is_err());
283 assert!(SarExt::new(0.0, 0.0, f64::NAN, 0.02, 0.2, 0.02, 0.02, 0.2).is_err());
286 assert!(SarExt::new(0.0, 0.0, 0.02, 0.02, 0.2, 0.02, f64::INFINITY, 0.2).is_err());
287 assert!(SarExt::new(f64::NAN, 0.0, 0.02, 0.02, 0.2, 0.02, 0.02, 0.2).is_err());
289 assert!(SarExt::new(0.0, -1.0, 0.02, 0.02, 0.2, 0.02, 0.02, 0.2).is_err());
290 }
291
292 #[test]
293 fn accessors_and_metadata() {
294 let s = classic();
295 assert_eq!(s.warmup_period(), 2);
296 assert_eq!(s.name(), "SAREXT");
297 assert!(!s.is_ready());
298 }
299
300 #[test]
301 fn seed_returns_none_then_emits() {
302 let mut s = classic();
303 assert_eq!(s.update(c(11.0, 9.0, 10.0)), None);
304 assert!(!s.is_ready());
305 assert!(s.update(c(12.0, 10.0, 11.0)).is_some());
306 assert!(s.is_ready());
307 }
308
309 #[test]
310 fn uptrend_is_positive_and_below_lows() {
311 let candles: Vec<Candle> = (0..40)
312 .map(|i| {
313 let base = 100.0 + f64::from(i);
314 c(base + 0.5, base - 0.5, base)
315 })
316 .collect();
317 let mut s = classic();
318 let ok = s
319 .batch(&candles)
320 .iter()
321 .enumerate()
322 .all(|(i, v)| v.is_none_or(|x| x > 0.0 && x <= candles[i].low + 1e-9));
323 assert!(ok, "long-phase SAREXT must be positive and below the low");
324 }
325
326 #[test]
327 fn downtrend_is_negative_and_above_highs() {
328 let candles: Vec<Candle> = (0..40)
329 .rev()
330 .map(|i| {
331 let base = 100.0 + f64::from(i);
332 c(base + 0.5, base - 0.5, base)
333 })
334 .collect();
335 let mut s = classic();
336 let ok = s
337 .batch(&candles)
338 .iter()
339 .enumerate()
340 .skip(5)
341 .all(|(i, v)| v.is_none_or(|x| x < 0.0 && -x >= candles[i].high - 1e-9));
342 assert!(ok, "short-phase SAREXT must be negative and above the high");
343 }
344
345 #[test]
346 fn positive_start_value_begins_long() {
347 let mut s = SarExt::new(95.0, 0.0, 0.02, 0.02, 0.2, 0.02, 0.02, 0.2).unwrap();
349 assert_eq!(s.update(c(101.0, 99.0, 100.0)), None);
350 let v = s.update(c(102.0, 100.0, 101.0)).unwrap();
351 assert!(v > 0.0);
352 }
353
354 #[test]
355 fn negative_start_value_begins_short() {
356 let mut s = SarExt::new(-105.0, 0.0, 0.02, 0.02, 0.2, 0.02, 0.02, 0.2).unwrap();
358 assert_eq!(s.update(c(101.0, 99.0, 100.0)), None);
359 let v = s.update(c(100.0, 98.0, 99.0)).unwrap();
360 assert!(v < 0.0);
361 }
362
363 #[test]
364 fn offset_on_reverse_pushes_sar_further() {
365 let candles: Vec<Candle> = (0..12)
368 .map(|i| {
369 let base = if i < 6 {
370 100.0 - f64::from(i) * 2.0
371 } else {
372 88.0 + f64::from(i - 6) * 2.0
373 };
374 c(base + 1.0, base - 1.0, base)
375 })
376 .collect();
377 let plain = SarExt::new(0.0, 0.0, 0.02, 0.02, 0.2, 0.02, 0.02, 0.2)
378 .unwrap()
379 .batch(&candles);
380 let offset = SarExt::new(0.0, 0.1, 0.02, 0.02, 0.2, 0.02, 0.02, 0.2)
381 .unwrap()
382 .batch(&candles);
383 assert_ne!(plain, offset);
385 }
386
387 #[test]
388 fn batch_equals_streaming() {
389 let candles: Vec<Candle> = (0..60)
390 .map(|i| {
391 let m = 100.0 + (f64::from(i) * 0.3).sin() * 8.0;
392 c(m + 1.0, m - 1.0, m)
393 })
394 .collect();
395 let mut a = classic();
396 let mut b = classic();
397 assert_eq!(
398 a.batch(&candles),
399 candles.iter().map(|x| b.update(*x)).collect::<Vec<_>>()
400 );
401 }
402
403 #[test]
404 fn reset_allows_clean_reuse() {
405 let candles: Vec<Candle> = (0..40)
406 .map(|i| {
407 let base = 100.0 + f64::from(i);
408 c(base + 0.5, base - 0.5, base)
409 })
410 .collect();
411 let mut s = classic();
412 let first = s.batch(&candles);
413 assert!(s.is_ready());
414 s.reset();
415 assert!(!s.is_ready());
416 assert_eq!(first, s.batch(&candles));
417 }
418}