Skip to main content

vector_ta/indicators/
supertrend_recovery.rs

1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::utilities::data_loader::Candles;
16use crate::utilities::enums::Kernel;
17use crate::utilities::helpers::{
18    alloc_with_nan_prefix, detect_best_batch_kernel, make_uninit_matrix,
19};
20#[cfg(feature = "python")]
21use crate::utilities::kernel_validation::validate_kernel;
22#[cfg(not(target_arch = "wasm32"))]
23use rayon::prelude::*;
24#[cfg(test)]
25use std::error::Error as StdError;
26use std::mem::ManuallyDrop;
27use thiserror::Error;
28
29const DEFAULT_ATR_LENGTH: usize = 10;
30const DEFAULT_MULTIPLIER: f64 = 3.0;
31const DEFAULT_ALPHA_PERCENT: f64 = 5.0;
32const DEFAULT_THRESHOLD_ATR: f64 = 1.0;
33const DEFAULT_TREND: i8 = 1;
34const MIN_ALPHA_PERCENT: f64 = 0.1;
35const MAX_ALPHA_PERCENT: f64 = 100.0;
36const MIN_MULTIPLIER: f64 = 0.1;
37
38#[inline(always)]
39fn high_source(candles: &Candles) -> &[f64] {
40    &candles.high
41}
42
43#[inline(always)]
44fn low_source(candles: &Candles) -> &[f64] {
45    &candles.low
46}
47
48#[inline(always)]
49fn close_source(candles: &Candles) -> &[f64] {
50    &candles.close
51}
52
53#[inline(always)]
54fn hl2(high: f64, low: f64) -> f64 {
55    0.5 * (high + low)
56}
57
58#[inline(always)]
59fn true_range(high: f64, low: f64, prev_close: f64) -> f64 {
60    (high - low)
61        .max((high - prev_close).abs())
62        .max((low - prev_close).abs())
63}
64
65#[derive(Debug, Clone)]
66pub enum SuperTrendRecoveryData<'a> {
67    Candles {
68        candles: &'a Candles,
69    },
70    Slices {
71        high: &'a [f64],
72        low: &'a [f64],
73        close: &'a [f64],
74    },
75}
76
77#[derive(Debug, Clone)]
78#[cfg_attr(
79    all(target_arch = "wasm32", feature = "wasm"),
80    derive(Serialize, Deserialize)
81)]
82pub struct SuperTrendRecoveryOutput {
83    pub band: Vec<f64>,
84    pub switch_price: Vec<f64>,
85    pub trend: Vec<f64>,
86    pub changed: Vec<f64>,
87}
88
89#[derive(Debug, Clone, PartialEq)]
90#[cfg_attr(
91    all(target_arch = "wasm32", feature = "wasm"),
92    derive(Serialize, Deserialize)
93)]
94pub struct SuperTrendRecoveryParams {
95    pub atr_length: Option<usize>,
96    pub multiplier: Option<f64>,
97    pub alpha_percent: Option<f64>,
98    pub threshold_atr: Option<f64>,
99}
100
101impl Default for SuperTrendRecoveryParams {
102    fn default() -> Self {
103        Self {
104            atr_length: Some(DEFAULT_ATR_LENGTH),
105            multiplier: Some(DEFAULT_MULTIPLIER),
106            alpha_percent: Some(DEFAULT_ALPHA_PERCENT),
107            threshold_atr: Some(DEFAULT_THRESHOLD_ATR),
108        }
109    }
110}
111
112#[derive(Debug, Clone)]
113pub struct SuperTrendRecoveryInput<'a> {
114    pub data: SuperTrendRecoveryData<'a>,
115    pub params: SuperTrendRecoveryParams,
116}
117
118impl<'a> SuperTrendRecoveryInput<'a> {
119    #[inline(always)]
120    pub fn from_candles(candles: &'a Candles, params: SuperTrendRecoveryParams) -> Self {
121        Self {
122            data: SuperTrendRecoveryData::Candles { candles },
123            params,
124        }
125    }
126
127    #[inline(always)]
128    pub fn from_slices(
129        high: &'a [f64],
130        low: &'a [f64],
131        close: &'a [f64],
132        params: SuperTrendRecoveryParams,
133    ) -> Self {
134        Self {
135            data: SuperTrendRecoveryData::Slices { high, low, close },
136            params,
137        }
138    }
139
140    #[inline(always)]
141    pub fn with_default_candles(candles: &'a Candles) -> Self {
142        Self::from_candles(candles, SuperTrendRecoveryParams::default())
143    }
144
145    #[inline(always)]
146    pub fn get_atr_length(&self) -> usize {
147        self.params.atr_length.unwrap_or(DEFAULT_ATR_LENGTH)
148    }
149
150    #[inline(always)]
151    pub fn get_multiplier(&self) -> f64 {
152        self.params.multiplier.unwrap_or(DEFAULT_MULTIPLIER)
153    }
154
155    #[inline(always)]
156    pub fn get_alpha_percent(&self) -> f64 {
157        self.params.alpha_percent.unwrap_or(DEFAULT_ALPHA_PERCENT)
158    }
159
160    #[inline(always)]
161    pub fn get_threshold_atr(&self) -> f64 {
162        self.params.threshold_atr.unwrap_or(DEFAULT_THRESHOLD_ATR)
163    }
164
165    #[inline(always)]
166    fn as_hlc(&self) -> (&'a [f64], &'a [f64], &'a [f64]) {
167        match &self.data {
168            SuperTrendRecoveryData::Candles { candles } => (
169                high_source(candles),
170                low_source(candles),
171                close_source(candles),
172            ),
173            SuperTrendRecoveryData::Slices { high, low, close } => (*high, *low, *close),
174        }
175    }
176}
177
178impl<'a> AsRef<[f64]> for SuperTrendRecoveryInput<'a> {
179    #[inline(always)]
180    fn as_ref(&self) -> &[f64] {
181        self.as_hlc().2
182    }
183}
184
185#[derive(Clone, Debug)]
186pub struct SuperTrendRecoveryBuilder {
187    atr_length: Option<usize>,
188    multiplier: Option<f64>,
189    alpha_percent: Option<f64>,
190    threshold_atr: Option<f64>,
191    kernel: Kernel,
192}
193
194impl Default for SuperTrendRecoveryBuilder {
195    fn default() -> Self {
196        Self {
197            atr_length: None,
198            multiplier: None,
199            alpha_percent: None,
200            threshold_atr: None,
201            kernel: Kernel::Auto,
202        }
203    }
204}
205
206impl SuperTrendRecoveryBuilder {
207    #[inline(always)]
208    pub fn new() -> Self {
209        Self::default()
210    }
211
212    #[inline(always)]
213    pub fn atr_length(mut self, value: usize) -> Self {
214        self.atr_length = Some(value);
215        self
216    }
217
218    #[inline(always)]
219    pub fn multiplier(mut self, value: f64) -> Self {
220        self.multiplier = Some(value);
221        self
222    }
223
224    #[inline(always)]
225    pub fn alpha_percent(mut self, value: f64) -> Self {
226        self.alpha_percent = Some(value);
227        self
228    }
229
230    #[inline(always)]
231    pub fn threshold_atr(mut self, value: f64) -> Self {
232        self.threshold_atr = Some(value);
233        self
234    }
235
236    #[inline(always)]
237    pub fn kernel(mut self, kernel: Kernel) -> Self {
238        self.kernel = kernel;
239        self
240    }
241
242    #[inline(always)]
243    fn params(self) -> SuperTrendRecoveryParams {
244        SuperTrendRecoveryParams {
245            atr_length: self.atr_length,
246            multiplier: self.multiplier,
247            alpha_percent: self.alpha_percent,
248            threshold_atr: self.threshold_atr,
249        }
250    }
251
252    #[inline(always)]
253    pub fn apply(
254        self,
255        candles: &Candles,
256    ) -> Result<SuperTrendRecoveryOutput, SuperTrendRecoveryError> {
257        let kernel = self.kernel;
258        let params = self.params();
259        supertrend_recovery_with_kernel(
260            &SuperTrendRecoveryInput::from_candles(candles, params),
261            kernel,
262        )
263    }
264
265    #[inline(always)]
266    pub fn apply_slices(
267        self,
268        high: &[f64],
269        low: &[f64],
270        close: &[f64],
271    ) -> Result<SuperTrendRecoveryOutput, SuperTrendRecoveryError> {
272        let kernel = self.kernel;
273        let params = self.params();
274        supertrend_recovery_with_kernel(
275            &SuperTrendRecoveryInput::from_slices(high, low, close, params),
276            kernel,
277        )
278    }
279
280    #[inline(always)]
281    pub fn into_stream(self) -> Result<SuperTrendRecoveryStream, SuperTrendRecoveryError> {
282        SuperTrendRecoveryStream::try_new(self.params())
283    }
284}
285
286#[derive(Debug, Error)]
287pub enum SuperTrendRecoveryError {
288    #[error("supertrend_recovery: input data slice is empty.")]
289    EmptyInputData,
290    #[error("supertrend_recovery: all values are NaN.")]
291    AllValuesNaN,
292    #[error(
293        "supertrend_recovery: inconsistent data lengths - high = {high_len}, low = {low_len}, close = {close_len}"
294    )]
295    DataLengthMismatch {
296        high_len: usize,
297        low_len: usize,
298        close_len: usize,
299    },
300    #[error(
301        "supertrend_recovery: invalid period: atr_length = {atr_length}, data length = {data_len}"
302    )]
303    InvalidPeriod { atr_length: usize, data_len: usize },
304    #[error("supertrend_recovery: invalid multiplier: {multiplier}")]
305    InvalidMultiplier { multiplier: f64 },
306    #[error("supertrend_recovery: invalid alpha_percent: {alpha_percent}")]
307    InvalidAlphaPercent { alpha_percent: f64 },
308    #[error("supertrend_recovery: invalid threshold_atr: {threshold_atr}")]
309    InvalidThresholdAtr { threshold_atr: f64 },
310    #[error("supertrend_recovery: not enough valid data: needed = {needed}, valid = {valid}")]
311    NotEnoughValidData { needed: usize, valid: usize },
312    #[error("supertrend_recovery: output length mismatch: expected = {expected}, got = {got}")]
313    OutputLengthMismatch { expected: usize, got: usize },
314    #[error(
315        "supertrend_recovery: invalid range for {axis}: start = {start}, end = {end}, step = {step}"
316    )]
317    InvalidRange {
318        axis: &'static str,
319        start: String,
320        end: String,
321        step: String,
322    },
323    #[error("supertrend_recovery: invalid kernel for batch: {0:?}")]
324    InvalidKernelForBatch(Kernel),
325}
326
327#[derive(Clone, Copy, Debug)]
328struct PreparedInput<'a> {
329    high: &'a [f64],
330    low: &'a [f64],
331    close: &'a [f64],
332    atr_length: usize,
333    multiplier: f64,
334    alpha: f64,
335    threshold_atr: f64,
336    warmup: usize,
337}
338
339#[inline(always)]
340fn normalize_single_kernel(_kernel: Kernel) -> Kernel {
341    Kernel::Scalar
342}
343
344#[inline(always)]
345fn validate_params(
346    atr_length: usize,
347    multiplier: f64,
348    alpha_percent: f64,
349    threshold_atr: f64,
350    data_len: usize,
351) -> Result<(), SuperTrendRecoveryError> {
352    if atr_length == 0 || atr_length > data_len {
353        return Err(SuperTrendRecoveryError::InvalidPeriod {
354            atr_length,
355            data_len,
356        });
357    }
358    if !multiplier.is_finite() || multiplier < MIN_MULTIPLIER {
359        return Err(SuperTrendRecoveryError::InvalidMultiplier { multiplier });
360    }
361    if !alpha_percent.is_finite()
362        || !(MIN_ALPHA_PERCENT..=MAX_ALPHA_PERCENT).contains(&alpha_percent)
363    {
364        return Err(SuperTrendRecoveryError::InvalidAlphaPercent { alpha_percent });
365    }
366    if !threshold_atr.is_finite() || threshold_atr < 0.0 {
367        return Err(SuperTrendRecoveryError::InvalidThresholdAtr { threshold_atr });
368    }
369    Ok(())
370}
371
372#[inline(always)]
373fn analyze_valid_segments(
374    high: &[f64],
375    low: &[f64],
376    close: &[f64],
377) -> Result<(usize, usize), SuperTrendRecoveryError> {
378    if high.is_empty() || low.is_empty() || close.is_empty() {
379        return Err(SuperTrendRecoveryError::EmptyInputData);
380    }
381    if high.len() != low.len() || high.len() != close.len() {
382        return Err(SuperTrendRecoveryError::DataLengthMismatch {
383            high_len: high.len(),
384            low_len: low.len(),
385            close_len: close.len(),
386        });
387    }
388
389    let mut first_valid = None;
390    let mut max_run = 0usize;
391    let mut run = 0usize;
392
393    for i in 0..close.len() {
394        let valid = high[i].is_finite() && low[i].is_finite() && close[i].is_finite();
395        if valid {
396            if first_valid.is_none() {
397                first_valid = Some(i);
398            }
399            run += 1;
400            if run > max_run {
401                max_run = run;
402            }
403        } else {
404            run = 0;
405        }
406    }
407
408    match first_valid {
409        Some(idx) => Ok((idx, max_run)),
410        None => Err(SuperTrendRecoveryError::AllValuesNaN),
411    }
412}
413
414#[inline(always)]
415fn prepare_input<'a>(
416    input: &'a SuperTrendRecoveryInput<'a>,
417    kernel: Kernel,
418) -> Result<PreparedInput<'a>, SuperTrendRecoveryError> {
419    let _chosen = normalize_single_kernel(kernel);
420    let (high, low, close) = input.as_hlc();
421    let atr_length = input.get_atr_length();
422    let multiplier = input.get_multiplier();
423    let alpha_percent = input.get_alpha_percent();
424    let threshold_atr = input.get_threshold_atr();
425    validate_params(
426        atr_length,
427        multiplier,
428        alpha_percent,
429        threshold_atr,
430        close.len(),
431    )?;
432
433    let (first_valid, max_run) = analyze_valid_segments(high, low, close)?;
434    if max_run < atr_length {
435        return Err(SuperTrendRecoveryError::NotEnoughValidData {
436            needed: atr_length,
437            valid: max_run,
438        });
439    }
440
441    Ok(PreparedInput {
442        high,
443        low,
444        close,
445        atr_length,
446        multiplier,
447        alpha: alpha_percent * 0.01,
448        threshold_atr,
449        warmup: first_valid + atr_length - 1,
450    })
451}
452
453#[derive(Clone, Debug)]
454struct AtrState {
455    length: usize,
456    count: usize,
457    sum: f64,
458    value: f64,
459}
460
461impl AtrState {
462    #[inline(always)]
463    fn new(length: usize) -> Self {
464        Self {
465            length,
466            count: 0,
467            sum: 0.0,
468            value: f64::NAN,
469        }
470    }
471
472    #[inline(always)]
473    fn reset(&mut self) {
474        self.count = 0;
475        self.sum = 0.0;
476        self.value = f64::NAN;
477    }
478
479    #[inline(always)]
480    fn update(&mut self, tr: f64) -> Option<f64> {
481        if self.count < self.length {
482            self.count += 1;
483            self.sum += tr;
484            if self.count == self.length {
485                self.value = self.sum / self.length as f64;
486                Some(self.value)
487            } else {
488                None
489            }
490        } else {
491            self.value = ((self.value * (self.length as f64 - 1.0)) + tr) / self.length as f64;
492            Some(self.value)
493        }
494    }
495}
496
497#[derive(Clone, Debug)]
498struct SuperTrendRecoveryState {
499    atr: AtrState,
500    multiplier: f64,
501    alpha: f64,
502    threshold_atr: f64,
503    prev_close: f64,
504    band: f64,
505    switch_price: f64,
506    trend: i8,
507}
508
509impl SuperTrendRecoveryState {
510    #[inline(always)]
511    fn new(atr_length: usize, multiplier: f64, alpha: f64, threshold_atr: f64) -> Self {
512        Self {
513            atr: AtrState::new(atr_length),
514            multiplier,
515            alpha,
516            threshold_atr,
517            prev_close: f64::NAN,
518            band: f64::NAN,
519            switch_price: f64::NAN,
520            trend: DEFAULT_TREND,
521        }
522    }
523
524    #[inline(always)]
525    fn reset(&mut self) {
526        self.atr.reset();
527        self.prev_close = f64::NAN;
528        self.band = f64::NAN;
529        self.switch_price = f64::NAN;
530        self.trend = DEFAULT_TREND;
531    }
532
533    #[inline(always)]
534    fn update(&mut self, high: f64, low: f64, close: f64) -> Option<(f64, f64, f64, f64)> {
535        if !high.is_finite() || !low.is_finite() || !close.is_finite() {
536            self.reset();
537            return None;
538        }
539
540        if !self.switch_price.is_finite() {
541            self.switch_price = close;
542        }
543
544        let tr = if self.prev_close.is_finite() {
545            true_range(high, low, self.prev_close)
546        } else {
547            high - low
548        };
549        self.prev_close = close;
550
551        let atr = self.atr.update(tr)?;
552        let src = hl2(high, low);
553        let upper_base = src + self.multiplier * atr;
554        let lower_base = src - self.multiplier * atr;
555        let deviation = self.threshold_atr * atr;
556        let is_at_loss = (self.trend == 1 && (self.switch_price - close) > deviation)
557            || (self.trend == -1 && (close - self.switch_price) > deviation);
558        let prev_band = if self.band.is_finite() {
559            self.band
560        } else if self.trend == 1 {
561            lower_base
562        } else {
563            upper_base
564        };
565
566        let mut changed = 0.0;
567
568        if self.trend == 1 {
569            let target_band = if is_at_loss {
570                self.alpha.mul_add(close, (1.0 - self.alpha) * prev_band)
571            } else {
572                lower_base
573            };
574            self.band = target_band.max(prev_band);
575            if close < self.band {
576                self.trend = -1;
577                self.band = upper_base;
578                self.switch_price = close;
579                changed = 1.0;
580            }
581        } else {
582            let target_band = if is_at_loss {
583                self.alpha.mul_add(close, (1.0 - self.alpha) * prev_band)
584            } else {
585                upper_base
586            };
587            self.band = target_band.min(prev_band);
588            if close > self.band {
589                self.trend = 1;
590                self.band = lower_base;
591                self.switch_price = close;
592                changed = 1.0;
593            }
594        }
595
596        Some((self.band, self.switch_price, self.trend as f64, changed))
597    }
598}
599
600#[derive(Clone, Debug)]
601pub struct SuperTrendRecoveryStream {
602    params: SuperTrendRecoveryParams,
603    state: SuperTrendRecoveryState,
604}
605
606impl SuperTrendRecoveryStream {
607    #[inline(always)]
608    pub fn try_new(params: SuperTrendRecoveryParams) -> Result<Self, SuperTrendRecoveryError> {
609        let atr_length = params.atr_length.unwrap_or(DEFAULT_ATR_LENGTH);
610        let multiplier = params.multiplier.unwrap_or(DEFAULT_MULTIPLIER);
611        let alpha_percent = params.alpha_percent.unwrap_or(DEFAULT_ALPHA_PERCENT);
612        let threshold_atr = params.threshold_atr.unwrap_or(DEFAULT_THRESHOLD_ATR);
613        validate_params(
614            atr_length,
615            multiplier,
616            alpha_percent,
617            threshold_atr,
618            usize::MAX,
619        )?;
620        Ok(Self {
621            state: SuperTrendRecoveryState::new(
622                atr_length,
623                multiplier,
624                alpha_percent * 0.01,
625                threshold_atr,
626            ),
627            params,
628        })
629    }
630
631    #[inline(always)]
632    pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<(f64, f64, f64, f64)> {
633        self.state.update(high, low, close)
634    }
635
636    #[inline(always)]
637    pub fn params(&self) -> &SuperTrendRecoveryParams {
638        &self.params
639    }
640}
641
642#[derive(Clone, Debug)]
643pub struct SuperTrendRecoveryBatchRange {
644    pub atr_length: (usize, usize, usize),
645    pub multiplier: (f64, f64, f64),
646    pub alpha_percent: (f64, f64, f64),
647    pub threshold_atr: (f64, f64, f64),
648}
649
650impl Default for SuperTrendRecoveryBatchRange {
651    fn default() -> Self {
652        Self {
653            atr_length: (DEFAULT_ATR_LENGTH, DEFAULT_ATR_LENGTH, 0),
654            multiplier: (DEFAULT_MULTIPLIER, DEFAULT_MULTIPLIER, 0.0),
655            alpha_percent: (DEFAULT_ALPHA_PERCENT, DEFAULT_ALPHA_PERCENT, 0.0),
656            threshold_atr: (DEFAULT_THRESHOLD_ATR, DEFAULT_THRESHOLD_ATR, 0.0),
657        }
658    }
659}
660
661#[derive(Clone, Debug, Default)]
662pub struct SuperTrendRecoveryBatchBuilder {
663    range: SuperTrendRecoveryBatchRange,
664    kernel: Kernel,
665}
666
667#[derive(Clone, Debug)]
668pub struct SuperTrendRecoveryBatchOutput {
669    pub band: Vec<f64>,
670    pub switch_price: Vec<f64>,
671    pub trend: Vec<f64>,
672    pub changed: Vec<f64>,
673    pub combos: Vec<SuperTrendRecoveryParams>,
674    pub rows: usize,
675    pub cols: usize,
676}
677
678impl SuperTrendRecoveryBatchBuilder {
679    #[inline(always)]
680    pub fn new() -> Self {
681        Self::default()
682    }
683
684    #[inline(always)]
685    pub fn kernel(mut self, kernel: Kernel) -> Self {
686        self.kernel = kernel;
687        self
688    }
689
690    #[inline(always)]
691    pub fn atr_length_range(mut self, start: usize, end: usize, step: usize) -> Self {
692        self.range.atr_length = (start, end, step);
693        self
694    }
695
696    #[inline(always)]
697    pub fn multiplier_range(mut self, start: f64, end: f64, step: f64) -> Self {
698        self.range.multiplier = (start, end, step);
699        self
700    }
701
702    #[inline(always)]
703    pub fn alpha_percent_range(mut self, start: f64, end: f64, step: f64) -> Self {
704        self.range.alpha_percent = (start, end, step);
705        self
706    }
707
708    #[inline(always)]
709    pub fn threshold_atr_range(mut self, start: f64, end: f64, step: f64) -> Self {
710        self.range.threshold_atr = (start, end, step);
711        self
712    }
713
714    #[inline(always)]
715    pub fn apply_slices(
716        self,
717        high: &[f64],
718        low: &[f64],
719        close: &[f64],
720    ) -> Result<SuperTrendRecoveryBatchOutput, SuperTrendRecoveryError> {
721        supertrend_recovery_batch_with_kernel(high, low, close, &self.range, self.kernel)
722    }
723
724    #[inline(always)]
725    pub fn apply(
726        self,
727        candles: &Candles,
728    ) -> Result<SuperTrendRecoveryBatchOutput, SuperTrendRecoveryError> {
729        self.apply_slices(&candles.high, &candles.low, &candles.close)
730    }
731}
732
733#[inline(always)]
734fn compute_row(
735    high: &[f64],
736    low: &[f64],
737    close: &[f64],
738    atr_length: usize,
739    multiplier: f64,
740    alpha_percent: f64,
741    threshold_atr: f64,
742    band_out: &mut [f64],
743    switch_price_out: &mut [f64],
744    trend_out: &mut [f64],
745    changed_out: &mut [f64],
746) -> Result<(), SuperTrendRecoveryError> {
747    let len = close.len();
748    if band_out.len() != len
749        || switch_price_out.len() != len
750        || trend_out.len() != len
751        || changed_out.len() != len
752    {
753        return Err(SuperTrendRecoveryError::OutputLengthMismatch {
754            expected: len,
755            got: band_out
756                .len()
757                .max(switch_price_out.len())
758                .max(trend_out.len())
759                .max(changed_out.len()),
760        });
761    }
762
763    let mut state =
764        SuperTrendRecoveryState::new(atr_length, multiplier, alpha_percent * 0.01, threshold_atr);
765
766    for i in 0..len {
767        if let Some((band, switch_price, trend, changed)) = state.update(high[i], low[i], close[i])
768        {
769            band_out[i] = band;
770            switch_price_out[i] = switch_price;
771            trend_out[i] = trend;
772            changed_out[i] = changed;
773        } else {
774            band_out[i] = f64::NAN;
775            switch_price_out[i] = f64::NAN;
776            trend_out[i] = f64::NAN;
777            changed_out[i] = f64::NAN;
778        }
779    }
780
781    Ok(())
782}
783
784#[inline]
785pub fn supertrend_recovery(
786    input: &SuperTrendRecoveryInput,
787) -> Result<SuperTrendRecoveryOutput, SuperTrendRecoveryError> {
788    supertrend_recovery_with_kernel(input, Kernel::Auto)
789}
790
791pub fn supertrend_recovery_with_kernel(
792    input: &SuperTrendRecoveryInput,
793    kernel: Kernel,
794) -> Result<SuperTrendRecoveryOutput, SuperTrendRecoveryError> {
795    let prepared = prepare_input(input, kernel)?;
796    let len = prepared.close.len();
797    let mut band = alloc_with_nan_prefix(len, prepared.warmup);
798    let mut switch_price = alloc_with_nan_prefix(len, prepared.warmup);
799    let mut trend = alloc_with_nan_prefix(len, prepared.warmup);
800    let mut changed = alloc_with_nan_prefix(len, prepared.warmup);
801    compute_row(
802        prepared.high,
803        prepared.low,
804        prepared.close,
805        prepared.atr_length,
806        prepared.multiplier,
807        prepared.alpha / 0.01,
808        prepared.threshold_atr,
809        &mut band,
810        &mut switch_price,
811        &mut trend,
812        &mut changed,
813    )?;
814    Ok(SuperTrendRecoveryOutput {
815        band,
816        switch_price,
817        trend,
818        changed,
819    })
820}
821
822#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
823pub fn supertrend_recovery_into(
824    band_out: &mut [f64],
825    switch_price_out: &mut [f64],
826    trend_out: &mut [f64],
827    changed_out: &mut [f64],
828    input: &SuperTrendRecoveryInput,
829) -> Result<(), SuperTrendRecoveryError> {
830    supertrend_recovery_into_slice(
831        band_out,
832        switch_price_out,
833        trend_out,
834        changed_out,
835        input,
836        Kernel::Auto,
837    )
838}
839
840pub fn supertrend_recovery_into_slice(
841    band_out: &mut [f64],
842    switch_price_out: &mut [f64],
843    trend_out: &mut [f64],
844    changed_out: &mut [f64],
845    input: &SuperTrendRecoveryInput,
846    kernel: Kernel,
847) -> Result<(), SuperTrendRecoveryError> {
848    let prepared = prepare_input(input, kernel)?;
849    compute_row(
850        prepared.high,
851        prepared.low,
852        prepared.close,
853        prepared.atr_length,
854        prepared.multiplier,
855        prepared.alpha / 0.01,
856        prepared.threshold_atr,
857        band_out,
858        switch_price_out,
859        trend_out,
860        changed_out,
861    )
862}
863
864#[inline(always)]
865pub fn expand_grid(
866    sweep: &SuperTrendRecoveryBatchRange,
867) -> Result<Vec<SuperTrendRecoveryParams>, SuperTrendRecoveryError> {
868    fn axis_usize(
869        axis: &'static str,
870        (start, end, step): (usize, usize, usize),
871    ) -> Result<Vec<usize>, SuperTrendRecoveryError> {
872        if step == 0 || start == end {
873            return Ok(vec![start]);
874        }
875        let mut out = Vec::new();
876        if start < end {
877            let mut value = start;
878            while value <= end {
879                out.push(value);
880                match value.checked_add(step) {
881                    Some(next) => value = next,
882                    None => break,
883                }
884            }
885        } else {
886            let mut value = start as isize;
887            let stop = end as isize;
888            let stride = step as isize;
889            while value >= stop {
890                out.push(value as usize);
891                value -= stride;
892            }
893        }
894        if out.is_empty() {
895            return Err(SuperTrendRecoveryError::InvalidRange {
896                axis,
897                start: start.to_string(),
898                end: end.to_string(),
899                step: step.to_string(),
900            });
901        }
902        Ok(out)
903    }
904
905    fn axis_float(
906        axis: &'static str,
907        (start, end, step): (f64, f64, f64),
908    ) -> Result<Vec<f64>, SuperTrendRecoveryError> {
909        if !start.is_finite() || !end.is_finite() || !step.is_finite() {
910            return Err(SuperTrendRecoveryError::InvalidRange {
911                axis,
912                start: start.to_string(),
913                end: end.to_string(),
914                step: step.to_string(),
915            });
916        }
917        if step == 0.0 || start == end {
918            return Ok(vec![start]);
919        }
920        if step < 0.0 {
921            return Err(SuperTrendRecoveryError::InvalidRange {
922                axis,
923                start: start.to_string(),
924                end: end.to_string(),
925                step: step.to_string(),
926            });
927        }
928        let mut out = Vec::new();
929        let eps = step.abs() * 1e-9 + 1e-12;
930        if start < end {
931            let mut value = start;
932            while value <= end + eps {
933                out.push(value);
934                value += step;
935            }
936        } else {
937            let mut value = start;
938            while value + eps >= end {
939                out.push(value);
940                value -= step;
941            }
942        }
943        if out.is_empty() {
944            return Err(SuperTrendRecoveryError::InvalidRange {
945                axis,
946                start: start.to_string(),
947                end: end.to_string(),
948                step: step.to_string(),
949            });
950        }
951        Ok(out)
952    }
953
954    let atr_lengths = axis_usize("atr_length", sweep.atr_length)?;
955    let multipliers = axis_float("multiplier", sweep.multiplier)?;
956    let alpha_percents = axis_float("alpha_percent", sweep.alpha_percent)?;
957    let threshold_atrs = axis_float("threshold_atr", sweep.threshold_atr)?;
958
959    let cap = atr_lengths
960        .len()
961        .checked_mul(multipliers.len())
962        .and_then(|v| v.checked_mul(alpha_percents.len()))
963        .and_then(|v| v.checked_mul(threshold_atrs.len()))
964        .ok_or(SuperTrendRecoveryError::InvalidRange {
965            axis: "grid",
966            start: "cap".to_string(),
967            end: "overflow".to_string(),
968            step: "mul".to_string(),
969        })?;
970
971    let mut out = Vec::with_capacity(cap);
972    for &atr_length in &atr_lengths {
973        for &multiplier in &multipliers {
974            for &alpha_percent in &alpha_percents {
975                for &threshold_atr in &threshold_atrs {
976                    out.push(SuperTrendRecoveryParams {
977                        atr_length: Some(atr_length),
978                        multiplier: Some(multiplier),
979                        alpha_percent: Some(alpha_percent),
980                        threshold_atr: Some(threshold_atr),
981                    });
982                }
983            }
984        }
985    }
986    Ok(out)
987}
988
989fn supertrend_recovery_batch_inner_into(
990    high: &[f64],
991    low: &[f64],
992    close: &[f64],
993    sweep: &SuperTrendRecoveryBatchRange,
994    parallel: bool,
995    band_out: &mut [f64],
996    switch_price_out: &mut [f64],
997    trend_out: &mut [f64],
998    changed_out: &mut [f64],
999) -> Result<Vec<SuperTrendRecoveryParams>, SuperTrendRecoveryError> {
1000    let (_, max_run) = analyze_valid_segments(high, low, close)?;
1001    let combos = expand_grid(sweep)?;
1002    let rows = combos.len();
1003    let cols = close.len();
1004    let expected = rows
1005        .checked_mul(cols)
1006        .ok_or(SuperTrendRecoveryError::OutputLengthMismatch {
1007            expected: usize::MAX,
1008            got: band_out.len(),
1009        })?;
1010    if band_out.len() != expected
1011        || switch_price_out.len() != expected
1012        || trend_out.len() != expected
1013        || changed_out.len() != expected
1014    {
1015        return Err(SuperTrendRecoveryError::OutputLengthMismatch {
1016            expected,
1017            got: band_out
1018                .len()
1019                .max(switch_price_out.len())
1020                .max(trend_out.len())
1021                .max(changed_out.len()),
1022        });
1023    }
1024
1025    for params in &combos {
1026        let atr_length = params.atr_length.unwrap_or(DEFAULT_ATR_LENGTH);
1027        let multiplier = params.multiplier.unwrap_or(DEFAULT_MULTIPLIER);
1028        let alpha_percent = params.alpha_percent.unwrap_or(DEFAULT_ALPHA_PERCENT);
1029        let threshold_atr = params.threshold_atr.unwrap_or(DEFAULT_THRESHOLD_ATR);
1030        validate_params(atr_length, multiplier, alpha_percent, threshold_atr, cols)?;
1031        if max_run < atr_length {
1032            return Err(SuperTrendRecoveryError::NotEnoughValidData {
1033                needed: atr_length,
1034                valid: max_run,
1035            });
1036        }
1037    }
1038
1039    let do_row = |row: usize,
1040                  band_row: &mut [f64],
1041                  switch_row: &mut [f64],
1042                  trend_row: &mut [f64],
1043                  changed_row: &mut [f64]| {
1044        let params = &combos[row];
1045        compute_row(
1046            high,
1047            low,
1048            close,
1049            params.atr_length.unwrap_or(DEFAULT_ATR_LENGTH),
1050            params.multiplier.unwrap_or(DEFAULT_MULTIPLIER),
1051            params.alpha_percent.unwrap_or(DEFAULT_ALPHA_PERCENT),
1052            params.threshold_atr.unwrap_or(DEFAULT_THRESHOLD_ATR),
1053            band_row,
1054            switch_row,
1055            trend_row,
1056            changed_row,
1057        )
1058    };
1059
1060    if parallel {
1061        #[cfg(not(target_arch = "wasm32"))]
1062        {
1063            band_out
1064                .par_chunks_mut(cols)
1065                .zip(switch_price_out.par_chunks_mut(cols))
1066                .zip(trend_out.par_chunks_mut(cols))
1067                .zip(changed_out.par_chunks_mut(cols))
1068                .enumerate()
1069                .try_for_each(
1070                    |(row, (((band_row, switch_row), trend_row), changed_row))| {
1071                        do_row(row, band_row, switch_row, trend_row, changed_row)
1072                    },
1073                )?;
1074        }
1075        #[cfg(target_arch = "wasm32")]
1076        {
1077            for (row, (((band_row, switch_row), trend_row), changed_row)) in band_out
1078                .chunks_mut(cols)
1079                .zip(switch_price_out.chunks_mut(cols))
1080                .zip(trend_out.chunks_mut(cols))
1081                .zip(changed_out.chunks_mut(cols))
1082                .enumerate()
1083            {
1084                do_row(row, band_row, switch_row, trend_row, changed_row)?;
1085            }
1086        }
1087    } else {
1088        for (row, (((band_row, switch_row), trend_row), changed_row)) in band_out
1089            .chunks_mut(cols)
1090            .zip(switch_price_out.chunks_mut(cols))
1091            .zip(trend_out.chunks_mut(cols))
1092            .zip(changed_out.chunks_mut(cols))
1093            .enumerate()
1094        {
1095            do_row(row, band_row, switch_row, trend_row, changed_row)?;
1096        }
1097    }
1098
1099    Ok(combos)
1100}
1101
1102pub fn supertrend_recovery_batch_with_kernel(
1103    high: &[f64],
1104    low: &[f64],
1105    close: &[f64],
1106    sweep: &SuperTrendRecoveryBatchRange,
1107    kernel: Kernel,
1108) -> Result<SuperTrendRecoveryBatchOutput, SuperTrendRecoveryError> {
1109    match kernel {
1110        Kernel::Auto => {
1111            let _ = detect_best_batch_kernel();
1112        }
1113        k if !k.is_batch() => return Err(SuperTrendRecoveryError::InvalidKernelForBatch(k)),
1114        _ => {}
1115    }
1116    supertrend_recovery_batch_par_slice(high, low, close, sweep, Kernel::ScalarBatch)
1117}
1118
1119pub fn supertrend_recovery_batch_slice(
1120    high: &[f64],
1121    low: &[f64],
1122    close: &[f64],
1123    sweep: &SuperTrendRecoveryBatchRange,
1124    _kernel: Kernel,
1125) -> Result<SuperTrendRecoveryBatchOutput, SuperTrendRecoveryError> {
1126    supertrend_recovery_batch_impl(high, low, close, sweep, false)
1127}
1128
1129pub fn supertrend_recovery_batch_par_slice(
1130    high: &[f64],
1131    low: &[f64],
1132    close: &[f64],
1133    sweep: &SuperTrendRecoveryBatchRange,
1134    _kernel: Kernel,
1135) -> Result<SuperTrendRecoveryBatchOutput, SuperTrendRecoveryError> {
1136    supertrend_recovery_batch_impl(high, low, close, sweep, true)
1137}
1138
1139fn supertrend_recovery_batch_impl(
1140    high: &[f64],
1141    low: &[f64],
1142    close: &[f64],
1143    sweep: &SuperTrendRecoveryBatchRange,
1144    parallel: bool,
1145) -> Result<SuperTrendRecoveryBatchOutput, SuperTrendRecoveryError> {
1146    let rows = expand_grid(sweep)?.len();
1147    let cols = close.len();
1148
1149    let band_mu = make_uninit_matrix(rows, cols);
1150    let switch_mu = make_uninit_matrix(rows, cols);
1151    let trend_mu = make_uninit_matrix(rows, cols);
1152    let changed_mu = make_uninit_matrix(rows, cols);
1153
1154    let mut band_guard = ManuallyDrop::new(band_mu);
1155    let mut switch_guard = ManuallyDrop::new(switch_mu);
1156    let mut trend_guard = ManuallyDrop::new(trend_mu);
1157    let mut changed_guard = ManuallyDrop::new(changed_mu);
1158
1159    let band_out: &mut [f64] = unsafe {
1160        core::slice::from_raw_parts_mut(band_guard.as_mut_ptr() as *mut f64, band_guard.len())
1161    };
1162    let switch_out: &mut [f64] = unsafe {
1163        core::slice::from_raw_parts_mut(switch_guard.as_mut_ptr() as *mut f64, switch_guard.len())
1164    };
1165    let trend_out: &mut [f64] = unsafe {
1166        core::slice::from_raw_parts_mut(trend_guard.as_mut_ptr() as *mut f64, trend_guard.len())
1167    };
1168    let changed_out: &mut [f64] = unsafe {
1169        core::slice::from_raw_parts_mut(changed_guard.as_mut_ptr() as *mut f64, changed_guard.len())
1170    };
1171
1172    let combos = supertrend_recovery_batch_inner_into(
1173        high,
1174        low,
1175        close,
1176        sweep,
1177        parallel,
1178        band_out,
1179        switch_out,
1180        trend_out,
1181        changed_out,
1182    )?;
1183
1184    let band = unsafe {
1185        Vec::from_raw_parts(
1186            band_guard.as_mut_ptr() as *mut f64,
1187            band_guard.len(),
1188            band_guard.capacity(),
1189        )
1190    };
1191    let switch_price = unsafe {
1192        Vec::from_raw_parts(
1193            switch_guard.as_mut_ptr() as *mut f64,
1194            switch_guard.len(),
1195            switch_guard.capacity(),
1196        )
1197    };
1198    let trend = unsafe {
1199        Vec::from_raw_parts(
1200            trend_guard.as_mut_ptr() as *mut f64,
1201            trend_guard.len(),
1202            trend_guard.capacity(),
1203        )
1204    };
1205    let changed = unsafe {
1206        Vec::from_raw_parts(
1207            changed_guard.as_mut_ptr() as *mut f64,
1208            changed_guard.len(),
1209            changed_guard.capacity(),
1210        )
1211    };
1212
1213    Ok(SuperTrendRecoveryBatchOutput {
1214        band,
1215        switch_price,
1216        trend,
1217        changed,
1218        combos,
1219        rows,
1220        cols,
1221    })
1222}
1223
1224#[cfg(feature = "python")]
1225#[pyfunction(name = "supertrend_recovery")]
1226#[pyo3(signature = (high, low, close, atr_length=DEFAULT_ATR_LENGTH, multiplier=DEFAULT_MULTIPLIER, alpha_percent=DEFAULT_ALPHA_PERCENT, threshold_atr=DEFAULT_THRESHOLD_ATR, kernel=None))]
1227pub fn supertrend_recovery_py<'py>(
1228    py: Python<'py>,
1229    high: PyReadonlyArray1<'py, f64>,
1230    low: PyReadonlyArray1<'py, f64>,
1231    close: PyReadonlyArray1<'py, f64>,
1232    atr_length: usize,
1233    multiplier: f64,
1234    alpha_percent: f64,
1235    threshold_atr: f64,
1236    kernel: Option<&str>,
1237) -> PyResult<(
1238    Bound<'py, PyArray1<f64>>,
1239    Bound<'py, PyArray1<f64>>,
1240    Bound<'py, PyArray1<f64>>,
1241    Bound<'py, PyArray1<f64>>,
1242)> {
1243    let high_slice = high.as_slice()?;
1244    let low_slice = low.as_slice()?;
1245    let close_slice = close.as_slice()?;
1246    let kernel = validate_kernel(kernel, false)?;
1247    let input = SuperTrendRecoveryInput::from_slices(
1248        high_slice,
1249        low_slice,
1250        close_slice,
1251        SuperTrendRecoveryParams {
1252            atr_length: Some(atr_length),
1253            multiplier: Some(multiplier),
1254            alpha_percent: Some(alpha_percent),
1255            threshold_atr: Some(threshold_atr),
1256        },
1257    );
1258    let output = py
1259        .allow_threads(|| supertrend_recovery_with_kernel(&input, kernel))
1260        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1261    Ok((
1262        output.band.into_pyarray(py),
1263        output.switch_price.into_pyarray(py),
1264        output.trend.into_pyarray(py),
1265        output.changed.into_pyarray(py),
1266    ))
1267}
1268
1269#[cfg(feature = "python")]
1270#[pyfunction(name = "supertrend_recovery_batch")]
1271#[pyo3(signature = (high, low, close, atr_length_range=(DEFAULT_ATR_LENGTH, DEFAULT_ATR_LENGTH, 0), multiplier_range=(DEFAULT_MULTIPLIER, DEFAULT_MULTIPLIER, 0.0), alpha_percent_range=(DEFAULT_ALPHA_PERCENT, DEFAULT_ALPHA_PERCENT, 0.0), threshold_atr_range=(DEFAULT_THRESHOLD_ATR, DEFAULT_THRESHOLD_ATR, 0.0), kernel=None))]
1272pub fn supertrend_recovery_batch_py<'py>(
1273    py: Python<'py>,
1274    high: PyReadonlyArray1<'py, f64>,
1275    low: PyReadonlyArray1<'py, f64>,
1276    close: PyReadonlyArray1<'py, f64>,
1277    atr_length_range: (usize, usize, usize),
1278    multiplier_range: (f64, f64, f64),
1279    alpha_percent_range: (f64, f64, f64),
1280    threshold_atr_range: (f64, f64, f64),
1281    kernel: Option<&str>,
1282) -> PyResult<Bound<'py, PyDict>> {
1283    let high_slice = high.as_slice()?;
1284    let low_slice = low.as_slice()?;
1285    let close_slice = close.as_slice()?;
1286    let kernel = validate_kernel(kernel, true)?;
1287    let sweep = SuperTrendRecoveryBatchRange {
1288        atr_length: atr_length_range,
1289        multiplier: multiplier_range,
1290        alpha_percent: alpha_percent_range,
1291        threshold_atr: threshold_atr_range,
1292    };
1293
1294    let rows = expand_grid(&sweep)
1295        .map_err(|e| PyValueError::new_err(e.to_string()))?
1296        .len();
1297    let cols = close_slice.len();
1298    let total = rows
1299        .checked_mul(cols)
1300        .ok_or_else(|| PyValueError::new_err("rows*cols overflow in supertrend_recovery_batch"))?;
1301
1302    let band_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1303    let switch_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1304    let trend_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1305    let changed_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1306
1307    let band_out = unsafe { band_arr.as_slice_mut()? };
1308    let switch_out = unsafe { switch_arr.as_slice_mut()? };
1309    let trend_out = unsafe { trend_arr.as_slice_mut()? };
1310    let changed_out = unsafe { changed_arr.as_slice_mut()? };
1311
1312    let combos = py
1313        .allow_threads(|| {
1314            supertrend_recovery_batch_inner_into(
1315                high_slice,
1316                low_slice,
1317                close_slice,
1318                &sweep,
1319                !matches!(kernel, Kernel::Scalar | Kernel::ScalarBatch),
1320                band_out,
1321                switch_out,
1322                trend_out,
1323                changed_out,
1324            )
1325        })
1326        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1327
1328    let dict = PyDict::new(py);
1329    dict.set_item("band", band_arr.reshape((rows, cols))?)?;
1330    dict.set_item("switch_price", switch_arr.reshape((rows, cols))?)?;
1331    dict.set_item("trend", trend_arr.reshape((rows, cols))?)?;
1332    dict.set_item("changed", changed_arr.reshape((rows, cols))?)?;
1333    dict.set_item(
1334        "atr_lengths",
1335        combos
1336            .iter()
1337            .map(|c| c.atr_length.unwrap_or(DEFAULT_ATR_LENGTH) as u64)
1338            .collect::<Vec<_>>()
1339            .into_pyarray(py),
1340    )?;
1341    dict.set_item(
1342        "multipliers",
1343        combos
1344            .iter()
1345            .map(|c| c.multiplier.unwrap_or(DEFAULT_MULTIPLIER))
1346            .collect::<Vec<_>>()
1347            .into_pyarray(py),
1348    )?;
1349    dict.set_item(
1350        "alpha_percents",
1351        combos
1352            .iter()
1353            .map(|c| c.alpha_percent.unwrap_or(DEFAULT_ALPHA_PERCENT))
1354            .collect::<Vec<_>>()
1355            .into_pyarray(py),
1356    )?;
1357    dict.set_item(
1358        "threshold_atrs",
1359        combos
1360            .iter()
1361            .map(|c| c.threshold_atr.unwrap_or(DEFAULT_THRESHOLD_ATR))
1362            .collect::<Vec<_>>()
1363            .into_pyarray(py),
1364    )?;
1365    dict.set_item("rows", rows)?;
1366    dict.set_item("cols", cols)?;
1367    Ok(dict)
1368}
1369
1370#[cfg(feature = "python")]
1371#[pyclass(name = "SuperTrendRecoveryStream")]
1372pub struct SuperTrendRecoveryStreamPy {
1373    stream: SuperTrendRecoveryStream,
1374}
1375
1376#[cfg(feature = "python")]
1377#[pymethods]
1378impl SuperTrendRecoveryStreamPy {
1379    #[new]
1380    #[pyo3(signature = (atr_length=DEFAULT_ATR_LENGTH, multiplier=DEFAULT_MULTIPLIER, alpha_percent=DEFAULT_ALPHA_PERCENT, threshold_atr=DEFAULT_THRESHOLD_ATR))]
1381    fn new(
1382        atr_length: usize,
1383        multiplier: f64,
1384        alpha_percent: f64,
1385        threshold_atr: f64,
1386    ) -> PyResult<Self> {
1387        let stream = SuperTrendRecoveryStream::try_new(SuperTrendRecoveryParams {
1388            atr_length: Some(atr_length),
1389            multiplier: Some(multiplier),
1390            alpha_percent: Some(alpha_percent),
1391            threshold_atr: Some(threshold_atr),
1392        })
1393        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1394        Ok(Self { stream })
1395    }
1396
1397    fn update(&mut self, high: f64, low: f64, close: f64) -> Option<(f64, f64, f64, f64)> {
1398        self.stream.update(high, low, close)
1399    }
1400}
1401
1402#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1403#[derive(Serialize, Deserialize)]
1404pub struct SuperTrendRecoveryBatchConfig {
1405    pub atr_length_range: (usize, usize, usize),
1406    pub multiplier_range: (f64, f64, f64),
1407    pub alpha_percent_range: (f64, f64, f64),
1408    pub threshold_atr_range: (f64, f64, f64),
1409}
1410
1411#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1412#[derive(Serialize, Deserialize)]
1413pub struct SuperTrendRecoveryBatchJsOutput {
1414    pub band: Vec<f64>,
1415    pub switch_price: Vec<f64>,
1416    pub trend: Vec<f64>,
1417    pub changed: Vec<f64>,
1418    pub combos: Vec<SuperTrendRecoveryParams>,
1419    pub rows: usize,
1420    pub cols: usize,
1421}
1422
1423#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1424#[wasm_bindgen]
1425pub fn supertrend_recovery_js(
1426    high: &[f64],
1427    low: &[f64],
1428    close: &[f64],
1429    atr_length: usize,
1430    multiplier: f64,
1431    alpha_percent: f64,
1432    threshold_atr: f64,
1433) -> Result<JsValue, JsValue> {
1434    let input = SuperTrendRecoveryInput::from_slices(
1435        high,
1436        low,
1437        close,
1438        SuperTrendRecoveryParams {
1439            atr_length: Some(atr_length),
1440            multiplier: Some(multiplier),
1441            alpha_percent: Some(alpha_percent),
1442            threshold_atr: Some(threshold_atr),
1443        },
1444    );
1445    let output = supertrend_recovery_with_kernel(&input, Kernel::Auto)
1446        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1447    serde_wasm_bindgen::to_value(&output)
1448        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1449}
1450
1451#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1452#[wasm_bindgen]
1453pub fn supertrend_recovery_alloc(len: usize) -> *mut f64 {
1454    let mut vec = Vec::<f64>::with_capacity(len);
1455    let ptr = vec.as_mut_ptr();
1456    std::mem::forget(vec);
1457    ptr
1458}
1459
1460#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1461#[wasm_bindgen]
1462pub fn supertrend_recovery_free(ptr: *mut f64, len: usize) {
1463    if !ptr.is_null() {
1464        unsafe {
1465            let _ = Vec::from_raw_parts(ptr, len, len);
1466        }
1467    }
1468}
1469
1470#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1471#[wasm_bindgen]
1472pub fn supertrend_recovery_into(
1473    high_ptr: *const f64,
1474    low_ptr: *const f64,
1475    close_ptr: *const f64,
1476    band_ptr: *mut f64,
1477    switch_price_ptr: *mut f64,
1478    trend_ptr: *mut f64,
1479    changed_ptr: *mut f64,
1480    len: usize,
1481    atr_length: usize,
1482    multiplier: f64,
1483    alpha_percent: f64,
1484    threshold_atr: f64,
1485) -> Result<(), JsValue> {
1486    if high_ptr.is_null()
1487        || low_ptr.is_null()
1488        || close_ptr.is_null()
1489        || band_ptr.is_null()
1490        || switch_price_ptr.is_null()
1491        || trend_ptr.is_null()
1492        || changed_ptr.is_null()
1493    {
1494        return Err(JsValue::from_str("Null pointer provided"));
1495    }
1496
1497    unsafe {
1498        let high = std::slice::from_raw_parts(high_ptr, len);
1499        let low = std::slice::from_raw_parts(low_ptr, len);
1500        let close = std::slice::from_raw_parts(close_ptr, len);
1501        let input = SuperTrendRecoveryInput::from_slices(
1502            high,
1503            low,
1504            close,
1505            SuperTrendRecoveryParams {
1506                atr_length: Some(atr_length),
1507                multiplier: Some(multiplier),
1508                alpha_percent: Some(alpha_percent),
1509                threshold_atr: Some(threshold_atr),
1510            },
1511        );
1512
1513        let aliased = [
1514            high_ptr as *const u8,
1515            low_ptr as *const u8,
1516            close_ptr as *const u8,
1517        ]
1518        .iter()
1519        .any(|&inp| {
1520            [
1521                band_ptr as *const u8,
1522                switch_price_ptr as *const u8,
1523                trend_ptr as *const u8,
1524                changed_ptr as *const u8,
1525            ]
1526            .iter()
1527            .any(|&out| inp == out)
1528        }) || band_ptr == switch_price_ptr
1529            || band_ptr == trend_ptr
1530            || band_ptr == changed_ptr
1531            || switch_price_ptr == trend_ptr
1532            || switch_price_ptr == changed_ptr
1533            || trend_ptr == changed_ptr;
1534
1535        if aliased {
1536            let output = supertrend_recovery_with_kernel(&input, Kernel::Auto)
1537                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1538            std::slice::from_raw_parts_mut(band_ptr, len).copy_from_slice(&output.band);
1539            std::slice::from_raw_parts_mut(switch_price_ptr, len)
1540                .copy_from_slice(&output.switch_price);
1541            std::slice::from_raw_parts_mut(trend_ptr, len).copy_from_slice(&output.trend);
1542            std::slice::from_raw_parts_mut(changed_ptr, len).copy_from_slice(&output.changed);
1543        } else {
1544            let band_out = std::slice::from_raw_parts_mut(band_ptr, len);
1545            let switch_out = std::slice::from_raw_parts_mut(switch_price_ptr, len);
1546            let trend_out = std::slice::from_raw_parts_mut(trend_ptr, len);
1547            let changed_out = std::slice::from_raw_parts_mut(changed_ptr, len);
1548            supertrend_recovery_into_slice(
1549                band_out,
1550                switch_out,
1551                trend_out,
1552                changed_out,
1553                &input,
1554                Kernel::Auto,
1555            )
1556            .map_err(|e| JsValue::from_str(&e.to_string()))?;
1557        }
1558    }
1559
1560    Ok(())
1561}
1562
1563#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1564#[wasm_bindgen(js_name = supertrend_recovery_batch)]
1565pub fn supertrend_recovery_batch_unified_js(
1566    high: &[f64],
1567    low: &[f64],
1568    close: &[f64],
1569    config: JsValue,
1570) -> Result<JsValue, JsValue> {
1571    let config: SuperTrendRecoveryBatchConfig = serde_wasm_bindgen::from_value(config)
1572        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1573    let sweep = SuperTrendRecoveryBatchRange {
1574        atr_length: config.atr_length_range,
1575        multiplier: config.multiplier_range,
1576        alpha_percent: config.alpha_percent_range,
1577        threshold_atr: config.threshold_atr_range,
1578    };
1579    let output = supertrend_recovery_batch_with_kernel(high, low, close, &sweep, Kernel::Auto)
1580        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1581    let js_output = SuperTrendRecoveryBatchJsOutput {
1582        band: output.band,
1583        switch_price: output.switch_price,
1584        trend: output.trend,
1585        changed: output.changed,
1586        combos: output.combos,
1587        rows: output.rows,
1588        cols: output.cols,
1589    };
1590    serde_wasm_bindgen::to_value(&js_output)
1591        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1592}
1593
1594#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1595#[wasm_bindgen]
1596pub fn supertrend_recovery_batch_into(
1597    high_ptr: *const f64,
1598    low_ptr: *const f64,
1599    close_ptr: *const f64,
1600    band_ptr: *mut f64,
1601    switch_price_ptr: *mut f64,
1602    trend_ptr: *mut f64,
1603    changed_ptr: *mut f64,
1604    len: usize,
1605    atr_length_start: usize,
1606    atr_length_end: usize,
1607    atr_length_step: usize,
1608    multiplier_start: f64,
1609    multiplier_end: f64,
1610    multiplier_step: f64,
1611    alpha_percent_start: f64,
1612    alpha_percent_end: f64,
1613    alpha_percent_step: f64,
1614    threshold_atr_start: f64,
1615    threshold_atr_end: f64,
1616    threshold_atr_step: f64,
1617) -> Result<usize, JsValue> {
1618    if high_ptr.is_null()
1619        || low_ptr.is_null()
1620        || close_ptr.is_null()
1621        || band_ptr.is_null()
1622        || switch_price_ptr.is_null()
1623        || trend_ptr.is_null()
1624        || changed_ptr.is_null()
1625    {
1626        return Err(JsValue::from_str("Null pointer provided"));
1627    }
1628
1629    let sweep = SuperTrendRecoveryBatchRange {
1630        atr_length: (atr_length_start, atr_length_end, atr_length_step),
1631        multiplier: (multiplier_start, multiplier_end, multiplier_step),
1632        alpha_percent: (alpha_percent_start, alpha_percent_end, alpha_percent_step),
1633        threshold_atr: (threshold_atr_start, threshold_atr_end, threshold_atr_step),
1634    };
1635    let rows = expand_grid(&sweep)
1636        .map_err(|e| JsValue::from_str(&e.to_string()))?
1637        .len();
1638    let total = rows
1639        .checked_mul(len)
1640        .ok_or_else(|| JsValue::from_str("rows*len overflow"))?;
1641
1642    unsafe {
1643        let high = std::slice::from_raw_parts(high_ptr, len);
1644        let low = std::slice::from_raw_parts(low_ptr, len);
1645        let close = std::slice::from_raw_parts(close_ptr, len);
1646        let band_out = std::slice::from_raw_parts_mut(band_ptr, total);
1647        let switch_out = std::slice::from_raw_parts_mut(switch_price_ptr, total);
1648        let trend_out = std::slice::from_raw_parts_mut(trend_ptr, total);
1649        let changed_out = std::slice::from_raw_parts_mut(changed_ptr, total);
1650        supertrend_recovery_batch_inner_into(
1651            high,
1652            low,
1653            close,
1654            &sweep,
1655            false,
1656            band_out,
1657            switch_out,
1658            trend_out,
1659            changed_out,
1660        )
1661        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1662    }
1663
1664    Ok(rows)
1665}
1666
1667#[cfg(test)]
1668mod tests {
1669    use super::*;
1670    use crate::utilities::data_loader::read_candles_from_csv;
1671
1672    fn trend_data(size: usize) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
1673        let mut high = Vec::with_capacity(size);
1674        let mut low = Vec::with_capacity(size);
1675        let mut close = Vec::with_capacity(size);
1676        for i in 0..size {
1677            let base = 100.0 + i as f64 * 0.8;
1678            high.push(base + 1.2 + (i % 3) as f64 * 0.1);
1679            low.push(base - 1.0 - (i % 2) as f64 * 0.1);
1680            close.push(base + ((i % 5) as f64 - 2.0) * 0.05);
1681        }
1682        (high, low, close)
1683    }
1684
1685    fn reversal_data() -> (Vec<f64>, Vec<f64>, Vec<f64>) {
1686        let close = vec![
1687            100.0, 101.0, 102.0, 103.0, 104.0, 105.0, 104.0, 102.0, 99.0, 96.0, 93.0, 91.0, 90.0,
1688            91.0, 93.0, 96.0, 100.0, 105.0, 109.0, 112.0, 111.0, 109.0, 107.0, 104.0, 101.0, 98.0,
1689            96.0, 95.0, 96.0, 98.0,
1690        ];
1691        let high = close.iter().map(|v| v + 1.0).collect::<Vec<_>>();
1692        let low = close.iter().map(|v| v - 1.0).collect::<Vec<_>>();
1693        (high, low, close)
1694    }
1695
1696    fn recovery_data() -> (Vec<f64>, Vec<f64>, Vec<f64>) {
1697        let close = vec![
1698            100.0, 101.0, 102.0, 103.0, 104.0, 105.0, 102.0, 97.0, 92.0, 88.0, 90.0, 93.0, 96.0,
1699            99.0, 101.0, 102.0, 101.0, 100.0, 99.0, 98.0, 97.0, 96.0, 95.0, 94.0, 93.0, 92.0, 91.0,
1700            90.0, 89.0, 88.0,
1701        ];
1702        let high = close.iter().map(|v| v + 0.9).collect::<Vec<_>>();
1703        let low = close.iter().map(|v| v - 0.9).collect::<Vec<_>>();
1704        (high, low, close)
1705    }
1706
1707    fn arrays_eq_nan(a: &[f64], b: &[f64]) -> bool {
1708        a.len() == b.len()
1709            && a.iter().zip(b.iter()).all(|(x, y)| {
1710                (x.is_nan() && y.is_nan())
1711                    || (!x.is_nan() && !y.is_nan() && (*x - *y).abs() <= 1e-12)
1712            })
1713    }
1714
1715    #[test]
1716    fn supertrend_recovery_into_matches_single() -> Result<(), Box<dyn StdError>> {
1717        let (high, low, close) = trend_data(160);
1718        let input = SuperTrendRecoveryInput::from_slices(
1719            &high,
1720            &low,
1721            &close,
1722            SuperTrendRecoveryParams::default(),
1723        );
1724        let single = supertrend_recovery(&input)?;
1725
1726        let mut band = vec![0.0; close.len()];
1727        let mut switch_price = vec![0.0; close.len()];
1728        let mut trend = vec![0.0; close.len()];
1729        let mut changed = vec![0.0; close.len()];
1730        supertrend_recovery_into_slice(
1731            &mut band,
1732            &mut switch_price,
1733            &mut trend,
1734            &mut changed,
1735            &input,
1736            Kernel::Auto,
1737        )?;
1738
1739        assert!(arrays_eq_nan(&single.band, &band));
1740        assert!(arrays_eq_nan(&single.switch_price, &switch_price));
1741        assert!(arrays_eq_nan(&single.trend, &trend));
1742        assert!(arrays_eq_nan(&single.changed, &changed));
1743        Ok(())
1744    }
1745
1746    #[test]
1747    fn supertrend_recovery_stream_matches_batch() -> Result<(), Box<dyn StdError>> {
1748        let (high, low, close) = trend_data(170);
1749        let params = SuperTrendRecoveryParams::default();
1750        let input = SuperTrendRecoveryInput::from_slices(&high, &low, &close, params.clone());
1751        let batch = supertrend_recovery(&input)?;
1752
1753        let mut stream = SuperTrendRecoveryStream::try_new(params)?;
1754        let mut band = Vec::with_capacity(close.len());
1755        let mut switch_price = Vec::with_capacity(close.len());
1756        let mut trend = Vec::with_capacity(close.len());
1757        let mut changed = Vec::with_capacity(close.len());
1758
1759        for i in 0..close.len() {
1760            if let Some((b, s, t, c)) = stream.update(high[i], low[i], close[i]) {
1761                band.push(b);
1762                switch_price.push(s);
1763                trend.push(t);
1764                changed.push(c);
1765            } else {
1766                band.push(f64::NAN);
1767                switch_price.push(f64::NAN);
1768                trend.push(f64::NAN);
1769                changed.push(f64::NAN);
1770            }
1771        }
1772
1773        assert!(arrays_eq_nan(&batch.band, &band));
1774        assert!(arrays_eq_nan(&batch.switch_price, &switch_price));
1775        assert!(arrays_eq_nan(&batch.trend, &trend));
1776        assert!(arrays_eq_nan(&batch.changed, &changed));
1777        Ok(())
1778    }
1779
1780    #[test]
1781    fn supertrend_recovery_reversal_behavior() -> Result<(), Box<dyn StdError>> {
1782        let (high, low, close) = reversal_data();
1783        let output = supertrend_recovery(&SuperTrendRecoveryInput::from_slices(
1784            &high,
1785            &low,
1786            &close,
1787            SuperTrendRecoveryParams {
1788                atr_length: Some(4),
1789                multiplier: Some(1.5),
1790                alpha_percent: Some(5.0),
1791                threshold_atr: Some(1.0),
1792            },
1793        ))?;
1794
1795        let changes = output
1796            .changed
1797            .iter()
1798            .enumerate()
1799            .filter_map(|(i, v)| if *v == 1.0 { Some(i) } else { None })
1800            .collect::<Vec<_>>();
1801        assert!(!changes.is_empty());
1802        let first = changes[0];
1803        assert!(output.band[first].is_finite());
1804        assert!(output.switch_price[first].is_finite());
1805        assert!(output.trend[first] == 1.0 || output.trend[first] == -1.0);
1806        Ok(())
1807    }
1808
1809    #[test]
1810    fn supertrend_recovery_recovery_behavior() -> Result<(), Box<dyn StdError>> {
1811        let (high, low, close) = recovery_data();
1812        let recovered = supertrend_recovery(&SuperTrendRecoveryInput::from_slices(
1813            &high,
1814            &low,
1815            &close,
1816            SuperTrendRecoveryParams {
1817                atr_length: Some(4),
1818                multiplier: Some(3.0),
1819                alpha_percent: Some(100.0),
1820                threshold_atr: Some(0.0),
1821            },
1822        ))?;
1823        let baseline = supertrend_recovery(&SuperTrendRecoveryInput::from_slices(
1824            &high,
1825            &low,
1826            &close,
1827            SuperTrendRecoveryParams {
1828                atr_length: Some(4),
1829                multiplier: Some(3.0),
1830                alpha_percent: Some(0.1),
1831                threshold_atr: Some(1000.0),
1832            },
1833        ))?;
1834
1835        let mut found = false;
1836        for i in 0..close.len() {
1837            if recovered.trend[i] == baseline.trend[i]
1838                && recovered.band[i].is_finite()
1839                && baseline.band[i].is_finite()
1840            {
1841                if recovered.trend[i] == 1.0 && recovered.band[i] > baseline.band[i] {
1842                    found = true;
1843                    break;
1844                }
1845                if recovered.trend[i] == -1.0 && recovered.band[i] < baseline.band[i] {
1846                    found = true;
1847                    break;
1848                }
1849            }
1850        }
1851        assert!(found);
1852        Ok(())
1853    }
1854
1855    #[test]
1856    fn supertrend_recovery_nan_gap_restarts() -> Result<(), Box<dyn StdError>> {
1857        let (mut high, mut low, mut close) = trend_data(170);
1858        high[120] = f64::NAN;
1859        low[120] = f64::NAN;
1860        close[120] = f64::NAN;
1861        let output = supertrend_recovery(&SuperTrendRecoveryInput::from_slices(
1862            &high,
1863            &low,
1864            &close,
1865            SuperTrendRecoveryParams::default(),
1866        ))?;
1867
1868        let restart_end = (120 + DEFAULT_ATR_LENGTH).min(output.band.len());
1869        for i in 120..restart_end {
1870            assert!(output.band[i].is_nan());
1871            assert!(output.trend[i].is_nan());
1872            assert!(output.changed[i].is_nan());
1873        }
1874        Ok(())
1875    }
1876
1877    #[test]
1878    fn supertrend_recovery_batch_matches_single() -> Result<(), Box<dyn StdError>> {
1879        let (high, low, close) = trend_data(170);
1880        let sweep = SuperTrendRecoveryBatchRange {
1881            atr_length: (4, 5, 1),
1882            multiplier: (1.5, 2.0, 0.5),
1883            alpha_percent: (5.0, 10.0, 5.0),
1884            threshold_atr: (0.5, 1.0, 0.5),
1885        };
1886        let batch = supertrend_recovery_batch_with_kernel(
1887            &high,
1888            &low,
1889            &close,
1890            &sweep,
1891            Kernel::ScalarBatch,
1892        )?;
1893
1894        assert_eq!(batch.rows, 16);
1895        assert_eq!(batch.cols, close.len());
1896        for row in 0..batch.rows {
1897            let combo = &batch.combos[row];
1898            let single = supertrend_recovery(&SuperTrendRecoveryInput::from_slices(
1899                &high,
1900                &low,
1901                &close,
1902                combo.clone(),
1903            ))?;
1904            let start = row * batch.cols;
1905            let end = start + batch.cols;
1906            assert!(arrays_eq_nan(
1907                &batch.band[start..end],
1908                single.band.as_slice()
1909            ));
1910            assert!(arrays_eq_nan(
1911                &batch.switch_price[start..end],
1912                single.switch_price.as_slice()
1913            ));
1914            assert!(arrays_eq_nan(
1915                &batch.trend[start..end],
1916                single.trend.as_slice()
1917            ));
1918            assert!(arrays_eq_nan(
1919                &batch.changed[start..end],
1920                single.changed.as_slice()
1921            ));
1922        }
1923        Ok(())
1924    }
1925
1926    #[test]
1927    fn supertrend_recovery_invalid_alpha_errors() {
1928        let (high, low, close) = trend_data(160);
1929        let input = SuperTrendRecoveryInput::from_slices(
1930            &high,
1931            &low,
1932            &close,
1933            SuperTrendRecoveryParams {
1934                atr_length: Some(10),
1935                multiplier: Some(3.0),
1936                alpha_percent: Some(0.0),
1937                threshold_atr: Some(1.0),
1938            },
1939        );
1940        assert!(matches!(
1941            supertrend_recovery(&input),
1942            Err(SuperTrendRecoveryError::InvalidAlphaPercent { .. })
1943        ));
1944    }
1945
1946    #[test]
1947    fn supertrend_recovery_all_nan_errors() {
1948        let high = vec![f64::NAN; 160];
1949        let low = vec![f64::NAN; 160];
1950        let close = vec![f64::NAN; 160];
1951        let input = SuperTrendRecoveryInput::from_slices(
1952            &high,
1953            &low,
1954            &close,
1955            SuperTrendRecoveryParams::default(),
1956        );
1957        assert!(matches!(
1958            supertrend_recovery(&input),
1959            Err(SuperTrendRecoveryError::AllValuesNaN)
1960        ));
1961    }
1962
1963    #[test]
1964    fn supertrend_recovery_default_candles_smoke() -> Result<(), Box<dyn StdError>> {
1965        let candles = read_candles_from_csv("src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv")?;
1966        let output = supertrend_recovery(&SuperTrendRecoveryInput::with_default_candles(&candles))?;
1967        assert_eq!(output.band.len(), candles.close.len());
1968        assert_eq!(output.switch_price.len(), candles.close.len());
1969        assert_eq!(output.trend.len(), candles.close.len());
1970        assert_eq!(output.changed.len(), candles.close.len());
1971        Ok(())
1972    }
1973}