Skip to main content

vector_ta/indicators/
dx.rs

1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5    make_uninit_matrix,
6};
7use aligned_vec::{AVec, CACHELINE_ALIGN};
8#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
9use core::arch::x86_64::*;
10#[cfg(not(target_arch = "wasm32"))]
11use rayon::prelude::*;
12use std::convert::AsRef;
13use std::error::Error;
14use std::mem::MaybeUninit;
15use thiserror::Error;
16
17#[cfg(feature = "python")]
18use crate::utilities::kernel_validation::validate_kernel;
19#[cfg(feature = "python")]
20use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
21#[cfg(feature = "python")]
22use pyo3::exceptions::PyValueError;
23#[cfg(feature = "python")]
24use pyo3::prelude::*;
25#[cfg(feature = "python")]
26use pyo3::types::{PyDict, PyList};
27
28#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
29use serde::{Deserialize, Serialize};
30#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
31use wasm_bindgen::prelude::*;
32
33#[derive(Debug, Clone)]
34pub enum DxData<'a> {
35    Candles {
36        candles: &'a Candles,
37    },
38    HlcSlices {
39        high: &'a [f64],
40        low: &'a [f64],
41        close: &'a [f64],
42    },
43}
44
45#[derive(Debug, Clone)]
46pub struct DxOutput {
47    pub values: Vec<f64>,
48}
49
50#[derive(Debug, Clone)]
51#[cfg_attr(
52    all(target_arch = "wasm32", feature = "wasm"),
53    derive(Serialize, Deserialize)
54)]
55pub struct DxParams {
56    pub period: Option<usize>,
57}
58
59impl Default for DxParams {
60    fn default() -> Self {
61        Self { period: Some(14) }
62    }
63}
64
65impl DxParams {
66    pub fn generate_batch_params(period_range: (usize, usize, usize)) -> Vec<Self> {
67        let (start, end, step) = period_range;
68        let step = if step == 0 { 1 } else { step };
69        let mut params = Vec::new();
70
71        let mut period = start;
72        while period <= end {
73            params.push(Self {
74                period: Some(period),
75            });
76            period += step;
77        }
78
79        params
80    }
81}
82
83#[derive(Debug, Clone)]
84pub struct DxInput<'a> {
85    pub data: DxData<'a>,
86    pub params: DxParams,
87}
88
89impl<'a> DxInput<'a> {
90    #[inline]
91    pub fn from_candles(candles: &'a Candles, params: DxParams) -> Self {
92        Self {
93            data: DxData::Candles { candles },
94            params,
95        }
96    }
97    #[inline]
98    pub fn from_hlc_slices(
99        high: &'a [f64],
100        low: &'a [f64],
101        close: &'a [f64],
102        params: DxParams,
103    ) -> Self {
104        Self {
105            data: DxData::HlcSlices { high, low, close },
106            params,
107        }
108    }
109    #[inline]
110    pub fn with_default_candles(candles: &'a Candles) -> Self {
111        Self {
112            data: DxData::Candles { candles },
113            params: DxParams::default(),
114        }
115    }
116    #[inline]
117    pub fn get_period(&self) -> usize {
118        self.params.period.unwrap_or(14)
119    }
120}
121
122#[derive(Copy, Clone, Debug)]
123pub struct DxBuilder {
124    period: Option<usize>,
125    kernel: Kernel,
126}
127
128impl Default for DxBuilder {
129    fn default() -> Self {
130        Self {
131            period: None,
132            kernel: Kernel::Auto,
133        }
134    }
135}
136
137impl DxBuilder {
138    #[inline(always)]
139    pub fn new() -> Self {
140        Self::default()
141    }
142    #[inline(always)]
143    pub fn period(mut self, n: usize) -> Self {
144        self.period = Some(n);
145        self
146    }
147    #[inline(always)]
148    pub fn kernel(mut self, k: Kernel) -> Self {
149        self.kernel = k;
150        self
151    }
152    #[inline(always)]
153    pub fn apply(self, c: &Candles) -> Result<DxOutput, DxError> {
154        let p = DxParams {
155            period: self.period,
156        };
157        let i = DxInput::from_candles(c, p);
158        dx_with_kernel(&i, self.kernel)
159    }
160    #[inline(always)]
161    pub fn apply_hlc(self, high: &[f64], low: &[f64], close: &[f64]) -> Result<DxOutput, DxError> {
162        let p = DxParams {
163            period: self.period,
164        };
165        let i = DxInput::from_hlc_slices(high, low, close, p);
166        dx_with_kernel(&i, self.kernel)
167    }
168    #[inline(always)]
169    pub fn into_stream(self) -> Result<DxStream, DxError> {
170        let p = DxParams {
171            period: self.period,
172        };
173        DxStream::try_new(p)
174    }
175}
176
177#[derive(Debug, Error)]
178pub enum DxError {
179    #[error("dx: Empty data provided for DX.")]
180    EmptyInputData,
181    #[error("dx: Could not select candle field: {0}")]
182    SelectCandleFieldError(String),
183    #[error("dx: Invalid period: period = {period}, data length = {data_len}")]
184    InvalidPeriod { period: usize, data_len: usize },
185    #[error("dx: Not enough valid data: needed = {needed}, valid = {valid}")]
186    NotEnoughValidData { needed: usize, valid: usize },
187    #[error("dx: All high, low, and close values are NaN.")]
188    AllValuesNaN,
189    #[error("dx: output length mismatch: expected = {expected}, got = {got}")]
190    OutputLengthMismatch { expected: usize, got: usize },
191    #[error("dx: invalid kernel for batch: {0:?}")]
192    InvalidKernelForBatch(Kernel),
193    #[error("dx: invalid range: start={start}, end={end}, step={step}")]
194    InvalidRange {
195        start: usize,
196        end: usize,
197        step: usize,
198    },
199    #[error("dx: invalid input: {0}")]
200    InvalidInput(String),
201}
202
203#[inline(always)]
204fn dx_prepare<'a>(
205    input: &'a DxInput,
206    kernel: Kernel,
207) -> Result<(&'a [f64], &'a [f64], &'a [f64], usize, usize, Kernel), DxError> {
208    let (high, low, close) = match &input.data {
209        DxData::Candles { candles } => {
210            let h = candles
211                .select_candle_field("high")
212                .map_err(|e| DxError::SelectCandleFieldError(e.to_string()))?;
213            let l = candles
214                .select_candle_field("low")
215                .map_err(|e| DxError::SelectCandleFieldError(e.to_string()))?;
216            let c = candles
217                .select_candle_field("close")
218                .map_err(|e| DxError::SelectCandleFieldError(e.to_string()))?;
219            (h, l, c)
220        }
221        DxData::HlcSlices { high, low, close } => (*high, *low, *close),
222    };
223    let len = high.len().min(low.len()).min(close.len());
224    if len == 0 {
225        return Err(DxError::EmptyInputData);
226    }
227    let period = input.get_period();
228    if period == 0 || period > len {
229        return Err(DxError::InvalidPeriod {
230            period,
231            data_len: len,
232        });
233    }
234
235    let first = (0..len)
236        .find(|&i| !high[i].is_nan() && !low[i].is_nan() && !close[i].is_nan())
237        .ok_or(DxError::AllValuesNaN)?;
238    if len - first < period {
239        return Err(DxError::NotEnoughValidData {
240            needed: period,
241            valid: len - first,
242        });
243    }
244    let chosen = match kernel {
245        Kernel::Auto => Kernel::Scalar,
246        k => k,
247    };
248    Ok((high, low, close, len, first, chosen))
249}
250
251#[inline]
252pub fn dx(input: &DxInput) -> Result<DxOutput, DxError> {
253    dx_with_kernel(input, Kernel::Auto)
254}
255
256pub fn dx_with_kernel(input: &DxInput, kernel: Kernel) -> Result<DxOutput, DxError> {
257    let (h, l, c, len, first, chosen) = dx_prepare(input, kernel)?;
258    let warm = first + input.get_period() - 1;
259    let mut out = alloc_with_nan_prefix(len, warm);
260    unsafe {
261        match chosen {
262            Kernel::Scalar | Kernel::ScalarBatch => {
263                dx_scalar(h, l, c, input.get_period(), first, &mut out)
264            }
265            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
266            Kernel::Avx2 | Kernel::Avx2Batch => {
267                dx_avx2(h, l, c, input.get_period(), first, &mut out)
268            }
269            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
270            Kernel::Avx512 | Kernel::Avx512Batch => {
271                dx_avx512(h, l, c, input.get_period(), first, &mut out)
272            }
273            _ => unreachable!(),
274        }
275    }
276    Ok(DxOutput { values: out })
277}
278
279#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
280#[inline]
281pub fn dx_into(input: &DxInput, out: &mut [f64]) -> Result<(), DxError> {
282    dx_into_slice(out, input, Kernel::Auto)
283}
284
285#[inline]
286pub fn dx_scalar(
287    high: &[f64],
288    low: &[f64],
289    close: &[f64],
290    period: usize,
291    first_valid_idx: usize,
292    out: &mut [f64],
293) {
294    let len = high.len().min(low.len()).min(close.len());
295    if len == 0 {
296        return;
297    }
298
299    let p_f64 = period as f64;
300    let hundred = 100.0f64;
301
302    let mut prev_high = high[first_valid_idx];
303    let mut prev_low = low[first_valid_idx];
304    let mut prev_close = close[first_valid_idx];
305
306    let mut plus_dm_sum = 0.0f64;
307    let mut minus_dm_sum = 0.0f64;
308    let mut tr_sum = 0.0f64;
309    let mut initial_count: usize = 0;
310
311    unsafe {
312        let mut i = first_valid_idx + 1;
313        while i < len {
314            let h = *high.get_unchecked(i);
315            let l = *low.get_unchecked(i);
316            let cl = *close.get_unchecked(i);
317
318            if h.is_nan() | l.is_nan() | cl.is_nan() {
319                *out.get_unchecked_mut(i) = if i > 0 {
320                    *out.get_unchecked(i - 1)
321                } else {
322                    f64::NAN
323                };
324                prev_high = h;
325                prev_low = l;
326                prev_close = cl;
327                i += 1;
328                continue;
329            }
330
331            let up_move = h - prev_high;
332            let down_move = prev_low - l;
333            let mut plus_dm = 0.0f64;
334            let mut minus_dm = 0.0f64;
335            if up_move > 0.0 && up_move > down_move {
336                plus_dm = up_move;
337            } else if down_move > 0.0 && down_move > up_move {
338                minus_dm = down_move;
339            }
340
341            let tr1 = h - l;
342            let tr2 = (h - prev_close).abs();
343            let tr3 = (l - prev_close).abs();
344            let tr = tr1.max(tr2).max(tr3);
345
346            if initial_count < (period - 1) {
347                plus_dm_sum += plus_dm;
348                minus_dm_sum += minus_dm;
349                tr_sum += tr;
350                initial_count += 1;
351
352                if initial_count == (period - 1) {
353                    let plus_di = (plus_dm_sum / tr_sum) * hundred;
354                    let minus_di = (minus_dm_sum / tr_sum) * hundred;
355                    let sum_di = plus_di + minus_di;
356                    *out.get_unchecked_mut(i) = if sum_di != 0.0 {
357                        hundred * ((plus_di - minus_di).abs() / sum_di)
358                    } else {
359                        0.0
360                    };
361                }
362            } else {
363                plus_dm_sum = plus_dm_sum - (plus_dm_sum / p_f64) + plus_dm;
364                minus_dm_sum = minus_dm_sum - (minus_dm_sum / p_f64) + minus_dm;
365                tr_sum = tr_sum - (tr_sum / p_f64) + tr;
366
367                let plus_di = if tr_sum != 0.0 {
368                    (plus_dm_sum / tr_sum) * hundred
369                } else {
370                    0.0
371                };
372                let minus_di = if tr_sum != 0.0 {
373                    (minus_dm_sum / tr_sum) * hundred
374                } else {
375                    0.0
376                };
377                let sum_di = plus_di + minus_di;
378                *out.get_unchecked_mut(i) = if sum_di != 0.0 {
379                    hundred * ((plus_di - minus_di).abs() / sum_di)
380                } else {
381                    *out.get_unchecked(i - 1)
382                };
383            }
384
385            prev_high = h;
386            prev_low = l;
387            prev_close = cl;
388
389            i += 1;
390        }
391    }
392}
393
394#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
395#[inline]
396pub fn dx_avx512(
397    high: &[f64],
398    low: &[f64],
399    close: &[f64],
400    period: usize,
401    first_valid: usize,
402    out: &mut [f64],
403) {
404    if period <= 32 {
405        unsafe { dx_avx512_short(high, low, close, period, first_valid, out) }
406    } else {
407        unsafe { dx_avx512_long(high, low, close, period, first_valid, out) }
408    }
409}
410
411#[inline]
412pub fn dx_avx2(
413    high: &[f64],
414    low: &[f64],
415    close: &[f64],
416    period: usize,
417    first_valid: usize,
418    out: &mut [f64],
419) {
420    dx_scalar(high, low, close, period, first_valid, out)
421}
422
423#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
424#[inline]
425pub fn dx_avx512_short(
426    high: &[f64],
427    low: &[f64],
428    close: &[f64],
429    period: usize,
430    first_valid: usize,
431    out: &mut [f64],
432) {
433    dx_scalar(high, low, close, period, first_valid, out)
434}
435
436#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
437#[inline]
438pub fn dx_avx512_long(
439    high: &[f64],
440    low: &[f64],
441    close: &[f64],
442    period: usize,
443    first_valid: usize,
444    out: &mut [f64],
445) {
446    dx_scalar(high, low, close, period, first_valid, out)
447}
448
449#[derive(Debug, Clone)]
450pub struct DxStream {
451    period: usize,
452
453    p_f64: f64,
454    hundred: f64,
455
456    plus_dm_sum: f64,
457    minus_dm_sum: f64,
458    tr_sum: f64,
459
460    prev_high: f64,
461    prev_low: f64,
462    prev_close: f64,
463
464    initial_count: usize,
465    filled: bool,
466
467    last_dx: f64,
468}
469
470impl DxStream {
471    pub fn try_new(params: DxParams) -> Result<Self, DxError> {
472        let period = params.period.unwrap_or(14);
473        if period == 0 {
474            return Err(DxError::InvalidPeriod {
475                period,
476                data_len: 0,
477            });
478        }
479        Ok(Self {
480            period,
481            p_f64: period as f64,
482            hundred: 100.0,
483
484            plus_dm_sum: 0.0,
485            minus_dm_sum: 0.0,
486            tr_sum: 0.0,
487
488            prev_high: f64::NAN,
489            prev_low: f64::NAN,
490            prev_close: f64::NAN,
491
492            initial_count: 0,
493            filled: false,
494            last_dx: f64::NAN,
495        })
496    }
497
498    #[inline(always)]
499    pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<f64> {
500        if self.prev_high.is_nan() || self.prev_low.is_nan() || self.prev_close.is_nan() {
501            self.prev_high = high;
502            self.prev_low = low;
503            self.prev_close = close;
504            return None;
505        }
506
507        if high.is_nan() || low.is_nan() || close.is_nan() {
508            let carried = if self.filled { self.last_dx } else { f64::NAN };
509            self.prev_high = high;
510            self.prev_low = low;
511            self.prev_close = close;
512            return Some(carried);
513        }
514
515        let up_move = high - self.prev_high;
516        let down_move = self.prev_low - low;
517        let plus_dm = if up_move > 0.0 && up_move > down_move {
518            up_move
519        } else {
520            0.0
521        };
522        let minus_dm = if down_move > 0.0 && down_move > up_move {
523            down_move
524        } else {
525            0.0
526        };
527
528        let tr1 = high - low;
529        let tr2 = (high - self.prev_close).abs();
530        let tr3 = (low - self.prev_close).abs();
531        let tr = tr1.max(tr2).max(tr3);
532
533        let mut out: Option<f64> = None;
534
535        if self.initial_count < (self.period - 1) {
536            self.plus_dm_sum += plus_dm;
537            self.minus_dm_sum += minus_dm;
538            self.tr_sum += tr;
539            self.initial_count += 1;
540
541            if self.initial_count == (self.period - 1) {
542                let plus_di = (self.plus_dm_sum / self.tr_sum) * self.hundred;
543                let minus_di = (self.minus_dm_sum / self.tr_sum) * self.hundred;
544                let sum_di = plus_di + minus_di;
545
546                let dx = if sum_di != 0.0 {
547                    self.hundred * ((plus_di - minus_di).abs() / sum_di)
548                } else {
549                    0.0
550                };
551                self.filled = true;
552                self.last_dx = dx;
553                out = Some(dx);
554            }
555        } else {
556            self.plus_dm_sum = self.plus_dm_sum - (self.plus_dm_sum / self.p_f64) + plus_dm;
557            self.minus_dm_sum = self.minus_dm_sum - (self.minus_dm_sum / self.p_f64) + minus_dm;
558            self.tr_sum = self.tr_sum - (self.tr_sum / self.p_f64) + tr;
559
560            let plus_di = if self.tr_sum != 0.0 {
561                (self.plus_dm_sum / self.tr_sum) * self.hundred
562            } else {
563                0.0
564            };
565            let minus_di = if self.tr_sum != 0.0 {
566                (self.minus_dm_sum / self.tr_sum) * self.hundred
567            } else {
568                0.0
569            };
570            let sum_di = plus_di + minus_di;
571
572            let dx = if sum_di != 0.0 {
573                self.hundred * ((plus_di - minus_di).abs() / sum_di)
574            } else if self.filled {
575                self.last_dx
576            } else {
577                f64::NAN
578            };
579            self.last_dx = dx;
580            out = Some(dx);
581        }
582
583        self.prev_high = high;
584        self.prev_low = low;
585        self.prev_close = close;
586
587        out
588    }
589}
590
591#[derive(Clone, Debug)]
592pub struct DxBatchRange {
593    pub period: (usize, usize, usize),
594}
595
596impl Default for DxBatchRange {
597    fn default() -> Self {
598        Self {
599            period: (14, 263, 1),
600        }
601    }
602}
603
604impl DxBatchRange {
605    pub fn from_tuple(period: (usize, usize, usize)) -> Self {
606        Self { period }
607    }
608}
609
610#[derive(Clone, Debug, Default)]
611pub struct DxBatchBuilder {
612    range: DxBatchRange,
613    kernel: Kernel,
614}
615
616impl DxBatchBuilder {
617    pub fn new() -> Self {
618        Self::default()
619    }
620    pub fn kernel(mut self, k: Kernel) -> Self {
621        self.kernel = k;
622        self
623    }
624
625    #[inline]
626    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
627        self.range.period = (start, end, step);
628        self
629    }
630    #[inline]
631    pub fn period_static(mut self, p: usize) -> Self {
632        self.range.period = (p, p, 0);
633        self
634    }
635
636    pub fn apply_hlc(
637        self,
638        high: &[f64],
639        low: &[f64],
640        close: &[f64],
641    ) -> Result<DxBatchOutput, DxError> {
642        dx_batch_with_kernel(high, low, close, &self.range, self.kernel)
643    }
644
645    pub fn apply_candles(self, c: &Candles) -> Result<DxBatchOutput, DxError> {
646        let high = source_type(c, "high");
647        let low = source_type(c, "low");
648        let close = source_type(c, "close");
649        self.apply_hlc(high, low, close)
650    }
651}
652
653pub struct DxBatchOutput {
654    pub values: Vec<f64>,
655    pub combos: Vec<DxParams>,
656    pub rows: usize,
657    pub cols: usize,
658}
659
660impl DxBatchOutput {
661    pub fn row_for_params(&self, p: &DxParams) -> Option<usize> {
662        self.combos
663            .iter()
664            .position(|c| c.period.unwrap_or(14) == p.period.unwrap_or(14))
665    }
666    pub fn values_for(&self, p: &DxParams) -> Option<&[f64]> {
667        self.row_for_params(p).map(|row| {
668            let start = row * self.cols;
669            &self.values[start..start + self.cols]
670        })
671    }
672}
673
674#[inline(always)]
675fn expand_grid_checked(r: &DxBatchRange) -> Result<Vec<DxParams>, DxError> {
676    let (start, end, step) = r.period;
677
678    if step == 0 || start == end {
679        return Ok(vec![DxParams {
680            period: Some(start),
681        }]);
682    }
683
684    let mut out: Vec<usize> = Vec::new();
685    if start < end {
686        let mut v = start;
687        while v <= end {
688            out.push(v);
689            match v.checked_add(step) {
690                Some(next) if next != v => v = next,
691                _ => break,
692            }
693        }
694    } else {
695        let mut v = start;
696
697        loop {
698            out.push(v);
699            if v <= end {
700                break;
701            }
702            let dec = v.saturating_sub(step);
703            if dec == v {
704                break;
705            }
706            v = dec;
707        }
708
709        out.sort_unstable();
710    }
711    if out.is_empty() {
712        return Err(DxError::InvalidRange { start, end, step });
713    }
714    Ok(out
715        .into_iter()
716        .map(|p| DxParams { period: Some(p) })
717        .collect())
718}
719
720pub fn dx_batch_with_kernel(
721    high: &[f64],
722    low: &[f64],
723    close: &[f64],
724    sweep: &DxBatchRange,
725    k: Kernel,
726) -> Result<DxBatchOutput, DxError> {
727    let kernel = match k {
728        Kernel::Auto => Kernel::ScalarBatch,
729        other if other.is_batch() => other,
730        other => return Err(DxError::InvalidKernelForBatch(other)),
731    };
732
733    let simd = match kernel {
734        Kernel::Avx512Batch => Kernel::Avx512,
735        Kernel::Avx2Batch => Kernel::Avx2,
736        Kernel::ScalarBatch => Kernel::Scalar,
737        _ => unreachable!(),
738    };
739    dx_batch_par_slice(high, low, close, sweep, simd)
740}
741
742#[inline(always)]
743pub fn dx_batch_slice(
744    high: &[f64],
745    low: &[f64],
746    close: &[f64],
747    sweep: &DxBatchRange,
748    kern: Kernel,
749) -> Result<DxBatchOutput, DxError> {
750    dx_batch_inner(high, low, close, sweep, kern, false)
751}
752
753#[inline(always)]
754pub fn dx_batch_par_slice(
755    high: &[f64],
756    low: &[f64],
757    close: &[f64],
758    sweep: &DxBatchRange,
759    kern: Kernel,
760) -> Result<DxBatchOutput, DxError> {
761    dx_batch_inner(high, low, close, sweep, kern, true)
762}
763
764#[inline(always)]
765fn dx_batch_inner_into(
766    high: &[f64],
767    low: &[f64],
768    close: &[f64],
769    sweep: &DxBatchRange,
770    kern: Kernel,
771    parallel: bool,
772    out: &mut [f64],
773) -> Result<Vec<DxParams>, DxError> {
774    let combos_vec = expand_grid_checked(sweep)?;
775    let combos = combos_vec;
776
777    let len = high.len().min(low.len()).min(close.len());
778    let first = (0..len)
779        .find(|&i| !high[i].is_nan() && !low[i].is_nan() && !close[i].is_nan())
780        .ok_or(DxError::AllValuesNaN)?;
781    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
782    if len - first < max_p {
783        return Err(DxError::NotEnoughValidData {
784            needed: max_p,
785            valid: len - first,
786        });
787    }
788
789    let rows = combos.len();
790    let total = rows
791        .checked_mul(len)
792        .ok_or_else(|| DxError::InvalidInput("rows*cols overflow".into()))?;
793    if out.len() != total {
794        return Err(DxError::OutputLengthMismatch {
795            expected: total,
796            got: out.len(),
797        });
798    }
799
800    let actual = match kern {
801        Kernel::Auto => detect_best_batch_kernel(),
802        k => k,
803    };
804    let simd = match actual {
805        Kernel::ScalarBatch | Kernel::Scalar => Kernel::Scalar,
806        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
807        Kernel::Avx2Batch | Kernel::Avx2 => Kernel::Avx2,
808        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
809        Kernel::Avx512Batch | Kernel::Avx512 => Kernel::Avx512,
810        _ => unreachable!(),
811    };
812
813    let (plus_dm, minus_dm, tr, carry) = dx_precompute_terms(high, low, close, first, len);
814
815    let do_row = |row: usize, dst_row: &mut [f64]| unsafe {
816        let p = combos[row].period.unwrap();
817        dx_row_scalar_precomputed(&plus_dm, &minus_dm, &tr, &carry, first, p, dst_row);
818    };
819
820    if parallel {
821        #[cfg(not(target_arch = "wasm32"))]
822        out.par_chunks_mut(len)
823            .enumerate()
824            .for_each(|(r, s)| do_row(r, s));
825        #[cfg(target_arch = "wasm32")]
826        for (r, s) in out.chunks_mut(len).enumerate() {
827            do_row(r, s);
828        }
829    } else {
830        for (r, s) in out.chunks_mut(len).enumerate() {
831            do_row(r, s);
832        }
833    }
834    Ok(combos)
835}
836
837#[inline(always)]
838fn dx_precompute_terms(
839    high: &[f64],
840    low: &[f64],
841    close: &[f64],
842    first: usize,
843    len: usize,
844) -> (AVec<f64>, AVec<f64>, AVec<f64>, Vec<u8>) {
845    let mut plus_dm: AVec<f64> = AVec::with_capacity(CACHELINE_ALIGN, len);
846    let mut minus_dm: AVec<f64> = AVec::with_capacity(CACHELINE_ALIGN, len);
847    let mut tr: AVec<f64> = AVec::with_capacity(CACHELINE_ALIGN, len);
848
849    let mut carry: Vec<u8> = vec![0; len];
850
851    for _ in 0..len {
852        plus_dm.push(0.0);
853    }
854    for _ in 0..len {
855        minus_dm.push(0.0);
856    }
857    for _ in 0..len {
858        tr.push(0.0);
859    }
860
861    if len == 0 || first + 1 >= len {
862        return (plus_dm, minus_dm, tr, carry);
863    }
864
865    for i in (first + 1)..len {
866        let h = high[i];
867        let l = low[i];
868        let c = close[i];
869        if h.is_nan() || l.is_nan() || c.is_nan() {
870            carry[i] = 1;
871            continue;
872        }
873
874        let up_move = h - high[i - 1];
875        let down_move = low[i - 1] - l;
876        let pdm = if up_move > 0.0 && up_move > down_move {
877            up_move
878        } else {
879            0.0
880        };
881        let mdm = if down_move > 0.0 && down_move > up_move {
882            down_move
883        } else {
884            0.0
885        };
886
887        let tr1 = h - l;
888        let tr2 = (h - close[i - 1]).abs();
889        let tr3 = (l - close[i - 1]).abs();
890        let t = tr1.max(tr2).max(tr3);
891
892        plus_dm[i] = pdm;
893        minus_dm[i] = mdm;
894        tr[i] = t;
895    }
896
897    (plus_dm, minus_dm, tr, carry)
898}
899
900#[inline(always)]
901unsafe fn dx_row_scalar_precomputed(
902    plus_dm: &[f64],
903    minus_dm: &[f64],
904    tr: &[f64],
905    carry: &[u8],
906    first: usize,
907    period: usize,
908    out: &mut [f64],
909) {
910    let len = out.len();
911    if len == 0 || first + 1 >= len {
912        return;
913    }
914
915    let p_f64 = period as f64;
916    let hundred = 100.0f64;
917
918    let mut plus_dm_sum = 0.0f64;
919    let mut minus_dm_sum = 0.0f64;
920    let mut tr_sum = 0.0f64;
921    let mut initial_count: usize = 0;
922
923    let mut i = first + 1;
924    while i < len {
925        if *carry.get_unchecked(i) != 0 {
926            *out.get_unchecked_mut(i) = if i > 0 {
927                *out.get_unchecked(i - 1)
928            } else {
929                f64::NAN
930            };
931            i += 1;
932            continue;
933        }
934
935        let pdm = *plus_dm.get_unchecked(i);
936        let mdm = *minus_dm.get_unchecked(i);
937        let t = *tr.get_unchecked(i);
938
939        if initial_count < (period - 1) {
940            plus_dm_sum += pdm;
941            minus_dm_sum += mdm;
942            tr_sum += t;
943            initial_count += 1;
944            if initial_count == (period - 1) {
945                let plus_di = (plus_dm_sum / tr_sum) * hundred;
946                let minus_di = (minus_dm_sum / tr_sum) * hundred;
947                let sum_di = plus_di + minus_di;
948                *out.get_unchecked_mut(i) = if sum_di != 0.0 {
949                    hundred * ((plus_di - minus_di).abs() / sum_di)
950                } else {
951                    0.0
952                };
953            }
954        } else {
955            plus_dm_sum = plus_dm_sum - (plus_dm_sum / p_f64) + pdm;
956            minus_dm_sum = minus_dm_sum - (minus_dm_sum / p_f64) + mdm;
957            tr_sum = tr_sum - (tr_sum / p_f64) + t;
958            let plus_di = if tr_sum != 0.0 {
959                (plus_dm_sum / tr_sum) * hundred
960            } else {
961                0.0
962            };
963            let minus_di = if tr_sum != 0.0 {
964                (minus_dm_sum / tr_sum) * hundred
965            } else {
966                0.0
967            };
968            let sum_di = plus_di + minus_di;
969            *out.get_unchecked_mut(i) = if sum_di != 0.0 {
970                hundred * ((plus_di - minus_di).abs() / sum_di)
971            } else {
972                *out.get_unchecked(i - 1)
973            };
974        }
975
976        i += 1;
977    }
978}
979
980fn dx_batch_inner(
981    high: &[f64],
982    low: &[f64],
983    close: &[f64],
984    sweep: &DxBatchRange,
985    kern: Kernel,
986    parallel: bool,
987) -> Result<DxBatchOutput, DxError> {
988    let combos = expand_grid_checked(sweep)?;
989    let rows = combos.len();
990    let cols = high.len().min(low.len()).min(close.len());
991    if cols == 0 {
992        return Err(DxError::EmptyInputData);
993    }
994    let _ = rows
995        .checked_mul(cols)
996        .ok_or_else(|| DxError::InvalidInput("rows*cols overflow".into()))?;
997
998    let first = (0..cols)
999        .find(|&i| !high[i].is_nan() && !low[i].is_nan() && !close[i].is_nan())
1000        .ok_or(DxError::AllValuesNaN)?;
1001    let warm: Vec<usize> = combos
1002        .iter()
1003        .map(|c| first + c.period.unwrap() - 1)
1004        .collect();
1005
1006    let mut buf_mu = make_uninit_matrix(rows, cols);
1007    init_matrix_prefixes(&mut buf_mu, cols, &warm);
1008
1009    let mut guard = core::mem::ManuallyDrop::new(buf_mu);
1010    let out_slice: &mut [f64] =
1011        unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
1012
1013    let _ = dx_batch_inner_into(high, low, close, sweep, kern, parallel, out_slice)?;
1014
1015    let values = unsafe {
1016        Vec::from_raw_parts(
1017            guard.as_mut_ptr() as *mut f64,
1018            guard.len(),
1019            guard.capacity(),
1020        )
1021    };
1022    Ok(DxBatchOutput {
1023        values,
1024        combos,
1025        rows,
1026        cols,
1027    })
1028}
1029
1030#[inline(always)]
1031unsafe fn dx_row_scalar(
1032    high: &[f64],
1033    low: &[f64],
1034    close: &[f64],
1035    first: usize,
1036    period: usize,
1037    out: &mut [f64],
1038) {
1039    dx_scalar(high, low, close, period, first, out)
1040}
1041
1042#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1043#[inline(always)]
1044unsafe fn dx_row_avx2(
1045    high: &[f64],
1046    low: &[f64],
1047    close: &[f64],
1048    first: usize,
1049    period: usize,
1050    out: &mut [f64],
1051) {
1052    dx_scalar(high, low, close, period, first, out)
1053}
1054
1055#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1056#[inline(always)]
1057pub unsafe fn dx_row_avx512(
1058    high: &[f64],
1059    low: &[f64],
1060    close: &[f64],
1061    first: usize,
1062    period: usize,
1063    out: &mut [f64],
1064) {
1065    if period <= 32 {
1066        dx_row_avx512_short(high, low, close, first, period, out);
1067    } else {
1068        dx_row_avx512_long(high, low, close, first, period, out);
1069    }
1070}
1071
1072#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1073#[inline(always)]
1074unsafe fn dx_row_avx512_short(
1075    high: &[f64],
1076    low: &[f64],
1077    close: &[f64],
1078    first: usize,
1079    period: usize,
1080    out: &mut [f64],
1081) {
1082    dx_scalar(high, low, close, period, first, out)
1083}
1084
1085#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1086#[inline(always)]
1087unsafe fn dx_row_avx512_long(
1088    high: &[f64],
1089    low: &[f64],
1090    close: &[f64],
1091    first: usize,
1092    period: usize,
1093    out: &mut [f64],
1094) {
1095    dx_scalar(high, low, close, period, first, out)
1096}
1097
1098#[inline(always)]
1099pub fn expand_grid_dx(r: &DxBatchRange) -> Vec<DxParams> {
1100    expand_grid_checked(r).unwrap_or_else(|_| vec![])
1101}
1102
1103#[cfg(test)]
1104mod tests {
1105    use super::*;
1106    use crate::skip_if_unsupported;
1107    use crate::utilities::data_loader::read_candles_from_csv;
1108
1109    fn check_dx_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1110        skip_if_unsupported!(kernel, test_name);
1111        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1112        let candles = read_candles_from_csv(file_path)?;
1113
1114        let default_params = DxParams { period: None };
1115        let input = DxInput::from_candles(&candles, default_params);
1116        let output = dx_with_kernel(&input, kernel)?;
1117        assert_eq!(output.values.len(), candles.close.len());
1118        Ok(())
1119    }
1120
1121    #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1122    #[test]
1123    fn test_dx_into_matches_api() -> Result<(), Box<dyn Error>> {
1124        let n = 512usize;
1125        let mut close = vec![0.0f64; n];
1126        for i in 0..n {
1127            let t = i as f64;
1128            close[i] = 100.0 + 0.1 * t + (t * 0.2).sin() * 2.0;
1129        }
1130        let mut high = vec![0.0f64; n];
1131        let mut low = vec![0.0f64; n];
1132        for i in 0..n {
1133            let t = i as f64;
1134
1135            high[i] = close[i] + 0.6 + 0.05 * (t * 0.3).sin();
1136            low[i] = close[i] - 0.6 - 0.05 * (t * 0.3).cos();
1137            if low[i] > high[i] {
1138                core::mem::swap(&mut low[i], &mut high[i]);
1139            }
1140        }
1141
1142        let params = DxParams { period: Some(14) };
1143        let input = DxInput::from_hlc_slices(&high, &low, &close, params);
1144
1145        let base = dx(&input)?.values;
1146
1147        let mut into_out = vec![0.0f64; n];
1148        dx_into(&input, &mut into_out)?;
1149
1150        #[inline]
1151        fn eq_or_both_nan(a: f64, b: f64) -> bool {
1152            (a.is_nan() && b.is_nan()) || (a == b) || ((a - b).abs() <= 1e-12)
1153        }
1154
1155        assert_eq!(base.len(), into_out.len());
1156        for i in 0..n {
1157            assert!(
1158                eq_or_both_nan(base[i], into_out[i]),
1159                "dx_into mismatch at {}: base={}, into={}",
1160                i,
1161                base[i],
1162                into_out[i]
1163            );
1164        }
1165        Ok(())
1166    }
1167
1168    fn check_dx_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1169        skip_if_unsupported!(kernel, test_name);
1170        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1171        let candles = read_candles_from_csv(file_path)?;
1172
1173        let input = DxInput::from_candles(&candles, DxParams::default());
1174        let result = dx_with_kernel(&input, kernel)?;
1175        let expected_last_five = [
1176            43.72121533411883,
1177            41.47251493226443,
1178            43.43041386436222,
1179            43.22673458811955,
1180            51.65514026197179,
1181        ];
1182        let start = result.values.len().saturating_sub(5);
1183        for (i, &val) in result.values[start..].iter().enumerate() {
1184            let diff = (val - expected_last_five[i]).abs();
1185            assert!(
1186                diff < 1e-4,
1187                "[{}] DX {:?} mismatch at idx {}: got {}, expected {}",
1188                test_name,
1189                kernel,
1190                i,
1191                val,
1192                expected_last_five[i]
1193            );
1194        }
1195        Ok(())
1196    }
1197
1198    fn check_dx_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1199        skip_if_unsupported!(kernel, test_name);
1200        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1201        let candles = read_candles_from_csv(file_path)?;
1202
1203        let input = DxInput::with_default_candles(&candles);
1204        match input.data {
1205            DxData::Candles { .. } => {}
1206            _ => panic!("Expected DxData::Candles"),
1207        }
1208        let output = dx_with_kernel(&input, kernel)?;
1209        assert_eq!(output.values.len(), candles.close.len());
1210        Ok(())
1211    }
1212
1213    fn check_dx_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1214        skip_if_unsupported!(kernel, test_name);
1215        let high = [2.0, 2.5, 3.0];
1216        let low = [1.0, 1.2, 2.1];
1217        let close = [1.5, 2.3, 2.2];
1218        let params = DxParams { period: Some(0) };
1219        let input = DxInput::from_hlc_slices(&high, &low, &close, params);
1220        let res = dx_with_kernel(&input, kernel);
1221        assert!(
1222            res.is_err(),
1223            "[{}] DX should fail with zero period",
1224            test_name
1225        );
1226        Ok(())
1227    }
1228
1229    fn check_dx_period_exceeds_length(
1230        test_name: &str,
1231        kernel: Kernel,
1232    ) -> Result<(), Box<dyn Error>> {
1233        skip_if_unsupported!(kernel, test_name);
1234        let high = [3.0, 4.0];
1235        let low = [2.0, 3.0];
1236        let close = [2.5, 3.5];
1237        let params = DxParams { period: Some(14) };
1238        let input = DxInput::from_hlc_slices(&high, &low, &close, params);
1239        let res = dx_with_kernel(&input, kernel);
1240        assert!(
1241            res.is_err(),
1242            "[{}] DX should fail with period exceeding length",
1243            test_name
1244        );
1245        Ok(())
1246    }
1247
1248    fn check_dx_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1249        skip_if_unsupported!(kernel, test_name);
1250        let high = [3.0];
1251        let low = [2.0];
1252        let close = [2.5];
1253        let params = DxParams { period: Some(14) };
1254        let input = DxInput::from_hlc_slices(&high, &low, &close, params);
1255        let res = dx_with_kernel(&input, kernel);
1256        assert!(
1257            res.is_err(),
1258            "[{}] DX should fail with insufficient data",
1259            test_name
1260        );
1261        Ok(())
1262    }
1263
1264    fn check_dx_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1265        skip_if_unsupported!(kernel, test_name);
1266        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1267        let candles = read_candles_from_csv(file_path)?;
1268
1269        let first_params = DxParams { period: Some(14) };
1270        let first_input = DxInput::from_candles(&candles, first_params);
1271        let first_result = dx_with_kernel(&first_input, kernel)?;
1272
1273        let second_params = DxParams { period: Some(14) };
1274        let second_input = DxInput::from_hlc_slices(
1275            &first_result.values,
1276            &first_result.values,
1277            &first_result.values,
1278            second_params,
1279        );
1280        let second_result = dx_with_kernel(&second_input, kernel)?;
1281        assert_eq!(second_result.values.len(), first_result.values.len());
1282        for i in 28..second_result.values.len() {
1283            assert!(
1284                !second_result.values[i].is_nan(),
1285                "[{}] Expected no NaN after index 28, found NaN at idx {}",
1286                test_name,
1287                i
1288            );
1289        }
1290        Ok(())
1291    }
1292
1293    fn check_dx_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1294        skip_if_unsupported!(kernel, test_name);
1295        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1296        let candles = read_candles_from_csv(file_path)?;
1297        let input = DxInput::from_candles(&candles, DxParams { period: Some(14) });
1298        let res = dx_with_kernel(&input, kernel)?;
1299        assert_eq!(res.values.len(), candles.close.len());
1300        if res.values.len() > 50 {
1301            for (i, &val) in res.values[50..].iter().enumerate() {
1302                assert!(
1303                    !val.is_nan(),
1304                    "[{}] Found unexpected NaN at out-index {}",
1305                    test_name,
1306                    50 + i
1307                );
1308            }
1309        }
1310        Ok(())
1311    }
1312
1313    fn check_dx_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1314        skip_if_unsupported!(kernel, test_name);
1315        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1316        let candles = read_candles_from_csv(file_path)?;
1317        let high = source_type(&candles, "high");
1318        let low = source_type(&candles, "low");
1319        let close = source_type(&candles, "close");
1320        let period = 14;
1321
1322        let input = DxInput::from_candles(
1323            &candles,
1324            DxParams {
1325                period: Some(period),
1326            },
1327        );
1328        let batch_output = dx_with_kernel(&input, kernel)?.values;
1329
1330        let mut stream = DxStream::try_new(DxParams {
1331            period: Some(period),
1332        })?;
1333        let mut stream_values = Vec::with_capacity(candles.close.len());
1334        for ((&h, &l), &c) in high.iter().zip(low).zip(close) {
1335            match stream.update(h, l, c) {
1336                Some(dx_val) => stream_values.push(dx_val),
1337                None => stream_values.push(f64::NAN),
1338            }
1339        }
1340
1341        assert_eq!(batch_output.len(), stream_values.len());
1342        for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1343            if b.is_nan() && s.is_nan() {
1344                continue;
1345            }
1346            let diff = (b - s).abs();
1347            assert!(
1348                diff < 1e-9,
1349                "[{}] DX streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1350                test_name,
1351                i,
1352                b,
1353                s,
1354                diff
1355            );
1356        }
1357        Ok(())
1358    }
1359
1360    macro_rules! generate_all_dx_tests {
1361        ($($test_fn:ident),*) => {
1362            paste::paste! {
1363                $(
1364                    #[test]
1365                    fn [<$test_fn _scalar_f64>]() {
1366                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1367                    }
1368                )*
1369                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1370                $(
1371                    #[test]
1372                    fn [<$test_fn _avx2_f64>]() {
1373                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1374                    }
1375                    #[test]
1376                    fn [<$test_fn _avx512_f64>]() {
1377                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1378                    }
1379                )*
1380            }
1381        }
1382    }
1383    #[cfg(debug_assertions)]
1384    fn check_dx_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1385        skip_if_unsupported!(kernel, test_name);
1386
1387        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1388        let candles = read_candles_from_csv(file_path)?;
1389
1390        let test_params = vec![
1391            DxParams::default(),
1392            DxParams { period: Some(2) },
1393            DxParams { period: Some(5) },
1394            DxParams { period: Some(7) },
1395            DxParams { period: Some(10) },
1396            DxParams { period: Some(14) },
1397            DxParams { period: Some(20) },
1398            DxParams { period: Some(30) },
1399            DxParams { period: Some(50) },
1400            DxParams { period: Some(100) },
1401            DxParams { period: Some(200) },
1402        ];
1403
1404        for (param_idx, params) in test_params.iter().enumerate() {
1405            let input = DxInput::from_candles(&candles, params.clone());
1406            let output = dx_with_kernel(&input, kernel)?;
1407
1408            for (i, &val) in output.values.iter().enumerate() {
1409                if val.is_nan() {
1410                    continue;
1411                }
1412
1413                let bits = val.to_bits();
1414
1415                if bits == 0x11111111_11111111 {
1416                    panic!(
1417                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1418						 with params: period={} (param set {})",
1419                        test_name,
1420                        val,
1421                        bits,
1422                        i,
1423                        params.period.unwrap_or(14),
1424                        param_idx
1425                    );
1426                }
1427
1428                if bits == 0x22222222_22222222 {
1429                    panic!(
1430                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1431						 with params: period={} (param set {})",
1432                        test_name,
1433                        val,
1434                        bits,
1435                        i,
1436                        params.period.unwrap_or(14),
1437                        param_idx
1438                    );
1439                }
1440
1441                if bits == 0x33333333_33333333 {
1442                    panic!(
1443                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1444						 with params: period={} (param set {})",
1445                        test_name,
1446                        val,
1447                        bits,
1448                        i,
1449                        params.period.unwrap_or(14),
1450                        param_idx
1451                    );
1452                }
1453            }
1454        }
1455
1456        Ok(())
1457    }
1458
1459    #[cfg(not(debug_assertions))]
1460    fn check_dx_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1461        Ok(())
1462    }
1463
1464    #[cfg(test)]
1465    #[allow(clippy::float_cmp)]
1466    fn check_dx_property(
1467        test_name: &str,
1468        kernel: Kernel,
1469    ) -> Result<(), Box<dyn std::error::Error>> {
1470        use proptest::prelude::*;
1471        skip_if_unsupported!(kernel, test_name);
1472
1473        let strat = (2usize..=50)
1474            .prop_flat_map(|period| {
1475                (
1476                    100.0f64..5000.0f64,
1477                    (period + 20)..400,
1478                    0.001f64..0.05f64,
1479                    -0.01f64..0.01f64,
1480                    Just(period),
1481                )
1482            })
1483            .prop_map(|(base_price, data_len, volatility, trend, period)| {
1484                let mut high = Vec::with_capacity(data_len);
1485                let mut low = Vec::with_capacity(data_len);
1486                let mut close = Vec::with_capacity(data_len);
1487
1488                let mut price = base_price;
1489                for i in 0..data_len {
1490                    let trend_component = trend * i as f64;
1491                    let random_component = ((i * 7 + 13) % 17) as f64 / 17.0 - 0.5;
1492                    price =
1493                        base_price + trend_component + random_component * volatility * base_price;
1494
1495                    let daily_volatility = volatility * price;
1496                    let h = price + daily_volatility * (0.5 + ((i * 3) % 7) as f64 / 14.0);
1497                    let l = price - daily_volatility * (0.5 + ((i * 5) % 7) as f64 / 14.0);
1498                    let c = l + (h - l) * (0.3 + ((i * 11) % 7) as f64 / 10.0);
1499
1500                    high.push(h);
1501                    low.push(l);
1502                    close.push(c);
1503                }
1504
1505                (high, low, close, period)
1506            });
1507
1508        proptest::test_runner::TestRunner::default()
1509			.run(&strat, |(high, low, close, period)| {
1510				let params = DxParams { period: Some(period) };
1511				let input = DxInput::from_hlc_slices(&high, &low, &close, params.clone());
1512
1513				let DxOutput { values: out } = dx_with_kernel(&input, kernel).unwrap();
1514				let DxOutput { values: ref_out } = dx_with_kernel(&input, Kernel::Scalar).unwrap();
1515
1516
1517				for (i, &val) in out.iter().enumerate() {
1518					if !val.is_nan() {
1519						prop_assert!(
1520							val >= -1e-9 && val <= 100.0 + 1e-9,
1521							"[{}] DX value {} at index {} is outside [0, 100] range",
1522							test_name, val, i
1523						);
1524					}
1525				}
1526
1527
1528
1529
1530				let warmup = period - 1;
1531				for i in 0..warmup {
1532					prop_assert!(
1533						out[i].is_nan(),
1534						"[{}] Expected NaN during warmup at index {}, got {}",
1535						test_name, i, out[i]
1536					);
1537				}
1538
1539
1540				if out.len() > warmup + 10 {
1541					for i in (warmup + 10)..out.len() {
1542						prop_assert!(
1543							!out[i].is_nan(),
1544							"[{}] Unexpected NaN after warmup at index {}",
1545							test_name, i
1546						);
1547					}
1548				}
1549
1550
1551				for (i, (&val, &ref_val)) in out.iter().zip(ref_out.iter()).enumerate() {
1552					if val.is_nan() && ref_val.is_nan() {
1553						continue;
1554					}
1555
1556					let diff = (val - ref_val).abs();
1557					prop_assert!(
1558						diff < 1e-9,
1559						"[{}] Kernel mismatch at index {}: {} vs {} (diff: {})",
1560						test_name, i, val, ref_val, diff
1561					);
1562				}
1563
1564
1565				let all_same_high = high.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10);
1566				let all_same_low = low.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10);
1567				let all_same_close = close.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10);
1568
1569				if all_same_high && all_same_low && all_same_close {
1570
1571					if out.len() > warmup + 10 {
1572						let stable_vals = &out[warmup + 10..];
1573						for (i, &val) in stable_vals.iter().enumerate() {
1574							if !val.is_nan() {
1575								prop_assert!(
1576									val < 1.0,
1577									"[{}] With constant prices, expected DX < 1.0, got {} at index {}",
1578									test_name, val, warmup + 10 + i
1579								);
1580							}
1581						}
1582					}
1583				}
1584
1585
1586				if period <= 20 && out.len() > 100 {
1587
1588					let mid = out.len() / 2;
1589					let first_half_avg_price = close[..mid].iter().sum::<f64>() / mid as f64;
1590					let second_half_avg_price = close[mid..].iter().sum::<f64>() / (out.len() - mid) as f64;
1591					let price_change = ((second_half_avg_price - first_half_avg_price) / first_half_avg_price).abs();
1592
1593
1594					if price_change > 0.05 {
1595
1596						let first_half_dx = &out[warmup..mid];
1597						let second_half_dx = &out[mid..];
1598
1599						let first_avg = first_half_dx.iter()
1600							.filter(|v| !v.is_nan())
1601							.sum::<f64>() / first_half_dx.len() as f64;
1602						let second_avg = second_half_dx.iter()
1603							.filter(|v| !v.is_nan())
1604							.sum::<f64>() / second_half_dx.len() as f64;
1605
1606
1607						prop_assert!(
1608							second_avg > 20.0 || first_avg > 20.0,
1609							"[{}] Expected higher average DX in trending market. First half avg: {}, Second half avg: {}",
1610							test_name, first_avg, second_avg
1611						);
1612					}
1613				}
1614
1615
1616
1617				if period <= 14 && out.len() > 50 {
1618
1619					let trend_base = close[0];
1620					let perfect_trend = (0..50)
1621						.map(|i| {
1622							let price = trend_base + (i as f64 * trend_base * 0.01);
1623							let h = price * 1.005;
1624							let l = price * 0.995;
1625							let c = price;
1626							(h, l, c)
1627						})
1628						.collect::<Vec<_>>();
1629
1630					let perfect_high: Vec<f64> = perfect_trend.iter().map(|&(h, _, _)| h).collect();
1631					let perfect_low: Vec<f64> = perfect_trend.iter().map(|&(_, l, _)| l).collect();
1632					let perfect_close: Vec<f64> = perfect_trend.iter().map(|&(_, _, c)| c).collect();
1633
1634					let perfect_input = DxInput::from_hlc_slices(&perfect_high, &perfect_low, &perfect_close, params.clone());
1635					let DxOutput { values: perfect_out } = dx_with_kernel(&perfect_input, kernel).unwrap();
1636
1637
1638					if perfect_out.len() > warmup + 10 {
1639						let stable_dx = &perfect_out[warmup + 10..];
1640						let avg_dx = stable_dx.iter()
1641							.filter(|v| !v.is_nan())
1642							.sum::<f64>() / stable_dx.len() as f64;
1643
1644						prop_assert!(
1645							avg_dx > 50.0,
1646							"[{}] Expected high DX (>50) in perfect trend, got avg {}",
1647							test_name, avg_dx
1648						);
1649					}
1650				}
1651
1652
1653
1654				if period <= 14 && out.len() > 50 {
1655
1656					let range_base = close[0];
1657					let ranging_data = (0..50)
1658						.map(|i| {
1659
1660							let price = if i % 4 < 2 {
1661								range_base * 1.01
1662							} else {
1663								range_base * 0.99
1664							};
1665							let h = price * 1.002;
1666							let l = price * 0.998;
1667							let c = price;
1668							(h, l, c)
1669						})
1670						.collect::<Vec<_>>();
1671
1672					let ranging_high: Vec<f64> = ranging_data.iter().map(|&(h, _, _)| h).collect();
1673					let ranging_low: Vec<f64> = ranging_data.iter().map(|&(_, l, _)| l).collect();
1674					let ranging_close: Vec<f64> = ranging_data.iter().map(|&(_, _, c)| c).collect();
1675
1676					let ranging_input = DxInput::from_hlc_slices(&ranging_high, &ranging_low, &ranging_close, params.clone());
1677					let DxOutput { values: ranging_out } = dx_with_kernel(&ranging_input, kernel).unwrap();
1678
1679
1680					if ranging_out.len() > warmup + 10 {
1681						let stable_dx = &ranging_out[warmup + 10..];
1682						let avg_dx = stable_dx.iter()
1683							.filter(|v| !v.is_nan())
1684							.sum::<f64>() / stable_dx.len() as f64;
1685
1686
1687
1688						prop_assert!(
1689							avg_dx < 65.0,
1690							"[{}] Expected moderate DX (<65) in ranging market, got avg {}",
1691							test_name, avg_dx
1692						);
1693					}
1694				}
1695
1696				Ok(())
1697			})
1698			.unwrap();
1699
1700        Ok(())
1701    }
1702
1703    generate_all_dx_tests!(
1704        check_dx_partial_params,
1705        check_dx_accuracy,
1706        check_dx_default_candles,
1707        check_dx_zero_period,
1708        check_dx_period_exceeds_length,
1709        check_dx_very_small_dataset,
1710        check_dx_reinput,
1711        check_dx_nan_handling,
1712        check_dx_streaming,
1713        check_dx_no_poison
1714    );
1715
1716    #[cfg(test)]
1717    generate_all_dx_tests!(check_dx_property);
1718
1719    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1720        skip_if_unsupported!(kernel, test);
1721
1722        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1723        let c = read_candles_from_csv(file)?;
1724
1725        let output = DxBatchBuilder::new().kernel(kernel).apply_candles(&c)?;
1726
1727        let def = DxParams::default();
1728        let row = output.values_for(&def).expect("default row missing");
1729
1730        assert_eq!(row.len(), c.close.len());
1731
1732        let expected = [
1733            43.72121533411883,
1734            41.47251493226443,
1735            43.43041386436222,
1736            43.22673458811955,
1737            51.65514026197179,
1738        ];
1739        let start = row.len() - 5;
1740        for (i, &v) in row[start..].iter().enumerate() {
1741            assert!(
1742                (v - expected[i]).abs() < 1e-4,
1743                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1744            );
1745        }
1746        Ok(())
1747    }
1748
1749    macro_rules! gen_batch_tests {
1750        ($fn_name:ident) => {
1751            paste::paste! {
1752                #[test] fn [<$fn_name _scalar>]()      {
1753                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1754                }
1755                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1756                #[test] fn [<$fn_name _avx2>]()        {
1757                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1758                }
1759                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1760                #[test] fn [<$fn_name _avx512>]()      {
1761                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1762                }
1763                #[test] fn [<$fn_name _auto_detect>]() {
1764                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1765                }
1766            }
1767        };
1768    }
1769    fn check_batch_sweep(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1770        skip_if_unsupported!(kernel, test);
1771
1772        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1773        let c = read_candles_from_csv(file)?;
1774
1775        let output = DxBatchBuilder::new()
1776            .kernel(kernel)
1777            .period_range(10, 30, 5)
1778            .apply_candles(&c)?;
1779
1780        let expected_combos = 5;
1781        assert_eq!(output.combos.len(), expected_combos);
1782        assert_eq!(output.rows, expected_combos);
1783        assert_eq!(output.cols, c.close.len());
1784
1785        Ok(())
1786    }
1787
1788    #[cfg(debug_assertions)]
1789    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1790        skip_if_unsupported!(kernel, test);
1791
1792        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1793        let c = read_candles_from_csv(file)?;
1794
1795        let test_configs = vec![
1796            (2, 10, 2),
1797            (5, 25, 5),
1798            (30, 60, 15),
1799            (2, 5, 1),
1800            (10, 20, 2),
1801            (14, 14, 0),
1802            (5, 50, 15),
1803            (100, 200, 50),
1804        ];
1805
1806        for (cfg_idx, &(p_start, p_end, p_step)) in test_configs.iter().enumerate() {
1807            let output = DxBatchBuilder::new()
1808                .kernel(kernel)
1809                .period_range(p_start, p_end, p_step)
1810                .apply_candles(&c)?;
1811
1812            for (idx, &val) in output.values.iter().enumerate() {
1813                if val.is_nan() {
1814                    continue;
1815                }
1816
1817                let bits = val.to_bits();
1818                let row = idx / output.cols;
1819                let col = idx % output.cols;
1820                let combo = &output.combos[row];
1821
1822                if bits == 0x11111111_11111111 {
1823                    panic!(
1824                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1825						 at row {} col {} (flat index {}) with params: period={}",
1826                        test,
1827                        cfg_idx,
1828                        val,
1829                        bits,
1830                        row,
1831                        col,
1832                        idx,
1833                        combo.period.unwrap_or(14)
1834                    );
1835                }
1836
1837                if bits == 0x22222222_22222222 {
1838                    panic!(
1839                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1840						 at row {} col {} (flat index {}) with params: period={}",
1841                        test,
1842                        cfg_idx,
1843                        val,
1844                        bits,
1845                        row,
1846                        col,
1847                        idx,
1848                        combo.period.unwrap_or(14)
1849                    );
1850                }
1851
1852                if bits == 0x33333333_33333333 {
1853                    panic!(
1854                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1855						 at row {} col {} (flat index {}) with params: period={}",
1856                        test,
1857                        cfg_idx,
1858                        val,
1859                        bits,
1860                        row,
1861                        col,
1862                        idx,
1863                        combo.period.unwrap_or(14)
1864                    );
1865                }
1866            }
1867        }
1868
1869        Ok(())
1870    }
1871
1872    #[cfg(not(debug_assertions))]
1873    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1874        Ok(())
1875    }
1876
1877    gen_batch_tests!(check_batch_default_row);
1878    gen_batch_tests!(check_batch_sweep);
1879    gen_batch_tests!(check_batch_no_poison);
1880}
1881
1882#[cfg(feature = "python")]
1883#[pyfunction(name = "dx")]
1884#[pyo3(signature = (high, low, close, period, kernel=None))]
1885pub fn dx_py<'py>(
1886    py: Python<'py>,
1887    high: PyReadonlyArray1<f64>,
1888    low: PyReadonlyArray1<f64>,
1889    close: PyReadonlyArray1<f64>,
1890    period: usize,
1891    kernel: Option<&str>,
1892) -> PyResult<Bound<'py, PyArray1<f64>>> {
1893    let h = high.as_slice()?;
1894    let l = low.as_slice()?;
1895    let c = close.as_slice()?;
1896    let kern = validate_kernel(kernel, false)?;
1897    let params = DxParams {
1898        period: Some(period),
1899    };
1900    let inp = DxInput::from_hlc_slices(h, l, c, params);
1901    let vec_out: Vec<f64> = py
1902        .allow_threads(|| dx_with_kernel(&inp, kern).map(|o| o.values))
1903        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1904    Ok(vec_out.into_pyarray(py))
1905}
1906
1907#[cfg(feature = "python")]
1908#[pyfunction(name = "dx_batch")]
1909#[pyo3(signature = (high, low, close, period_range, kernel=None))]
1910pub fn dx_batch_py<'py>(
1911    py: Python<'py>,
1912    high: PyReadonlyArray1<f64>,
1913    low: PyReadonlyArray1<f64>,
1914    close: PyReadonlyArray1<f64>,
1915    period_range: (usize, usize, usize),
1916    kernel: Option<&str>,
1917) -> PyResult<Bound<'py, PyDict>> {
1918    use numpy::{PyArray1, PyArrayMethods};
1919    let h = high.as_slice()?;
1920    let l = low.as_slice()?;
1921    let c = close.as_slice()?;
1922    let sweep = DxBatchRange::from_tuple(period_range);
1923    let combos = expand_grid_checked(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1924    let rows = combos.len();
1925
1926    let cols = h.len().min(l.len()).min(c.len());
1927    let kern = validate_kernel(kernel, true)?;
1928    let DxBatchOutput { values, .. } = py
1929        .allow_threads(|| dx_batch_with_kernel(h, l, c, &sweep, kern))
1930        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1931
1932    let out_arr = PyArray1::from_vec(py, values);
1933
1934    let dict = PyDict::new(py);
1935    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1936    dict.set_item(
1937        "periods",
1938        combos
1939            .iter()
1940            .map(|p| p.period.unwrap() as u64)
1941            .collect::<Vec<_>>()
1942            .into_pyarray(py),
1943    )?;
1944    Ok(dict.into())
1945}
1946
1947#[cfg(all(feature = "python", feature = "cuda"))]
1948use crate::cuda::dx_wrapper::CudaDx;
1949#[cfg(all(feature = "python", feature = "cuda"))]
1950use crate::cuda::moving_averages::alma_wrapper::DeviceArrayF32 as DeviceArrayF32Cuda;
1951#[cfg(all(feature = "python", feature = "cuda"))]
1952use cust::context::Context as CudaContext;
1953#[cfg(all(feature = "python", feature = "cuda"))]
1954use std::sync::Arc;
1955
1956#[cfg(all(feature = "python", feature = "cuda"))]
1957#[pyfunction(name = "dx_cuda_batch_dev")]
1958#[pyo3(signature = (high_f32, low_f32, close_f32, period_range, device_id=0))]
1959pub fn dx_cuda_batch_dev_py<'py>(
1960    py: Python<'py>,
1961    high_f32: numpy::PyReadonlyArray1<'py, f32>,
1962    low_f32: numpy::PyReadonlyArray1<'py, f32>,
1963    close_f32: numpy::PyReadonlyArray1<'py, f32>,
1964    period_range: (usize, usize, usize),
1965    device_id: usize,
1966) -> PyResult<(DxDeviceArrayF32Py, Bound<'py, PyDict>)> {
1967    use crate::cuda::cuda_available;
1968    if !cuda_available() {
1969        return Err(PyValueError::new_err("CUDA not available"));
1970    }
1971    let h = high_f32.as_slice()?;
1972    let l = low_f32.as_slice()?;
1973    let c = close_f32.as_slice()?;
1974    let sweep = DxBatchRange::from_tuple(period_range);
1975    let (inner, combos, ctx, dev_id) = py.allow_threads(|| {
1976        let cuda = CudaDx::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1977        let ctx = cuda.context_arc();
1978        let dev_id = cuda.device_id();
1979        cuda.dx_batch_dev(h, l, c, &sweep)
1980            .map(|(arr, combos)| (arr, combos, ctx, dev_id))
1981            .map_err(|e| PyValueError::new_err(e.to_string()))
1982    })?;
1983    let dict = PyDict::new(py);
1984    dict.set_item(
1985        "periods",
1986        combos
1987            .iter()
1988            .map(|p| p.period.unwrap() as u64)
1989            .collect::<Vec<_>>()
1990            .into_pyarray(py),
1991    )?;
1992    Ok((
1993        DxDeviceArrayF32Py {
1994            inner,
1995            _ctx: ctx,
1996            device_id: dev_id,
1997        },
1998        dict,
1999    ))
2000}
2001
2002#[cfg(all(feature = "python", feature = "cuda"))]
2003#[pyfunction(name = "dx_cuda_many_series_one_param_dev")]
2004#[pyo3(signature = (high_tm_f32, low_tm_f32, close_tm_f32, cols, rows, period, device_id=0))]
2005pub fn dx_cuda_many_series_one_param_dev_py(
2006    py: Python<'_>,
2007    high_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
2008    low_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
2009    close_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
2010    cols: usize,
2011    rows: usize,
2012    period: usize,
2013    device_id: usize,
2014) -> PyResult<DxDeviceArrayF32Py> {
2015    use crate::cuda::cuda_available;
2016    if !cuda_available() {
2017        return Err(PyValueError::new_err("CUDA not available"));
2018    }
2019    let h = high_tm_f32.as_slice()?;
2020    let l = low_tm_f32.as_slice()?;
2021    let c = close_tm_f32.as_slice()?;
2022    let (inner, ctx, dev_id) = py.allow_threads(|| {
2023        let cuda = CudaDx::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2024        let ctx = cuda.context_arc();
2025        let dev_id = cuda.device_id();
2026        cuda.dx_many_series_one_param_time_major_dev(h, l, c, cols, rows, period)
2027            .map(|arr| (arr, ctx, dev_id))
2028            .map_err(|e| PyValueError::new_err(e.to_string()))
2029    })?;
2030    Ok(DxDeviceArrayF32Py {
2031        inner,
2032        _ctx: ctx,
2033        device_id: dev_id,
2034    })
2035}
2036
2037#[cfg(all(feature = "python", feature = "cuda"))]
2038#[pyclass(module = "ta_indicators.cuda", unsendable)]
2039pub struct DxDeviceArrayF32Py {
2040    pub(crate) inner: DeviceArrayF32Cuda,
2041    pub(crate) _ctx: Arc<CudaContext>,
2042    pub(crate) device_id: u32,
2043}
2044
2045#[cfg(all(feature = "python", feature = "cuda"))]
2046#[pymethods]
2047impl DxDeviceArrayF32Py {
2048    #[getter]
2049    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2050        let inner = &self.inner;
2051        let d = PyDict::new(py);
2052
2053        d.set_item("shape", (inner.rows, inner.cols))?;
2054
2055        d.set_item("typestr", "<f4")?;
2056
2057        d.set_item(
2058            "strides",
2059            (
2060                inner.cols * std::mem::size_of::<f32>(),
2061                std::mem::size_of::<f32>(),
2062            ),
2063        )?;
2064        let size = inner.rows.saturating_mul(inner.cols);
2065        let ptr = if size == 0 {
2066            0usize
2067        } else {
2068            inner.device_ptr() as usize
2069        };
2070        d.set_item("data", (ptr, false))?;
2071
2072        d.set_item("version", 3)?;
2073        Ok(d)
2074    }
2075
2076    fn __dlpack_device__(&self) -> PyResult<(i32, i32)> {
2077        unsafe {
2078            use cust::sys::cuPointerGetAttribute;
2079            let attr = cust::sys::CUpointer_attribute::CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL;
2080            let mut dev_ordinal: i32 = -1;
2081            let res = cuPointerGetAttribute(
2082                &mut dev_ordinal as *mut _ as *mut std::ffi::c_void,
2083                attr,
2084                self.inner.device_ptr(),
2085            );
2086            if res == cust::sys::CUresult::CUDA_SUCCESS && dev_ordinal >= 0 {
2087                return Ok((2, dev_ordinal));
2088            }
2089            Ok((2, self.device_id as i32))
2090        }
2091    }
2092
2093    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2094    fn __dlpack__<'py>(
2095        &mut self,
2096        py: Python<'py>,
2097        stream: Option<PyObject>,
2098        max_version: Option<PyObject>,
2099        dl_device: Option<PyObject>,
2100        copy: Option<PyObject>,
2101    ) -> PyResult<PyObject> {
2102        let (kdl, alloc_dev) = self.__dlpack_device__()?;
2103        if let Some(dev_obj) = dl_device.as_ref() {
2104            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2105                if dev_ty != kdl || dev_id != alloc_dev {
2106                    let wants_copy = copy
2107                        .as_ref()
2108                        .and_then(|c| c.extract::<bool>(py).ok())
2109                        .unwrap_or(false);
2110                    if wants_copy {
2111                        return Err(PyValueError::new_err(
2112                            "device copy not implemented for __dlpack__",
2113                        ));
2114                    } else {
2115                        return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
2116                    }
2117                }
2118            }
2119        }
2120        let _ = stream;
2121
2122        let dummy = cust::memory::DeviceBuffer::<f32>::from_slice(&[])
2123            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2124        let inner = std::mem::replace(
2125            &mut self.inner,
2126            DeviceArrayF32Cuda {
2127                buf: dummy,
2128                rows: 0,
2129                cols: 0,
2130            },
2131        );
2132
2133        let rows = inner.rows;
2134        let cols = inner.cols;
2135        let buf = inner.buf;
2136
2137        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2138
2139        crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d(
2140            py,
2141            buf,
2142            rows,
2143            cols,
2144            alloc_dev,
2145            max_version_bound,
2146        )
2147    }
2148}
2149
2150#[cfg(feature = "python")]
2151#[pyclass(name = "DxStream")]
2152pub struct DxStreamPy {
2153    inner: DxStream,
2154}
2155
2156#[cfg(feature = "python")]
2157#[pymethods]
2158impl DxStreamPy {
2159    #[new]
2160    pub fn new(period: usize) -> PyResult<Self> {
2161        let params = DxParams {
2162            period: Some(period),
2163        };
2164        let inner = DxStream::try_new(params)
2165            .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
2166        Ok(Self { inner })
2167    }
2168
2169    pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<f64> {
2170        self.inner.update(high, low, close)
2171    }
2172}
2173
2174#[inline]
2175pub fn dx_into_slice(dst: &mut [f64], input: &DxInput, kern: Kernel) -> Result<(), DxError> {
2176    let (h, l, c, len, first, chosen) = dx_prepare(input, kern)?;
2177    if dst.len() != len {
2178        return Err(DxError::OutputLengthMismatch {
2179            expected: len,
2180            got: dst.len(),
2181        });
2182    }
2183    unsafe {
2184        match chosen {
2185            Kernel::Scalar | Kernel::ScalarBatch => {
2186                dx_scalar(h, l, c, input.get_period(), first, dst)
2187            }
2188            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2189            Kernel::Avx2 | Kernel::Avx2Batch => dx_avx2(h, l, c, input.get_period(), first, dst),
2190            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2191            Kernel::Avx512 | Kernel::Avx512Batch => {
2192                dx_avx512(h, l, c, input.get_period(), first, dst)
2193            }
2194            _ => unreachable!(),
2195        }
2196    }
2197    let warm = first + input.get_period() - 1;
2198    for v in &mut dst[..warm] {
2199        *v = f64::NAN;
2200    }
2201    Ok(())
2202}
2203
2204#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2205#[wasm_bindgen]
2206pub fn dx_js(high: &[f64], low: &[f64], close: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
2207    let input = DxInput::from_hlc_slices(
2208        high,
2209        low,
2210        close,
2211        DxParams {
2212            period: Some(period),
2213        },
2214    );
2215    let mut out = vec![0.0; high.len().min(low.len()).min(close.len())];
2216    dx_into_slice(&mut out, &input, detect_best_kernel())
2217        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2218    Ok(out)
2219}
2220
2221#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2222#[wasm_bindgen]
2223pub fn dx_into(
2224    h_ptr: *const f64,
2225    l_ptr: *const f64,
2226    c_ptr: *const f64,
2227    out_ptr: *mut f64,
2228    len: usize,
2229    period: usize,
2230) -> Result<(), JsValue> {
2231    if [
2232        h_ptr as usize,
2233        l_ptr as usize,
2234        c_ptr as usize,
2235        out_ptr as usize,
2236    ]
2237    .iter()
2238    .any(|&p| p == 0)
2239    {
2240        return Err(JsValue::from_str("null pointer"));
2241    }
2242    unsafe {
2243        let h = core::slice::from_raw_parts(h_ptr, len);
2244        let l = core::slice::from_raw_parts(l_ptr, len);
2245        let c = core::slice::from_raw_parts(c_ptr, len);
2246        let inp = DxInput::from_hlc_slices(
2247            h,
2248            l,
2249            c,
2250            DxParams {
2251                period: Some(period),
2252            },
2253        );
2254
2255        if out_ptr == h_ptr as *mut f64
2256            || out_ptr == l_ptr as *mut f64
2257            || out_ptr == c_ptr as *mut f64
2258        {
2259            let mut tmp = vec![0.0; len];
2260            dx_into_slice(&mut tmp, &inp, detect_best_kernel())
2261                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2262            let dst = core::slice::from_raw_parts_mut(out_ptr, len);
2263            dst.copy_from_slice(&tmp);
2264        } else {
2265            let out = core::slice::from_raw_parts_mut(out_ptr, len);
2266            dx_into_slice(out, &inp, detect_best_kernel())
2267                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2268        }
2269        Ok(())
2270    }
2271}
2272
2273#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2274#[wasm_bindgen]
2275pub fn dx_alloc(len: usize) -> *mut f64 {
2276    let mut vec = Vec::<f64>::with_capacity(len);
2277    let ptr = vec.as_mut_ptr();
2278    std::mem::forget(vec);
2279    ptr
2280}
2281
2282#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2283#[wasm_bindgen]
2284pub fn dx_free(ptr: *mut f64, len: usize) {
2285    if !ptr.is_null() {
2286        unsafe {
2287            let _ = Vec::from_raw_parts(ptr, len, len);
2288        }
2289    }
2290}
2291
2292#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2293#[derive(Serialize, Deserialize)]
2294pub struct DxBatchConfig {
2295    pub period_range: (usize, usize, usize),
2296}
2297
2298#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2299#[derive(Serialize, Deserialize)]
2300pub struct DxBatchJsOutput {
2301    pub values: Vec<f64>,
2302    pub combos: Vec<DxParams>,
2303    pub rows: usize,
2304    pub cols: usize,
2305}
2306
2307#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2308#[wasm_bindgen(js_name = "dx_batch")]
2309pub fn dx_batch_unified_js(
2310    high: &[f64],
2311    low: &[f64],
2312    close: &[f64],
2313    config: JsValue,
2314) -> Result<JsValue, JsValue> {
2315    let cfg: DxBatchConfig = serde_wasm_bindgen::from_value(config)
2316        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2317    let sweep = DxBatchRange::from_tuple(cfg.period_range);
2318
2319    let rows = expand_grid_checked(&sweep)
2320        .map_err(|e| JsValue::from_str(&e.to_string()))?
2321        .len();
2322    let cols = high.len().min(low.len()).min(close.len());
2323    let mut buf_mu = make_uninit_matrix(rows, cols);
2324
2325    let first = (0..cols)
2326        .find(|&i| !high[i].is_nan() && !low[i].is_nan() && !close[i].is_nan())
2327        .ok_or_else(|| JsValue::from_str("AllValuesNaN"))?;
2328    let warm: Vec<usize> = expand_grid_checked(&sweep)
2329        .map_err(|e| JsValue::from_str(&e.to_string()))?
2330        .iter()
2331        .map(|p| first + p.period.unwrap() - 1)
2332        .collect();
2333    init_matrix_prefixes(&mut buf_mu, cols, &warm);
2334
2335    let mut guard = core::mem::ManuallyDrop::new(buf_mu);
2336    let out_slice: &mut [f64] =
2337        unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
2338    let combos = dx_batch_inner_into(
2339        high,
2340        low,
2341        close,
2342        &sweep,
2343        detect_best_batch_kernel(),
2344        false,
2345        out_slice,
2346    )
2347    .map_err(|e| JsValue::from_str(&e.to_string()))?;
2348    let values = unsafe {
2349        Vec::from_raw_parts(
2350            guard.as_mut_ptr() as *mut f64,
2351            guard.len(),
2352            guard.capacity(),
2353        )
2354    };
2355    let js = DxBatchJsOutput {
2356        values,
2357        combos,
2358        rows,
2359        cols,
2360    };
2361    serde_wasm_bindgen::to_value(&js)
2362        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2363}
2364
2365#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2366#[wasm_bindgen]
2367pub fn dx_batch_into(
2368    high_ptr: *const f64,
2369    low_ptr: *const f64,
2370    close_ptr: *const f64,
2371    out_ptr: *mut f64,
2372    len: usize,
2373    period_start: usize,
2374    period_end: usize,
2375    period_step: usize,
2376) -> Result<(), JsValue> {
2377    if high_ptr.is_null() || low_ptr.is_null() || close_ptr.is_null() || out_ptr.is_null() {
2378        return Err(JsValue::from_str("Null pointer provided"));
2379    }
2380
2381    unsafe {
2382        let high = std::slice::from_raw_parts(high_ptr, len);
2383        let low = std::slice::from_raw_parts(low_ptr, len);
2384        let close = std::slice::from_raw_parts(close_ptr, len);
2385        let batch_range = DxBatchRange::from_tuple((period_start, period_end, period_step));
2386        let combos = DxParams::generate_batch_params((period_start, period_end, period_step));
2387        let n_combos = combos.len();
2388
2389        if high_ptr == out_ptr || low_ptr == out_ptr || close_ptr == out_ptr {
2390            let result = dx_batch_with_kernel(high, low, close, &batch_range, Kernel::Auto)
2391                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2392            let out = std::slice::from_raw_parts_mut(out_ptr, len * n_combos);
2393            out.copy_from_slice(&result.values);
2394        } else {
2395            let params = combos;
2396            let out = std::slice::from_raw_parts_mut(out_ptr, len * n_combos);
2397
2398            let first = high
2399                .iter()
2400                .zip(low)
2401                .zip(close)
2402                .position(|((&h, &l), &c)| !h.is_nan() && !l.is_nan() && !c.is_nan())
2403                .ok_or_else(|| JsValue::from_str("All values are NaN"))?;
2404
2405            let mut buf_uninit = make_uninit_matrix(params.len(), len);
2406            let warmup_periods: Vec<usize> = params
2407                .iter()
2408                .map(|p| first + p.period.unwrap() - 1)
2409                .collect();
2410            init_matrix_prefixes(&mut buf_uninit, len, &warmup_periods);
2411
2412            let buf_ptr = buf_uninit.as_mut_ptr() as *mut f64;
2413            std::mem::forget(buf_uninit);
2414            let slice_out = std::slice::from_raw_parts_mut(buf_ptr, params.len() * len);
2415
2416            for (i, param) in params.iter().enumerate() {
2417                let row_offset = i * len;
2418                let row = &mut slice_out[row_offset..row_offset + len];
2419
2420                let warmup = first + param.period.unwrap() - 1;
2421                dx_scalar(high, low, close, param.period.unwrap(), first, row);
2422            }
2423
2424            out.copy_from_slice(slice_out);
2425        }
2426        Ok(())
2427    }
2428}