1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::utilities::data_loader::Candles;
16use crate::utilities::enums::Kernel;
17use crate::utilities::helpers::{
18 alloc_with_nan_prefix, detect_best_batch_kernel, make_uninit_matrix,
19};
20#[cfg(feature = "python")]
21use crate::utilities::kernel_validation::validate_kernel;
22#[cfg(not(target_arch = "wasm32"))]
23use rayon::prelude::*;
24#[cfg(test)]
25use std::error::Error as StdError;
26use std::mem::ManuallyDrop;
27use thiserror::Error;
28
29const DEFAULT_ATR_LENGTH: usize = 10;
30const DEFAULT_MULTIPLIER: f64 = 3.0;
31const DEFAULT_ALPHA_PERCENT: f64 = 5.0;
32const DEFAULT_THRESHOLD_ATR: f64 = 1.0;
33const DEFAULT_TREND: i8 = 1;
34const MIN_ALPHA_PERCENT: f64 = 0.1;
35const MAX_ALPHA_PERCENT: f64 = 100.0;
36const MIN_MULTIPLIER: f64 = 0.1;
37
38#[inline(always)]
39fn high_source(candles: &Candles) -> &[f64] {
40 &candles.high
41}
42
43#[inline(always)]
44fn low_source(candles: &Candles) -> &[f64] {
45 &candles.low
46}
47
48#[inline(always)]
49fn close_source(candles: &Candles) -> &[f64] {
50 &candles.close
51}
52
53#[inline(always)]
54fn hl2(high: f64, low: f64) -> f64 {
55 0.5 * (high + low)
56}
57
58#[inline(always)]
59fn true_range(high: f64, low: f64, prev_close: f64) -> f64 {
60 (high - low)
61 .max((high - prev_close).abs())
62 .max((low - prev_close).abs())
63}
64
65#[derive(Debug, Clone)]
66pub enum SuperTrendRecoveryData<'a> {
67 Candles {
68 candles: &'a Candles,
69 },
70 Slices {
71 high: &'a [f64],
72 low: &'a [f64],
73 close: &'a [f64],
74 },
75}
76
77#[derive(Debug, Clone)]
78#[cfg_attr(
79 all(target_arch = "wasm32", feature = "wasm"),
80 derive(Serialize, Deserialize)
81)]
82pub struct SuperTrendRecoveryOutput {
83 pub band: Vec<f64>,
84 pub switch_price: Vec<f64>,
85 pub trend: Vec<f64>,
86 pub changed: Vec<f64>,
87}
88
89#[derive(Debug, Clone, PartialEq)]
90#[cfg_attr(
91 all(target_arch = "wasm32", feature = "wasm"),
92 derive(Serialize, Deserialize)
93)]
94pub struct SuperTrendRecoveryParams {
95 pub atr_length: Option<usize>,
96 pub multiplier: Option<f64>,
97 pub alpha_percent: Option<f64>,
98 pub threshold_atr: Option<f64>,
99}
100
101impl Default for SuperTrendRecoveryParams {
102 fn default() -> Self {
103 Self {
104 atr_length: Some(DEFAULT_ATR_LENGTH),
105 multiplier: Some(DEFAULT_MULTIPLIER),
106 alpha_percent: Some(DEFAULT_ALPHA_PERCENT),
107 threshold_atr: Some(DEFAULT_THRESHOLD_ATR),
108 }
109 }
110}
111
112#[derive(Debug, Clone)]
113pub struct SuperTrendRecoveryInput<'a> {
114 pub data: SuperTrendRecoveryData<'a>,
115 pub params: SuperTrendRecoveryParams,
116}
117
118impl<'a> SuperTrendRecoveryInput<'a> {
119 #[inline(always)]
120 pub fn from_candles(candles: &'a Candles, params: SuperTrendRecoveryParams) -> Self {
121 Self {
122 data: SuperTrendRecoveryData::Candles { candles },
123 params,
124 }
125 }
126
127 #[inline(always)]
128 pub fn from_slices(
129 high: &'a [f64],
130 low: &'a [f64],
131 close: &'a [f64],
132 params: SuperTrendRecoveryParams,
133 ) -> Self {
134 Self {
135 data: SuperTrendRecoveryData::Slices { high, low, close },
136 params,
137 }
138 }
139
140 #[inline(always)]
141 pub fn with_default_candles(candles: &'a Candles) -> Self {
142 Self::from_candles(candles, SuperTrendRecoveryParams::default())
143 }
144
145 #[inline(always)]
146 pub fn get_atr_length(&self) -> usize {
147 self.params.atr_length.unwrap_or(DEFAULT_ATR_LENGTH)
148 }
149
150 #[inline(always)]
151 pub fn get_multiplier(&self) -> f64 {
152 self.params.multiplier.unwrap_or(DEFAULT_MULTIPLIER)
153 }
154
155 #[inline(always)]
156 pub fn get_alpha_percent(&self) -> f64 {
157 self.params.alpha_percent.unwrap_or(DEFAULT_ALPHA_PERCENT)
158 }
159
160 #[inline(always)]
161 pub fn get_threshold_atr(&self) -> f64 {
162 self.params.threshold_atr.unwrap_or(DEFAULT_THRESHOLD_ATR)
163 }
164
165 #[inline(always)]
166 fn as_hlc(&self) -> (&'a [f64], &'a [f64], &'a [f64]) {
167 match &self.data {
168 SuperTrendRecoveryData::Candles { candles } => (
169 high_source(candles),
170 low_source(candles),
171 close_source(candles),
172 ),
173 SuperTrendRecoveryData::Slices { high, low, close } => (*high, *low, *close),
174 }
175 }
176}
177
178impl<'a> AsRef<[f64]> for SuperTrendRecoveryInput<'a> {
179 #[inline(always)]
180 fn as_ref(&self) -> &[f64] {
181 self.as_hlc().2
182 }
183}
184
185#[derive(Clone, Debug)]
186pub struct SuperTrendRecoveryBuilder {
187 atr_length: Option<usize>,
188 multiplier: Option<f64>,
189 alpha_percent: Option<f64>,
190 threshold_atr: Option<f64>,
191 kernel: Kernel,
192}
193
194impl Default for SuperTrendRecoveryBuilder {
195 fn default() -> Self {
196 Self {
197 atr_length: None,
198 multiplier: None,
199 alpha_percent: None,
200 threshold_atr: None,
201 kernel: Kernel::Auto,
202 }
203 }
204}
205
206impl SuperTrendRecoveryBuilder {
207 #[inline(always)]
208 pub fn new() -> Self {
209 Self::default()
210 }
211
212 #[inline(always)]
213 pub fn atr_length(mut self, value: usize) -> Self {
214 self.atr_length = Some(value);
215 self
216 }
217
218 #[inline(always)]
219 pub fn multiplier(mut self, value: f64) -> Self {
220 self.multiplier = Some(value);
221 self
222 }
223
224 #[inline(always)]
225 pub fn alpha_percent(mut self, value: f64) -> Self {
226 self.alpha_percent = Some(value);
227 self
228 }
229
230 #[inline(always)]
231 pub fn threshold_atr(mut self, value: f64) -> Self {
232 self.threshold_atr = Some(value);
233 self
234 }
235
236 #[inline(always)]
237 pub fn kernel(mut self, kernel: Kernel) -> Self {
238 self.kernel = kernel;
239 self
240 }
241
242 #[inline(always)]
243 fn params(self) -> SuperTrendRecoveryParams {
244 SuperTrendRecoveryParams {
245 atr_length: self.atr_length,
246 multiplier: self.multiplier,
247 alpha_percent: self.alpha_percent,
248 threshold_atr: self.threshold_atr,
249 }
250 }
251
252 #[inline(always)]
253 pub fn apply(
254 self,
255 candles: &Candles,
256 ) -> Result<SuperTrendRecoveryOutput, SuperTrendRecoveryError> {
257 let kernel = self.kernel;
258 let params = self.params();
259 supertrend_recovery_with_kernel(
260 &SuperTrendRecoveryInput::from_candles(candles, params),
261 kernel,
262 )
263 }
264
265 #[inline(always)]
266 pub fn apply_slices(
267 self,
268 high: &[f64],
269 low: &[f64],
270 close: &[f64],
271 ) -> Result<SuperTrendRecoveryOutput, SuperTrendRecoveryError> {
272 let kernel = self.kernel;
273 let params = self.params();
274 supertrend_recovery_with_kernel(
275 &SuperTrendRecoveryInput::from_slices(high, low, close, params),
276 kernel,
277 )
278 }
279
280 #[inline(always)]
281 pub fn into_stream(self) -> Result<SuperTrendRecoveryStream, SuperTrendRecoveryError> {
282 SuperTrendRecoveryStream::try_new(self.params())
283 }
284}
285
286#[derive(Debug, Error)]
287pub enum SuperTrendRecoveryError {
288 #[error("supertrend_recovery: input data slice is empty.")]
289 EmptyInputData,
290 #[error("supertrend_recovery: all values are NaN.")]
291 AllValuesNaN,
292 #[error(
293 "supertrend_recovery: inconsistent data lengths - high = {high_len}, low = {low_len}, close = {close_len}"
294 )]
295 DataLengthMismatch {
296 high_len: usize,
297 low_len: usize,
298 close_len: usize,
299 },
300 #[error(
301 "supertrend_recovery: invalid period: atr_length = {atr_length}, data length = {data_len}"
302 )]
303 InvalidPeriod { atr_length: usize, data_len: usize },
304 #[error("supertrend_recovery: invalid multiplier: {multiplier}")]
305 InvalidMultiplier { multiplier: f64 },
306 #[error("supertrend_recovery: invalid alpha_percent: {alpha_percent}")]
307 InvalidAlphaPercent { alpha_percent: f64 },
308 #[error("supertrend_recovery: invalid threshold_atr: {threshold_atr}")]
309 InvalidThresholdAtr { threshold_atr: f64 },
310 #[error("supertrend_recovery: not enough valid data: needed = {needed}, valid = {valid}")]
311 NotEnoughValidData { needed: usize, valid: usize },
312 #[error("supertrend_recovery: output length mismatch: expected = {expected}, got = {got}")]
313 OutputLengthMismatch { expected: usize, got: usize },
314 #[error(
315 "supertrend_recovery: invalid range for {axis}: start = {start}, end = {end}, step = {step}"
316 )]
317 InvalidRange {
318 axis: &'static str,
319 start: String,
320 end: String,
321 step: String,
322 },
323 #[error("supertrend_recovery: invalid kernel for batch: {0:?}")]
324 InvalidKernelForBatch(Kernel),
325}
326
327#[derive(Clone, Copy, Debug)]
328struct PreparedInput<'a> {
329 high: &'a [f64],
330 low: &'a [f64],
331 close: &'a [f64],
332 atr_length: usize,
333 multiplier: f64,
334 alpha: f64,
335 threshold_atr: f64,
336 warmup: usize,
337}
338
339#[inline(always)]
340fn normalize_single_kernel(_kernel: Kernel) -> Kernel {
341 Kernel::Scalar
342}
343
344#[inline(always)]
345fn validate_params(
346 atr_length: usize,
347 multiplier: f64,
348 alpha_percent: f64,
349 threshold_atr: f64,
350 data_len: usize,
351) -> Result<(), SuperTrendRecoveryError> {
352 if atr_length == 0 || atr_length > data_len {
353 return Err(SuperTrendRecoveryError::InvalidPeriod {
354 atr_length,
355 data_len,
356 });
357 }
358 if !multiplier.is_finite() || multiplier < MIN_MULTIPLIER {
359 return Err(SuperTrendRecoveryError::InvalidMultiplier { multiplier });
360 }
361 if !alpha_percent.is_finite()
362 || !(MIN_ALPHA_PERCENT..=MAX_ALPHA_PERCENT).contains(&alpha_percent)
363 {
364 return Err(SuperTrendRecoveryError::InvalidAlphaPercent { alpha_percent });
365 }
366 if !threshold_atr.is_finite() || threshold_atr < 0.0 {
367 return Err(SuperTrendRecoveryError::InvalidThresholdAtr { threshold_atr });
368 }
369 Ok(())
370}
371
372#[inline(always)]
373fn analyze_valid_segments(
374 high: &[f64],
375 low: &[f64],
376 close: &[f64],
377) -> Result<(usize, usize), SuperTrendRecoveryError> {
378 if high.is_empty() || low.is_empty() || close.is_empty() {
379 return Err(SuperTrendRecoveryError::EmptyInputData);
380 }
381 if high.len() != low.len() || high.len() != close.len() {
382 return Err(SuperTrendRecoveryError::DataLengthMismatch {
383 high_len: high.len(),
384 low_len: low.len(),
385 close_len: close.len(),
386 });
387 }
388
389 let mut first_valid = None;
390 let mut max_run = 0usize;
391 let mut run = 0usize;
392
393 for i in 0..close.len() {
394 let valid = high[i].is_finite() && low[i].is_finite() && close[i].is_finite();
395 if valid {
396 if first_valid.is_none() {
397 first_valid = Some(i);
398 }
399 run += 1;
400 if run > max_run {
401 max_run = run;
402 }
403 } else {
404 run = 0;
405 }
406 }
407
408 match first_valid {
409 Some(idx) => Ok((idx, max_run)),
410 None => Err(SuperTrendRecoveryError::AllValuesNaN),
411 }
412}
413
414#[inline(always)]
415fn prepare_input<'a>(
416 input: &'a SuperTrendRecoveryInput<'a>,
417 kernel: Kernel,
418) -> Result<PreparedInput<'a>, SuperTrendRecoveryError> {
419 let _chosen = normalize_single_kernel(kernel);
420 let (high, low, close) = input.as_hlc();
421 let atr_length = input.get_atr_length();
422 let multiplier = input.get_multiplier();
423 let alpha_percent = input.get_alpha_percent();
424 let threshold_atr = input.get_threshold_atr();
425 validate_params(
426 atr_length,
427 multiplier,
428 alpha_percent,
429 threshold_atr,
430 close.len(),
431 )?;
432
433 let (first_valid, max_run) = analyze_valid_segments(high, low, close)?;
434 if max_run < atr_length {
435 return Err(SuperTrendRecoveryError::NotEnoughValidData {
436 needed: atr_length,
437 valid: max_run,
438 });
439 }
440
441 Ok(PreparedInput {
442 high,
443 low,
444 close,
445 atr_length,
446 multiplier,
447 alpha: alpha_percent * 0.01,
448 threshold_atr,
449 warmup: first_valid + atr_length - 1,
450 })
451}
452
453#[derive(Clone, Debug)]
454struct AtrState {
455 length: usize,
456 count: usize,
457 sum: f64,
458 value: f64,
459}
460
461impl AtrState {
462 #[inline(always)]
463 fn new(length: usize) -> Self {
464 Self {
465 length,
466 count: 0,
467 sum: 0.0,
468 value: f64::NAN,
469 }
470 }
471
472 #[inline(always)]
473 fn reset(&mut self) {
474 self.count = 0;
475 self.sum = 0.0;
476 self.value = f64::NAN;
477 }
478
479 #[inline(always)]
480 fn update(&mut self, tr: f64) -> Option<f64> {
481 if self.count < self.length {
482 self.count += 1;
483 self.sum += tr;
484 if self.count == self.length {
485 self.value = self.sum / self.length as f64;
486 Some(self.value)
487 } else {
488 None
489 }
490 } else {
491 self.value = ((self.value * (self.length as f64 - 1.0)) + tr) / self.length as f64;
492 Some(self.value)
493 }
494 }
495}
496
497#[derive(Clone, Debug)]
498struct SuperTrendRecoveryState {
499 atr: AtrState,
500 multiplier: f64,
501 alpha: f64,
502 threshold_atr: f64,
503 prev_close: f64,
504 band: f64,
505 switch_price: f64,
506 trend: i8,
507}
508
509impl SuperTrendRecoveryState {
510 #[inline(always)]
511 fn new(atr_length: usize, multiplier: f64, alpha: f64, threshold_atr: f64) -> Self {
512 Self {
513 atr: AtrState::new(atr_length),
514 multiplier,
515 alpha,
516 threshold_atr,
517 prev_close: f64::NAN,
518 band: f64::NAN,
519 switch_price: f64::NAN,
520 trend: DEFAULT_TREND,
521 }
522 }
523
524 #[inline(always)]
525 fn reset(&mut self) {
526 self.atr.reset();
527 self.prev_close = f64::NAN;
528 self.band = f64::NAN;
529 self.switch_price = f64::NAN;
530 self.trend = DEFAULT_TREND;
531 }
532
533 #[inline(always)]
534 fn update(&mut self, high: f64, low: f64, close: f64) -> Option<(f64, f64, f64, f64)> {
535 if !high.is_finite() || !low.is_finite() || !close.is_finite() {
536 self.reset();
537 return None;
538 }
539
540 if !self.switch_price.is_finite() {
541 self.switch_price = close;
542 }
543
544 let tr = if self.prev_close.is_finite() {
545 true_range(high, low, self.prev_close)
546 } else {
547 high - low
548 };
549 self.prev_close = close;
550
551 let atr = self.atr.update(tr)?;
552 let src = hl2(high, low);
553 let upper_base = src + self.multiplier * atr;
554 let lower_base = src - self.multiplier * atr;
555 let deviation = self.threshold_atr * atr;
556 let is_at_loss = (self.trend == 1 && (self.switch_price - close) > deviation)
557 || (self.trend == -1 && (close - self.switch_price) > deviation);
558 let prev_band = if self.band.is_finite() {
559 self.band
560 } else if self.trend == 1 {
561 lower_base
562 } else {
563 upper_base
564 };
565
566 let mut changed = 0.0;
567
568 if self.trend == 1 {
569 let target_band = if is_at_loss {
570 self.alpha.mul_add(close, (1.0 - self.alpha) * prev_band)
571 } else {
572 lower_base
573 };
574 self.band = target_band.max(prev_band);
575 if close < self.band {
576 self.trend = -1;
577 self.band = upper_base;
578 self.switch_price = close;
579 changed = 1.0;
580 }
581 } else {
582 let target_band = if is_at_loss {
583 self.alpha.mul_add(close, (1.0 - self.alpha) * prev_band)
584 } else {
585 upper_base
586 };
587 self.band = target_band.min(prev_band);
588 if close > self.band {
589 self.trend = 1;
590 self.band = lower_base;
591 self.switch_price = close;
592 changed = 1.0;
593 }
594 }
595
596 Some((self.band, self.switch_price, self.trend as f64, changed))
597 }
598}
599
600#[derive(Clone, Debug)]
601pub struct SuperTrendRecoveryStream {
602 params: SuperTrendRecoveryParams,
603 state: SuperTrendRecoveryState,
604}
605
606impl SuperTrendRecoveryStream {
607 #[inline(always)]
608 pub fn try_new(params: SuperTrendRecoveryParams) -> Result<Self, SuperTrendRecoveryError> {
609 let atr_length = params.atr_length.unwrap_or(DEFAULT_ATR_LENGTH);
610 let multiplier = params.multiplier.unwrap_or(DEFAULT_MULTIPLIER);
611 let alpha_percent = params.alpha_percent.unwrap_or(DEFAULT_ALPHA_PERCENT);
612 let threshold_atr = params.threshold_atr.unwrap_or(DEFAULT_THRESHOLD_ATR);
613 validate_params(
614 atr_length,
615 multiplier,
616 alpha_percent,
617 threshold_atr,
618 usize::MAX,
619 )?;
620 Ok(Self {
621 state: SuperTrendRecoveryState::new(
622 atr_length,
623 multiplier,
624 alpha_percent * 0.01,
625 threshold_atr,
626 ),
627 params,
628 })
629 }
630
631 #[inline(always)]
632 pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<(f64, f64, f64, f64)> {
633 self.state.update(high, low, close)
634 }
635
636 #[inline(always)]
637 pub fn params(&self) -> &SuperTrendRecoveryParams {
638 &self.params
639 }
640}
641
642#[derive(Clone, Debug)]
643pub struct SuperTrendRecoveryBatchRange {
644 pub atr_length: (usize, usize, usize),
645 pub multiplier: (f64, f64, f64),
646 pub alpha_percent: (f64, f64, f64),
647 pub threshold_atr: (f64, f64, f64),
648}
649
650impl Default for SuperTrendRecoveryBatchRange {
651 fn default() -> Self {
652 Self {
653 atr_length: (DEFAULT_ATR_LENGTH, DEFAULT_ATR_LENGTH, 0),
654 multiplier: (DEFAULT_MULTIPLIER, DEFAULT_MULTIPLIER, 0.0),
655 alpha_percent: (DEFAULT_ALPHA_PERCENT, DEFAULT_ALPHA_PERCENT, 0.0),
656 threshold_atr: (DEFAULT_THRESHOLD_ATR, DEFAULT_THRESHOLD_ATR, 0.0),
657 }
658 }
659}
660
661#[derive(Clone, Debug, Default)]
662pub struct SuperTrendRecoveryBatchBuilder {
663 range: SuperTrendRecoveryBatchRange,
664 kernel: Kernel,
665}
666
667#[derive(Clone, Debug)]
668pub struct SuperTrendRecoveryBatchOutput {
669 pub band: Vec<f64>,
670 pub switch_price: Vec<f64>,
671 pub trend: Vec<f64>,
672 pub changed: Vec<f64>,
673 pub combos: Vec<SuperTrendRecoveryParams>,
674 pub rows: usize,
675 pub cols: usize,
676}
677
678impl SuperTrendRecoveryBatchBuilder {
679 #[inline(always)]
680 pub fn new() -> Self {
681 Self::default()
682 }
683
684 #[inline(always)]
685 pub fn kernel(mut self, kernel: Kernel) -> Self {
686 self.kernel = kernel;
687 self
688 }
689
690 #[inline(always)]
691 pub fn atr_length_range(mut self, start: usize, end: usize, step: usize) -> Self {
692 self.range.atr_length = (start, end, step);
693 self
694 }
695
696 #[inline(always)]
697 pub fn multiplier_range(mut self, start: f64, end: f64, step: f64) -> Self {
698 self.range.multiplier = (start, end, step);
699 self
700 }
701
702 #[inline(always)]
703 pub fn alpha_percent_range(mut self, start: f64, end: f64, step: f64) -> Self {
704 self.range.alpha_percent = (start, end, step);
705 self
706 }
707
708 #[inline(always)]
709 pub fn threshold_atr_range(mut self, start: f64, end: f64, step: f64) -> Self {
710 self.range.threshold_atr = (start, end, step);
711 self
712 }
713
714 #[inline(always)]
715 pub fn apply_slices(
716 self,
717 high: &[f64],
718 low: &[f64],
719 close: &[f64],
720 ) -> Result<SuperTrendRecoveryBatchOutput, SuperTrendRecoveryError> {
721 supertrend_recovery_batch_with_kernel(high, low, close, &self.range, self.kernel)
722 }
723
724 #[inline(always)]
725 pub fn apply(
726 self,
727 candles: &Candles,
728 ) -> Result<SuperTrendRecoveryBatchOutput, SuperTrendRecoveryError> {
729 self.apply_slices(&candles.high, &candles.low, &candles.close)
730 }
731}
732
733#[inline(always)]
734fn compute_row(
735 high: &[f64],
736 low: &[f64],
737 close: &[f64],
738 atr_length: usize,
739 multiplier: f64,
740 alpha_percent: f64,
741 threshold_atr: f64,
742 band_out: &mut [f64],
743 switch_price_out: &mut [f64],
744 trend_out: &mut [f64],
745 changed_out: &mut [f64],
746) -> Result<(), SuperTrendRecoveryError> {
747 let len = close.len();
748 if band_out.len() != len
749 || switch_price_out.len() != len
750 || trend_out.len() != len
751 || changed_out.len() != len
752 {
753 return Err(SuperTrendRecoveryError::OutputLengthMismatch {
754 expected: len,
755 got: band_out
756 .len()
757 .max(switch_price_out.len())
758 .max(trend_out.len())
759 .max(changed_out.len()),
760 });
761 }
762
763 let mut state =
764 SuperTrendRecoveryState::new(atr_length, multiplier, alpha_percent * 0.01, threshold_atr);
765
766 for i in 0..len {
767 if let Some((band, switch_price, trend, changed)) = state.update(high[i], low[i], close[i])
768 {
769 band_out[i] = band;
770 switch_price_out[i] = switch_price;
771 trend_out[i] = trend;
772 changed_out[i] = changed;
773 } else {
774 band_out[i] = f64::NAN;
775 switch_price_out[i] = f64::NAN;
776 trend_out[i] = f64::NAN;
777 changed_out[i] = f64::NAN;
778 }
779 }
780
781 Ok(())
782}
783
784#[inline]
785pub fn supertrend_recovery(
786 input: &SuperTrendRecoveryInput,
787) -> Result<SuperTrendRecoveryOutput, SuperTrendRecoveryError> {
788 supertrend_recovery_with_kernel(input, Kernel::Auto)
789}
790
791pub fn supertrend_recovery_with_kernel(
792 input: &SuperTrendRecoveryInput,
793 kernel: Kernel,
794) -> Result<SuperTrendRecoveryOutput, SuperTrendRecoveryError> {
795 let prepared = prepare_input(input, kernel)?;
796 let len = prepared.close.len();
797 let mut band = alloc_with_nan_prefix(len, prepared.warmup);
798 let mut switch_price = alloc_with_nan_prefix(len, prepared.warmup);
799 let mut trend = alloc_with_nan_prefix(len, prepared.warmup);
800 let mut changed = alloc_with_nan_prefix(len, prepared.warmup);
801 compute_row(
802 prepared.high,
803 prepared.low,
804 prepared.close,
805 prepared.atr_length,
806 prepared.multiplier,
807 prepared.alpha / 0.01,
808 prepared.threshold_atr,
809 &mut band,
810 &mut switch_price,
811 &mut trend,
812 &mut changed,
813 )?;
814 Ok(SuperTrendRecoveryOutput {
815 band,
816 switch_price,
817 trend,
818 changed,
819 })
820}
821
822#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
823pub fn supertrend_recovery_into(
824 band_out: &mut [f64],
825 switch_price_out: &mut [f64],
826 trend_out: &mut [f64],
827 changed_out: &mut [f64],
828 input: &SuperTrendRecoveryInput,
829) -> Result<(), SuperTrendRecoveryError> {
830 supertrend_recovery_into_slice(
831 band_out,
832 switch_price_out,
833 trend_out,
834 changed_out,
835 input,
836 Kernel::Auto,
837 )
838}
839
840pub fn supertrend_recovery_into_slice(
841 band_out: &mut [f64],
842 switch_price_out: &mut [f64],
843 trend_out: &mut [f64],
844 changed_out: &mut [f64],
845 input: &SuperTrendRecoveryInput,
846 kernel: Kernel,
847) -> Result<(), SuperTrendRecoveryError> {
848 let prepared = prepare_input(input, kernel)?;
849 compute_row(
850 prepared.high,
851 prepared.low,
852 prepared.close,
853 prepared.atr_length,
854 prepared.multiplier,
855 prepared.alpha / 0.01,
856 prepared.threshold_atr,
857 band_out,
858 switch_price_out,
859 trend_out,
860 changed_out,
861 )
862}
863
864#[inline(always)]
865pub fn expand_grid(
866 sweep: &SuperTrendRecoveryBatchRange,
867) -> Result<Vec<SuperTrendRecoveryParams>, SuperTrendRecoveryError> {
868 fn axis_usize(
869 axis: &'static str,
870 (start, end, step): (usize, usize, usize),
871 ) -> Result<Vec<usize>, SuperTrendRecoveryError> {
872 if step == 0 || start == end {
873 return Ok(vec![start]);
874 }
875 let mut out = Vec::new();
876 if start < end {
877 let mut value = start;
878 while value <= end {
879 out.push(value);
880 match value.checked_add(step) {
881 Some(next) => value = next,
882 None => break,
883 }
884 }
885 } else {
886 let mut value = start as isize;
887 let stop = end as isize;
888 let stride = step as isize;
889 while value >= stop {
890 out.push(value as usize);
891 value -= stride;
892 }
893 }
894 if out.is_empty() {
895 return Err(SuperTrendRecoveryError::InvalidRange {
896 axis,
897 start: start.to_string(),
898 end: end.to_string(),
899 step: step.to_string(),
900 });
901 }
902 Ok(out)
903 }
904
905 fn axis_float(
906 axis: &'static str,
907 (start, end, step): (f64, f64, f64),
908 ) -> Result<Vec<f64>, SuperTrendRecoveryError> {
909 if !start.is_finite() || !end.is_finite() || !step.is_finite() {
910 return Err(SuperTrendRecoveryError::InvalidRange {
911 axis,
912 start: start.to_string(),
913 end: end.to_string(),
914 step: step.to_string(),
915 });
916 }
917 if step == 0.0 || start == end {
918 return Ok(vec![start]);
919 }
920 if step < 0.0 {
921 return Err(SuperTrendRecoveryError::InvalidRange {
922 axis,
923 start: start.to_string(),
924 end: end.to_string(),
925 step: step.to_string(),
926 });
927 }
928 let mut out = Vec::new();
929 let eps = step.abs() * 1e-9 + 1e-12;
930 if start < end {
931 let mut value = start;
932 while value <= end + eps {
933 out.push(value);
934 value += step;
935 }
936 } else {
937 let mut value = start;
938 while value + eps >= end {
939 out.push(value);
940 value -= step;
941 }
942 }
943 if out.is_empty() {
944 return Err(SuperTrendRecoveryError::InvalidRange {
945 axis,
946 start: start.to_string(),
947 end: end.to_string(),
948 step: step.to_string(),
949 });
950 }
951 Ok(out)
952 }
953
954 let atr_lengths = axis_usize("atr_length", sweep.atr_length)?;
955 let multipliers = axis_float("multiplier", sweep.multiplier)?;
956 let alpha_percents = axis_float("alpha_percent", sweep.alpha_percent)?;
957 let threshold_atrs = axis_float("threshold_atr", sweep.threshold_atr)?;
958
959 let cap = atr_lengths
960 .len()
961 .checked_mul(multipliers.len())
962 .and_then(|v| v.checked_mul(alpha_percents.len()))
963 .and_then(|v| v.checked_mul(threshold_atrs.len()))
964 .ok_or(SuperTrendRecoveryError::InvalidRange {
965 axis: "grid",
966 start: "cap".to_string(),
967 end: "overflow".to_string(),
968 step: "mul".to_string(),
969 })?;
970
971 let mut out = Vec::with_capacity(cap);
972 for &atr_length in &atr_lengths {
973 for &multiplier in &multipliers {
974 for &alpha_percent in &alpha_percents {
975 for &threshold_atr in &threshold_atrs {
976 out.push(SuperTrendRecoveryParams {
977 atr_length: Some(atr_length),
978 multiplier: Some(multiplier),
979 alpha_percent: Some(alpha_percent),
980 threshold_atr: Some(threshold_atr),
981 });
982 }
983 }
984 }
985 }
986 Ok(out)
987}
988
989fn supertrend_recovery_batch_inner_into(
990 high: &[f64],
991 low: &[f64],
992 close: &[f64],
993 sweep: &SuperTrendRecoveryBatchRange,
994 parallel: bool,
995 band_out: &mut [f64],
996 switch_price_out: &mut [f64],
997 trend_out: &mut [f64],
998 changed_out: &mut [f64],
999) -> Result<Vec<SuperTrendRecoveryParams>, SuperTrendRecoveryError> {
1000 let (_, max_run) = analyze_valid_segments(high, low, close)?;
1001 let combos = expand_grid(sweep)?;
1002 let rows = combos.len();
1003 let cols = close.len();
1004 let expected = rows
1005 .checked_mul(cols)
1006 .ok_or(SuperTrendRecoveryError::OutputLengthMismatch {
1007 expected: usize::MAX,
1008 got: band_out.len(),
1009 })?;
1010 if band_out.len() != expected
1011 || switch_price_out.len() != expected
1012 || trend_out.len() != expected
1013 || changed_out.len() != expected
1014 {
1015 return Err(SuperTrendRecoveryError::OutputLengthMismatch {
1016 expected,
1017 got: band_out
1018 .len()
1019 .max(switch_price_out.len())
1020 .max(trend_out.len())
1021 .max(changed_out.len()),
1022 });
1023 }
1024
1025 for params in &combos {
1026 let atr_length = params.atr_length.unwrap_or(DEFAULT_ATR_LENGTH);
1027 let multiplier = params.multiplier.unwrap_or(DEFAULT_MULTIPLIER);
1028 let alpha_percent = params.alpha_percent.unwrap_or(DEFAULT_ALPHA_PERCENT);
1029 let threshold_atr = params.threshold_atr.unwrap_or(DEFAULT_THRESHOLD_ATR);
1030 validate_params(atr_length, multiplier, alpha_percent, threshold_atr, cols)?;
1031 if max_run < atr_length {
1032 return Err(SuperTrendRecoveryError::NotEnoughValidData {
1033 needed: atr_length,
1034 valid: max_run,
1035 });
1036 }
1037 }
1038
1039 let do_row = |row: usize,
1040 band_row: &mut [f64],
1041 switch_row: &mut [f64],
1042 trend_row: &mut [f64],
1043 changed_row: &mut [f64]| {
1044 let params = &combos[row];
1045 compute_row(
1046 high,
1047 low,
1048 close,
1049 params.atr_length.unwrap_or(DEFAULT_ATR_LENGTH),
1050 params.multiplier.unwrap_or(DEFAULT_MULTIPLIER),
1051 params.alpha_percent.unwrap_or(DEFAULT_ALPHA_PERCENT),
1052 params.threshold_atr.unwrap_or(DEFAULT_THRESHOLD_ATR),
1053 band_row,
1054 switch_row,
1055 trend_row,
1056 changed_row,
1057 )
1058 };
1059
1060 if parallel {
1061 #[cfg(not(target_arch = "wasm32"))]
1062 {
1063 band_out
1064 .par_chunks_mut(cols)
1065 .zip(switch_price_out.par_chunks_mut(cols))
1066 .zip(trend_out.par_chunks_mut(cols))
1067 .zip(changed_out.par_chunks_mut(cols))
1068 .enumerate()
1069 .try_for_each(
1070 |(row, (((band_row, switch_row), trend_row), changed_row))| {
1071 do_row(row, band_row, switch_row, trend_row, changed_row)
1072 },
1073 )?;
1074 }
1075 #[cfg(target_arch = "wasm32")]
1076 {
1077 for (row, (((band_row, switch_row), trend_row), changed_row)) in band_out
1078 .chunks_mut(cols)
1079 .zip(switch_price_out.chunks_mut(cols))
1080 .zip(trend_out.chunks_mut(cols))
1081 .zip(changed_out.chunks_mut(cols))
1082 .enumerate()
1083 {
1084 do_row(row, band_row, switch_row, trend_row, changed_row)?;
1085 }
1086 }
1087 } else {
1088 for (row, (((band_row, switch_row), trend_row), changed_row)) in band_out
1089 .chunks_mut(cols)
1090 .zip(switch_price_out.chunks_mut(cols))
1091 .zip(trend_out.chunks_mut(cols))
1092 .zip(changed_out.chunks_mut(cols))
1093 .enumerate()
1094 {
1095 do_row(row, band_row, switch_row, trend_row, changed_row)?;
1096 }
1097 }
1098
1099 Ok(combos)
1100}
1101
1102pub fn supertrend_recovery_batch_with_kernel(
1103 high: &[f64],
1104 low: &[f64],
1105 close: &[f64],
1106 sweep: &SuperTrendRecoveryBatchRange,
1107 kernel: Kernel,
1108) -> Result<SuperTrendRecoveryBatchOutput, SuperTrendRecoveryError> {
1109 match kernel {
1110 Kernel::Auto => {
1111 let _ = detect_best_batch_kernel();
1112 }
1113 k if !k.is_batch() => return Err(SuperTrendRecoveryError::InvalidKernelForBatch(k)),
1114 _ => {}
1115 }
1116 supertrend_recovery_batch_par_slice(high, low, close, sweep, Kernel::ScalarBatch)
1117}
1118
1119pub fn supertrend_recovery_batch_slice(
1120 high: &[f64],
1121 low: &[f64],
1122 close: &[f64],
1123 sweep: &SuperTrendRecoveryBatchRange,
1124 _kernel: Kernel,
1125) -> Result<SuperTrendRecoveryBatchOutput, SuperTrendRecoveryError> {
1126 supertrend_recovery_batch_impl(high, low, close, sweep, false)
1127}
1128
1129pub fn supertrend_recovery_batch_par_slice(
1130 high: &[f64],
1131 low: &[f64],
1132 close: &[f64],
1133 sweep: &SuperTrendRecoveryBatchRange,
1134 _kernel: Kernel,
1135) -> Result<SuperTrendRecoveryBatchOutput, SuperTrendRecoveryError> {
1136 supertrend_recovery_batch_impl(high, low, close, sweep, true)
1137}
1138
1139fn supertrend_recovery_batch_impl(
1140 high: &[f64],
1141 low: &[f64],
1142 close: &[f64],
1143 sweep: &SuperTrendRecoveryBatchRange,
1144 parallel: bool,
1145) -> Result<SuperTrendRecoveryBatchOutput, SuperTrendRecoveryError> {
1146 let rows = expand_grid(sweep)?.len();
1147 let cols = close.len();
1148
1149 let band_mu = make_uninit_matrix(rows, cols);
1150 let switch_mu = make_uninit_matrix(rows, cols);
1151 let trend_mu = make_uninit_matrix(rows, cols);
1152 let changed_mu = make_uninit_matrix(rows, cols);
1153
1154 let mut band_guard = ManuallyDrop::new(band_mu);
1155 let mut switch_guard = ManuallyDrop::new(switch_mu);
1156 let mut trend_guard = ManuallyDrop::new(trend_mu);
1157 let mut changed_guard = ManuallyDrop::new(changed_mu);
1158
1159 let band_out: &mut [f64] = unsafe {
1160 core::slice::from_raw_parts_mut(band_guard.as_mut_ptr() as *mut f64, band_guard.len())
1161 };
1162 let switch_out: &mut [f64] = unsafe {
1163 core::slice::from_raw_parts_mut(switch_guard.as_mut_ptr() as *mut f64, switch_guard.len())
1164 };
1165 let trend_out: &mut [f64] = unsafe {
1166 core::slice::from_raw_parts_mut(trend_guard.as_mut_ptr() as *mut f64, trend_guard.len())
1167 };
1168 let changed_out: &mut [f64] = unsafe {
1169 core::slice::from_raw_parts_mut(changed_guard.as_mut_ptr() as *mut f64, changed_guard.len())
1170 };
1171
1172 let combos = supertrend_recovery_batch_inner_into(
1173 high,
1174 low,
1175 close,
1176 sweep,
1177 parallel,
1178 band_out,
1179 switch_out,
1180 trend_out,
1181 changed_out,
1182 )?;
1183
1184 let band = unsafe {
1185 Vec::from_raw_parts(
1186 band_guard.as_mut_ptr() as *mut f64,
1187 band_guard.len(),
1188 band_guard.capacity(),
1189 )
1190 };
1191 let switch_price = unsafe {
1192 Vec::from_raw_parts(
1193 switch_guard.as_mut_ptr() as *mut f64,
1194 switch_guard.len(),
1195 switch_guard.capacity(),
1196 )
1197 };
1198 let trend = unsafe {
1199 Vec::from_raw_parts(
1200 trend_guard.as_mut_ptr() as *mut f64,
1201 trend_guard.len(),
1202 trend_guard.capacity(),
1203 )
1204 };
1205 let changed = unsafe {
1206 Vec::from_raw_parts(
1207 changed_guard.as_mut_ptr() as *mut f64,
1208 changed_guard.len(),
1209 changed_guard.capacity(),
1210 )
1211 };
1212
1213 Ok(SuperTrendRecoveryBatchOutput {
1214 band,
1215 switch_price,
1216 trend,
1217 changed,
1218 combos,
1219 rows,
1220 cols,
1221 })
1222}
1223
1224#[cfg(feature = "python")]
1225#[pyfunction(name = "supertrend_recovery")]
1226#[pyo3(signature = (high, low, close, atr_length=DEFAULT_ATR_LENGTH, multiplier=DEFAULT_MULTIPLIER, alpha_percent=DEFAULT_ALPHA_PERCENT, threshold_atr=DEFAULT_THRESHOLD_ATR, kernel=None))]
1227pub fn supertrend_recovery_py<'py>(
1228 py: Python<'py>,
1229 high: PyReadonlyArray1<'py, f64>,
1230 low: PyReadonlyArray1<'py, f64>,
1231 close: PyReadonlyArray1<'py, f64>,
1232 atr_length: usize,
1233 multiplier: f64,
1234 alpha_percent: f64,
1235 threshold_atr: f64,
1236 kernel: Option<&str>,
1237) -> PyResult<(
1238 Bound<'py, PyArray1<f64>>,
1239 Bound<'py, PyArray1<f64>>,
1240 Bound<'py, PyArray1<f64>>,
1241 Bound<'py, PyArray1<f64>>,
1242)> {
1243 let high_slice = high.as_slice()?;
1244 let low_slice = low.as_slice()?;
1245 let close_slice = close.as_slice()?;
1246 let kernel = validate_kernel(kernel, false)?;
1247 let input = SuperTrendRecoveryInput::from_slices(
1248 high_slice,
1249 low_slice,
1250 close_slice,
1251 SuperTrendRecoveryParams {
1252 atr_length: Some(atr_length),
1253 multiplier: Some(multiplier),
1254 alpha_percent: Some(alpha_percent),
1255 threshold_atr: Some(threshold_atr),
1256 },
1257 );
1258 let output = py
1259 .allow_threads(|| supertrend_recovery_with_kernel(&input, kernel))
1260 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1261 Ok((
1262 output.band.into_pyarray(py),
1263 output.switch_price.into_pyarray(py),
1264 output.trend.into_pyarray(py),
1265 output.changed.into_pyarray(py),
1266 ))
1267}
1268
1269#[cfg(feature = "python")]
1270#[pyfunction(name = "supertrend_recovery_batch")]
1271#[pyo3(signature = (high, low, close, atr_length_range=(DEFAULT_ATR_LENGTH, DEFAULT_ATR_LENGTH, 0), multiplier_range=(DEFAULT_MULTIPLIER, DEFAULT_MULTIPLIER, 0.0), alpha_percent_range=(DEFAULT_ALPHA_PERCENT, DEFAULT_ALPHA_PERCENT, 0.0), threshold_atr_range=(DEFAULT_THRESHOLD_ATR, DEFAULT_THRESHOLD_ATR, 0.0), kernel=None))]
1272pub fn supertrend_recovery_batch_py<'py>(
1273 py: Python<'py>,
1274 high: PyReadonlyArray1<'py, f64>,
1275 low: PyReadonlyArray1<'py, f64>,
1276 close: PyReadonlyArray1<'py, f64>,
1277 atr_length_range: (usize, usize, usize),
1278 multiplier_range: (f64, f64, f64),
1279 alpha_percent_range: (f64, f64, f64),
1280 threshold_atr_range: (f64, f64, f64),
1281 kernel: Option<&str>,
1282) -> PyResult<Bound<'py, PyDict>> {
1283 let high_slice = high.as_slice()?;
1284 let low_slice = low.as_slice()?;
1285 let close_slice = close.as_slice()?;
1286 let kernel = validate_kernel(kernel, true)?;
1287 let sweep = SuperTrendRecoveryBatchRange {
1288 atr_length: atr_length_range,
1289 multiplier: multiplier_range,
1290 alpha_percent: alpha_percent_range,
1291 threshold_atr: threshold_atr_range,
1292 };
1293
1294 let rows = expand_grid(&sweep)
1295 .map_err(|e| PyValueError::new_err(e.to_string()))?
1296 .len();
1297 let cols = close_slice.len();
1298 let total = rows
1299 .checked_mul(cols)
1300 .ok_or_else(|| PyValueError::new_err("rows*cols overflow in supertrend_recovery_batch"))?;
1301
1302 let band_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1303 let switch_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1304 let trend_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1305 let changed_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1306
1307 let band_out = unsafe { band_arr.as_slice_mut()? };
1308 let switch_out = unsafe { switch_arr.as_slice_mut()? };
1309 let trend_out = unsafe { trend_arr.as_slice_mut()? };
1310 let changed_out = unsafe { changed_arr.as_slice_mut()? };
1311
1312 let combos = py
1313 .allow_threads(|| {
1314 supertrend_recovery_batch_inner_into(
1315 high_slice,
1316 low_slice,
1317 close_slice,
1318 &sweep,
1319 !matches!(kernel, Kernel::Scalar | Kernel::ScalarBatch),
1320 band_out,
1321 switch_out,
1322 trend_out,
1323 changed_out,
1324 )
1325 })
1326 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1327
1328 let dict = PyDict::new(py);
1329 dict.set_item("band", band_arr.reshape((rows, cols))?)?;
1330 dict.set_item("switch_price", switch_arr.reshape((rows, cols))?)?;
1331 dict.set_item("trend", trend_arr.reshape((rows, cols))?)?;
1332 dict.set_item("changed", changed_arr.reshape((rows, cols))?)?;
1333 dict.set_item(
1334 "atr_lengths",
1335 combos
1336 .iter()
1337 .map(|c| c.atr_length.unwrap_or(DEFAULT_ATR_LENGTH) as u64)
1338 .collect::<Vec<_>>()
1339 .into_pyarray(py),
1340 )?;
1341 dict.set_item(
1342 "multipliers",
1343 combos
1344 .iter()
1345 .map(|c| c.multiplier.unwrap_or(DEFAULT_MULTIPLIER))
1346 .collect::<Vec<_>>()
1347 .into_pyarray(py),
1348 )?;
1349 dict.set_item(
1350 "alpha_percents",
1351 combos
1352 .iter()
1353 .map(|c| c.alpha_percent.unwrap_or(DEFAULT_ALPHA_PERCENT))
1354 .collect::<Vec<_>>()
1355 .into_pyarray(py),
1356 )?;
1357 dict.set_item(
1358 "threshold_atrs",
1359 combos
1360 .iter()
1361 .map(|c| c.threshold_atr.unwrap_or(DEFAULT_THRESHOLD_ATR))
1362 .collect::<Vec<_>>()
1363 .into_pyarray(py),
1364 )?;
1365 dict.set_item("rows", rows)?;
1366 dict.set_item("cols", cols)?;
1367 Ok(dict)
1368}
1369
1370#[cfg(feature = "python")]
1371#[pyclass(name = "SuperTrendRecoveryStream")]
1372pub struct SuperTrendRecoveryStreamPy {
1373 stream: SuperTrendRecoveryStream,
1374}
1375
1376#[cfg(feature = "python")]
1377#[pymethods]
1378impl SuperTrendRecoveryStreamPy {
1379 #[new]
1380 #[pyo3(signature = (atr_length=DEFAULT_ATR_LENGTH, multiplier=DEFAULT_MULTIPLIER, alpha_percent=DEFAULT_ALPHA_PERCENT, threshold_atr=DEFAULT_THRESHOLD_ATR))]
1381 fn new(
1382 atr_length: usize,
1383 multiplier: f64,
1384 alpha_percent: f64,
1385 threshold_atr: f64,
1386 ) -> PyResult<Self> {
1387 let stream = SuperTrendRecoveryStream::try_new(SuperTrendRecoveryParams {
1388 atr_length: Some(atr_length),
1389 multiplier: Some(multiplier),
1390 alpha_percent: Some(alpha_percent),
1391 threshold_atr: Some(threshold_atr),
1392 })
1393 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1394 Ok(Self { stream })
1395 }
1396
1397 fn update(&mut self, high: f64, low: f64, close: f64) -> Option<(f64, f64, f64, f64)> {
1398 self.stream.update(high, low, close)
1399 }
1400}
1401
1402#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1403#[derive(Serialize, Deserialize)]
1404pub struct SuperTrendRecoveryBatchConfig {
1405 pub atr_length_range: (usize, usize, usize),
1406 pub multiplier_range: (f64, f64, f64),
1407 pub alpha_percent_range: (f64, f64, f64),
1408 pub threshold_atr_range: (f64, f64, f64),
1409}
1410
1411#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1412#[derive(Serialize, Deserialize)]
1413pub struct SuperTrendRecoveryBatchJsOutput {
1414 pub band: Vec<f64>,
1415 pub switch_price: Vec<f64>,
1416 pub trend: Vec<f64>,
1417 pub changed: Vec<f64>,
1418 pub combos: Vec<SuperTrendRecoveryParams>,
1419 pub rows: usize,
1420 pub cols: usize,
1421}
1422
1423#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1424#[wasm_bindgen]
1425pub fn supertrend_recovery_js(
1426 high: &[f64],
1427 low: &[f64],
1428 close: &[f64],
1429 atr_length: usize,
1430 multiplier: f64,
1431 alpha_percent: f64,
1432 threshold_atr: f64,
1433) -> Result<JsValue, JsValue> {
1434 let input = SuperTrendRecoveryInput::from_slices(
1435 high,
1436 low,
1437 close,
1438 SuperTrendRecoveryParams {
1439 atr_length: Some(atr_length),
1440 multiplier: Some(multiplier),
1441 alpha_percent: Some(alpha_percent),
1442 threshold_atr: Some(threshold_atr),
1443 },
1444 );
1445 let output = supertrend_recovery_with_kernel(&input, Kernel::Auto)
1446 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1447 serde_wasm_bindgen::to_value(&output)
1448 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1449}
1450
1451#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1452#[wasm_bindgen]
1453pub fn supertrend_recovery_alloc(len: usize) -> *mut f64 {
1454 let mut vec = Vec::<f64>::with_capacity(len);
1455 let ptr = vec.as_mut_ptr();
1456 std::mem::forget(vec);
1457 ptr
1458}
1459
1460#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1461#[wasm_bindgen]
1462pub fn supertrend_recovery_free(ptr: *mut f64, len: usize) {
1463 if !ptr.is_null() {
1464 unsafe {
1465 let _ = Vec::from_raw_parts(ptr, len, len);
1466 }
1467 }
1468}
1469
1470#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1471#[wasm_bindgen]
1472pub fn supertrend_recovery_into(
1473 high_ptr: *const f64,
1474 low_ptr: *const f64,
1475 close_ptr: *const f64,
1476 band_ptr: *mut f64,
1477 switch_price_ptr: *mut f64,
1478 trend_ptr: *mut f64,
1479 changed_ptr: *mut f64,
1480 len: usize,
1481 atr_length: usize,
1482 multiplier: f64,
1483 alpha_percent: f64,
1484 threshold_atr: f64,
1485) -> Result<(), JsValue> {
1486 if high_ptr.is_null()
1487 || low_ptr.is_null()
1488 || close_ptr.is_null()
1489 || band_ptr.is_null()
1490 || switch_price_ptr.is_null()
1491 || trend_ptr.is_null()
1492 || changed_ptr.is_null()
1493 {
1494 return Err(JsValue::from_str("Null pointer provided"));
1495 }
1496
1497 unsafe {
1498 let high = std::slice::from_raw_parts(high_ptr, len);
1499 let low = std::slice::from_raw_parts(low_ptr, len);
1500 let close = std::slice::from_raw_parts(close_ptr, len);
1501 let input = SuperTrendRecoveryInput::from_slices(
1502 high,
1503 low,
1504 close,
1505 SuperTrendRecoveryParams {
1506 atr_length: Some(atr_length),
1507 multiplier: Some(multiplier),
1508 alpha_percent: Some(alpha_percent),
1509 threshold_atr: Some(threshold_atr),
1510 },
1511 );
1512
1513 let aliased = [
1514 high_ptr as *const u8,
1515 low_ptr as *const u8,
1516 close_ptr as *const u8,
1517 ]
1518 .iter()
1519 .any(|&inp| {
1520 [
1521 band_ptr as *const u8,
1522 switch_price_ptr as *const u8,
1523 trend_ptr as *const u8,
1524 changed_ptr as *const u8,
1525 ]
1526 .iter()
1527 .any(|&out| inp == out)
1528 }) || band_ptr == switch_price_ptr
1529 || band_ptr == trend_ptr
1530 || band_ptr == changed_ptr
1531 || switch_price_ptr == trend_ptr
1532 || switch_price_ptr == changed_ptr
1533 || trend_ptr == changed_ptr;
1534
1535 if aliased {
1536 let output = supertrend_recovery_with_kernel(&input, Kernel::Auto)
1537 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1538 std::slice::from_raw_parts_mut(band_ptr, len).copy_from_slice(&output.band);
1539 std::slice::from_raw_parts_mut(switch_price_ptr, len)
1540 .copy_from_slice(&output.switch_price);
1541 std::slice::from_raw_parts_mut(trend_ptr, len).copy_from_slice(&output.trend);
1542 std::slice::from_raw_parts_mut(changed_ptr, len).copy_from_slice(&output.changed);
1543 } else {
1544 let band_out = std::slice::from_raw_parts_mut(band_ptr, len);
1545 let switch_out = std::slice::from_raw_parts_mut(switch_price_ptr, len);
1546 let trend_out = std::slice::from_raw_parts_mut(trend_ptr, len);
1547 let changed_out = std::slice::from_raw_parts_mut(changed_ptr, len);
1548 supertrend_recovery_into_slice(
1549 band_out,
1550 switch_out,
1551 trend_out,
1552 changed_out,
1553 &input,
1554 Kernel::Auto,
1555 )
1556 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1557 }
1558 }
1559
1560 Ok(())
1561}
1562
1563#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1564#[wasm_bindgen(js_name = supertrend_recovery_batch)]
1565pub fn supertrend_recovery_batch_unified_js(
1566 high: &[f64],
1567 low: &[f64],
1568 close: &[f64],
1569 config: JsValue,
1570) -> Result<JsValue, JsValue> {
1571 let config: SuperTrendRecoveryBatchConfig = serde_wasm_bindgen::from_value(config)
1572 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1573 let sweep = SuperTrendRecoveryBatchRange {
1574 atr_length: config.atr_length_range,
1575 multiplier: config.multiplier_range,
1576 alpha_percent: config.alpha_percent_range,
1577 threshold_atr: config.threshold_atr_range,
1578 };
1579 let output = supertrend_recovery_batch_with_kernel(high, low, close, &sweep, Kernel::Auto)
1580 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1581 let js_output = SuperTrendRecoveryBatchJsOutput {
1582 band: output.band,
1583 switch_price: output.switch_price,
1584 trend: output.trend,
1585 changed: output.changed,
1586 combos: output.combos,
1587 rows: output.rows,
1588 cols: output.cols,
1589 };
1590 serde_wasm_bindgen::to_value(&js_output)
1591 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1592}
1593
1594#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1595#[wasm_bindgen]
1596pub fn supertrend_recovery_batch_into(
1597 high_ptr: *const f64,
1598 low_ptr: *const f64,
1599 close_ptr: *const f64,
1600 band_ptr: *mut f64,
1601 switch_price_ptr: *mut f64,
1602 trend_ptr: *mut f64,
1603 changed_ptr: *mut f64,
1604 len: usize,
1605 atr_length_start: usize,
1606 atr_length_end: usize,
1607 atr_length_step: usize,
1608 multiplier_start: f64,
1609 multiplier_end: f64,
1610 multiplier_step: f64,
1611 alpha_percent_start: f64,
1612 alpha_percent_end: f64,
1613 alpha_percent_step: f64,
1614 threshold_atr_start: f64,
1615 threshold_atr_end: f64,
1616 threshold_atr_step: f64,
1617) -> Result<usize, JsValue> {
1618 if high_ptr.is_null()
1619 || low_ptr.is_null()
1620 || close_ptr.is_null()
1621 || band_ptr.is_null()
1622 || switch_price_ptr.is_null()
1623 || trend_ptr.is_null()
1624 || changed_ptr.is_null()
1625 {
1626 return Err(JsValue::from_str("Null pointer provided"));
1627 }
1628
1629 let sweep = SuperTrendRecoveryBatchRange {
1630 atr_length: (atr_length_start, atr_length_end, atr_length_step),
1631 multiplier: (multiplier_start, multiplier_end, multiplier_step),
1632 alpha_percent: (alpha_percent_start, alpha_percent_end, alpha_percent_step),
1633 threshold_atr: (threshold_atr_start, threshold_atr_end, threshold_atr_step),
1634 };
1635 let rows = expand_grid(&sweep)
1636 .map_err(|e| JsValue::from_str(&e.to_string()))?
1637 .len();
1638 let total = rows
1639 .checked_mul(len)
1640 .ok_or_else(|| JsValue::from_str("rows*len overflow"))?;
1641
1642 unsafe {
1643 let high = std::slice::from_raw_parts(high_ptr, len);
1644 let low = std::slice::from_raw_parts(low_ptr, len);
1645 let close = std::slice::from_raw_parts(close_ptr, len);
1646 let band_out = std::slice::from_raw_parts_mut(band_ptr, total);
1647 let switch_out = std::slice::from_raw_parts_mut(switch_price_ptr, total);
1648 let trend_out = std::slice::from_raw_parts_mut(trend_ptr, total);
1649 let changed_out = std::slice::from_raw_parts_mut(changed_ptr, total);
1650 supertrend_recovery_batch_inner_into(
1651 high,
1652 low,
1653 close,
1654 &sweep,
1655 false,
1656 band_out,
1657 switch_out,
1658 trend_out,
1659 changed_out,
1660 )
1661 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1662 }
1663
1664 Ok(rows)
1665}
1666
1667#[cfg(test)]
1668mod tests {
1669 use super::*;
1670 use crate::utilities::data_loader::read_candles_from_csv;
1671
1672 fn trend_data(size: usize) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
1673 let mut high = Vec::with_capacity(size);
1674 let mut low = Vec::with_capacity(size);
1675 let mut close = Vec::with_capacity(size);
1676 for i in 0..size {
1677 let base = 100.0 + i as f64 * 0.8;
1678 high.push(base + 1.2 + (i % 3) as f64 * 0.1);
1679 low.push(base - 1.0 - (i % 2) as f64 * 0.1);
1680 close.push(base + ((i % 5) as f64 - 2.0) * 0.05);
1681 }
1682 (high, low, close)
1683 }
1684
1685 fn reversal_data() -> (Vec<f64>, Vec<f64>, Vec<f64>) {
1686 let close = vec![
1687 100.0, 101.0, 102.0, 103.0, 104.0, 105.0, 104.0, 102.0, 99.0, 96.0, 93.0, 91.0, 90.0,
1688 91.0, 93.0, 96.0, 100.0, 105.0, 109.0, 112.0, 111.0, 109.0, 107.0, 104.0, 101.0, 98.0,
1689 96.0, 95.0, 96.0, 98.0,
1690 ];
1691 let high = close.iter().map(|v| v + 1.0).collect::<Vec<_>>();
1692 let low = close.iter().map(|v| v - 1.0).collect::<Vec<_>>();
1693 (high, low, close)
1694 }
1695
1696 fn recovery_data() -> (Vec<f64>, Vec<f64>, Vec<f64>) {
1697 let close = vec![
1698 100.0, 101.0, 102.0, 103.0, 104.0, 105.0, 102.0, 97.0, 92.0, 88.0, 90.0, 93.0, 96.0,
1699 99.0, 101.0, 102.0, 101.0, 100.0, 99.0, 98.0, 97.0, 96.0, 95.0, 94.0, 93.0, 92.0, 91.0,
1700 90.0, 89.0, 88.0,
1701 ];
1702 let high = close.iter().map(|v| v + 0.9).collect::<Vec<_>>();
1703 let low = close.iter().map(|v| v - 0.9).collect::<Vec<_>>();
1704 (high, low, close)
1705 }
1706
1707 fn arrays_eq_nan(a: &[f64], b: &[f64]) -> bool {
1708 a.len() == b.len()
1709 && a.iter().zip(b.iter()).all(|(x, y)| {
1710 (x.is_nan() && y.is_nan())
1711 || (!x.is_nan() && !y.is_nan() && (*x - *y).abs() <= 1e-12)
1712 })
1713 }
1714
1715 #[test]
1716 fn supertrend_recovery_into_matches_single() -> Result<(), Box<dyn StdError>> {
1717 let (high, low, close) = trend_data(160);
1718 let input = SuperTrendRecoveryInput::from_slices(
1719 &high,
1720 &low,
1721 &close,
1722 SuperTrendRecoveryParams::default(),
1723 );
1724 let single = supertrend_recovery(&input)?;
1725
1726 let mut band = vec![0.0; close.len()];
1727 let mut switch_price = vec![0.0; close.len()];
1728 let mut trend = vec![0.0; close.len()];
1729 let mut changed = vec![0.0; close.len()];
1730 supertrend_recovery_into_slice(
1731 &mut band,
1732 &mut switch_price,
1733 &mut trend,
1734 &mut changed,
1735 &input,
1736 Kernel::Auto,
1737 )?;
1738
1739 assert!(arrays_eq_nan(&single.band, &band));
1740 assert!(arrays_eq_nan(&single.switch_price, &switch_price));
1741 assert!(arrays_eq_nan(&single.trend, &trend));
1742 assert!(arrays_eq_nan(&single.changed, &changed));
1743 Ok(())
1744 }
1745
1746 #[test]
1747 fn supertrend_recovery_stream_matches_batch() -> Result<(), Box<dyn StdError>> {
1748 let (high, low, close) = trend_data(170);
1749 let params = SuperTrendRecoveryParams::default();
1750 let input = SuperTrendRecoveryInput::from_slices(&high, &low, &close, params.clone());
1751 let batch = supertrend_recovery(&input)?;
1752
1753 let mut stream = SuperTrendRecoveryStream::try_new(params)?;
1754 let mut band = Vec::with_capacity(close.len());
1755 let mut switch_price = Vec::with_capacity(close.len());
1756 let mut trend = Vec::with_capacity(close.len());
1757 let mut changed = Vec::with_capacity(close.len());
1758
1759 for i in 0..close.len() {
1760 if let Some((b, s, t, c)) = stream.update(high[i], low[i], close[i]) {
1761 band.push(b);
1762 switch_price.push(s);
1763 trend.push(t);
1764 changed.push(c);
1765 } else {
1766 band.push(f64::NAN);
1767 switch_price.push(f64::NAN);
1768 trend.push(f64::NAN);
1769 changed.push(f64::NAN);
1770 }
1771 }
1772
1773 assert!(arrays_eq_nan(&batch.band, &band));
1774 assert!(arrays_eq_nan(&batch.switch_price, &switch_price));
1775 assert!(arrays_eq_nan(&batch.trend, &trend));
1776 assert!(arrays_eq_nan(&batch.changed, &changed));
1777 Ok(())
1778 }
1779
1780 #[test]
1781 fn supertrend_recovery_reversal_behavior() -> Result<(), Box<dyn StdError>> {
1782 let (high, low, close) = reversal_data();
1783 let output = supertrend_recovery(&SuperTrendRecoveryInput::from_slices(
1784 &high,
1785 &low,
1786 &close,
1787 SuperTrendRecoveryParams {
1788 atr_length: Some(4),
1789 multiplier: Some(1.5),
1790 alpha_percent: Some(5.0),
1791 threshold_atr: Some(1.0),
1792 },
1793 ))?;
1794
1795 let changes = output
1796 .changed
1797 .iter()
1798 .enumerate()
1799 .filter_map(|(i, v)| if *v == 1.0 { Some(i) } else { None })
1800 .collect::<Vec<_>>();
1801 assert!(!changes.is_empty());
1802 let first = changes[0];
1803 assert!(output.band[first].is_finite());
1804 assert!(output.switch_price[first].is_finite());
1805 assert!(output.trend[first] == 1.0 || output.trend[first] == -1.0);
1806 Ok(())
1807 }
1808
1809 #[test]
1810 fn supertrend_recovery_recovery_behavior() -> Result<(), Box<dyn StdError>> {
1811 let (high, low, close) = recovery_data();
1812 let recovered = supertrend_recovery(&SuperTrendRecoveryInput::from_slices(
1813 &high,
1814 &low,
1815 &close,
1816 SuperTrendRecoveryParams {
1817 atr_length: Some(4),
1818 multiplier: Some(3.0),
1819 alpha_percent: Some(100.0),
1820 threshold_atr: Some(0.0),
1821 },
1822 ))?;
1823 let baseline = supertrend_recovery(&SuperTrendRecoveryInput::from_slices(
1824 &high,
1825 &low,
1826 &close,
1827 SuperTrendRecoveryParams {
1828 atr_length: Some(4),
1829 multiplier: Some(3.0),
1830 alpha_percent: Some(0.1),
1831 threshold_atr: Some(1000.0),
1832 },
1833 ))?;
1834
1835 let mut found = false;
1836 for i in 0..close.len() {
1837 if recovered.trend[i] == baseline.trend[i]
1838 && recovered.band[i].is_finite()
1839 && baseline.band[i].is_finite()
1840 {
1841 if recovered.trend[i] == 1.0 && recovered.band[i] > baseline.band[i] {
1842 found = true;
1843 break;
1844 }
1845 if recovered.trend[i] == -1.0 && recovered.band[i] < baseline.band[i] {
1846 found = true;
1847 break;
1848 }
1849 }
1850 }
1851 assert!(found);
1852 Ok(())
1853 }
1854
1855 #[test]
1856 fn supertrend_recovery_nan_gap_restarts() -> Result<(), Box<dyn StdError>> {
1857 let (mut high, mut low, mut close) = trend_data(170);
1858 high[120] = f64::NAN;
1859 low[120] = f64::NAN;
1860 close[120] = f64::NAN;
1861 let output = supertrend_recovery(&SuperTrendRecoveryInput::from_slices(
1862 &high,
1863 &low,
1864 &close,
1865 SuperTrendRecoveryParams::default(),
1866 ))?;
1867
1868 let restart_end = (120 + DEFAULT_ATR_LENGTH).min(output.band.len());
1869 for i in 120..restart_end {
1870 assert!(output.band[i].is_nan());
1871 assert!(output.trend[i].is_nan());
1872 assert!(output.changed[i].is_nan());
1873 }
1874 Ok(())
1875 }
1876
1877 #[test]
1878 fn supertrend_recovery_batch_matches_single() -> Result<(), Box<dyn StdError>> {
1879 let (high, low, close) = trend_data(170);
1880 let sweep = SuperTrendRecoveryBatchRange {
1881 atr_length: (4, 5, 1),
1882 multiplier: (1.5, 2.0, 0.5),
1883 alpha_percent: (5.0, 10.0, 5.0),
1884 threshold_atr: (0.5, 1.0, 0.5),
1885 };
1886 let batch = supertrend_recovery_batch_with_kernel(
1887 &high,
1888 &low,
1889 &close,
1890 &sweep,
1891 Kernel::ScalarBatch,
1892 )?;
1893
1894 assert_eq!(batch.rows, 16);
1895 assert_eq!(batch.cols, close.len());
1896 for row in 0..batch.rows {
1897 let combo = &batch.combos[row];
1898 let single = supertrend_recovery(&SuperTrendRecoveryInput::from_slices(
1899 &high,
1900 &low,
1901 &close,
1902 combo.clone(),
1903 ))?;
1904 let start = row * batch.cols;
1905 let end = start + batch.cols;
1906 assert!(arrays_eq_nan(
1907 &batch.band[start..end],
1908 single.band.as_slice()
1909 ));
1910 assert!(arrays_eq_nan(
1911 &batch.switch_price[start..end],
1912 single.switch_price.as_slice()
1913 ));
1914 assert!(arrays_eq_nan(
1915 &batch.trend[start..end],
1916 single.trend.as_slice()
1917 ));
1918 assert!(arrays_eq_nan(
1919 &batch.changed[start..end],
1920 single.changed.as_slice()
1921 ));
1922 }
1923 Ok(())
1924 }
1925
1926 #[test]
1927 fn supertrend_recovery_invalid_alpha_errors() {
1928 let (high, low, close) = trend_data(160);
1929 let input = SuperTrendRecoveryInput::from_slices(
1930 &high,
1931 &low,
1932 &close,
1933 SuperTrendRecoveryParams {
1934 atr_length: Some(10),
1935 multiplier: Some(3.0),
1936 alpha_percent: Some(0.0),
1937 threshold_atr: Some(1.0),
1938 },
1939 );
1940 assert!(matches!(
1941 supertrend_recovery(&input),
1942 Err(SuperTrendRecoveryError::InvalidAlphaPercent { .. })
1943 ));
1944 }
1945
1946 #[test]
1947 fn supertrend_recovery_all_nan_errors() {
1948 let high = vec![f64::NAN; 160];
1949 let low = vec![f64::NAN; 160];
1950 let close = vec![f64::NAN; 160];
1951 let input = SuperTrendRecoveryInput::from_slices(
1952 &high,
1953 &low,
1954 &close,
1955 SuperTrendRecoveryParams::default(),
1956 );
1957 assert!(matches!(
1958 supertrend_recovery(&input),
1959 Err(SuperTrendRecoveryError::AllValuesNaN)
1960 ));
1961 }
1962
1963 #[test]
1964 fn supertrend_recovery_default_candles_smoke() -> Result<(), Box<dyn StdError>> {
1965 let candles = read_candles_from_csv("src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv")?;
1966 let output = supertrend_recovery(&SuperTrendRecoveryInput::with_default_candles(&candles))?;
1967 assert_eq!(output.band.len(), candles.close.len());
1968 assert_eq!(output.switch_price.len(), candles.close.len());
1969 assert_eq!(output.trend.len(), candles.close.len());
1970 assert_eq!(output.changed.len(), candles.close.len());
1971 Ok(())
1972 }
1973}