1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::indicators::atr::{AtrParams, AtrStream};
16use crate::utilities::data_loader::{source_type, Candles};
17use crate::utilities::enums::Kernel;
18use crate::utilities::helpers::{
19 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
20 make_uninit_matrix,
21};
22#[cfg(feature = "python")]
23use crate::utilities::kernel_validation::validate_kernel;
24
25#[cfg(not(target_arch = "wasm32"))]
26use rayon::prelude::*;
27use std::collections::HashMap;
28use std::mem::{ManuallyDrop, MaybeUninit};
29use thiserror::Error;
30
31const DEFAULT_LENGTH: usize = 10;
32const DEFAULT_MULT: f64 = 2.0;
33const DEFAULT_SMOOTH: usize = 72;
34const OUTPUT_SCALE: f64 = 100.0;
35
36#[derive(Debug, Clone)]
37pub enum SuperTrendOscillatorData<'a> {
38 Candles {
39 candles: &'a Candles,
40 source: &'a str,
41 },
42 Slices {
43 high: &'a [f64],
44 low: &'a [f64],
45 source: &'a [f64],
46 },
47}
48
49#[derive(Debug, Clone)]
50pub struct SuperTrendOscillatorOutput {
51 pub oscillator: Vec<f64>,
52 pub signal: Vec<f64>,
53 pub histogram: Vec<f64>,
54}
55
56#[derive(Debug, Clone)]
57#[cfg_attr(
58 all(target_arch = "wasm32", feature = "wasm"),
59 derive(Serialize, Deserialize)
60)]
61pub struct SuperTrendOscillatorParams {
62 pub length: Option<usize>,
63 pub mult: Option<f64>,
64 pub smooth: Option<usize>,
65}
66
67impl Default for SuperTrendOscillatorParams {
68 fn default() -> Self {
69 Self {
70 length: Some(DEFAULT_LENGTH),
71 mult: Some(DEFAULT_MULT),
72 smooth: Some(DEFAULT_SMOOTH),
73 }
74 }
75}
76
77#[derive(Debug, Clone)]
78pub struct SuperTrendOscillatorInput<'a> {
79 pub data: SuperTrendOscillatorData<'a>,
80 pub params: SuperTrendOscillatorParams,
81}
82
83impl<'a> SuperTrendOscillatorInput<'a> {
84 #[inline]
85 pub fn from_candles(
86 candles: &'a Candles,
87 source: &'a str,
88 params: SuperTrendOscillatorParams,
89 ) -> Self {
90 Self {
91 data: SuperTrendOscillatorData::Candles { candles, source },
92 params,
93 }
94 }
95
96 #[inline]
97 pub fn from_slices(
98 high: &'a [f64],
99 low: &'a [f64],
100 source: &'a [f64],
101 params: SuperTrendOscillatorParams,
102 ) -> Self {
103 Self {
104 data: SuperTrendOscillatorData::Slices { high, low, source },
105 params,
106 }
107 }
108
109 #[inline]
110 pub fn with_default_candles(candles: &'a Candles) -> Self {
111 Self::from_candles(candles, "close", SuperTrendOscillatorParams::default())
112 }
113
114 #[inline]
115 pub fn get_length(&self) -> usize {
116 self.params.length.unwrap_or(DEFAULT_LENGTH)
117 }
118
119 #[inline]
120 pub fn get_mult(&self) -> f64 {
121 self.params.mult.unwrap_or(DEFAULT_MULT)
122 }
123
124 #[inline]
125 pub fn get_smooth(&self) -> usize {
126 self.params.smooth.unwrap_or(DEFAULT_SMOOTH)
127 }
128
129 #[inline]
130 pub fn as_refs(&'a self) -> (&'a [f64], &'a [f64], &'a [f64]) {
131 match &self.data {
132 SuperTrendOscillatorData::Candles { candles, source } => (
133 candles.high.as_slice(),
134 candles.low.as_slice(),
135 source_type(candles, source),
136 ),
137 SuperTrendOscillatorData::Slices { high, low, source } => (*high, *low, *source),
138 }
139 }
140}
141
142#[derive(Clone, Debug)]
143pub struct SuperTrendOscillatorBuilder {
144 length: Option<usize>,
145 mult: Option<f64>,
146 smooth: Option<usize>,
147 source: Option<String>,
148 kernel: Kernel,
149}
150
151impl Default for SuperTrendOscillatorBuilder {
152 fn default() -> Self {
153 Self {
154 length: None,
155 mult: None,
156 smooth: None,
157 source: None,
158 kernel: Kernel::Auto,
159 }
160 }
161}
162
163impl SuperTrendOscillatorBuilder {
164 #[inline]
165 pub fn new() -> Self {
166 Self::default()
167 }
168
169 #[inline]
170 pub fn length(mut self, value: usize) -> Self {
171 self.length = Some(value);
172 self
173 }
174
175 #[inline]
176 pub fn mult(mut self, value: f64) -> Self {
177 self.mult = Some(value);
178 self
179 }
180
181 #[inline]
182 pub fn smooth(mut self, value: usize) -> Self {
183 self.smooth = Some(value);
184 self
185 }
186
187 #[inline]
188 pub fn source<S: Into<String>>(mut self, value: S) -> Self {
189 self.source = Some(value.into());
190 self
191 }
192
193 #[inline]
194 pub fn kernel(mut self, value: Kernel) -> Self {
195 self.kernel = value;
196 self
197 }
198
199 #[inline]
200 pub fn apply(
201 self,
202 candles: &Candles,
203 ) -> Result<SuperTrendOscillatorOutput, SuperTrendOscillatorError> {
204 let input = SuperTrendOscillatorInput::from_candles(
205 candles,
206 self.source.as_deref().unwrap_or("close"),
207 SuperTrendOscillatorParams {
208 length: self.length,
209 mult: self.mult,
210 smooth: self.smooth,
211 },
212 );
213 supertrend_oscillator_with_kernel(&input, self.kernel)
214 }
215
216 #[inline]
217 pub fn apply_slices(
218 self,
219 high: &[f64],
220 low: &[f64],
221 source: &[f64],
222 ) -> Result<SuperTrendOscillatorOutput, SuperTrendOscillatorError> {
223 let input = SuperTrendOscillatorInput::from_slices(
224 high,
225 low,
226 source,
227 SuperTrendOscillatorParams {
228 length: self.length,
229 mult: self.mult,
230 smooth: self.smooth,
231 },
232 );
233 supertrend_oscillator_with_kernel(&input, self.kernel)
234 }
235
236 #[inline]
237 pub fn into_stream(self) -> Result<SuperTrendOscillatorStream, SuperTrendOscillatorError> {
238 SuperTrendOscillatorStream::try_new(SuperTrendOscillatorParams {
239 length: self.length,
240 mult: self.mult,
241 smooth: self.smooth,
242 })
243 }
244}
245
246#[derive(Debug, Error)]
247pub enum SuperTrendOscillatorError {
248 #[error("supertrend_oscillator: Empty input data.")]
249 EmptyInputData,
250 #[error(
251 "supertrend_oscillator: Input length mismatch: high={high}, low={low}, source={source_len}"
252 )]
253 DataLengthMismatch {
254 high: usize,
255 low: usize,
256 source_len: usize,
257 },
258 #[error("supertrend_oscillator: All input values are invalid.")]
259 AllValuesNaN,
260 #[error("supertrend_oscillator: Invalid length: length = {length}, data length = {data_len}")]
261 InvalidLength { length: usize, data_len: usize },
262 #[error("supertrend_oscillator: Invalid multiplier: {mult}")]
263 InvalidMultiplier { mult: f64 },
264 #[error("supertrend_oscillator: Invalid smooth: {smooth}")]
265 InvalidSmooth { smooth: usize },
266 #[error("supertrend_oscillator: Not enough valid data: needed = {needed}, valid = {valid}")]
267 NotEnoughValidData { needed: usize, valid: usize },
268 #[error("supertrend_oscillator: Output length mismatch: expected = {expected}, got = {got}")]
269 OutputLengthMismatch { expected: usize, got: usize },
270 #[error("supertrend_oscillator: Invalid range: start={start}, end={end}, step={step}")]
271 InvalidRange {
272 start: String,
273 end: String,
274 step: String,
275 },
276 #[error("supertrend_oscillator: Invalid float range: start={start}, end={end}, step={step}")]
277 InvalidFloatRange { start: f64, end: f64, step: f64 },
278 #[error("supertrend_oscillator: Invalid kernel for batch: {0:?}")]
279 InvalidKernelForBatch(Kernel),
280}
281
282#[inline(always)]
283fn valid_bar(high: f64, low: f64, source: f64) -> bool {
284 high.is_finite() && low.is_finite() && source.is_finite() && high >= low
285}
286
287#[inline(always)]
288fn first_valid_bar(high: &[f64], low: &[f64], source: &[f64]) -> Option<usize> {
289 (0..source.len()).find(|&i| valid_bar(high[i], low[i], source[i]))
290}
291
292#[inline(always)]
293fn max_valid_run(high: &[f64], low: &[f64], source: &[f64]) -> usize {
294 let mut best = 0usize;
295 let mut cur = 0usize;
296 for i in 0..source.len() {
297 if valid_bar(high[i], low[i], source[i]) {
298 cur += 1;
299 if cur > best {
300 best = cur;
301 }
302 } else {
303 cur = 0;
304 }
305 }
306 best
307}
308
309#[inline(always)]
310fn normalized_kernel(kernel: Kernel) -> Kernel {
311 match kernel {
312 Kernel::Auto => detect_best_kernel(),
313 other if other.is_batch() => other.to_non_batch(),
314 other => other,
315 }
316}
317
318#[inline(always)]
319fn clamp_unit(value: f64) -> f64 {
320 value.clamp(-1.0, 1.0)
321}
322
323#[inline(always)]
324fn warmup_end(first_valid: usize, length: usize) -> usize {
325 first_valid.saturating_add(length.saturating_sub(1))
326}
327
328#[inline(always)]
329fn validate_lengths(
330 high: &[f64],
331 low: &[f64],
332 source: &[f64],
333) -> Result<(), SuperTrendOscillatorError> {
334 if high.is_empty() || low.is_empty() || source.is_empty() {
335 return Err(SuperTrendOscillatorError::EmptyInputData);
336 }
337 if high.len() != low.len() || low.len() != source.len() {
338 return Err(SuperTrendOscillatorError::DataLengthMismatch {
339 high: high.len(),
340 low: low.len(),
341 source_len: source.len(),
342 });
343 }
344 Ok(())
345}
346
347#[inline(always)]
348fn validate_params(
349 length: usize,
350 mult: f64,
351 smooth: usize,
352 data_len: usize,
353) -> Result<(), SuperTrendOscillatorError> {
354 if length == 0 || length > data_len {
355 return Err(SuperTrendOscillatorError::InvalidLength { length, data_len });
356 }
357 if !mult.is_finite() || mult <= 0.0 {
358 return Err(SuperTrendOscillatorError::InvalidMultiplier { mult });
359 }
360 if smooth == 0 {
361 return Err(SuperTrendOscillatorError::InvalidSmooth { smooth });
362 }
363 Ok(())
364}
365
366fn compute_atr_series(high: &[f64], low: &[f64], source: &[f64], length: usize) -> Vec<f64> {
367 let mut out = vec![f64::NAN; source.len()];
368 let mut stream = AtrStream::try_new(AtrParams {
369 length: Some(length),
370 })
371 .expect("validated length");
372
373 for i in 0..source.len() {
374 if !valid_bar(high[i], low[i], source[i]) {
375 stream = AtrStream::try_new(AtrParams {
376 length: Some(length),
377 })
378 .expect("validated length");
379 continue;
380 }
381 if let Some(atr) = stream.update(high[i], low[i], source[i]) {
382 out[i] = atr;
383 }
384 }
385
386 out
387}
388
389#[inline(always)]
390fn supertrend_oscillator_row_scalar(
391 high: &[f64],
392 low: &[f64],
393 source: &[f64],
394 length: usize,
395 mult: f64,
396 smooth: usize,
397 atr_values: &[f64],
398 out_oscillator: &mut [f64],
399 out_signal: &mut [f64],
400 out_histogram: &mut [f64],
401) {
402 let hist_alpha = 2.0 / (smooth as f64 + 1.0);
403 let mut prev_source = f64::NAN;
404 let mut prev_upper = f64::NAN;
405 let mut prev_lower = f64::NAN;
406 let mut prev_trend = 0.0;
407 let mut ama_prev: Option<f64> = None;
408 let mut hist_prev: Option<f64> = None;
409 let length_f64 = length as f64;
410
411 for i in 0..source.len() {
412 let src = source[i];
413 if !valid_bar(high[i], low[i], src) {
414 out_oscillator[i] = f64::NAN;
415 out_signal[i] = f64::NAN;
416 out_histogram[i] = f64::NAN;
417 prev_source = f64::NAN;
418 prev_upper = f64::NAN;
419 prev_lower = f64::NAN;
420 prev_trend = 0.0;
421 ama_prev = None;
422 hist_prev = None;
423 continue;
424 }
425
426 if !atr_values[i].is_finite() {
427 out_oscillator[i] = f64::NAN;
428 out_signal[i] = f64::NAN;
429 out_histogram[i] = f64::NAN;
430 prev_source = src;
431 continue;
432 }
433
434 let mid = 0.5 * (high[i] + low[i]);
435 let band = atr_values[i] * mult;
436 let up = mid + band;
437 let dn = mid - band;
438
439 let upper = if prev_source.is_finite() && prev_upper.is_finite() && prev_source < prev_upper
440 {
441 up.min(prev_upper)
442 } else {
443 up
444 };
445 let lower = if prev_source.is_finite() && prev_lower.is_finite() && prev_source > prev_lower
446 {
447 dn.max(prev_lower)
448 } else {
449 dn
450 };
451
452 let trend = if prev_upper.is_finite() && src > prev_upper {
453 1.0
454 } else if prev_lower.is_finite() && src < prev_lower {
455 0.0
456 } else {
457 prev_trend
458 };
459
460 let supertrend = trend * lower + (1.0 - trend) * upper;
461 let width = upper - lower;
462 let osc = if width.is_finite() && width != 0.0 {
463 clamp_unit((src - supertrend) / width)
464 } else {
465 0.0
466 };
467 let alpha = (osc * osc) / length_f64;
468 let ama = match ama_prev {
469 Some(prev) => prev + alpha * (osc - prev),
470 None => osc,
471 };
472 let diff = osc - ama;
473 let hist = match hist_prev {
474 Some(prev) => prev + hist_alpha * (diff - prev),
475 None => diff,
476 };
477
478 out_oscillator[i] = osc * OUTPUT_SCALE;
479 out_signal[i] = ama * OUTPUT_SCALE;
480 out_histogram[i] = hist * OUTPUT_SCALE;
481
482 prev_source = src;
483 prev_upper = upper;
484 prev_lower = lower;
485 prev_trend = trend;
486 ama_prev = Some(ama);
487 hist_prev = Some(hist);
488 }
489}
490
491fn supertrend_oscillator_prepare<'a>(
492 input: &'a SuperTrendOscillatorInput<'a>,
493 kernel: Kernel,
494) -> Result<
495 (
496 &'a [f64],
497 &'a [f64],
498 &'a [f64],
499 usize,
500 f64,
501 usize,
502 usize,
503 Vec<f64>,
504 Kernel,
505 ),
506 SuperTrendOscillatorError,
507> {
508 let (high, low, source) = input.as_refs();
509 validate_lengths(high, low, source)?;
510
511 let length = input.get_length();
512 let mult = input.get_mult();
513 let smooth = input.get_smooth();
514 validate_params(length, mult, smooth, source.len())?;
515
516 let first_valid =
517 first_valid_bar(high, low, source).ok_or(SuperTrendOscillatorError::AllValuesNaN)?;
518 let max_run = max_valid_run(high, low, source);
519 if max_run < length {
520 return Err(SuperTrendOscillatorError::NotEnoughValidData {
521 needed: length,
522 valid: max_run,
523 });
524 }
525
526 let atr_values = compute_atr_series(high, low, source, length);
527
528 Ok((
529 high,
530 low,
531 source,
532 length,
533 mult,
534 smooth,
535 first_valid,
536 atr_values,
537 normalized_kernel(kernel),
538 ))
539}
540
541#[inline]
542pub fn supertrend_oscillator(
543 input: &SuperTrendOscillatorInput,
544) -> Result<SuperTrendOscillatorOutput, SuperTrendOscillatorError> {
545 supertrend_oscillator_with_kernel(input, Kernel::Auto)
546}
547
548#[inline]
549pub fn supertrend_oscillator_with_kernel(
550 input: &SuperTrendOscillatorInput,
551 kernel: Kernel,
552) -> Result<SuperTrendOscillatorOutput, SuperTrendOscillatorError> {
553 let (high, low, source, length, mult, smooth, first_valid, atr_values, _chosen) =
554 supertrend_oscillator_prepare(input, kernel)?;
555
556 let len = source.len();
557 let warmup = warmup_end(first_valid, length);
558 let mut oscillator = alloc_with_nan_prefix(len, warmup);
559 let mut signal = alloc_with_nan_prefix(len, warmup);
560 let mut histogram = alloc_with_nan_prefix(len, warmup);
561
562 supertrend_oscillator_row_scalar(
563 high,
564 low,
565 source,
566 length,
567 mult,
568 smooth,
569 &atr_values,
570 &mut oscillator,
571 &mut signal,
572 &mut histogram,
573 );
574
575 Ok(SuperTrendOscillatorOutput {
576 oscillator,
577 signal,
578 histogram,
579 })
580}
581
582#[inline]
583pub fn supertrend_oscillator_into_slice(
584 out_oscillator: &mut [f64],
585 out_signal: &mut [f64],
586 out_histogram: &mut [f64],
587 input: &SuperTrendOscillatorInput,
588 kernel: Kernel,
589) -> Result<(), SuperTrendOscillatorError> {
590 let (high, low, source, length, mult, smooth, _first_valid, atr_values, _chosen) =
591 supertrend_oscillator_prepare(input, kernel)?;
592 let len = source.len();
593 if out_oscillator.len() != len || out_signal.len() != len || out_histogram.len() != len {
594 return Err(SuperTrendOscillatorError::OutputLengthMismatch {
595 expected: len,
596 got: out_oscillator
597 .len()
598 .max(out_signal.len())
599 .max(out_histogram.len()),
600 });
601 }
602
603 supertrend_oscillator_row_scalar(
604 high,
605 low,
606 source,
607 length,
608 mult,
609 smooth,
610 &atr_values,
611 out_oscillator,
612 out_signal,
613 out_histogram,
614 );
615 Ok(())
616}
617
618#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
619#[inline]
620pub fn supertrend_oscillator_into(
621 input: &SuperTrendOscillatorInput,
622 out_oscillator: &mut [f64],
623 out_signal: &mut [f64],
624 out_histogram: &mut [f64],
625) -> Result<(), SuperTrendOscillatorError> {
626 supertrend_oscillator_into_slice(
627 out_oscillator,
628 out_signal,
629 out_histogram,
630 input,
631 Kernel::Auto,
632 )
633}
634
635#[derive(Clone, Debug)]
636pub struct SuperTrendOscillatorStream {
637 length: usize,
638 mult: f64,
639 hist_alpha: f64,
640 atr_stream: AtrStream,
641 prev_source: f64,
642 prev_upper: f64,
643 prev_lower: f64,
644 prev_trend: f64,
645 ama_prev: Option<f64>,
646 hist_prev: Option<f64>,
647}
648
649impl SuperTrendOscillatorStream {
650 #[inline]
651 pub fn try_new(params: SuperTrendOscillatorParams) -> Result<Self, SuperTrendOscillatorError> {
652 let length = params.length.unwrap_or(DEFAULT_LENGTH);
653 let mult = params.mult.unwrap_or(DEFAULT_MULT);
654 let smooth = params.smooth.unwrap_or(DEFAULT_SMOOTH);
655 validate_params(length, mult, smooth, length)?;
656 Ok(Self {
657 length,
658 mult,
659 hist_alpha: 2.0 / (smooth as f64 + 1.0),
660 atr_stream: AtrStream::try_new(AtrParams {
661 length: Some(length),
662 })
663 .expect("validated length"),
664 prev_source: f64::NAN,
665 prev_upper: f64::NAN,
666 prev_lower: f64::NAN,
667 prev_trend: 0.0,
668 ama_prev: None,
669 hist_prev: None,
670 })
671 }
672
673 #[inline]
674 fn reset(&mut self) {
675 self.atr_stream = AtrStream::try_new(AtrParams {
676 length: Some(self.length),
677 })
678 .expect("validated length");
679 self.prev_source = f64::NAN;
680 self.prev_upper = f64::NAN;
681 self.prev_lower = f64::NAN;
682 self.prev_trend = 0.0;
683 self.ama_prev = None;
684 self.hist_prev = None;
685 }
686
687 #[inline]
688 pub fn update(&mut self, high: f64, low: f64, source: f64) -> Option<(f64, f64, f64)> {
689 if !valid_bar(high, low, source) {
690 self.reset();
691 return None;
692 }
693
694 let atr = match self.atr_stream.update(high, low, source) {
695 Some(value) => value,
696 None => {
697 self.prev_source = source;
698 return None;
699 }
700 };
701
702 let mid = 0.5 * (high + low);
703 let up = mid + atr * self.mult;
704 let dn = mid - atr * self.mult;
705
706 let upper = if self.prev_source.is_finite()
707 && self.prev_upper.is_finite()
708 && self.prev_source < self.prev_upper
709 {
710 up.min(self.prev_upper)
711 } else {
712 up
713 };
714 let lower = if self.prev_source.is_finite()
715 && self.prev_lower.is_finite()
716 && self.prev_source > self.prev_lower
717 {
718 dn.max(self.prev_lower)
719 } else {
720 dn
721 };
722
723 let trend = if self.prev_upper.is_finite() && source > self.prev_upper {
724 1.0
725 } else if self.prev_lower.is_finite() && source < self.prev_lower {
726 0.0
727 } else {
728 self.prev_trend
729 };
730
731 let supertrend = trend * lower + (1.0 - trend) * upper;
732 let width = upper - lower;
733 let osc = if width.is_finite() && width != 0.0 {
734 clamp_unit((source - supertrend) / width)
735 } else {
736 0.0
737 };
738 let alpha = (osc * osc) / self.length as f64;
739 let ama = match self.ama_prev {
740 Some(prev) => prev + alpha * (osc - prev),
741 None => osc,
742 };
743 let diff = osc - ama;
744 let hist = match self.hist_prev {
745 Some(prev) => prev + self.hist_alpha * (diff - prev),
746 None => diff,
747 };
748
749 self.prev_source = source;
750 self.prev_upper = upper;
751 self.prev_lower = lower;
752 self.prev_trend = trend;
753 self.ama_prev = Some(ama);
754 self.hist_prev = Some(hist);
755
756 Some((osc * OUTPUT_SCALE, ama * OUTPUT_SCALE, hist * OUTPUT_SCALE))
757 }
758}
759
760#[derive(Debug, Clone)]
761pub struct SuperTrendOscillatorBatchOutput {
762 pub oscillator: Vec<f64>,
763 pub signal: Vec<f64>,
764 pub histogram: Vec<f64>,
765 pub combos: Vec<SuperTrendOscillatorParams>,
766 pub rows: usize,
767 pub cols: usize,
768}
769
770impl SuperTrendOscillatorBatchOutput {
771 #[inline]
772 pub fn row_for_params(&self, params: &SuperTrendOscillatorParams) -> Option<usize> {
773 self.combos.iter().position(|p| {
774 p.length == params.length && p.mult == params.mult && p.smooth == params.smooth
775 })
776 }
777}
778
779#[derive(Debug, Clone)]
780pub struct SuperTrendOscillatorBatchRange {
781 pub length: (usize, usize, usize),
782 pub mult: (f64, f64, f64),
783 pub smooth: (usize, usize, usize),
784}
785
786impl Default for SuperTrendOscillatorBatchRange {
787 fn default() -> Self {
788 Self {
789 length: (DEFAULT_LENGTH, DEFAULT_LENGTH, 0),
790 mult: (DEFAULT_MULT, DEFAULT_MULT, 0.0),
791 smooth: (DEFAULT_SMOOTH, DEFAULT_SMOOTH, 0),
792 }
793 }
794}
795
796#[derive(Clone, Debug)]
797pub struct SuperTrendOscillatorBatchBuilder {
798 range: SuperTrendOscillatorBatchRange,
799 source: Option<String>,
800 kernel: Kernel,
801}
802
803impl Default for SuperTrendOscillatorBatchBuilder {
804 fn default() -> Self {
805 Self {
806 range: SuperTrendOscillatorBatchRange::default(),
807 source: None,
808 kernel: Kernel::Auto,
809 }
810 }
811}
812
813impl SuperTrendOscillatorBatchBuilder {
814 #[inline]
815 pub fn new() -> Self {
816 Self::default()
817 }
818
819 #[inline]
820 pub fn kernel(mut self, value: Kernel) -> Self {
821 self.kernel = value;
822 self
823 }
824
825 #[inline]
826 pub fn source<S: Into<String>>(mut self, value: S) -> Self {
827 self.source = Some(value.into());
828 self
829 }
830
831 #[inline]
832 pub fn length_range(mut self, start: usize, end: usize, step: usize) -> Self {
833 self.range.length = (start, end, step);
834 self
835 }
836
837 #[inline]
838 pub fn length_static(mut self, value: usize) -> Self {
839 self.range.length = (value, value, 0);
840 self
841 }
842
843 #[inline]
844 pub fn mult_range(mut self, start: f64, end: f64, step: f64) -> Self {
845 self.range.mult = (start, end, step);
846 self
847 }
848
849 #[inline]
850 pub fn mult_static(mut self, value: f64) -> Self {
851 self.range.mult = (value, value, 0.0);
852 self
853 }
854
855 #[inline]
856 pub fn smooth_range(mut self, start: usize, end: usize, step: usize) -> Self {
857 self.range.smooth = (start, end, step);
858 self
859 }
860
861 #[inline]
862 pub fn smooth_static(mut self, value: usize) -> Self {
863 self.range.smooth = (value, value, 0);
864 self
865 }
866
867 #[inline]
868 pub fn apply_candles(
869 self,
870 candles: &Candles,
871 ) -> Result<SuperTrendOscillatorBatchOutput, SuperTrendOscillatorError> {
872 let source = source_type(candles, self.source.as_deref().unwrap_or("close"));
873 supertrend_oscillator_batch_with_kernel(
874 candles.high.as_slice(),
875 candles.low.as_slice(),
876 source,
877 &self.range,
878 self.kernel,
879 )
880 }
881
882 #[inline]
883 pub fn apply_slices(
884 self,
885 high: &[f64],
886 low: &[f64],
887 source: &[f64],
888 ) -> Result<SuperTrendOscillatorBatchOutput, SuperTrendOscillatorError> {
889 supertrend_oscillator_batch_with_kernel(high, low, source, &self.range, self.kernel)
890 }
891}
892
893#[inline]
894pub fn expand_grid_supertrend_oscillator(
895 range: &SuperTrendOscillatorBatchRange,
896) -> Result<Vec<SuperTrendOscillatorParams>, SuperTrendOscillatorError> {
897 fn axis_usize(
898 (start, end, step): (usize, usize, usize),
899 ) -> Result<Vec<usize>, SuperTrendOscillatorError> {
900 if step == 0 || start == end {
901 return Ok(vec![start]);
902 }
903 if start <= end {
904 let mut out = Vec::new();
905 let mut x = start;
906 while x <= end {
907 out.push(x);
908 match x.checked_add(step.max(1)) {
909 Some(next) if next > x => x = next,
910 _ => break,
911 }
912 }
913 if out.is_empty() {
914 return Err(SuperTrendOscillatorError::InvalidRange {
915 start: start.to_string(),
916 end: end.to_string(),
917 step: step.to_string(),
918 });
919 }
920 return Ok(out);
921 }
922
923 let mut out = Vec::new();
924 let mut x = start;
925 let step = step.max(1);
926 while x >= end {
927 out.push(x);
928 if x == end {
929 break;
930 }
931 let next = x.saturating_sub(step);
932 if next == x || next < end {
933 break;
934 }
935 x = next;
936 }
937 if out.is_empty() {
938 return Err(SuperTrendOscillatorError::InvalidRange {
939 start: start.to_string(),
940 end: end.to_string(),
941 step: step.to_string(),
942 });
943 }
944 Ok(out)
945 }
946
947 fn axis_f64(
948 (start, end, step): (f64, f64, f64),
949 ) -> Result<Vec<f64>, SuperTrendOscillatorError> {
950 if !start.is_finite() || !end.is_finite() || !step.is_finite() {
951 return Err(SuperTrendOscillatorError::InvalidFloatRange { start, end, step });
952 }
953 if step.abs() < 1e-12 || (start - end).abs() < 1e-12 {
954 return Ok(vec![start]);
955 }
956 let step = step.abs();
957 let mut out = Vec::new();
958 if start <= end {
959 let mut x = start;
960 while x <= end + 1e-12 {
961 out.push(x);
962 x += step;
963 }
964 } else {
965 let mut x = start;
966 while x + 1e-12 >= end {
967 out.push(x);
968 x -= step;
969 }
970 }
971 if out.is_empty() {
972 return Err(SuperTrendOscillatorError::InvalidFloatRange { start, end, step });
973 }
974 Ok(out)
975 }
976
977 let lengths = axis_usize(range.length)?;
978 let mults = axis_f64(range.mult)?;
979 let smooths = axis_usize(range.smooth)?;
980
981 let cap = lengths
982 .len()
983 .checked_mul(mults.len())
984 .and_then(|value| value.checked_mul(smooths.len()))
985 .ok_or(SuperTrendOscillatorError::InvalidRange {
986 start: range.length.0.to_string(),
987 end: range.length.1.to_string(),
988 step: range.length.2.to_string(),
989 })?;
990
991 let mut out = Vec::with_capacity(cap);
992 for &length in &lengths {
993 for &mult in &mults {
994 for &smooth in &smooths {
995 out.push(SuperTrendOscillatorParams {
996 length: Some(length),
997 mult: Some(mult),
998 smooth: Some(smooth),
999 });
1000 }
1001 }
1002 }
1003 Ok(out)
1004}
1005
1006#[inline]
1007pub fn supertrend_oscillator_batch_with_kernel(
1008 high: &[f64],
1009 low: &[f64],
1010 source: &[f64],
1011 sweep: &SuperTrendOscillatorBatchRange,
1012 kernel: Kernel,
1013) -> Result<SuperTrendOscillatorBatchOutput, SuperTrendOscillatorError> {
1014 let batch_kernel = match kernel {
1015 Kernel::Auto => detect_best_batch_kernel(),
1016 other if other.is_batch() => other,
1017 other => return Err(SuperTrendOscillatorError::InvalidKernelForBatch(other)),
1018 };
1019 supertrend_oscillator_batch_par_slice(high, low, source, sweep, batch_kernel.to_non_batch())
1020}
1021
1022#[inline]
1023pub fn supertrend_oscillator_batch_slice(
1024 high: &[f64],
1025 low: &[f64],
1026 source: &[f64],
1027 sweep: &SuperTrendOscillatorBatchRange,
1028 kernel: Kernel,
1029) -> Result<SuperTrendOscillatorBatchOutput, SuperTrendOscillatorError> {
1030 supertrend_oscillator_batch_inner(high, low, source, sweep, kernel, false)
1031}
1032
1033#[inline]
1034pub fn supertrend_oscillator_batch_par_slice(
1035 high: &[f64],
1036 low: &[f64],
1037 source: &[f64],
1038 sweep: &SuperTrendOscillatorBatchRange,
1039 kernel: Kernel,
1040) -> Result<SuperTrendOscillatorBatchOutput, SuperTrendOscillatorError> {
1041 supertrend_oscillator_batch_inner(high, low, source, sweep, kernel, true)
1042}
1043
1044fn supertrend_oscillator_batch_inner(
1045 high: &[f64],
1046 low: &[f64],
1047 source: &[f64],
1048 sweep: &SuperTrendOscillatorBatchRange,
1049 _kernel: Kernel,
1050 parallel: bool,
1051) -> Result<SuperTrendOscillatorBatchOutput, SuperTrendOscillatorError> {
1052 validate_lengths(high, low, source)?;
1053 let combos = expand_grid_supertrend_oscillator(sweep)?;
1054 let first_valid =
1055 first_valid_bar(high, low, source).ok_or(SuperTrendOscillatorError::AllValuesNaN)?;
1056 let max_run = max_valid_run(high, low, source);
1057 let max_length = combos
1058 .iter()
1059 .map(|params| params.length.unwrap_or(DEFAULT_LENGTH))
1060 .max()
1061 .unwrap_or(DEFAULT_LENGTH);
1062 if max_run < max_length {
1063 return Err(SuperTrendOscillatorError::NotEnoughValidData {
1064 needed: max_length,
1065 valid: max_run,
1066 });
1067 }
1068 for params in &combos {
1069 validate_params(
1070 params.length.unwrap_or(DEFAULT_LENGTH),
1071 params.mult.unwrap_or(DEFAULT_MULT),
1072 params.smooth.unwrap_or(DEFAULT_SMOOTH),
1073 source.len(),
1074 )?;
1075 }
1076
1077 let rows = combos.len();
1078 let cols = source.len();
1079 let total = rows
1080 .checked_mul(cols)
1081 .ok_or(SuperTrendOscillatorError::OutputLengthMismatch {
1082 expected: usize::MAX,
1083 got: 0,
1084 })?;
1085
1086 let mut oscillator_matrix = make_uninit_matrix(rows, cols);
1087 let mut signal_matrix = make_uninit_matrix(rows, cols);
1088 let mut histogram_matrix = make_uninit_matrix(rows, cols);
1089
1090 let warmups: Vec<usize> = combos
1091 .iter()
1092 .map(|params| warmup_end(first_valid, params.length.unwrap_or(DEFAULT_LENGTH)))
1093 .collect();
1094 init_matrix_prefixes(&mut oscillator_matrix, cols, &warmups);
1095 init_matrix_prefixes(&mut signal_matrix, cols, &warmups);
1096 init_matrix_prefixes(&mut histogram_matrix, cols, &warmups);
1097
1098 let mut oscillator_guard = ManuallyDrop::new(oscillator_matrix);
1099 let mut signal_guard = ManuallyDrop::new(signal_matrix);
1100 let mut histogram_guard = ManuallyDrop::new(histogram_matrix);
1101
1102 let oscillator_mu: &mut [MaybeUninit<f64>] = unsafe {
1103 std::slice::from_raw_parts_mut(oscillator_guard.as_mut_ptr(), oscillator_guard.len())
1104 };
1105 let signal_mu: &mut [MaybeUninit<f64>] =
1106 unsafe { std::slice::from_raw_parts_mut(signal_guard.as_mut_ptr(), signal_guard.len()) };
1107 let histogram_mu: &mut [MaybeUninit<f64>] = unsafe {
1108 std::slice::from_raw_parts_mut(histogram_guard.as_mut_ptr(), histogram_guard.len())
1109 };
1110
1111 let mut atr_cache: HashMap<usize, Vec<f64>> = HashMap::new();
1112 let mut lengths: Vec<usize> = combos
1113 .iter()
1114 .map(|params| params.length.unwrap_or(DEFAULT_LENGTH))
1115 .collect();
1116 lengths.sort_unstable();
1117 lengths.dedup();
1118 for length in lengths {
1119 atr_cache.insert(length, compute_atr_series(high, low, source, length));
1120 }
1121
1122 let do_row = |row: usize,
1123 row_oscillator: &mut [MaybeUninit<f64>],
1124 row_signal: &mut [MaybeUninit<f64>],
1125 row_histogram: &mut [MaybeUninit<f64>]| {
1126 let params = &combos[row];
1127 let length = params.length.unwrap_or(DEFAULT_LENGTH);
1128 let mult = params.mult.unwrap_or(DEFAULT_MULT);
1129 let smooth = params.smooth.unwrap_or(DEFAULT_SMOOTH);
1130 let atr_values = atr_cache.get(&length).expect("cached atr");
1131
1132 let dst_oscillator = unsafe {
1133 std::slice::from_raw_parts_mut(row_oscillator.as_mut_ptr() as *mut f64, cols)
1134 };
1135 let dst_signal =
1136 unsafe { std::slice::from_raw_parts_mut(row_signal.as_mut_ptr() as *mut f64, cols) };
1137 let dst_histogram =
1138 unsafe { std::slice::from_raw_parts_mut(row_histogram.as_mut_ptr() as *mut f64, cols) };
1139
1140 supertrend_oscillator_row_scalar(
1141 high,
1142 low,
1143 source,
1144 length,
1145 mult,
1146 smooth,
1147 atr_values,
1148 dst_oscillator,
1149 dst_signal,
1150 dst_histogram,
1151 );
1152 };
1153
1154 if parallel {
1155 #[cfg(not(target_arch = "wasm32"))]
1156 oscillator_mu
1157 .par_chunks_mut(cols)
1158 .zip(signal_mu.par_chunks_mut(cols))
1159 .zip(histogram_mu.par_chunks_mut(cols))
1160 .enumerate()
1161 .for_each(|(row, ((row_oscillator, row_signal), row_histogram))| {
1162 do_row(row, row_oscillator, row_signal, row_histogram)
1163 });
1164
1165 #[cfg(target_arch = "wasm32")]
1166 for (row, ((row_oscillator, row_signal), row_histogram)) in oscillator_mu
1167 .chunks_mut(cols)
1168 .zip(signal_mu.chunks_mut(cols))
1169 .zip(histogram_mu.chunks_mut(cols))
1170 .enumerate()
1171 {
1172 do_row(row, row_oscillator, row_signal, row_histogram);
1173 }
1174 } else {
1175 for (row, ((row_oscillator, row_signal), row_histogram)) in oscillator_mu
1176 .chunks_mut(cols)
1177 .zip(signal_mu.chunks_mut(cols))
1178 .zip(histogram_mu.chunks_mut(cols))
1179 .enumerate()
1180 {
1181 do_row(row, row_oscillator, row_signal, row_histogram);
1182 }
1183 }
1184
1185 let oscillator = unsafe {
1186 Vec::from_raw_parts(
1187 oscillator_guard.as_mut_ptr() as *mut f64,
1188 total,
1189 oscillator_guard.capacity(),
1190 )
1191 };
1192 let signal = unsafe {
1193 Vec::from_raw_parts(
1194 signal_guard.as_mut_ptr() as *mut f64,
1195 total,
1196 signal_guard.capacity(),
1197 )
1198 };
1199 let histogram = unsafe {
1200 Vec::from_raw_parts(
1201 histogram_guard.as_mut_ptr() as *mut f64,
1202 total,
1203 histogram_guard.capacity(),
1204 )
1205 };
1206
1207 Ok(SuperTrendOscillatorBatchOutput {
1208 oscillator,
1209 signal,
1210 histogram,
1211 combos,
1212 rows,
1213 cols,
1214 })
1215}
1216
1217fn supertrend_oscillator_batch_inner_into(
1218 high: &[f64],
1219 low: &[f64],
1220 source: &[f64],
1221 sweep: &SuperTrendOscillatorBatchRange,
1222 kernel: Kernel,
1223 parallel: bool,
1224 out_oscillator: &mut [f64],
1225 out_signal: &mut [f64],
1226 out_histogram: &mut [f64],
1227) -> Result<Vec<SuperTrendOscillatorParams>, SuperTrendOscillatorError> {
1228 validate_lengths(high, low, source)?;
1229 let combos = expand_grid_supertrend_oscillator(sweep)?;
1230 let max_run = max_valid_run(high, low, source);
1231 let max_length = combos
1232 .iter()
1233 .map(|params| params.length.unwrap_or(DEFAULT_LENGTH))
1234 .max()
1235 .unwrap_or(DEFAULT_LENGTH);
1236 if max_run < max_length {
1237 return Err(SuperTrendOscillatorError::NotEnoughValidData {
1238 needed: max_length,
1239 valid: max_run,
1240 });
1241 }
1242
1243 let rows = combos.len();
1244 let cols = source.len();
1245 let total = rows
1246 .checked_mul(cols)
1247 .ok_or(SuperTrendOscillatorError::OutputLengthMismatch {
1248 expected: usize::MAX,
1249 got: 0,
1250 })?;
1251 if out_oscillator.len() != total || out_signal.len() != total || out_histogram.len() != total {
1252 return Err(SuperTrendOscillatorError::OutputLengthMismatch {
1253 expected: total,
1254 got: out_oscillator
1255 .len()
1256 .max(out_signal.len())
1257 .max(out_histogram.len()),
1258 });
1259 }
1260
1261 let mut atr_cache: HashMap<usize, Vec<f64>> = HashMap::new();
1262 for params in &combos {
1263 let length = params.length.unwrap_or(DEFAULT_LENGTH);
1264 validate_params(
1265 length,
1266 params.mult.unwrap_or(DEFAULT_MULT),
1267 params.smooth.unwrap_or(DEFAULT_SMOOTH),
1268 cols,
1269 )?;
1270 atr_cache
1271 .entry(length)
1272 .or_insert_with(|| compute_atr_series(high, low, source, length));
1273 }
1274
1275 let _ = kernel;
1276 let do_row = |row: usize,
1277 dst_oscillator: &mut [f64],
1278 dst_signal: &mut [f64],
1279 dst_histogram: &mut [f64]| {
1280 let params = &combos[row];
1281 let length = params.length.unwrap_or(DEFAULT_LENGTH);
1282 let mult = params.mult.unwrap_or(DEFAULT_MULT);
1283 let smooth = params.smooth.unwrap_or(DEFAULT_SMOOTH);
1284 let atr_values = atr_cache.get(&length).expect("cached atr");
1285
1286 supertrend_oscillator_row_scalar(
1287 high,
1288 low,
1289 source,
1290 length,
1291 mult,
1292 smooth,
1293 atr_values,
1294 dst_oscillator,
1295 dst_signal,
1296 dst_histogram,
1297 );
1298 };
1299
1300 if parallel {
1301 #[cfg(not(target_arch = "wasm32"))]
1302 out_oscillator
1303 .par_chunks_mut(cols)
1304 .zip(out_signal.par_chunks_mut(cols))
1305 .zip(out_histogram.par_chunks_mut(cols))
1306 .enumerate()
1307 .for_each(|(row, ((dst_oscillator, dst_signal), dst_histogram))| {
1308 do_row(row, dst_oscillator, dst_signal, dst_histogram)
1309 });
1310
1311 #[cfg(target_arch = "wasm32")]
1312 for (row, ((dst_oscillator, dst_signal), dst_histogram)) in out_oscillator
1313 .chunks_mut(cols)
1314 .zip(out_signal.chunks_mut(cols))
1315 .zip(out_histogram.chunks_mut(cols))
1316 .enumerate()
1317 {
1318 do_row(row, dst_oscillator, dst_signal, dst_histogram);
1319 }
1320 } else {
1321 for (row, ((dst_oscillator, dst_signal), dst_histogram)) in out_oscillator
1322 .chunks_mut(cols)
1323 .zip(out_signal.chunks_mut(cols))
1324 .zip(out_histogram.chunks_mut(cols))
1325 .enumerate()
1326 {
1327 do_row(row, dst_oscillator, dst_signal, dst_histogram);
1328 }
1329 }
1330
1331 Ok(combos)
1332}
1333
1334#[cfg(feature = "python")]
1335#[pyfunction(name = "supertrend_oscillator")]
1336#[pyo3(signature = (high, low, source, length=DEFAULT_LENGTH, mult=DEFAULT_MULT, smooth=DEFAULT_SMOOTH, kernel=None))]
1337pub fn supertrend_oscillator_py<'py>(
1338 py: Python<'py>,
1339 high: PyReadonlyArray1<'py, f64>,
1340 low: PyReadonlyArray1<'py, f64>,
1341 source: PyReadonlyArray1<'py, f64>,
1342 length: usize,
1343 mult: f64,
1344 smooth: usize,
1345 kernel: Option<&str>,
1346) -> PyResult<(
1347 Bound<'py, PyArray1<f64>>,
1348 Bound<'py, PyArray1<f64>>,
1349 Bound<'py, PyArray1<f64>>,
1350)> {
1351 let high = high.as_slice()?;
1352 let low = low.as_slice()?;
1353 let source = source.as_slice()?;
1354 let input = SuperTrendOscillatorInput::from_slices(
1355 high,
1356 low,
1357 source,
1358 SuperTrendOscillatorParams {
1359 length: Some(length),
1360 mult: Some(mult),
1361 smooth: Some(smooth),
1362 },
1363 );
1364 let kernel = validate_kernel(kernel, false)?;
1365 let out = py
1366 .allow_threads(|| supertrend_oscillator_with_kernel(&input, kernel))
1367 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1368 Ok((
1369 out.oscillator.into_pyarray(py),
1370 out.signal.into_pyarray(py),
1371 out.histogram.into_pyarray(py),
1372 ))
1373}
1374
1375#[cfg(feature = "python")]
1376#[pyclass(name = "SuperTrendOscillatorStream")]
1377pub struct SuperTrendOscillatorStreamPy {
1378 stream: SuperTrendOscillatorStream,
1379}
1380
1381#[cfg(feature = "python")]
1382#[pymethods]
1383impl SuperTrendOscillatorStreamPy {
1384 #[new]
1385 #[pyo3(signature = (length=DEFAULT_LENGTH, mult=DEFAULT_MULT, smooth=DEFAULT_SMOOTH))]
1386 fn new(length: usize, mult: f64, smooth: usize) -> PyResult<Self> {
1387 let stream = SuperTrendOscillatorStream::try_new(SuperTrendOscillatorParams {
1388 length: Some(length),
1389 mult: Some(mult),
1390 smooth: Some(smooth),
1391 })
1392 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1393 Ok(Self { stream })
1394 }
1395
1396 fn update(&mut self, high: f64, low: f64, source: f64) -> Option<(f64, f64, f64)> {
1397 self.stream.update(high, low, source)
1398 }
1399}
1400
1401#[cfg(feature = "python")]
1402#[pyfunction(name = "supertrend_oscillator_batch")]
1403#[pyo3(signature = (high, low, source, length_range, mult_range, smooth_range, kernel=None))]
1404pub fn supertrend_oscillator_batch_py<'py>(
1405 py: Python<'py>,
1406 high: PyReadonlyArray1<'py, f64>,
1407 low: PyReadonlyArray1<'py, f64>,
1408 source: PyReadonlyArray1<'py, f64>,
1409 length_range: (usize, usize, usize),
1410 mult_range: (f64, f64, f64),
1411 smooth_range: (usize, usize, usize),
1412 kernel: Option<&str>,
1413) -> PyResult<Bound<'py, PyDict>> {
1414 let high = high.as_slice()?;
1415 let low = low.as_slice()?;
1416 let source = source.as_slice()?;
1417 let sweep = SuperTrendOscillatorBatchRange {
1418 length: length_range,
1419 mult: mult_range,
1420 smooth: smooth_range,
1421 };
1422 let combos = expand_grid_supertrend_oscillator(&sweep)
1423 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1424 let rows = combos.len();
1425 let cols = source.len();
1426 let total = rows
1427 .checked_mul(cols)
1428 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
1429 let oscillator_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1430 let signal_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1431 let histogram_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1432 let out_oscillator = unsafe { oscillator_arr.as_slice_mut()? };
1433 let out_signal = unsafe { signal_arr.as_slice_mut()? };
1434 let out_histogram = unsafe { histogram_arr.as_slice_mut()? };
1435 let kernel = validate_kernel(kernel, true)?;
1436
1437 py.allow_threads(|| {
1438 let batch_kernel = match kernel {
1439 Kernel::Auto => detect_best_batch_kernel(),
1440 other => other,
1441 };
1442 supertrend_oscillator_batch_inner_into(
1443 high,
1444 low,
1445 source,
1446 &sweep,
1447 batch_kernel.to_non_batch(),
1448 true,
1449 out_oscillator,
1450 out_signal,
1451 out_histogram,
1452 )
1453 })
1454 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1455
1456 let lengths: Vec<usize> = combos
1457 .iter()
1458 .map(|params| params.length.unwrap_or(DEFAULT_LENGTH))
1459 .collect();
1460 let mults: Vec<f64> = combos
1461 .iter()
1462 .map(|params| params.mult.unwrap_or(DEFAULT_MULT))
1463 .collect();
1464 let smooths: Vec<usize> = combos
1465 .iter()
1466 .map(|params| params.smooth.unwrap_or(DEFAULT_SMOOTH))
1467 .collect();
1468
1469 let dict = PyDict::new(py);
1470 dict.set_item("oscillator", oscillator_arr.reshape((rows, cols))?)?;
1471 dict.set_item("signal", signal_arr.reshape((rows, cols))?)?;
1472 dict.set_item("histogram", histogram_arr.reshape((rows, cols))?)?;
1473 dict.set_item("rows", rows)?;
1474 dict.set_item("cols", cols)?;
1475 dict.set_item("lengths", lengths.into_pyarray(py))?;
1476 dict.set_item("mults", mults.into_pyarray(py))?;
1477 dict.set_item("smooths", smooths.into_pyarray(py))?;
1478 Ok(dict)
1479}
1480
1481#[cfg(feature = "python")]
1482pub fn register_supertrend_oscillator_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
1483 m.add_function(wrap_pyfunction!(supertrend_oscillator_py, m)?)?;
1484 m.add_function(wrap_pyfunction!(supertrend_oscillator_batch_py, m)?)?;
1485 m.add_class::<SuperTrendOscillatorStreamPy>()?;
1486 Ok(())
1487}
1488
1489#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1490#[derive(Debug, Clone, Serialize, Deserialize)]
1491struct SuperTrendOscillatorJsOutput {
1492 oscillator: Vec<f64>,
1493 signal: Vec<f64>,
1494 histogram: Vec<f64>,
1495}
1496
1497#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1498#[derive(Debug, Clone, Serialize, Deserialize)]
1499struct SuperTrendOscillatorBatchConfig {
1500 length_range: Vec<usize>,
1501 mult_range: Vec<f64>,
1502 smooth_range: Vec<usize>,
1503}
1504
1505#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1506#[derive(Debug, Clone, Serialize, Deserialize)]
1507struct SuperTrendOscillatorBatchJsOutput {
1508 oscillator: Vec<f64>,
1509 signal: Vec<f64>,
1510 histogram: Vec<f64>,
1511 rows: usize,
1512 cols: usize,
1513 combos: Vec<SuperTrendOscillatorParams>,
1514}
1515
1516#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1517#[wasm_bindgen(js_name = "supertrend_oscillator")]
1518pub fn supertrend_oscillator_js(
1519 high: &[f64],
1520 low: &[f64],
1521 source: &[f64],
1522 length: usize,
1523 mult: f64,
1524 smooth: usize,
1525) -> Result<JsValue, JsValue> {
1526 let input = SuperTrendOscillatorInput::from_slices(
1527 high,
1528 low,
1529 source,
1530 SuperTrendOscillatorParams {
1531 length: Some(length),
1532 mult: Some(mult),
1533 smooth: Some(smooth),
1534 },
1535 );
1536 let out = supertrend_oscillator(&input).map_err(|e| JsValue::from_str(&e.to_string()))?;
1537 serde_wasm_bindgen::to_value(&SuperTrendOscillatorJsOutput {
1538 oscillator: out.oscillator,
1539 signal: out.signal,
1540 histogram: out.histogram,
1541 })
1542 .map_err(|e| JsValue::from_str(&format!("Serialization error: {e}")))
1543}
1544
1545#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1546#[wasm_bindgen]
1547pub fn supertrend_oscillator_into(
1548 high_ptr: *const f64,
1549 low_ptr: *const f64,
1550 source_ptr: *const f64,
1551 out_ptr: *mut f64,
1552 len: usize,
1553 length: usize,
1554 mult: f64,
1555 smooth: usize,
1556) -> Result<(), JsValue> {
1557 if high_ptr.is_null() || low_ptr.is_null() || source_ptr.is_null() || out_ptr.is_null() {
1558 return Err(JsValue::from_str(
1559 "null pointer passed to supertrend_oscillator_into",
1560 ));
1561 }
1562
1563 unsafe {
1564 let high = std::slice::from_raw_parts(high_ptr, len);
1565 let low = std::slice::from_raw_parts(low_ptr, len);
1566 let source = std::slice::from_raw_parts(source_ptr, len);
1567 let out = std::slice::from_raw_parts_mut(out_ptr, len * 3);
1568 let (out_oscillator, rest) = out.split_at_mut(len);
1569 let (out_signal, out_histogram) = rest.split_at_mut(len);
1570 let input = SuperTrendOscillatorInput::from_slices(
1571 high,
1572 low,
1573 source,
1574 SuperTrendOscillatorParams {
1575 length: Some(length),
1576 mult: Some(mult),
1577 smooth: Some(smooth),
1578 },
1579 );
1580 supertrend_oscillator_into_slice(
1581 out_oscillator,
1582 out_signal,
1583 out_histogram,
1584 &input,
1585 Kernel::Auto,
1586 )
1587 .map_err(|e| JsValue::from_str(&e.to_string()))
1588 }
1589}
1590
1591#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1592#[wasm_bindgen(js_name = "supertrend_oscillator_into_host")]
1593pub fn supertrend_oscillator_into_host(
1594 high: &[f64],
1595 low: &[f64],
1596 source: &[f64],
1597 out_ptr: *mut f64,
1598 length: usize,
1599 mult: f64,
1600 smooth: usize,
1601) -> Result<(), JsValue> {
1602 if out_ptr.is_null() {
1603 return Err(JsValue::from_str(
1604 "null pointer passed to supertrend_oscillator_into_host",
1605 ));
1606 }
1607
1608 unsafe {
1609 let out = std::slice::from_raw_parts_mut(out_ptr, source.len() * 3);
1610 let (out_oscillator, rest) = out.split_at_mut(source.len());
1611 let (out_signal, out_histogram) = rest.split_at_mut(source.len());
1612 let input = SuperTrendOscillatorInput::from_slices(
1613 high,
1614 low,
1615 source,
1616 SuperTrendOscillatorParams {
1617 length: Some(length),
1618 mult: Some(mult),
1619 smooth: Some(smooth),
1620 },
1621 );
1622 supertrend_oscillator_into_slice(
1623 out_oscillator,
1624 out_signal,
1625 out_histogram,
1626 &input,
1627 Kernel::Auto,
1628 )
1629 .map_err(|e| JsValue::from_str(&e.to_string()))
1630 }
1631}
1632
1633#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1634#[wasm_bindgen]
1635pub fn supertrend_oscillator_alloc(len: usize) -> *mut f64 {
1636 let mut buf = vec![0.0_f64; len * 3];
1637 let ptr = buf.as_mut_ptr();
1638 std::mem::forget(buf);
1639 ptr
1640}
1641
1642#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1643#[wasm_bindgen]
1644pub fn supertrend_oscillator_free(ptr: *mut f64, len: usize) {
1645 if ptr.is_null() {
1646 return;
1647 }
1648 unsafe {
1649 let _ = Vec::from_raw_parts(ptr, len * 3, len * 3);
1650 }
1651}
1652
1653#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1654#[wasm_bindgen(js_name = "supertrend_oscillator_batch")]
1655pub fn supertrend_oscillator_batch_js(
1656 high: &[f64],
1657 low: &[f64],
1658 source: &[f64],
1659 config: JsValue,
1660) -> Result<JsValue, JsValue> {
1661 let config: SuperTrendOscillatorBatchConfig = serde_wasm_bindgen::from_value(config)
1662 .map_err(|e| JsValue::from_str(&format!("Invalid config: {e}")))?;
1663 if config.length_range.len() != 3
1664 || config.mult_range.len() != 3
1665 || config.smooth_range.len() != 3
1666 {
1667 return Err(JsValue::from_str(
1668 "Invalid config: ranges must have exactly 3 elements [start, end, step]",
1669 ));
1670 }
1671
1672 let sweep = SuperTrendOscillatorBatchRange {
1673 length: (
1674 config.length_range[0],
1675 config.length_range[1],
1676 config.length_range[2],
1677 ),
1678 mult: (
1679 config.mult_range[0],
1680 config.mult_range[1],
1681 config.mult_range[2],
1682 ),
1683 smooth: (
1684 config.smooth_range[0],
1685 config.smooth_range[1],
1686 config.smooth_range[2],
1687 ),
1688 };
1689 let batch = supertrend_oscillator_batch_slice(high, low, source, &sweep, Kernel::Scalar)
1690 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1691 serde_wasm_bindgen::to_value(&SuperTrendOscillatorBatchJsOutput {
1692 oscillator: batch.oscillator,
1693 signal: batch.signal,
1694 histogram: batch.histogram,
1695 rows: batch.rows,
1696 cols: batch.cols,
1697 combos: batch.combos,
1698 })
1699 .map_err(|e| JsValue::from_str(&format!("Serialization error: {e}")))
1700}
1701
1702#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1703#[wasm_bindgen]
1704pub fn supertrend_oscillator_batch_into(
1705 high_ptr: *const f64,
1706 low_ptr: *const f64,
1707 source_ptr: *const f64,
1708 oscillator_ptr: *mut f64,
1709 signal_ptr: *mut f64,
1710 histogram_ptr: *mut f64,
1711 len: usize,
1712 length_start: usize,
1713 length_end: usize,
1714 length_step: usize,
1715 mult_start: f64,
1716 mult_end: f64,
1717 mult_step: f64,
1718 smooth_start: usize,
1719 smooth_end: usize,
1720 smooth_step: usize,
1721) -> Result<usize, JsValue> {
1722 if high_ptr.is_null()
1723 || low_ptr.is_null()
1724 || source_ptr.is_null()
1725 || oscillator_ptr.is_null()
1726 || signal_ptr.is_null()
1727 || histogram_ptr.is_null()
1728 {
1729 return Err(JsValue::from_str(
1730 "null pointer passed to supertrend_oscillator_batch_into",
1731 ));
1732 }
1733
1734 unsafe {
1735 let high = std::slice::from_raw_parts(high_ptr, len);
1736 let low = std::slice::from_raw_parts(low_ptr, len);
1737 let source = std::slice::from_raw_parts(source_ptr, len);
1738 let sweep = SuperTrendOscillatorBatchRange {
1739 length: (length_start, length_end, length_step),
1740 mult: (mult_start, mult_end, mult_step),
1741 smooth: (smooth_start, smooth_end, smooth_step),
1742 };
1743 let combos = expand_grid_supertrend_oscillator(&sweep)
1744 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1745 let rows = combos.len();
1746 let total = rows
1747 .checked_mul(len)
1748 .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
1749 let oscillator = std::slice::from_raw_parts_mut(oscillator_ptr, total);
1750 let signal = std::slice::from_raw_parts_mut(signal_ptr, total);
1751 let histogram = std::slice::from_raw_parts_mut(histogram_ptr, total);
1752 supertrend_oscillator_batch_inner_into(
1753 high,
1754 low,
1755 source,
1756 &sweep,
1757 Kernel::Scalar,
1758 false,
1759 oscillator,
1760 signal,
1761 histogram,
1762 )
1763 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1764 Ok(rows)
1765 }
1766}
1767
1768#[cfg(test)]
1769mod tests {
1770 use super::*;
1771 use crate::indicators::dispatch::{
1772 compute_cpu_batch, IndicatorBatchRequest, IndicatorDataRef, IndicatorParamSet, ParamKV,
1773 ParamValue,
1774 };
1775
1776 fn assert_close(a: &[f64], b: &[f64], tol: f64) {
1777 assert_eq!(a.len(), b.len());
1778 for i in 0..a.len() {
1779 let lhs = a[i];
1780 let rhs = b[i];
1781 if lhs.is_nan() || rhs.is_nan() {
1782 assert!(
1783 lhs.is_nan() && rhs.is_nan(),
1784 "nan mismatch at {i}: {lhs} vs {rhs}"
1785 );
1786 } else {
1787 assert!(
1788 (lhs - rhs).abs() <= tol,
1789 "mismatch at {i}: {lhs} vs {rhs} with tol {tol}"
1790 );
1791 }
1792 }
1793 }
1794
1795 fn sample_hls(len: usize) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
1796 let mut high = Vec::with_capacity(len);
1797 let mut low = Vec::with_capacity(len);
1798 let mut source = Vec::with_capacity(len);
1799
1800 for i in 0..len {
1801 let base = 100.0 + i as f64 * 0.21 + (i as f64 * 0.17).sin() * 2.0;
1802 let spread = 1.5 + (i as f64 * 0.11).cos().abs() * 1.25;
1803 let src = base + (i as f64 * 0.07).cos() * 0.6;
1804 high.push(base + spread);
1805 low.push(base - spread);
1806 source.push(src);
1807 }
1808
1809 (high, low, source)
1810 }
1811
1812 fn check_output_contract(kernel: Kernel) {
1813 let (high, low, source) = sample_hls(192);
1814 let input = SuperTrendOscillatorInput::from_slices(
1815 &high,
1816 &low,
1817 &source,
1818 SuperTrendOscillatorParams {
1819 length: Some(10),
1820 mult: Some(2.0),
1821 smooth: Some(72),
1822 },
1823 );
1824 let out = supertrend_oscillator_with_kernel(&input, kernel).expect("indicator");
1825 assert_eq!(out.oscillator.len(), source.len());
1826 assert_eq!(out.signal.len(), source.len());
1827 assert_eq!(out.histogram.len(), source.len());
1828 assert!(out.oscillator[..9].iter().all(|v| v.is_nan()));
1829 assert!(out.signal[..9].iter().all(|v| v.is_nan()));
1830 assert!(out.histogram[..9].iter().all(|v| v.is_nan()));
1831 assert!(out.oscillator[9..].iter().any(|v| v.is_finite()));
1832 assert!(out.signal[9..].iter().any(|v| v.is_finite()));
1833 assert!(out.histogram[9..].iter().any(|v| v.is_finite()));
1834 }
1835
1836 fn check_into_matches_api(kernel: Kernel) {
1837 let (high, low, source) = sample_hls(224);
1838 let input = SuperTrendOscillatorInput::from_slices(
1839 &high,
1840 &low,
1841 &source,
1842 SuperTrendOscillatorParams {
1843 length: Some(11),
1844 mult: Some(2.5),
1845 smooth: Some(20),
1846 },
1847 );
1848 let baseline = supertrend_oscillator_with_kernel(&input, kernel).expect("baseline");
1849 let mut oscillator = vec![0.0; source.len()];
1850 let mut signal = vec![0.0; source.len()];
1851 let mut histogram = vec![0.0; source.len()];
1852 supertrend_oscillator_into_slice(
1853 &mut oscillator,
1854 &mut signal,
1855 &mut histogram,
1856 &input,
1857 kernel,
1858 )
1859 .expect("into");
1860
1861 assert_close(&baseline.oscillator, &oscillator, 1e-12);
1862 assert_close(&baseline.signal, &signal, 1e-12);
1863 assert_close(&baseline.histogram, &histogram, 1e-12);
1864 }
1865
1866 fn check_stream_matches_batch() {
1867 let (high, low, source) = sample_hls(200);
1868 let input = SuperTrendOscillatorInput::from_slices(
1869 &high,
1870 &low,
1871 &source,
1872 SuperTrendOscillatorParams {
1873 length: Some(12),
1874 mult: Some(1.75),
1875 smooth: Some(18),
1876 },
1877 );
1878 let batch = supertrend_oscillator(&input).expect("batch");
1879 let mut stream = SuperTrendOscillatorStream::try_new(SuperTrendOscillatorParams {
1880 length: Some(12),
1881 mult: Some(1.75),
1882 smooth: Some(18),
1883 })
1884 .expect("stream");
1885
1886 let mut oscillator = vec![f64::NAN; source.len()];
1887 let mut signal = vec![f64::NAN; source.len()];
1888 let mut histogram = vec![f64::NAN; source.len()];
1889 for i in 0..source.len() {
1890 if let Some((osc, sig, hist)) = stream.update(high[i], low[i], source[i]) {
1891 oscillator[i] = osc;
1892 signal[i] = sig;
1893 histogram[i] = hist;
1894 }
1895 }
1896
1897 assert_close(&batch.oscillator, &oscillator, 1e-12);
1898 assert_close(&batch.signal, &signal, 1e-12);
1899 assert_close(&batch.histogram, &histogram, 1e-12);
1900 }
1901
1902 fn check_batch_single_matches_single(kernel: Kernel) {
1903 let (high, low, source) = sample_hls(180);
1904 let batch = supertrend_oscillator_batch_with_kernel(
1905 &high,
1906 &low,
1907 &source,
1908 &SuperTrendOscillatorBatchRange {
1909 length: (12, 12, 0),
1910 mult: (2.5, 2.5, 0.0),
1911 smooth: (18, 18, 0),
1912 },
1913 kernel,
1914 )
1915 .expect("batch");
1916 let single = supertrend_oscillator(&SuperTrendOscillatorInput::from_slices(
1917 &high,
1918 &low,
1919 &source,
1920 SuperTrendOscillatorParams {
1921 length: Some(12),
1922 mult: Some(2.5),
1923 smooth: Some(18),
1924 },
1925 ))
1926 .expect("single");
1927
1928 assert_eq!(batch.rows, 1);
1929 assert_eq!(batch.cols, source.len());
1930 assert_close(&batch.oscillator[..source.len()], &single.oscillator, 1e-12);
1931 assert_close(&batch.signal[..source.len()], &single.signal, 1e-12);
1932 assert_close(&batch.histogram[..source.len()], &single.histogram, 1e-12);
1933 }
1934
1935 #[test]
1936 fn supertrend_oscillator_invalid_params() {
1937 let (high, low, source) = sample_hls(64);
1938
1939 let err = supertrend_oscillator(&SuperTrendOscillatorInput::from_slices(
1940 &high,
1941 &low,
1942 &source,
1943 SuperTrendOscillatorParams {
1944 length: Some(0),
1945 mult: Some(2.0),
1946 smooth: Some(10),
1947 },
1948 ))
1949 .expect_err("invalid length");
1950 assert!(matches!(
1951 err,
1952 SuperTrendOscillatorError::InvalidLength { .. }
1953 ));
1954
1955 let err = supertrend_oscillator(&SuperTrendOscillatorInput::from_slices(
1956 &high,
1957 &low,
1958 &source,
1959 SuperTrendOscillatorParams {
1960 length: Some(10),
1961 mult: Some(0.0),
1962 smooth: Some(10),
1963 },
1964 ))
1965 .expect_err("invalid mult");
1966 assert!(matches!(
1967 err,
1968 SuperTrendOscillatorError::InvalidMultiplier { .. }
1969 ));
1970
1971 let err = supertrend_oscillator(&SuperTrendOscillatorInput::from_slices(
1972 &high,
1973 &low,
1974 &source,
1975 SuperTrendOscillatorParams {
1976 length: Some(10),
1977 mult: Some(2.0),
1978 smooth: Some(0),
1979 },
1980 ))
1981 .expect_err("invalid smooth");
1982 assert!(matches!(
1983 err,
1984 SuperTrendOscillatorError::InvalidSmooth { .. }
1985 ));
1986 }
1987
1988 #[test]
1989 fn supertrend_oscillator_dispatch_matches_direct() {
1990 let (high, low, source) = sample_hls(160);
1991 let combo = [
1992 ParamKV {
1993 key: "length",
1994 value: ParamValue::Int(12),
1995 },
1996 ParamKV {
1997 key: "mult",
1998 value: ParamValue::Float(2.5),
1999 },
2000 ParamKV {
2001 key: "smooth",
2002 value: ParamValue::Int(18),
2003 },
2004 ];
2005 let combos = [IndicatorParamSet { params: &combo }];
2006 let req = IndicatorBatchRequest {
2007 indicator_id: "supertrend_oscillator",
2008 output_id: Some("oscillator"),
2009 data: IndicatorDataRef::Ohlc {
2010 open: &source,
2011 high: &high,
2012 low: &low,
2013 close: &source,
2014 },
2015 combos: &combos,
2016 kernel: Kernel::Auto,
2017 };
2018 let out = compute_cpu_batch(req).expect("dispatch");
2019 let direct = supertrend_oscillator(&SuperTrendOscillatorInput::from_slices(
2020 &high,
2021 &low,
2022 &source,
2023 SuperTrendOscillatorParams {
2024 length: Some(12),
2025 mult: Some(2.5),
2026 smooth: Some(18),
2027 },
2028 ))
2029 .expect("direct");
2030 assert_eq!(out.rows, 1);
2031 assert_eq!(out.cols, source.len());
2032 assert_close(&out.values_f64.expect("values"), &direct.oscillator, 1e-12);
2033 }
2034
2035 macro_rules! gen_kernel_tests {
2036 ($module:ident, $kernel:expr, $batch_kernel:expr) => {
2037 mod $module {
2038 use super::*;
2039
2040 #[test]
2041 fn output_contract() {
2042 check_output_contract($kernel);
2043 }
2044
2045 #[test]
2046 fn into_matches_api() {
2047 check_into_matches_api($kernel);
2048 }
2049
2050 #[test]
2051 fn batch_single_matches_single() {
2052 check_batch_single_matches_single($batch_kernel);
2053 }
2054 }
2055 };
2056 }
2057
2058 gen_kernel_tests!(scalar_kernel, Kernel::Scalar, Kernel::ScalarBatch);
2059 gen_kernel_tests!(auto_kernel, Kernel::Auto, Kernel::Auto);
2060
2061 #[test]
2062 fn supertrend_oscillator_stream_matches_batch() {
2063 check_stream_matches_batch();
2064 }
2065}