Skip to main content

vector_ta/indicators/
supertrend_oscillator.rs

1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::indicators::atr::{AtrParams, AtrStream};
16use crate::utilities::data_loader::{source_type, Candles};
17use crate::utilities::enums::Kernel;
18use crate::utilities::helpers::{
19    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
20    make_uninit_matrix,
21};
22#[cfg(feature = "python")]
23use crate::utilities::kernel_validation::validate_kernel;
24
25#[cfg(not(target_arch = "wasm32"))]
26use rayon::prelude::*;
27use std::collections::HashMap;
28use std::mem::{ManuallyDrop, MaybeUninit};
29use thiserror::Error;
30
31const DEFAULT_LENGTH: usize = 10;
32const DEFAULT_MULT: f64 = 2.0;
33const DEFAULT_SMOOTH: usize = 72;
34const OUTPUT_SCALE: f64 = 100.0;
35
36#[derive(Debug, Clone)]
37pub enum SuperTrendOscillatorData<'a> {
38    Candles {
39        candles: &'a Candles,
40        source: &'a str,
41    },
42    Slices {
43        high: &'a [f64],
44        low: &'a [f64],
45        source: &'a [f64],
46    },
47}
48
49#[derive(Debug, Clone)]
50pub struct SuperTrendOscillatorOutput {
51    pub oscillator: Vec<f64>,
52    pub signal: Vec<f64>,
53    pub histogram: Vec<f64>,
54}
55
56#[derive(Debug, Clone)]
57#[cfg_attr(
58    all(target_arch = "wasm32", feature = "wasm"),
59    derive(Serialize, Deserialize)
60)]
61pub struct SuperTrendOscillatorParams {
62    pub length: Option<usize>,
63    pub mult: Option<f64>,
64    pub smooth: Option<usize>,
65}
66
67impl Default for SuperTrendOscillatorParams {
68    fn default() -> Self {
69        Self {
70            length: Some(DEFAULT_LENGTH),
71            mult: Some(DEFAULT_MULT),
72            smooth: Some(DEFAULT_SMOOTH),
73        }
74    }
75}
76
77#[derive(Debug, Clone)]
78pub struct SuperTrendOscillatorInput<'a> {
79    pub data: SuperTrendOscillatorData<'a>,
80    pub params: SuperTrendOscillatorParams,
81}
82
83impl<'a> SuperTrendOscillatorInput<'a> {
84    #[inline]
85    pub fn from_candles(
86        candles: &'a Candles,
87        source: &'a str,
88        params: SuperTrendOscillatorParams,
89    ) -> Self {
90        Self {
91            data: SuperTrendOscillatorData::Candles { candles, source },
92            params,
93        }
94    }
95
96    #[inline]
97    pub fn from_slices(
98        high: &'a [f64],
99        low: &'a [f64],
100        source: &'a [f64],
101        params: SuperTrendOscillatorParams,
102    ) -> Self {
103        Self {
104            data: SuperTrendOscillatorData::Slices { high, low, source },
105            params,
106        }
107    }
108
109    #[inline]
110    pub fn with_default_candles(candles: &'a Candles) -> Self {
111        Self::from_candles(candles, "close", SuperTrendOscillatorParams::default())
112    }
113
114    #[inline]
115    pub fn get_length(&self) -> usize {
116        self.params.length.unwrap_or(DEFAULT_LENGTH)
117    }
118
119    #[inline]
120    pub fn get_mult(&self) -> f64 {
121        self.params.mult.unwrap_or(DEFAULT_MULT)
122    }
123
124    #[inline]
125    pub fn get_smooth(&self) -> usize {
126        self.params.smooth.unwrap_or(DEFAULT_SMOOTH)
127    }
128
129    #[inline]
130    pub fn as_refs(&'a self) -> (&'a [f64], &'a [f64], &'a [f64]) {
131        match &self.data {
132            SuperTrendOscillatorData::Candles { candles, source } => (
133                candles.high.as_slice(),
134                candles.low.as_slice(),
135                source_type(candles, source),
136            ),
137            SuperTrendOscillatorData::Slices { high, low, source } => (*high, *low, *source),
138        }
139    }
140}
141
142#[derive(Clone, Debug)]
143pub struct SuperTrendOscillatorBuilder {
144    length: Option<usize>,
145    mult: Option<f64>,
146    smooth: Option<usize>,
147    source: Option<String>,
148    kernel: Kernel,
149}
150
151impl Default for SuperTrendOscillatorBuilder {
152    fn default() -> Self {
153        Self {
154            length: None,
155            mult: None,
156            smooth: None,
157            source: None,
158            kernel: Kernel::Auto,
159        }
160    }
161}
162
163impl SuperTrendOscillatorBuilder {
164    #[inline]
165    pub fn new() -> Self {
166        Self::default()
167    }
168
169    #[inline]
170    pub fn length(mut self, value: usize) -> Self {
171        self.length = Some(value);
172        self
173    }
174
175    #[inline]
176    pub fn mult(mut self, value: f64) -> Self {
177        self.mult = Some(value);
178        self
179    }
180
181    #[inline]
182    pub fn smooth(mut self, value: usize) -> Self {
183        self.smooth = Some(value);
184        self
185    }
186
187    #[inline]
188    pub fn source<S: Into<String>>(mut self, value: S) -> Self {
189        self.source = Some(value.into());
190        self
191    }
192
193    #[inline]
194    pub fn kernel(mut self, value: Kernel) -> Self {
195        self.kernel = value;
196        self
197    }
198
199    #[inline]
200    pub fn apply(
201        self,
202        candles: &Candles,
203    ) -> Result<SuperTrendOscillatorOutput, SuperTrendOscillatorError> {
204        let input = SuperTrendOscillatorInput::from_candles(
205            candles,
206            self.source.as_deref().unwrap_or("close"),
207            SuperTrendOscillatorParams {
208                length: self.length,
209                mult: self.mult,
210                smooth: self.smooth,
211            },
212        );
213        supertrend_oscillator_with_kernel(&input, self.kernel)
214    }
215
216    #[inline]
217    pub fn apply_slices(
218        self,
219        high: &[f64],
220        low: &[f64],
221        source: &[f64],
222    ) -> Result<SuperTrendOscillatorOutput, SuperTrendOscillatorError> {
223        let input = SuperTrendOscillatorInput::from_slices(
224            high,
225            low,
226            source,
227            SuperTrendOscillatorParams {
228                length: self.length,
229                mult: self.mult,
230                smooth: self.smooth,
231            },
232        );
233        supertrend_oscillator_with_kernel(&input, self.kernel)
234    }
235
236    #[inline]
237    pub fn into_stream(self) -> Result<SuperTrendOscillatorStream, SuperTrendOscillatorError> {
238        SuperTrendOscillatorStream::try_new(SuperTrendOscillatorParams {
239            length: self.length,
240            mult: self.mult,
241            smooth: self.smooth,
242        })
243    }
244}
245
246#[derive(Debug, Error)]
247pub enum SuperTrendOscillatorError {
248    #[error("supertrend_oscillator: Empty input data.")]
249    EmptyInputData,
250    #[error(
251        "supertrend_oscillator: Input length mismatch: high={high}, low={low}, source={source_len}"
252    )]
253    DataLengthMismatch {
254        high: usize,
255        low: usize,
256        source_len: usize,
257    },
258    #[error("supertrend_oscillator: All input values are invalid.")]
259    AllValuesNaN,
260    #[error("supertrend_oscillator: Invalid length: length = {length}, data length = {data_len}")]
261    InvalidLength { length: usize, data_len: usize },
262    #[error("supertrend_oscillator: Invalid multiplier: {mult}")]
263    InvalidMultiplier { mult: f64 },
264    #[error("supertrend_oscillator: Invalid smooth: {smooth}")]
265    InvalidSmooth { smooth: usize },
266    #[error("supertrend_oscillator: Not enough valid data: needed = {needed}, valid = {valid}")]
267    NotEnoughValidData { needed: usize, valid: usize },
268    #[error("supertrend_oscillator: Output length mismatch: expected = {expected}, got = {got}")]
269    OutputLengthMismatch { expected: usize, got: usize },
270    #[error("supertrend_oscillator: Invalid range: start={start}, end={end}, step={step}")]
271    InvalidRange {
272        start: String,
273        end: String,
274        step: String,
275    },
276    #[error("supertrend_oscillator: Invalid float range: start={start}, end={end}, step={step}")]
277    InvalidFloatRange { start: f64, end: f64, step: f64 },
278    #[error("supertrend_oscillator: Invalid kernel for batch: {0:?}")]
279    InvalidKernelForBatch(Kernel),
280}
281
282#[inline(always)]
283fn valid_bar(high: f64, low: f64, source: f64) -> bool {
284    high.is_finite() && low.is_finite() && source.is_finite() && high >= low
285}
286
287#[inline(always)]
288fn first_valid_bar(high: &[f64], low: &[f64], source: &[f64]) -> Option<usize> {
289    (0..source.len()).find(|&i| valid_bar(high[i], low[i], source[i]))
290}
291
292#[inline(always)]
293fn max_valid_run(high: &[f64], low: &[f64], source: &[f64]) -> usize {
294    let mut best = 0usize;
295    let mut cur = 0usize;
296    for i in 0..source.len() {
297        if valid_bar(high[i], low[i], source[i]) {
298            cur += 1;
299            if cur > best {
300                best = cur;
301            }
302        } else {
303            cur = 0;
304        }
305    }
306    best
307}
308
309#[inline(always)]
310fn normalized_kernel(kernel: Kernel) -> Kernel {
311    match kernel {
312        Kernel::Auto => detect_best_kernel(),
313        other if other.is_batch() => other.to_non_batch(),
314        other => other,
315    }
316}
317
318#[inline(always)]
319fn clamp_unit(value: f64) -> f64 {
320    value.clamp(-1.0, 1.0)
321}
322
323#[inline(always)]
324fn warmup_end(first_valid: usize, length: usize) -> usize {
325    first_valid.saturating_add(length.saturating_sub(1))
326}
327
328#[inline(always)]
329fn validate_lengths(
330    high: &[f64],
331    low: &[f64],
332    source: &[f64],
333) -> Result<(), SuperTrendOscillatorError> {
334    if high.is_empty() || low.is_empty() || source.is_empty() {
335        return Err(SuperTrendOscillatorError::EmptyInputData);
336    }
337    if high.len() != low.len() || low.len() != source.len() {
338        return Err(SuperTrendOscillatorError::DataLengthMismatch {
339            high: high.len(),
340            low: low.len(),
341            source_len: source.len(),
342        });
343    }
344    Ok(())
345}
346
347#[inline(always)]
348fn validate_params(
349    length: usize,
350    mult: f64,
351    smooth: usize,
352    data_len: usize,
353) -> Result<(), SuperTrendOscillatorError> {
354    if length == 0 || length > data_len {
355        return Err(SuperTrendOscillatorError::InvalidLength { length, data_len });
356    }
357    if !mult.is_finite() || mult <= 0.0 {
358        return Err(SuperTrendOscillatorError::InvalidMultiplier { mult });
359    }
360    if smooth == 0 {
361        return Err(SuperTrendOscillatorError::InvalidSmooth { smooth });
362    }
363    Ok(())
364}
365
366fn compute_atr_series(high: &[f64], low: &[f64], source: &[f64], length: usize) -> Vec<f64> {
367    let mut out = vec![f64::NAN; source.len()];
368    let mut stream = AtrStream::try_new(AtrParams {
369        length: Some(length),
370    })
371    .expect("validated length");
372
373    for i in 0..source.len() {
374        if !valid_bar(high[i], low[i], source[i]) {
375            stream = AtrStream::try_new(AtrParams {
376                length: Some(length),
377            })
378            .expect("validated length");
379            continue;
380        }
381        if let Some(atr) = stream.update(high[i], low[i], source[i]) {
382            out[i] = atr;
383        }
384    }
385
386    out
387}
388
389#[inline(always)]
390fn supertrend_oscillator_row_scalar(
391    high: &[f64],
392    low: &[f64],
393    source: &[f64],
394    length: usize,
395    mult: f64,
396    smooth: usize,
397    atr_values: &[f64],
398    out_oscillator: &mut [f64],
399    out_signal: &mut [f64],
400    out_histogram: &mut [f64],
401) {
402    let hist_alpha = 2.0 / (smooth as f64 + 1.0);
403    let mut prev_source = f64::NAN;
404    let mut prev_upper = f64::NAN;
405    let mut prev_lower = f64::NAN;
406    let mut prev_trend = 0.0;
407    let mut ama_prev: Option<f64> = None;
408    let mut hist_prev: Option<f64> = None;
409    let length_f64 = length as f64;
410
411    for i in 0..source.len() {
412        let src = source[i];
413        if !valid_bar(high[i], low[i], src) {
414            out_oscillator[i] = f64::NAN;
415            out_signal[i] = f64::NAN;
416            out_histogram[i] = f64::NAN;
417            prev_source = f64::NAN;
418            prev_upper = f64::NAN;
419            prev_lower = f64::NAN;
420            prev_trend = 0.0;
421            ama_prev = None;
422            hist_prev = None;
423            continue;
424        }
425
426        if !atr_values[i].is_finite() {
427            out_oscillator[i] = f64::NAN;
428            out_signal[i] = f64::NAN;
429            out_histogram[i] = f64::NAN;
430            prev_source = src;
431            continue;
432        }
433
434        let mid = 0.5 * (high[i] + low[i]);
435        let band = atr_values[i] * mult;
436        let up = mid + band;
437        let dn = mid - band;
438
439        let upper = if prev_source.is_finite() && prev_upper.is_finite() && prev_source < prev_upper
440        {
441            up.min(prev_upper)
442        } else {
443            up
444        };
445        let lower = if prev_source.is_finite() && prev_lower.is_finite() && prev_source > prev_lower
446        {
447            dn.max(prev_lower)
448        } else {
449            dn
450        };
451
452        let trend = if prev_upper.is_finite() && src > prev_upper {
453            1.0
454        } else if prev_lower.is_finite() && src < prev_lower {
455            0.0
456        } else {
457            prev_trend
458        };
459
460        let supertrend = trend * lower + (1.0 - trend) * upper;
461        let width = upper - lower;
462        let osc = if width.is_finite() && width != 0.0 {
463            clamp_unit((src - supertrend) / width)
464        } else {
465            0.0
466        };
467        let alpha = (osc * osc) / length_f64;
468        let ama = match ama_prev {
469            Some(prev) => prev + alpha * (osc - prev),
470            None => osc,
471        };
472        let diff = osc - ama;
473        let hist = match hist_prev {
474            Some(prev) => prev + hist_alpha * (diff - prev),
475            None => diff,
476        };
477
478        out_oscillator[i] = osc * OUTPUT_SCALE;
479        out_signal[i] = ama * OUTPUT_SCALE;
480        out_histogram[i] = hist * OUTPUT_SCALE;
481
482        prev_source = src;
483        prev_upper = upper;
484        prev_lower = lower;
485        prev_trend = trend;
486        ama_prev = Some(ama);
487        hist_prev = Some(hist);
488    }
489}
490
491fn supertrend_oscillator_prepare<'a>(
492    input: &'a SuperTrendOscillatorInput<'a>,
493    kernel: Kernel,
494) -> Result<
495    (
496        &'a [f64],
497        &'a [f64],
498        &'a [f64],
499        usize,
500        f64,
501        usize,
502        usize,
503        Vec<f64>,
504        Kernel,
505    ),
506    SuperTrendOscillatorError,
507> {
508    let (high, low, source) = input.as_refs();
509    validate_lengths(high, low, source)?;
510
511    let length = input.get_length();
512    let mult = input.get_mult();
513    let smooth = input.get_smooth();
514    validate_params(length, mult, smooth, source.len())?;
515
516    let first_valid =
517        first_valid_bar(high, low, source).ok_or(SuperTrendOscillatorError::AllValuesNaN)?;
518    let max_run = max_valid_run(high, low, source);
519    if max_run < length {
520        return Err(SuperTrendOscillatorError::NotEnoughValidData {
521            needed: length,
522            valid: max_run,
523        });
524    }
525
526    let atr_values = compute_atr_series(high, low, source, length);
527
528    Ok((
529        high,
530        low,
531        source,
532        length,
533        mult,
534        smooth,
535        first_valid,
536        atr_values,
537        normalized_kernel(kernel),
538    ))
539}
540
541#[inline]
542pub fn supertrend_oscillator(
543    input: &SuperTrendOscillatorInput,
544) -> Result<SuperTrendOscillatorOutput, SuperTrendOscillatorError> {
545    supertrend_oscillator_with_kernel(input, Kernel::Auto)
546}
547
548#[inline]
549pub fn supertrend_oscillator_with_kernel(
550    input: &SuperTrendOscillatorInput,
551    kernel: Kernel,
552) -> Result<SuperTrendOscillatorOutput, SuperTrendOscillatorError> {
553    let (high, low, source, length, mult, smooth, first_valid, atr_values, _chosen) =
554        supertrend_oscillator_prepare(input, kernel)?;
555
556    let len = source.len();
557    let warmup = warmup_end(first_valid, length);
558    let mut oscillator = alloc_with_nan_prefix(len, warmup);
559    let mut signal = alloc_with_nan_prefix(len, warmup);
560    let mut histogram = alloc_with_nan_prefix(len, warmup);
561
562    supertrend_oscillator_row_scalar(
563        high,
564        low,
565        source,
566        length,
567        mult,
568        smooth,
569        &atr_values,
570        &mut oscillator,
571        &mut signal,
572        &mut histogram,
573    );
574
575    Ok(SuperTrendOscillatorOutput {
576        oscillator,
577        signal,
578        histogram,
579    })
580}
581
582#[inline]
583pub fn supertrend_oscillator_into_slice(
584    out_oscillator: &mut [f64],
585    out_signal: &mut [f64],
586    out_histogram: &mut [f64],
587    input: &SuperTrendOscillatorInput,
588    kernel: Kernel,
589) -> Result<(), SuperTrendOscillatorError> {
590    let (high, low, source, length, mult, smooth, _first_valid, atr_values, _chosen) =
591        supertrend_oscillator_prepare(input, kernel)?;
592    let len = source.len();
593    if out_oscillator.len() != len || out_signal.len() != len || out_histogram.len() != len {
594        return Err(SuperTrendOscillatorError::OutputLengthMismatch {
595            expected: len,
596            got: out_oscillator
597                .len()
598                .max(out_signal.len())
599                .max(out_histogram.len()),
600        });
601    }
602
603    supertrend_oscillator_row_scalar(
604        high,
605        low,
606        source,
607        length,
608        mult,
609        smooth,
610        &atr_values,
611        out_oscillator,
612        out_signal,
613        out_histogram,
614    );
615    Ok(())
616}
617
618#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
619#[inline]
620pub fn supertrend_oscillator_into(
621    input: &SuperTrendOscillatorInput,
622    out_oscillator: &mut [f64],
623    out_signal: &mut [f64],
624    out_histogram: &mut [f64],
625) -> Result<(), SuperTrendOscillatorError> {
626    supertrend_oscillator_into_slice(
627        out_oscillator,
628        out_signal,
629        out_histogram,
630        input,
631        Kernel::Auto,
632    )
633}
634
635#[derive(Clone, Debug)]
636pub struct SuperTrendOscillatorStream {
637    length: usize,
638    mult: f64,
639    hist_alpha: f64,
640    atr_stream: AtrStream,
641    prev_source: f64,
642    prev_upper: f64,
643    prev_lower: f64,
644    prev_trend: f64,
645    ama_prev: Option<f64>,
646    hist_prev: Option<f64>,
647}
648
649impl SuperTrendOscillatorStream {
650    #[inline]
651    pub fn try_new(params: SuperTrendOscillatorParams) -> Result<Self, SuperTrendOscillatorError> {
652        let length = params.length.unwrap_or(DEFAULT_LENGTH);
653        let mult = params.mult.unwrap_or(DEFAULT_MULT);
654        let smooth = params.smooth.unwrap_or(DEFAULT_SMOOTH);
655        validate_params(length, mult, smooth, length)?;
656        Ok(Self {
657            length,
658            mult,
659            hist_alpha: 2.0 / (smooth as f64 + 1.0),
660            atr_stream: AtrStream::try_new(AtrParams {
661                length: Some(length),
662            })
663            .expect("validated length"),
664            prev_source: f64::NAN,
665            prev_upper: f64::NAN,
666            prev_lower: f64::NAN,
667            prev_trend: 0.0,
668            ama_prev: None,
669            hist_prev: None,
670        })
671    }
672
673    #[inline]
674    fn reset(&mut self) {
675        self.atr_stream = AtrStream::try_new(AtrParams {
676            length: Some(self.length),
677        })
678        .expect("validated length");
679        self.prev_source = f64::NAN;
680        self.prev_upper = f64::NAN;
681        self.prev_lower = f64::NAN;
682        self.prev_trend = 0.0;
683        self.ama_prev = None;
684        self.hist_prev = None;
685    }
686
687    #[inline]
688    pub fn update(&mut self, high: f64, low: f64, source: f64) -> Option<(f64, f64, f64)> {
689        if !valid_bar(high, low, source) {
690            self.reset();
691            return None;
692        }
693
694        let atr = match self.atr_stream.update(high, low, source) {
695            Some(value) => value,
696            None => {
697                self.prev_source = source;
698                return None;
699            }
700        };
701
702        let mid = 0.5 * (high + low);
703        let up = mid + atr * self.mult;
704        let dn = mid - atr * self.mult;
705
706        let upper = if self.prev_source.is_finite()
707            && self.prev_upper.is_finite()
708            && self.prev_source < self.prev_upper
709        {
710            up.min(self.prev_upper)
711        } else {
712            up
713        };
714        let lower = if self.prev_source.is_finite()
715            && self.prev_lower.is_finite()
716            && self.prev_source > self.prev_lower
717        {
718            dn.max(self.prev_lower)
719        } else {
720            dn
721        };
722
723        let trend = if self.prev_upper.is_finite() && source > self.prev_upper {
724            1.0
725        } else if self.prev_lower.is_finite() && source < self.prev_lower {
726            0.0
727        } else {
728            self.prev_trend
729        };
730
731        let supertrend = trend * lower + (1.0 - trend) * upper;
732        let width = upper - lower;
733        let osc = if width.is_finite() && width != 0.0 {
734            clamp_unit((source - supertrend) / width)
735        } else {
736            0.0
737        };
738        let alpha = (osc * osc) / self.length as f64;
739        let ama = match self.ama_prev {
740            Some(prev) => prev + alpha * (osc - prev),
741            None => osc,
742        };
743        let diff = osc - ama;
744        let hist = match self.hist_prev {
745            Some(prev) => prev + self.hist_alpha * (diff - prev),
746            None => diff,
747        };
748
749        self.prev_source = source;
750        self.prev_upper = upper;
751        self.prev_lower = lower;
752        self.prev_trend = trend;
753        self.ama_prev = Some(ama);
754        self.hist_prev = Some(hist);
755
756        Some((osc * OUTPUT_SCALE, ama * OUTPUT_SCALE, hist * OUTPUT_SCALE))
757    }
758}
759
760#[derive(Debug, Clone)]
761pub struct SuperTrendOscillatorBatchOutput {
762    pub oscillator: Vec<f64>,
763    pub signal: Vec<f64>,
764    pub histogram: Vec<f64>,
765    pub combos: Vec<SuperTrendOscillatorParams>,
766    pub rows: usize,
767    pub cols: usize,
768}
769
770impl SuperTrendOscillatorBatchOutput {
771    #[inline]
772    pub fn row_for_params(&self, params: &SuperTrendOscillatorParams) -> Option<usize> {
773        self.combos.iter().position(|p| {
774            p.length == params.length && p.mult == params.mult && p.smooth == params.smooth
775        })
776    }
777}
778
779#[derive(Debug, Clone)]
780pub struct SuperTrendOscillatorBatchRange {
781    pub length: (usize, usize, usize),
782    pub mult: (f64, f64, f64),
783    pub smooth: (usize, usize, usize),
784}
785
786impl Default for SuperTrendOscillatorBatchRange {
787    fn default() -> Self {
788        Self {
789            length: (DEFAULT_LENGTH, DEFAULT_LENGTH, 0),
790            mult: (DEFAULT_MULT, DEFAULT_MULT, 0.0),
791            smooth: (DEFAULT_SMOOTH, DEFAULT_SMOOTH, 0),
792        }
793    }
794}
795
796#[derive(Clone, Debug)]
797pub struct SuperTrendOscillatorBatchBuilder {
798    range: SuperTrendOscillatorBatchRange,
799    source: Option<String>,
800    kernel: Kernel,
801}
802
803impl Default for SuperTrendOscillatorBatchBuilder {
804    fn default() -> Self {
805        Self {
806            range: SuperTrendOscillatorBatchRange::default(),
807            source: None,
808            kernel: Kernel::Auto,
809        }
810    }
811}
812
813impl SuperTrendOscillatorBatchBuilder {
814    #[inline]
815    pub fn new() -> Self {
816        Self::default()
817    }
818
819    #[inline]
820    pub fn kernel(mut self, value: Kernel) -> Self {
821        self.kernel = value;
822        self
823    }
824
825    #[inline]
826    pub fn source<S: Into<String>>(mut self, value: S) -> Self {
827        self.source = Some(value.into());
828        self
829    }
830
831    #[inline]
832    pub fn length_range(mut self, start: usize, end: usize, step: usize) -> Self {
833        self.range.length = (start, end, step);
834        self
835    }
836
837    #[inline]
838    pub fn length_static(mut self, value: usize) -> Self {
839        self.range.length = (value, value, 0);
840        self
841    }
842
843    #[inline]
844    pub fn mult_range(mut self, start: f64, end: f64, step: f64) -> Self {
845        self.range.mult = (start, end, step);
846        self
847    }
848
849    #[inline]
850    pub fn mult_static(mut self, value: f64) -> Self {
851        self.range.mult = (value, value, 0.0);
852        self
853    }
854
855    #[inline]
856    pub fn smooth_range(mut self, start: usize, end: usize, step: usize) -> Self {
857        self.range.smooth = (start, end, step);
858        self
859    }
860
861    #[inline]
862    pub fn smooth_static(mut self, value: usize) -> Self {
863        self.range.smooth = (value, value, 0);
864        self
865    }
866
867    #[inline]
868    pub fn apply_candles(
869        self,
870        candles: &Candles,
871    ) -> Result<SuperTrendOscillatorBatchOutput, SuperTrendOscillatorError> {
872        let source = source_type(candles, self.source.as_deref().unwrap_or("close"));
873        supertrend_oscillator_batch_with_kernel(
874            candles.high.as_slice(),
875            candles.low.as_slice(),
876            source,
877            &self.range,
878            self.kernel,
879        )
880    }
881
882    #[inline]
883    pub fn apply_slices(
884        self,
885        high: &[f64],
886        low: &[f64],
887        source: &[f64],
888    ) -> Result<SuperTrendOscillatorBatchOutput, SuperTrendOscillatorError> {
889        supertrend_oscillator_batch_with_kernel(high, low, source, &self.range, self.kernel)
890    }
891}
892
893#[inline]
894pub fn expand_grid_supertrend_oscillator(
895    range: &SuperTrendOscillatorBatchRange,
896) -> Result<Vec<SuperTrendOscillatorParams>, SuperTrendOscillatorError> {
897    fn axis_usize(
898        (start, end, step): (usize, usize, usize),
899    ) -> Result<Vec<usize>, SuperTrendOscillatorError> {
900        if step == 0 || start == end {
901            return Ok(vec![start]);
902        }
903        if start <= end {
904            let mut out = Vec::new();
905            let mut x = start;
906            while x <= end {
907                out.push(x);
908                match x.checked_add(step.max(1)) {
909                    Some(next) if next > x => x = next,
910                    _ => break,
911                }
912            }
913            if out.is_empty() {
914                return Err(SuperTrendOscillatorError::InvalidRange {
915                    start: start.to_string(),
916                    end: end.to_string(),
917                    step: step.to_string(),
918                });
919            }
920            return Ok(out);
921        }
922
923        let mut out = Vec::new();
924        let mut x = start;
925        let step = step.max(1);
926        while x >= end {
927            out.push(x);
928            if x == end {
929                break;
930            }
931            let next = x.saturating_sub(step);
932            if next == x || next < end {
933                break;
934            }
935            x = next;
936        }
937        if out.is_empty() {
938            return Err(SuperTrendOscillatorError::InvalidRange {
939                start: start.to_string(),
940                end: end.to_string(),
941                step: step.to_string(),
942            });
943        }
944        Ok(out)
945    }
946
947    fn axis_f64(
948        (start, end, step): (f64, f64, f64),
949    ) -> Result<Vec<f64>, SuperTrendOscillatorError> {
950        if !start.is_finite() || !end.is_finite() || !step.is_finite() {
951            return Err(SuperTrendOscillatorError::InvalidFloatRange { start, end, step });
952        }
953        if step.abs() < 1e-12 || (start - end).abs() < 1e-12 {
954            return Ok(vec![start]);
955        }
956        let step = step.abs();
957        let mut out = Vec::new();
958        if start <= end {
959            let mut x = start;
960            while x <= end + 1e-12 {
961                out.push(x);
962                x += step;
963            }
964        } else {
965            let mut x = start;
966            while x + 1e-12 >= end {
967                out.push(x);
968                x -= step;
969            }
970        }
971        if out.is_empty() {
972            return Err(SuperTrendOscillatorError::InvalidFloatRange { start, end, step });
973        }
974        Ok(out)
975    }
976
977    let lengths = axis_usize(range.length)?;
978    let mults = axis_f64(range.mult)?;
979    let smooths = axis_usize(range.smooth)?;
980
981    let cap = lengths
982        .len()
983        .checked_mul(mults.len())
984        .and_then(|value| value.checked_mul(smooths.len()))
985        .ok_or(SuperTrendOscillatorError::InvalidRange {
986            start: range.length.0.to_string(),
987            end: range.length.1.to_string(),
988            step: range.length.2.to_string(),
989        })?;
990
991    let mut out = Vec::with_capacity(cap);
992    for &length in &lengths {
993        for &mult in &mults {
994            for &smooth in &smooths {
995                out.push(SuperTrendOscillatorParams {
996                    length: Some(length),
997                    mult: Some(mult),
998                    smooth: Some(smooth),
999                });
1000            }
1001        }
1002    }
1003    Ok(out)
1004}
1005
1006#[inline]
1007pub fn supertrend_oscillator_batch_with_kernel(
1008    high: &[f64],
1009    low: &[f64],
1010    source: &[f64],
1011    sweep: &SuperTrendOscillatorBatchRange,
1012    kernel: Kernel,
1013) -> Result<SuperTrendOscillatorBatchOutput, SuperTrendOscillatorError> {
1014    let batch_kernel = match kernel {
1015        Kernel::Auto => detect_best_batch_kernel(),
1016        other if other.is_batch() => other,
1017        other => return Err(SuperTrendOscillatorError::InvalidKernelForBatch(other)),
1018    };
1019    supertrend_oscillator_batch_par_slice(high, low, source, sweep, batch_kernel.to_non_batch())
1020}
1021
1022#[inline]
1023pub fn supertrend_oscillator_batch_slice(
1024    high: &[f64],
1025    low: &[f64],
1026    source: &[f64],
1027    sweep: &SuperTrendOscillatorBatchRange,
1028    kernel: Kernel,
1029) -> Result<SuperTrendOscillatorBatchOutput, SuperTrendOscillatorError> {
1030    supertrend_oscillator_batch_inner(high, low, source, sweep, kernel, false)
1031}
1032
1033#[inline]
1034pub fn supertrend_oscillator_batch_par_slice(
1035    high: &[f64],
1036    low: &[f64],
1037    source: &[f64],
1038    sweep: &SuperTrendOscillatorBatchRange,
1039    kernel: Kernel,
1040) -> Result<SuperTrendOscillatorBatchOutput, SuperTrendOscillatorError> {
1041    supertrend_oscillator_batch_inner(high, low, source, sweep, kernel, true)
1042}
1043
1044fn supertrend_oscillator_batch_inner(
1045    high: &[f64],
1046    low: &[f64],
1047    source: &[f64],
1048    sweep: &SuperTrendOscillatorBatchRange,
1049    _kernel: Kernel,
1050    parallel: bool,
1051) -> Result<SuperTrendOscillatorBatchOutput, SuperTrendOscillatorError> {
1052    validate_lengths(high, low, source)?;
1053    let combos = expand_grid_supertrend_oscillator(sweep)?;
1054    let first_valid =
1055        first_valid_bar(high, low, source).ok_or(SuperTrendOscillatorError::AllValuesNaN)?;
1056    let max_run = max_valid_run(high, low, source);
1057    let max_length = combos
1058        .iter()
1059        .map(|params| params.length.unwrap_or(DEFAULT_LENGTH))
1060        .max()
1061        .unwrap_or(DEFAULT_LENGTH);
1062    if max_run < max_length {
1063        return Err(SuperTrendOscillatorError::NotEnoughValidData {
1064            needed: max_length,
1065            valid: max_run,
1066        });
1067    }
1068    for params in &combos {
1069        validate_params(
1070            params.length.unwrap_or(DEFAULT_LENGTH),
1071            params.mult.unwrap_or(DEFAULT_MULT),
1072            params.smooth.unwrap_or(DEFAULT_SMOOTH),
1073            source.len(),
1074        )?;
1075    }
1076
1077    let rows = combos.len();
1078    let cols = source.len();
1079    let total = rows
1080        .checked_mul(cols)
1081        .ok_or(SuperTrendOscillatorError::OutputLengthMismatch {
1082            expected: usize::MAX,
1083            got: 0,
1084        })?;
1085
1086    let mut oscillator_matrix = make_uninit_matrix(rows, cols);
1087    let mut signal_matrix = make_uninit_matrix(rows, cols);
1088    let mut histogram_matrix = make_uninit_matrix(rows, cols);
1089
1090    let warmups: Vec<usize> = combos
1091        .iter()
1092        .map(|params| warmup_end(first_valid, params.length.unwrap_or(DEFAULT_LENGTH)))
1093        .collect();
1094    init_matrix_prefixes(&mut oscillator_matrix, cols, &warmups);
1095    init_matrix_prefixes(&mut signal_matrix, cols, &warmups);
1096    init_matrix_prefixes(&mut histogram_matrix, cols, &warmups);
1097
1098    let mut oscillator_guard = ManuallyDrop::new(oscillator_matrix);
1099    let mut signal_guard = ManuallyDrop::new(signal_matrix);
1100    let mut histogram_guard = ManuallyDrop::new(histogram_matrix);
1101
1102    let oscillator_mu: &mut [MaybeUninit<f64>] = unsafe {
1103        std::slice::from_raw_parts_mut(oscillator_guard.as_mut_ptr(), oscillator_guard.len())
1104    };
1105    let signal_mu: &mut [MaybeUninit<f64>] =
1106        unsafe { std::slice::from_raw_parts_mut(signal_guard.as_mut_ptr(), signal_guard.len()) };
1107    let histogram_mu: &mut [MaybeUninit<f64>] = unsafe {
1108        std::slice::from_raw_parts_mut(histogram_guard.as_mut_ptr(), histogram_guard.len())
1109    };
1110
1111    let mut atr_cache: HashMap<usize, Vec<f64>> = HashMap::new();
1112    let mut lengths: Vec<usize> = combos
1113        .iter()
1114        .map(|params| params.length.unwrap_or(DEFAULT_LENGTH))
1115        .collect();
1116    lengths.sort_unstable();
1117    lengths.dedup();
1118    for length in lengths {
1119        atr_cache.insert(length, compute_atr_series(high, low, source, length));
1120    }
1121
1122    let do_row = |row: usize,
1123                  row_oscillator: &mut [MaybeUninit<f64>],
1124                  row_signal: &mut [MaybeUninit<f64>],
1125                  row_histogram: &mut [MaybeUninit<f64>]| {
1126        let params = &combos[row];
1127        let length = params.length.unwrap_or(DEFAULT_LENGTH);
1128        let mult = params.mult.unwrap_or(DEFAULT_MULT);
1129        let smooth = params.smooth.unwrap_or(DEFAULT_SMOOTH);
1130        let atr_values = atr_cache.get(&length).expect("cached atr");
1131
1132        let dst_oscillator = unsafe {
1133            std::slice::from_raw_parts_mut(row_oscillator.as_mut_ptr() as *mut f64, cols)
1134        };
1135        let dst_signal =
1136            unsafe { std::slice::from_raw_parts_mut(row_signal.as_mut_ptr() as *mut f64, cols) };
1137        let dst_histogram =
1138            unsafe { std::slice::from_raw_parts_mut(row_histogram.as_mut_ptr() as *mut f64, cols) };
1139
1140        supertrend_oscillator_row_scalar(
1141            high,
1142            low,
1143            source,
1144            length,
1145            mult,
1146            smooth,
1147            atr_values,
1148            dst_oscillator,
1149            dst_signal,
1150            dst_histogram,
1151        );
1152    };
1153
1154    if parallel {
1155        #[cfg(not(target_arch = "wasm32"))]
1156        oscillator_mu
1157            .par_chunks_mut(cols)
1158            .zip(signal_mu.par_chunks_mut(cols))
1159            .zip(histogram_mu.par_chunks_mut(cols))
1160            .enumerate()
1161            .for_each(|(row, ((row_oscillator, row_signal), row_histogram))| {
1162                do_row(row, row_oscillator, row_signal, row_histogram)
1163            });
1164
1165        #[cfg(target_arch = "wasm32")]
1166        for (row, ((row_oscillator, row_signal), row_histogram)) in oscillator_mu
1167            .chunks_mut(cols)
1168            .zip(signal_mu.chunks_mut(cols))
1169            .zip(histogram_mu.chunks_mut(cols))
1170            .enumerate()
1171        {
1172            do_row(row, row_oscillator, row_signal, row_histogram);
1173        }
1174    } else {
1175        for (row, ((row_oscillator, row_signal), row_histogram)) in oscillator_mu
1176            .chunks_mut(cols)
1177            .zip(signal_mu.chunks_mut(cols))
1178            .zip(histogram_mu.chunks_mut(cols))
1179            .enumerate()
1180        {
1181            do_row(row, row_oscillator, row_signal, row_histogram);
1182        }
1183    }
1184
1185    let oscillator = unsafe {
1186        Vec::from_raw_parts(
1187            oscillator_guard.as_mut_ptr() as *mut f64,
1188            total,
1189            oscillator_guard.capacity(),
1190        )
1191    };
1192    let signal = unsafe {
1193        Vec::from_raw_parts(
1194            signal_guard.as_mut_ptr() as *mut f64,
1195            total,
1196            signal_guard.capacity(),
1197        )
1198    };
1199    let histogram = unsafe {
1200        Vec::from_raw_parts(
1201            histogram_guard.as_mut_ptr() as *mut f64,
1202            total,
1203            histogram_guard.capacity(),
1204        )
1205    };
1206
1207    Ok(SuperTrendOscillatorBatchOutput {
1208        oscillator,
1209        signal,
1210        histogram,
1211        combos,
1212        rows,
1213        cols,
1214    })
1215}
1216
1217fn supertrend_oscillator_batch_inner_into(
1218    high: &[f64],
1219    low: &[f64],
1220    source: &[f64],
1221    sweep: &SuperTrendOscillatorBatchRange,
1222    kernel: Kernel,
1223    parallel: bool,
1224    out_oscillator: &mut [f64],
1225    out_signal: &mut [f64],
1226    out_histogram: &mut [f64],
1227) -> Result<Vec<SuperTrendOscillatorParams>, SuperTrendOscillatorError> {
1228    validate_lengths(high, low, source)?;
1229    let combos = expand_grid_supertrend_oscillator(sweep)?;
1230    let max_run = max_valid_run(high, low, source);
1231    let max_length = combos
1232        .iter()
1233        .map(|params| params.length.unwrap_or(DEFAULT_LENGTH))
1234        .max()
1235        .unwrap_or(DEFAULT_LENGTH);
1236    if max_run < max_length {
1237        return Err(SuperTrendOscillatorError::NotEnoughValidData {
1238            needed: max_length,
1239            valid: max_run,
1240        });
1241    }
1242
1243    let rows = combos.len();
1244    let cols = source.len();
1245    let total = rows
1246        .checked_mul(cols)
1247        .ok_or(SuperTrendOscillatorError::OutputLengthMismatch {
1248            expected: usize::MAX,
1249            got: 0,
1250        })?;
1251    if out_oscillator.len() != total || out_signal.len() != total || out_histogram.len() != total {
1252        return Err(SuperTrendOscillatorError::OutputLengthMismatch {
1253            expected: total,
1254            got: out_oscillator
1255                .len()
1256                .max(out_signal.len())
1257                .max(out_histogram.len()),
1258        });
1259    }
1260
1261    let mut atr_cache: HashMap<usize, Vec<f64>> = HashMap::new();
1262    for params in &combos {
1263        let length = params.length.unwrap_or(DEFAULT_LENGTH);
1264        validate_params(
1265            length,
1266            params.mult.unwrap_or(DEFAULT_MULT),
1267            params.smooth.unwrap_or(DEFAULT_SMOOTH),
1268            cols,
1269        )?;
1270        atr_cache
1271            .entry(length)
1272            .or_insert_with(|| compute_atr_series(high, low, source, length));
1273    }
1274
1275    let _ = kernel;
1276    let do_row = |row: usize,
1277                  dst_oscillator: &mut [f64],
1278                  dst_signal: &mut [f64],
1279                  dst_histogram: &mut [f64]| {
1280        let params = &combos[row];
1281        let length = params.length.unwrap_or(DEFAULT_LENGTH);
1282        let mult = params.mult.unwrap_or(DEFAULT_MULT);
1283        let smooth = params.smooth.unwrap_or(DEFAULT_SMOOTH);
1284        let atr_values = atr_cache.get(&length).expect("cached atr");
1285
1286        supertrend_oscillator_row_scalar(
1287            high,
1288            low,
1289            source,
1290            length,
1291            mult,
1292            smooth,
1293            atr_values,
1294            dst_oscillator,
1295            dst_signal,
1296            dst_histogram,
1297        );
1298    };
1299
1300    if parallel {
1301        #[cfg(not(target_arch = "wasm32"))]
1302        out_oscillator
1303            .par_chunks_mut(cols)
1304            .zip(out_signal.par_chunks_mut(cols))
1305            .zip(out_histogram.par_chunks_mut(cols))
1306            .enumerate()
1307            .for_each(|(row, ((dst_oscillator, dst_signal), dst_histogram))| {
1308                do_row(row, dst_oscillator, dst_signal, dst_histogram)
1309            });
1310
1311        #[cfg(target_arch = "wasm32")]
1312        for (row, ((dst_oscillator, dst_signal), dst_histogram)) in out_oscillator
1313            .chunks_mut(cols)
1314            .zip(out_signal.chunks_mut(cols))
1315            .zip(out_histogram.chunks_mut(cols))
1316            .enumerate()
1317        {
1318            do_row(row, dst_oscillator, dst_signal, dst_histogram);
1319        }
1320    } else {
1321        for (row, ((dst_oscillator, dst_signal), dst_histogram)) in out_oscillator
1322            .chunks_mut(cols)
1323            .zip(out_signal.chunks_mut(cols))
1324            .zip(out_histogram.chunks_mut(cols))
1325            .enumerate()
1326        {
1327            do_row(row, dst_oscillator, dst_signal, dst_histogram);
1328        }
1329    }
1330
1331    Ok(combos)
1332}
1333
1334#[cfg(feature = "python")]
1335#[pyfunction(name = "supertrend_oscillator")]
1336#[pyo3(signature = (high, low, source, length=DEFAULT_LENGTH, mult=DEFAULT_MULT, smooth=DEFAULT_SMOOTH, kernel=None))]
1337pub fn supertrend_oscillator_py<'py>(
1338    py: Python<'py>,
1339    high: PyReadonlyArray1<'py, f64>,
1340    low: PyReadonlyArray1<'py, f64>,
1341    source: PyReadonlyArray1<'py, f64>,
1342    length: usize,
1343    mult: f64,
1344    smooth: usize,
1345    kernel: Option<&str>,
1346) -> PyResult<(
1347    Bound<'py, PyArray1<f64>>,
1348    Bound<'py, PyArray1<f64>>,
1349    Bound<'py, PyArray1<f64>>,
1350)> {
1351    let high = high.as_slice()?;
1352    let low = low.as_slice()?;
1353    let source = source.as_slice()?;
1354    let input = SuperTrendOscillatorInput::from_slices(
1355        high,
1356        low,
1357        source,
1358        SuperTrendOscillatorParams {
1359            length: Some(length),
1360            mult: Some(mult),
1361            smooth: Some(smooth),
1362        },
1363    );
1364    let kernel = validate_kernel(kernel, false)?;
1365    let out = py
1366        .allow_threads(|| supertrend_oscillator_with_kernel(&input, kernel))
1367        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1368    Ok((
1369        out.oscillator.into_pyarray(py),
1370        out.signal.into_pyarray(py),
1371        out.histogram.into_pyarray(py),
1372    ))
1373}
1374
1375#[cfg(feature = "python")]
1376#[pyclass(name = "SuperTrendOscillatorStream")]
1377pub struct SuperTrendOscillatorStreamPy {
1378    stream: SuperTrendOscillatorStream,
1379}
1380
1381#[cfg(feature = "python")]
1382#[pymethods]
1383impl SuperTrendOscillatorStreamPy {
1384    #[new]
1385    #[pyo3(signature = (length=DEFAULT_LENGTH, mult=DEFAULT_MULT, smooth=DEFAULT_SMOOTH))]
1386    fn new(length: usize, mult: f64, smooth: usize) -> PyResult<Self> {
1387        let stream = SuperTrendOscillatorStream::try_new(SuperTrendOscillatorParams {
1388            length: Some(length),
1389            mult: Some(mult),
1390            smooth: Some(smooth),
1391        })
1392        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1393        Ok(Self { stream })
1394    }
1395
1396    fn update(&mut self, high: f64, low: f64, source: f64) -> Option<(f64, f64, f64)> {
1397        self.stream.update(high, low, source)
1398    }
1399}
1400
1401#[cfg(feature = "python")]
1402#[pyfunction(name = "supertrend_oscillator_batch")]
1403#[pyo3(signature = (high, low, source, length_range, mult_range, smooth_range, kernel=None))]
1404pub fn supertrend_oscillator_batch_py<'py>(
1405    py: Python<'py>,
1406    high: PyReadonlyArray1<'py, f64>,
1407    low: PyReadonlyArray1<'py, f64>,
1408    source: PyReadonlyArray1<'py, f64>,
1409    length_range: (usize, usize, usize),
1410    mult_range: (f64, f64, f64),
1411    smooth_range: (usize, usize, usize),
1412    kernel: Option<&str>,
1413) -> PyResult<Bound<'py, PyDict>> {
1414    let high = high.as_slice()?;
1415    let low = low.as_slice()?;
1416    let source = source.as_slice()?;
1417    let sweep = SuperTrendOscillatorBatchRange {
1418        length: length_range,
1419        mult: mult_range,
1420        smooth: smooth_range,
1421    };
1422    let combos = expand_grid_supertrend_oscillator(&sweep)
1423        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1424    let rows = combos.len();
1425    let cols = source.len();
1426    let total = rows
1427        .checked_mul(cols)
1428        .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
1429    let oscillator_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1430    let signal_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1431    let histogram_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1432    let out_oscillator = unsafe { oscillator_arr.as_slice_mut()? };
1433    let out_signal = unsafe { signal_arr.as_slice_mut()? };
1434    let out_histogram = unsafe { histogram_arr.as_slice_mut()? };
1435    let kernel = validate_kernel(kernel, true)?;
1436
1437    py.allow_threads(|| {
1438        let batch_kernel = match kernel {
1439            Kernel::Auto => detect_best_batch_kernel(),
1440            other => other,
1441        };
1442        supertrend_oscillator_batch_inner_into(
1443            high,
1444            low,
1445            source,
1446            &sweep,
1447            batch_kernel.to_non_batch(),
1448            true,
1449            out_oscillator,
1450            out_signal,
1451            out_histogram,
1452        )
1453    })
1454    .map_err(|e| PyValueError::new_err(e.to_string()))?;
1455
1456    let lengths: Vec<usize> = combos
1457        .iter()
1458        .map(|params| params.length.unwrap_or(DEFAULT_LENGTH))
1459        .collect();
1460    let mults: Vec<f64> = combos
1461        .iter()
1462        .map(|params| params.mult.unwrap_or(DEFAULT_MULT))
1463        .collect();
1464    let smooths: Vec<usize> = combos
1465        .iter()
1466        .map(|params| params.smooth.unwrap_or(DEFAULT_SMOOTH))
1467        .collect();
1468
1469    let dict = PyDict::new(py);
1470    dict.set_item("oscillator", oscillator_arr.reshape((rows, cols))?)?;
1471    dict.set_item("signal", signal_arr.reshape((rows, cols))?)?;
1472    dict.set_item("histogram", histogram_arr.reshape((rows, cols))?)?;
1473    dict.set_item("rows", rows)?;
1474    dict.set_item("cols", cols)?;
1475    dict.set_item("lengths", lengths.into_pyarray(py))?;
1476    dict.set_item("mults", mults.into_pyarray(py))?;
1477    dict.set_item("smooths", smooths.into_pyarray(py))?;
1478    Ok(dict)
1479}
1480
1481#[cfg(feature = "python")]
1482pub fn register_supertrend_oscillator_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
1483    m.add_function(wrap_pyfunction!(supertrend_oscillator_py, m)?)?;
1484    m.add_function(wrap_pyfunction!(supertrend_oscillator_batch_py, m)?)?;
1485    m.add_class::<SuperTrendOscillatorStreamPy>()?;
1486    Ok(())
1487}
1488
1489#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1490#[derive(Debug, Clone, Serialize, Deserialize)]
1491struct SuperTrendOscillatorJsOutput {
1492    oscillator: Vec<f64>,
1493    signal: Vec<f64>,
1494    histogram: Vec<f64>,
1495}
1496
1497#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1498#[derive(Debug, Clone, Serialize, Deserialize)]
1499struct SuperTrendOscillatorBatchConfig {
1500    length_range: Vec<usize>,
1501    mult_range: Vec<f64>,
1502    smooth_range: Vec<usize>,
1503}
1504
1505#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1506#[derive(Debug, Clone, Serialize, Deserialize)]
1507struct SuperTrendOscillatorBatchJsOutput {
1508    oscillator: Vec<f64>,
1509    signal: Vec<f64>,
1510    histogram: Vec<f64>,
1511    rows: usize,
1512    cols: usize,
1513    combos: Vec<SuperTrendOscillatorParams>,
1514}
1515
1516#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1517#[wasm_bindgen(js_name = "supertrend_oscillator")]
1518pub fn supertrend_oscillator_js(
1519    high: &[f64],
1520    low: &[f64],
1521    source: &[f64],
1522    length: usize,
1523    mult: f64,
1524    smooth: usize,
1525) -> Result<JsValue, JsValue> {
1526    let input = SuperTrendOscillatorInput::from_slices(
1527        high,
1528        low,
1529        source,
1530        SuperTrendOscillatorParams {
1531            length: Some(length),
1532            mult: Some(mult),
1533            smooth: Some(smooth),
1534        },
1535    );
1536    let out = supertrend_oscillator(&input).map_err(|e| JsValue::from_str(&e.to_string()))?;
1537    serde_wasm_bindgen::to_value(&SuperTrendOscillatorJsOutput {
1538        oscillator: out.oscillator,
1539        signal: out.signal,
1540        histogram: out.histogram,
1541    })
1542    .map_err(|e| JsValue::from_str(&format!("Serialization error: {e}")))
1543}
1544
1545#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1546#[wasm_bindgen]
1547pub fn supertrend_oscillator_into(
1548    high_ptr: *const f64,
1549    low_ptr: *const f64,
1550    source_ptr: *const f64,
1551    out_ptr: *mut f64,
1552    len: usize,
1553    length: usize,
1554    mult: f64,
1555    smooth: usize,
1556) -> Result<(), JsValue> {
1557    if high_ptr.is_null() || low_ptr.is_null() || source_ptr.is_null() || out_ptr.is_null() {
1558        return Err(JsValue::from_str(
1559            "null pointer passed to supertrend_oscillator_into",
1560        ));
1561    }
1562
1563    unsafe {
1564        let high = std::slice::from_raw_parts(high_ptr, len);
1565        let low = std::slice::from_raw_parts(low_ptr, len);
1566        let source = std::slice::from_raw_parts(source_ptr, len);
1567        let out = std::slice::from_raw_parts_mut(out_ptr, len * 3);
1568        let (out_oscillator, rest) = out.split_at_mut(len);
1569        let (out_signal, out_histogram) = rest.split_at_mut(len);
1570        let input = SuperTrendOscillatorInput::from_slices(
1571            high,
1572            low,
1573            source,
1574            SuperTrendOscillatorParams {
1575                length: Some(length),
1576                mult: Some(mult),
1577                smooth: Some(smooth),
1578            },
1579        );
1580        supertrend_oscillator_into_slice(
1581            out_oscillator,
1582            out_signal,
1583            out_histogram,
1584            &input,
1585            Kernel::Auto,
1586        )
1587        .map_err(|e| JsValue::from_str(&e.to_string()))
1588    }
1589}
1590
1591#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1592#[wasm_bindgen(js_name = "supertrend_oscillator_into_host")]
1593pub fn supertrend_oscillator_into_host(
1594    high: &[f64],
1595    low: &[f64],
1596    source: &[f64],
1597    out_ptr: *mut f64,
1598    length: usize,
1599    mult: f64,
1600    smooth: usize,
1601) -> Result<(), JsValue> {
1602    if out_ptr.is_null() {
1603        return Err(JsValue::from_str(
1604            "null pointer passed to supertrend_oscillator_into_host",
1605        ));
1606    }
1607
1608    unsafe {
1609        let out = std::slice::from_raw_parts_mut(out_ptr, source.len() * 3);
1610        let (out_oscillator, rest) = out.split_at_mut(source.len());
1611        let (out_signal, out_histogram) = rest.split_at_mut(source.len());
1612        let input = SuperTrendOscillatorInput::from_slices(
1613            high,
1614            low,
1615            source,
1616            SuperTrendOscillatorParams {
1617                length: Some(length),
1618                mult: Some(mult),
1619                smooth: Some(smooth),
1620            },
1621        );
1622        supertrend_oscillator_into_slice(
1623            out_oscillator,
1624            out_signal,
1625            out_histogram,
1626            &input,
1627            Kernel::Auto,
1628        )
1629        .map_err(|e| JsValue::from_str(&e.to_string()))
1630    }
1631}
1632
1633#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1634#[wasm_bindgen]
1635pub fn supertrend_oscillator_alloc(len: usize) -> *mut f64 {
1636    let mut buf = vec![0.0_f64; len * 3];
1637    let ptr = buf.as_mut_ptr();
1638    std::mem::forget(buf);
1639    ptr
1640}
1641
1642#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1643#[wasm_bindgen]
1644pub fn supertrend_oscillator_free(ptr: *mut f64, len: usize) {
1645    if ptr.is_null() {
1646        return;
1647    }
1648    unsafe {
1649        let _ = Vec::from_raw_parts(ptr, len * 3, len * 3);
1650    }
1651}
1652
1653#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1654#[wasm_bindgen(js_name = "supertrend_oscillator_batch")]
1655pub fn supertrend_oscillator_batch_js(
1656    high: &[f64],
1657    low: &[f64],
1658    source: &[f64],
1659    config: JsValue,
1660) -> Result<JsValue, JsValue> {
1661    let config: SuperTrendOscillatorBatchConfig = serde_wasm_bindgen::from_value(config)
1662        .map_err(|e| JsValue::from_str(&format!("Invalid config: {e}")))?;
1663    if config.length_range.len() != 3
1664        || config.mult_range.len() != 3
1665        || config.smooth_range.len() != 3
1666    {
1667        return Err(JsValue::from_str(
1668            "Invalid config: ranges must have exactly 3 elements [start, end, step]",
1669        ));
1670    }
1671
1672    let sweep = SuperTrendOscillatorBatchRange {
1673        length: (
1674            config.length_range[0],
1675            config.length_range[1],
1676            config.length_range[2],
1677        ),
1678        mult: (
1679            config.mult_range[0],
1680            config.mult_range[1],
1681            config.mult_range[2],
1682        ),
1683        smooth: (
1684            config.smooth_range[0],
1685            config.smooth_range[1],
1686            config.smooth_range[2],
1687        ),
1688    };
1689    let batch = supertrend_oscillator_batch_slice(high, low, source, &sweep, Kernel::Scalar)
1690        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1691    serde_wasm_bindgen::to_value(&SuperTrendOscillatorBatchJsOutput {
1692        oscillator: batch.oscillator,
1693        signal: batch.signal,
1694        histogram: batch.histogram,
1695        rows: batch.rows,
1696        cols: batch.cols,
1697        combos: batch.combos,
1698    })
1699    .map_err(|e| JsValue::from_str(&format!("Serialization error: {e}")))
1700}
1701
1702#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1703#[wasm_bindgen]
1704pub fn supertrend_oscillator_batch_into(
1705    high_ptr: *const f64,
1706    low_ptr: *const f64,
1707    source_ptr: *const f64,
1708    oscillator_ptr: *mut f64,
1709    signal_ptr: *mut f64,
1710    histogram_ptr: *mut f64,
1711    len: usize,
1712    length_start: usize,
1713    length_end: usize,
1714    length_step: usize,
1715    mult_start: f64,
1716    mult_end: f64,
1717    mult_step: f64,
1718    smooth_start: usize,
1719    smooth_end: usize,
1720    smooth_step: usize,
1721) -> Result<usize, JsValue> {
1722    if high_ptr.is_null()
1723        || low_ptr.is_null()
1724        || source_ptr.is_null()
1725        || oscillator_ptr.is_null()
1726        || signal_ptr.is_null()
1727        || histogram_ptr.is_null()
1728    {
1729        return Err(JsValue::from_str(
1730            "null pointer passed to supertrend_oscillator_batch_into",
1731        ));
1732    }
1733
1734    unsafe {
1735        let high = std::slice::from_raw_parts(high_ptr, len);
1736        let low = std::slice::from_raw_parts(low_ptr, len);
1737        let source = std::slice::from_raw_parts(source_ptr, len);
1738        let sweep = SuperTrendOscillatorBatchRange {
1739            length: (length_start, length_end, length_step),
1740            mult: (mult_start, mult_end, mult_step),
1741            smooth: (smooth_start, smooth_end, smooth_step),
1742        };
1743        let combos = expand_grid_supertrend_oscillator(&sweep)
1744            .map_err(|e| JsValue::from_str(&e.to_string()))?;
1745        let rows = combos.len();
1746        let total = rows
1747            .checked_mul(len)
1748            .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
1749        let oscillator = std::slice::from_raw_parts_mut(oscillator_ptr, total);
1750        let signal = std::slice::from_raw_parts_mut(signal_ptr, total);
1751        let histogram = std::slice::from_raw_parts_mut(histogram_ptr, total);
1752        supertrend_oscillator_batch_inner_into(
1753            high,
1754            low,
1755            source,
1756            &sweep,
1757            Kernel::Scalar,
1758            false,
1759            oscillator,
1760            signal,
1761            histogram,
1762        )
1763        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1764        Ok(rows)
1765    }
1766}
1767
1768#[cfg(test)]
1769mod tests {
1770    use super::*;
1771    use crate::indicators::dispatch::{
1772        compute_cpu_batch, IndicatorBatchRequest, IndicatorDataRef, IndicatorParamSet, ParamKV,
1773        ParamValue,
1774    };
1775
1776    fn assert_close(a: &[f64], b: &[f64], tol: f64) {
1777        assert_eq!(a.len(), b.len());
1778        for i in 0..a.len() {
1779            let lhs = a[i];
1780            let rhs = b[i];
1781            if lhs.is_nan() || rhs.is_nan() {
1782                assert!(
1783                    lhs.is_nan() && rhs.is_nan(),
1784                    "nan mismatch at {i}: {lhs} vs {rhs}"
1785                );
1786            } else {
1787                assert!(
1788                    (lhs - rhs).abs() <= tol,
1789                    "mismatch at {i}: {lhs} vs {rhs} with tol {tol}"
1790                );
1791            }
1792        }
1793    }
1794
1795    fn sample_hls(len: usize) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
1796        let mut high = Vec::with_capacity(len);
1797        let mut low = Vec::with_capacity(len);
1798        let mut source = Vec::with_capacity(len);
1799
1800        for i in 0..len {
1801            let base = 100.0 + i as f64 * 0.21 + (i as f64 * 0.17).sin() * 2.0;
1802            let spread = 1.5 + (i as f64 * 0.11).cos().abs() * 1.25;
1803            let src = base + (i as f64 * 0.07).cos() * 0.6;
1804            high.push(base + spread);
1805            low.push(base - spread);
1806            source.push(src);
1807        }
1808
1809        (high, low, source)
1810    }
1811
1812    fn check_output_contract(kernel: Kernel) {
1813        let (high, low, source) = sample_hls(192);
1814        let input = SuperTrendOscillatorInput::from_slices(
1815            &high,
1816            &low,
1817            &source,
1818            SuperTrendOscillatorParams {
1819                length: Some(10),
1820                mult: Some(2.0),
1821                smooth: Some(72),
1822            },
1823        );
1824        let out = supertrend_oscillator_with_kernel(&input, kernel).expect("indicator");
1825        assert_eq!(out.oscillator.len(), source.len());
1826        assert_eq!(out.signal.len(), source.len());
1827        assert_eq!(out.histogram.len(), source.len());
1828        assert!(out.oscillator[..9].iter().all(|v| v.is_nan()));
1829        assert!(out.signal[..9].iter().all(|v| v.is_nan()));
1830        assert!(out.histogram[..9].iter().all(|v| v.is_nan()));
1831        assert!(out.oscillator[9..].iter().any(|v| v.is_finite()));
1832        assert!(out.signal[9..].iter().any(|v| v.is_finite()));
1833        assert!(out.histogram[9..].iter().any(|v| v.is_finite()));
1834    }
1835
1836    fn check_into_matches_api(kernel: Kernel) {
1837        let (high, low, source) = sample_hls(224);
1838        let input = SuperTrendOscillatorInput::from_slices(
1839            &high,
1840            &low,
1841            &source,
1842            SuperTrendOscillatorParams {
1843                length: Some(11),
1844                mult: Some(2.5),
1845                smooth: Some(20),
1846            },
1847        );
1848        let baseline = supertrend_oscillator_with_kernel(&input, kernel).expect("baseline");
1849        let mut oscillator = vec![0.0; source.len()];
1850        let mut signal = vec![0.0; source.len()];
1851        let mut histogram = vec![0.0; source.len()];
1852        supertrend_oscillator_into_slice(
1853            &mut oscillator,
1854            &mut signal,
1855            &mut histogram,
1856            &input,
1857            kernel,
1858        )
1859        .expect("into");
1860
1861        assert_close(&baseline.oscillator, &oscillator, 1e-12);
1862        assert_close(&baseline.signal, &signal, 1e-12);
1863        assert_close(&baseline.histogram, &histogram, 1e-12);
1864    }
1865
1866    fn check_stream_matches_batch() {
1867        let (high, low, source) = sample_hls(200);
1868        let input = SuperTrendOscillatorInput::from_slices(
1869            &high,
1870            &low,
1871            &source,
1872            SuperTrendOscillatorParams {
1873                length: Some(12),
1874                mult: Some(1.75),
1875                smooth: Some(18),
1876            },
1877        );
1878        let batch = supertrend_oscillator(&input).expect("batch");
1879        let mut stream = SuperTrendOscillatorStream::try_new(SuperTrendOscillatorParams {
1880            length: Some(12),
1881            mult: Some(1.75),
1882            smooth: Some(18),
1883        })
1884        .expect("stream");
1885
1886        let mut oscillator = vec![f64::NAN; source.len()];
1887        let mut signal = vec![f64::NAN; source.len()];
1888        let mut histogram = vec![f64::NAN; source.len()];
1889        for i in 0..source.len() {
1890            if let Some((osc, sig, hist)) = stream.update(high[i], low[i], source[i]) {
1891                oscillator[i] = osc;
1892                signal[i] = sig;
1893                histogram[i] = hist;
1894            }
1895        }
1896
1897        assert_close(&batch.oscillator, &oscillator, 1e-12);
1898        assert_close(&batch.signal, &signal, 1e-12);
1899        assert_close(&batch.histogram, &histogram, 1e-12);
1900    }
1901
1902    fn check_batch_single_matches_single(kernel: Kernel) {
1903        let (high, low, source) = sample_hls(180);
1904        let batch = supertrend_oscillator_batch_with_kernel(
1905            &high,
1906            &low,
1907            &source,
1908            &SuperTrendOscillatorBatchRange {
1909                length: (12, 12, 0),
1910                mult: (2.5, 2.5, 0.0),
1911                smooth: (18, 18, 0),
1912            },
1913            kernel,
1914        )
1915        .expect("batch");
1916        let single = supertrend_oscillator(&SuperTrendOscillatorInput::from_slices(
1917            &high,
1918            &low,
1919            &source,
1920            SuperTrendOscillatorParams {
1921                length: Some(12),
1922                mult: Some(2.5),
1923                smooth: Some(18),
1924            },
1925        ))
1926        .expect("single");
1927
1928        assert_eq!(batch.rows, 1);
1929        assert_eq!(batch.cols, source.len());
1930        assert_close(&batch.oscillator[..source.len()], &single.oscillator, 1e-12);
1931        assert_close(&batch.signal[..source.len()], &single.signal, 1e-12);
1932        assert_close(&batch.histogram[..source.len()], &single.histogram, 1e-12);
1933    }
1934
1935    #[test]
1936    fn supertrend_oscillator_invalid_params() {
1937        let (high, low, source) = sample_hls(64);
1938
1939        let err = supertrend_oscillator(&SuperTrendOscillatorInput::from_slices(
1940            &high,
1941            &low,
1942            &source,
1943            SuperTrendOscillatorParams {
1944                length: Some(0),
1945                mult: Some(2.0),
1946                smooth: Some(10),
1947            },
1948        ))
1949        .expect_err("invalid length");
1950        assert!(matches!(
1951            err,
1952            SuperTrendOscillatorError::InvalidLength { .. }
1953        ));
1954
1955        let err = supertrend_oscillator(&SuperTrendOscillatorInput::from_slices(
1956            &high,
1957            &low,
1958            &source,
1959            SuperTrendOscillatorParams {
1960                length: Some(10),
1961                mult: Some(0.0),
1962                smooth: Some(10),
1963            },
1964        ))
1965        .expect_err("invalid mult");
1966        assert!(matches!(
1967            err,
1968            SuperTrendOscillatorError::InvalidMultiplier { .. }
1969        ));
1970
1971        let err = supertrend_oscillator(&SuperTrendOscillatorInput::from_slices(
1972            &high,
1973            &low,
1974            &source,
1975            SuperTrendOscillatorParams {
1976                length: Some(10),
1977                mult: Some(2.0),
1978                smooth: Some(0),
1979            },
1980        ))
1981        .expect_err("invalid smooth");
1982        assert!(matches!(
1983            err,
1984            SuperTrendOscillatorError::InvalidSmooth { .. }
1985        ));
1986    }
1987
1988    #[test]
1989    fn supertrend_oscillator_dispatch_matches_direct() {
1990        let (high, low, source) = sample_hls(160);
1991        let combo = [
1992            ParamKV {
1993                key: "length",
1994                value: ParamValue::Int(12),
1995            },
1996            ParamKV {
1997                key: "mult",
1998                value: ParamValue::Float(2.5),
1999            },
2000            ParamKV {
2001                key: "smooth",
2002                value: ParamValue::Int(18),
2003            },
2004        ];
2005        let combos = [IndicatorParamSet { params: &combo }];
2006        let req = IndicatorBatchRequest {
2007            indicator_id: "supertrend_oscillator",
2008            output_id: Some("oscillator"),
2009            data: IndicatorDataRef::Ohlc {
2010                open: &source,
2011                high: &high,
2012                low: &low,
2013                close: &source,
2014            },
2015            combos: &combos,
2016            kernel: Kernel::Auto,
2017        };
2018        let out = compute_cpu_batch(req).expect("dispatch");
2019        let direct = supertrend_oscillator(&SuperTrendOscillatorInput::from_slices(
2020            &high,
2021            &low,
2022            &source,
2023            SuperTrendOscillatorParams {
2024                length: Some(12),
2025                mult: Some(2.5),
2026                smooth: Some(18),
2027            },
2028        ))
2029        .expect("direct");
2030        assert_eq!(out.rows, 1);
2031        assert_eq!(out.cols, source.len());
2032        assert_close(&out.values_f64.expect("values"), &direct.oscillator, 1e-12);
2033    }
2034
2035    macro_rules! gen_kernel_tests {
2036        ($module:ident, $kernel:expr, $batch_kernel:expr) => {
2037            mod $module {
2038                use super::*;
2039
2040                #[test]
2041                fn output_contract() {
2042                    check_output_contract($kernel);
2043                }
2044
2045                #[test]
2046                fn into_matches_api() {
2047                    check_into_matches_api($kernel);
2048                }
2049
2050                #[test]
2051                fn batch_single_matches_single() {
2052                    check_batch_single_matches_single($batch_kernel);
2053                }
2054            }
2055        };
2056    }
2057
2058    gen_kernel_tests!(scalar_kernel, Kernel::Scalar, Kernel::ScalarBatch);
2059    gen_kernel_tests!(auto_kernel, Kernel::Auto, Kernel::Auto);
2060
2061    #[test]
2062    fn supertrend_oscillator_stream_matches_batch() {
2063        check_stream_matches_batch();
2064    }
2065}