Skip to main content

vector_ta/indicators/
supertrend.rs

1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::moving_averages::DeviceArrayF32;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::cuda::{cuda_available, CudaSupertrend};
5use crate::indicators::atr::{atr, AtrData, AtrError, AtrInput, AtrOutput, AtrParams};
6use crate::utilities::data_loader::{source_type, Candles};
7#[cfg(all(feature = "python", feature = "cuda"))]
8use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
9use crate::utilities::enums::Kernel;
10use crate::utilities::helpers::{
11    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
12    make_uninit_matrix,
13};
14#[cfg(feature = "python")]
15use crate::utilities::kernel_validation::validate_kernel;
16use aligned_vec::{AVec, CACHELINE_ALIGN};
17#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
18use core::arch::x86_64::*;
19#[cfg(all(feature = "python", feature = "cuda"))]
20use cust::context::Context;
21#[cfg(feature = "python")]
22use pyo3::exceptions::{PyBufferError, PyValueError};
23#[cfg(feature = "python")]
24use pyo3::prelude::*;
25#[cfg(not(target_arch = "wasm32"))]
26use rayon::prelude::*;
27#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
28use serde::{Deserialize, Serialize};
29use std::collections::HashMap;
30use std::convert::AsRef;
31#[cfg(all(feature = "python", feature = "cuda"))]
32use std::sync::Arc;
33use thiserror::Error;
34#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
35use wasm_bindgen::prelude::*;
36
37#[derive(Debug, Clone)]
38pub enum SuperTrendData<'a> {
39    Candles {
40        candles: &'a Candles,
41    },
42    Slices {
43        high: &'a [f64],
44        low: &'a [f64],
45        close: &'a [f64],
46    },
47}
48
49#[derive(Debug, Clone)]
50pub struct SuperTrendParams {
51    pub period: Option<usize>,
52    pub factor: Option<f64>,
53}
54impl Default for SuperTrendParams {
55    fn default() -> Self {
56        Self {
57            period: Some(10),
58            factor: Some(3.0),
59        }
60    }
61}
62
63#[derive(Debug, Clone)]
64pub struct SuperTrendInput<'a> {
65    pub data: SuperTrendData<'a>,
66    pub params: SuperTrendParams,
67}
68
69impl<'a> SuperTrendInput<'a> {
70    #[inline]
71    pub fn from_candles(candles: &'a Candles, params: SuperTrendParams) -> Self {
72        Self {
73            data: SuperTrendData::Candles { candles },
74            params,
75        }
76    }
77    #[inline]
78    pub fn from_slices(
79        high: &'a [f64],
80        low: &'a [f64],
81        close: &'a [f64],
82        params: SuperTrendParams,
83    ) -> Self {
84        Self {
85            data: SuperTrendData::Slices { high, low, close },
86            params,
87        }
88    }
89    #[inline]
90    pub fn with_default_candles(candles: &'a Candles) -> Self {
91        Self {
92            data: SuperTrendData::Candles { candles },
93            params: SuperTrendParams::default(),
94        }
95    }
96    #[inline]
97    pub fn get_period(&self) -> usize {
98        self.params.period.unwrap_or(10)
99    }
100    #[inline]
101    pub fn get_factor(&self) -> f64 {
102        self.params.factor.unwrap_or(3.0)
103    }
104    #[inline(always)]
105    fn as_hlc(&self) -> (&[f64], &[f64], &[f64]) {
106        match &self.data {
107            SuperTrendData::Candles { candles } => (
108                source_type(candles, "high"),
109                source_type(candles, "low"),
110                source_type(candles, "close"),
111            ),
112            SuperTrendData::Slices { high, low, close } => (*high, *low, *close),
113        }
114    }
115}
116
117#[derive(Debug, Clone)]
118pub struct SuperTrendOutput {
119    pub trend: Vec<f64>,
120    pub changed: Vec<f64>,
121}
122
123#[derive(Copy, Clone, Debug)]
124pub struct SuperTrendBuilder {
125    period: Option<usize>,
126    factor: Option<f64>,
127    kernel: Kernel,
128}
129impl Default for SuperTrendBuilder {
130    fn default() -> Self {
131        Self {
132            period: None,
133            factor: None,
134            kernel: Kernel::Auto,
135        }
136    }
137}
138impl SuperTrendBuilder {
139    #[inline]
140    pub fn new() -> Self {
141        Self::default()
142    }
143    #[inline]
144    pub fn period(mut self, n: usize) -> Self {
145        self.period = Some(n);
146        self
147    }
148    #[inline]
149    pub fn factor(mut self, x: f64) -> Self {
150        self.factor = Some(x);
151        self
152    }
153    #[inline]
154    pub fn kernel(mut self, k: Kernel) -> Self {
155        self.kernel = k;
156        self
157    }
158    #[inline]
159    pub fn apply(self, c: &Candles) -> Result<SuperTrendOutput, SuperTrendError> {
160        let p = SuperTrendParams {
161            period: self.period,
162            factor: self.factor,
163        };
164        let i = SuperTrendInput::from_candles(c, p);
165        supertrend_with_kernel(&i, self.kernel)
166    }
167    #[inline]
168    pub fn apply_slices(
169        self,
170        high: &[f64],
171        low: &[f64],
172        close: &[f64],
173    ) -> Result<SuperTrendOutput, SuperTrendError> {
174        let p = SuperTrendParams {
175            period: self.period,
176            factor: self.factor,
177        };
178        let i = SuperTrendInput::from_slices(high, low, close, p);
179        supertrend_with_kernel(&i, self.kernel)
180    }
181    #[inline]
182    pub fn into_stream(self) -> Result<SuperTrendStream, SuperTrendError> {
183        let p = SuperTrendParams {
184            period: self.period,
185            factor: self.factor,
186        };
187        SuperTrendStream::try_new(p)
188    }
189}
190
191#[derive(Debug, Error)]
192pub enum SuperTrendError {
193    #[error("supertrend: Empty data provided.")]
194    EmptyInputData,
195    #[error("supertrend: All values are NaN.")]
196    AllValuesNaN,
197    #[error("supertrend: Invalid period: period = {period}, data length = {data_len}")]
198    InvalidPeriod { period: usize, data_len: usize },
199    #[error("supertrend: Not enough valid data: needed = {needed}, valid = {valid}")]
200    NotEnoughValidData { needed: usize, valid: usize },
201    #[error("supertrend: Output slice length mismatch: expected = {expected}, got = {got}")]
202    OutputLengthMismatch { expected: usize, got: usize },
203    #[error("supertrend: Invalid range: start={start}, end={end}, step={step}")]
204    InvalidRange {
205        start: usize,
206        end: usize,
207        step: usize,
208    },
209    #[error("supertrend: Invalid factor range: start={start}, end={end}, step={step}")]
210    InvalidFactorRange { start: f64, end: f64, step: f64 },
211    #[error("supertrend: Invalid kernel for batch: {0:?}")]
212    InvalidKernelForBatch(Kernel),
213    #[error(transparent)]
214    AtrError(#[from] AtrError),
215}
216
217#[inline]
218pub fn supertrend(input: &SuperTrendInput) -> Result<SuperTrendOutput, SuperTrendError> {
219    supertrend_with_kernel(input, Kernel::Auto)
220}
221
222#[inline(always)]
223fn supertrend_prepare<'a>(
224    input: &'a SuperTrendInput,
225    kernel: Kernel,
226) -> Result<
227    (
228        &'a [f64],
229        &'a [f64],
230        &'a [f64],
231        usize,
232        f64,
233        usize,
234        Vec<f64>,
235        Kernel,
236    ),
237    SuperTrendError,
238> {
239    let (high, low, close) = input.as_hlc();
240
241    if high.is_empty() || low.is_empty() || close.is_empty() {
242        return Err(SuperTrendError::EmptyInputData);
243    }
244
245    let period = input.get_period();
246    if period == 0 || period > high.len() {
247        return Err(SuperTrendError::InvalidPeriod {
248            period,
249            data_len: high.len(),
250        });
251    }
252
253    let factor = input.get_factor();
254    let len = high.len();
255
256    let mut first_valid_idx = None;
257    for i in 0..len {
258        if !high[i].is_nan() && !low[i].is_nan() && !close[i].is_nan() {
259            first_valid_idx = Some(i);
260            break;
261        }
262    }
263
264    let first_valid_idx = match first_valid_idx {
265        Some(idx) => idx,
266        None => return Err(SuperTrendError::AllValuesNaN),
267    };
268
269    if (len - first_valid_idx) < period {
270        return Err(SuperTrendError::NotEnoughValidData {
271            needed: period,
272            valid: len - first_valid_idx,
273        });
274    }
275
276    let atr_input = AtrInput::from_slices(
277        &high[first_valid_idx..],
278        &low[first_valid_idx..],
279        &close[first_valid_idx..],
280        AtrParams {
281            length: Some(period),
282        },
283    );
284    let AtrOutput { values: atr_values } = atr(&atr_input)?;
285
286    let chosen = match kernel {
287        Kernel::Auto => Kernel::Scalar,
288        other => other,
289    };
290
291    Ok((
292        high,
293        low,
294        close,
295        period,
296        factor,
297        first_valid_idx,
298        atr_values,
299        chosen,
300    ))
301}
302
303#[inline(always)]
304fn supertrend_compute_into(
305    high: &[f64],
306    low: &[f64],
307    close: &[f64],
308    period: usize,
309    factor: f64,
310    first_valid_idx: usize,
311    atr_values: &[f64],
312    kernel: Kernel,
313    trend_out: &mut [f64],
314    changed_out: &mut [f64],
315) {
316    unsafe {
317        match kernel {
318            Kernel::Scalar | Kernel::ScalarBatch => {
319                supertrend_scalar(
320                    high,
321                    low,
322                    close,
323                    period,
324                    factor,
325                    first_valid_idx,
326                    &atr_values,
327                    trend_out,
328                    changed_out,
329                );
330            }
331            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
332            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
333                supertrend_scalar(
334                    high,
335                    low,
336                    close,
337                    period,
338                    factor,
339                    first_valid_idx,
340                    &atr_values,
341                    trend_out,
342                    changed_out,
343                );
344            }
345            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
346            Kernel::Avx2 | Kernel::Avx2Batch => {
347                supertrend_avx2(
348                    high,
349                    low,
350                    close,
351                    period,
352                    factor,
353                    first_valid_idx,
354                    &atr_values,
355                    trend_out,
356                    changed_out,
357                );
358            }
359            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
360            Kernel::Avx512 | Kernel::Avx512Batch => {
361                supertrend_avx512(
362                    high,
363                    low,
364                    close,
365                    period,
366                    factor,
367                    first_valid_idx,
368                    &atr_values,
369                    trend_out,
370                    changed_out,
371                );
372            }
373            _ => unreachable!(),
374        }
375    }
376}
377
378pub fn supertrend_with_kernel(
379    input: &SuperTrendInput,
380    kernel: Kernel,
381) -> Result<SuperTrendOutput, SuperTrendError> {
382    let (high, low, close, period, factor, first_valid_idx, atr_values, chosen) =
383        supertrend_prepare(input, kernel)?;
384
385    let len = high.len();
386    let mut trend = alloc_with_nan_prefix(len, first_valid_idx + period - 1);
387    let mut changed = alloc_with_nan_prefix(len, first_valid_idx + period - 1);
388
389    supertrend_compute_into(
390        high,
391        low,
392        close,
393        period,
394        factor,
395        first_valid_idx,
396        &atr_values,
397        chosen,
398        &mut trend,
399        &mut changed,
400    );
401
402    Ok(SuperTrendOutput { trend, changed })
403}
404
405#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
406#[inline]
407pub fn supertrend_into(
408    input: &SuperTrendInput,
409    trend_out: &mut [f64],
410    changed_out: &mut [f64],
411) -> Result<(), SuperTrendError> {
412    let (high, _low, _close) = input.as_hlc();
413    let len = high.len();
414
415    if trend_out.len() != len {
416        return Err(SuperTrendError::OutputLengthMismatch {
417            expected: len,
418            got: trend_out.len(),
419        });
420    }
421    if changed_out.len() != len {
422        return Err(SuperTrendError::OutputLengthMismatch {
423            expected: len,
424            got: changed_out.len(),
425        });
426    }
427
428    let (high, low, close, period, factor, first_valid_idx, atr_values, chosen) =
429        supertrend_prepare(input, Kernel::Auto)?;
430
431    let warmup_end = first_valid_idx + period - 1;
432    let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
433    for v in &mut trend_out[..warmup_end.min(len)] {
434        *v = qnan;
435    }
436    for v in &mut changed_out[..warmup_end.min(len)] {
437        *v = qnan;
438    }
439
440    supertrend_compute_into(
441        high,
442        low,
443        close,
444        period,
445        factor,
446        first_valid_idx,
447        &atr_values,
448        chosen,
449        trend_out,
450        changed_out,
451    );
452
453    Ok(())
454}
455
456#[inline(always)]
457pub fn supertrend_scalar(
458    high: &[f64],
459    low: &[f64],
460    close: &[f64],
461    period: usize,
462    factor: f64,
463    first_valid_idx: usize,
464    atr_values: &[f64],
465    trend: &mut [f64],
466    changed: &mut [f64],
467) {
468    let len = high.len();
469    let start = first_valid_idx + period;
470    if start > len {
471        return;
472    }
473
474    unsafe {
475        let h_ptr = high.as_ptr();
476        let l_ptr = low.as_ptr();
477        let c_ptr = close.as_ptr();
478        let atr_ptr = atr_values.as_ptr();
479        let tr_ptr = trend.as_mut_ptr();
480        let ch_ptr = changed.as_mut_ptr();
481
482        let warmup = start - 1;
483        let hw = *h_ptr.add(warmup);
484        let lw = *l_ptr.add(warmup);
485        let hl2_w = (hw + lw) * 0.5;
486        let atr_w = *atr_ptr.add(period - 1);
487        let mut prev_upper_band = hl2_w + factor * atr_w;
488        let mut prev_lower_band = hl2_w - factor * atr_w;
489
490        let mut last_close = *c_ptr.add(warmup);
491        let mut upper_state = if last_close <= prev_upper_band {
492            *tr_ptr.add(warmup) = prev_upper_band;
493            true
494        } else {
495            *tr_ptr.add(warmup) = prev_lower_band;
496            false
497        };
498        *ch_ptr.add(warmup) = 0.0;
499
500        let mut i = warmup + 1;
501        let mut atr_idx = i.saturating_sub(first_valid_idx);
502        let neg_factor = -factor;
503        while i < len {
504            let atr_i = *atr_ptr.add(atr_idx);
505            let hi = *h_ptr.add(i);
506            let lo = *l_ptr.add(i);
507            let hl2 = (hi + lo) * 0.5;
508            let upper_basic = factor.mul_add(atr_i, hl2);
509            let lower_basic = neg_factor.mul_add(atr_i, hl2);
510
511            let prev_close = last_close;
512            let mut curr_upper_band = upper_basic;
513            if prev_close <= prev_upper_band {
514                curr_upper_band = curr_upper_band.min(prev_upper_band);
515            }
516            let mut curr_lower_band = lower_basic;
517            if prev_close >= prev_lower_band {
518                curr_lower_band = curr_lower_band.max(prev_lower_band);
519            }
520
521            let curr_close = *c_ptr.add(i);
522            if upper_state {
523                if curr_close <= curr_upper_band {
524                    *tr_ptr.add(i) = curr_upper_band;
525                    *ch_ptr.add(i) = 0.0;
526                } else {
527                    *tr_ptr.add(i) = curr_lower_band;
528                    *ch_ptr.add(i) = 1.0;
529                    upper_state = false;
530                }
531            } else {
532                if curr_close >= curr_lower_band {
533                    *tr_ptr.add(i) = curr_lower_band;
534                    *ch_ptr.add(i) = 0.0;
535                } else {
536                    *tr_ptr.add(i) = curr_upper_band;
537                    *ch_ptr.add(i) = 1.0;
538                    upper_state = true;
539                }
540            }
541
542            prev_upper_band = curr_upper_band;
543            prev_lower_band = curr_lower_band;
544            last_close = curr_close;
545            i += 1;
546            atr_idx += 1;
547        }
548    }
549}
550
551#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
552#[inline(always)]
553pub unsafe fn supertrend_avx2(
554    high: &[f64],
555    low: &[f64],
556    close: &[f64],
557    period: usize,
558    factor: f64,
559    first_valid_idx: usize,
560    atr_values: &[f64],
561    trend: &mut [f64],
562    changed: &mut [f64],
563) {
564    supertrend_scalar(
565        high,
566        low,
567        close,
568        period,
569        factor,
570        first_valid_idx,
571        atr_values,
572        trend,
573        changed,
574    );
575}
576
577#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
578#[inline(always)]
579pub unsafe fn supertrend_avx512(
580    high: &[f64],
581    low: &[f64],
582    close: &[f64],
583    period: usize,
584    factor: f64,
585    first_valid_idx: usize,
586    atr_values: &[f64],
587    trend: &mut [f64],
588    changed: &mut [f64],
589) {
590    if period <= 32 {
591        supertrend_avx512_short(
592            high,
593            low,
594            close,
595            period,
596            factor,
597            first_valid_idx,
598            atr_values,
599            trend,
600            changed,
601        );
602    } else {
603        supertrend_avx512_long(
604            high,
605            low,
606            close,
607            period,
608            factor,
609            first_valid_idx,
610            atr_values,
611            trend,
612            changed,
613        );
614    }
615}
616
617#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
618#[inline(always)]
619pub unsafe fn supertrend_avx512_short(
620    high: &[f64],
621    low: &[f64],
622    close: &[f64],
623    period: usize,
624    factor: f64,
625    first_valid_idx: usize,
626    atr_values: &[f64],
627    trend: &mut [f64],
628    changed: &mut [f64],
629) {
630    supertrend_scalar(
631        high,
632        low,
633        close,
634        period,
635        factor,
636        first_valid_idx,
637        atr_values,
638        trend,
639        changed,
640    );
641}
642
643#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
644#[inline(always)]
645pub unsafe fn supertrend_avx512_long(
646    high: &[f64],
647    low: &[f64],
648    close: &[f64],
649    period: usize,
650    factor: f64,
651    first_valid_idx: usize,
652    atr_values: &[f64],
653    trend: &mut [f64],
654    changed: &mut [f64],
655) {
656    supertrend_scalar(
657        high,
658        low,
659        close,
660        period,
661        factor,
662        first_valid_idx,
663        atr_values,
664        trend,
665        changed,
666    );
667}
668
669#[inline]
670pub unsafe fn supertrend_scalar_classic(
671    high: &[f64],
672    low: &[f64],
673    close: &[f64],
674    period: usize,
675    factor: f64,
676    trend_out: &mut [f64],
677    changed_out: &mut [f64],
678) -> Result<(), SuperTrendError> {
679    let n = high.len();
680
681    let mut first_valid = None;
682    for i in 0..n {
683        if high[i].is_finite() && low[i].is_finite() && close[i].is_finite() {
684            first_valid = Some(i);
685            break;
686        }
687    }
688
689    let first_valid = first_valid.ok_or(SuperTrendError::AllValuesNaN)?;
690
691    if n - first_valid < period {
692        return Err(SuperTrendError::NotEnoughValidData {
693            needed: period,
694            valid: n - first_valid,
695        });
696    }
697
698    let warmup = first_valid + period - 1;
699    for i in 0..warmup.min(n) {
700        trend_out[i] = f64::NAN;
701        changed_out[i] = f64::NAN;
702    }
703
704    let mut tr_values = vec![0.0; n];
705
706    if first_valid < n {
707        tr_values[first_valid] = high[first_valid] - low[first_valid];
708    }
709
710    for i in (first_valid + 1)..n {
711        let high_low = high[i] - low[i];
712        let high_close = (high[i] - close[i - 1]).abs();
713        let low_close = (low[i] - close[i - 1]).abs();
714        tr_values[i] = high_low.max(high_close).max(low_close);
715    }
716
717    let mut atr_values = vec![f64::NAN; n];
718
719    let mut atr_sum = 0.0;
720    for i in first_valid..(first_valid + period).min(n) {
721        atr_sum += tr_values[i];
722    }
723
724    if first_valid + period <= n {
725        atr_values[first_valid + period - 1] = atr_sum / period as f64;
726
727        let alpha = 1.0 / period as f64;
728        let alpha_1minus = 1.0 - alpha;
729
730        for i in (first_valid + period)..n {
731            atr_values[i] = alpha * tr_values[i] + alpha_1minus * atr_values[i - 1];
732        }
733    }
734
735    if warmup >= n {
736        return Ok(());
737    }
738
739    let half_range = (high[warmup] + low[warmup]) / 2.0;
740    let mut prev_upper_band = factor.mul_add(atr_values[warmup], half_range);
741    let mut prev_lower_band = (-factor).mul_add(atr_values[warmup], half_range);
742
743    let mut last_close = close[warmup];
744    let mut upper_state = if last_close <= prev_upper_band {
745        trend_out[warmup] = prev_upper_band;
746        true
747    } else {
748        trend_out[warmup] = prev_lower_band;
749        false
750    };
751    changed_out[warmup] = 0.0;
752
753    for i in (warmup + 1)..n {
754        let half_range = (high[i] + low[i]) / 2.0;
755        let upper_basic = factor.mul_add(atr_values[i], half_range);
756        let lower_basic = (-factor).mul_add(atr_values[i], half_range);
757
758        let prev_close = last_close;
759        let mut curr_upper_band = upper_basic;
760        let mut curr_lower_band = lower_basic;
761        if prev_close <= prev_upper_band {
762            curr_upper_band = curr_upper_band.min(prev_upper_band);
763        }
764        if prev_close >= prev_lower_band {
765            curr_lower_band = curr_lower_band.max(prev_lower_band);
766        }
767
768        let curr_close = close[i];
769        if upper_state {
770            if curr_close <= curr_upper_band {
771                trend_out[i] = curr_upper_band;
772                changed_out[i] = 0.0;
773            } else {
774                trend_out[i] = curr_lower_band;
775                changed_out[i] = 1.0;
776                upper_state = false;
777            }
778        } else {
779            if curr_close >= curr_lower_band {
780                trend_out[i] = curr_lower_band;
781                changed_out[i] = 0.0;
782            } else {
783                trend_out[i] = curr_upper_band;
784                changed_out[i] = 1.0;
785                upper_state = true;
786            }
787        }
788
789        prev_upper_band = curr_upper_band;
790        prev_lower_band = curr_lower_band;
791        last_close = curr_close;
792    }
793
794    Ok(())
795}
796
797#[derive(Debug, Clone)]
798pub struct SuperTrendStream {
799    pub period: usize,
800    pub factor: f64,
801    atr_stream: crate::indicators::atr::AtrStream,
802
803    prev_upper_band: f64,
804    prev_lower_band: f64,
805    prev_close: f64,
806    upper_state: bool,
807    warmed: bool,
808}
809
810impl SuperTrendStream {
811    #[inline]
812    pub fn try_new(params: SuperTrendParams) -> Result<Self, SuperTrendError> {
813        let period = params.period.unwrap_or(10);
814        let factor = params.factor.unwrap_or(3.0);
815        let atr_stream = crate::indicators::atr::AtrStream::try_new(AtrParams {
816            length: Some(period),
817        })?;
818        Ok(Self {
819            period,
820            factor,
821            atr_stream,
822            prev_upper_band: f64::NAN,
823            prev_lower_band: f64::NAN,
824            prev_close: f64::NAN,
825            upper_state: false,
826            warmed: false,
827        })
828    }
829
830    #[inline(always)]
831    pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<(f64, f64)> {
832        let atr = match self.atr_stream.update(high, low, close) {
833            Some(v) => v,
834            None => return None,
835        };
836
837        let hl2 = (high + low) * 0.5;
838        let upper_basic = self.factor.mul_add(atr, hl2);
839        let lower_basic = (-self.factor).mul_add(atr, hl2);
840
841        if !self.warmed {
842            self.prev_upper_band = upper_basic;
843            self.prev_lower_band = lower_basic;
844            self.upper_state = close <= self.prev_upper_band;
845            let trend = if self.upper_state {
846                self.prev_upper_band
847            } else {
848                self.prev_lower_band
849            };
850            self.prev_close = close;
851            self.warmed = true;
852            return Some((trend, 0.0));
853        }
854
855        let mut curr_upper_band = upper_basic;
856        if self.prev_close <= self.prev_upper_band {
857            curr_upper_band = curr_upper_band.min(self.prev_upper_band);
858        }
859        let mut curr_lower_band = lower_basic;
860        if self.prev_close >= self.prev_lower_band {
861            curr_lower_band = curr_lower_band.max(self.prev_lower_band);
862        }
863
864        let mut changed = 0.0;
865        let trend = if self.upper_state {
866            if close <= curr_upper_band {
867                curr_upper_band
868            } else {
869                changed = 1.0;
870                self.upper_state = false;
871                curr_lower_band
872            }
873        } else {
874            if close >= curr_lower_band {
875                curr_lower_band
876            } else {
877                changed = 1.0;
878                self.upper_state = true;
879                curr_upper_band
880            }
881        };
882
883        self.prev_upper_band = curr_upper_band;
884        self.prev_lower_band = curr_lower_band;
885        self.prev_close = close;
886
887        Some((trend, changed))
888    }
889}
890
891#[derive(Clone, Debug)]
892pub struct SuperTrendBatchRange {
893    pub period: (usize, usize, usize),
894    pub factor: (f64, f64, f64),
895}
896impl Default for SuperTrendBatchRange {
897    fn default() -> Self {
898        Self {
899            period: (10, 259, 1),
900            factor: (3.0, 3.0, 0.0),
901        }
902    }
903}
904
905#[derive(Clone, Debug, Default)]
906pub struct SuperTrendBatchBuilder {
907    range: SuperTrendBatchRange,
908    kernel: Kernel,
909}
910impl SuperTrendBatchBuilder {
911    pub fn new() -> Self {
912        Self::default()
913    }
914    pub fn kernel(mut self, k: Kernel) -> Self {
915        self.kernel = k;
916        self
917    }
918    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
919        self.range.period = (start, end, step);
920        self
921    }
922    pub fn period_static(mut self, p: usize) -> Self {
923        self.range.period = (p, p, 0);
924        self
925    }
926    pub fn factor_range(mut self, start: f64, end: f64, step: f64) -> Self {
927        self.range.factor = (start, end, step);
928        self
929    }
930    pub fn factor_static(mut self, x: f64) -> Self {
931        self.range.factor = (x, x, 0.0);
932        self
933    }
934    pub fn apply_slices(
935        self,
936        high: &[f64],
937        low: &[f64],
938        close: &[f64],
939    ) -> Result<SuperTrendBatchOutput, SuperTrendError> {
940        supertrend_batch_with_kernel(high, low, close, &self.range, self.kernel)
941    }
942    pub fn apply_candles(self, c: &Candles) -> Result<SuperTrendBatchOutput, SuperTrendError> {
943        let high = source_type(c, "high");
944        let low = source_type(c, "low");
945        let close = source_type(c, "close");
946        self.apply_slices(high, low, close)
947    }
948    pub fn with_default_candles(
949        c: &Candles,
950        k: Kernel,
951    ) -> Result<SuperTrendBatchOutput, SuperTrendError> {
952        SuperTrendBatchBuilder::new().kernel(k).apply_candles(c)
953    }
954}
955
956pub struct SuperTrendBatchOutput {
957    pub trend: Vec<f64>,
958    pub changed: Vec<f64>,
959    pub combos: Vec<SuperTrendParams>,
960    pub rows: usize,
961    pub cols: usize,
962}
963impl SuperTrendBatchOutput {
964    pub fn row_for_params(&self, p: &SuperTrendParams) -> Option<usize> {
965        self.combos.iter().position(|c| {
966            c.period.unwrap_or(10) == p.period.unwrap_or(10)
967                && (c.factor.unwrap_or(3.0) - p.factor.unwrap_or(3.0)).abs() < 1e-12
968        })
969    }
970    pub fn trend_for(&self, p: &SuperTrendParams) -> Option<&[f64]> {
971        self.row_for_params(p).map(|row| {
972            let start = row * self.cols;
973            &self.trend[start..start + self.cols]
974        })
975    }
976    pub fn changed_for(&self, p: &SuperTrendParams) -> Option<&[f64]> {
977        self.row_for_params(p).map(|row| {
978            let start = row * self.cols;
979            &self.changed[start..start + self.cols]
980        })
981    }
982}
983
984#[cfg(all(feature = "python", feature = "cuda"))]
985#[pyclass(module = "ta_indicators.cuda", unsendable)]
986pub struct SupertrendDeviceArrayF32Py {
987    pub(crate) inner: DeviceArrayF32,
988    pub(crate) _ctx: Arc<Context>,
989    pub(crate) device_id: u32,
990}
991
992#[cfg(all(feature = "python", feature = "cuda"))]
993#[pymethods]
994impl SupertrendDeviceArrayF32Py {
995    #[getter]
996    fn __cuda_array_interface__<'py>(
997        &self,
998        py: Python<'py>,
999    ) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1000        let d = pyo3::types::PyDict::new(py);
1001        d.set_item("shape", (self.inner.rows, self.inner.cols))?;
1002        d.set_item("typestr", "<f4")?;
1003        d.set_item(
1004            "strides",
1005            (
1006                self.inner.cols * std::mem::size_of::<f32>(),
1007                std::mem::size_of::<f32>(),
1008            ),
1009        )?;
1010        d.set_item("data", (self.inner.device_ptr() as usize, false))?;
1011
1012        d.set_item("version", 3)?;
1013        Ok(d)
1014    }
1015
1016    fn __dlpack_device__(&self) -> (i32, i32) {
1017        (2, self.device_id as i32)
1018    }
1019
1020    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1021    fn __dlpack__<'py>(
1022        &mut self,
1023        py: Python<'py>,
1024        stream: Option<PyObject>,
1025        max_version: Option<PyObject>,
1026        dl_device: Option<PyObject>,
1027        copy: Option<PyObject>,
1028    ) -> PyResult<PyObject> {
1029        use cust::memory::DeviceBuffer;
1030
1031        let (kdl, alloc_dev) = self.__dlpack_device__();
1032        if let Some(dev_obj) = dl_device.as_ref() {
1033            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1034                if dev_ty != kdl || dev_id != alloc_dev {
1035                    let wants_copy = copy
1036                        .as_ref()
1037                        .and_then(|c| c.extract::<bool>(py).ok())
1038                        .unwrap_or(false);
1039                    if wants_copy {
1040                        return Err(PyBufferError::new_err(
1041                            "device copy not implemented for __dlpack__",
1042                        ));
1043                    } else {
1044                        return Err(PyBufferError::new_err(
1045                            "__dlpack__: requested device does not match producer buffer",
1046                        ));
1047                    }
1048                }
1049            }
1050        }
1051        let _ = stream;
1052
1053        if let Some(copy_obj) = copy.as_ref() {
1054            let do_copy: bool = copy_obj.extract(py)?;
1055            if do_copy {
1056                return Err(PyBufferError::new_err(
1057                    "__dlpack__(copy=True) not supported for supertrend CUDA buffers",
1058                ));
1059            }
1060        }
1061
1062        let dummy =
1063            DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
1064        let rows = self.inner.rows;
1065        let cols = self.inner.cols;
1066        let inner = std::mem::replace(
1067            &mut self.inner,
1068            DeviceArrayF32 {
1069                buf: dummy,
1070                rows: 0,
1071                cols: 0,
1072            },
1073        );
1074
1075        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1076
1077        export_f32_cuda_dlpack_2d(py, inner.buf, rows, cols, alloc_dev, max_version_bound)
1078    }
1079}
1080
1081#[cfg(all(feature = "python", feature = "cuda"))]
1082#[pyfunction(name = "supertrend_cuda_batch_dev")]
1083#[pyo3(signature = (high, low, close, period_range, factor_range, device_id=0))]
1084pub fn supertrend_cuda_batch_dev_py<'py>(
1085    py: Python<'py>,
1086    high: numpy::PyReadonlyArray1<'py, f64>,
1087    low: numpy::PyReadonlyArray1<'py, f64>,
1088    close: numpy::PyReadonlyArray1<'py, f64>,
1089    period_range: (usize, usize, usize),
1090    factor_range: (f64, f64, f64),
1091    device_id: usize,
1092) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1093    use numpy::IntoPyArray;
1094    if !cuda_available() {
1095        return Err(PyValueError::new_err("CUDA not available"));
1096    }
1097    let h = high.as_slice()?;
1098    let l = low.as_slice()?;
1099    let c = close.as_slice()?;
1100    let sweep = SuperTrendBatchRange {
1101        period: period_range,
1102        factor: factor_range,
1103    };
1104    let (trend, changed, combos, ctx_arc, dev_id) = py.allow_threads(|| -> PyResult<_> {
1105        let cuda =
1106            CudaSupertrend::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1107        let h32: Vec<f32> = h.iter().map(|&v| v as f32).collect();
1108        let l32: Vec<f32> = l.iter().map(|&v| v as f32).collect();
1109        let c32: Vec<f32> = c.iter().map(|&v| v as f32).collect();
1110        let (trend, changed, combos) = cuda
1111            .supertrend_batch_dev(&h32, &l32, &c32, &sweep)
1112            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1113        let ctx_arc = cuda.context_arc();
1114        let dev_id = cuda.device_id();
1115        Ok((trend, changed, combos, ctx_arc, dev_id))
1116    })?;
1117
1118    let dict = pyo3::types::PyDict::new(py);
1119    dict.set_item(
1120        "trend",
1121        Py::new(
1122            py,
1123            SupertrendDeviceArrayF32Py {
1124                inner: trend,
1125                _ctx: ctx_arc.clone(),
1126                device_id: dev_id,
1127            },
1128        )?,
1129    )?;
1130    dict.set_item(
1131        "changed",
1132        Py::new(
1133            py,
1134            SupertrendDeviceArrayF32Py {
1135                inner: changed,
1136                _ctx: ctx_arc,
1137                device_id: dev_id,
1138            },
1139        )?,
1140    )?;
1141    let periods: Vec<usize> = combos.iter().map(|p| p.period.unwrap()).collect();
1142    let factors: Vec<f64> = combos.iter().map(|p| p.factor.unwrap()).collect();
1143    dict.set_item("periods", periods.into_pyarray(py))?;
1144    dict.set_item("factors", factors.into_pyarray(py))?;
1145    Ok(dict)
1146}
1147
1148#[cfg(all(feature = "python", feature = "cuda"))]
1149#[pyfunction(name = "supertrend_cuda_many_series_one_param_dev")]
1150#[pyo3(signature = (high_tm, low_tm, close_tm, cols, rows, period, factor, device_id=0))]
1151pub fn supertrend_cuda_many_series_one_param_dev_py<'py>(
1152    py: Python<'py>,
1153    high_tm: numpy::PyReadonlyArray1<'py, f64>,
1154    low_tm: numpy::PyReadonlyArray1<'py, f64>,
1155    close_tm: numpy::PyReadonlyArray1<'py, f64>,
1156    cols: usize,
1157    rows: usize,
1158    period: usize,
1159    factor: f64,
1160    device_id: usize,
1161) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1162    use numpy::IntoPyArray;
1163    if !cuda_available() {
1164        return Err(PyValueError::new_err("CUDA not available"));
1165    }
1166    let h = high_tm.as_slice()?;
1167    let l = low_tm.as_slice()?;
1168    let c = close_tm.as_slice()?;
1169    if h.len() != l.len() || l.len() != c.len() {
1170        return Err(PyValueError::new_err("length mismatch"));
1171    }
1172    let h32: Vec<f32> = h.iter().map(|&v| v as f32).collect();
1173    let l32: Vec<f32> = l.iter().map(|&v| v as f32).collect();
1174    let c32: Vec<f32> = c.iter().map(|&v| v as f32).collect();
1175    let (out, ctx_arc, dev_id) = py.allow_threads(|| -> PyResult<_> {
1176        let cuda =
1177            CudaSupertrend::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1178        let out = cuda
1179            .supertrend_many_series_one_param_time_major_dev(
1180                &h32,
1181                &l32,
1182                &c32,
1183                cols,
1184                rows,
1185                period,
1186                factor as f32,
1187            )
1188            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1189        let ctx_arc = cuda.context_arc();
1190        let dev_id = cuda.device_id();
1191        Ok((out, ctx_arc, dev_id))
1192    })?;
1193
1194    let dict = pyo3::types::PyDict::new(py);
1195    dict.set_item(
1196        "trend",
1197        Py::new(
1198            py,
1199            SupertrendDeviceArrayF32Py {
1200                inner: out.plus,
1201                _ctx: ctx_arc.clone(),
1202                device_id: dev_id,
1203            },
1204        )?,
1205    )?;
1206    dict.set_item(
1207        "changed",
1208        Py::new(
1209            py,
1210            SupertrendDeviceArrayF32Py {
1211                inner: out.minus,
1212                _ctx: ctx_arc,
1213                device_id: dev_id,
1214            },
1215        )?,
1216    )?;
1217    dict.set_item("cols", cols)?;
1218    dict.set_item("rows", rows)?;
1219    Ok(dict)
1220}
1221
1222#[inline(always)]
1223fn expand_grid(r: &SuperTrendBatchRange) -> Result<Vec<SuperTrendParams>, SuperTrendError> {
1224    fn axis_usize(
1225        (start, end, step): (usize, usize, usize),
1226    ) -> Result<Vec<usize>, SuperTrendError> {
1227        if step == 0 || start == end {
1228            return Ok(vec![start]);
1229        }
1230        if start < end {
1231            let v: Vec<usize> = (start..=end).step_by(step.max(1)).collect();
1232            if v.is_empty() {
1233                return Err(SuperTrendError::InvalidRange { start, end, step });
1234            }
1235            return Ok(v);
1236        }
1237        let mut v = Vec::new();
1238        let mut cur = start;
1239        let st = step.max(1);
1240        while cur >= end {
1241            v.push(cur);
1242            let next = cur.saturating_sub(st);
1243            if next == cur {
1244                break;
1245            }
1246            cur = next;
1247        }
1248        if v.is_empty() {
1249            return Err(SuperTrendError::InvalidRange { start, end, step });
1250        }
1251        Ok(v)
1252    }
1253    fn axis_f64((start, end, step): (f64, f64, f64)) -> Result<Vec<f64>, SuperTrendError> {
1254        if step.abs() < 1e-12 || (start - end).abs() < 1e-12 {
1255            return Ok(vec![start]);
1256        }
1257        let st = step.abs();
1258        if start < end {
1259            let mut v = Vec::new();
1260            let mut x = start;
1261            while x <= end + 1e-12 {
1262                v.push(x);
1263                x += st;
1264            }
1265            if v.is_empty() {
1266                return Err(SuperTrendError::InvalidFactorRange { start, end, step });
1267            }
1268            return Ok(v);
1269        }
1270        let mut v = Vec::new();
1271        let mut x = start;
1272        while x + 1e-12 >= end {
1273            v.push(x);
1274            x -= st;
1275        }
1276        if v.is_empty() {
1277            return Err(SuperTrendError::InvalidFactorRange { start, end, step });
1278        }
1279        Ok(v)
1280    }
1281    let periods = axis_usize(r.period)?;
1282    let factors = axis_f64(r.factor)?;
1283    let cap = periods
1284        .len()
1285        .checked_mul(factors.len())
1286        .ok_or(SuperTrendError::InvalidRange {
1287            start: r.period.0,
1288            end: r.period.1,
1289            step: r.period.2,
1290        })?;
1291    let mut out = Vec::with_capacity(cap);
1292    for &p in &periods {
1293        for &f in &factors {
1294            out.push(SuperTrendParams {
1295                period: Some(p),
1296                factor: Some(f),
1297            });
1298        }
1299    }
1300    Ok(out)
1301}
1302
1303pub fn supertrend_batch_with_kernel(
1304    high: &[f64],
1305    low: &[f64],
1306    close: &[f64],
1307    sweep: &SuperTrendBatchRange,
1308    k: Kernel,
1309) -> Result<SuperTrendBatchOutput, SuperTrendError> {
1310    let kernel = match k {
1311        Kernel::Auto => detect_best_batch_kernel(),
1312        other if other.is_batch() => other,
1313        _ => {
1314            return Err(SuperTrendError::InvalidKernelForBatch(k));
1315        }
1316    };
1317    let simd = match kernel {
1318        Kernel::Avx512Batch => Kernel::Avx512,
1319        Kernel::Avx2Batch => Kernel::Avx2,
1320        Kernel::ScalarBatch => Kernel::Scalar,
1321        _ => unreachable!(),
1322    };
1323    supertrend_batch_par_slice(high, low, close, sweep, simd)
1324}
1325
1326#[inline(always)]
1327pub fn supertrend_batch_slice(
1328    high: &[f64],
1329    low: &[f64],
1330    close: &[f64],
1331    sweep: &SuperTrendBatchRange,
1332    kern: Kernel,
1333) -> Result<SuperTrendBatchOutput, SuperTrendError> {
1334    supertrend_batch_inner(high, low, close, sweep, kern, false)
1335}
1336
1337#[inline(always)]
1338pub fn supertrend_batch_par_slice(
1339    high: &[f64],
1340    low: &[f64],
1341    close: &[f64],
1342    sweep: &SuperTrendBatchRange,
1343    kern: Kernel,
1344) -> Result<SuperTrendBatchOutput, SuperTrendError> {
1345    supertrend_batch_inner(high, low, close, sweep, kern, true)
1346}
1347
1348#[inline(always)]
1349fn supertrend_batch_inner(
1350    high: &[f64],
1351    low: &[f64],
1352    close: &[f64],
1353    sweep: &SuperTrendBatchRange,
1354    kern: Kernel,
1355    parallel: bool,
1356) -> Result<SuperTrendBatchOutput, SuperTrendError> {
1357    let combos = expand_grid(sweep)?;
1358    if combos.is_empty() {
1359        return Err(SuperTrendError::InvalidRange {
1360            start: sweep.period.0,
1361            end: sweep.period.1,
1362            step: sweep.period.2,
1363        });
1364    }
1365    let len = high.len();
1366    let mut first_valid_idx = None;
1367    for i in 0..len {
1368        if !high[i].is_nan() && !low[i].is_nan() && !close[i].is_nan() {
1369            first_valid_idx = Some(i);
1370            break;
1371        }
1372    }
1373    let first_valid_idx = match first_valid_idx {
1374        Some(idx) => idx,
1375        None => return Err(SuperTrendError::AllValuesNaN),
1376    };
1377    let max_p = combos.iter().map(|c| c.period.unwrap_or(10)).max().unwrap();
1378    if len - first_valid_idx < max_p {
1379        return Err(SuperTrendError::NotEnoughValidData {
1380            needed: max_p,
1381            valid: len - first_valid_idx,
1382        });
1383    }
1384    let rows = combos.len();
1385    let cols = len;
1386
1387    rows.checked_mul(cols)
1388        .ok_or(SuperTrendError::InvalidRange {
1389            start: sweep.period.0,
1390            end: sweep.period.1,
1391            step: sweep.period.2,
1392        })?;
1393
1394    let mut trend_mu = make_uninit_matrix(rows, cols);
1395    let mut changed_mu = make_uninit_matrix(rows, cols);
1396
1397    let warm: Vec<usize> = combos
1398        .iter()
1399        .map(|c| first_valid_idx + c.period.unwrap_or(10) - 1)
1400        .collect();
1401
1402    init_matrix_prefixes(&mut trend_mu, cols, &warm);
1403    init_matrix_prefixes(&mut changed_mu, cols, &warm);
1404
1405    let mut trend_guard = core::mem::ManuallyDrop::new(trend_mu);
1406    let mut changed_guard = core::mem::ManuallyDrop::new(changed_mu);
1407
1408    let trend: &mut [f64] = unsafe {
1409        core::slice::from_raw_parts_mut(trend_guard.as_mut_ptr() as *mut f64, trend_guard.len())
1410    };
1411    let changed: &mut [f64] = unsafe {
1412        core::slice::from_raw_parts_mut(changed_guard.as_mut_ptr() as *mut f64, changed_guard.len())
1413    };
1414
1415    let mut atr_cache: HashMap<usize, Vec<f64>> = HashMap::new();
1416    {
1417        let mut periods: Vec<usize> = combos.iter().map(|c| c.period.unwrap()).collect();
1418        periods.sort_unstable();
1419        periods.dedup();
1420        for &p in &periods {
1421            let atr_input = AtrInput::from_slices(
1422                &high[first_valid_idx..],
1423                &low[first_valid_idx..],
1424                &close[first_valid_idx..],
1425                AtrParams { length: Some(p) },
1426            );
1427            let AtrOutput { values } = atr(&atr_input)?;
1428            atr_cache.insert(p, values);
1429        }
1430    }
1431
1432    let hl2: Vec<f64> = (0..len).map(|i| 0.5 * (high[i] + low[i])).collect();
1433
1434    let do_row = |row: usize, trend_row: &mut [f64], changed_row: &mut [f64]| unsafe {
1435        let period = combos[row].period.unwrap();
1436        let factor = combos[row].factor.unwrap();
1437        let atr_values = atr_cache.get(&period).unwrap().as_slice();
1438        match kern {
1439            Kernel::Scalar => supertrend_row_scalar_from_hl(
1440                &hl2,
1441                close,
1442                period,
1443                factor,
1444                first_valid_idx,
1445                atr_values,
1446                trend_row,
1447                changed_row,
1448            ),
1449            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1450            Kernel::Avx2 | Kernel::Avx512 => supertrend_row_scalar_from_hl(
1451                &hl2,
1452                close,
1453                period,
1454                factor,
1455                first_valid_idx,
1456                atr_values,
1457                trend_row,
1458                changed_row,
1459            ),
1460            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1461            Kernel::Avx2 => supertrend_row_avx2(
1462                high,
1463                low,
1464                close,
1465                period,
1466                factor,
1467                first_valid_idx,
1468                atr_values,
1469                trend_row,
1470                changed_row,
1471            ),
1472            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1473            Kernel::Avx512 => supertrend_row_avx512(
1474                high,
1475                low,
1476                close,
1477                period,
1478                factor,
1479                first_valid_idx,
1480                atr_values,
1481                trend_row,
1482                changed_row,
1483            ),
1484            _ => unreachable!(),
1485        }
1486    };
1487    if parallel {
1488        #[cfg(not(target_arch = "wasm32"))]
1489        {
1490            trend
1491                .par_chunks_mut(cols)
1492                .zip(changed.par_chunks_mut(cols))
1493                .enumerate()
1494                .for_each(|(row, (tr, ch))| do_row(row, tr, ch));
1495        }
1496
1497        #[cfg(target_arch = "wasm32")]
1498        {
1499            for (row, (tr, ch)) in trend
1500                .chunks_mut(cols)
1501                .zip(changed.chunks_mut(cols))
1502                .enumerate()
1503            {
1504                do_row(row, tr, ch);
1505            }
1506        }
1507    } else {
1508        for (row, (tr, ch)) in trend
1509            .chunks_mut(cols)
1510            .zip(changed.chunks_mut(cols))
1511            .enumerate()
1512        {
1513            do_row(row, tr, ch);
1514        }
1515    }
1516
1517    let trend_vec = unsafe {
1518        Vec::from_raw_parts(
1519            trend_guard.as_mut_ptr() as *mut f64,
1520            trend_guard.len(),
1521            trend_guard.capacity(),
1522        )
1523    };
1524    let changed_vec = unsafe {
1525        Vec::from_raw_parts(
1526            changed_guard.as_mut_ptr() as *mut f64,
1527            changed_guard.len(),
1528            changed_guard.capacity(),
1529        )
1530    };
1531
1532    Ok(SuperTrendBatchOutput {
1533        trend: trend_vec,
1534        changed: changed_vec,
1535        combos,
1536        rows,
1537        cols,
1538    })
1539}
1540
1541#[inline(always)]
1542unsafe fn supertrend_row_scalar(
1543    high: &[f64],
1544    low: &[f64],
1545    close: &[f64],
1546    period: usize,
1547    factor: f64,
1548    first_valid_idx: usize,
1549    atr_values: &[f64],
1550    trend: &mut [f64],
1551    changed: &mut [f64],
1552) {
1553    supertrend_scalar(
1554        high,
1555        low,
1556        close,
1557        period,
1558        factor,
1559        first_valid_idx,
1560        atr_values,
1561        trend,
1562        changed,
1563    );
1564}
1565
1566#[inline(always)]
1567unsafe fn supertrend_row_scalar_from_hl(
1568    hl2: &[f64],
1569    close: &[f64],
1570    period: usize,
1571    factor: f64,
1572    first_valid_idx: usize,
1573    atr_values: &[f64],
1574    trend: &mut [f64],
1575    changed: &mut [f64],
1576) {
1577    let len = hl2.len();
1578    let start = first_valid_idx + period;
1579    if start > len {
1580        return;
1581    }
1582
1583    let hl_ptr = hl2.as_ptr();
1584    let c_ptr = close.as_ptr();
1585    let atr_ptr = atr_values.as_ptr();
1586    let tr_ptr = trend.as_mut_ptr();
1587    let ch_ptr = changed.as_mut_ptr();
1588
1589    let warmup = start - 1;
1590    let hl2_w = *hl_ptr.add(warmup);
1591    let atr_w = *atr_ptr.add(period - 1);
1592    let mut prev_upper_band = factor.mul_add(atr_w, hl2_w);
1593    let mut prev_lower_band = (-factor).mul_add(atr_w, hl2_w);
1594
1595    let mut last_close = *c_ptr.add(warmup);
1596    let mut upper_state = if last_close <= prev_upper_band {
1597        *tr_ptr.add(warmup) = prev_upper_band;
1598        true
1599    } else {
1600        *tr_ptr.add(warmup) = prev_lower_band;
1601        false
1602    };
1603    *ch_ptr.add(warmup) = 0.0;
1604
1605    let mut i = warmup + 1;
1606    while i < len {
1607        let atr_i = *atr_ptr.add(i - first_valid_idx);
1608        let hl = *hl_ptr.add(i);
1609        let upper_basic = factor.mul_add(atr_i, hl);
1610        let lower_basic = (-factor).mul_add(atr_i, hl);
1611
1612        let prev_close = last_close;
1613        let mut curr_upper_band = upper_basic;
1614        if prev_close <= prev_upper_band {
1615            curr_upper_band = curr_upper_band.min(prev_upper_band);
1616        }
1617        let mut curr_lower_band = lower_basic;
1618        if prev_close >= prev_lower_band {
1619            curr_lower_band = curr_lower_band.max(prev_lower_band);
1620        }
1621
1622        let curr_close = *c_ptr.add(i);
1623        if upper_state {
1624            if curr_close <= curr_upper_band {
1625                *tr_ptr.add(i) = curr_upper_band;
1626                *ch_ptr.add(i) = 0.0;
1627            } else {
1628                *tr_ptr.add(i) = curr_lower_band;
1629                *ch_ptr.add(i) = 1.0;
1630                upper_state = false;
1631            }
1632        } else {
1633            if curr_close >= curr_lower_band {
1634                *tr_ptr.add(i) = curr_lower_band;
1635                *ch_ptr.add(i) = 0.0;
1636            } else {
1637                *tr_ptr.add(i) = curr_upper_band;
1638                *ch_ptr.add(i) = 1.0;
1639                upper_state = true;
1640            }
1641        }
1642
1643        prev_upper_band = curr_upper_band;
1644        prev_lower_band = curr_lower_band;
1645        last_close = curr_close;
1646        i += 1;
1647    }
1648}
1649
1650#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1651#[inline(always)]
1652unsafe fn supertrend_row_avx2(
1653    high: &[f64],
1654    low: &[f64],
1655    close: &[f64],
1656    period: usize,
1657    factor: f64,
1658    first_valid_idx: usize,
1659    atr_values: &[f64],
1660    trend: &mut [f64],
1661    changed: &mut [f64],
1662) {
1663    supertrend_scalar(
1664        high,
1665        low,
1666        close,
1667        period,
1668        factor,
1669        first_valid_idx,
1670        atr_values,
1671        trend,
1672        changed,
1673    );
1674}
1675
1676#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1677#[inline(always)]
1678unsafe fn supertrend_row_avx512(
1679    high: &[f64],
1680    low: &[f64],
1681    close: &[f64],
1682    period: usize,
1683    factor: f64,
1684    first_valid_idx: usize,
1685    atr_values: &[f64],
1686    trend: &mut [f64],
1687    changed: &mut [f64],
1688) {
1689    if period <= 32 {
1690        supertrend_row_avx512_short(
1691            high,
1692            low,
1693            close,
1694            period,
1695            factor,
1696            first_valid_idx,
1697            atr_values,
1698            trend,
1699            changed,
1700        );
1701    } else {
1702        supertrend_row_avx512_long(
1703            high,
1704            low,
1705            close,
1706            period,
1707            factor,
1708            first_valid_idx,
1709            atr_values,
1710            trend,
1711            changed,
1712        );
1713    }
1714}
1715
1716#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1717#[inline(always)]
1718unsafe fn supertrend_row_avx512_short(
1719    high: &[f64],
1720    low: &[f64],
1721    close: &[f64],
1722    period: usize,
1723    factor: f64,
1724    first_valid_idx: usize,
1725    atr_values: &[f64],
1726    trend: &mut [f64],
1727    changed: &mut [f64],
1728) {
1729    supertrend_scalar(
1730        high,
1731        low,
1732        close,
1733        period,
1734        factor,
1735        first_valid_idx,
1736        atr_values,
1737        trend,
1738        changed,
1739    );
1740}
1741
1742#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1743#[inline(always)]
1744unsafe fn supertrend_row_avx512_long(
1745    high: &[f64],
1746    low: &[f64],
1747    close: &[f64],
1748    period: usize,
1749    factor: f64,
1750    first_valid_idx: usize,
1751    atr_values: &[f64],
1752    trend: &mut [f64],
1753    changed: &mut [f64],
1754) {
1755    supertrend_scalar(
1756        high,
1757        low,
1758        close,
1759        period,
1760        factor,
1761        first_valid_idx,
1762        atr_values,
1763        trend,
1764        changed,
1765    );
1766}
1767
1768#[cfg(feature = "python")]
1769#[inline(always)]
1770pub fn supertrend_batch_inner_into(
1771    high: &[f64],
1772    low: &[f64],
1773    close: &[f64],
1774    sweep: &SuperTrendBatchRange,
1775    simd: Kernel,
1776    parallel: bool,
1777    trend_out: &mut [f64],
1778    changed_out: &mut [f64],
1779) -> Result<Vec<SuperTrendParams>, SuperTrendError> {
1780    let combos = expand_grid(sweep)?;
1781    if combos.is_empty() {
1782        return Err(SuperTrendError::InvalidRange {
1783            start: sweep.period.0,
1784            end: sweep.period.1,
1785            step: sweep.period.2,
1786        });
1787    }
1788    let len = high.len();
1789    let mut first_valid_idx = None;
1790    for i in 0..len {
1791        if !high[i].is_nan() && !low[i].is_nan() && !close[i].is_nan() {
1792            first_valid_idx = Some(i);
1793            break;
1794        }
1795    }
1796    let first_valid_idx = match first_valid_idx {
1797        Some(idx) => idx,
1798        None => return Err(SuperTrendError::AllValuesNaN),
1799    };
1800    let max_p = combos.iter().map(|c| c.period.unwrap_or(10)).max().unwrap();
1801    if len - first_valid_idx < max_p {
1802        return Err(SuperTrendError::NotEnoughValidData {
1803            needed: max_p,
1804            valid: len - first_valid_idx,
1805        });
1806    }
1807    let rows = combos.len();
1808    let cols = len;
1809
1810    let expected_len = rows
1811        .checked_mul(cols)
1812        .ok_or(SuperTrendError::InvalidRange {
1813            start: sweep.period.0,
1814            end: sweep.period.1,
1815            step: sweep.period.2,
1816        })?;
1817    if trend_out.len() != expected_len {
1818        return Err(SuperTrendError::OutputLengthMismatch {
1819            expected: expected_len,
1820            got: trend_out.len(),
1821        });
1822    }
1823    if changed_out.len() != expected_len {
1824        return Err(SuperTrendError::OutputLengthMismatch {
1825            expected: expected_len,
1826            got: changed_out.len(),
1827        });
1828    }
1829
1830    for (row, combo) in combos.iter().enumerate() {
1831        let warmup = first_valid_idx + combo.period.unwrap_or(10) - 1;
1832        let row_start = row * cols;
1833        for i in 0..warmup.min(cols) {
1834            trend_out[row_start + i] = f64::NAN;
1835            changed_out[row_start + i] = f64::NAN;
1836        }
1837    }
1838
1839    let hl2: Vec<f64> = (0..len).map(|i| 0.5 * (high[i] + low[i])).collect();
1840
1841    let do_row = |row: usize, trend_row: &mut [f64], changed_row: &mut [f64]| unsafe {
1842        let period = combos[row].period.unwrap();
1843        let factor = combos[row].factor.unwrap();
1844        let atr_input = AtrInput::from_slices(
1845            &high[first_valid_idx..],
1846            &low[first_valid_idx..],
1847            &close[first_valid_idx..],
1848            AtrParams {
1849                length: Some(period),
1850            },
1851        );
1852        let AtrOutput { values: atr_values } = atr(&atr_input).unwrap();
1853        match simd {
1854            Kernel::Scalar => supertrend_row_scalar_from_hl(
1855                &hl2,
1856                close,
1857                period,
1858                factor,
1859                first_valid_idx,
1860                &atr_values,
1861                trend_row,
1862                changed_row,
1863            ),
1864            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1865            Kernel::Avx2 => supertrend_row_avx2(
1866                high,
1867                low,
1868                close,
1869                period,
1870                factor,
1871                first_valid_idx,
1872                &atr_values,
1873                trend_row,
1874                changed_row,
1875            ),
1876            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1877            Kernel::Avx512 => supertrend_row_avx512(
1878                high,
1879                low,
1880                close,
1881                period,
1882                factor,
1883                first_valid_idx,
1884                &atr_values,
1885                trend_row,
1886                changed_row,
1887            ),
1888            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1889            Kernel::Avx2 | Kernel::Avx512 => supertrend_row_scalar_from_hl(
1890                &hl2,
1891                close,
1892                period,
1893                factor,
1894                first_valid_idx,
1895                &atr_values,
1896                trend_row,
1897                changed_row,
1898            ),
1899            _ => unreachable!(),
1900        }
1901    };
1902    if parallel {
1903        #[cfg(not(target_arch = "wasm32"))]
1904        {
1905            trend_out
1906                .par_chunks_mut(cols)
1907                .zip(changed_out.par_chunks_mut(cols))
1908                .enumerate()
1909                .for_each(|(row, (tr, ch))| do_row(row, tr, ch));
1910        }
1911
1912        #[cfg(target_arch = "wasm32")]
1913        {
1914            for (row, (tr, ch)) in trend_out
1915                .chunks_mut(cols)
1916                .zip(changed_out.chunks_mut(cols))
1917                .enumerate()
1918            {
1919                do_row(row, tr, ch);
1920            }
1921        }
1922    } else {
1923        for (row, (tr, ch)) in trend_out
1924            .chunks_mut(cols)
1925            .zip(changed_out.chunks_mut(cols))
1926            .enumerate()
1927        {
1928            do_row(row, tr, ch);
1929        }
1930    }
1931    Ok(combos)
1932}
1933
1934#[cfg(feature = "python")]
1935#[pyfunction(name = "supertrend")]
1936#[pyo3(signature = (high, low, close, period, factor, kernel=None))]
1937pub fn supertrend_py<'py>(
1938    py: Python<'py>,
1939    high: numpy::PyReadonlyArray1<'py, f64>,
1940    low: numpy::PyReadonlyArray1<'py, f64>,
1941    close: numpy::PyReadonlyArray1<'py, f64>,
1942    period: usize,
1943    factor: f64,
1944    kernel: Option<&str>,
1945) -> PyResult<(
1946    Bound<'py, numpy::PyArray1<f64>>,
1947    Bound<'py, numpy::PyArray1<f64>>,
1948)> {
1949    use numpy::{IntoPyArray, PyArrayMethods};
1950
1951    let high_slice = high.as_slice()?;
1952    let low_slice = low.as_slice()?;
1953    let close_slice = close.as_slice()?;
1954    let kern = validate_kernel(kernel, false)?;
1955
1956    let params = SuperTrendParams {
1957        period: Some(period),
1958        factor: Some(factor),
1959    };
1960    let input = SuperTrendInput::from_slices(high_slice, low_slice, close_slice, params);
1961
1962    let (trend_vec, changed_vec) = py
1963        .allow_threads(|| supertrend_with_kernel(&input, kern).map(|o| (o.trend, o.changed)))
1964        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1965
1966    Ok((trend_vec.into_pyarray(py), changed_vec.into_pyarray(py)))
1967}
1968
1969#[cfg(feature = "python")]
1970#[pyfunction(name = "supertrend_batch")]
1971#[pyo3(signature = (high, low, close, period_range, factor_range, kernel=None))]
1972pub fn supertrend_batch_py<'py>(
1973    py: Python<'py>,
1974    high: numpy::PyReadonlyArray1<'py, f64>,
1975    low: numpy::PyReadonlyArray1<'py, f64>,
1976    close: numpy::PyReadonlyArray1<'py, f64>,
1977    period_range: (usize, usize, usize),
1978    factor_range: (f64, f64, f64),
1979    kernel: Option<&str>,
1980) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1981    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1982    use pyo3::types::PyDict;
1983
1984    let high_slice = high.as_slice()?;
1985    let low_slice = low.as_slice()?;
1986    let close_slice = close.as_slice()?;
1987    let kern = validate_kernel(kernel, true)?;
1988
1989    let sweep = SuperTrendBatchRange {
1990        period: period_range,
1991        factor: factor_range,
1992    };
1993
1994    let grid_combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1995    if grid_combos.is_empty() {
1996        return Err(PyValueError::new_err(format!(
1997            "supertrend: Invalid range: start={}, end={}, step={}",
1998            sweep.period.0, sweep.period.1, sweep.period.2
1999        )));
2000    }
2001    let rows = grid_combos.len();
2002    let cols = high_slice.len();
2003    let total = rows
2004        .checked_mul(cols)
2005        .ok_or_else(|| PyValueError::new_err("supertrend: rows*cols overflow"))?;
2006
2007    let trend_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2008    let trend_out = unsafe { trend_arr.as_slice_mut()? };
2009    let changed_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2010    let changed_out = unsafe { changed_arr.as_slice_mut()? };
2011
2012    let combos = py
2013        .allow_threads(|| {
2014            let kernel = match kern {
2015                Kernel::Auto => detect_best_batch_kernel(),
2016                k => k,
2017            };
2018            let simd = match kernel {
2019                Kernel::Avx512Batch => Kernel::Avx512,
2020                Kernel::Avx2Batch => Kernel::Avx2,
2021                Kernel::ScalarBatch => Kernel::Scalar,
2022                _ => unreachable!(),
2023            };
2024            supertrend_batch_inner_into(
2025                high_slice,
2026                low_slice,
2027                close_slice,
2028                &sweep,
2029                simd,
2030                true,
2031                trend_out,
2032                changed_out,
2033            )
2034        })
2035        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2036
2037    let dict = PyDict::new(py);
2038    dict.set_item("trend", trend_arr.reshape((rows, cols))?)?;
2039    dict.set_item("changed", changed_arr.reshape((rows, cols))?)?;
2040    dict.set_item(
2041        "periods",
2042        combos
2043            .iter()
2044            .map(|p| p.period.unwrap() as u64)
2045            .collect::<Vec<_>>()
2046            .into_pyarray(py),
2047    )?;
2048    dict.set_item(
2049        "factors",
2050        combos
2051            .iter()
2052            .map(|p| p.factor.unwrap())
2053            .collect::<Vec<_>>()
2054            .into_pyarray(py),
2055    )?;
2056    dict.set_item("rows", rows)?;
2057    dict.set_item("cols", cols)?;
2058
2059    Ok(dict)
2060}
2061
2062#[cfg(feature = "python")]
2063#[pyclass(name = "SuperTrendStream")]
2064pub struct SuperTrendStreamPy {
2065    stream: SuperTrendStream,
2066}
2067
2068#[cfg(feature = "python")]
2069#[pymethods]
2070impl SuperTrendStreamPy {
2071    #[new]
2072    fn new(period: usize, factor: f64) -> PyResult<Self> {
2073        let params = SuperTrendParams {
2074            period: Some(period),
2075            factor: Some(factor),
2076        };
2077        let stream =
2078            SuperTrendStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2079        Ok(SuperTrendStreamPy { stream })
2080    }
2081
2082    fn update(&mut self, high: f64, low: f64, close: f64) -> Option<(f64, f64)> {
2083        self.stream.update(high, low, close)
2084    }
2085}
2086
2087#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2088#[inline]
2089pub fn supertrend_into_slice(
2090    trend_dst: &mut [f64],
2091    changed_dst: &mut [f64],
2092    input: &SuperTrendInput,
2093    kern: Kernel,
2094) -> Result<(), SuperTrendError> {
2095    let (high, low, close, period, factor, first_valid_idx, atr_values, chosen) =
2096        supertrend_prepare(input, kern)?;
2097
2098    let len = high.len();
2099    if trend_dst.len() != len {
2100        return Err(SuperTrendError::OutputLengthMismatch {
2101            expected: len,
2102            got: trend_dst.len(),
2103        });
2104    }
2105    if changed_dst.len() != len {
2106        return Err(SuperTrendError::OutputLengthMismatch {
2107            expected: len,
2108            got: changed_dst.len(),
2109        });
2110    }
2111
2112    let warmup_end = first_valid_idx + period - 1;
2113    for v in &mut trend_dst[..warmup_end] {
2114        *v = f64::NAN;
2115    }
2116    for v in &mut changed_dst[..warmup_end] {
2117        *v = f64::NAN;
2118    }
2119
2120    supertrend_compute_into(
2121        high,
2122        low,
2123        close,
2124        period,
2125        factor,
2126        first_valid_idx,
2127        &atr_values,
2128        chosen,
2129        trend_dst,
2130        changed_dst,
2131    );
2132
2133    Ok(())
2134}
2135
2136#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2137#[derive(Serialize, Deserialize)]
2138pub struct SuperTrendJsResult {
2139    pub values: Vec<f64>,
2140    pub rows: usize,
2141    pub cols: usize,
2142}
2143
2144#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2145#[wasm_bindgen(js_name = supertrend)]
2146pub fn supertrend_js(
2147    high: &[f64],
2148    low: &[f64],
2149    close: &[f64],
2150    period: usize,
2151    factor: f64,
2152) -> Result<JsValue, JsValue> {
2153    let len = high.len();
2154    let params = SuperTrendParams {
2155        period: Some(period),
2156        factor: Some(factor),
2157    };
2158    let input = SuperTrendInput::from_slices(high, low, close, params);
2159
2160    let mut values = vec![0.0; len * 2];
2161    let (trend_slice, changed_slice) = values.split_at_mut(len);
2162    supertrend_into_slice(trend_slice, changed_slice, &input, Kernel::Auto)
2163        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2164
2165    let out = SuperTrendJsResult {
2166        values,
2167        rows: 2,
2168        cols: len,
2169    };
2170    serde_wasm_bindgen::to_value(&out)
2171        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2172}
2173
2174#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2175#[wasm_bindgen]
2176pub fn supertrend_into(
2177    high_ptr: *const f64,
2178    low_ptr: *const f64,
2179    close_ptr: *const f64,
2180    trend_ptr: *mut f64,
2181    changed_ptr: *mut f64,
2182    len: usize,
2183    period: usize,
2184    factor: f64,
2185) -> Result<(), JsValue> {
2186    if high_ptr.is_null()
2187        || low_ptr.is_null()
2188        || close_ptr.is_null()
2189        || trend_ptr.is_null()
2190        || changed_ptr.is_null()
2191    {
2192        return Err(JsValue::from_str("Null pointer provided"));
2193    }
2194
2195    unsafe {
2196        let high = std::slice::from_raw_parts(high_ptr, len);
2197        let low = std::slice::from_raw_parts(low_ptr, len);
2198        let close = std::slice::from_raw_parts(close_ptr, len);
2199
2200        let params = SuperTrendParams {
2201            period: Some(period),
2202            factor: Some(factor),
2203        };
2204        let input = SuperTrendInput::from_slices(high, low, close, params);
2205
2206        let input_ptrs = [
2207            high_ptr as *const u8,
2208            low_ptr as *const u8,
2209            close_ptr as *const u8,
2210        ];
2211        let output_ptrs = [trend_ptr as *const u8, changed_ptr as *const u8];
2212
2213        let has_aliasing = input_ptrs
2214            .iter()
2215            .any(|&inp| output_ptrs.iter().any(|&out| inp == out));
2216
2217        if has_aliasing {
2218            let output = supertrend_with_kernel(&input, Kernel::Auto)
2219                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2220
2221            let trend_out = std::slice::from_raw_parts_mut(trend_ptr, len);
2222            let changed_out = std::slice::from_raw_parts_mut(changed_ptr, len);
2223
2224            trend_out.copy_from_slice(&output.trend);
2225            changed_out.copy_from_slice(&output.changed);
2226        } else {
2227            let trend_out = std::slice::from_raw_parts_mut(trend_ptr, len);
2228            let changed_out = std::slice::from_raw_parts_mut(changed_ptr, len);
2229
2230            supertrend_into_slice(trend_out, changed_out, &input, Kernel::Auto)
2231                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2232        }
2233
2234        Ok(())
2235    }
2236}
2237
2238#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2239#[wasm_bindgen]
2240pub fn supertrend_alloc(len: usize) -> *mut f64 {
2241    let mut vec = Vec::<f64>::with_capacity(len);
2242    let ptr = vec.as_mut_ptr();
2243    std::mem::forget(vec);
2244    ptr
2245}
2246
2247#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2248#[wasm_bindgen]
2249pub fn supertrend_free(ptr: *mut f64, len: usize) {
2250    if !ptr.is_null() {
2251        unsafe {
2252            let _ = Vec::from_raw_parts(ptr, len, len);
2253        }
2254    }
2255}
2256
2257#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2258#[derive(Serialize, Deserialize)]
2259pub struct SuperTrendBatchConfig {
2260    pub period_range: (usize, usize, usize),
2261    pub factor_range: (f64, f64, f64),
2262}
2263
2264#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2265#[derive(Serialize, Deserialize)]
2266pub struct SuperTrendBatchJsOutput {
2267    pub values: Vec<f64>,
2268    pub periods: Vec<usize>,
2269    pub factors: Vec<f64>,
2270    pub rows: usize,
2271    pub cols: usize,
2272}
2273
2274#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2275#[wasm_bindgen(js_name = supertrend_batch)]
2276pub fn supertrend_batch_js(
2277    high: &[f64],
2278    low: &[f64],
2279    close: &[f64],
2280    config: JsValue,
2281) -> Result<JsValue, JsValue> {
2282    let cfg: SuperTrendBatchConfig = serde_wasm_bindgen::from_value(config)
2283        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2284
2285    let sweep = SuperTrendBatchRange {
2286        period: cfg.period_range,
2287        factor: cfg.factor_range,
2288    };
2289
2290    let batch = supertrend_batch_with_kernel(high, low, close, &sweep, Kernel::Auto)
2291        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2292
2293    let mut values = Vec::with_capacity(batch.rows * 2 * batch.cols);
2294    for r in 0..batch.rows {
2295        let rs = r * batch.cols;
2296        values.extend_from_slice(&batch.trend[rs..rs + batch.cols]);
2297        values.extend_from_slice(&batch.changed[rs..rs + batch.cols]);
2298    }
2299
2300    let periods: Vec<usize> = batch
2301        .combos
2302        .iter()
2303        .map(|c| c.period.unwrap_or(10))
2304        .collect();
2305    let factors: Vec<f64> = batch
2306        .combos
2307        .iter()
2308        .map(|c| c.factor.unwrap_or(3.0))
2309        .collect();
2310
2311    let out = SuperTrendBatchJsOutput {
2312        values,
2313        periods,
2314        factors,
2315        rows: batch.rows * 2,
2316        cols: batch.cols,
2317    };
2318    serde_wasm_bindgen::to_value(&out)
2319        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2320}
2321
2322#[cfg(test)]
2323mod tests {
2324    use super::*;
2325    use crate::skip_if_unsupported;
2326    use crate::utilities::data_loader::read_candles_from_csv;
2327    use crate::utilities::enums::Kernel;
2328
2329    #[test]
2330    fn test_supertrend_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
2331        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2332        let candles = read_candles_from_csv(file_path)?;
2333
2334        let params = SuperTrendParams {
2335            period: Some(10),
2336            factor: Some(3.0),
2337        };
2338        let input = SuperTrendInput::from_candles(&candles, params);
2339
2340        let baseline = supertrend_with_kernel(&input, Kernel::Auto)?;
2341
2342        let n = candles.close.len();
2343        let mut trend_out = vec![0.0; n];
2344        let mut changed_out = vec![0.0; n];
2345
2346        #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2347        {
2348            supertrend_into(&input, &mut trend_out, &mut changed_out)?;
2349        }
2350
2351        assert_eq!(baseline.trend.len(), n);
2352        assert_eq!(baseline.changed.len(), n);
2353        assert_eq!(trend_out.len(), n);
2354        assert_eq!(changed_out.len(), n);
2355
2356        #[inline]
2357        fn eq_or_both_nan(a: f64, b: f64) -> bool {
2358            (a.is_nan() && b.is_nan()) || (a - b).abs() <= 1e-9
2359        }
2360
2361        for i in 0..n {
2362            assert!(
2363                eq_or_both_nan(baseline.trend[i], trend_out[i]),
2364                "trend mismatch at {}: baseline={}, into={}",
2365                i,
2366                baseline.trend[i],
2367                trend_out[i]
2368            );
2369            assert!(
2370                eq_or_both_nan(baseline.changed[i], changed_out[i]),
2371                "changed mismatch at {}: baseline={}, into={}",
2372                i,
2373                baseline.changed[i],
2374                changed_out[i]
2375            );
2376        }
2377
2378        Ok(())
2379    }
2380
2381    fn check_supertrend_partial_params(
2382        test_name: &str,
2383        kernel: Kernel,
2384    ) -> Result<(), Box<dyn std::error::Error>> {
2385        skip_if_unsupported!(kernel, test_name);
2386        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2387        let candles = read_candles_from_csv(file_path)?;
2388
2389        let default_params = SuperTrendParams {
2390            period: None,
2391            factor: None,
2392        };
2393        let input = SuperTrendInput::from_candles(&candles, default_params);
2394        let output = supertrend_with_kernel(&input, kernel)?;
2395        assert_eq!(output.trend.len(), candles.close.len());
2396        assert_eq!(output.changed.len(), candles.close.len());
2397
2398        Ok(())
2399    }
2400
2401    fn check_supertrend_accuracy(
2402        test_name: &str,
2403        kernel: Kernel,
2404    ) -> Result<(), Box<dyn std::error::Error>> {
2405        skip_if_unsupported!(kernel, test_name);
2406        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2407        let candles = read_candles_from_csv(file_path)?;
2408
2409        let params = SuperTrendParams {
2410            period: Some(10),
2411            factor: Some(3.0),
2412        };
2413        let input = SuperTrendInput::from_candles(&candles, params);
2414        let st_result = supertrend_with_kernel(&input, kernel)?;
2415
2416        assert_eq!(st_result.trend.len(), candles.close.len());
2417        assert_eq!(st_result.changed.len(), candles.close.len());
2418
2419        let expected_last_five_trend = [
2420            61811.479454208165,
2421            61721.73150878735,
2422            61459.10835790861,
2423            61351.59752211775,
2424            61033.18776990598,
2425        ];
2426        let expected_last_five_changed = [0.0, 0.0, 0.0, 0.0, 0.0];
2427
2428        let start_index = st_result.trend.len() - 5;
2429        let trend_slice = &st_result.trend[start_index..];
2430        let changed_slice = &st_result.changed[start_index..];
2431
2432        for (i, &val) in trend_slice.iter().enumerate() {
2433            let exp = expected_last_five_trend[i];
2434            assert!(
2435                (val - exp).abs() < 1e-4,
2436                "[{}] Trend mismatch at idx {}: got {}, expected {}",
2437                test_name,
2438                i,
2439                val,
2440                exp
2441            );
2442        }
2443        for (i, &val) in changed_slice.iter().enumerate() {
2444            let exp = expected_last_five_changed[i];
2445            assert!(
2446                (val - exp).abs() < 1e-9,
2447                "[{}] Changed mismatch at idx {}: got {}, expected {}",
2448                test_name,
2449                i,
2450                val,
2451                exp
2452            );
2453        }
2454        Ok(())
2455    }
2456
2457    fn check_supertrend_default_candles(
2458        test_name: &str,
2459        kernel: Kernel,
2460    ) -> Result<(), Box<dyn std::error::Error>> {
2461        skip_if_unsupported!(kernel, test_name);
2462        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2463        let candles = read_candles_from_csv(file_path)?;
2464
2465        let input = SuperTrendInput::with_default_candles(&candles);
2466        let output = supertrend_with_kernel(&input, kernel)?;
2467        assert_eq!(output.trend.len(), candles.close.len());
2468        assert_eq!(output.changed.len(), candles.close.len());
2469        Ok(())
2470    }
2471
2472    fn check_supertrend_zero_period(
2473        test_name: &str,
2474        kernel: Kernel,
2475    ) -> Result<(), Box<dyn std::error::Error>> {
2476        skip_if_unsupported!(kernel, test_name);
2477        let high = [10.0, 12.0, 13.0];
2478        let low = [9.0, 11.0, 12.5];
2479        let close = [9.5, 11.5, 13.0];
2480        let params = SuperTrendParams {
2481            period: Some(0),
2482            factor: Some(3.0),
2483        };
2484        let input = SuperTrendInput::from_slices(&high, &low, &close, params);
2485        let res = supertrend_with_kernel(&input, kernel);
2486        assert!(res.is_err(), "[{}] Should fail with zero period", test_name);
2487        Ok(())
2488    }
2489
2490    fn check_supertrend_period_exceeds_length(
2491        test_name: &str,
2492        kernel: Kernel,
2493    ) -> Result<(), Box<dyn std::error::Error>> {
2494        skip_if_unsupported!(kernel, test_name);
2495        let high = [10.0, 12.0, 13.0];
2496        let low = [9.0, 11.0, 12.5];
2497        let close = [9.5, 11.5, 13.0];
2498        let params = SuperTrendParams {
2499            period: Some(10),
2500            factor: Some(3.0),
2501        };
2502        let input = SuperTrendInput::from_slices(&high, &low, &close, params);
2503        let res = supertrend_with_kernel(&input, kernel);
2504        assert!(
2505            res.is_err(),
2506            "[{}] Should fail with period > data.len()",
2507            test_name
2508        );
2509        Ok(())
2510    }
2511
2512    fn check_supertrend_very_small_dataset(
2513        test_name: &str,
2514        kernel: Kernel,
2515    ) -> Result<(), Box<dyn std::error::Error>> {
2516        skip_if_unsupported!(kernel, test_name);
2517        let high = [42.0];
2518        let low = [40.0];
2519        let close = [41.0];
2520        let params = SuperTrendParams {
2521            period: Some(10),
2522            factor: Some(3.0),
2523        };
2524        let input = SuperTrendInput::from_slices(&high, &low, &close, params);
2525        let res = supertrend_with_kernel(&input, kernel);
2526        assert!(
2527            res.is_err(),
2528            "[{}] Should fail for data smaller than period",
2529            test_name
2530        );
2531        Ok(())
2532    }
2533
2534    fn check_supertrend_reinput(
2535        test_name: &str,
2536        kernel: Kernel,
2537    ) -> Result<(), Box<dyn std::error::Error>> {
2538        skip_if_unsupported!(kernel, test_name);
2539        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2540        let candles = read_candles_from_csv(file_path)?;
2541
2542        let first_params = SuperTrendParams {
2543            period: Some(10),
2544            factor: Some(3.0),
2545        };
2546        let first_input = SuperTrendInput::from_candles(&candles, first_params);
2547        let first_result = supertrend_with_kernel(&first_input, kernel)?;
2548
2549        let second_params = SuperTrendParams {
2550            period: Some(5),
2551            factor: Some(2.0),
2552        };
2553        let second_input = SuperTrendInput::from_slices(
2554            &first_result.trend,
2555            &first_result.trend,
2556            &first_result.trend,
2557            second_params,
2558        );
2559        let second_result = supertrend_with_kernel(&second_input, kernel)?;
2560        assert_eq!(second_result.trend.len(), first_result.trend.len());
2561        assert_eq!(second_result.changed.len(), first_result.changed.len());
2562        Ok(())
2563    }
2564
2565    fn check_supertrend_nan_handling(
2566        test_name: &str,
2567        kernel: Kernel,
2568    ) -> Result<(), Box<dyn std::error::Error>> {
2569        skip_if_unsupported!(kernel, test_name);
2570        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2571        let candles = read_candles_from_csv(file_path)?;
2572
2573        let params = SuperTrendParams {
2574            period: Some(10),
2575            factor: Some(3.0),
2576        };
2577        let input = SuperTrendInput::from_candles(&candles, params);
2578        let result = supertrend_with_kernel(&input, kernel)?;
2579        if result.trend.len() > 50 {
2580            for (i, &val) in result.trend[50..].iter().enumerate() {
2581                assert!(
2582                    !val.is_nan(),
2583                    "[{}] Found unexpected NaN at out-index {}",
2584                    test_name,
2585                    50 + i
2586                );
2587            }
2588        }
2589        Ok(())
2590    }
2591
2592    fn check_supertrend_streaming(
2593        test_name: &str,
2594        kernel: Kernel,
2595    ) -> Result<(), Box<dyn std::error::Error>> {
2596        skip_if_unsupported!(kernel, test_name);
2597        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2598        let candles = read_candles_from_csv(file_path)?;
2599
2600        let period = 10;
2601        let factor = 3.0;
2602        let params = SuperTrendParams {
2603            period: Some(period),
2604            factor: Some(factor),
2605        };
2606        let input = SuperTrendInput::from_candles(&candles, params.clone());
2607        let batch_output = supertrend_with_kernel(&input, kernel)?;
2608
2609        let mut stream = SuperTrendStream::try_new(params.clone())?;
2610        let mut stream_trend = Vec::with_capacity(candles.close.len());
2611        let mut stream_changed = Vec::with_capacity(candles.close.len());
2612
2613        for i in 0..candles.close.len() {
2614            let (h, l, c) = (candles.high[i], candles.low[i], candles.close[i]);
2615            match stream.update(h, l, c) {
2616                Some((trend, changed)) => {
2617                    stream_trend.push(trend);
2618                    stream_changed.push(changed);
2619                }
2620                None => {
2621                    stream_trend.push(f64::NAN);
2622                    stream_changed.push(f64::NAN);
2623                }
2624            }
2625        }
2626        assert_eq!(batch_output.trend.len(), stream_trend.len());
2627        assert_eq!(batch_output.changed.len(), stream_changed.len());
2628
2629        for (i, (&b, &s)) in batch_output
2630            .trend
2631            .iter()
2632            .zip(stream_trend.iter())
2633            .enumerate()
2634        {
2635            if b.is_nan() && s.is_nan() {
2636                continue;
2637            }
2638            let diff = (b - s).abs();
2639            assert!(
2640                diff < 1e-8,
2641                "[{}] Streaming trend mismatch at idx {}: batch={}, stream={}, diff={}",
2642                test_name,
2643                i,
2644                b,
2645                s,
2646                diff
2647            );
2648        }
2649        for (i, (&b, &s)) in batch_output
2650            .changed
2651            .iter()
2652            .zip(stream_changed.iter())
2653            .enumerate()
2654        {
2655            if b.is_nan() && s.is_nan() {
2656                continue;
2657            }
2658            let diff = (b - s).abs();
2659            assert!(
2660                diff < 1e-9,
2661                "[{}] Streaming changed mismatch at idx {}: batch={}, stream={}, diff={}",
2662                test_name,
2663                i,
2664                b,
2665                s,
2666                diff
2667            );
2668        }
2669        Ok(())
2670    }
2671
2672    #[cfg(debug_assertions)]
2673    fn check_supertrend_no_poison(
2674        test_name: &str,
2675        kernel: Kernel,
2676    ) -> Result<(), Box<dyn std::error::Error>> {
2677        skip_if_unsupported!(kernel, test_name);
2678
2679        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2680        let candles = read_candles_from_csv(file_path)?;
2681
2682        let test_params = vec![
2683            SuperTrendParams::default(),
2684            SuperTrendParams {
2685                period: Some(2),
2686                factor: Some(1.0),
2687            },
2688            SuperTrendParams {
2689                period: Some(5),
2690                factor: Some(0.5),
2691            },
2692            SuperTrendParams {
2693                period: Some(5),
2694                factor: Some(2.0),
2695            },
2696            SuperTrendParams {
2697                period: Some(5),
2698                factor: Some(3.5),
2699            },
2700            SuperTrendParams {
2701                period: Some(10),
2702                factor: Some(1.5),
2703            },
2704            SuperTrendParams {
2705                period: Some(14),
2706                factor: Some(2.5),
2707            },
2708            SuperTrendParams {
2709                period: Some(20),
2710                factor: Some(3.0),
2711            },
2712            SuperTrendParams {
2713                period: Some(50),
2714                factor: Some(2.0),
2715            },
2716            SuperTrendParams {
2717                period: Some(100),
2718                factor: Some(1.0),
2719            },
2720            SuperTrendParams {
2721                period: Some(10),
2722                factor: Some(0.1),
2723            },
2724            SuperTrendParams {
2725                period: Some(10),
2726                factor: Some(5.0),
2727            },
2728        ];
2729
2730        for (param_idx, params) in test_params.iter().enumerate() {
2731            let input = SuperTrendInput::from_candles(&candles, params.clone());
2732            let output = supertrend_with_kernel(&input, kernel)?;
2733
2734            for (i, &val) in output.trend.iter().enumerate() {
2735                if val.is_nan() {
2736                    continue;
2737                }
2738
2739                let bits = val.to_bits();
2740
2741                if bits == 0x11111111_11111111 {
2742                    panic!(
2743						"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in trend \
2744						 with params: period={}, factor={} (param set {})",
2745						test_name, val, bits, i,
2746						params.period.unwrap_or(10),
2747						params.factor.unwrap_or(3.0),
2748						param_idx
2749					);
2750                }
2751
2752                if bits == 0x22222222_22222222 {
2753                    panic!(
2754						"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in trend \
2755						 with params: period={}, factor={} (param set {})",
2756						test_name, val, bits, i,
2757						params.period.unwrap_or(10),
2758						params.factor.unwrap_or(3.0),
2759						param_idx
2760					);
2761                }
2762
2763                if bits == 0x33333333_33333333 {
2764                    panic!(
2765						"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in trend \
2766						 with params: period={}, factor={} (param set {})",
2767						test_name, val, bits, i,
2768						params.period.unwrap_or(10),
2769						params.factor.unwrap_or(3.0),
2770						param_idx
2771					);
2772                }
2773            }
2774
2775            for (i, &val) in output.changed.iter().enumerate() {
2776                if val.is_nan() {
2777                    continue;
2778                }
2779
2780                let bits = val.to_bits();
2781
2782                if bits == 0x11111111_11111111 {
2783                    panic!(
2784						"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in changed \
2785						 with params: period={}, factor={} (param set {})",
2786						test_name, val, bits, i,
2787						params.period.unwrap_or(10),
2788						params.factor.unwrap_or(3.0),
2789						param_idx
2790					);
2791                }
2792
2793                if bits == 0x22222222_22222222 {
2794                    panic!(
2795						"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in changed \
2796						 with params: period={}, factor={} (param set {})",
2797						test_name, val, bits, i,
2798						params.period.unwrap_or(10),
2799						params.factor.unwrap_or(3.0),
2800						param_idx
2801					);
2802                }
2803
2804                if bits == 0x33333333_33333333 {
2805                    panic!(
2806						"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in changed \
2807						 with params: period={}, factor={} (param set {})",
2808						test_name, val, bits, i,
2809						params.period.unwrap_or(10),
2810						params.factor.unwrap_or(3.0),
2811						param_idx
2812					);
2813                }
2814            }
2815        }
2816
2817        Ok(())
2818    }
2819
2820    #[cfg(not(debug_assertions))]
2821    fn check_supertrend_no_poison(
2822        _test_name: &str,
2823        _kernel: Kernel,
2824    ) -> Result<(), Box<dyn std::error::Error>> {
2825        Ok(())
2826    }
2827
2828    #[cfg(feature = "proptest")]
2829    #[allow(clippy::float_cmp)]
2830    fn check_supertrend_property(
2831        test_name: &str,
2832        kernel: Kernel,
2833    ) -> Result<(), Box<dyn std::error::Error>> {
2834        use proptest::prelude::*;
2835        skip_if_unsupported!(kernel, test_name);
2836
2837        let strat = (2usize..=50).prop_flat_map(|period| {
2838            let data_len = period * 2 + 50;
2839            (
2840                prop::collection::vec(
2841                    (100f64..10000f64).prop_filter("finite", |x| x.is_finite()),
2842                    data_len,
2843                ),
2844                Just(period),
2845                0.5f64..5.0f64,
2846            )
2847        });
2848
2849        proptest::test_runner::TestRunner::default()
2850            .run(&strat, |(base_prices, period, factor)| {
2851                let mut high = Vec::with_capacity(base_prices.len());
2852                let mut low = Vec::with_capacity(base_prices.len());
2853                let mut close = Vec::with_capacity(base_prices.len());
2854
2855                let mut rng_state = 42u64;
2856                for &base in &base_prices {
2857                    rng_state = rng_state.wrapping_mul(1664525).wrapping_add(1013904223);
2858                    let rand1 = ((rng_state >> 32) as f64) / (u32::MAX as f64);
2859                    rng_state = rng_state.wrapping_mul(1664525).wrapping_add(1013904223);
2860                    let rand2 = ((rng_state >> 32) as f64) / (u32::MAX as f64);
2861
2862                    let spread = base * (0.005 + rand1 * 0.025);
2863                    let h = base + spread;
2864                    let l = base - spread;
2865
2866                    let c = l + (h - l) * rand2;
2867
2868                    high.push(h);
2869                    low.push(l);
2870                    close.push(c);
2871                }
2872
2873                let params = SuperTrendParams {
2874                    period: Some(period),
2875                    factor: Some(factor),
2876                };
2877                let input = SuperTrendInput::from_slices(&high, &low, &close, params);
2878
2879                let output = supertrend_with_kernel(&input, kernel).unwrap();
2880
2881                let ref_output = supertrend_with_kernel(&input, Kernel::Scalar).unwrap();
2882
2883                prop_assert_eq!(
2884                    output.trend.len(),
2885                    high.len(),
2886                    "[{}] Trend length mismatch",
2887                    test_name
2888                );
2889                prop_assert_eq!(
2890                    output.changed.len(),
2891                    high.len(),
2892                    "[{}] Changed length mismatch",
2893                    test_name
2894                );
2895
2896                let warmup_end = period - 1;
2897                for i in 0..warmup_end {
2898                    prop_assert!(
2899                        output.trend[i].is_nan(),
2900                        "[{}] Expected NaN during warmup at index {}",
2901                        test_name,
2902                        i
2903                    );
2904                    prop_assert!(
2905                        output.changed[i].is_nan(),
2906                        "[{}] Expected NaN in changed during warmup at index {}",
2907                        test_name,
2908                        i
2909                    );
2910                }
2911
2912                for i in warmup_end..output.trend.len() {
2913                    let val = output.trend[i];
2914                    if !val.is_nan() {
2915                        let global_high = high.iter().fold(f64::NEG_INFINITY, |a, &b| {
2916                            if b.is_finite() {
2917                                a.max(b)
2918                            } else {
2919                                a
2920                            }
2921                        });
2922                        let global_low = low.iter().fold(f64::INFINITY, |a, &b| {
2923                            if b.is_finite() {
2924                                a.min(b)
2925                            } else {
2926                                a
2927                            }
2928                        });
2929
2930                        let global_range = global_high - global_low;
2931
2932                        let margin = global_range * factor;
2933
2934                        prop_assert!(
2935                            val >= global_low - margin && val <= global_high + margin,
2936                            "[{}] Trend value {} at index {} outside global bounds [{}, {}]",
2937                            test_name,
2938                            val,
2939                            i,
2940                            global_low - margin,
2941                            global_high + margin
2942                        );
2943                    }
2944                }
2945
2946                for i in warmup_end..output.changed.len() {
2947                    let val = output.changed[i];
2948                    if !val.is_nan() {
2949                        prop_assert!(
2950                            val == 0.0 || val == 1.0,
2951                            "[{}] Changed value {} at index {} is not 0.0 or 1.0",
2952                            test_name,
2953                            val,
2954                            i
2955                        );
2956                    }
2957                }
2958
2959                for i in 0..output.trend.len() {
2960                    let trend_val = output.trend[i];
2961                    let ref_trend_val = ref_output.trend[i];
2962                    let changed_val = output.changed[i];
2963                    let ref_changed_val = ref_output.changed[i];
2964
2965                    if !trend_val.is_finite() || !ref_trend_val.is_finite() {
2966                        prop_assert_eq!(
2967                            trend_val.to_bits(),
2968                            ref_trend_val.to_bits(),
2969                            "[{}] NaN/Inf mismatch in trend at index {}",
2970                            test_name,
2971                            i
2972                        );
2973                    } else {
2974                        let ulp_diff = trend_val.to_bits().abs_diff(ref_trend_val.to_bits());
2975                        prop_assert!(
2976                            (trend_val - ref_trend_val).abs() <= 1e-9 || ulp_diff <= 5,
2977                            "[{}] Kernel mismatch in trend at index {}: {} vs {} (ULP={})",
2978                            test_name,
2979                            i,
2980                            trend_val,
2981                            ref_trend_val,
2982                            ulp_diff
2983                        );
2984                    }
2985
2986                    if !changed_val.is_nan() && !ref_changed_val.is_nan() {
2987                        prop_assert_eq!(
2988                            changed_val,
2989                            ref_changed_val,
2990                            "[{}] Kernel mismatch in changed at index {}: {} vs {}",
2991                            test_name,
2992                            i,
2993                            changed_val,
2994                            ref_changed_val
2995                        );
2996                    }
2997                }
2998
2999                if base_prices.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10) {
3000                    let stable_start = (period * 2).min(output.trend.len());
3001                    if stable_start < output.trend.len() {
3002                        let stable_trend = output.trend[stable_start];
3003                        for i in (stable_start + 1)..output.trend.len() {
3004                            if !output.trend[i].is_nan() && !stable_trend.is_nan() {
3005                                prop_assert!(
3006                                    (output.trend[i] - stable_trend).abs() < 1e-9,
3007                                    "[{}] Trend not stable for constant prices at index {}",
3008                                    test_name,
3009                                    i
3010                                );
3011                            }
3012                        }
3013                    }
3014                }
3015
3016                if output.trend.len() > warmup_end + 1 {
3017                    for i in (warmup_end + 1)..output.changed.len() {
3018                        let changed_val = output.changed[i];
3019                        if !changed_val.is_nan() {
3020                            let curr_trend = output.trend[i];
3021                            let prev_trend = output.trend[i - 1];
3022
3023                            if !curr_trend.is_nan() && !prev_trend.is_nan() {
3024                                if changed_val == 1.0 {
3025                                    prop_assert!(
3026										(curr_trend - prev_trend).abs() > 1e-6,
3027										"[{}] Changed=1.0 at index {} but trend didn't switch: {} vs {}",
3028										test_name, i, prev_trend, curr_trend
3029									);
3030                                }
3031                            }
3032                        }
3033                    }
3034                }
3035
3036                Ok(())
3037            })
3038            .unwrap();
3039
3040        Ok(())
3041    }
3042
3043    macro_rules! generate_all_supertrend_tests {
3044        ($($test_fn:ident),*) => {
3045            paste::paste! {
3046                $(
3047                    #[test]
3048                    fn [<$test_fn _scalar_f64>]() {
3049                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
3050                    }
3051                )*
3052                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3053                $(
3054                    #[test]
3055                    fn [<$test_fn _avx2_f64>]() {
3056                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
3057                    }
3058                    #[test]
3059                    fn [<$test_fn _avx512_f64>]() {
3060                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
3061                    }
3062                )*
3063            }
3064        }
3065    }
3066
3067    generate_all_supertrend_tests!(
3068        check_supertrend_partial_params,
3069        check_supertrend_accuracy,
3070        check_supertrend_default_candles,
3071        check_supertrend_zero_period,
3072        check_supertrend_period_exceeds_length,
3073        check_supertrend_very_small_dataset,
3074        check_supertrend_reinput,
3075        check_supertrend_nan_handling,
3076        check_supertrend_streaming,
3077        check_supertrend_no_poison
3078    );
3079
3080    #[cfg(feature = "proptest")]
3081    generate_all_supertrend_tests!(check_supertrend_property);
3082
3083    fn check_batch_default_row(
3084        test: &str,
3085        kernel: Kernel,
3086    ) -> Result<(), Box<dyn std::error::Error>> {
3087        skip_if_unsupported!(kernel, test);
3088
3089        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3090        let c = read_candles_from_csv(file)?;
3091
3092        let output = SuperTrendBatchBuilder::new()
3093            .kernel(kernel)
3094            .apply_candles(&c)?;
3095
3096        let def = SuperTrendParams::default();
3097        let row = output.trend_for(&def).expect("default row missing");
3098
3099        assert_eq!(row.len(), c.close.len());
3100
3101        let expected = [
3102            61811.479454208165,
3103            61721.73150878735,
3104            61459.10835790861,
3105            61351.59752211775,
3106            61033.18776990598,
3107        ];
3108        let start = row.len() - 5;
3109        for (i, &v) in row[start..].iter().enumerate() {
3110            assert!(
3111                (v - expected[i]).abs() < 1e-4,
3112                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
3113            );
3114        }
3115        Ok(())
3116    }
3117
3118    #[cfg(debug_assertions)]
3119    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
3120        skip_if_unsupported!(kernel, test);
3121
3122        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3123        let c = read_candles_from_csv(file)?;
3124
3125        let test_configs = vec![
3126            (2, 10, 2, 1.0, 3.0, 0.5),
3127            (5, 25, 5, 2.0, 2.0, 0.0),
3128            (10, 10, 0, 0.5, 4.0, 0.5),
3129            (2, 5, 1, 1.5, 1.5, 0.0),
3130            (30, 60, 15, 3.0, 3.0, 0.0),
3131            (20, 30, 5, 1.0, 3.0, 1.0),
3132            (8, 12, 1, 0.5, 2.5, 0.5),
3133        ];
3134
3135        for (cfg_idx, &(p_start, p_end, p_step, f_start, f_end, f_step)) in
3136            test_configs.iter().enumerate()
3137        {
3138            let output = SuperTrendBatchBuilder::new()
3139                .kernel(kernel)
3140                .period_range(p_start, p_end, p_step)
3141                .factor_range(f_start, f_end, f_step)
3142                .apply_candles(&c)?;
3143
3144            for (idx, &val) in output.trend.iter().enumerate() {
3145                if val.is_nan() {
3146                    continue;
3147                }
3148
3149                let bits = val.to_bits();
3150                let row = idx / output.cols;
3151                let col = idx % output.cols;
3152                let combo = &output.combos[row];
3153
3154                if bits == 0x11111111_11111111 {
3155                    panic!(
3156                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
3157						at row {} col {} (flat index {}) in trend with params: period={}, factor={}",
3158                        test,
3159                        cfg_idx,
3160                        val,
3161                        bits,
3162                        row,
3163                        col,
3164                        idx,
3165                        combo.period.unwrap_or(10),
3166                        combo.factor.unwrap_or(3.0)
3167                    );
3168                }
3169
3170                if bits == 0x22222222_22222222 {
3171                    panic!(
3172                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
3173						at row {} col {} (flat index {}) in trend with params: period={}, factor={}",
3174                        test,
3175                        cfg_idx,
3176                        val,
3177                        bits,
3178                        row,
3179                        col,
3180                        idx,
3181                        combo.period.unwrap_or(10),
3182                        combo.factor.unwrap_or(3.0)
3183                    );
3184                }
3185
3186                if bits == 0x33333333_33333333 {
3187                    panic!(
3188                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
3189						at row {} col {} (flat index {}) in trend with params: period={}, factor={}",
3190                        test,
3191                        cfg_idx,
3192                        val,
3193                        bits,
3194                        row,
3195                        col,
3196                        idx,
3197                        combo.period.unwrap_or(10),
3198                        combo.factor.unwrap_or(3.0)
3199                    );
3200                }
3201            }
3202
3203            for (idx, &val) in output.changed.iter().enumerate() {
3204                if val.is_nan() {
3205                    continue;
3206                }
3207
3208                let bits = val.to_bits();
3209                let row = idx / output.cols;
3210                let col = idx % output.cols;
3211                let combo = &output.combos[row];
3212
3213                if bits == 0x11111111_11111111 {
3214                    panic!(
3215                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
3216						at row {} col {} (flat index {}) in changed with params: period={}, factor={}",
3217                        test,
3218                        cfg_idx,
3219                        val,
3220                        bits,
3221                        row,
3222                        col,
3223                        idx,
3224                        combo.period.unwrap_or(10),
3225                        combo.factor.unwrap_or(3.0)
3226                    );
3227                }
3228
3229                if bits == 0x22222222_22222222 {
3230                    panic!(
3231                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
3232						at row {} col {} (flat index {}) in changed with params: period={}, factor={}",
3233                        test,
3234                        cfg_idx,
3235                        val,
3236                        bits,
3237                        row,
3238                        col,
3239                        idx,
3240                        combo.period.unwrap_or(10),
3241                        combo.factor.unwrap_or(3.0)
3242                    );
3243                }
3244
3245                if bits == 0x33333333_33333333 {
3246                    panic!(
3247                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
3248						at row {} col {} (flat index {}) in changed with params: period={}, factor={}",
3249                        test,
3250                        cfg_idx,
3251                        val,
3252                        bits,
3253                        row,
3254                        col,
3255                        idx,
3256                        combo.period.unwrap_or(10),
3257                        combo.factor.unwrap_or(3.0)
3258                    );
3259                }
3260            }
3261        }
3262
3263        Ok(())
3264    }
3265
3266    #[cfg(not(debug_assertions))]
3267    fn check_batch_no_poison(
3268        _test: &str,
3269        _kernel: Kernel,
3270    ) -> Result<(), Box<dyn std::error::Error>> {
3271        Ok(())
3272    }
3273
3274    macro_rules! gen_batch_tests {
3275        ($fn_name:ident) => {
3276            paste::paste! {
3277                #[test] fn [<$fn_name _scalar>]()      {
3278                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
3279                }
3280                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3281                #[test] fn [<$fn_name _avx2>]()        {
3282                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
3283                }
3284                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3285                #[test] fn [<$fn_name _avx512>]()      {
3286                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
3287                }
3288                #[test] fn [<$fn_name _auto_detect>]() {
3289                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
3290                }
3291            }
3292        };
3293    }
3294    gen_batch_tests!(check_batch_default_row);
3295    gen_batch_tests!(check_batch_no_poison);
3296}