Skip to main content

vector_ta/indicators/
di.rs

1#[cfg(all(feature = "python", feature = "cuda"))]
2use numpy::PyUntypedArrayMethods;
3#[cfg(feature = "python")]
4use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
5#[cfg(feature = "python")]
6use pyo3::exceptions::PyValueError;
7#[cfg(feature = "python")]
8use pyo3::prelude::*;
9#[cfg(feature = "python")]
10use pyo3::types::PyDict;
11
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use serde::{Deserialize, Serialize};
14#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
15use wasm_bindgen::prelude::*;
16
17use crate::utilities::data_loader::{source_type, Candles};
18use crate::utilities::enums::Kernel;
19use crate::utilities::helpers::{
20    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
21    make_uninit_matrix,
22};
23#[cfg(feature = "python")]
24use crate::utilities::kernel_validation::validate_kernel;
25#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
26use core::arch::x86_64::*;
27#[cfg(not(target_arch = "wasm32"))]
28use rayon::prelude::*;
29use std::convert::AsRef;
30use std::error::Error;
31use thiserror::Error;
32
33#[derive(Debug, Clone)]
34pub enum DiData<'a> {
35    Candles {
36        candles: &'a Candles,
37    },
38    Slices {
39        high: &'a [f64],
40        low: &'a [f64],
41        close: &'a [f64],
42    },
43}
44
45#[derive(Debug, Clone)]
46pub struct DiOutput {
47    pub plus: Vec<f64>,
48    pub minus: Vec<f64>,
49}
50
51#[derive(Debug, Clone)]
52#[cfg_attr(
53    all(target_arch = "wasm32", feature = "wasm"),
54    derive(Serialize, Deserialize)
55)]
56pub struct DiParams {
57    pub period: Option<usize>,
58}
59
60impl Default for DiParams {
61    fn default() -> Self {
62        Self { period: Some(14) }
63    }
64}
65
66#[derive(Debug, Clone)]
67pub struct DiInput<'a> {
68    pub data: DiData<'a>,
69    pub params: DiParams,
70}
71
72impl<'a> DiInput<'a> {
73    #[inline(always)]
74    pub fn from_candles(candles: &'a Candles, params: DiParams) -> Self {
75        Self {
76            data: DiData::Candles { candles },
77            params,
78        }
79    }
80    #[inline(always)]
81    pub fn from_slices(
82        high: &'a [f64],
83        low: &'a [f64],
84        close: &'a [f64],
85        params: DiParams,
86    ) -> Self {
87        Self {
88            data: DiData::Slices { high, low, close },
89            params,
90        }
91    }
92    #[inline(always)]
93    pub fn with_default_candles(candles: &'a Candles) -> Self {
94        Self {
95            data: DiData::Candles { candles },
96            params: DiParams::default(),
97        }
98    }
99    #[inline(always)]
100    pub fn get_period(&self) -> usize {
101        self.params.period.unwrap_or(14)
102    }
103    #[inline(always)]
104    pub fn as_slices(&self) -> (&'a [f64], &'a [f64], &'a [f64]) {
105        match &self.data {
106            DiData::Candles { candles } => (
107                source_type(candles, "high"),
108                source_type(candles, "low"),
109                source_type(candles, "close"),
110            ),
111            DiData::Slices { high, low, close } => (*high, *low, *close),
112        }
113    }
114}
115
116#[derive(Copy, Clone, Debug)]
117pub struct DiBuilder {
118    period: Option<usize>,
119    kernel: Kernel,
120}
121
122impl Default for DiBuilder {
123    fn default() -> Self {
124        Self {
125            period: None,
126            kernel: Kernel::Auto,
127        }
128    }
129}
130
131impl DiBuilder {
132    #[inline(always)]
133    pub fn new() -> Self {
134        Self::default()
135    }
136    #[inline(always)]
137    pub fn period(mut self, n: usize) -> Self {
138        self.period = Some(n);
139        self
140    }
141    #[inline(always)]
142    pub fn kernel(mut self, k: Kernel) -> Self {
143        self.kernel = k;
144        self
145    }
146    #[inline(always)]
147    pub fn apply(self, candles: &Candles) -> Result<DiOutput, DiError> {
148        let params = DiParams {
149            period: self.period,
150        };
151        let input = DiInput::from_candles(candles, params);
152        di_with_kernel(&input, self.kernel)
153    }
154    #[inline(always)]
155    pub fn apply_slices(
156        self,
157        high: &[f64],
158        low: &[f64],
159        close: &[f64],
160    ) -> Result<DiOutput, DiError> {
161        let params = DiParams {
162            period: self.period,
163        };
164        let input = DiInput::from_slices(high, low, close, params);
165        di_with_kernel(&input, self.kernel)
166    }
167    #[inline(always)]
168    pub fn into_stream(self) -> Result<DiStream, DiError> {
169        let params = DiParams {
170            period: self.period,
171        };
172        DiStream::try_new(params)
173    }
174}
175
176#[derive(Debug, Error)]
177pub enum DiError {
178    #[error("di: Empty data provided.")]
179    EmptyInputData,
180    #[error("di: Empty data provided for DI.")]
181    EmptyData,
182    #[error("di: Invalid period: period = {period}, data length = {data_len}")]
183    InvalidPeriod { period: usize, data_len: usize },
184    #[error("di: Not enough valid data: needed = {needed}, valid = {valid}")]
185    NotEnoughValidData { needed: usize, valid: usize },
186    #[error("di: All values are NaN.")]
187    AllValuesNaN,
188    #[error("di: Output length mismatch: expected {expected}, got {got}")]
189    OutputLengthMismatch { expected: usize, got: usize },
190    #[error("di: Invalid range expansion: start={start}, end={end}, step={step}")]
191    InvalidRange {
192        start: usize,
193        end: usize,
194        step: usize,
195    },
196    #[error("di: Invalid kernel for batch path: {0:?}")]
197    InvalidKernelForBatch(Kernel),
198}
199
200#[inline]
201pub fn di(input: &DiInput) -> Result<DiOutput, DiError> {
202    di_with_kernel(input, Kernel::Auto)
203}
204
205#[inline(always)]
206fn di_prepare<'a>(
207    input: &'a DiInput<'a>,
208    kernel: Kernel,
209) -> Result<(&'a [f64], &'a [f64], &'a [f64], usize, usize, Kernel), DiError> {
210    let (high, low, close) = match &input.data {
211        DiData::Candles { candles } => {
212            let h = source_type(candles, "high");
213            let l = source_type(candles, "low");
214            let c = source_type(candles, "close");
215            (h, l, c)
216        }
217        DiData::Slices { high, low, close } => (*high, *low, *close),
218    };
219    let n = high.len();
220    if n == 0 || low.len() != n || close.len() != n {
221        return Err(DiError::EmptyInputData);
222    }
223    let period = input.get_period();
224    if period == 0 || period > n {
225        return Err(DiError::InvalidPeriod {
226            period,
227            data_len: n,
228        });
229    }
230    let first_valid_idx =
231        (0..n).find(|&i| !(high[i].is_nan() || low[i].is_nan() || close[i].is_nan()));
232    let first_idx = match first_valid_idx {
233        Some(idx) => idx,
234        None => return Err(DiError::AllValuesNaN),
235    };
236    if (n - first_idx) < period {
237        return Err(DiError::NotEnoughValidData {
238            needed: period,
239            valid: n - first_idx,
240        });
241    }
242
243    let chosen = match kernel {
244        Kernel::Auto => Kernel::Scalar,
245        other => other,
246    };
247
248    Ok((high, low, close, period, first_idx, chosen))
249}
250
251pub fn di_with_kernel(input: &DiInput, kernel: Kernel) -> Result<DiOutput, DiError> {
252    let (high, low, close, period, first_idx, chosen) = di_prepare(input, kernel)?;
253
254    unsafe {
255        match chosen {
256            Kernel::Scalar | Kernel::ScalarBatch => di_scalar(high, low, close, period, first_idx),
257            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
258            Kernel::Avx2 | Kernel::Avx2Batch => di_avx2(high, low, close, period, first_idx),
259            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
260            Kernel::Avx512 | Kernel::Avx512Batch => di_avx512(high, low, close, period, first_idx),
261            _ => unreachable!(),
262        }
263    }
264}
265
266#[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
267#[inline(always)]
268pub unsafe fn di_avx2_into(
269    high: &[f64],
270    low: &[f64],
271    close: &[f64],
272    period: usize,
273    first_idx: usize,
274    out_plus: &mut [f64],
275    out_minus: &mut [f64],
276) {
277    di_scalar_into(high, low, close, period, first_idx, out_plus, out_minus)
278}
279
280#[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
281#[inline(always)]
282pub unsafe fn di_avx512_into(
283    high: &[f64],
284    low: &[f64],
285    close: &[f64],
286    period: usize,
287    first_idx: usize,
288    out_plus: &mut [f64],
289    out_minus: &mut [f64],
290) {
291    di_scalar_into(high, low, close, period, first_idx, out_plus, out_minus)
292}
293
294#[inline(always)]
295fn di_compute_into(
296    high: &[f64],
297    low: &[f64],
298    close: &[f64],
299    period: usize,
300    first_idx: usize,
301    kernel: Kernel,
302    out_plus: &mut [f64],
303    out_minus: &mut [f64],
304) {
305    unsafe {
306        match kernel {
307            Kernel::Scalar | Kernel::ScalarBatch => {
308                di_scalar_into(high, low, close, period, first_idx, out_plus, out_minus)
309            }
310            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
311            Kernel::Avx2 | Kernel::Avx2Batch => {
312                di_avx2_into(high, low, close, period, first_idx, out_plus, out_minus)
313            }
314            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
315            Kernel::Avx512 | Kernel::Avx512Batch => {
316                di_avx512_into(high, low, close, period, first_idx, out_plus, out_minus)
317            }
318            _ => unreachable!(),
319        }
320    }
321}
322
323pub fn di_into_slice(
324    dst_plus: &mut [f64],
325    dst_minus: &mut [f64],
326    input: &DiInput,
327    kern: Kernel,
328) -> Result<(), DiError> {
329    let (high, low, close, period, first_idx, chosen) = di_prepare(input, kern)?;
330
331    let n = high.len();
332    if dst_plus.len() != n || dst_minus.len() != n {
333        return Err(DiError::OutputLengthMismatch {
334            expected: n,
335            got: dst_plus.len().min(dst_minus.len()),
336        });
337    }
338
339    di_compute_into(
340        high, low, close, period, first_idx, chosen, dst_plus, dst_minus,
341    );
342
343    let warmup_end = first_idx + period - 1;
344    for v in &mut dst_plus[..warmup_end] {
345        *v = f64::from_bits(0x7ff8_0000_0000_0000);
346    }
347    for v in &mut dst_minus[..warmup_end] {
348        *v = f64::from_bits(0x7ff8_0000_0000_0000);
349    }
350
351    Ok(())
352}
353
354#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
355pub fn di_into(
356    input: &DiInput,
357    out_plus: &mut [f64],
358    out_minus: &mut [f64],
359) -> Result<(), DiError> {
360    let (high, low, close, period, first_idx, chosen) = di_prepare(input, Kernel::Auto)?;
361
362    let n = high.len();
363    if out_plus.len() != n || out_minus.len() != n {
364        return Err(DiError::OutputLengthMismatch {
365            expected: n,
366            got: out_plus.len().min(out_minus.len()),
367        });
368    }
369
370    di_compute_into(
371        high, low, close, period, first_idx, chosen, out_plus, out_minus,
372    );
373
374    let warmup_end = first_idx + period - 1;
375    let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
376    for v in &mut out_plus[..warmup_end] {
377        *v = qnan;
378    }
379    for v in &mut out_minus[..warmup_end] {
380        *v = qnan;
381    }
382
383    Ok(())
384}
385
386#[inline(always)]
387pub unsafe fn di_scalar_into(
388    high: &[f64],
389    low: &[f64],
390    close: &[f64],
391    period: usize,
392    first_idx: usize,
393    out_plus: &mut [f64],
394    out_minus: &mut [f64],
395) {
396    let n = high.len();
397    if n == 0 {
398        return;
399    }
400
401    let pf = period as f64;
402    let invp = pf.recip();
403    let keep = 1.0 - invp;
404
405    let mut prev_h = high[first_idx];
406    let mut prev_l = low[first_idx];
407    let mut prev_c = close[first_idx];
408
409    let start = first_idx + 1;
410    let stop = first_idx + period;
411    let mut plus_dm_sum = 0.0;
412    let mut minus_dm_sum = 0.0;
413    let mut tr_sum = 0.0;
414
415    let mut i = start;
416    while i < stop {
417        let ch = high[i];
418        let cl = low[i];
419        let cc = close[i];
420
421        let dp = ch - prev_h;
422        let dm = prev_l - cl;
423        if dp > dm && dp > 0.0 {
424            plus_dm_sum += dp;
425        }
426        if dm > dp && dm > 0.0 {
427            minus_dm_sum += dm;
428        }
429
430        let mut tr = ch - cl;
431        let tr2 = (ch - prev_c).abs();
432        let tr3 = (cl - prev_c).abs();
433        if tr2 > tr {
434            tr = tr2;
435        }
436        if tr3 > tr {
437            tr = tr3;
438        }
439        tr_sum += tr;
440
441        prev_h = ch;
442        prev_l = cl;
443        prev_c = cc;
444        i += 1;
445    }
446
447    let mut cur_plus = plus_dm_sum;
448    let mut cur_minus = minus_dm_sum;
449    let mut cur_tr = tr_sum;
450
451    let mut idx = stop - 1;
452    let mut scale = if cur_tr == 0.0 { 0.0 } else { 100.0 / cur_tr };
453    out_plus[idx] = cur_plus * scale;
454    out_minus[idx] = cur_minus * scale;
455    idx += 1;
456
457    while idx < n {
458        let ch = high[idx];
459        let cl = low[idx];
460        let cc = close[idx];
461
462        let dp = ch - prev_h;
463        let dm = prev_l - cl;
464        let inc_p = if dp > dm && dp > 0.0 { dp } else { 0.0 };
465        let inc_m = if dm > dp && dm > 0.0 { dm } else { 0.0 };
466
467        cur_plus = cur_plus.mul_add(keep, inc_p);
468        cur_minus = cur_minus.mul_add(keep, inc_m);
469
470        let mut tr = ch - cl;
471        let tr2 = (ch - prev_c).abs();
472        let tr3 = (cl - prev_c).abs();
473        if tr2 > tr {
474            tr = tr2;
475        }
476        if tr3 > tr {
477            tr = tr3;
478        }
479        cur_tr = cur_tr.mul_add(keep, tr);
480
481        scale = if cur_tr == 0.0 { 0.0 } else { 100.0 / cur_tr };
482        out_plus[idx] = cur_plus * scale;
483        out_minus[idx] = cur_minus * scale;
484
485        prev_h = ch;
486        prev_l = cl;
487        prev_c = cc;
488        idx += 1;
489    }
490}
491
492#[inline(always)]
493pub unsafe fn di_scalar(
494    high: &[f64],
495    low: &[f64],
496    close: &[f64],
497    period: usize,
498    first_idx: usize,
499) -> Result<DiOutput, DiError> {
500    let n = high.len();
501    let mut plus_di = alloc_with_nan_prefix(n, first_idx + period - 1);
502    let mut minus_di = alloc_with_nan_prefix(n, first_idx + period - 1);
503    di_scalar_into(
504        high,
505        low,
506        close,
507        period,
508        first_idx,
509        &mut plus_di,
510        &mut minus_di,
511    );
512    Ok(DiOutput {
513        plus: plus_di,
514        minus: minus_di,
515    })
516}
517
518#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
519#[inline(always)]
520pub unsafe fn di_avx2_into(
521    high: &[f64],
522    low: &[f64],
523    close: &[f64],
524    period: usize,
525    first_idx: usize,
526    out_plus: &mut [f64],
527    out_minus: &mut [f64],
528) {
529    di_scalar_into(high, low, close, period, first_idx, out_plus, out_minus)
530}
531
532#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
533#[inline(always)]
534pub unsafe fn di_avx2(
535    high: &[f64],
536    low: &[f64],
537    close: &[f64],
538    period: usize,
539    first_idx: usize,
540) -> Result<DiOutput, DiError> {
541    di_scalar(high, low, close, period, first_idx)
542}
543
544#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
545#[inline(always)]
546pub unsafe fn di_avx512_into(
547    high: &[f64],
548    low: &[f64],
549    close: &[f64],
550    period: usize,
551    first_idx: usize,
552    out_plus: &mut [f64],
553    out_minus: &mut [f64],
554) {
555    di_scalar_into(high, low, close, period, first_idx, out_plus, out_minus)
556}
557
558#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
559#[inline(always)]
560pub unsafe fn di_avx512(
561    high: &[f64],
562    low: &[f64],
563    close: &[f64],
564    period: usize,
565    first_idx: usize,
566) -> Result<DiOutput, DiError> {
567    di_scalar(high, low, close, period, first_idx)
568}
569
570#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
571#[inline(always)]
572pub unsafe fn di_avx512_short(
573    high: &[f64],
574    low: &[f64],
575    close: &[f64],
576    period: usize,
577    first_idx: usize,
578) -> Result<DiOutput, DiError> {
579    di_avx512(high, low, close, period, first_idx)
580}
581
582#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
583#[inline(always)]
584pub unsafe fn di_avx512_long(
585    high: &[f64],
586    low: &[f64],
587    close: &[f64],
588    period: usize,
589    first_idx: usize,
590) -> Result<DiOutput, DiError> {
591    di_avx512(high, low, close, period, first_idx)
592}
593
594#[derive(Debug, Clone)]
595pub struct DiStream {
596    period: usize,
597    keep: f64,
598
599    prev_h: f64,
600    prev_l: f64,
601    prev_c: f64,
602    have_prev: bool,
603
604    warm_plus: f64,
605    warm_minus: f64,
606    warm_tr: f64,
607    warm_count: usize,
608
609    cur_plus: f64,
610    cur_minus: f64,
611    cur_tr: f64,
612    warmed: bool,
613}
614
615impl DiStream {
616    pub fn try_new(params: DiParams) -> Result<Self, DiError> {
617        let period = params.period.unwrap_or(14);
618        if period == 0 {
619            return Err(DiError::InvalidPeriod {
620                period,
621                data_len: 0,
622            });
623        }
624        let pf = period as f64;
625        Ok(Self {
626            period,
627            keep: 1.0 - pf.recip(),
628
629            prev_h: f64::NAN,
630            prev_l: f64::NAN,
631            prev_c: f64::NAN,
632            have_prev: false,
633
634            warm_plus: 0.0,
635            warm_minus: 0.0,
636            warm_tr: 0.0,
637            warm_count: 0,
638
639            cur_plus: 0.0,
640            cur_minus: 0.0,
641            cur_tr: 0.0,
642            warmed: false,
643        })
644    }
645
646    #[inline(always)]
647    fn reset(&mut self) {
648        self.prev_h = f64::NAN;
649        self.prev_l = f64::NAN;
650        self.prev_c = f64::NAN;
651        self.have_prev = false;
652
653        self.warm_plus = 0.0;
654        self.warm_minus = 0.0;
655        self.warm_tr = 0.0;
656        self.warm_count = 0;
657
658        self.cur_plus = 0.0;
659        self.cur_minus = 0.0;
660        self.cur_tr = 0.0;
661        self.warmed = false;
662    }
663
664    #[inline(always)]
665    fn tr_fast(high: f64, low: f64, prev_close: f64) -> f64 {
666        let hi = if high > prev_close { high } else { prev_close };
667        let lo = if low < prev_close { low } else { prev_close };
668        hi - lo
669    }
670
671    #[inline(always)]
672    pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<(f64, f64)> {
673        if high.is_nan() || low.is_nan() || close.is_nan() {
674            self.reset();
675            return None;
676        }
677
678        if !self.have_prev {
679            self.prev_h = high;
680            self.prev_l = low;
681            self.prev_c = close;
682            self.have_prev = true;
683            self.warm_count = 1;
684            return None;
685        }
686
687        let dp = high - self.prev_h;
688        let dm = self.prev_l - low;
689
690        let inc_p = if dp > dm && dp > 0.0 { dp } else { 0.0 };
691        let inc_m = if dm > dp && dm > 0.0 { dm } else { 0.0 };
692        let tr = Self::tr_fast(high, low, self.prev_c);
693
694        self.prev_h = high;
695        self.prev_l = low;
696        self.prev_c = close;
697
698        if !self.warmed {
699            self.warm_plus += inc_p;
700            self.warm_minus += inc_m;
701            self.warm_tr += tr;
702            self.warm_count += 1;
703
704            if self.warm_count < self.period {
705                return None;
706            }
707
708            self.cur_plus = self.warm_plus;
709            self.cur_minus = self.warm_minus;
710            self.cur_tr = self.warm_tr;
711            self.warmed = true;
712
713            let scale = if self.cur_tr == 0.0 {
714                0.0
715            } else {
716                100.0 / self.cur_tr
717            };
718            return Some((self.cur_plus * scale, self.cur_minus * scale));
719        }
720
721        self.cur_plus = self.cur_plus.mul_add(self.keep, inc_p);
722        self.cur_minus = self.cur_minus.mul_add(self.keep, inc_m);
723        self.cur_tr = self.cur_tr.mul_add(self.keep, tr);
724
725        let scale = if self.cur_tr == 0.0 {
726            0.0
727        } else {
728            100.0 / self.cur_tr
729        };
730        Some((self.cur_plus * scale, self.cur_minus * scale))
731    }
732}
733
734#[derive(Clone, Debug)]
735pub struct DiBatchRange {
736    pub period: (usize, usize, usize),
737}
738
739impl Default for DiBatchRange {
740    fn default() -> Self {
741        Self {
742            period: (14, 263, 1),
743        }
744    }
745}
746
747#[derive(Clone, Debug, Default)]
748pub struct DiBatchBuilder {
749    range: DiBatchRange,
750    kernel: Kernel,
751}
752
753impl DiBatchBuilder {
754    pub fn new() -> Self {
755        Self::default()
756    }
757    pub fn kernel(mut self, k: Kernel) -> Self {
758        self.kernel = k;
759        self
760    }
761    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
762        self.range.period = (start, end, step);
763        self
764    }
765    pub fn period_static(mut self, p: usize) -> Self {
766        self.range.period = (p, p, 0);
767        self
768    }
769    pub fn apply_slices(
770        self,
771        high: &[f64],
772        low: &[f64],
773        close: &[f64],
774    ) -> Result<DiBatchOutput, DiError> {
775        di_batch_with_kernel(high, low, close, &self.range, self.kernel)
776    }
777    pub fn apply_candles(self, c: &Candles) -> Result<DiBatchOutput, DiError> {
778        let h = source_type(c, "high");
779        let l = source_type(c, "low");
780        let cl = source_type(c, "close");
781        self.apply_slices(h, l, cl)
782    }
783    pub fn with_default_candles(c: &Candles) -> Result<DiBatchOutput, DiError> {
784        DiBatchBuilder::new().kernel(Kernel::Auto).apply_candles(c)
785    }
786}
787
788pub fn di_batch_with_kernel(
789    high: &[f64],
790    low: &[f64],
791    close: &[f64],
792    sweep: &DiBatchRange,
793    k: Kernel,
794) -> Result<DiBatchOutput, DiError> {
795    let kernel = match k {
796        Kernel::Auto => detect_best_batch_kernel(),
797        other if other.is_batch() => other,
798        other => return Err(DiError::InvalidKernelForBatch(other)),
799    };
800    let simd = match kernel {
801        Kernel::Avx512Batch => Kernel::Avx512,
802        Kernel::Avx2Batch => Kernel::Avx2,
803        Kernel::ScalarBatch => Kernel::Scalar,
804        _ => unreachable!(),
805    };
806    di_batch_par_slice(high, low, close, sweep, simd)
807}
808
809#[derive(Clone, Debug)]
810pub struct DiBatchOutput {
811    pub plus: Vec<f64>,
812    pub minus: Vec<f64>,
813    pub combos: Vec<DiParams>,
814    pub rows: usize,
815    pub cols: usize,
816}
817impl DiBatchOutput {
818    pub fn row_for_params(&self, p: &DiParams) -> Option<usize> {
819        self.combos
820            .iter()
821            .position(|c| c.period.unwrap_or(14) == p.period.unwrap_or(14))
822    }
823    pub fn plus_for(&self, p: &DiParams) -> Option<&[f64]> {
824        self.row_for_params(p).map(|row| {
825            let start = row * self.cols;
826            &self.plus[start..start + self.cols]
827        })
828    }
829    pub fn minus_for(&self, p: &DiParams) -> Option<&[f64]> {
830        self.row_for_params(p).map(|row| {
831            let start = row * self.cols;
832            &self.minus[start..start + self.cols]
833        })
834    }
835}
836
837#[inline(always)]
838fn expand_grid(r: &DiBatchRange) -> Vec<DiParams> {
839    fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
840        if step == 0 || start == end {
841            return vec![start];
842        }
843        if start < end {
844            return (start..=end).step_by(step.max(1)).collect();
845        }
846
847        let mut v = Vec::new();
848        let mut cur = start;
849        while cur >= end {
850            v.push(cur);
851            if cur == end {
852                break;
853            }
854            cur = cur.saturating_sub(step.max(1));
855            if cur < end {
856                break;
857            }
858        }
859        v
860    }
861    let periods = axis_usize(r.period);
862    let mut out = Vec::with_capacity(periods.len());
863    for &p in &periods {
864        out.push(DiParams { period: Some(p) });
865    }
866    out
867}
868
869#[inline(always)]
870pub fn di_batch_slice(
871    high: &[f64],
872    low: &[f64],
873    close: &[f64],
874    sweep: &DiBatchRange,
875    kern: Kernel,
876) -> Result<DiBatchOutput, DiError> {
877    di_batch_inner(high, low, close, sweep, kern, false)
878}
879
880#[inline(always)]
881pub fn di_batch_par_slice(
882    high: &[f64],
883    low: &[f64],
884    close: &[f64],
885    sweep: &DiBatchRange,
886    kern: Kernel,
887) -> Result<DiBatchOutput, DiError> {
888    di_batch_inner(high, low, close, sweep, kern, true)
889}
890
891#[inline(always)]
892fn di_batch_inner_into(
893    high: &[f64],
894    low: &[f64],
895    close: &[f64],
896    sweep: &DiBatchRange,
897    kern: Kernel,
898    parallel: bool,
899    out_plus: &mut [f64],
900    out_minus: &mut [f64],
901) -> Result<Vec<DiParams>, DiError> {
902    let combos = expand_grid(sweep);
903    if combos.is_empty() {
904        let (s, e, st) = sweep.period;
905        return Err(DiError::InvalidRange {
906            start: s,
907            end: e,
908            step: st,
909        });
910    }
911    let n = high.len();
912    if n == 0 || low.len() != n || close.len() != n {
913        return Err(DiError::EmptyInputData);
914    }
915    let first = (0..n)
916        .find(|&i| !(high[i].is_nan() || low[i].is_nan() || close[i].is_nan()))
917        .ok_or(DiError::AllValuesNaN)?;
918    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
919    if n - first < max_p {
920        return Err(DiError::NotEnoughValidData {
921            needed: max_p,
922            valid: n - first,
923        });
924    }
925
926    let rows = combos.len();
927    let cols = n;
928
929    unsafe {
930        let total = rows.checked_mul(cols).ok_or(DiError::InvalidRange {
931            start: sweep.period.0,
932            end: sweep.period.1,
933            step: sweep.period.2,
934        })?;
935        let plus_mu = std::slice::from_raw_parts_mut(
936            out_plus.as_mut_ptr() as *mut std::mem::MaybeUninit<f64>,
937            total,
938        );
939        let minus_mu = std::slice::from_raw_parts_mut(
940            out_minus.as_mut_ptr() as *mut std::mem::MaybeUninit<f64>,
941            total,
942        );
943        let warm: Vec<usize> = combos
944            .iter()
945            .map(|c| first.saturating_add(c.period.unwrap()).saturating_sub(1))
946            .collect();
947        init_matrix_prefixes(plus_mu, cols, &warm);
948        init_matrix_prefixes(minus_mu, cols, &warm);
949    }
950
951    let mut up = vec![0.0f64; n];
952    let mut dn = vec![0.0f64; n];
953    let mut tr = vec![0.0f64; n];
954    {
955        let mut prev_h = high[first];
956        let mut prev_l = low[first];
957        let mut prev_c = close[first];
958        let mut i = first + 1;
959        while i < n {
960            let ch = high[i];
961            let cl = low[i];
962            let dp = ch - prev_h;
963            let dm = prev_l - cl;
964            if dp > dm && dp > 0.0 {
965                up[i] = dp;
966            }
967            if dm > dp && dm > 0.0 {
968                dn[i] = dm;
969            }
970            let mut t = ch - cl;
971            let t2 = (ch - prev_c).abs();
972            let t3 = (cl - prev_c).abs();
973            if t2 > t {
974                t = t2;
975            }
976            if t3 > t {
977                t = t3;
978            }
979            tr[i] = t;
980            prev_h = ch;
981            prev_l = cl;
982            prev_c = close[i];
983            i += 1;
984        }
985    }
986
987    let do_row = |row: usize, out_plus: &mut [f64], out_minus: &mut [f64]| unsafe {
988        let period = combos[row].period.unwrap();
989        let result = di_row_scalar_precomputed(&up, &dn, &tr, period, first, out_plus, out_minus);
990        debug_assert!(result.is_ok());
991    };
992
993    if parallel {
994        #[cfg(not(target_arch = "wasm32"))]
995        {
996            out_plus
997                .par_chunks_mut(cols)
998                .zip(out_minus.par_chunks_mut(cols))
999                .enumerate()
1000                .for_each(|(row, (pl, mi))| do_row(row, pl, mi));
1001        }
1002
1003        #[cfg(target_arch = "wasm32")]
1004        {
1005            for (row, (pl, mi)) in out_plus
1006                .chunks_mut(cols)
1007                .zip(out_minus.chunks_mut(cols))
1008                .enumerate()
1009            {
1010                do_row(row, pl, mi);
1011            }
1012        }
1013    } else {
1014        for (row, (pl, mi)) in out_plus
1015            .chunks_mut(cols)
1016            .zip(out_minus.chunks_mut(cols))
1017            .enumerate()
1018        {
1019            do_row(row, pl, mi);
1020        }
1021    }
1022
1023    Ok(combos)
1024}
1025
1026#[inline(always)]
1027fn di_batch_inner(
1028    high: &[f64],
1029    low: &[f64],
1030    close: &[f64],
1031    sweep: &DiBatchRange,
1032    kern: Kernel,
1033    parallel: bool,
1034) -> Result<DiBatchOutput, DiError> {
1035    let combos = expand_grid(sweep);
1036    if combos.is_empty() {
1037        let (s, e, st) = sweep.period;
1038        return Err(DiError::InvalidRange {
1039            start: s,
1040            end: e,
1041            step: st,
1042        });
1043    }
1044    let n = high.len();
1045    if n == 0 || low.len() != n || close.len() != n {
1046        return Err(DiError::EmptyInputData);
1047    }
1048    let first = (0..n)
1049        .find(|&i| !(high[i].is_nan() || low[i].is_nan() || close[i].is_nan()))
1050        .ok_or(DiError::AllValuesNaN)?;
1051    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1052    if n - first < max_p {
1053        return Err(DiError::NotEnoughValidData {
1054            needed: max_p,
1055            valid: n - first,
1056        });
1057    }
1058
1059    let rows = combos.len();
1060    let cols = n;
1061
1062    let _ = rows.checked_mul(cols).ok_or(DiError::InvalidRange {
1063        start: sweep.period.0,
1064        end: sweep.period.1,
1065        step: sweep.period.2,
1066    })?;
1067
1068    let warmup_periods: Vec<usize> = combos
1069        .iter()
1070        .map(|c| first.saturating_add(c.period.unwrap()).saturating_sub(1))
1071        .collect();
1072
1073    let mut plus_mu = make_uninit_matrix(rows, cols);
1074    let mut minus_mu = make_uninit_matrix(rows, cols);
1075
1076    init_matrix_prefixes(&mut plus_mu, cols, &warmup_periods);
1077    init_matrix_prefixes(&mut minus_mu, cols, &warmup_periods);
1078
1079    let mut plus_guard = core::mem::ManuallyDrop::new(plus_mu);
1080    let mut minus_guard = core::mem::ManuallyDrop::new(minus_mu);
1081    let plus: &mut [f64] = unsafe {
1082        core::slice::from_raw_parts_mut(plus_guard.as_mut_ptr() as *mut f64, plus_guard.len())
1083    };
1084    let minus: &mut [f64] = unsafe {
1085        core::slice::from_raw_parts_mut(minus_guard.as_mut_ptr() as *mut f64, minus_guard.len())
1086    };
1087
1088    let mut up = vec![0.0f64; n];
1089    let mut dn = vec![0.0f64; n];
1090    let mut tr = vec![0.0f64; n];
1091    {
1092        let mut prev_h = high[first];
1093        let mut prev_l = low[first];
1094        let mut prev_c = close[first];
1095        let mut i = first + 1;
1096        while i < n {
1097            let ch = high[i];
1098            let cl = low[i];
1099            let dp = ch - prev_h;
1100            let dm = prev_l - cl;
1101            if dp > dm && dp > 0.0 {
1102                up[i] = dp;
1103            }
1104            if dm > dp && dm > 0.0 {
1105                dn[i] = dm;
1106            }
1107            let mut t = ch - cl;
1108            let t2 = (ch - prev_c).abs();
1109            let t3 = (cl - prev_c).abs();
1110            if t2 > t {
1111                t = t2;
1112            }
1113            if t3 > t {
1114                t = t3;
1115            }
1116            tr[i] = t;
1117            prev_h = ch;
1118            prev_l = cl;
1119            prev_c = close[i];
1120            i += 1;
1121        }
1122    }
1123
1124    let do_row = |row: usize, out_plus: &mut [f64], out_minus: &mut [f64]| unsafe {
1125        let period = combos[row].period.unwrap();
1126        let result = di_row_scalar_precomputed(&up, &dn, &tr, period, first, out_plus, out_minus);
1127        debug_assert!(result.is_ok());
1128    };
1129
1130    if parallel {
1131        #[cfg(not(target_arch = "wasm32"))]
1132        {
1133            plus.par_chunks_mut(cols)
1134                .zip(minus.par_chunks_mut(cols))
1135                .enumerate()
1136                .for_each(|(row, (pl, mi))| do_row(row, pl, mi));
1137        }
1138
1139        #[cfg(target_arch = "wasm32")]
1140        {
1141            for (row, (pl, mi)) in plus
1142                .chunks_mut(cols)
1143                .zip(minus.chunks_mut(cols))
1144                .enumerate()
1145            {
1146                do_row(row, pl, mi);
1147            }
1148        }
1149    } else {
1150        for (row, (pl, mi)) in plus
1151            .chunks_mut(cols)
1152            .zip(minus.chunks_mut(cols))
1153            .enumerate()
1154        {
1155            do_row(row, pl, mi);
1156        }
1157    }
1158
1159    let plus = unsafe {
1160        Vec::from_raw_parts(
1161            plus_guard.as_mut_ptr() as *mut f64,
1162            plus_guard.len(),
1163            plus_guard.capacity(),
1164        )
1165    };
1166    let minus = unsafe {
1167        Vec::from_raw_parts(
1168            minus_guard.as_mut_ptr() as *mut f64,
1169            minus_guard.len(),
1170            minus_guard.capacity(),
1171        )
1172    };
1173
1174    Ok(DiBatchOutput {
1175        plus,
1176        minus,
1177        combos,
1178        rows,
1179        cols,
1180    })
1181}
1182
1183#[inline(always)]
1184pub unsafe fn di_row_scalar(
1185    high: &[f64],
1186    low: &[f64],
1187    close: &[f64],
1188    period: usize,
1189    first: usize,
1190    out_plus: &mut [f64],
1191    out_minus: &mut [f64],
1192) -> Result<(), DiError> {
1193    let n = high.len();
1194    if n == 0 {
1195        return Ok(());
1196    }
1197
1198    let pf = period as f64;
1199    let invp = pf.recip();
1200    let keep = 1.0 - invp;
1201
1202    let mut prev_h = high[first];
1203    let mut prev_l = low[first];
1204    let mut prev_c = close[first];
1205
1206    let start = first + 1;
1207    let stop = first + period;
1208    let mut plus_dm_sum = 0.0;
1209    let mut minus_dm_sum = 0.0;
1210    let mut tr_sum = 0.0;
1211
1212    let mut i = start;
1213    while i < stop {
1214        let ch = high[i];
1215        let cl = low[i];
1216        let cc = close[i];
1217
1218        let dp = ch - prev_h;
1219        let dm = prev_l - cl;
1220        if dp > dm && dp > 0.0 {
1221            plus_dm_sum += dp;
1222        }
1223        if dm > dp && dm > 0.0 {
1224            minus_dm_sum += dm;
1225        }
1226
1227        let mut tr = ch - cl;
1228        let tr2 = (ch - prev_c).abs();
1229        let tr3 = (cl - prev_c).abs();
1230        if tr2 > tr {
1231            tr = tr2;
1232        }
1233        if tr3 > tr {
1234            tr = tr3;
1235        }
1236        tr_sum += tr;
1237
1238        prev_h = ch;
1239        prev_l = cl;
1240        prev_c = cc;
1241        i += 1;
1242    }
1243
1244    let mut cur_plus = plus_dm_sum;
1245    let mut cur_minus = minus_dm_sum;
1246    let mut cur_tr = tr_sum;
1247
1248    let mut idx = stop - 1;
1249    let mut scale = if cur_tr == 0.0 { 0.0 } else { 100.0 / cur_tr };
1250    out_plus[idx] = cur_plus * scale;
1251    out_minus[idx] = cur_minus * scale;
1252    idx += 1;
1253
1254    while idx < n {
1255        let ch = high[idx];
1256        let cl = low[idx];
1257        let cc = close[idx];
1258
1259        let dp = ch - prev_h;
1260        let dm = prev_l - cl;
1261        let inc_p = if dp > dm && dp > 0.0 { dp } else { 0.0 };
1262        let inc_m = if dm > dp && dm > 0.0 { dm } else { 0.0 };
1263
1264        cur_plus = cur_plus.mul_add(keep, inc_p);
1265        cur_minus = cur_minus.mul_add(keep, inc_m);
1266
1267        let mut tr = ch - cl;
1268        let tr2 = (ch - prev_c).abs();
1269        let tr3 = (cl - prev_c).abs();
1270        if tr2 > tr {
1271            tr = tr2;
1272        }
1273        if tr3 > tr {
1274            tr = tr3;
1275        }
1276        cur_tr = cur_tr.mul_add(keep, tr);
1277
1278        scale = if cur_tr == 0.0 { 0.0 } else { 100.0 / cur_tr };
1279        out_plus[idx] = cur_plus * scale;
1280        out_minus[idx] = cur_minus * scale;
1281
1282        prev_h = ch;
1283        prev_l = cl;
1284        prev_c = cc;
1285        idx += 1;
1286    }
1287    Ok(())
1288}
1289
1290#[inline(always)]
1291pub unsafe fn di_row_scalar_precomputed(
1292    up: &[f64],
1293    dn: &[f64],
1294    tr: &[f64],
1295    period: usize,
1296    first: usize,
1297    out_plus: &mut [f64],
1298    out_minus: &mut [f64],
1299) -> Result<(), DiError> {
1300    let n = up.len();
1301    if n == 0 {
1302        return Ok(());
1303    }
1304
1305    let pf = period as f64;
1306    let invp = pf.recip();
1307    let keep = 1.0 - invp;
1308
1309    let start = first + 1;
1310    let stop = first + period;
1311    let mut plus_dm_sum = 0.0;
1312    let mut minus_dm_sum = 0.0;
1313    let mut tr_sum = 0.0;
1314
1315    let mut i = start;
1316    while i < stop {
1317        plus_dm_sum += up[i];
1318        minus_dm_sum += dn[i];
1319        tr_sum += tr[i];
1320        i += 1;
1321    }
1322
1323    let mut cur_plus = plus_dm_sum;
1324    let mut cur_minus = minus_dm_sum;
1325    let mut cur_tr = tr_sum;
1326
1327    let mut idx = stop - 1;
1328    let mut scale = if cur_tr == 0.0 { 0.0 } else { 100.0 / cur_tr };
1329    out_plus[idx] = cur_plus * scale;
1330    out_minus[idx] = cur_minus * scale;
1331    idx += 1;
1332
1333    while idx < n {
1334        cur_plus = cur_plus.mul_add(keep, up[idx]);
1335        cur_minus = cur_minus.mul_add(keep, dn[idx]);
1336        cur_tr = cur_tr.mul_add(keep, tr[idx]);
1337        scale = if cur_tr == 0.0 { 0.0 } else { 100.0 / cur_tr };
1338        out_plus[idx] = cur_plus * scale;
1339        out_minus[idx] = cur_minus * scale;
1340        idx += 1;
1341    }
1342    Ok(())
1343}
1344
1345#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1346#[inline(always)]
1347pub unsafe fn di_row_avx2(
1348    high: &[f64],
1349    low: &[f64],
1350    close: &[f64],
1351    period: usize,
1352    first: usize,
1353    out_plus: &mut [f64],
1354    out_minus: &mut [f64],
1355) -> Result<(), DiError> {
1356    di_row_scalar(high, low, close, period, first, out_plus, out_minus)
1357}
1358
1359#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1360#[inline(always)]
1361pub unsafe fn di_row_avx512(
1362    high: &[f64],
1363    low: &[f64],
1364    close: &[f64],
1365    period: usize,
1366    first: usize,
1367    out_plus: &mut [f64],
1368    out_minus: &mut [f64],
1369) -> Result<(), DiError> {
1370    di_row_scalar(high, low, close, period, first, out_plus, out_minus)
1371}
1372
1373#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1374#[inline(always)]
1375pub unsafe fn di_row_avx512_short(
1376    high: &[f64],
1377    low: &[f64],
1378    close: &[f64],
1379    period: usize,
1380    first: usize,
1381    out_plus: &mut [f64],
1382    out_minus: &mut [f64],
1383) -> Result<(), DiError> {
1384    di_row_avx512(high, low, close, period, first, out_plus, out_minus)
1385}
1386
1387#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1388#[inline(always)]
1389pub unsafe fn di_row_avx512_long(
1390    high: &[f64],
1391    low: &[f64],
1392    close: &[f64],
1393    period: usize,
1394    first: usize,
1395    out_plus: &mut [f64],
1396    out_minus: &mut [f64],
1397) -> Result<(), DiError> {
1398    di_row_avx512(high, low, close, period, first, out_plus, out_minus)
1399}
1400
1401#[inline(always)]
1402fn true_range(current_high: f64, current_low: f64, prev_close: f64) -> f64 {
1403    let mut tr1 = current_high - current_low;
1404    let tr2 = (current_high - prev_close).abs();
1405    let tr3 = (current_low - prev_close).abs();
1406    if tr2 > tr1 {
1407        tr1 = tr2;
1408    }
1409    if tr3 > tr1 {
1410        tr1 = tr3;
1411    }
1412    tr1
1413}
1414
1415#[cfg(test)]
1416mod tests {
1417    use super::*;
1418    use crate::skip_if_unsupported;
1419    use crate::utilities::data_loader::read_candles_from_csv;
1420
1421    fn check_di_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1422        skip_if_unsupported!(kernel, test_name);
1423        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1424        let candles = read_candles_from_csv(file_path)?;
1425        let default_params = DiParams { period: None };
1426        let input_default = DiInput::from_candles(&candles, default_params);
1427        let output_default = di_with_kernel(&input_default, kernel)?;
1428        assert_eq!(output_default.plus.len(), candles.close.len());
1429        assert_eq!(output_default.minus.len(), candles.close.len());
1430        let params_period_10 = DiParams { period: Some(10) };
1431        let input_period_10 = DiInput::from_candles(&candles, params_period_10);
1432        let output_period_10 = di_with_kernel(&input_period_10, kernel)?;
1433        assert_eq!(output_period_10.plus.len(), candles.close.len());
1434        assert_eq!(output_period_10.minus.len(), candles.close.len());
1435        Ok(())
1436    }
1437    fn check_di_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1438        skip_if_unsupported!(kernel, test_name);
1439        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1440        let candles = read_candles_from_csv(file_path)?;
1441        let params = DiParams { period: Some(14) };
1442        let input = DiInput::from_candles(&candles, params);
1443        let di_result = di_with_kernel(&input, kernel)?;
1444        assert_eq!(di_result.plus.len(), candles.close.len());
1445        assert_eq!(di_result.minus.len(), candles.close.len());
1446        let test_plus = [
1447            10.99067007335658,
1448            11.306993269828585,
1449            10.948661818939213,
1450            10.683207768215592,
1451            9.802180952619183,
1452        ];
1453        let test_minus = [
1454            28.06728094177839,
1455            27.331240567633152,
1456            27.759989125359493,
1457            26.951434842917386,
1458            30.748897303623057,
1459        ];
1460        if di_result.plus.len() > 5 {
1461            let plus_tail = &di_result.plus[di_result.plus.len() - 5..];
1462            let minus_tail = &di_result.minus[di_result.minus.len() - 5..];
1463            for i in 0..5 {
1464                assert!(
1465                    (plus_tail[i] - test_plus[i]).abs() < 1e-6,
1466                    "Mismatch in +DI at tail index {}: expected {}, got {}",
1467                    i,
1468                    test_plus[i],
1469                    plus_tail[i]
1470                );
1471                assert!(
1472                    (minus_tail[i] - test_minus[i]).abs() < 1e-6,
1473                    "Mismatch in -DI at tail index {}: expected {}, got {}",
1474                    i,
1475                    test_minus[i],
1476                    minus_tail[i]
1477                );
1478            }
1479        }
1480        Ok(())
1481    }
1482    fn check_di_with_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1483        skip_if_unsupported!(kernel, test_name);
1484        let high = [10.0, 11.0, 12.0];
1485        let low = [9.0, 8.0, 7.0];
1486        let close = [9.5, 10.0, 11.0];
1487        let params = DiParams { period: Some(0) };
1488        let input = DiInput::from_slices(&high, &low, &close, params);
1489        let result = di_with_kernel(&input, kernel);
1490        assert!(result.is_err());
1491        Ok(())
1492    }
1493    fn check_di_with_period_exceeding_data_length(
1494        test_name: &str,
1495        kernel: Kernel,
1496    ) -> Result<(), Box<dyn Error>> {
1497        skip_if_unsupported!(kernel, test_name);
1498        let high = [10.0, 11.0, 12.0];
1499        let low = [9.0, 8.0, 7.0];
1500        let close = [9.5, 10.0, 11.0];
1501        let params = DiParams { period: Some(10) };
1502        let input = DiInput::from_slices(&high, &low, &close, params);
1503        let result = di_with_kernel(&input, kernel);
1504        assert!(result.is_err());
1505        Ok(())
1506    }
1507    fn check_di_very_small_data_set(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1508        skip_if_unsupported!(kernel, test_name);
1509        let high = [42.0];
1510        let low = [41.0];
1511        let close = [41.5];
1512        let params = DiParams { period: Some(14) };
1513        let input = DiInput::from_slices(&high, &low, &close, params);
1514        let result = di_with_kernel(&input, kernel);
1515        assert!(result.is_err());
1516        Ok(())
1517    }
1518    fn check_di_with_slice_data_reinput(
1519        test_name: &str,
1520        kernel: Kernel,
1521    ) -> Result<(), Box<dyn Error>> {
1522        skip_if_unsupported!(kernel, test_name);
1523        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1524        let candles = read_candles_from_csv(file_path)?;
1525        let first_params = DiParams { period: Some(14) };
1526        let first_input = DiInput::from_candles(&candles, first_params);
1527        let first_result = di_with_kernel(&first_input, kernel)?;
1528        assert_eq!(first_result.plus.len(), candles.close.len());
1529        assert_eq!(first_result.minus.len(), candles.close.len());
1530        let second_params = DiParams { period: Some(14) };
1531        let second_input = DiInput::from_slices(
1532            &first_result.plus,
1533            &first_result.minus,
1534            &candles.close,
1535            second_params,
1536        );
1537        let second_result = di_with_kernel(&second_input, kernel)?;
1538        assert_eq!(second_result.plus.len(), first_result.plus.len());
1539        assert_eq!(second_result.minus.len(), first_result.minus.len());
1540        Ok(())
1541    }
1542    fn check_di_accuracy_nan_check(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1543        skip_if_unsupported!(kernel, test_name);
1544        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1545        let candles = read_candles_from_csv(file_path)?;
1546        let params = DiParams { period: Some(14) };
1547        let input = DiInput::from_candles(&candles, params);
1548        let di_result = di_with_kernel(&input, kernel)?;
1549        assert_eq!(di_result.plus.len(), candles.close.len());
1550        assert_eq!(di_result.minus.len(), candles.close.len());
1551        if di_result.plus.len() > 40 {
1552            for i in 40..di_result.plus.len() {
1553                assert!(!di_result.plus[i].is_nan());
1554                assert!(!di_result.minus[i].is_nan());
1555            }
1556        }
1557        Ok(())
1558    }
1559    #[cfg(debug_assertions)]
1560    fn check_di_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1561        skip_if_unsupported!(kernel, test_name);
1562
1563        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1564        let candles = read_candles_from_csv(file_path)?;
1565
1566        let test_params = vec![
1567            DiParams::default(),
1568            DiParams { period: Some(2) },
1569            DiParams { period: Some(5) },
1570            DiParams { period: Some(7) },
1571            DiParams { period: Some(10) },
1572            DiParams { period: Some(20) },
1573            DiParams { period: Some(30) },
1574            DiParams { period: Some(50) },
1575            DiParams { period: Some(100) },
1576            DiParams { period: Some(200) },
1577        ];
1578
1579        for (param_idx, params) in test_params.iter().enumerate() {
1580            let input = DiInput::from_candles(&candles, params.clone());
1581            let output = di_with_kernel(&input, kernel)?;
1582
1583            for (i, &val) in output.plus.iter().enumerate() {
1584                if val.is_nan() {
1585                    continue;
1586                }
1587
1588                let bits = val.to_bits();
1589
1590                if bits == 0x11111111_11111111 {
1591                    panic!(
1592                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1593						 in plus output with params: {:?} (param set {})",
1594                        test_name, val, bits, i, params, param_idx
1595                    );
1596                }
1597
1598                if bits == 0x22222222_22222222 {
1599                    panic!(
1600                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1601						 in plus output with params: {:?} (param set {})",
1602                        test_name, val, bits, i, params, param_idx
1603                    );
1604                }
1605
1606                if bits == 0x33333333_33333333 {
1607                    panic!(
1608                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1609						 in plus output with params: {:?} (param set {})",
1610                        test_name, val, bits, i, params, param_idx
1611                    );
1612                }
1613            }
1614
1615            for (i, &val) in output.minus.iter().enumerate() {
1616                if val.is_nan() {
1617                    continue;
1618                }
1619
1620                let bits = val.to_bits();
1621
1622                if bits == 0x11111111_11111111 {
1623                    panic!(
1624                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1625						 in minus output with params: {:?} (param set {})",
1626                        test_name, val, bits, i, params, param_idx
1627                    );
1628                }
1629
1630                if bits == 0x22222222_22222222 {
1631                    panic!(
1632                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1633						 in minus output with params: {:?} (param set {})",
1634                        test_name, val, bits, i, params, param_idx
1635                    );
1636                }
1637
1638                if bits == 0x33333333_33333333 {
1639                    panic!(
1640                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1641						 in minus output with params: {:?} (param set {})",
1642                        test_name, val, bits, i, params, param_idx
1643                    );
1644                }
1645            }
1646        }
1647
1648        Ok(())
1649    }
1650
1651    #[cfg(not(debug_assertions))]
1652    fn check_di_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1653        Ok(())
1654    }
1655
1656    #[cfg(test)]
1657    #[allow(clippy::float_cmp)]
1658    fn check_di_property(
1659        test_name: &str,
1660        kernel: Kernel,
1661    ) -> Result<(), Box<dyn std::error::Error>> {
1662        use proptest::prelude::*;
1663        skip_if_unsupported!(kernel, test_name);
1664
1665        let strat = (2usize..=50)
1666            .prop_flat_map(|period| {
1667                (
1668                    100.0f64..5000.0f64,
1669                    (period + 20)..400,
1670                    0.001f64..0.05f64,
1671                    -0.01f64..0.01f64,
1672                    Just(period),
1673                )
1674            })
1675            .prop_map(|(base_price, data_len, volatility, trend, period)| {
1676                let mut high = Vec::with_capacity(data_len);
1677                let mut low = Vec::with_capacity(data_len);
1678                let mut close = Vec::with_capacity(data_len);
1679
1680                let mut price = base_price;
1681
1682                for i in 0..data_len {
1683                    let trend_component = trend * i as f64;
1684                    let random_component = ((i * 137 + 11) % 100) as f64 / 100.0 - 0.5;
1685                    price = price * (1.0 + trend_component + random_component * volatility);
1686
1687                    price = price.max(1.0);
1688
1689                    let daily_range = price * volatility * (1.0 + ((i * 73) % 50) as f64 / 100.0);
1690                    let h = price + daily_range * 0.5;
1691                    let l = price - daily_range * 0.5;
1692
1693                    let close_factor = ((i * 29 + 7) % 100) as f64 / 100.0;
1694                    let c = l + (h - l) * close_factor;
1695
1696                    high.push(h);
1697                    low.push(l);
1698                    close.push(c);
1699                }
1700
1701                (high, low, close, period)
1702            });
1703
1704        proptest::test_runner::TestRunner::default()
1705            .run(&strat, |(high, low, close, period)| {
1706                let params = DiParams {
1707                    period: Some(period),
1708                };
1709                let input = DiInput::from_slices(&high, &low, &close, params);
1710
1711                let output = di_with_kernel(&input, kernel)?;
1712
1713                let ref_output = di_with_kernel(&input, Kernel::Scalar)?;
1714
1715                prop_assert_eq!(output.plus.len(), high.len());
1716                prop_assert_eq!(output.minus.len(), high.len());
1717
1718                let warmup_end = period - 1;
1719                for i in 0..warmup_end {
1720                    prop_assert!(
1721                        output.plus[i].is_nan(),
1722                        "Expected NaN at index {} during warmup, got {}",
1723                        i,
1724                        output.plus[i]
1725                    );
1726                    prop_assert!(
1727                        output.minus[i].is_nan(),
1728                        "Expected NaN at index {} during warmup, got {}",
1729                        i,
1730                        output.minus[i]
1731                    );
1732                }
1733
1734                for i in warmup_end..high.len() {
1735                    let plus_val = output.plus[i];
1736                    let minus_val = output.minus[i];
1737
1738                    prop_assert!(
1739                        plus_val.is_finite(),
1740                        "Expected finite +DI at index {}, got {}",
1741                        i,
1742                        plus_val
1743                    );
1744                    prop_assert!(
1745                        minus_val.is_finite(),
1746                        "Expected finite -DI at index {}, got {}",
1747                        i,
1748                        minus_val
1749                    );
1750
1751                    prop_assert!(
1752                        plus_val >= 0.0 && plus_val <= 100.0,
1753                        "+DI at index {} = {} is out of range [0, 100]",
1754                        i,
1755                        plus_val
1756                    );
1757                    prop_assert!(
1758                        minus_val >= 0.0 && minus_val <= 100.0,
1759                        "-DI at index {} = {} is out of range [0, 100]",
1760                        i,
1761                        minus_val
1762                    );
1763                }
1764
1765                for i in 0..high.len() {
1766                    let plus_val = output.plus[i];
1767                    let minus_val = output.minus[i];
1768                    let ref_plus = ref_output.plus[i];
1769                    let ref_minus = ref_output.minus[i];
1770
1771                    if plus_val.is_nan() || ref_plus.is_nan() {
1772                        prop_assert_eq!(
1773                            plus_val.is_nan(),
1774                            ref_plus.is_nan(),
1775                            "NaN mismatch in +DI at index {}",
1776                            i
1777                        );
1778                    } else {
1779                        prop_assert!(
1780                            (plus_val - ref_plus).abs() <= 1e-9,
1781                            "+DI mismatch at index {}: {} vs {} (diff: {})",
1782                            i,
1783                            plus_val,
1784                            ref_plus,
1785                            (plus_val - ref_plus).abs()
1786                        );
1787                    }
1788
1789                    if minus_val.is_nan() || ref_minus.is_nan() {
1790                        prop_assert_eq!(
1791                            minus_val.is_nan(),
1792                            ref_minus.is_nan(),
1793                            "NaN mismatch in -DI at index {}",
1794                            i
1795                        );
1796                    } else {
1797                        prop_assert!(
1798                            (minus_val - ref_minus).abs() <= 1e-9,
1799                            "-DI mismatch at index {}: {} vs {} (diff: {})",
1800                            i,
1801                            minus_val,
1802                            ref_minus,
1803                            (minus_val - ref_minus).abs()
1804                        );
1805                    }
1806                }
1807
1808                let constant_high = vec![100.0; 50];
1809                let constant_low = vec![100.0; 50];
1810                let constant_close = vec![100.0; 50];
1811                let const_params = DiParams {
1812                    period: Some(period),
1813                };
1814                let const_input = DiInput::from_slices(
1815                    &constant_high,
1816                    &constant_low,
1817                    &constant_close,
1818                    const_params,
1819                );
1820
1821                if let Ok(const_output) = di_with_kernel(&const_input, kernel) {
1822                    for i in (period + 5)..constant_high.len() {
1823                        prop_assert!(
1824                            const_output.plus[i] < 1.0,
1825                            "Expected near-zero +DI for constant prices, got {} at index {}",
1826                            const_output.plus[i],
1827                            i
1828                        );
1829                        prop_assert!(
1830                            const_output.minus[i] < 1.0,
1831                            "Expected near-zero -DI for constant prices, got {} at index {}",
1832                            const_output.minus[i],
1833                            i
1834                        );
1835                    }
1836                }
1837
1838                let mut spike_high = vec![100.0; 30];
1839                let mut spike_low = vec![99.0; 30];
1840                let mut spike_close = vec![99.5; 30];
1841
1842                spike_high[15] = 120.0;
1843                spike_low[15] = 80.0;
1844                spike_close[15] = 100.0;
1845
1846                let spike_params = DiParams {
1847                    period: Some(period.min(10)),
1848                };
1849                let spike_input =
1850                    DiInput::from_slices(&spike_high, &spike_low, &spike_close, spike_params);
1851
1852                if let Ok(spike_output) = di_with_kernel(&spike_input, kernel) {
1853                    for (i, (&plus, &minus)) in spike_output
1854                        .plus
1855                        .iter()
1856                        .zip(spike_output.minus.iter())
1857                        .enumerate()
1858                    {
1859                        if !plus.is_nan() {
1860                            prop_assert!(
1861                                plus >= 0.0 && plus <= 100.0,
1862                                "Volatility spike caused +DI out of range at {}: {}",
1863                                i,
1864                                plus
1865                            );
1866                        }
1867                        if !minus.is_nan() {
1868                            prop_assert!(
1869                                minus >= 0.0 && minus <= 100.0,
1870                                "Volatility spike caused -DI out of range at {}: {}",
1871                                i,
1872                                minus
1873                            );
1874                        }
1875                    }
1876                }
1877
1878                if period <= 20 {
1879                    let trend_len = 200;
1880                    let mut uptrend_high = Vec::with_capacity(trend_len);
1881                    let mut uptrend_low = Vec::with_capacity(trend_len);
1882                    let mut uptrend_close = Vec::with_capacity(trend_len);
1883
1884                    for i in 0..trend_len {
1885                        let price = 100.0 + i as f64 * 2.0;
1886                        uptrend_high.push(price + 1.0);
1887                        uptrend_low.push(price - 1.0);
1888                        uptrend_close.push(price);
1889                    }
1890
1891                    let trend_params = DiParams {
1892                        period: Some(period),
1893                    };
1894                    let trend_input = DiInput::from_slices(
1895                        &uptrend_high,
1896                        &uptrend_low,
1897                        &uptrend_close,
1898                        trend_params,
1899                    );
1900
1901                    if let Ok(trend_output) = di_with_kernel(&trend_input, kernel) {
1902                        let check_start = trend_len / 2;
1903                        let mut plus_wins = 0;
1904                        let mut minus_wins = 0;
1905
1906                        for i in check_start..trend_len {
1907                            if trend_output.plus[i] > trend_output.minus[i] {
1908                                plus_wins += 1;
1909                            } else {
1910                                minus_wins += 1;
1911                            }
1912                        }
1913
1914                        if plus_wins + minus_wins > 0 {
1915                            prop_assert!(
1916								plus_wins > minus_wins,
1917								"In uptrend, expected +DI > -DI more often. +DI wins: {}, -DI wins: {} (period: {})",
1918								plus_wins, minus_wins, period
1919							);
1920                        }
1921                    }
1922                }
1923
1924                Ok(())
1925            })
1926            .unwrap();
1927
1928        Ok(())
1929    }
1930
1931    #[test]
1932    fn test_di_into_matches_api() -> Result<(), Box<dyn Error>> {
1933        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1934        let candles = read_candles_from_csv(file_path)?;
1935
1936        let params = DiParams::default();
1937        let input = DiInput::from_candles(&candles, params);
1938
1939        let baseline = di(&input)?;
1940
1941        let n = candles.close.len();
1942        let mut plus = vec![0.0_f64; n];
1943        let mut minus = vec![0.0_f64; n];
1944        di_into(&input, &mut plus, &mut minus)?;
1945
1946        assert_eq!(baseline.plus.len(), plus.len());
1947        assert_eq!(baseline.minus.len(), minus.len());
1948
1949        fn eq_or_both_nan(a: f64, b: f64) -> bool {
1950            (a.is_nan() && b.is_nan()) || (a == b) || ((a - b).abs() <= 1e-12)
1951        }
1952
1953        for i in 0..n {
1954            assert!(
1955                eq_or_both_nan(baseline.plus[i], plus[i]),
1956                "+DI mismatch at index {}: baseline={} vs into={}",
1957                i,
1958                baseline.plus[i],
1959                plus[i]
1960            );
1961            assert!(
1962                eq_or_both_nan(baseline.minus[i], minus[i]),
1963                "-DI mismatch at index {}: baseline={} vs into={}",
1964                i,
1965                baseline.minus[i],
1966                minus[i]
1967            );
1968        }
1969        Ok(())
1970    }
1971
1972    macro_rules! generate_all_di_tests {
1973        ($($test_fn:ident),*) => {
1974            paste::paste! {
1975                $(
1976                    #[test]
1977                    fn [<$test_fn _scalar_f64>]() {
1978                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1979                    }
1980                )*
1981                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1982                $(
1983                    #[test]
1984                    fn [<$test_fn _avx2_f64>]() {
1985                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1986                    }
1987                    #[test]
1988                    fn [<$test_fn _avx512_f64>]() {
1989                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1990                    }
1991                )*
1992            }
1993        }
1994    }
1995    generate_all_di_tests!(
1996        check_di_partial_params,
1997        check_di_accuracy,
1998        check_di_with_zero_period,
1999        check_di_with_period_exceeding_data_length,
2000        check_di_very_small_data_set,
2001        check_di_with_slice_data_reinput,
2002        check_di_accuracy_nan_check,
2003        check_di_no_poison
2004    );
2005
2006    #[cfg(test)]
2007    generate_all_di_tests!(check_di_property);
2008    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2009        skip_if_unsupported!(kernel, test);
2010
2011        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2012        let c = read_candles_from_csv(file)?;
2013
2014        let output = DiBatchBuilder::new().kernel(kernel).apply_candles(&c)?;
2015
2016        let def = DiParams::default();
2017        let row = output.plus_for(&def).expect("default row missing");
2018
2019        assert_eq!(row.len(), c.close.len());
2020
2021        let scalar = DiInput::from_candles(&c, DiParams::default());
2022        let scalar_out = di_with_kernel(&scalar, Kernel::Scalar)?;
2023        let plus_tail = &row[row.len() - 5..];
2024        let scalar_tail = &scalar_out.plus[scalar_out.plus.len() - 5..];
2025        for (i, (&a, &b)) in plus_tail.iter().zip(scalar_tail.iter()).enumerate() {
2026            assert!(
2027                (a - b).abs() < 1e-8,
2028                "[{test}] batch/scalar plus mismatch idx={i}: {a} vs {b}"
2029            );
2030        }
2031        Ok(())
2032    }
2033
2034    fn check_batch_period_range(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2035        skip_if_unsupported!(kernel, test);
2036
2037        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2038        let c = read_candles_from_csv(file)?;
2039
2040        let output = DiBatchBuilder::new()
2041            .kernel(kernel)
2042            .period_range(10, 20, 5)
2043            .apply_candles(&c)?;
2044
2045        assert_eq!(output.rows, 3);
2046        assert_eq!(output.cols, c.close.len());
2047
2048        let periods = [10, 15, 20];
2049        for (i, p) in periods.iter().enumerate() {
2050            let param = DiParams { period: Some(*p) };
2051            let plus = output.plus_for(&param).expect("plus missing");
2052            let minus = output.minus_for(&param).expect("minus missing");
2053            assert_eq!(plus.len(), c.close.len());
2054            assert_eq!(minus.len(), c.close.len());
2055        }
2056        Ok(())
2057    }
2058
2059    #[cfg(debug_assertions)]
2060    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2061        skip_if_unsupported!(kernel, test);
2062
2063        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2064        let c = read_candles_from_csv(file)?;
2065
2066        let test_configs = vec![
2067            (2, 10, 2),
2068            (5, 25, 5),
2069            (30, 60, 15),
2070            (2, 5, 1),
2071            (10, 10, 0),
2072            (14, 14, 0),
2073            (50, 50, 0),
2074            (7, 21, 7),
2075        ];
2076
2077        for (cfg_idx, &(p_start, p_end, p_step)) in test_configs.iter().enumerate() {
2078            let output = DiBatchBuilder::new()
2079                .kernel(kernel)
2080                .period_range(p_start, p_end, p_step)
2081                .apply_candles(&c)?;
2082
2083            for (idx, &val) in output.plus.iter().enumerate() {
2084                if val.is_nan() {
2085                    continue;
2086                }
2087
2088                let bits = val.to_bits();
2089                let row = idx / output.cols;
2090                let col = idx % output.cols;
2091                let combo = &output.combos[row];
2092
2093                if bits == 0x11111111_11111111 {
2094                    panic!(
2095                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2096						at row {} col {} (flat index {}) in plus output with params: period={}",
2097                        test,
2098                        cfg_idx,
2099                        val,
2100                        bits,
2101                        row,
2102                        col,
2103                        idx,
2104                        combo.period.unwrap_or(14)
2105                    );
2106                }
2107
2108                if bits == 0x22222222_22222222 {
2109                    panic!(
2110                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2111						at row {} col {} (flat index {}) in plus output with params: period={}",
2112                        test,
2113                        cfg_idx,
2114                        val,
2115                        bits,
2116                        row,
2117                        col,
2118                        idx,
2119                        combo.period.unwrap_or(14)
2120                    );
2121                }
2122
2123                if bits == 0x33333333_33333333 {
2124                    panic!(
2125                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2126						at row {} col {} (flat index {}) in plus output with params: period={}",
2127                        test,
2128                        cfg_idx,
2129                        val,
2130                        bits,
2131                        row,
2132                        col,
2133                        idx,
2134                        combo.period.unwrap_or(14)
2135                    );
2136                }
2137            }
2138
2139            for (idx, &val) in output.minus.iter().enumerate() {
2140                if val.is_nan() {
2141                    continue;
2142                }
2143
2144                let bits = val.to_bits();
2145                let row = idx / output.cols;
2146                let col = idx % output.cols;
2147                let combo = &output.combos[row];
2148
2149                if bits == 0x11111111_11111111 {
2150                    panic!(
2151                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2152						at row {} col {} (flat index {}) in minus output with params: period={}",
2153                        test,
2154                        cfg_idx,
2155                        val,
2156                        bits,
2157                        row,
2158                        col,
2159                        idx,
2160                        combo.period.unwrap_or(14)
2161                    );
2162                }
2163
2164                if bits == 0x22222222_22222222 {
2165                    panic!(
2166                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2167						at row {} col {} (flat index {}) in minus output with params: period={}",
2168                        test,
2169                        cfg_idx,
2170                        val,
2171                        bits,
2172                        row,
2173                        col,
2174                        idx,
2175                        combo.period.unwrap_or(14)
2176                    );
2177                }
2178
2179                if bits == 0x33333333_33333333 {
2180                    panic!(
2181                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2182						at row {} col {} (flat index {}) in minus output with params: period={}",
2183                        test,
2184                        cfg_idx,
2185                        val,
2186                        bits,
2187                        row,
2188                        col,
2189                        idx,
2190                        combo.period.unwrap_or(14)
2191                    );
2192                }
2193            }
2194        }
2195
2196        Ok(())
2197    }
2198
2199    #[cfg(not(debug_assertions))]
2200    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2201        Ok(())
2202    }
2203
2204    macro_rules! gen_batch_tests {
2205        ($fn_name:ident) => {
2206            paste::paste! {
2207                #[test] fn [<$fn_name _scalar>]()      {
2208                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2209                }
2210                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2211                #[test] fn [<$fn_name _avx2>]()        {
2212                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2213                }
2214                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2215                #[test] fn [<$fn_name _avx512>]()      {
2216                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2217                }
2218                #[test] fn [<$fn_name _auto_detect>]() {
2219                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2220                }
2221            }
2222        };
2223    }
2224
2225    gen_batch_tests!(check_batch_default_row);
2226    gen_batch_tests!(check_batch_period_range);
2227    gen_batch_tests!(check_batch_no_poison);
2228}
2229
2230#[cfg(feature = "python")]
2231#[pyfunction(name = "di")]
2232#[pyo3(signature = (high, low, close, period, kernel=None))]
2233pub fn di_py<'py>(
2234    py: Python<'py>,
2235    high: PyReadonlyArray1<'py, f64>,
2236    low: PyReadonlyArray1<'py, f64>,
2237    close: PyReadonlyArray1<'py, f64>,
2238    period: usize,
2239    kernel: Option<&str>,
2240) -> PyResult<(Bound<'py, PyArray1<f64>>, Bound<'py, PyArray1<f64>>)> {
2241    let high_slice = high.as_slice()?;
2242    let low_slice = low.as_slice()?;
2243    let close_slice = close.as_slice()?;
2244    let kern = validate_kernel(kernel, false)?;
2245
2246    let params = DiParams {
2247        period: Some(period),
2248    };
2249    let input = DiInput::from_slices(high_slice, low_slice, close_slice, params);
2250
2251    let (plus_vec, minus_vec) = py
2252        .allow_threads(|| di_with_kernel(&input, kern).map(|o| (o.plus, o.minus)))
2253        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2254
2255    Ok((plus_vec.into_pyarray(py), minus_vec.into_pyarray(py)))
2256}
2257
2258#[cfg(feature = "python")]
2259#[pyclass(name = "DiStream")]
2260pub struct DiStreamPy {
2261    inner: DiStream,
2262}
2263
2264#[cfg(feature = "python")]
2265#[pymethods]
2266impl DiStreamPy {
2267    #[new]
2268    pub fn new(period: usize) -> PyResult<Self> {
2269        let params = DiParams {
2270            period: Some(period),
2271        };
2272        let inner = DiStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2273        Ok(DiStreamPy { inner })
2274    }
2275
2276    pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<(f64, f64)> {
2277        self.inner.update(high, low, close)
2278    }
2279}
2280
2281#[cfg(feature = "python")]
2282#[pyfunction(name = "di_batch")]
2283#[pyo3(signature = (high, low, close, period_range, kernel=None))]
2284pub fn di_batch_py<'py>(
2285    py: Python<'py>,
2286    high: PyReadonlyArray1<'py, f64>,
2287    low: PyReadonlyArray1<'py, f64>,
2288    close: PyReadonlyArray1<'py, f64>,
2289    period_range: (usize, usize, usize),
2290    kernel: Option<&str>,
2291) -> PyResult<Bound<'py, PyDict>> {
2292    let high_slice = high.as_slice()?;
2293    let low_slice = low.as_slice()?;
2294    let close_slice = close.as_slice()?;
2295    let kern = validate_kernel(kernel, true)?;
2296
2297    let sweep = DiBatchRange {
2298        period: period_range,
2299    };
2300
2301    let combos = expand_grid(&sweep);
2302    let rows = combos.len();
2303    let cols = high_slice.len();
2304    let total = rows
2305        .checked_mul(cols)
2306        .ok_or_else(|| PyValueError::new_err("di: size overflow in rows*cols"))?;
2307
2308    let out_plus = unsafe { PyArray1::<f64>::new(py, [total], false) };
2309    let out_minus = unsafe { PyArray1::<f64>::new(py, [total], false) };
2310
2311    let pl = unsafe { out_plus.as_slice_mut()? };
2312    let mi = unsafe { out_minus.as_slice_mut()? };
2313
2314    let first = (0..cols)
2315        .find(|&i| !(high_slice[i].is_nan() || low_slice[i].is_nan() || close_slice[i].is_nan()))
2316        .ok_or_else(|| PyValueError::new_err("di: All values are NaN"))?;
2317    let warm: Vec<usize> = combos
2318        .iter()
2319        .map(|p| first + p.period.unwrap() - 1)
2320        .collect();
2321
2322    unsafe {
2323        let plus_mu = std::slice::from_raw_parts_mut(
2324            pl.as_mut_ptr() as *mut std::mem::MaybeUninit<f64>,
2325            total,
2326        );
2327        let minus_mu = std::slice::from_raw_parts_mut(
2328            mi.as_mut_ptr() as *mut std::mem::MaybeUninit<f64>,
2329            total,
2330        );
2331        init_matrix_prefixes(plus_mu, cols, &warm);
2332        init_matrix_prefixes(minus_mu, cols, &warm);
2333    }
2334
2335    let combos = py
2336        .allow_threads(|| {
2337            let simd = match kern {
2338                Kernel::Auto => match detect_best_batch_kernel() {
2339                    Kernel::Avx512Batch => Kernel::Avx512,
2340                    Kernel::Avx2Batch => Kernel::Avx2,
2341                    _ => Kernel::Scalar,
2342                },
2343                Kernel::Avx512Batch => Kernel::Avx512,
2344                Kernel::Avx2Batch => Kernel::Avx2,
2345                Kernel::ScalarBatch => Kernel::Scalar,
2346                k => k,
2347            };
2348
2349            di_batch_inner_into(
2350                high_slice,
2351                low_slice,
2352                close_slice,
2353                &sweep,
2354                simd,
2355                true,
2356                pl,
2357                mi,
2358            )
2359        })
2360        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2361
2362    let dict = PyDict::new(py);
2363    dict.set_item("plus", out_plus.reshape((rows, cols))?)?;
2364    dict.set_item("minus", out_minus.reshape((rows, cols))?)?;
2365    dict.set_item(
2366        "periods",
2367        combos
2368            .iter()
2369            .map(|p| p.period.unwrap() as u64)
2370            .collect::<Vec<_>>()
2371            .into_pyarray(py),
2372    )?;
2373
2374    Ok(dict)
2375}
2376
2377#[cfg(all(feature = "python", feature = "cuda"))]
2378use crate::cuda::cuda_available;
2379#[cfg(all(feature = "python", feature = "cuda"))]
2380use crate::cuda::moving_averages::DeviceArrayF32;
2381#[cfg(all(feature = "python", feature = "cuda"))]
2382use crate::cuda::CudaDi;
2383#[cfg(all(feature = "python", feature = "cuda"))]
2384use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
2385#[cfg(all(feature = "python", feature = "cuda"))]
2386use cust::context::Context;
2387#[cfg(all(feature = "python", feature = "cuda"))]
2388use cust::memory::DeviceBuffer;
2389#[cfg(all(feature = "python", feature = "cuda"))]
2390use pyo3::prelude::*;
2391#[cfg(all(feature = "python", feature = "cuda"))]
2392use std::os::raw::c_void;
2393#[cfg(all(feature = "python", feature = "cuda"))]
2394use std::sync::Arc;
2395
2396#[cfg(all(feature = "python", feature = "cuda"))]
2397#[pyclass(module = "ta_indicators.cuda", unsendable)]
2398pub struct DeviceArrayF32DiPy {
2399    pub(crate) inner: DeviceArrayF32,
2400    pub(crate) ctx: Arc<Context>,
2401    pub(crate) device_id: u32,
2402}
2403
2404#[cfg(all(feature = "python", feature = "cuda"))]
2405#[pymethods]
2406impl DeviceArrayF32DiPy {
2407    #[getter]
2408    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2409        let itemsize = std::mem::size_of::<f32>();
2410        let d = PyDict::new(py);
2411        d.set_item("shape", (self.inner.rows, self.inner.cols))?;
2412        d.set_item("typestr", "<f4")?;
2413        d.set_item("strides", (self.inner.cols * itemsize, itemsize))?;
2414        d.set_item("data", (self.inner.device_ptr() as usize, false))?;
2415
2416        d.set_item("version", 3)?;
2417        Ok(d)
2418    }
2419
2420    fn __dlpack_device__(&self) -> (i32, i32) {
2421        (2, self.device_id as i32)
2422    }
2423
2424    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2425    fn __dlpack__<'py>(
2426        &mut self,
2427        py: Python<'py>,
2428        stream: Option<PyObject>,
2429        max_version: Option<PyObject>,
2430        dl_device: Option<PyObject>,
2431        copy: Option<PyObject>,
2432    ) -> PyResult<PyObject> {
2433        let (kdl, alloc_dev) = self.__dlpack_device__();
2434        if let Some(dev_obj) = dl_device.as_ref() {
2435            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2436                if dev_ty != kdl || dev_id != alloc_dev {
2437                    let wants_copy = copy
2438                        .as_ref()
2439                        .and_then(|c| c.extract::<bool>(py).ok())
2440                        .unwrap_or(false);
2441                    if wants_copy {
2442                        return Err(PyValueError::new_err(
2443                            "device copy not implemented for __dlpack__",
2444                        ));
2445                    } else {
2446                        return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
2447                    }
2448                }
2449            }
2450        }
2451        let _ = stream;
2452
2453        let dummy =
2454            DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
2455        let inner = std::mem::replace(
2456            &mut self.inner,
2457            DeviceArrayF32 {
2458                buf: dummy,
2459                rows: 0,
2460                cols: 0,
2461            },
2462        );
2463
2464        let rows = inner.rows;
2465        let cols = inner.cols;
2466        let buf = inner.buf;
2467
2468        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2469
2470        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2471    }
2472}
2473
2474#[cfg(all(feature = "python", feature = "cuda"))]
2475#[pyfunction(name = "di_cuda_batch_dev")]
2476#[pyo3(signature = (high_f32, low_f32, close_f32, period_range, device_id=0))]
2477pub fn di_cuda_batch_dev_py<'py>(
2478    py: Python<'py>,
2479    high_f32: numpy::PyReadonlyArray1<'py, f32>,
2480    low_f32: numpy::PyReadonlyArray1<'py, f32>,
2481    close_f32: numpy::PyReadonlyArray1<'py, f32>,
2482    period_range: (usize, usize, usize),
2483    device_id: usize,
2484) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
2485    use numpy::IntoPyArray;
2486    if !cuda_available() {
2487        return Err(PyValueError::new_err("CUDA not available"));
2488    }
2489    let h = high_f32.as_slice()?;
2490    let l = low_f32.as_slice()?;
2491    let c = close_f32.as_slice()?;
2492    let sweep = DiBatchRange {
2493        period: period_range,
2494    };
2495    let (plus_dev, minus_dev, combos, ctx, dev_id) = py.allow_threads(|| {
2496        let cuda = CudaDi::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2497        let res = cuda
2498            .di_batch_dev(h, l, c, &sweep)
2499            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2500        Ok::<_, pyo3::PyErr>((res.0, res.1, res.2, cuda.context_arc(), cuda.device_id()))
2501    })?;
2502    let dict = pyo3::types::PyDict::new(py);
2503    dict.set_item(
2504        "plus",
2505        Py::new(
2506            py,
2507            DeviceArrayF32DiPy {
2508                inner: plus_dev,
2509                ctx: ctx.clone(),
2510                device_id: dev_id,
2511            },
2512        )?,
2513    )?;
2514    dict.set_item(
2515        "minus",
2516        Py::new(
2517            py,
2518            DeviceArrayF32DiPy {
2519                inner: minus_dev,
2520                ctx,
2521                device_id: dev_id,
2522            },
2523        )?,
2524    )?;
2525    dict.set_item(
2526        "periods",
2527        combos
2528            .iter()
2529            .map(|p| p.period.unwrap() as u64)
2530            .collect::<Vec<_>>()
2531            .into_pyarray(py),
2532    )?;
2533    dict.set_item("rows", combos.len())?;
2534    dict.set_item("cols", h.len())?;
2535    Ok(dict)
2536}
2537
2538#[cfg(all(feature = "python", feature = "cuda"))]
2539#[pyfunction(name = "di_cuda_many_series_one_param_dev")]
2540#[pyo3(signature = (high_tm_f32, low_tm_f32, close_tm_f32, period, device_id=0))]
2541pub fn di_cuda_many_series_one_param_dev_py<'py>(
2542    py: Python<'py>,
2543    high_tm_f32: numpy::PyReadonlyArray2<'py, f32>,
2544    low_tm_f32: numpy::PyReadonlyArray2<'py, f32>,
2545    close_tm_f32: numpy::PyReadonlyArray2<'py, f32>,
2546    period: usize,
2547    device_id: usize,
2548) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
2549    if !cuda_available() {
2550        return Err(PyValueError::new_err("CUDA not available"));
2551    }
2552    let h_shape = high_tm_f32.shape();
2553    let l_shape = low_tm_f32.shape();
2554    let c_shape = close_tm_f32.shape();
2555    if h_shape != l_shape || l_shape != c_shape || h_shape.len() != 2 {
2556        return Err(PyValueError::new_err(
2557            "expected three 2D arrays of same shape",
2558        ));
2559    }
2560    let rows = h_shape[0];
2561    let cols = h_shape[1];
2562    let h = high_tm_f32.as_slice()?;
2563    let l = low_tm_f32.as_slice()?;
2564    let c = close_tm_f32.as_slice()?;
2565    let (pair, ctx, dev_id) = py.allow_threads(|| {
2566        let cuda = CudaDi::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2567        let p = cuda
2568            .di_many_series_one_param_time_major_dev(h, l, c, cols, rows, period)
2569            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2570        Ok::<_, pyo3::PyErr>((p, cuda.context_arc(), cuda.device_id()))
2571    })?;
2572    let dict = pyo3::types::PyDict::new(py);
2573    dict.set_item(
2574        "plus",
2575        Py::new(
2576            py,
2577            DeviceArrayF32DiPy {
2578                inner: pair.plus,
2579                ctx: ctx.clone(),
2580                device_id: dev_id,
2581            },
2582        )?,
2583    )?;
2584    dict.set_item(
2585        "minus",
2586        Py::new(
2587            py,
2588            DeviceArrayF32DiPy {
2589                inner: pair.minus,
2590                ctx,
2591                device_id: dev_id,
2592            },
2593        )?,
2594    )?;
2595    dict.set_item("rows", rows)?;
2596    dict.set_item("cols", cols)?;
2597    dict.set_item("period", period)?;
2598    Ok(dict)
2599}
2600
2601#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2602#[derive(Serialize, Deserialize)]
2603pub struct DiJsResult {
2604    pub values: Vec<f64>,
2605    pub rows: usize,
2606    pub cols: usize,
2607}
2608
2609#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2610#[wasm_bindgen(js_name = di)]
2611pub fn di_js(high: &[f64], low: &[f64], close: &[f64], period: usize) -> Result<JsValue, JsValue> {
2612    let params = DiParams {
2613        period: Some(period),
2614    };
2615    let input = DiInput::from_slices(high, low, close, params);
2616    let out = di_with_kernel(&input, crate::utilities::enums::Kernel::Auto)
2617        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2618    let cols = high.len();
2619    let total = cols
2620        .checked_mul(2)
2621        .ok_or_else(|| JsValue::from_str("di_js: size overflow"))?;
2622    let mut values = Vec::with_capacity(total);
2623    values.extend_from_slice(&out.plus);
2624    values.extend_from_slice(&out.minus);
2625    let result = DiJsResult {
2626        values,
2627        rows: 2,
2628        cols,
2629    };
2630    serde_wasm_bindgen::to_value(&result).map_err(|e| JsValue::from_str(&e.to_string()))
2631}
2632
2633#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2634#[wasm_bindgen]
2635pub fn di_alloc(len: usize) -> *mut f64 {
2636    let mut vec = Vec::<f64>::with_capacity(len);
2637    let ptr = vec.as_mut_ptr();
2638    std::mem::forget(vec);
2639    ptr
2640}
2641
2642#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2643#[wasm_bindgen]
2644pub fn di_free(ptr: *mut f64, len: usize) {
2645    if !ptr.is_null() {
2646        unsafe {
2647            let _ = Vec::from_raw_parts(ptr, len, len);
2648        }
2649    }
2650}
2651
2652#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2653#[wasm_bindgen]
2654pub fn di_into(
2655    high_ptr: *const f64,
2656    low_ptr: *const f64,
2657    close_ptr: *const f64,
2658    out_ptr: *mut f64,
2659    len: usize,
2660    period: usize,
2661) -> Result<(), JsValue> {
2662    if high_ptr.is_null() || low_ptr.is_null() || close_ptr.is_null() || out_ptr.is_null() {
2663        return Err(JsValue::from_str("null pointer to di_into"));
2664    }
2665    unsafe {
2666        let h = std::slice::from_raw_parts(high_ptr, len);
2667        let l = std::slice::from_raw_parts(low_ptr, len);
2668        let c = std::slice::from_raw_parts(close_ptr, len);
2669        let params = DiParams {
2670            period: Some(period),
2671        };
2672        let input = DiInput::from_slices(h, l, c, params);
2673
2674        let DiOutput { plus, minus } =
2675            di_with_kernel(&input, crate::utilities::enums::Kernel::Auto)
2676                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2677
2678        let total = len
2679            .checked_mul(2)
2680            .ok_or_else(|| JsValue::from_str("di_into: size overflow"))?;
2681        let out = std::slice::from_raw_parts_mut(out_ptr, total);
2682        out[..len].copy_from_slice(&plus);
2683        out[len..total].copy_from_slice(&minus);
2684        Ok(())
2685    }
2686}
2687
2688#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2689#[derive(Serialize, Deserialize)]
2690pub struct DiBatchConfig {
2691    pub period_range: (usize, usize, usize),
2692}
2693
2694#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2695#[derive(Serialize, Deserialize)]
2696pub struct DiBatchJsOutput {
2697    pub values: Vec<f64>,
2698    pub rows: usize,
2699    pub cols: usize,
2700    pub periods: Vec<usize>,
2701}
2702
2703#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2704#[wasm_bindgen(js_name = di_batch)]
2705pub fn di_batch_unified_js(
2706    high: &[f64],
2707    low: &[f64],
2708    close: &[f64],
2709    config: JsValue,
2710) -> Result<JsValue, JsValue> {
2711    let cfg: DiBatchConfig = serde_wasm_bindgen::from_value(config)
2712        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2713    let sweep = DiBatchRange {
2714        period: cfg.period_range,
2715    };
2716    let output = di_batch_slice(
2717        high,
2718        low,
2719        close,
2720        &sweep,
2721        crate::utilities::enums::Kernel::Scalar,
2722    )
2723    .map_err(|e| JsValue::from_str(&e.to_string()))?;
2724    let rows = output
2725        .rows
2726        .checked_mul(2)
2727        .ok_or_else(|| JsValue::from_str("di_batch_unified_js: rows overflow"))?;
2728    let cols = output.cols;
2729
2730    let total = rows
2731        .checked_mul(cols)
2732        .ok_or_else(|| JsValue::from_str("di_batch_unified_js: size overflow"))?;
2733    let mut values = Vec::with_capacity(total);
2734
2735    for combo_idx in 0..output.rows {
2736        let start = combo_idx * cols;
2737        values.extend_from_slice(&output.plus[start..start + cols]);
2738        values.extend_from_slice(&output.minus[start..start + cols]);
2739    }
2740
2741    let js = DiBatchJsOutput {
2742        values,
2743        rows,
2744        cols,
2745        periods: output
2746            .combos
2747            .iter()
2748            .map(|p| p.period.unwrap())
2749            .collect::<Vec<_>>(),
2750    };
2751    serde_wasm_bindgen::to_value(&js)
2752        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2753}
2754
2755#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2756#[wasm_bindgen]
2757pub fn di_batch_into(
2758    high_ptr: *const f64,
2759    low_ptr: *const f64,
2760    close_ptr: *const f64,
2761    plus_ptr: *mut f64,
2762    minus_ptr: *mut f64,
2763    len: usize,
2764    period_start: usize,
2765    period_end: usize,
2766    period_step: usize,
2767) -> Result<usize, JsValue> {
2768    if high_ptr.is_null()
2769        || low_ptr.is_null()
2770        || close_ptr.is_null()
2771        || plus_ptr.is_null()
2772        || minus_ptr.is_null()
2773    {
2774        return Err(JsValue::from_str("Null pointer provided"));
2775    }
2776
2777    unsafe {
2778        let high = std::slice::from_raw_parts(high_ptr, len);
2779        let low = std::slice::from_raw_parts(low_ptr, len);
2780        let close = std::slice::from_raw_parts(close_ptr, len);
2781
2782        let sweep = DiBatchRange {
2783            period: (period_start, period_end, period_step),
2784        };
2785
2786        let combos = expand_grid(&sweep);
2787        let rows = combos.len();
2788        let total_size = rows
2789            .checked_mul(len)
2790            .ok_or_else(|| JsValue::from_str("di_batch_into: size overflow"))?;
2791
2792        let out_plus = std::slice::from_raw_parts_mut(plus_ptr, total_size);
2793        let out_minus = std::slice::from_raw_parts_mut(minus_ptr, total_size);
2794
2795        di_batch_inner_into(
2796            high,
2797            low,
2798            close,
2799            &sweep,
2800            Kernel::Auto,
2801            false,
2802            out_plus,
2803            out_minus,
2804        )
2805        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2806
2807        Ok(rows)
2808    }
2809}