1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::moving_averages::DeviceArrayF32;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::cuda::{cuda_available, CudaSupertrend};
5use crate::indicators::atr::{atr, AtrData, AtrError, AtrInput, AtrOutput, AtrParams};
6use crate::utilities::data_loader::{source_type, Candles};
7#[cfg(all(feature = "python", feature = "cuda"))]
8use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
9use crate::utilities::enums::Kernel;
10use crate::utilities::helpers::{
11 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
12 make_uninit_matrix,
13};
14#[cfg(feature = "python")]
15use crate::utilities::kernel_validation::validate_kernel;
16use aligned_vec::{AVec, CACHELINE_ALIGN};
17#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
18use core::arch::x86_64::*;
19#[cfg(all(feature = "python", feature = "cuda"))]
20use cust::context::Context;
21#[cfg(feature = "python")]
22use pyo3::exceptions::{PyBufferError, PyValueError};
23#[cfg(feature = "python")]
24use pyo3::prelude::*;
25#[cfg(not(target_arch = "wasm32"))]
26use rayon::prelude::*;
27#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
28use serde::{Deserialize, Serialize};
29use std::collections::HashMap;
30use std::convert::AsRef;
31#[cfg(all(feature = "python", feature = "cuda"))]
32use std::sync::Arc;
33use thiserror::Error;
34#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
35use wasm_bindgen::prelude::*;
36
37#[derive(Debug, Clone)]
38pub enum SuperTrendData<'a> {
39 Candles {
40 candles: &'a Candles,
41 },
42 Slices {
43 high: &'a [f64],
44 low: &'a [f64],
45 close: &'a [f64],
46 },
47}
48
49#[derive(Debug, Clone)]
50pub struct SuperTrendParams {
51 pub period: Option<usize>,
52 pub factor: Option<f64>,
53}
54impl Default for SuperTrendParams {
55 fn default() -> Self {
56 Self {
57 period: Some(10),
58 factor: Some(3.0),
59 }
60 }
61}
62
63#[derive(Debug, Clone)]
64pub struct SuperTrendInput<'a> {
65 pub data: SuperTrendData<'a>,
66 pub params: SuperTrendParams,
67}
68
69impl<'a> SuperTrendInput<'a> {
70 #[inline]
71 pub fn from_candles(candles: &'a Candles, params: SuperTrendParams) -> Self {
72 Self {
73 data: SuperTrendData::Candles { candles },
74 params,
75 }
76 }
77 #[inline]
78 pub fn from_slices(
79 high: &'a [f64],
80 low: &'a [f64],
81 close: &'a [f64],
82 params: SuperTrendParams,
83 ) -> Self {
84 Self {
85 data: SuperTrendData::Slices { high, low, close },
86 params,
87 }
88 }
89 #[inline]
90 pub fn with_default_candles(candles: &'a Candles) -> Self {
91 Self {
92 data: SuperTrendData::Candles { candles },
93 params: SuperTrendParams::default(),
94 }
95 }
96 #[inline]
97 pub fn get_period(&self) -> usize {
98 self.params.period.unwrap_or(10)
99 }
100 #[inline]
101 pub fn get_factor(&self) -> f64 {
102 self.params.factor.unwrap_or(3.0)
103 }
104 #[inline(always)]
105 fn as_hlc(&self) -> (&[f64], &[f64], &[f64]) {
106 match &self.data {
107 SuperTrendData::Candles { candles } => (
108 source_type(candles, "high"),
109 source_type(candles, "low"),
110 source_type(candles, "close"),
111 ),
112 SuperTrendData::Slices { high, low, close } => (*high, *low, *close),
113 }
114 }
115}
116
117#[derive(Debug, Clone)]
118pub struct SuperTrendOutput {
119 pub trend: Vec<f64>,
120 pub changed: Vec<f64>,
121}
122
123#[derive(Copy, Clone, Debug)]
124pub struct SuperTrendBuilder {
125 period: Option<usize>,
126 factor: Option<f64>,
127 kernel: Kernel,
128}
129impl Default for SuperTrendBuilder {
130 fn default() -> Self {
131 Self {
132 period: None,
133 factor: None,
134 kernel: Kernel::Auto,
135 }
136 }
137}
138impl SuperTrendBuilder {
139 #[inline]
140 pub fn new() -> Self {
141 Self::default()
142 }
143 #[inline]
144 pub fn period(mut self, n: usize) -> Self {
145 self.period = Some(n);
146 self
147 }
148 #[inline]
149 pub fn factor(mut self, x: f64) -> Self {
150 self.factor = Some(x);
151 self
152 }
153 #[inline]
154 pub fn kernel(mut self, k: Kernel) -> Self {
155 self.kernel = k;
156 self
157 }
158 #[inline]
159 pub fn apply(self, c: &Candles) -> Result<SuperTrendOutput, SuperTrendError> {
160 let p = SuperTrendParams {
161 period: self.period,
162 factor: self.factor,
163 };
164 let i = SuperTrendInput::from_candles(c, p);
165 supertrend_with_kernel(&i, self.kernel)
166 }
167 #[inline]
168 pub fn apply_slices(
169 self,
170 high: &[f64],
171 low: &[f64],
172 close: &[f64],
173 ) -> Result<SuperTrendOutput, SuperTrendError> {
174 let p = SuperTrendParams {
175 period: self.period,
176 factor: self.factor,
177 };
178 let i = SuperTrendInput::from_slices(high, low, close, p);
179 supertrend_with_kernel(&i, self.kernel)
180 }
181 #[inline]
182 pub fn into_stream(self) -> Result<SuperTrendStream, SuperTrendError> {
183 let p = SuperTrendParams {
184 period: self.period,
185 factor: self.factor,
186 };
187 SuperTrendStream::try_new(p)
188 }
189}
190
191#[derive(Debug, Error)]
192pub enum SuperTrendError {
193 #[error("supertrend: Empty data provided.")]
194 EmptyInputData,
195 #[error("supertrend: All values are NaN.")]
196 AllValuesNaN,
197 #[error("supertrend: Invalid period: period = {period}, data length = {data_len}")]
198 InvalidPeriod { period: usize, data_len: usize },
199 #[error("supertrend: Not enough valid data: needed = {needed}, valid = {valid}")]
200 NotEnoughValidData { needed: usize, valid: usize },
201 #[error("supertrend: Output slice length mismatch: expected = {expected}, got = {got}")]
202 OutputLengthMismatch { expected: usize, got: usize },
203 #[error("supertrend: Invalid range: start={start}, end={end}, step={step}")]
204 InvalidRange {
205 start: usize,
206 end: usize,
207 step: usize,
208 },
209 #[error("supertrend: Invalid factor range: start={start}, end={end}, step={step}")]
210 InvalidFactorRange { start: f64, end: f64, step: f64 },
211 #[error("supertrend: Invalid kernel for batch: {0:?}")]
212 InvalidKernelForBatch(Kernel),
213 #[error(transparent)]
214 AtrError(#[from] AtrError),
215}
216
217#[inline]
218pub fn supertrend(input: &SuperTrendInput) -> Result<SuperTrendOutput, SuperTrendError> {
219 supertrend_with_kernel(input, Kernel::Auto)
220}
221
222#[inline(always)]
223fn supertrend_prepare<'a>(
224 input: &'a SuperTrendInput,
225 kernel: Kernel,
226) -> Result<
227 (
228 &'a [f64],
229 &'a [f64],
230 &'a [f64],
231 usize,
232 f64,
233 usize,
234 Vec<f64>,
235 Kernel,
236 ),
237 SuperTrendError,
238> {
239 let (high, low, close) = input.as_hlc();
240
241 if high.is_empty() || low.is_empty() || close.is_empty() {
242 return Err(SuperTrendError::EmptyInputData);
243 }
244
245 let period = input.get_period();
246 if period == 0 || period > high.len() {
247 return Err(SuperTrendError::InvalidPeriod {
248 period,
249 data_len: high.len(),
250 });
251 }
252
253 let factor = input.get_factor();
254 let len = high.len();
255
256 let mut first_valid_idx = None;
257 for i in 0..len {
258 if !high[i].is_nan() && !low[i].is_nan() && !close[i].is_nan() {
259 first_valid_idx = Some(i);
260 break;
261 }
262 }
263
264 let first_valid_idx = match first_valid_idx {
265 Some(idx) => idx,
266 None => return Err(SuperTrendError::AllValuesNaN),
267 };
268
269 if (len - first_valid_idx) < period {
270 return Err(SuperTrendError::NotEnoughValidData {
271 needed: period,
272 valid: len - first_valid_idx,
273 });
274 }
275
276 let atr_input = AtrInput::from_slices(
277 &high[first_valid_idx..],
278 &low[first_valid_idx..],
279 &close[first_valid_idx..],
280 AtrParams {
281 length: Some(period),
282 },
283 );
284 let AtrOutput { values: atr_values } = atr(&atr_input)?;
285
286 let chosen = match kernel {
287 Kernel::Auto => Kernel::Scalar,
288 other => other,
289 };
290
291 Ok((
292 high,
293 low,
294 close,
295 period,
296 factor,
297 first_valid_idx,
298 atr_values,
299 chosen,
300 ))
301}
302
303#[inline(always)]
304fn supertrend_compute_into(
305 high: &[f64],
306 low: &[f64],
307 close: &[f64],
308 period: usize,
309 factor: f64,
310 first_valid_idx: usize,
311 atr_values: &[f64],
312 kernel: Kernel,
313 trend_out: &mut [f64],
314 changed_out: &mut [f64],
315) {
316 unsafe {
317 match kernel {
318 Kernel::Scalar | Kernel::ScalarBatch => {
319 supertrend_scalar(
320 high,
321 low,
322 close,
323 period,
324 factor,
325 first_valid_idx,
326 &atr_values,
327 trend_out,
328 changed_out,
329 );
330 }
331 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
332 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
333 supertrend_scalar(
334 high,
335 low,
336 close,
337 period,
338 factor,
339 first_valid_idx,
340 &atr_values,
341 trend_out,
342 changed_out,
343 );
344 }
345 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
346 Kernel::Avx2 | Kernel::Avx2Batch => {
347 supertrend_avx2(
348 high,
349 low,
350 close,
351 period,
352 factor,
353 first_valid_idx,
354 &atr_values,
355 trend_out,
356 changed_out,
357 );
358 }
359 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
360 Kernel::Avx512 | Kernel::Avx512Batch => {
361 supertrend_avx512(
362 high,
363 low,
364 close,
365 period,
366 factor,
367 first_valid_idx,
368 &atr_values,
369 trend_out,
370 changed_out,
371 );
372 }
373 _ => unreachable!(),
374 }
375 }
376}
377
378pub fn supertrend_with_kernel(
379 input: &SuperTrendInput,
380 kernel: Kernel,
381) -> Result<SuperTrendOutput, SuperTrendError> {
382 let (high, low, close, period, factor, first_valid_idx, atr_values, chosen) =
383 supertrend_prepare(input, kernel)?;
384
385 let len = high.len();
386 let mut trend = alloc_with_nan_prefix(len, first_valid_idx + period - 1);
387 let mut changed = alloc_with_nan_prefix(len, first_valid_idx + period - 1);
388
389 supertrend_compute_into(
390 high,
391 low,
392 close,
393 period,
394 factor,
395 first_valid_idx,
396 &atr_values,
397 chosen,
398 &mut trend,
399 &mut changed,
400 );
401
402 Ok(SuperTrendOutput { trend, changed })
403}
404
405#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
406#[inline]
407pub fn supertrend_into(
408 input: &SuperTrendInput,
409 trend_out: &mut [f64],
410 changed_out: &mut [f64],
411) -> Result<(), SuperTrendError> {
412 let (high, _low, _close) = input.as_hlc();
413 let len = high.len();
414
415 if trend_out.len() != len {
416 return Err(SuperTrendError::OutputLengthMismatch {
417 expected: len,
418 got: trend_out.len(),
419 });
420 }
421 if changed_out.len() != len {
422 return Err(SuperTrendError::OutputLengthMismatch {
423 expected: len,
424 got: changed_out.len(),
425 });
426 }
427
428 let (high, low, close, period, factor, first_valid_idx, atr_values, chosen) =
429 supertrend_prepare(input, Kernel::Auto)?;
430
431 let warmup_end = first_valid_idx + period - 1;
432 let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
433 for v in &mut trend_out[..warmup_end.min(len)] {
434 *v = qnan;
435 }
436 for v in &mut changed_out[..warmup_end.min(len)] {
437 *v = qnan;
438 }
439
440 supertrend_compute_into(
441 high,
442 low,
443 close,
444 period,
445 factor,
446 first_valid_idx,
447 &atr_values,
448 chosen,
449 trend_out,
450 changed_out,
451 );
452
453 Ok(())
454}
455
456#[inline(always)]
457pub fn supertrend_scalar(
458 high: &[f64],
459 low: &[f64],
460 close: &[f64],
461 period: usize,
462 factor: f64,
463 first_valid_idx: usize,
464 atr_values: &[f64],
465 trend: &mut [f64],
466 changed: &mut [f64],
467) {
468 let len = high.len();
469 let start = first_valid_idx + period;
470 if start > len {
471 return;
472 }
473
474 unsafe {
475 let h_ptr = high.as_ptr();
476 let l_ptr = low.as_ptr();
477 let c_ptr = close.as_ptr();
478 let atr_ptr = atr_values.as_ptr();
479 let tr_ptr = trend.as_mut_ptr();
480 let ch_ptr = changed.as_mut_ptr();
481
482 let warmup = start - 1;
483 let hw = *h_ptr.add(warmup);
484 let lw = *l_ptr.add(warmup);
485 let hl2_w = (hw + lw) * 0.5;
486 let atr_w = *atr_ptr.add(period - 1);
487 let mut prev_upper_band = hl2_w + factor * atr_w;
488 let mut prev_lower_band = hl2_w - factor * atr_w;
489
490 let mut last_close = *c_ptr.add(warmup);
491 let mut upper_state = if last_close <= prev_upper_band {
492 *tr_ptr.add(warmup) = prev_upper_band;
493 true
494 } else {
495 *tr_ptr.add(warmup) = prev_lower_band;
496 false
497 };
498 *ch_ptr.add(warmup) = 0.0;
499
500 let mut i = warmup + 1;
501 let mut atr_idx = i.saturating_sub(first_valid_idx);
502 let neg_factor = -factor;
503 while i < len {
504 let atr_i = *atr_ptr.add(atr_idx);
505 let hi = *h_ptr.add(i);
506 let lo = *l_ptr.add(i);
507 let hl2 = (hi + lo) * 0.5;
508 let upper_basic = factor.mul_add(atr_i, hl2);
509 let lower_basic = neg_factor.mul_add(atr_i, hl2);
510
511 let prev_close = last_close;
512 let mut curr_upper_band = upper_basic;
513 if prev_close <= prev_upper_band {
514 curr_upper_band = curr_upper_band.min(prev_upper_band);
515 }
516 let mut curr_lower_band = lower_basic;
517 if prev_close >= prev_lower_band {
518 curr_lower_band = curr_lower_band.max(prev_lower_band);
519 }
520
521 let curr_close = *c_ptr.add(i);
522 if upper_state {
523 if curr_close <= curr_upper_band {
524 *tr_ptr.add(i) = curr_upper_band;
525 *ch_ptr.add(i) = 0.0;
526 } else {
527 *tr_ptr.add(i) = curr_lower_band;
528 *ch_ptr.add(i) = 1.0;
529 upper_state = false;
530 }
531 } else {
532 if curr_close >= curr_lower_band {
533 *tr_ptr.add(i) = curr_lower_band;
534 *ch_ptr.add(i) = 0.0;
535 } else {
536 *tr_ptr.add(i) = curr_upper_band;
537 *ch_ptr.add(i) = 1.0;
538 upper_state = true;
539 }
540 }
541
542 prev_upper_band = curr_upper_band;
543 prev_lower_band = curr_lower_band;
544 last_close = curr_close;
545 i += 1;
546 atr_idx += 1;
547 }
548 }
549}
550
551#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
552#[inline(always)]
553pub unsafe fn supertrend_avx2(
554 high: &[f64],
555 low: &[f64],
556 close: &[f64],
557 period: usize,
558 factor: f64,
559 first_valid_idx: usize,
560 atr_values: &[f64],
561 trend: &mut [f64],
562 changed: &mut [f64],
563) {
564 supertrend_scalar(
565 high,
566 low,
567 close,
568 period,
569 factor,
570 first_valid_idx,
571 atr_values,
572 trend,
573 changed,
574 );
575}
576
577#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
578#[inline(always)]
579pub unsafe fn supertrend_avx512(
580 high: &[f64],
581 low: &[f64],
582 close: &[f64],
583 period: usize,
584 factor: f64,
585 first_valid_idx: usize,
586 atr_values: &[f64],
587 trend: &mut [f64],
588 changed: &mut [f64],
589) {
590 if period <= 32 {
591 supertrend_avx512_short(
592 high,
593 low,
594 close,
595 period,
596 factor,
597 first_valid_idx,
598 atr_values,
599 trend,
600 changed,
601 );
602 } else {
603 supertrend_avx512_long(
604 high,
605 low,
606 close,
607 period,
608 factor,
609 first_valid_idx,
610 atr_values,
611 trend,
612 changed,
613 );
614 }
615}
616
617#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
618#[inline(always)]
619pub unsafe fn supertrend_avx512_short(
620 high: &[f64],
621 low: &[f64],
622 close: &[f64],
623 period: usize,
624 factor: f64,
625 first_valid_idx: usize,
626 atr_values: &[f64],
627 trend: &mut [f64],
628 changed: &mut [f64],
629) {
630 supertrend_scalar(
631 high,
632 low,
633 close,
634 period,
635 factor,
636 first_valid_idx,
637 atr_values,
638 trend,
639 changed,
640 );
641}
642
643#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
644#[inline(always)]
645pub unsafe fn supertrend_avx512_long(
646 high: &[f64],
647 low: &[f64],
648 close: &[f64],
649 period: usize,
650 factor: f64,
651 first_valid_idx: usize,
652 atr_values: &[f64],
653 trend: &mut [f64],
654 changed: &mut [f64],
655) {
656 supertrend_scalar(
657 high,
658 low,
659 close,
660 period,
661 factor,
662 first_valid_idx,
663 atr_values,
664 trend,
665 changed,
666 );
667}
668
669#[inline]
670pub unsafe fn supertrend_scalar_classic(
671 high: &[f64],
672 low: &[f64],
673 close: &[f64],
674 period: usize,
675 factor: f64,
676 trend_out: &mut [f64],
677 changed_out: &mut [f64],
678) -> Result<(), SuperTrendError> {
679 let n = high.len();
680
681 let mut first_valid = None;
682 for i in 0..n {
683 if high[i].is_finite() && low[i].is_finite() && close[i].is_finite() {
684 first_valid = Some(i);
685 break;
686 }
687 }
688
689 let first_valid = first_valid.ok_or(SuperTrendError::AllValuesNaN)?;
690
691 if n - first_valid < period {
692 return Err(SuperTrendError::NotEnoughValidData {
693 needed: period,
694 valid: n - first_valid,
695 });
696 }
697
698 let warmup = first_valid + period - 1;
699 for i in 0..warmup.min(n) {
700 trend_out[i] = f64::NAN;
701 changed_out[i] = f64::NAN;
702 }
703
704 let mut tr_values = vec![0.0; n];
705
706 if first_valid < n {
707 tr_values[first_valid] = high[first_valid] - low[first_valid];
708 }
709
710 for i in (first_valid + 1)..n {
711 let high_low = high[i] - low[i];
712 let high_close = (high[i] - close[i - 1]).abs();
713 let low_close = (low[i] - close[i - 1]).abs();
714 tr_values[i] = high_low.max(high_close).max(low_close);
715 }
716
717 let mut atr_values = vec![f64::NAN; n];
718
719 let mut atr_sum = 0.0;
720 for i in first_valid..(first_valid + period).min(n) {
721 atr_sum += tr_values[i];
722 }
723
724 if first_valid + period <= n {
725 atr_values[first_valid + period - 1] = atr_sum / period as f64;
726
727 let alpha = 1.0 / period as f64;
728 let alpha_1minus = 1.0 - alpha;
729
730 for i in (first_valid + period)..n {
731 atr_values[i] = alpha * tr_values[i] + alpha_1minus * atr_values[i - 1];
732 }
733 }
734
735 if warmup >= n {
736 return Ok(());
737 }
738
739 let half_range = (high[warmup] + low[warmup]) / 2.0;
740 let mut prev_upper_band = factor.mul_add(atr_values[warmup], half_range);
741 let mut prev_lower_band = (-factor).mul_add(atr_values[warmup], half_range);
742
743 let mut last_close = close[warmup];
744 let mut upper_state = if last_close <= prev_upper_band {
745 trend_out[warmup] = prev_upper_band;
746 true
747 } else {
748 trend_out[warmup] = prev_lower_band;
749 false
750 };
751 changed_out[warmup] = 0.0;
752
753 for i in (warmup + 1)..n {
754 let half_range = (high[i] + low[i]) / 2.0;
755 let upper_basic = factor.mul_add(atr_values[i], half_range);
756 let lower_basic = (-factor).mul_add(atr_values[i], half_range);
757
758 let prev_close = last_close;
759 let mut curr_upper_band = upper_basic;
760 let mut curr_lower_band = lower_basic;
761 if prev_close <= prev_upper_band {
762 curr_upper_band = curr_upper_band.min(prev_upper_band);
763 }
764 if prev_close >= prev_lower_band {
765 curr_lower_band = curr_lower_band.max(prev_lower_band);
766 }
767
768 let curr_close = close[i];
769 if upper_state {
770 if curr_close <= curr_upper_band {
771 trend_out[i] = curr_upper_band;
772 changed_out[i] = 0.0;
773 } else {
774 trend_out[i] = curr_lower_band;
775 changed_out[i] = 1.0;
776 upper_state = false;
777 }
778 } else {
779 if curr_close >= curr_lower_band {
780 trend_out[i] = curr_lower_band;
781 changed_out[i] = 0.0;
782 } else {
783 trend_out[i] = curr_upper_band;
784 changed_out[i] = 1.0;
785 upper_state = true;
786 }
787 }
788
789 prev_upper_band = curr_upper_band;
790 prev_lower_band = curr_lower_band;
791 last_close = curr_close;
792 }
793
794 Ok(())
795}
796
797#[derive(Debug, Clone)]
798pub struct SuperTrendStream {
799 pub period: usize,
800 pub factor: f64,
801 atr_stream: crate::indicators::atr::AtrStream,
802
803 prev_upper_band: f64,
804 prev_lower_band: f64,
805 prev_close: f64,
806 upper_state: bool,
807 warmed: bool,
808}
809
810impl SuperTrendStream {
811 #[inline]
812 pub fn try_new(params: SuperTrendParams) -> Result<Self, SuperTrendError> {
813 let period = params.period.unwrap_or(10);
814 let factor = params.factor.unwrap_or(3.0);
815 let atr_stream = crate::indicators::atr::AtrStream::try_new(AtrParams {
816 length: Some(period),
817 })?;
818 Ok(Self {
819 period,
820 factor,
821 atr_stream,
822 prev_upper_band: f64::NAN,
823 prev_lower_band: f64::NAN,
824 prev_close: f64::NAN,
825 upper_state: false,
826 warmed: false,
827 })
828 }
829
830 #[inline(always)]
831 pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<(f64, f64)> {
832 let atr = match self.atr_stream.update(high, low, close) {
833 Some(v) => v,
834 None => return None,
835 };
836
837 let hl2 = (high + low) * 0.5;
838 let upper_basic = self.factor.mul_add(atr, hl2);
839 let lower_basic = (-self.factor).mul_add(atr, hl2);
840
841 if !self.warmed {
842 self.prev_upper_band = upper_basic;
843 self.prev_lower_band = lower_basic;
844 self.upper_state = close <= self.prev_upper_band;
845 let trend = if self.upper_state {
846 self.prev_upper_band
847 } else {
848 self.prev_lower_band
849 };
850 self.prev_close = close;
851 self.warmed = true;
852 return Some((trend, 0.0));
853 }
854
855 let mut curr_upper_band = upper_basic;
856 if self.prev_close <= self.prev_upper_band {
857 curr_upper_band = curr_upper_band.min(self.prev_upper_band);
858 }
859 let mut curr_lower_band = lower_basic;
860 if self.prev_close >= self.prev_lower_band {
861 curr_lower_band = curr_lower_band.max(self.prev_lower_band);
862 }
863
864 let mut changed = 0.0;
865 let trend = if self.upper_state {
866 if close <= curr_upper_band {
867 curr_upper_band
868 } else {
869 changed = 1.0;
870 self.upper_state = false;
871 curr_lower_band
872 }
873 } else {
874 if close >= curr_lower_band {
875 curr_lower_band
876 } else {
877 changed = 1.0;
878 self.upper_state = true;
879 curr_upper_band
880 }
881 };
882
883 self.prev_upper_band = curr_upper_band;
884 self.prev_lower_band = curr_lower_band;
885 self.prev_close = close;
886
887 Some((trend, changed))
888 }
889}
890
891#[derive(Clone, Debug)]
892pub struct SuperTrendBatchRange {
893 pub period: (usize, usize, usize),
894 pub factor: (f64, f64, f64),
895}
896impl Default for SuperTrendBatchRange {
897 fn default() -> Self {
898 Self {
899 period: (10, 259, 1),
900 factor: (3.0, 3.0, 0.0),
901 }
902 }
903}
904
905#[derive(Clone, Debug, Default)]
906pub struct SuperTrendBatchBuilder {
907 range: SuperTrendBatchRange,
908 kernel: Kernel,
909}
910impl SuperTrendBatchBuilder {
911 pub fn new() -> Self {
912 Self::default()
913 }
914 pub fn kernel(mut self, k: Kernel) -> Self {
915 self.kernel = k;
916 self
917 }
918 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
919 self.range.period = (start, end, step);
920 self
921 }
922 pub fn period_static(mut self, p: usize) -> Self {
923 self.range.period = (p, p, 0);
924 self
925 }
926 pub fn factor_range(mut self, start: f64, end: f64, step: f64) -> Self {
927 self.range.factor = (start, end, step);
928 self
929 }
930 pub fn factor_static(mut self, x: f64) -> Self {
931 self.range.factor = (x, x, 0.0);
932 self
933 }
934 pub fn apply_slices(
935 self,
936 high: &[f64],
937 low: &[f64],
938 close: &[f64],
939 ) -> Result<SuperTrendBatchOutput, SuperTrendError> {
940 supertrend_batch_with_kernel(high, low, close, &self.range, self.kernel)
941 }
942 pub fn apply_candles(self, c: &Candles) -> Result<SuperTrendBatchOutput, SuperTrendError> {
943 let high = source_type(c, "high");
944 let low = source_type(c, "low");
945 let close = source_type(c, "close");
946 self.apply_slices(high, low, close)
947 }
948 pub fn with_default_candles(
949 c: &Candles,
950 k: Kernel,
951 ) -> Result<SuperTrendBatchOutput, SuperTrendError> {
952 SuperTrendBatchBuilder::new().kernel(k).apply_candles(c)
953 }
954}
955
956pub struct SuperTrendBatchOutput {
957 pub trend: Vec<f64>,
958 pub changed: Vec<f64>,
959 pub combos: Vec<SuperTrendParams>,
960 pub rows: usize,
961 pub cols: usize,
962}
963impl SuperTrendBatchOutput {
964 pub fn row_for_params(&self, p: &SuperTrendParams) -> Option<usize> {
965 self.combos.iter().position(|c| {
966 c.period.unwrap_or(10) == p.period.unwrap_or(10)
967 && (c.factor.unwrap_or(3.0) - p.factor.unwrap_or(3.0)).abs() < 1e-12
968 })
969 }
970 pub fn trend_for(&self, p: &SuperTrendParams) -> Option<&[f64]> {
971 self.row_for_params(p).map(|row| {
972 let start = row * self.cols;
973 &self.trend[start..start + self.cols]
974 })
975 }
976 pub fn changed_for(&self, p: &SuperTrendParams) -> Option<&[f64]> {
977 self.row_for_params(p).map(|row| {
978 let start = row * self.cols;
979 &self.changed[start..start + self.cols]
980 })
981 }
982}
983
984#[cfg(all(feature = "python", feature = "cuda"))]
985#[pyclass(module = "ta_indicators.cuda", unsendable)]
986pub struct SupertrendDeviceArrayF32Py {
987 pub(crate) inner: DeviceArrayF32,
988 pub(crate) _ctx: Arc<Context>,
989 pub(crate) device_id: u32,
990}
991
992#[cfg(all(feature = "python", feature = "cuda"))]
993#[pymethods]
994impl SupertrendDeviceArrayF32Py {
995 #[getter]
996 fn __cuda_array_interface__<'py>(
997 &self,
998 py: Python<'py>,
999 ) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1000 let d = pyo3::types::PyDict::new(py);
1001 d.set_item("shape", (self.inner.rows, self.inner.cols))?;
1002 d.set_item("typestr", "<f4")?;
1003 d.set_item(
1004 "strides",
1005 (
1006 self.inner.cols * std::mem::size_of::<f32>(),
1007 std::mem::size_of::<f32>(),
1008 ),
1009 )?;
1010 d.set_item("data", (self.inner.device_ptr() as usize, false))?;
1011
1012 d.set_item("version", 3)?;
1013 Ok(d)
1014 }
1015
1016 fn __dlpack_device__(&self) -> (i32, i32) {
1017 (2, self.device_id as i32)
1018 }
1019
1020 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1021 fn __dlpack__<'py>(
1022 &mut self,
1023 py: Python<'py>,
1024 stream: Option<PyObject>,
1025 max_version: Option<PyObject>,
1026 dl_device: Option<PyObject>,
1027 copy: Option<PyObject>,
1028 ) -> PyResult<PyObject> {
1029 use cust::memory::DeviceBuffer;
1030
1031 let (kdl, alloc_dev) = self.__dlpack_device__();
1032 if let Some(dev_obj) = dl_device.as_ref() {
1033 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1034 if dev_ty != kdl || dev_id != alloc_dev {
1035 let wants_copy = copy
1036 .as_ref()
1037 .and_then(|c| c.extract::<bool>(py).ok())
1038 .unwrap_or(false);
1039 if wants_copy {
1040 return Err(PyBufferError::new_err(
1041 "device copy not implemented for __dlpack__",
1042 ));
1043 } else {
1044 return Err(PyBufferError::new_err(
1045 "__dlpack__: requested device does not match producer buffer",
1046 ));
1047 }
1048 }
1049 }
1050 }
1051 let _ = stream;
1052
1053 if let Some(copy_obj) = copy.as_ref() {
1054 let do_copy: bool = copy_obj.extract(py)?;
1055 if do_copy {
1056 return Err(PyBufferError::new_err(
1057 "__dlpack__(copy=True) not supported for supertrend CUDA buffers",
1058 ));
1059 }
1060 }
1061
1062 let dummy =
1063 DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
1064 let rows = self.inner.rows;
1065 let cols = self.inner.cols;
1066 let inner = std::mem::replace(
1067 &mut self.inner,
1068 DeviceArrayF32 {
1069 buf: dummy,
1070 rows: 0,
1071 cols: 0,
1072 },
1073 );
1074
1075 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1076
1077 export_f32_cuda_dlpack_2d(py, inner.buf, rows, cols, alloc_dev, max_version_bound)
1078 }
1079}
1080
1081#[cfg(all(feature = "python", feature = "cuda"))]
1082#[pyfunction(name = "supertrend_cuda_batch_dev")]
1083#[pyo3(signature = (high, low, close, period_range, factor_range, device_id=0))]
1084pub fn supertrend_cuda_batch_dev_py<'py>(
1085 py: Python<'py>,
1086 high: numpy::PyReadonlyArray1<'py, f64>,
1087 low: numpy::PyReadonlyArray1<'py, f64>,
1088 close: numpy::PyReadonlyArray1<'py, f64>,
1089 period_range: (usize, usize, usize),
1090 factor_range: (f64, f64, f64),
1091 device_id: usize,
1092) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1093 use numpy::IntoPyArray;
1094 if !cuda_available() {
1095 return Err(PyValueError::new_err("CUDA not available"));
1096 }
1097 let h = high.as_slice()?;
1098 let l = low.as_slice()?;
1099 let c = close.as_slice()?;
1100 let sweep = SuperTrendBatchRange {
1101 period: period_range,
1102 factor: factor_range,
1103 };
1104 let (trend, changed, combos, ctx_arc, dev_id) = py.allow_threads(|| -> PyResult<_> {
1105 let cuda =
1106 CudaSupertrend::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1107 let h32: Vec<f32> = h.iter().map(|&v| v as f32).collect();
1108 let l32: Vec<f32> = l.iter().map(|&v| v as f32).collect();
1109 let c32: Vec<f32> = c.iter().map(|&v| v as f32).collect();
1110 let (trend, changed, combos) = cuda
1111 .supertrend_batch_dev(&h32, &l32, &c32, &sweep)
1112 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1113 let ctx_arc = cuda.context_arc();
1114 let dev_id = cuda.device_id();
1115 Ok((trend, changed, combos, ctx_arc, dev_id))
1116 })?;
1117
1118 let dict = pyo3::types::PyDict::new(py);
1119 dict.set_item(
1120 "trend",
1121 Py::new(
1122 py,
1123 SupertrendDeviceArrayF32Py {
1124 inner: trend,
1125 _ctx: ctx_arc.clone(),
1126 device_id: dev_id,
1127 },
1128 )?,
1129 )?;
1130 dict.set_item(
1131 "changed",
1132 Py::new(
1133 py,
1134 SupertrendDeviceArrayF32Py {
1135 inner: changed,
1136 _ctx: ctx_arc,
1137 device_id: dev_id,
1138 },
1139 )?,
1140 )?;
1141 let periods: Vec<usize> = combos.iter().map(|p| p.period.unwrap()).collect();
1142 let factors: Vec<f64> = combos.iter().map(|p| p.factor.unwrap()).collect();
1143 dict.set_item("periods", periods.into_pyarray(py))?;
1144 dict.set_item("factors", factors.into_pyarray(py))?;
1145 Ok(dict)
1146}
1147
1148#[cfg(all(feature = "python", feature = "cuda"))]
1149#[pyfunction(name = "supertrend_cuda_many_series_one_param_dev")]
1150#[pyo3(signature = (high_tm, low_tm, close_tm, cols, rows, period, factor, device_id=0))]
1151pub fn supertrend_cuda_many_series_one_param_dev_py<'py>(
1152 py: Python<'py>,
1153 high_tm: numpy::PyReadonlyArray1<'py, f64>,
1154 low_tm: numpy::PyReadonlyArray1<'py, f64>,
1155 close_tm: numpy::PyReadonlyArray1<'py, f64>,
1156 cols: usize,
1157 rows: usize,
1158 period: usize,
1159 factor: f64,
1160 device_id: usize,
1161) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1162 use numpy::IntoPyArray;
1163 if !cuda_available() {
1164 return Err(PyValueError::new_err("CUDA not available"));
1165 }
1166 let h = high_tm.as_slice()?;
1167 let l = low_tm.as_slice()?;
1168 let c = close_tm.as_slice()?;
1169 if h.len() != l.len() || l.len() != c.len() {
1170 return Err(PyValueError::new_err("length mismatch"));
1171 }
1172 let h32: Vec<f32> = h.iter().map(|&v| v as f32).collect();
1173 let l32: Vec<f32> = l.iter().map(|&v| v as f32).collect();
1174 let c32: Vec<f32> = c.iter().map(|&v| v as f32).collect();
1175 let (out, ctx_arc, dev_id) = py.allow_threads(|| -> PyResult<_> {
1176 let cuda =
1177 CudaSupertrend::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1178 let out = cuda
1179 .supertrend_many_series_one_param_time_major_dev(
1180 &h32,
1181 &l32,
1182 &c32,
1183 cols,
1184 rows,
1185 period,
1186 factor as f32,
1187 )
1188 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1189 let ctx_arc = cuda.context_arc();
1190 let dev_id = cuda.device_id();
1191 Ok((out, ctx_arc, dev_id))
1192 })?;
1193
1194 let dict = pyo3::types::PyDict::new(py);
1195 dict.set_item(
1196 "trend",
1197 Py::new(
1198 py,
1199 SupertrendDeviceArrayF32Py {
1200 inner: out.plus,
1201 _ctx: ctx_arc.clone(),
1202 device_id: dev_id,
1203 },
1204 )?,
1205 )?;
1206 dict.set_item(
1207 "changed",
1208 Py::new(
1209 py,
1210 SupertrendDeviceArrayF32Py {
1211 inner: out.minus,
1212 _ctx: ctx_arc,
1213 device_id: dev_id,
1214 },
1215 )?,
1216 )?;
1217 dict.set_item("cols", cols)?;
1218 dict.set_item("rows", rows)?;
1219 Ok(dict)
1220}
1221
1222#[inline(always)]
1223fn expand_grid(r: &SuperTrendBatchRange) -> Result<Vec<SuperTrendParams>, SuperTrendError> {
1224 fn axis_usize(
1225 (start, end, step): (usize, usize, usize),
1226 ) -> Result<Vec<usize>, SuperTrendError> {
1227 if step == 0 || start == end {
1228 return Ok(vec![start]);
1229 }
1230 if start < end {
1231 let v: Vec<usize> = (start..=end).step_by(step.max(1)).collect();
1232 if v.is_empty() {
1233 return Err(SuperTrendError::InvalidRange { start, end, step });
1234 }
1235 return Ok(v);
1236 }
1237 let mut v = Vec::new();
1238 let mut cur = start;
1239 let st = step.max(1);
1240 while cur >= end {
1241 v.push(cur);
1242 let next = cur.saturating_sub(st);
1243 if next == cur {
1244 break;
1245 }
1246 cur = next;
1247 }
1248 if v.is_empty() {
1249 return Err(SuperTrendError::InvalidRange { start, end, step });
1250 }
1251 Ok(v)
1252 }
1253 fn axis_f64((start, end, step): (f64, f64, f64)) -> Result<Vec<f64>, SuperTrendError> {
1254 if step.abs() < 1e-12 || (start - end).abs() < 1e-12 {
1255 return Ok(vec![start]);
1256 }
1257 let st = step.abs();
1258 if start < end {
1259 let mut v = Vec::new();
1260 let mut x = start;
1261 while x <= end + 1e-12 {
1262 v.push(x);
1263 x += st;
1264 }
1265 if v.is_empty() {
1266 return Err(SuperTrendError::InvalidFactorRange { start, end, step });
1267 }
1268 return Ok(v);
1269 }
1270 let mut v = Vec::new();
1271 let mut x = start;
1272 while x + 1e-12 >= end {
1273 v.push(x);
1274 x -= st;
1275 }
1276 if v.is_empty() {
1277 return Err(SuperTrendError::InvalidFactorRange { start, end, step });
1278 }
1279 Ok(v)
1280 }
1281 let periods = axis_usize(r.period)?;
1282 let factors = axis_f64(r.factor)?;
1283 let cap = periods
1284 .len()
1285 .checked_mul(factors.len())
1286 .ok_or(SuperTrendError::InvalidRange {
1287 start: r.period.0,
1288 end: r.period.1,
1289 step: r.period.2,
1290 })?;
1291 let mut out = Vec::with_capacity(cap);
1292 for &p in &periods {
1293 for &f in &factors {
1294 out.push(SuperTrendParams {
1295 period: Some(p),
1296 factor: Some(f),
1297 });
1298 }
1299 }
1300 Ok(out)
1301}
1302
1303pub fn supertrend_batch_with_kernel(
1304 high: &[f64],
1305 low: &[f64],
1306 close: &[f64],
1307 sweep: &SuperTrendBatchRange,
1308 k: Kernel,
1309) -> Result<SuperTrendBatchOutput, SuperTrendError> {
1310 let kernel = match k {
1311 Kernel::Auto => detect_best_batch_kernel(),
1312 other if other.is_batch() => other,
1313 _ => {
1314 return Err(SuperTrendError::InvalidKernelForBatch(k));
1315 }
1316 };
1317 let simd = match kernel {
1318 Kernel::Avx512Batch => Kernel::Avx512,
1319 Kernel::Avx2Batch => Kernel::Avx2,
1320 Kernel::ScalarBatch => Kernel::Scalar,
1321 _ => unreachable!(),
1322 };
1323 supertrend_batch_par_slice(high, low, close, sweep, simd)
1324}
1325
1326#[inline(always)]
1327pub fn supertrend_batch_slice(
1328 high: &[f64],
1329 low: &[f64],
1330 close: &[f64],
1331 sweep: &SuperTrendBatchRange,
1332 kern: Kernel,
1333) -> Result<SuperTrendBatchOutput, SuperTrendError> {
1334 supertrend_batch_inner(high, low, close, sweep, kern, false)
1335}
1336
1337#[inline(always)]
1338pub fn supertrend_batch_par_slice(
1339 high: &[f64],
1340 low: &[f64],
1341 close: &[f64],
1342 sweep: &SuperTrendBatchRange,
1343 kern: Kernel,
1344) -> Result<SuperTrendBatchOutput, SuperTrendError> {
1345 supertrend_batch_inner(high, low, close, sweep, kern, true)
1346}
1347
1348#[inline(always)]
1349fn supertrend_batch_inner(
1350 high: &[f64],
1351 low: &[f64],
1352 close: &[f64],
1353 sweep: &SuperTrendBatchRange,
1354 kern: Kernel,
1355 parallel: bool,
1356) -> Result<SuperTrendBatchOutput, SuperTrendError> {
1357 let combos = expand_grid(sweep)?;
1358 if combos.is_empty() {
1359 return Err(SuperTrendError::InvalidRange {
1360 start: sweep.period.0,
1361 end: sweep.period.1,
1362 step: sweep.period.2,
1363 });
1364 }
1365 let len = high.len();
1366 let mut first_valid_idx = None;
1367 for i in 0..len {
1368 if !high[i].is_nan() && !low[i].is_nan() && !close[i].is_nan() {
1369 first_valid_idx = Some(i);
1370 break;
1371 }
1372 }
1373 let first_valid_idx = match first_valid_idx {
1374 Some(idx) => idx,
1375 None => return Err(SuperTrendError::AllValuesNaN),
1376 };
1377 let max_p = combos.iter().map(|c| c.period.unwrap_or(10)).max().unwrap();
1378 if len - first_valid_idx < max_p {
1379 return Err(SuperTrendError::NotEnoughValidData {
1380 needed: max_p,
1381 valid: len - first_valid_idx,
1382 });
1383 }
1384 let rows = combos.len();
1385 let cols = len;
1386
1387 rows.checked_mul(cols)
1388 .ok_or(SuperTrendError::InvalidRange {
1389 start: sweep.period.0,
1390 end: sweep.period.1,
1391 step: sweep.period.2,
1392 })?;
1393
1394 let mut trend_mu = make_uninit_matrix(rows, cols);
1395 let mut changed_mu = make_uninit_matrix(rows, cols);
1396
1397 let warm: Vec<usize> = combos
1398 .iter()
1399 .map(|c| first_valid_idx + c.period.unwrap_or(10) - 1)
1400 .collect();
1401
1402 init_matrix_prefixes(&mut trend_mu, cols, &warm);
1403 init_matrix_prefixes(&mut changed_mu, cols, &warm);
1404
1405 let mut trend_guard = core::mem::ManuallyDrop::new(trend_mu);
1406 let mut changed_guard = core::mem::ManuallyDrop::new(changed_mu);
1407
1408 let trend: &mut [f64] = unsafe {
1409 core::slice::from_raw_parts_mut(trend_guard.as_mut_ptr() as *mut f64, trend_guard.len())
1410 };
1411 let changed: &mut [f64] = unsafe {
1412 core::slice::from_raw_parts_mut(changed_guard.as_mut_ptr() as *mut f64, changed_guard.len())
1413 };
1414
1415 let mut atr_cache: HashMap<usize, Vec<f64>> = HashMap::new();
1416 {
1417 let mut periods: Vec<usize> = combos.iter().map(|c| c.period.unwrap()).collect();
1418 periods.sort_unstable();
1419 periods.dedup();
1420 for &p in &periods {
1421 let atr_input = AtrInput::from_slices(
1422 &high[first_valid_idx..],
1423 &low[first_valid_idx..],
1424 &close[first_valid_idx..],
1425 AtrParams { length: Some(p) },
1426 );
1427 let AtrOutput { values } = atr(&atr_input)?;
1428 atr_cache.insert(p, values);
1429 }
1430 }
1431
1432 let hl2: Vec<f64> = (0..len).map(|i| 0.5 * (high[i] + low[i])).collect();
1433
1434 let do_row = |row: usize, trend_row: &mut [f64], changed_row: &mut [f64]| unsafe {
1435 let period = combos[row].period.unwrap();
1436 let factor = combos[row].factor.unwrap();
1437 let atr_values = atr_cache.get(&period).unwrap().as_slice();
1438 match kern {
1439 Kernel::Scalar => supertrend_row_scalar_from_hl(
1440 &hl2,
1441 close,
1442 period,
1443 factor,
1444 first_valid_idx,
1445 atr_values,
1446 trend_row,
1447 changed_row,
1448 ),
1449 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1450 Kernel::Avx2 | Kernel::Avx512 => supertrend_row_scalar_from_hl(
1451 &hl2,
1452 close,
1453 period,
1454 factor,
1455 first_valid_idx,
1456 atr_values,
1457 trend_row,
1458 changed_row,
1459 ),
1460 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1461 Kernel::Avx2 => supertrend_row_avx2(
1462 high,
1463 low,
1464 close,
1465 period,
1466 factor,
1467 first_valid_idx,
1468 atr_values,
1469 trend_row,
1470 changed_row,
1471 ),
1472 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1473 Kernel::Avx512 => supertrend_row_avx512(
1474 high,
1475 low,
1476 close,
1477 period,
1478 factor,
1479 first_valid_idx,
1480 atr_values,
1481 trend_row,
1482 changed_row,
1483 ),
1484 _ => unreachable!(),
1485 }
1486 };
1487 if parallel {
1488 #[cfg(not(target_arch = "wasm32"))]
1489 {
1490 trend
1491 .par_chunks_mut(cols)
1492 .zip(changed.par_chunks_mut(cols))
1493 .enumerate()
1494 .for_each(|(row, (tr, ch))| do_row(row, tr, ch));
1495 }
1496
1497 #[cfg(target_arch = "wasm32")]
1498 {
1499 for (row, (tr, ch)) in trend
1500 .chunks_mut(cols)
1501 .zip(changed.chunks_mut(cols))
1502 .enumerate()
1503 {
1504 do_row(row, tr, ch);
1505 }
1506 }
1507 } else {
1508 for (row, (tr, ch)) in trend
1509 .chunks_mut(cols)
1510 .zip(changed.chunks_mut(cols))
1511 .enumerate()
1512 {
1513 do_row(row, tr, ch);
1514 }
1515 }
1516
1517 let trend_vec = unsafe {
1518 Vec::from_raw_parts(
1519 trend_guard.as_mut_ptr() as *mut f64,
1520 trend_guard.len(),
1521 trend_guard.capacity(),
1522 )
1523 };
1524 let changed_vec = unsafe {
1525 Vec::from_raw_parts(
1526 changed_guard.as_mut_ptr() as *mut f64,
1527 changed_guard.len(),
1528 changed_guard.capacity(),
1529 )
1530 };
1531
1532 Ok(SuperTrendBatchOutput {
1533 trend: trend_vec,
1534 changed: changed_vec,
1535 combos,
1536 rows,
1537 cols,
1538 })
1539}
1540
1541#[inline(always)]
1542unsafe fn supertrend_row_scalar(
1543 high: &[f64],
1544 low: &[f64],
1545 close: &[f64],
1546 period: usize,
1547 factor: f64,
1548 first_valid_idx: usize,
1549 atr_values: &[f64],
1550 trend: &mut [f64],
1551 changed: &mut [f64],
1552) {
1553 supertrend_scalar(
1554 high,
1555 low,
1556 close,
1557 period,
1558 factor,
1559 first_valid_idx,
1560 atr_values,
1561 trend,
1562 changed,
1563 );
1564}
1565
1566#[inline(always)]
1567unsafe fn supertrend_row_scalar_from_hl(
1568 hl2: &[f64],
1569 close: &[f64],
1570 period: usize,
1571 factor: f64,
1572 first_valid_idx: usize,
1573 atr_values: &[f64],
1574 trend: &mut [f64],
1575 changed: &mut [f64],
1576) {
1577 let len = hl2.len();
1578 let start = first_valid_idx + period;
1579 if start > len {
1580 return;
1581 }
1582
1583 let hl_ptr = hl2.as_ptr();
1584 let c_ptr = close.as_ptr();
1585 let atr_ptr = atr_values.as_ptr();
1586 let tr_ptr = trend.as_mut_ptr();
1587 let ch_ptr = changed.as_mut_ptr();
1588
1589 let warmup = start - 1;
1590 let hl2_w = *hl_ptr.add(warmup);
1591 let atr_w = *atr_ptr.add(period - 1);
1592 let mut prev_upper_band = factor.mul_add(atr_w, hl2_w);
1593 let mut prev_lower_band = (-factor).mul_add(atr_w, hl2_w);
1594
1595 let mut last_close = *c_ptr.add(warmup);
1596 let mut upper_state = if last_close <= prev_upper_band {
1597 *tr_ptr.add(warmup) = prev_upper_band;
1598 true
1599 } else {
1600 *tr_ptr.add(warmup) = prev_lower_band;
1601 false
1602 };
1603 *ch_ptr.add(warmup) = 0.0;
1604
1605 let mut i = warmup + 1;
1606 while i < len {
1607 let atr_i = *atr_ptr.add(i - first_valid_idx);
1608 let hl = *hl_ptr.add(i);
1609 let upper_basic = factor.mul_add(atr_i, hl);
1610 let lower_basic = (-factor).mul_add(atr_i, hl);
1611
1612 let prev_close = last_close;
1613 let mut curr_upper_band = upper_basic;
1614 if prev_close <= prev_upper_band {
1615 curr_upper_band = curr_upper_band.min(prev_upper_band);
1616 }
1617 let mut curr_lower_band = lower_basic;
1618 if prev_close >= prev_lower_band {
1619 curr_lower_band = curr_lower_band.max(prev_lower_band);
1620 }
1621
1622 let curr_close = *c_ptr.add(i);
1623 if upper_state {
1624 if curr_close <= curr_upper_band {
1625 *tr_ptr.add(i) = curr_upper_band;
1626 *ch_ptr.add(i) = 0.0;
1627 } else {
1628 *tr_ptr.add(i) = curr_lower_band;
1629 *ch_ptr.add(i) = 1.0;
1630 upper_state = false;
1631 }
1632 } else {
1633 if curr_close >= curr_lower_band {
1634 *tr_ptr.add(i) = curr_lower_band;
1635 *ch_ptr.add(i) = 0.0;
1636 } else {
1637 *tr_ptr.add(i) = curr_upper_band;
1638 *ch_ptr.add(i) = 1.0;
1639 upper_state = true;
1640 }
1641 }
1642
1643 prev_upper_band = curr_upper_band;
1644 prev_lower_band = curr_lower_band;
1645 last_close = curr_close;
1646 i += 1;
1647 }
1648}
1649
1650#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1651#[inline(always)]
1652unsafe fn supertrend_row_avx2(
1653 high: &[f64],
1654 low: &[f64],
1655 close: &[f64],
1656 period: usize,
1657 factor: f64,
1658 first_valid_idx: usize,
1659 atr_values: &[f64],
1660 trend: &mut [f64],
1661 changed: &mut [f64],
1662) {
1663 supertrend_scalar(
1664 high,
1665 low,
1666 close,
1667 period,
1668 factor,
1669 first_valid_idx,
1670 atr_values,
1671 trend,
1672 changed,
1673 );
1674}
1675
1676#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1677#[inline(always)]
1678unsafe fn supertrend_row_avx512(
1679 high: &[f64],
1680 low: &[f64],
1681 close: &[f64],
1682 period: usize,
1683 factor: f64,
1684 first_valid_idx: usize,
1685 atr_values: &[f64],
1686 trend: &mut [f64],
1687 changed: &mut [f64],
1688) {
1689 if period <= 32 {
1690 supertrend_row_avx512_short(
1691 high,
1692 low,
1693 close,
1694 period,
1695 factor,
1696 first_valid_idx,
1697 atr_values,
1698 trend,
1699 changed,
1700 );
1701 } else {
1702 supertrend_row_avx512_long(
1703 high,
1704 low,
1705 close,
1706 period,
1707 factor,
1708 first_valid_idx,
1709 atr_values,
1710 trend,
1711 changed,
1712 );
1713 }
1714}
1715
1716#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1717#[inline(always)]
1718unsafe fn supertrend_row_avx512_short(
1719 high: &[f64],
1720 low: &[f64],
1721 close: &[f64],
1722 period: usize,
1723 factor: f64,
1724 first_valid_idx: usize,
1725 atr_values: &[f64],
1726 trend: &mut [f64],
1727 changed: &mut [f64],
1728) {
1729 supertrend_scalar(
1730 high,
1731 low,
1732 close,
1733 period,
1734 factor,
1735 first_valid_idx,
1736 atr_values,
1737 trend,
1738 changed,
1739 );
1740}
1741
1742#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1743#[inline(always)]
1744unsafe fn supertrend_row_avx512_long(
1745 high: &[f64],
1746 low: &[f64],
1747 close: &[f64],
1748 period: usize,
1749 factor: f64,
1750 first_valid_idx: usize,
1751 atr_values: &[f64],
1752 trend: &mut [f64],
1753 changed: &mut [f64],
1754) {
1755 supertrend_scalar(
1756 high,
1757 low,
1758 close,
1759 period,
1760 factor,
1761 first_valid_idx,
1762 atr_values,
1763 trend,
1764 changed,
1765 );
1766}
1767
1768#[cfg(feature = "python")]
1769#[inline(always)]
1770pub fn supertrend_batch_inner_into(
1771 high: &[f64],
1772 low: &[f64],
1773 close: &[f64],
1774 sweep: &SuperTrendBatchRange,
1775 simd: Kernel,
1776 parallel: bool,
1777 trend_out: &mut [f64],
1778 changed_out: &mut [f64],
1779) -> Result<Vec<SuperTrendParams>, SuperTrendError> {
1780 let combos = expand_grid(sweep)?;
1781 if combos.is_empty() {
1782 return Err(SuperTrendError::InvalidRange {
1783 start: sweep.period.0,
1784 end: sweep.period.1,
1785 step: sweep.period.2,
1786 });
1787 }
1788 let len = high.len();
1789 let mut first_valid_idx = None;
1790 for i in 0..len {
1791 if !high[i].is_nan() && !low[i].is_nan() && !close[i].is_nan() {
1792 first_valid_idx = Some(i);
1793 break;
1794 }
1795 }
1796 let first_valid_idx = match first_valid_idx {
1797 Some(idx) => idx,
1798 None => return Err(SuperTrendError::AllValuesNaN),
1799 };
1800 let max_p = combos.iter().map(|c| c.period.unwrap_or(10)).max().unwrap();
1801 if len - first_valid_idx < max_p {
1802 return Err(SuperTrendError::NotEnoughValidData {
1803 needed: max_p,
1804 valid: len - first_valid_idx,
1805 });
1806 }
1807 let rows = combos.len();
1808 let cols = len;
1809
1810 let expected_len = rows
1811 .checked_mul(cols)
1812 .ok_or(SuperTrendError::InvalidRange {
1813 start: sweep.period.0,
1814 end: sweep.period.1,
1815 step: sweep.period.2,
1816 })?;
1817 if trend_out.len() != expected_len {
1818 return Err(SuperTrendError::OutputLengthMismatch {
1819 expected: expected_len,
1820 got: trend_out.len(),
1821 });
1822 }
1823 if changed_out.len() != expected_len {
1824 return Err(SuperTrendError::OutputLengthMismatch {
1825 expected: expected_len,
1826 got: changed_out.len(),
1827 });
1828 }
1829
1830 for (row, combo) in combos.iter().enumerate() {
1831 let warmup = first_valid_idx + combo.period.unwrap_or(10) - 1;
1832 let row_start = row * cols;
1833 for i in 0..warmup.min(cols) {
1834 trend_out[row_start + i] = f64::NAN;
1835 changed_out[row_start + i] = f64::NAN;
1836 }
1837 }
1838
1839 let hl2: Vec<f64> = (0..len).map(|i| 0.5 * (high[i] + low[i])).collect();
1840
1841 let do_row = |row: usize, trend_row: &mut [f64], changed_row: &mut [f64]| unsafe {
1842 let period = combos[row].period.unwrap();
1843 let factor = combos[row].factor.unwrap();
1844 let atr_input = AtrInput::from_slices(
1845 &high[first_valid_idx..],
1846 &low[first_valid_idx..],
1847 &close[first_valid_idx..],
1848 AtrParams {
1849 length: Some(period),
1850 },
1851 );
1852 let AtrOutput { values: atr_values } = atr(&atr_input).unwrap();
1853 match simd {
1854 Kernel::Scalar => supertrend_row_scalar_from_hl(
1855 &hl2,
1856 close,
1857 period,
1858 factor,
1859 first_valid_idx,
1860 &atr_values,
1861 trend_row,
1862 changed_row,
1863 ),
1864 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1865 Kernel::Avx2 => supertrend_row_avx2(
1866 high,
1867 low,
1868 close,
1869 period,
1870 factor,
1871 first_valid_idx,
1872 &atr_values,
1873 trend_row,
1874 changed_row,
1875 ),
1876 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1877 Kernel::Avx512 => supertrend_row_avx512(
1878 high,
1879 low,
1880 close,
1881 period,
1882 factor,
1883 first_valid_idx,
1884 &atr_values,
1885 trend_row,
1886 changed_row,
1887 ),
1888 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1889 Kernel::Avx2 | Kernel::Avx512 => supertrend_row_scalar_from_hl(
1890 &hl2,
1891 close,
1892 period,
1893 factor,
1894 first_valid_idx,
1895 &atr_values,
1896 trend_row,
1897 changed_row,
1898 ),
1899 _ => unreachable!(),
1900 }
1901 };
1902 if parallel {
1903 #[cfg(not(target_arch = "wasm32"))]
1904 {
1905 trend_out
1906 .par_chunks_mut(cols)
1907 .zip(changed_out.par_chunks_mut(cols))
1908 .enumerate()
1909 .for_each(|(row, (tr, ch))| do_row(row, tr, ch));
1910 }
1911
1912 #[cfg(target_arch = "wasm32")]
1913 {
1914 for (row, (tr, ch)) in trend_out
1915 .chunks_mut(cols)
1916 .zip(changed_out.chunks_mut(cols))
1917 .enumerate()
1918 {
1919 do_row(row, tr, ch);
1920 }
1921 }
1922 } else {
1923 for (row, (tr, ch)) in trend_out
1924 .chunks_mut(cols)
1925 .zip(changed_out.chunks_mut(cols))
1926 .enumerate()
1927 {
1928 do_row(row, tr, ch);
1929 }
1930 }
1931 Ok(combos)
1932}
1933
1934#[cfg(feature = "python")]
1935#[pyfunction(name = "supertrend")]
1936#[pyo3(signature = (high, low, close, period, factor, kernel=None))]
1937pub fn supertrend_py<'py>(
1938 py: Python<'py>,
1939 high: numpy::PyReadonlyArray1<'py, f64>,
1940 low: numpy::PyReadonlyArray1<'py, f64>,
1941 close: numpy::PyReadonlyArray1<'py, f64>,
1942 period: usize,
1943 factor: f64,
1944 kernel: Option<&str>,
1945) -> PyResult<(
1946 Bound<'py, numpy::PyArray1<f64>>,
1947 Bound<'py, numpy::PyArray1<f64>>,
1948)> {
1949 use numpy::{IntoPyArray, PyArrayMethods};
1950
1951 let high_slice = high.as_slice()?;
1952 let low_slice = low.as_slice()?;
1953 let close_slice = close.as_slice()?;
1954 let kern = validate_kernel(kernel, false)?;
1955
1956 let params = SuperTrendParams {
1957 period: Some(period),
1958 factor: Some(factor),
1959 };
1960 let input = SuperTrendInput::from_slices(high_slice, low_slice, close_slice, params);
1961
1962 let (trend_vec, changed_vec) = py
1963 .allow_threads(|| supertrend_with_kernel(&input, kern).map(|o| (o.trend, o.changed)))
1964 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1965
1966 Ok((trend_vec.into_pyarray(py), changed_vec.into_pyarray(py)))
1967}
1968
1969#[cfg(feature = "python")]
1970#[pyfunction(name = "supertrend_batch")]
1971#[pyo3(signature = (high, low, close, period_range, factor_range, kernel=None))]
1972pub fn supertrend_batch_py<'py>(
1973 py: Python<'py>,
1974 high: numpy::PyReadonlyArray1<'py, f64>,
1975 low: numpy::PyReadonlyArray1<'py, f64>,
1976 close: numpy::PyReadonlyArray1<'py, f64>,
1977 period_range: (usize, usize, usize),
1978 factor_range: (f64, f64, f64),
1979 kernel: Option<&str>,
1980) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1981 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1982 use pyo3::types::PyDict;
1983
1984 let high_slice = high.as_slice()?;
1985 let low_slice = low.as_slice()?;
1986 let close_slice = close.as_slice()?;
1987 let kern = validate_kernel(kernel, true)?;
1988
1989 let sweep = SuperTrendBatchRange {
1990 period: period_range,
1991 factor: factor_range,
1992 };
1993
1994 let grid_combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1995 if grid_combos.is_empty() {
1996 return Err(PyValueError::new_err(format!(
1997 "supertrend: Invalid range: start={}, end={}, step={}",
1998 sweep.period.0, sweep.period.1, sweep.period.2
1999 )));
2000 }
2001 let rows = grid_combos.len();
2002 let cols = high_slice.len();
2003 let total = rows
2004 .checked_mul(cols)
2005 .ok_or_else(|| PyValueError::new_err("supertrend: rows*cols overflow"))?;
2006
2007 let trend_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2008 let trend_out = unsafe { trend_arr.as_slice_mut()? };
2009 let changed_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2010 let changed_out = unsafe { changed_arr.as_slice_mut()? };
2011
2012 let combos = py
2013 .allow_threads(|| {
2014 let kernel = match kern {
2015 Kernel::Auto => detect_best_batch_kernel(),
2016 k => k,
2017 };
2018 let simd = match kernel {
2019 Kernel::Avx512Batch => Kernel::Avx512,
2020 Kernel::Avx2Batch => Kernel::Avx2,
2021 Kernel::ScalarBatch => Kernel::Scalar,
2022 _ => unreachable!(),
2023 };
2024 supertrend_batch_inner_into(
2025 high_slice,
2026 low_slice,
2027 close_slice,
2028 &sweep,
2029 simd,
2030 true,
2031 trend_out,
2032 changed_out,
2033 )
2034 })
2035 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2036
2037 let dict = PyDict::new(py);
2038 dict.set_item("trend", trend_arr.reshape((rows, cols))?)?;
2039 dict.set_item("changed", changed_arr.reshape((rows, cols))?)?;
2040 dict.set_item(
2041 "periods",
2042 combos
2043 .iter()
2044 .map(|p| p.period.unwrap() as u64)
2045 .collect::<Vec<_>>()
2046 .into_pyarray(py),
2047 )?;
2048 dict.set_item(
2049 "factors",
2050 combos
2051 .iter()
2052 .map(|p| p.factor.unwrap())
2053 .collect::<Vec<_>>()
2054 .into_pyarray(py),
2055 )?;
2056 dict.set_item("rows", rows)?;
2057 dict.set_item("cols", cols)?;
2058
2059 Ok(dict)
2060}
2061
2062#[cfg(feature = "python")]
2063#[pyclass(name = "SuperTrendStream")]
2064pub struct SuperTrendStreamPy {
2065 stream: SuperTrendStream,
2066}
2067
2068#[cfg(feature = "python")]
2069#[pymethods]
2070impl SuperTrendStreamPy {
2071 #[new]
2072 fn new(period: usize, factor: f64) -> PyResult<Self> {
2073 let params = SuperTrendParams {
2074 period: Some(period),
2075 factor: Some(factor),
2076 };
2077 let stream =
2078 SuperTrendStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2079 Ok(SuperTrendStreamPy { stream })
2080 }
2081
2082 fn update(&mut self, high: f64, low: f64, close: f64) -> Option<(f64, f64)> {
2083 self.stream.update(high, low, close)
2084 }
2085}
2086
2087#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2088#[inline]
2089pub fn supertrend_into_slice(
2090 trend_dst: &mut [f64],
2091 changed_dst: &mut [f64],
2092 input: &SuperTrendInput,
2093 kern: Kernel,
2094) -> Result<(), SuperTrendError> {
2095 let (high, low, close, period, factor, first_valid_idx, atr_values, chosen) =
2096 supertrend_prepare(input, kern)?;
2097
2098 let len = high.len();
2099 if trend_dst.len() != len {
2100 return Err(SuperTrendError::OutputLengthMismatch {
2101 expected: len,
2102 got: trend_dst.len(),
2103 });
2104 }
2105 if changed_dst.len() != len {
2106 return Err(SuperTrendError::OutputLengthMismatch {
2107 expected: len,
2108 got: changed_dst.len(),
2109 });
2110 }
2111
2112 let warmup_end = first_valid_idx + period - 1;
2113 for v in &mut trend_dst[..warmup_end] {
2114 *v = f64::NAN;
2115 }
2116 for v in &mut changed_dst[..warmup_end] {
2117 *v = f64::NAN;
2118 }
2119
2120 supertrend_compute_into(
2121 high,
2122 low,
2123 close,
2124 period,
2125 factor,
2126 first_valid_idx,
2127 &atr_values,
2128 chosen,
2129 trend_dst,
2130 changed_dst,
2131 );
2132
2133 Ok(())
2134}
2135
2136#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2137#[derive(Serialize, Deserialize)]
2138pub struct SuperTrendJsResult {
2139 pub values: Vec<f64>,
2140 pub rows: usize,
2141 pub cols: usize,
2142}
2143
2144#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2145#[wasm_bindgen(js_name = supertrend)]
2146pub fn supertrend_js(
2147 high: &[f64],
2148 low: &[f64],
2149 close: &[f64],
2150 period: usize,
2151 factor: f64,
2152) -> Result<JsValue, JsValue> {
2153 let len = high.len();
2154 let params = SuperTrendParams {
2155 period: Some(period),
2156 factor: Some(factor),
2157 };
2158 let input = SuperTrendInput::from_slices(high, low, close, params);
2159
2160 let mut values = vec![0.0; len * 2];
2161 let (trend_slice, changed_slice) = values.split_at_mut(len);
2162 supertrend_into_slice(trend_slice, changed_slice, &input, Kernel::Auto)
2163 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2164
2165 let out = SuperTrendJsResult {
2166 values,
2167 rows: 2,
2168 cols: len,
2169 };
2170 serde_wasm_bindgen::to_value(&out)
2171 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2172}
2173
2174#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2175#[wasm_bindgen]
2176pub fn supertrend_into(
2177 high_ptr: *const f64,
2178 low_ptr: *const f64,
2179 close_ptr: *const f64,
2180 trend_ptr: *mut f64,
2181 changed_ptr: *mut f64,
2182 len: usize,
2183 period: usize,
2184 factor: f64,
2185) -> Result<(), JsValue> {
2186 if high_ptr.is_null()
2187 || low_ptr.is_null()
2188 || close_ptr.is_null()
2189 || trend_ptr.is_null()
2190 || changed_ptr.is_null()
2191 {
2192 return Err(JsValue::from_str("Null pointer provided"));
2193 }
2194
2195 unsafe {
2196 let high = std::slice::from_raw_parts(high_ptr, len);
2197 let low = std::slice::from_raw_parts(low_ptr, len);
2198 let close = std::slice::from_raw_parts(close_ptr, len);
2199
2200 let params = SuperTrendParams {
2201 period: Some(period),
2202 factor: Some(factor),
2203 };
2204 let input = SuperTrendInput::from_slices(high, low, close, params);
2205
2206 let input_ptrs = [
2207 high_ptr as *const u8,
2208 low_ptr as *const u8,
2209 close_ptr as *const u8,
2210 ];
2211 let output_ptrs = [trend_ptr as *const u8, changed_ptr as *const u8];
2212
2213 let has_aliasing = input_ptrs
2214 .iter()
2215 .any(|&inp| output_ptrs.iter().any(|&out| inp == out));
2216
2217 if has_aliasing {
2218 let output = supertrend_with_kernel(&input, Kernel::Auto)
2219 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2220
2221 let trend_out = std::slice::from_raw_parts_mut(trend_ptr, len);
2222 let changed_out = std::slice::from_raw_parts_mut(changed_ptr, len);
2223
2224 trend_out.copy_from_slice(&output.trend);
2225 changed_out.copy_from_slice(&output.changed);
2226 } else {
2227 let trend_out = std::slice::from_raw_parts_mut(trend_ptr, len);
2228 let changed_out = std::slice::from_raw_parts_mut(changed_ptr, len);
2229
2230 supertrend_into_slice(trend_out, changed_out, &input, Kernel::Auto)
2231 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2232 }
2233
2234 Ok(())
2235 }
2236}
2237
2238#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2239#[wasm_bindgen]
2240pub fn supertrend_alloc(len: usize) -> *mut f64 {
2241 let mut vec = Vec::<f64>::with_capacity(len);
2242 let ptr = vec.as_mut_ptr();
2243 std::mem::forget(vec);
2244 ptr
2245}
2246
2247#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2248#[wasm_bindgen]
2249pub fn supertrend_free(ptr: *mut f64, len: usize) {
2250 if !ptr.is_null() {
2251 unsafe {
2252 let _ = Vec::from_raw_parts(ptr, len, len);
2253 }
2254 }
2255}
2256
2257#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2258#[derive(Serialize, Deserialize)]
2259pub struct SuperTrendBatchConfig {
2260 pub period_range: (usize, usize, usize),
2261 pub factor_range: (f64, f64, f64),
2262}
2263
2264#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2265#[derive(Serialize, Deserialize)]
2266pub struct SuperTrendBatchJsOutput {
2267 pub values: Vec<f64>,
2268 pub periods: Vec<usize>,
2269 pub factors: Vec<f64>,
2270 pub rows: usize,
2271 pub cols: usize,
2272}
2273
2274#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2275#[wasm_bindgen(js_name = supertrend_batch)]
2276pub fn supertrend_batch_js(
2277 high: &[f64],
2278 low: &[f64],
2279 close: &[f64],
2280 config: JsValue,
2281) -> Result<JsValue, JsValue> {
2282 let cfg: SuperTrendBatchConfig = serde_wasm_bindgen::from_value(config)
2283 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2284
2285 let sweep = SuperTrendBatchRange {
2286 period: cfg.period_range,
2287 factor: cfg.factor_range,
2288 };
2289
2290 let batch = supertrend_batch_with_kernel(high, low, close, &sweep, Kernel::Auto)
2291 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2292
2293 let mut values = Vec::with_capacity(batch.rows * 2 * batch.cols);
2294 for r in 0..batch.rows {
2295 let rs = r * batch.cols;
2296 values.extend_from_slice(&batch.trend[rs..rs + batch.cols]);
2297 values.extend_from_slice(&batch.changed[rs..rs + batch.cols]);
2298 }
2299
2300 let periods: Vec<usize> = batch
2301 .combos
2302 .iter()
2303 .map(|c| c.period.unwrap_or(10))
2304 .collect();
2305 let factors: Vec<f64> = batch
2306 .combos
2307 .iter()
2308 .map(|c| c.factor.unwrap_or(3.0))
2309 .collect();
2310
2311 let out = SuperTrendBatchJsOutput {
2312 values,
2313 periods,
2314 factors,
2315 rows: batch.rows * 2,
2316 cols: batch.cols,
2317 };
2318 serde_wasm_bindgen::to_value(&out)
2319 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2320}
2321
2322#[cfg(test)]
2323mod tests {
2324 use super::*;
2325 use crate::skip_if_unsupported;
2326 use crate::utilities::data_loader::read_candles_from_csv;
2327 use crate::utilities::enums::Kernel;
2328
2329 #[test]
2330 fn test_supertrend_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
2331 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2332 let candles = read_candles_from_csv(file_path)?;
2333
2334 let params = SuperTrendParams {
2335 period: Some(10),
2336 factor: Some(3.0),
2337 };
2338 let input = SuperTrendInput::from_candles(&candles, params);
2339
2340 let baseline = supertrend_with_kernel(&input, Kernel::Auto)?;
2341
2342 let n = candles.close.len();
2343 let mut trend_out = vec![0.0; n];
2344 let mut changed_out = vec![0.0; n];
2345
2346 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2347 {
2348 supertrend_into(&input, &mut trend_out, &mut changed_out)?;
2349 }
2350
2351 assert_eq!(baseline.trend.len(), n);
2352 assert_eq!(baseline.changed.len(), n);
2353 assert_eq!(trend_out.len(), n);
2354 assert_eq!(changed_out.len(), n);
2355
2356 #[inline]
2357 fn eq_or_both_nan(a: f64, b: f64) -> bool {
2358 (a.is_nan() && b.is_nan()) || (a - b).abs() <= 1e-9
2359 }
2360
2361 for i in 0..n {
2362 assert!(
2363 eq_or_both_nan(baseline.trend[i], trend_out[i]),
2364 "trend mismatch at {}: baseline={}, into={}",
2365 i,
2366 baseline.trend[i],
2367 trend_out[i]
2368 );
2369 assert!(
2370 eq_or_both_nan(baseline.changed[i], changed_out[i]),
2371 "changed mismatch at {}: baseline={}, into={}",
2372 i,
2373 baseline.changed[i],
2374 changed_out[i]
2375 );
2376 }
2377
2378 Ok(())
2379 }
2380
2381 fn check_supertrend_partial_params(
2382 test_name: &str,
2383 kernel: Kernel,
2384 ) -> Result<(), Box<dyn std::error::Error>> {
2385 skip_if_unsupported!(kernel, test_name);
2386 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2387 let candles = read_candles_from_csv(file_path)?;
2388
2389 let default_params = SuperTrendParams {
2390 period: None,
2391 factor: None,
2392 };
2393 let input = SuperTrendInput::from_candles(&candles, default_params);
2394 let output = supertrend_with_kernel(&input, kernel)?;
2395 assert_eq!(output.trend.len(), candles.close.len());
2396 assert_eq!(output.changed.len(), candles.close.len());
2397
2398 Ok(())
2399 }
2400
2401 fn check_supertrend_accuracy(
2402 test_name: &str,
2403 kernel: Kernel,
2404 ) -> Result<(), Box<dyn std::error::Error>> {
2405 skip_if_unsupported!(kernel, test_name);
2406 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2407 let candles = read_candles_from_csv(file_path)?;
2408
2409 let params = SuperTrendParams {
2410 period: Some(10),
2411 factor: Some(3.0),
2412 };
2413 let input = SuperTrendInput::from_candles(&candles, params);
2414 let st_result = supertrend_with_kernel(&input, kernel)?;
2415
2416 assert_eq!(st_result.trend.len(), candles.close.len());
2417 assert_eq!(st_result.changed.len(), candles.close.len());
2418
2419 let expected_last_five_trend = [
2420 61811.479454208165,
2421 61721.73150878735,
2422 61459.10835790861,
2423 61351.59752211775,
2424 61033.18776990598,
2425 ];
2426 let expected_last_five_changed = [0.0, 0.0, 0.0, 0.0, 0.0];
2427
2428 let start_index = st_result.trend.len() - 5;
2429 let trend_slice = &st_result.trend[start_index..];
2430 let changed_slice = &st_result.changed[start_index..];
2431
2432 for (i, &val) in trend_slice.iter().enumerate() {
2433 let exp = expected_last_five_trend[i];
2434 assert!(
2435 (val - exp).abs() < 1e-4,
2436 "[{}] Trend mismatch at idx {}: got {}, expected {}",
2437 test_name,
2438 i,
2439 val,
2440 exp
2441 );
2442 }
2443 for (i, &val) in changed_slice.iter().enumerate() {
2444 let exp = expected_last_five_changed[i];
2445 assert!(
2446 (val - exp).abs() < 1e-9,
2447 "[{}] Changed mismatch at idx {}: got {}, expected {}",
2448 test_name,
2449 i,
2450 val,
2451 exp
2452 );
2453 }
2454 Ok(())
2455 }
2456
2457 fn check_supertrend_default_candles(
2458 test_name: &str,
2459 kernel: Kernel,
2460 ) -> Result<(), Box<dyn std::error::Error>> {
2461 skip_if_unsupported!(kernel, test_name);
2462 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2463 let candles = read_candles_from_csv(file_path)?;
2464
2465 let input = SuperTrendInput::with_default_candles(&candles);
2466 let output = supertrend_with_kernel(&input, kernel)?;
2467 assert_eq!(output.trend.len(), candles.close.len());
2468 assert_eq!(output.changed.len(), candles.close.len());
2469 Ok(())
2470 }
2471
2472 fn check_supertrend_zero_period(
2473 test_name: &str,
2474 kernel: Kernel,
2475 ) -> Result<(), Box<dyn std::error::Error>> {
2476 skip_if_unsupported!(kernel, test_name);
2477 let high = [10.0, 12.0, 13.0];
2478 let low = [9.0, 11.0, 12.5];
2479 let close = [9.5, 11.5, 13.0];
2480 let params = SuperTrendParams {
2481 period: Some(0),
2482 factor: Some(3.0),
2483 };
2484 let input = SuperTrendInput::from_slices(&high, &low, &close, params);
2485 let res = supertrend_with_kernel(&input, kernel);
2486 assert!(res.is_err(), "[{}] Should fail with zero period", test_name);
2487 Ok(())
2488 }
2489
2490 fn check_supertrend_period_exceeds_length(
2491 test_name: &str,
2492 kernel: Kernel,
2493 ) -> Result<(), Box<dyn std::error::Error>> {
2494 skip_if_unsupported!(kernel, test_name);
2495 let high = [10.0, 12.0, 13.0];
2496 let low = [9.0, 11.0, 12.5];
2497 let close = [9.5, 11.5, 13.0];
2498 let params = SuperTrendParams {
2499 period: Some(10),
2500 factor: Some(3.0),
2501 };
2502 let input = SuperTrendInput::from_slices(&high, &low, &close, params);
2503 let res = supertrend_with_kernel(&input, kernel);
2504 assert!(
2505 res.is_err(),
2506 "[{}] Should fail with period > data.len()",
2507 test_name
2508 );
2509 Ok(())
2510 }
2511
2512 fn check_supertrend_very_small_dataset(
2513 test_name: &str,
2514 kernel: Kernel,
2515 ) -> Result<(), Box<dyn std::error::Error>> {
2516 skip_if_unsupported!(kernel, test_name);
2517 let high = [42.0];
2518 let low = [40.0];
2519 let close = [41.0];
2520 let params = SuperTrendParams {
2521 period: Some(10),
2522 factor: Some(3.0),
2523 };
2524 let input = SuperTrendInput::from_slices(&high, &low, &close, params);
2525 let res = supertrend_with_kernel(&input, kernel);
2526 assert!(
2527 res.is_err(),
2528 "[{}] Should fail for data smaller than period",
2529 test_name
2530 );
2531 Ok(())
2532 }
2533
2534 fn check_supertrend_reinput(
2535 test_name: &str,
2536 kernel: Kernel,
2537 ) -> Result<(), Box<dyn std::error::Error>> {
2538 skip_if_unsupported!(kernel, test_name);
2539 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2540 let candles = read_candles_from_csv(file_path)?;
2541
2542 let first_params = SuperTrendParams {
2543 period: Some(10),
2544 factor: Some(3.0),
2545 };
2546 let first_input = SuperTrendInput::from_candles(&candles, first_params);
2547 let first_result = supertrend_with_kernel(&first_input, kernel)?;
2548
2549 let second_params = SuperTrendParams {
2550 period: Some(5),
2551 factor: Some(2.0),
2552 };
2553 let second_input = SuperTrendInput::from_slices(
2554 &first_result.trend,
2555 &first_result.trend,
2556 &first_result.trend,
2557 second_params,
2558 );
2559 let second_result = supertrend_with_kernel(&second_input, kernel)?;
2560 assert_eq!(second_result.trend.len(), first_result.trend.len());
2561 assert_eq!(second_result.changed.len(), first_result.changed.len());
2562 Ok(())
2563 }
2564
2565 fn check_supertrend_nan_handling(
2566 test_name: &str,
2567 kernel: Kernel,
2568 ) -> Result<(), Box<dyn std::error::Error>> {
2569 skip_if_unsupported!(kernel, test_name);
2570 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2571 let candles = read_candles_from_csv(file_path)?;
2572
2573 let params = SuperTrendParams {
2574 period: Some(10),
2575 factor: Some(3.0),
2576 };
2577 let input = SuperTrendInput::from_candles(&candles, params);
2578 let result = supertrend_with_kernel(&input, kernel)?;
2579 if result.trend.len() > 50 {
2580 for (i, &val) in result.trend[50..].iter().enumerate() {
2581 assert!(
2582 !val.is_nan(),
2583 "[{}] Found unexpected NaN at out-index {}",
2584 test_name,
2585 50 + i
2586 );
2587 }
2588 }
2589 Ok(())
2590 }
2591
2592 fn check_supertrend_streaming(
2593 test_name: &str,
2594 kernel: Kernel,
2595 ) -> Result<(), Box<dyn std::error::Error>> {
2596 skip_if_unsupported!(kernel, test_name);
2597 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2598 let candles = read_candles_from_csv(file_path)?;
2599
2600 let period = 10;
2601 let factor = 3.0;
2602 let params = SuperTrendParams {
2603 period: Some(period),
2604 factor: Some(factor),
2605 };
2606 let input = SuperTrendInput::from_candles(&candles, params.clone());
2607 let batch_output = supertrend_with_kernel(&input, kernel)?;
2608
2609 let mut stream = SuperTrendStream::try_new(params.clone())?;
2610 let mut stream_trend = Vec::with_capacity(candles.close.len());
2611 let mut stream_changed = Vec::with_capacity(candles.close.len());
2612
2613 for i in 0..candles.close.len() {
2614 let (h, l, c) = (candles.high[i], candles.low[i], candles.close[i]);
2615 match stream.update(h, l, c) {
2616 Some((trend, changed)) => {
2617 stream_trend.push(trend);
2618 stream_changed.push(changed);
2619 }
2620 None => {
2621 stream_trend.push(f64::NAN);
2622 stream_changed.push(f64::NAN);
2623 }
2624 }
2625 }
2626 assert_eq!(batch_output.trend.len(), stream_trend.len());
2627 assert_eq!(batch_output.changed.len(), stream_changed.len());
2628
2629 for (i, (&b, &s)) in batch_output
2630 .trend
2631 .iter()
2632 .zip(stream_trend.iter())
2633 .enumerate()
2634 {
2635 if b.is_nan() && s.is_nan() {
2636 continue;
2637 }
2638 let diff = (b - s).abs();
2639 assert!(
2640 diff < 1e-8,
2641 "[{}] Streaming trend mismatch at idx {}: batch={}, stream={}, diff={}",
2642 test_name,
2643 i,
2644 b,
2645 s,
2646 diff
2647 );
2648 }
2649 for (i, (&b, &s)) in batch_output
2650 .changed
2651 .iter()
2652 .zip(stream_changed.iter())
2653 .enumerate()
2654 {
2655 if b.is_nan() && s.is_nan() {
2656 continue;
2657 }
2658 let diff = (b - s).abs();
2659 assert!(
2660 diff < 1e-9,
2661 "[{}] Streaming changed mismatch at idx {}: batch={}, stream={}, diff={}",
2662 test_name,
2663 i,
2664 b,
2665 s,
2666 diff
2667 );
2668 }
2669 Ok(())
2670 }
2671
2672 #[cfg(debug_assertions)]
2673 fn check_supertrend_no_poison(
2674 test_name: &str,
2675 kernel: Kernel,
2676 ) -> Result<(), Box<dyn std::error::Error>> {
2677 skip_if_unsupported!(kernel, test_name);
2678
2679 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2680 let candles = read_candles_from_csv(file_path)?;
2681
2682 let test_params = vec![
2683 SuperTrendParams::default(),
2684 SuperTrendParams {
2685 period: Some(2),
2686 factor: Some(1.0),
2687 },
2688 SuperTrendParams {
2689 period: Some(5),
2690 factor: Some(0.5),
2691 },
2692 SuperTrendParams {
2693 period: Some(5),
2694 factor: Some(2.0),
2695 },
2696 SuperTrendParams {
2697 period: Some(5),
2698 factor: Some(3.5),
2699 },
2700 SuperTrendParams {
2701 period: Some(10),
2702 factor: Some(1.5),
2703 },
2704 SuperTrendParams {
2705 period: Some(14),
2706 factor: Some(2.5),
2707 },
2708 SuperTrendParams {
2709 period: Some(20),
2710 factor: Some(3.0),
2711 },
2712 SuperTrendParams {
2713 period: Some(50),
2714 factor: Some(2.0),
2715 },
2716 SuperTrendParams {
2717 period: Some(100),
2718 factor: Some(1.0),
2719 },
2720 SuperTrendParams {
2721 period: Some(10),
2722 factor: Some(0.1),
2723 },
2724 SuperTrendParams {
2725 period: Some(10),
2726 factor: Some(5.0),
2727 },
2728 ];
2729
2730 for (param_idx, params) in test_params.iter().enumerate() {
2731 let input = SuperTrendInput::from_candles(&candles, params.clone());
2732 let output = supertrend_with_kernel(&input, kernel)?;
2733
2734 for (i, &val) in output.trend.iter().enumerate() {
2735 if val.is_nan() {
2736 continue;
2737 }
2738
2739 let bits = val.to_bits();
2740
2741 if bits == 0x11111111_11111111 {
2742 panic!(
2743 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in trend \
2744 with params: period={}, factor={} (param set {})",
2745 test_name, val, bits, i,
2746 params.period.unwrap_or(10),
2747 params.factor.unwrap_or(3.0),
2748 param_idx
2749 );
2750 }
2751
2752 if bits == 0x22222222_22222222 {
2753 panic!(
2754 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in trend \
2755 with params: period={}, factor={} (param set {})",
2756 test_name, val, bits, i,
2757 params.period.unwrap_or(10),
2758 params.factor.unwrap_or(3.0),
2759 param_idx
2760 );
2761 }
2762
2763 if bits == 0x33333333_33333333 {
2764 panic!(
2765 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in trend \
2766 with params: period={}, factor={} (param set {})",
2767 test_name, val, bits, i,
2768 params.period.unwrap_or(10),
2769 params.factor.unwrap_or(3.0),
2770 param_idx
2771 );
2772 }
2773 }
2774
2775 for (i, &val) in output.changed.iter().enumerate() {
2776 if val.is_nan() {
2777 continue;
2778 }
2779
2780 let bits = val.to_bits();
2781
2782 if bits == 0x11111111_11111111 {
2783 panic!(
2784 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in changed \
2785 with params: period={}, factor={} (param set {})",
2786 test_name, val, bits, i,
2787 params.period.unwrap_or(10),
2788 params.factor.unwrap_or(3.0),
2789 param_idx
2790 );
2791 }
2792
2793 if bits == 0x22222222_22222222 {
2794 panic!(
2795 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in changed \
2796 with params: period={}, factor={} (param set {})",
2797 test_name, val, bits, i,
2798 params.period.unwrap_or(10),
2799 params.factor.unwrap_or(3.0),
2800 param_idx
2801 );
2802 }
2803
2804 if bits == 0x33333333_33333333 {
2805 panic!(
2806 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in changed \
2807 with params: period={}, factor={} (param set {})",
2808 test_name, val, bits, i,
2809 params.period.unwrap_or(10),
2810 params.factor.unwrap_or(3.0),
2811 param_idx
2812 );
2813 }
2814 }
2815 }
2816
2817 Ok(())
2818 }
2819
2820 #[cfg(not(debug_assertions))]
2821 fn check_supertrend_no_poison(
2822 _test_name: &str,
2823 _kernel: Kernel,
2824 ) -> Result<(), Box<dyn std::error::Error>> {
2825 Ok(())
2826 }
2827
2828 #[cfg(feature = "proptest")]
2829 #[allow(clippy::float_cmp)]
2830 fn check_supertrend_property(
2831 test_name: &str,
2832 kernel: Kernel,
2833 ) -> Result<(), Box<dyn std::error::Error>> {
2834 use proptest::prelude::*;
2835 skip_if_unsupported!(kernel, test_name);
2836
2837 let strat = (2usize..=50).prop_flat_map(|period| {
2838 let data_len = period * 2 + 50;
2839 (
2840 prop::collection::vec(
2841 (100f64..10000f64).prop_filter("finite", |x| x.is_finite()),
2842 data_len,
2843 ),
2844 Just(period),
2845 0.5f64..5.0f64,
2846 )
2847 });
2848
2849 proptest::test_runner::TestRunner::default()
2850 .run(&strat, |(base_prices, period, factor)| {
2851 let mut high = Vec::with_capacity(base_prices.len());
2852 let mut low = Vec::with_capacity(base_prices.len());
2853 let mut close = Vec::with_capacity(base_prices.len());
2854
2855 let mut rng_state = 42u64;
2856 for &base in &base_prices {
2857 rng_state = rng_state.wrapping_mul(1664525).wrapping_add(1013904223);
2858 let rand1 = ((rng_state >> 32) as f64) / (u32::MAX as f64);
2859 rng_state = rng_state.wrapping_mul(1664525).wrapping_add(1013904223);
2860 let rand2 = ((rng_state >> 32) as f64) / (u32::MAX as f64);
2861
2862 let spread = base * (0.005 + rand1 * 0.025);
2863 let h = base + spread;
2864 let l = base - spread;
2865
2866 let c = l + (h - l) * rand2;
2867
2868 high.push(h);
2869 low.push(l);
2870 close.push(c);
2871 }
2872
2873 let params = SuperTrendParams {
2874 period: Some(period),
2875 factor: Some(factor),
2876 };
2877 let input = SuperTrendInput::from_slices(&high, &low, &close, params);
2878
2879 let output = supertrend_with_kernel(&input, kernel).unwrap();
2880
2881 let ref_output = supertrend_with_kernel(&input, Kernel::Scalar).unwrap();
2882
2883 prop_assert_eq!(
2884 output.trend.len(),
2885 high.len(),
2886 "[{}] Trend length mismatch",
2887 test_name
2888 );
2889 prop_assert_eq!(
2890 output.changed.len(),
2891 high.len(),
2892 "[{}] Changed length mismatch",
2893 test_name
2894 );
2895
2896 let warmup_end = period - 1;
2897 for i in 0..warmup_end {
2898 prop_assert!(
2899 output.trend[i].is_nan(),
2900 "[{}] Expected NaN during warmup at index {}",
2901 test_name,
2902 i
2903 );
2904 prop_assert!(
2905 output.changed[i].is_nan(),
2906 "[{}] Expected NaN in changed during warmup at index {}",
2907 test_name,
2908 i
2909 );
2910 }
2911
2912 for i in warmup_end..output.trend.len() {
2913 let val = output.trend[i];
2914 if !val.is_nan() {
2915 let global_high = high.iter().fold(f64::NEG_INFINITY, |a, &b| {
2916 if b.is_finite() {
2917 a.max(b)
2918 } else {
2919 a
2920 }
2921 });
2922 let global_low = low.iter().fold(f64::INFINITY, |a, &b| {
2923 if b.is_finite() {
2924 a.min(b)
2925 } else {
2926 a
2927 }
2928 });
2929
2930 let global_range = global_high - global_low;
2931
2932 let margin = global_range * factor;
2933
2934 prop_assert!(
2935 val >= global_low - margin && val <= global_high + margin,
2936 "[{}] Trend value {} at index {} outside global bounds [{}, {}]",
2937 test_name,
2938 val,
2939 i,
2940 global_low - margin,
2941 global_high + margin
2942 );
2943 }
2944 }
2945
2946 for i in warmup_end..output.changed.len() {
2947 let val = output.changed[i];
2948 if !val.is_nan() {
2949 prop_assert!(
2950 val == 0.0 || val == 1.0,
2951 "[{}] Changed value {} at index {} is not 0.0 or 1.0",
2952 test_name,
2953 val,
2954 i
2955 );
2956 }
2957 }
2958
2959 for i in 0..output.trend.len() {
2960 let trend_val = output.trend[i];
2961 let ref_trend_val = ref_output.trend[i];
2962 let changed_val = output.changed[i];
2963 let ref_changed_val = ref_output.changed[i];
2964
2965 if !trend_val.is_finite() || !ref_trend_val.is_finite() {
2966 prop_assert_eq!(
2967 trend_val.to_bits(),
2968 ref_trend_val.to_bits(),
2969 "[{}] NaN/Inf mismatch in trend at index {}",
2970 test_name,
2971 i
2972 );
2973 } else {
2974 let ulp_diff = trend_val.to_bits().abs_diff(ref_trend_val.to_bits());
2975 prop_assert!(
2976 (trend_val - ref_trend_val).abs() <= 1e-9 || ulp_diff <= 5,
2977 "[{}] Kernel mismatch in trend at index {}: {} vs {} (ULP={})",
2978 test_name,
2979 i,
2980 trend_val,
2981 ref_trend_val,
2982 ulp_diff
2983 );
2984 }
2985
2986 if !changed_val.is_nan() && !ref_changed_val.is_nan() {
2987 prop_assert_eq!(
2988 changed_val,
2989 ref_changed_val,
2990 "[{}] Kernel mismatch in changed at index {}: {} vs {}",
2991 test_name,
2992 i,
2993 changed_val,
2994 ref_changed_val
2995 );
2996 }
2997 }
2998
2999 if base_prices.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10) {
3000 let stable_start = (period * 2).min(output.trend.len());
3001 if stable_start < output.trend.len() {
3002 let stable_trend = output.trend[stable_start];
3003 for i in (stable_start + 1)..output.trend.len() {
3004 if !output.trend[i].is_nan() && !stable_trend.is_nan() {
3005 prop_assert!(
3006 (output.trend[i] - stable_trend).abs() < 1e-9,
3007 "[{}] Trend not stable for constant prices at index {}",
3008 test_name,
3009 i
3010 );
3011 }
3012 }
3013 }
3014 }
3015
3016 if output.trend.len() > warmup_end + 1 {
3017 for i in (warmup_end + 1)..output.changed.len() {
3018 let changed_val = output.changed[i];
3019 if !changed_val.is_nan() {
3020 let curr_trend = output.trend[i];
3021 let prev_trend = output.trend[i - 1];
3022
3023 if !curr_trend.is_nan() && !prev_trend.is_nan() {
3024 if changed_val == 1.0 {
3025 prop_assert!(
3026 (curr_trend - prev_trend).abs() > 1e-6,
3027 "[{}] Changed=1.0 at index {} but trend didn't switch: {} vs {}",
3028 test_name, i, prev_trend, curr_trend
3029 );
3030 }
3031 }
3032 }
3033 }
3034 }
3035
3036 Ok(())
3037 })
3038 .unwrap();
3039
3040 Ok(())
3041 }
3042
3043 macro_rules! generate_all_supertrend_tests {
3044 ($($test_fn:ident),*) => {
3045 paste::paste! {
3046 $(
3047 #[test]
3048 fn [<$test_fn _scalar_f64>]() {
3049 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
3050 }
3051 )*
3052 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3053 $(
3054 #[test]
3055 fn [<$test_fn _avx2_f64>]() {
3056 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
3057 }
3058 #[test]
3059 fn [<$test_fn _avx512_f64>]() {
3060 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
3061 }
3062 )*
3063 }
3064 }
3065 }
3066
3067 generate_all_supertrend_tests!(
3068 check_supertrend_partial_params,
3069 check_supertrend_accuracy,
3070 check_supertrend_default_candles,
3071 check_supertrend_zero_period,
3072 check_supertrend_period_exceeds_length,
3073 check_supertrend_very_small_dataset,
3074 check_supertrend_reinput,
3075 check_supertrend_nan_handling,
3076 check_supertrend_streaming,
3077 check_supertrend_no_poison
3078 );
3079
3080 #[cfg(feature = "proptest")]
3081 generate_all_supertrend_tests!(check_supertrend_property);
3082
3083 fn check_batch_default_row(
3084 test: &str,
3085 kernel: Kernel,
3086 ) -> Result<(), Box<dyn std::error::Error>> {
3087 skip_if_unsupported!(kernel, test);
3088
3089 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3090 let c = read_candles_from_csv(file)?;
3091
3092 let output = SuperTrendBatchBuilder::new()
3093 .kernel(kernel)
3094 .apply_candles(&c)?;
3095
3096 let def = SuperTrendParams::default();
3097 let row = output.trend_for(&def).expect("default row missing");
3098
3099 assert_eq!(row.len(), c.close.len());
3100
3101 let expected = [
3102 61811.479454208165,
3103 61721.73150878735,
3104 61459.10835790861,
3105 61351.59752211775,
3106 61033.18776990598,
3107 ];
3108 let start = row.len() - 5;
3109 for (i, &v) in row[start..].iter().enumerate() {
3110 assert!(
3111 (v - expected[i]).abs() < 1e-4,
3112 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
3113 );
3114 }
3115 Ok(())
3116 }
3117
3118 #[cfg(debug_assertions)]
3119 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
3120 skip_if_unsupported!(kernel, test);
3121
3122 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3123 let c = read_candles_from_csv(file)?;
3124
3125 let test_configs = vec![
3126 (2, 10, 2, 1.0, 3.0, 0.5),
3127 (5, 25, 5, 2.0, 2.0, 0.0),
3128 (10, 10, 0, 0.5, 4.0, 0.5),
3129 (2, 5, 1, 1.5, 1.5, 0.0),
3130 (30, 60, 15, 3.0, 3.0, 0.0),
3131 (20, 30, 5, 1.0, 3.0, 1.0),
3132 (8, 12, 1, 0.5, 2.5, 0.5),
3133 ];
3134
3135 for (cfg_idx, &(p_start, p_end, p_step, f_start, f_end, f_step)) in
3136 test_configs.iter().enumerate()
3137 {
3138 let output = SuperTrendBatchBuilder::new()
3139 .kernel(kernel)
3140 .period_range(p_start, p_end, p_step)
3141 .factor_range(f_start, f_end, f_step)
3142 .apply_candles(&c)?;
3143
3144 for (idx, &val) in output.trend.iter().enumerate() {
3145 if val.is_nan() {
3146 continue;
3147 }
3148
3149 let bits = val.to_bits();
3150 let row = idx / output.cols;
3151 let col = idx % output.cols;
3152 let combo = &output.combos[row];
3153
3154 if bits == 0x11111111_11111111 {
3155 panic!(
3156 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
3157 at row {} col {} (flat index {}) in trend with params: period={}, factor={}",
3158 test,
3159 cfg_idx,
3160 val,
3161 bits,
3162 row,
3163 col,
3164 idx,
3165 combo.period.unwrap_or(10),
3166 combo.factor.unwrap_or(3.0)
3167 );
3168 }
3169
3170 if bits == 0x22222222_22222222 {
3171 panic!(
3172 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
3173 at row {} col {} (flat index {}) in trend with params: period={}, factor={}",
3174 test,
3175 cfg_idx,
3176 val,
3177 bits,
3178 row,
3179 col,
3180 idx,
3181 combo.period.unwrap_or(10),
3182 combo.factor.unwrap_or(3.0)
3183 );
3184 }
3185
3186 if bits == 0x33333333_33333333 {
3187 panic!(
3188 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
3189 at row {} col {} (flat index {}) in trend with params: period={}, factor={}",
3190 test,
3191 cfg_idx,
3192 val,
3193 bits,
3194 row,
3195 col,
3196 idx,
3197 combo.period.unwrap_or(10),
3198 combo.factor.unwrap_or(3.0)
3199 );
3200 }
3201 }
3202
3203 for (idx, &val) in output.changed.iter().enumerate() {
3204 if val.is_nan() {
3205 continue;
3206 }
3207
3208 let bits = val.to_bits();
3209 let row = idx / output.cols;
3210 let col = idx % output.cols;
3211 let combo = &output.combos[row];
3212
3213 if bits == 0x11111111_11111111 {
3214 panic!(
3215 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
3216 at row {} col {} (flat index {}) in changed with params: period={}, factor={}",
3217 test,
3218 cfg_idx,
3219 val,
3220 bits,
3221 row,
3222 col,
3223 idx,
3224 combo.period.unwrap_or(10),
3225 combo.factor.unwrap_or(3.0)
3226 );
3227 }
3228
3229 if bits == 0x22222222_22222222 {
3230 panic!(
3231 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
3232 at row {} col {} (flat index {}) in changed with params: period={}, factor={}",
3233 test,
3234 cfg_idx,
3235 val,
3236 bits,
3237 row,
3238 col,
3239 idx,
3240 combo.period.unwrap_or(10),
3241 combo.factor.unwrap_or(3.0)
3242 );
3243 }
3244
3245 if bits == 0x33333333_33333333 {
3246 panic!(
3247 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
3248 at row {} col {} (flat index {}) in changed with params: period={}, factor={}",
3249 test,
3250 cfg_idx,
3251 val,
3252 bits,
3253 row,
3254 col,
3255 idx,
3256 combo.period.unwrap_or(10),
3257 combo.factor.unwrap_or(3.0)
3258 );
3259 }
3260 }
3261 }
3262
3263 Ok(())
3264 }
3265
3266 #[cfg(not(debug_assertions))]
3267 fn check_batch_no_poison(
3268 _test: &str,
3269 _kernel: Kernel,
3270 ) -> Result<(), Box<dyn std::error::Error>> {
3271 Ok(())
3272 }
3273
3274 macro_rules! gen_batch_tests {
3275 ($fn_name:ident) => {
3276 paste::paste! {
3277 #[test] fn [<$fn_name _scalar>]() {
3278 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
3279 }
3280 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3281 #[test] fn [<$fn_name _avx2>]() {
3282 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
3283 }
3284 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3285 #[test] fn [<$fn_name _avx512>]() {
3286 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
3287 }
3288 #[test] fn [<$fn_name _auto_detect>]() {
3289 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
3290 }
3291 }
3292 };
3293 }
3294 gen_batch_tests!(check_batch_default_row);
3295 gen_batch_tests!(check_batch_no_poison);
3296}