Skip to main content

vector_ta/indicators/
dm.rs

1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::dm_wrapper::CudaDm;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::indicators::moving_averages::alma::DeviceArrayF32Py;
5use crate::utilities::data_loader::Candles;
6use crate::utilities::enums::Kernel;
7use crate::utilities::helpers::{
8    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
9    make_uninit_matrix,
10};
11#[cfg(feature = "python")]
12use crate::utilities::kernel_validation::validate_kernel;
13#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
14use core::arch::x86_64::*;
15#[cfg(not(target_arch = "wasm32"))]
16use rayon::prelude::*;
17use std::mem::MaybeUninit;
18use thiserror::Error;
19
20#[derive(Debug, Clone)]
21pub enum DmData<'a> {
22    Candles { candles: &'a Candles },
23    Slices { high: &'a [f64], low: &'a [f64] },
24}
25
26#[derive(Debug, Clone)]
27pub struct DmOutput {
28    pub plus: Vec<f64>,
29    pub minus: Vec<f64>,
30}
31
32#[derive(Debug, Clone)]
33pub struct DmParams {
34    pub period: Option<usize>,
35}
36
37impl Default for DmParams {
38    fn default() -> Self {
39        Self { period: Some(14) }
40    }
41}
42
43#[derive(Debug, Clone)]
44pub struct DmInput<'a> {
45    pub data: DmData<'a>,
46    pub params: DmParams,
47}
48
49impl<'a> DmInput<'a> {
50    #[inline]
51    pub fn from_candles(candles: &'a Candles, params: DmParams) -> Self {
52        Self {
53            data: DmData::Candles { candles },
54            params,
55        }
56    }
57    #[inline]
58    pub fn from_slices(high: &'a [f64], low: &'a [f64], params: DmParams) -> Self {
59        Self {
60            data: DmData::Slices { high, low },
61            params,
62        }
63    }
64    #[inline]
65    pub fn with_default_candles(candles: &'a Candles) -> Self {
66        Self {
67            data: DmData::Candles { candles },
68            params: DmParams::default(),
69        }
70    }
71    #[inline]
72    pub fn get_period(&self) -> usize {
73        self.params
74            .period
75            .unwrap_or_else(|| DmParams::default().period.unwrap())
76    }
77}
78
79#[derive(Copy, Clone, Debug)]
80pub struct DmBuilder {
81    period: Option<usize>,
82    kernel: Kernel,
83}
84
85impl Default for DmBuilder {
86    fn default() -> Self {
87        Self {
88            period: None,
89            kernel: Kernel::Auto,
90        }
91    }
92}
93
94impl DmBuilder {
95    #[inline(always)]
96    pub fn new() -> Self {
97        Self::default()
98    }
99    #[inline(always)]
100    pub fn period(mut self, n: usize) -> Self {
101        self.period = Some(n);
102        self
103    }
104    #[inline(always)]
105    pub fn kernel(mut self, k: Kernel) -> Self {
106        self.kernel = k;
107        self
108    }
109
110    #[inline(always)]
111    pub fn apply(self, candles: &Candles) -> Result<DmOutput, DmError> {
112        let p = DmParams {
113            period: self.period,
114        };
115        let i = DmInput::from_candles(candles, p);
116        dm_with_kernel(&i, self.kernel)
117    }
118
119    #[inline(always)]
120    pub fn apply_slices(self, high: &[f64], low: &[f64]) -> Result<DmOutput, DmError> {
121        let p = DmParams {
122            period: self.period,
123        };
124        let i = DmInput::from_slices(high, low, p);
125        dm_with_kernel(&i, self.kernel)
126    }
127
128    #[inline(always)]
129    pub fn into_stream(self) -> Result<DmStream, DmError> {
130        let p = DmParams {
131            period: self.period,
132        };
133        DmStream::try_new(p)
134    }
135}
136
137#[derive(Debug, Error)]
138pub enum DmError {
139    #[error("dm: Empty data provided (or high/low length mismatch).")]
140    EmptyInputData,
141    #[error("dm: Invalid period: period = {period}, data length = {data_len}")]
142    InvalidPeriod { period: usize, data_len: usize },
143    #[error("dm: Not enough valid data: needed = {needed}, valid = {valid}")]
144    NotEnoughValidData { needed: usize, valid: usize },
145    #[error("dm: All values are NaN.")]
146    AllValuesNaN,
147    #[error("dm: output length mismatch: expected = {expected}, got = {got}")]
148    OutputLengthMismatch { expected: usize, got: usize },
149    #[error("dm: invalid range: start={start}, end={end}, step={step}")]
150    InvalidRange {
151        start: usize,
152        end: usize,
153        step: usize,
154    },
155    #[error("dm: invalid kernel for batch: {0:?}")]
156    InvalidKernelForBatch(Kernel),
157}
158
159#[inline]
160pub fn dm(input: &DmInput) -> Result<DmOutput, DmError> {
161    dm_with_kernel(input, Kernel::Auto)
162}
163
164#[inline(always)]
165fn dm_prepare<'a>(
166    input: &'a DmInput,
167    kernel: Kernel,
168) -> Result<(&'a [f64], &'a [f64], usize, usize, Kernel), DmError> {
169    let (high, low) = match &input.data {
170        DmData::Candles { candles } => {
171            let h = candles
172                .select_candle_field("high")
173                .map_err(|_| DmError::EmptyInputData)?;
174            let l = candles
175                .select_candle_field("low")
176                .map_err(|_| DmError::EmptyInputData)?;
177            (h, l)
178        }
179        DmData::Slices { high, low } => (*high, *low),
180    };
181
182    if high.is_empty() || low.is_empty() || high.len() != low.len() {
183        return Err(DmError::EmptyInputData);
184    }
185
186    let period = input.get_period();
187    if period == 0 || period > high.len() {
188        return Err(DmError::InvalidPeriod {
189            period,
190            data_len: high.len(),
191        });
192    }
193
194    let first = high
195        .iter()
196        .zip(low.iter())
197        .position(|(&h, &l)| !h.is_nan() && !l.is_nan())
198        .ok_or(DmError::AllValuesNaN)?;
199
200    if high.len() - first < period {
201        return Err(DmError::NotEnoughValidData {
202            needed: period,
203            valid: high.len() - first,
204        });
205    }
206
207    let chosen = match kernel {
208        Kernel::Auto => Kernel::Scalar,
209        k => k,
210    };
211    Ok((high, low, period, first, chosen))
212}
213
214#[inline(always)]
215fn dm_compute_into_scalar(
216    high: &[f64],
217    low: &[f64],
218    period: usize,
219    first: usize,
220    plus_out: &mut [f64],
221    minus_out: &mut [f64],
222) {
223    debug_assert_eq!(high.len(), low.len());
224    let n = high.len();
225    if n == 0 {
226        return;
227    }
228
229    let end_init = first + period - 1;
230
231    unsafe {
232        let mut sum_plus = 0.0f64;
233        let mut sum_minus = 0.0f64;
234
235        let mut i = first + 1;
236        let warm_stop = end_init + 1;
237
238        let mut prev_high = *high.get_unchecked(first);
239        let mut prev_low = *low.get_unchecked(first);
240
241        while i < warm_stop {
242            let hi = *high.get_unchecked(i);
243            let lo = *low.get_unchecked(i);
244            let diff_p = hi - prev_high;
245            let diff_m = prev_low - lo;
246            prev_high = hi;
247            prev_low = lo;
248
249            if diff_p > 0.0 && diff_p > diff_m {
250                sum_plus += diff_p;
251            } else if diff_m > 0.0 && diff_m > diff_p {
252                sum_minus += diff_m;
253            }
254            i += 1;
255        }
256
257        *plus_out.get_unchecked_mut(end_init) = sum_plus;
258        *minus_out.get_unchecked_mut(end_init) = sum_minus;
259
260        if end_init + 1 >= n {
261            return;
262        }
263        let inv_p = 1.0 / (period as f64);
264
265        let mut j = end_init + 1;
266        while j < n {
267            let hi = *high.get_unchecked(j);
268            let lo = *low.get_unchecked(j);
269            let diff_p = hi - prev_high;
270            let diff_m = prev_low - lo;
271            prev_high = hi;
272            prev_low = lo;
273
274            let (p, m) = if diff_p > 0.0 && diff_p > diff_m {
275                (diff_p, 0.0)
276            } else if diff_m > 0.0 && diff_m > diff_p {
277                (0.0, diff_m)
278            } else {
279                (0.0, 0.0)
280            };
281
282            #[cfg(target_feature = "fma")]
283            {
284                sum_plus = (-inv_p).mul_add(sum_plus, sum_plus + p);
285                sum_minus = (-inv_p).mul_add(sum_minus, sum_minus + m);
286            }
287            #[cfg(not(target_feature = "fma"))]
288            {
289                sum_plus = sum_plus - (sum_plus * inv_p) + p;
290                sum_minus = sum_minus - (sum_minus * inv_p) + m;
291            }
292
293            *plus_out.get_unchecked_mut(j) = sum_plus;
294            *minus_out.get_unchecked_mut(j) = sum_minus;
295            j += 1;
296        }
297    }
298}
299
300#[inline(always)]
301fn dm_compute_into(
302    high: &[f64],
303    low: &[f64],
304    period: usize,
305    first: usize,
306    kernel: Kernel,
307    plus_out: &mut [f64],
308    minus_out: &mut [f64],
309) {
310    match kernel {
311        Kernel::Scalar | Kernel::ScalarBatch => {
312            dm_compute_into_scalar(high, low, period, first, plus_out, minus_out)
313        }
314        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
315        Kernel::Avx2 | Kernel::Avx2Batch => unsafe {
316            dm_compute_into_avx2(high, low, period, first, plus_out, minus_out)
317        },
318        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
319        Kernel::Avx512 | Kernel::Avx512Batch => unsafe {
320            dm_compute_into_avx512(high, low, period, first, plus_out, minus_out)
321        },
322        _ => unreachable!(),
323    }
324}
325
326#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
327#[target_feature(enable = "avx2")]
328unsafe fn dm_compute_into_avx2(
329    high: &[f64],
330    low: &[f64],
331    period: usize,
332    first: usize,
333    plus_out: &mut [f64],
334    minus_out: &mut [f64],
335) {
336    use core::arch::x86_64::*;
337    debug_assert_eq!(high.len(), low.len());
338    let n = high.len();
339    if n == 0 {
340        return;
341    }
342
343    let end_init = first + period - 1;
344    let inv_p = 1.0 / (period as f64);
345    let zero = _mm256_setzero_pd();
346
347    let mut sum_plus = 0.0f64;
348    let mut sum_minus = 0.0f64;
349    let mut i = first + 1;
350    let warm_stop = end_init + 1;
351    while i + 4 <= warm_stop {
352        let hc = _mm256_loadu_pd(high.as_ptr().add(i));
353        let hp = _mm256_loadu_pd(high.as_ptr().add(i - 1));
354        let dp = _mm256_sub_pd(hc, hp);
355
356        let lp = _mm256_loadu_pd(low.as_ptr().add(i - 1));
357        let lc = _mm256_loadu_pd(low.as_ptr().add(i));
358        let dm = _mm256_sub_pd(lp, lc);
359
360        let dp_pos = _mm256_max_pd(dp, zero);
361        let dm_pos = _mm256_max_pd(dm, zero);
362
363        let p_mask = _mm256_cmp_pd(dp_pos, dm_pos, _CMP_GT_OQ);
364        let m_mask = _mm256_cmp_pd(dm_pos, dp_pos, _CMP_GT_OQ);
365        let p_vec = _mm256_and_pd(dp_pos, p_mask);
366        let m_vec = _mm256_and_pd(dm_pos, m_mask);
367
368        let mut p_buf = [0.0f64; 4];
369        let mut m_buf = [0.0f64; 4];
370        _mm256_storeu_pd(p_buf.as_mut_ptr(), p_vec);
371        _mm256_storeu_pd(m_buf.as_mut_ptr(), m_vec);
372        sum_plus += p_buf.iter().sum::<f64>();
373        sum_minus += m_buf.iter().sum::<f64>();
374        i += 4;
375    }
376    while i < warm_stop {
377        let dp = *high.get_unchecked(i) - *high.get_unchecked(i - 1);
378        let dm = *low.get_unchecked(i - 1) - *low.get_unchecked(i);
379        if dp > 0.0 && dp > dm {
380            sum_plus += dp;
381        } else if dm > 0.0 && dm > dp {
382            sum_minus += dm;
383        }
384        i += 1;
385    }
386
387    *plus_out.get_unchecked_mut(end_init) = sum_plus;
388    *minus_out.get_unchecked_mut(end_init) = sum_minus;
389
390    if end_init + 1 >= n {
391        return;
392    }
393
394    let mut j = end_init + 1;
395    while j + 4 <= n {
396        let hc = _mm256_loadu_pd(high.as_ptr().add(j));
397        let hp = _mm256_loadu_pd(high.as_ptr().add(j - 1));
398        let dp = _mm256_sub_pd(hc, hp);
399
400        let lp = _mm256_loadu_pd(low.as_ptr().add(j - 1));
401        let lc = _mm256_loadu_pd(low.as_ptr().add(j));
402        let dm = _mm256_sub_pd(lp, lc);
403
404        let dp_pos = _mm256_max_pd(dp, zero);
405        let dm_pos = _mm256_max_pd(dm, zero);
406
407        let p_mask = _mm256_cmp_pd(dp_pos, dm_pos, _CMP_GT_OQ);
408        let m_mask = _mm256_cmp_pd(dm_pos, dp_pos, _CMP_GT_OQ);
409        let p_vec = _mm256_and_pd(dp_pos, p_mask);
410        let m_vec = _mm256_and_pd(dm_pos, m_mask);
411
412        let mut p_buf = [0.0f64; 4];
413        let mut m_buf = [0.0f64; 4];
414        _mm256_storeu_pd(p_buf.as_mut_ptr(), p_vec);
415        _mm256_storeu_pd(m_buf.as_mut_ptr(), m_vec);
416
417        #[cfg(target_feature = "fma")]
418        {
419            sum_plus = (-inv_p).mul_add(sum_plus, sum_plus + p_buf[0]);
420            sum_minus = (-inv_p).mul_add(sum_minus, sum_minus + m_buf[0]);
421            *plus_out.get_unchecked_mut(j) = sum_plus;
422            *minus_out.get_unchecked_mut(j) = sum_minus;
423
424            sum_plus = (-inv_p).mul_add(sum_plus, sum_plus + p_buf[1]);
425            sum_minus = (-inv_p).mul_add(sum_minus, sum_minus + m_buf[1]);
426            *plus_out.get_unchecked_mut(j + 1) = sum_plus;
427            *minus_out.get_unchecked_mut(j + 1) = sum_minus;
428
429            sum_plus = (-inv_p).mul_add(sum_plus, sum_plus + p_buf[2]);
430            sum_minus = (-inv_p).mul_add(sum_minus, sum_minus + m_buf[2]);
431            *plus_out.get_unchecked_mut(j + 2) = sum_plus;
432            *minus_out.get_unchecked_mut(j + 2) = sum_minus;
433
434            sum_plus = (-inv_p).mul_add(sum_plus, sum_plus + p_buf[3]);
435            sum_minus = (-inv_p).mul_add(sum_minus, sum_minus + m_buf[3]);
436            *plus_out.get_unchecked_mut(j + 3) = sum_plus;
437            *minus_out.get_unchecked_mut(j + 3) = sum_minus;
438        }
439        #[cfg(not(target_feature = "fma"))]
440        {
441            sum_plus = sum_plus - (sum_plus * inv_p) + p_buf[0];
442            sum_minus = sum_minus - (sum_minus * inv_p) + m_buf[0];
443            *plus_out.get_unchecked_mut(j) = sum_plus;
444            *minus_out.get_unchecked_mut(j) = sum_minus;
445
446            sum_plus = sum_plus - (sum_plus * inv_p) + p_buf[1];
447            sum_minus = sum_minus - (sum_minus * inv_p) + m_buf[1];
448            *plus_out.get_unchecked_mut(j + 1) = sum_plus;
449            *minus_out.get_unchecked_mut(j + 1) = sum_minus;
450
451            sum_plus = sum_plus - (sum_plus * inv_p) + p_buf[2];
452            sum_minus = sum_minus - (sum_minus * inv_p) + m_buf[2];
453            *plus_out.get_unchecked_mut(j + 2) = sum_plus;
454            *minus_out.get_unchecked_mut(j + 2) = sum_minus;
455
456            sum_plus = sum_plus - (sum_plus * inv_p) + p_buf[3];
457            sum_minus = sum_minus - (sum_minus * inv_p) + m_buf[3];
458            *plus_out.get_unchecked_mut(j + 3) = sum_plus;
459            *minus_out.get_unchecked_mut(j + 3) = sum_minus;
460        }
461        j += 4;
462    }
463
464    while j < n {
465        let dp = *high.get_unchecked(j) - *high.get_unchecked(j - 1);
466        let dm = *low.get_unchecked(j - 1) - *low.get_unchecked(j);
467
468        let (p, m) = if dp > 0.0 && dp > dm {
469            (dp, 0.0)
470        } else if dm > 0.0 && dm > dp {
471            (0.0, dm)
472        } else {
473            (0.0, 0.0)
474        };
475
476        #[cfg(target_feature = "fma")]
477        {
478            sum_plus = (-inv_p).mul_add(sum_plus, sum_plus + p);
479            sum_minus = (-inv_p).mul_add(sum_minus, sum_minus + m);
480        }
481        #[cfg(not(target_feature = "fma"))]
482        {
483            sum_plus = sum_plus - (sum_plus * inv_p) + p;
484            sum_minus = sum_minus - (sum_minus * inv_p) + m;
485        }
486        *plus_out.get_unchecked_mut(j) = sum_plus;
487        *minus_out.get_unchecked_mut(j) = sum_minus;
488        j += 1;
489    }
490}
491
492#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
493#[target_feature(enable = "avx512f")]
494unsafe fn dm_compute_into_avx512(
495    high: &[f64],
496    low: &[f64],
497    period: usize,
498    first: usize,
499    plus_out: &mut [f64],
500    minus_out: &mut [f64],
501) {
502    use core::arch::x86_64::*;
503    debug_assert_eq!(high.len(), low.len());
504    let n = high.len();
505    if n == 0 {
506        return;
507    }
508
509    let end_init = first + period - 1;
510    let inv_p = 1.0 / (period as f64);
511    let zero = _mm512_set1_pd(0.0);
512
513    let mut sum_plus = 0.0f64;
514    let mut sum_minus = 0.0f64;
515    let mut i = first + 1;
516    let warm_stop = end_init + 1;
517    while i + 8 <= warm_stop {
518        let hc = _mm512_loadu_pd(high.as_ptr().add(i));
519        let hp = _mm512_loadu_pd(high.as_ptr().add(i - 1));
520        let dp = _mm512_sub_pd(hc, hp);
521
522        let lp = _mm512_loadu_pd(low.as_ptr().add(i - 1));
523        let lc = _mm512_loadu_pd(low.as_ptr().add(i));
524        let dm = _mm512_sub_pd(lp, lc);
525
526        let dp_pos = _mm512_max_pd(dp, zero);
527        let dm_pos = _mm512_max_pd(dm, zero);
528
529        let p_mask = _mm512_cmp_pd_mask(dp_pos, dm_pos, _CMP_GT_OQ);
530        let m_mask = _mm512_cmp_pd_mask(dm_pos, dp_pos, _CMP_GT_OQ);
531        let p_vec = _mm512_maskz_mov_pd(p_mask, dp_pos);
532        let m_vec = _mm512_maskz_mov_pd(m_mask, dm_pos);
533
534        let mut p_buf = [0.0f64; 8];
535        let mut m_buf = [0.0f64; 8];
536        _mm512_storeu_pd(p_buf.as_mut_ptr(), p_vec);
537        _mm512_storeu_pd(m_buf.as_mut_ptr(), m_vec);
538        for k in 0..8 {
539            sum_plus += p_buf[k];
540            sum_minus += m_buf[k];
541        }
542        i += 8;
543    }
544    while i < warm_stop {
545        let dp = *high.get_unchecked(i) - *high.get_unchecked(i - 1);
546        let dm = *low.get_unchecked(i - 1) - *low.get_unchecked(i);
547        if dp > 0.0 && dp > dm {
548            sum_plus += dp;
549        } else if dm > 0.0 && dm > dp {
550            sum_minus += dm;
551        }
552        i += 1;
553    }
554    *plus_out.get_unchecked_mut(end_init) = sum_plus;
555    *minus_out.get_unchecked_mut(end_init) = sum_minus;
556
557    if end_init + 1 >= n {
558        return;
559    }
560
561    let mut j = end_init + 1;
562    while j + 8 <= n {
563        let hc = _mm512_loadu_pd(high.as_ptr().add(j));
564        let hp = _mm512_loadu_pd(high.as_ptr().add(j - 1));
565        let dp = _mm512_sub_pd(hc, hp);
566
567        let lp = _mm512_loadu_pd(low.as_ptr().add(j - 1));
568        let lc = _mm512_loadu_pd(low.as_ptr().add(j));
569        let dm = _mm512_sub_pd(lp, lc);
570
571        let dp_pos = _mm512_max_pd(dp, zero);
572        let dm_pos = _mm512_max_pd(dm, zero);
573
574        let p_mask = _mm512_cmp_pd_mask(dp_pos, dm_pos, _CMP_GT_OQ);
575        let m_mask = _mm512_cmp_pd_mask(dm_pos, dp_pos, _CMP_GT_OQ);
576        let p_vec = _mm512_maskz_mov_pd(p_mask, dp_pos);
577        let m_vec = _mm512_maskz_mov_pd(m_mask, dm_pos);
578
579        let mut p_buf = [0.0f64; 8];
580        let mut m_buf = [0.0f64; 8];
581        _mm512_storeu_pd(p_buf.as_mut_ptr(), p_vec);
582        _mm512_storeu_pd(m_buf.as_mut_ptr(), m_vec);
583
584        #[cfg(target_feature = "fma")]
585        {
586            for t in 0..8 {
587                sum_plus = (-inv_p).mul_add(sum_plus, sum_plus + p_buf[t]);
588                sum_minus = (-inv_p).mul_add(sum_minus, sum_minus + m_buf[t]);
589                *plus_out.get_unchecked_mut(j + t) = sum_plus;
590                *minus_out.get_unchecked_mut(j + t) = sum_minus;
591            }
592        }
593        #[cfg(not(target_feature = "fma"))]
594        {
595            for t in 0..8 {
596                sum_plus = sum_plus - (sum_plus * inv_p) + p_buf[t];
597                sum_minus = sum_minus - (sum_minus * inv_p) + m_buf[t];
598                *plus_out.get_unchecked_mut(j + t) = sum_plus;
599                *minus_out.get_unchecked_mut(j + t) = sum_minus;
600            }
601        }
602        j += 8;
603    }
604    while j < n {
605        let dp = *high.get_unchecked(j) - *high.get_unchecked(j - 1);
606        let dm = *low.get_unchecked(j - 1) - *low.get_unchecked(j);
607
608        let (p, m) = if dp > 0.0 && dp > dm {
609            (dp, 0.0)
610        } else if dm > 0.0 && dm > dp {
611            (0.0, dm)
612        } else {
613            (0.0, 0.0)
614        };
615
616        #[cfg(target_feature = "fma")]
617        {
618            sum_plus = (-inv_p).mul_add(sum_plus, sum_plus + p);
619            sum_minus = (-inv_p).mul_add(sum_minus, sum_minus + m);
620        }
621        #[cfg(not(target_feature = "fma"))]
622        {
623            sum_plus = sum_plus - (sum_plus * inv_p) + p;
624            sum_minus = sum_minus - (sum_minus * inv_p) + m;
625        }
626        *plus_out.get_unchecked_mut(j) = sum_plus;
627        *minus_out.get_unchecked_mut(j) = sum_minus;
628        j += 1;
629    }
630}
631
632pub fn dm_with_kernel(input: &DmInput, kernel: Kernel) -> Result<DmOutput, DmError> {
633    let (high, low, period, first, chosen) = dm_prepare(input, kernel)?;
634    let warm = first + period - 1;
635
636    let mut plus = alloc_with_nan_prefix(high.len(), warm);
637    let mut minus = alloc_with_nan_prefix(high.len(), warm);
638
639    dm_compute_into(high, low, period, first, chosen, &mut plus, &mut minus);
640    Ok(DmOutput { plus, minus })
641}
642
643#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
644#[inline]
645pub fn dm_into(
646    input: &DmInput,
647    plus_out: &mut [f64],
648    minus_out: &mut [f64],
649) -> Result<(), DmError> {
650    let (high, low, period, first, chosen) = dm_prepare(input, Kernel::Auto)?;
651
652    if plus_out.len() != high.len() {
653        return Err(DmError::OutputLengthMismatch {
654            expected: high.len(),
655            got: plus_out.len(),
656        });
657    }
658    if minus_out.len() != high.len() {
659        return Err(DmError::OutputLengthMismatch {
660            expected: high.len(),
661            got: minus_out.len(),
662        });
663    }
664
665    let warm = first + period - 1;
666    let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
667    let warm_end = warm.min(high.len());
668    for v in &mut plus_out[..warm_end] {
669        *v = qnan;
670    }
671    for v in &mut minus_out[..warm_end] {
672        *v = qnan;
673    }
674
675    dm_compute_into(high, low, period, first, chosen, plus_out, minus_out);
676    Ok(())
677}
678
679#[inline]
680pub fn dm_into_slice(
681    plus_dst: &mut [f64],
682    minus_dst: &mut [f64],
683    input: &DmInput,
684    kernel: Kernel,
685) -> Result<(), DmError> {
686    let (high, low, period, first, chosen) = dm_prepare(input, kernel)?;
687    if plus_dst.len() != high.len() {
688        return Err(DmError::OutputLengthMismatch {
689            expected: high.len(),
690            got: plus_dst.len(),
691        });
692    }
693    if minus_dst.len() != high.len() {
694        return Err(DmError::OutputLengthMismatch {
695            expected: high.len(),
696            got: minus_dst.len(),
697        });
698    }
699
700    dm_compute_into(high, low, period, first, chosen, plus_dst, minus_dst);
701
702    let warm = first + period - 1;
703    for v in &mut plus_dst[..warm] {
704        *v = f64::NAN;
705    }
706    for v in &mut minus_dst[..warm] {
707        *v = f64::NAN;
708    }
709    Ok(())
710}
711
712#[inline]
713pub unsafe fn dm_scalar(
714    high: &[f64],
715    low: &[f64],
716    period: usize,
717    first_valid_idx: usize,
718) -> Result<DmOutput, DmError> {
719    let warm = first_valid_idx + period - 1;
720    let mut plus_dm = alloc_with_nan_prefix(high.len(), warm);
721    let mut minus_dm = alloc_with_nan_prefix(high.len(), warm);
722
723    dm_compute_into_scalar(
724        high,
725        low,
726        period,
727        first_valid_idx,
728        &mut plus_dm,
729        &mut minus_dm,
730    );
731
732    Ok(DmOutput {
733        plus: plus_dm,
734        minus: minus_dm,
735    })
736}
737
738#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
739#[inline]
740pub unsafe fn dm_avx2(
741    high: &[f64],
742    low: &[f64],
743    period: usize,
744    first_valid_idx: usize,
745) -> Result<DmOutput, DmError> {
746    let warm = first_valid_idx + period - 1;
747    let mut plus_dm = alloc_with_nan_prefix(high.len(), warm);
748    let mut minus_dm = alloc_with_nan_prefix(high.len(), warm);
749    dm_compute_into_avx2(
750        high,
751        low,
752        period,
753        first_valid_idx,
754        &mut plus_dm,
755        &mut minus_dm,
756    );
757    Ok(DmOutput {
758        plus: plus_dm,
759        minus: minus_dm,
760    })
761}
762
763#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
764#[inline]
765pub unsafe fn dm_avx512(
766    high: &[f64],
767    low: &[f64],
768    period: usize,
769    first_valid_idx: usize,
770) -> Result<DmOutput, DmError> {
771    let warm = first_valid_idx + period - 1;
772    let mut plus_dm = alloc_with_nan_prefix(high.len(), warm);
773    let mut minus_dm = alloc_with_nan_prefix(high.len(), warm);
774    dm_compute_into_avx512(
775        high,
776        low,
777        period,
778        first_valid_idx,
779        &mut plus_dm,
780        &mut minus_dm,
781    );
782    Ok(DmOutput {
783        plus: plus_dm,
784        minus: minus_dm,
785    })
786}
787
788#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
789#[inline]
790pub unsafe fn dm_avx512_short(
791    high: &[f64],
792    low: &[f64],
793    period: usize,
794    first_valid_idx: usize,
795) -> Result<DmOutput, DmError> {
796    dm_avx512(high, low, period, first_valid_idx)
797}
798
799#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
800#[inline]
801pub unsafe fn dm_avx512_long(
802    high: &[f64],
803    low: &[f64],
804    period: usize,
805    first_valid_idx: usize,
806) -> Result<DmOutput, DmError> {
807    dm_avx512(high, low, period, first_valid_idx)
808}
809
810#[derive(Debug, Clone)]
811pub struct DmStream {
812    period: usize,
813    inv_period: f64,
814    sum_plus: f64,
815    sum_minus: f64,
816    prev_high: f64,
817    prev_low: f64,
818    count: usize,
819}
820
821impl DmStream {
822    pub fn try_new(params: DmParams) -> Result<Self, DmError> {
823        let period = params.period.unwrap_or(14);
824        if period == 0 {
825            return Err(DmError::InvalidPeriod {
826                period,
827                data_len: 0,
828            });
829        }
830        let inv = 1.0 / (period as f64);
831        Ok(Self {
832            period,
833            inv_period: inv,
834            sum_plus: 0.0,
835            sum_minus: 0.0,
836            prev_high: f64::NAN,
837            prev_low: f64::NAN,
838            count: 0,
839        })
840    }
841
842    #[inline(always)]
843    pub fn update(&mut self, high: f64, low: f64) -> Option<(f64, f64)> {
844        if self.count == 0 {
845            self.prev_high = high;
846            self.prev_low = low;
847        }
848
849        let dp = high - self.prev_high;
850        let dm = self.prev_low - low;
851
852        self.prev_high = high;
853        self.prev_low = low;
854
855        let dp_pos = dp.max(0.0);
856        let dm_pos = dm.max(0.0);
857
858        let plus_val = if dp_pos > dm_pos { dp_pos } else { 0.0 };
859        let minus_val = if dm_pos > dp_pos { dm_pos } else { 0.0 };
860
861        if self.count < self.period - 1 {
862            self.sum_plus += plus_val;
863            self.sum_minus += minus_val;
864            self.count += 1;
865            return None;
866        } else if self.count == self.period - 1 {
867            self.sum_plus += plus_val;
868            self.sum_minus += minus_val;
869            self.count += 1;
870            return Some((self.sum_plus, self.sum_minus));
871        }
872
873        #[cfg(target_feature = "fma")]
874        {
875            self.sum_plus = (-self.inv_period).mul_add(self.sum_plus, self.sum_plus + plus_val);
876            self.sum_minus = (-self.inv_period).mul_add(self.sum_minus, self.sum_minus + minus_val);
877        }
878        #[cfg(not(target_feature = "fma"))]
879        {
880            self.sum_plus = self.sum_plus - (self.sum_plus * self.inv_period) + plus_val;
881            self.sum_minus = self.sum_minus - (self.sum_minus * self.inv_period) + minus_val;
882        }
883
884        Some((self.sum_plus, self.sum_minus))
885    }
886}
887
888#[derive(Clone, Debug)]
889pub struct DmBatchRange {
890    pub period: (usize, usize, usize),
891}
892
893impl Default for DmBatchRange {
894    fn default() -> Self {
895        Self {
896            period: (14, 263, 1),
897        }
898    }
899}
900
901#[derive(Clone, Debug, Default)]
902pub struct DmBatchBuilder {
903    range: DmBatchRange,
904    kernel: Kernel,
905}
906
907impl DmBatchBuilder {
908    pub fn new() -> Self {
909        Self::default()
910    }
911    pub fn kernel(mut self, k: Kernel) -> Self {
912        self.kernel = k;
913        self
914    }
915    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
916        self.range.period = (start, end, step);
917        self
918    }
919    pub fn period_static(mut self, p: usize) -> Self {
920        self.range.period = (p, p, 0);
921        self
922    }
923    pub fn apply_slices(self, high: &[f64], low: &[f64]) -> Result<DmBatchOutput, DmError> {
924        dm_batch_with_kernel(high, low, &self.range, self.kernel)
925    }
926    pub fn apply_candles(self, c: &Candles) -> Result<DmBatchOutput, DmError> {
927        let high = c
928            .select_candle_field("high")
929            .map_err(|_| DmError::EmptyInputData)?;
930        let low = c
931            .select_candle_field("low")
932            .map_err(|_| DmError::EmptyInputData)?;
933        self.apply_slices(high, low)
934    }
935    pub fn with_default_candles(c: &Candles) -> Result<DmBatchOutput, DmError> {
936        DmBatchBuilder::new().kernel(Kernel::Auto).apply_candles(c)
937    }
938}
939
940#[derive(Clone, Debug)]
941pub struct DmBatchOutput {
942    pub plus: Vec<f64>,
943    pub minus: Vec<f64>,
944    pub combos: Vec<DmParams>,
945    pub rows: usize,
946    pub cols: usize,
947}
948impl DmBatchOutput {
949    pub fn row_for_params(&self, p: &DmParams) -> Option<usize> {
950        self.combos
951            .iter()
952            .position(|c| c.period.unwrap_or(14) == p.period.unwrap_or(14))
953    }
954    pub fn values_for(&self, p: &DmParams) -> Option<(&[f64], &[f64])> {
955        self.row_for_params(p).map(|row| {
956            let start = row * self.cols;
957            (
958                &self.plus[start..start + self.cols],
959                &self.minus[start..start + self.cols],
960            )
961        })
962    }
963}
964
965#[inline(always)]
966fn expand_grid(r: &DmBatchRange) -> Result<Vec<DmParams>, DmError> {
967    fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, DmError> {
968        if step == 0 || start == end {
969            return Ok(vec![start]);
970        }
971        if start < end {
972            let mut v = Vec::new();
973            let st = step.max(1);
974            let mut x = start;
975            while x <= end {
976                v.push(x);
977                match x.checked_add(st) {
978                    Some(next) => x = next,
979                    None => break,
980                }
981            }
982            if v.is_empty() {
983                return Err(DmError::InvalidRange { start, end, step });
984            }
985            return Ok(v);
986        }
987
988        let mut v = Vec::new();
989        let st = step.max(1) as isize;
990        let mut x = start as isize;
991        let end_i = end as isize;
992        while x >= end_i {
993            v.push(x as usize);
994            x -= st;
995        }
996        if v.is_empty() {
997            return Err(DmError::InvalidRange { start, end, step });
998        }
999        Ok(v)
1000    }
1001
1002    let periods = axis_usize(r.period)?;
1003    let mut out = Vec::with_capacity(periods.len());
1004    for p in periods {
1005        out.push(DmParams { period: Some(p) });
1006    }
1007    Ok(out)
1008}
1009
1010pub fn dm_batch_with_kernel(
1011    high: &[f64],
1012    low: &[f64],
1013    sweep: &DmBatchRange,
1014    k: Kernel,
1015) -> Result<DmBatchOutput, DmError> {
1016    let kernel = match k {
1017        Kernel::Auto => detect_best_batch_kernel(),
1018        other if other.is_batch() => other,
1019        _ => return Err(DmError::InvalidKernelForBatch(k)),
1020    };
1021    let simd = match kernel {
1022        Kernel::Avx512Batch => Kernel::Avx512,
1023        Kernel::Avx2Batch => Kernel::Avx2,
1024        Kernel::ScalarBatch => Kernel::Scalar,
1025        _ => unreachable!(),
1026    };
1027    dm_batch_par_slice(high, low, sweep, simd)
1028}
1029
1030#[inline(always)]
1031pub fn dm_batch_slice(
1032    high: &[f64],
1033    low: &[f64],
1034    sweep: &DmBatchRange,
1035    kern: Kernel,
1036) -> Result<DmBatchOutput, DmError> {
1037    dm_batch_inner(high, low, sweep, kern, false)
1038}
1039
1040#[inline(always)]
1041pub fn dm_batch_par_slice(
1042    high: &[f64],
1043    low: &[f64],
1044    sweep: &DmBatchRange,
1045    kern: Kernel,
1046) -> Result<DmBatchOutput, DmError> {
1047    dm_batch_inner(high, low, sweep, kern, true)
1048}
1049
1050#[inline(always)]
1051fn dm_batch_inner_into(
1052    high: &[f64],
1053    low: &[f64],
1054    sweep: &DmBatchRange,
1055    kern: Kernel,
1056    parallel: bool,
1057    first: usize,
1058    plus_out: &mut [f64],
1059    minus_out: &mut [f64],
1060) -> Result<Vec<DmParams>, DmError> {
1061    let combos = expand_grid(sweep)?;
1062
1063    let rows = combos.len();
1064    let cols = high.len();
1065
1066    let _total = rows.checked_mul(cols).ok_or(DmError::InvalidRange {
1067        start: sweep.period.0,
1068        end: sweep.period.1,
1069        step: sweep.period.2,
1070    })?;
1071    let chosen = match kern {
1072        Kernel::Auto => detect_best_batch_kernel(),
1073        k => k,
1074    };
1075
1076    let do_row = |row: usize, plus_row: &mut [f64], minus_row: &mut [f64]| {
1077        let p = combos[row].period.unwrap();
1078        dm_compute_into(
1079            high,
1080            low,
1081            p,
1082            first,
1083            match chosen {
1084                Kernel::Avx512Batch => Kernel::Avx512,
1085                Kernel::Avx2Batch => Kernel::Avx2,
1086                Kernel::ScalarBatch => Kernel::Scalar,
1087                k => k,
1088            },
1089            plus_row,
1090            minus_row,
1091        );
1092    };
1093
1094    if parallel {
1095        #[cfg(not(target_arch = "wasm32"))]
1096        {
1097            use rayon::prelude::*;
1098            plus_out
1099                .par_chunks_mut(cols)
1100                .zip(minus_out.par_chunks_mut(cols))
1101                .enumerate()
1102                .for_each(|(r, (pr, mr))| do_row(r, pr, mr));
1103        }
1104        #[cfg(target_arch = "wasm32")]
1105        {
1106            for (r, (pr, mr)) in plus_out
1107                .chunks_mut(cols)
1108                .zip(minus_out.chunks_mut(cols))
1109                .enumerate()
1110            {
1111                do_row(r, pr, mr);
1112            }
1113        }
1114    } else {
1115        for (r, (pr, mr)) in plus_out
1116            .chunks_mut(cols)
1117            .zip(minus_out.chunks_mut(cols))
1118            .enumerate()
1119        {
1120            do_row(r, pr, mr);
1121        }
1122    }
1123
1124    Ok(combos)
1125}
1126
1127#[inline(always)]
1128fn dm_batch_inner(
1129    high: &[f64],
1130    low: &[f64],
1131    sweep: &DmBatchRange,
1132    kern: Kernel,
1133    parallel: bool,
1134) -> Result<DmBatchOutput, DmError> {
1135    let combos = expand_grid(sweep)?;
1136
1137    let first = high
1138        .iter()
1139        .zip(low.iter())
1140        .position(|(&h, &l)| !h.is_nan() && !l.is_nan())
1141        .ok_or(DmError::AllValuesNaN)?;
1142
1143    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1144    if high.len() - first < max_p {
1145        return Err(DmError::NotEnoughValidData {
1146            needed: max_p,
1147            valid: high.len() - first,
1148        });
1149    }
1150
1151    let rows = combos.len();
1152    let cols = high.len();
1153
1154    let _total = rows.checked_mul(cols).ok_or(DmError::InvalidRange {
1155        start: sweep.period.0,
1156        end: sweep.period.1,
1157        step: sweep.period.2,
1158    })?;
1159
1160    let mut plus_mu = make_uninit_matrix(rows, cols);
1161    let mut minus_mu = make_uninit_matrix(rows, cols);
1162
1163    let warm: Vec<usize> = combos
1164        .iter()
1165        .map(|c| first + c.period.unwrap() - 1)
1166        .collect();
1167    init_matrix_prefixes(&mut plus_mu, cols, &warm);
1168    init_matrix_prefixes(&mut minus_mu, cols, &warm);
1169
1170    let mut plus_guard = core::mem::ManuallyDrop::new(plus_mu);
1171    let mut minus_guard = core::mem::ManuallyDrop::new(minus_mu);
1172    let plus_out: &mut [f64] = unsafe {
1173        core::slice::from_raw_parts_mut(plus_guard.as_mut_ptr() as *mut f64, plus_guard.len())
1174    };
1175    let minus_out: &mut [f64] = unsafe {
1176        core::slice::from_raw_parts_mut(minus_guard.as_mut_ptr() as *mut f64, minus_guard.len())
1177    };
1178
1179    let combos = dm_batch_inner_into(high, low, sweep, kern, parallel, first, plus_out, minus_out)?;
1180
1181    let plus = unsafe {
1182        Vec::from_raw_parts(
1183            plus_guard.as_mut_ptr() as *mut f64,
1184            plus_guard.len(),
1185            plus_guard.capacity(),
1186        )
1187    };
1188    let minus = unsafe {
1189        Vec::from_raw_parts(
1190            minus_guard.as_mut_ptr() as *mut f64,
1191            minus_guard.len(),
1192            minus_guard.capacity(),
1193        )
1194    };
1195
1196    Ok(DmBatchOutput {
1197        plus,
1198        minus,
1199        combos,
1200        rows,
1201        cols,
1202    })
1203}
1204
1205#[inline(always)]
1206unsafe fn dm_row_scalar(
1207    high: &[f64],
1208    low: &[f64],
1209    first: usize,
1210    period: usize,
1211    plus: &mut [f64],
1212    minus: &mut [f64],
1213) {
1214    let mut prev_high = high[first];
1215    let mut prev_low = low[first];
1216    let mut sum_plus = 0.0;
1217    let mut sum_minus = 0.0;
1218
1219    let end_init = first + period - 1;
1220    for i in (first + 1)..=end_init {
1221        let diff_p = high[i] - prev_high;
1222        let diff_m = prev_low - low[i];
1223        prev_high = high[i];
1224        prev_low = low[i];
1225
1226        let plus_val = if diff_p > 0.0 && diff_p > diff_m {
1227            diff_p
1228        } else {
1229            0.0
1230        };
1231        let minus_val = if diff_m > 0.0 && diff_m > diff_p {
1232            diff_m
1233        } else {
1234            0.0
1235        };
1236
1237        sum_plus += plus_val;
1238        sum_minus += minus_val;
1239    }
1240
1241    plus[end_init] = sum_plus;
1242    minus[end_init] = sum_minus;
1243
1244    let inv_period = 1.0 / (period as f64);
1245
1246    for i in (end_init + 1)..high.len() {
1247        let diff_p = high[i] - prev_high;
1248        let diff_m = prev_low - low[i];
1249        prev_high = high[i];
1250        prev_low = low[i];
1251
1252        let plus_val = if diff_p > 0.0 && diff_p > diff_m {
1253            diff_p
1254        } else {
1255            0.0
1256        };
1257        let minus_val = if diff_m > 0.0 && diff_m > diff_p {
1258            diff_m
1259        } else {
1260            0.0
1261        };
1262
1263        sum_plus = sum_plus - (sum_plus * inv_period) + plus_val;
1264        sum_minus = sum_minus - (sum_minus * inv_period) + minus_val;
1265
1266        plus[i] = sum_plus;
1267        minus[i] = sum_minus;
1268    }
1269}
1270
1271#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1272#[inline(always)]
1273unsafe fn dm_row_avx2(
1274    high: &[f64],
1275    low: &[f64],
1276    first: usize,
1277    period: usize,
1278    plus: &mut [f64],
1279    minus: &mut [f64],
1280) {
1281    dm_row_scalar(high, low, first, period, plus, minus)
1282}
1283
1284#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1285#[inline(always)]
1286unsafe fn dm_row_avx512(
1287    high: &[f64],
1288    low: &[f64],
1289    first: usize,
1290    period: usize,
1291    plus: &mut [f64],
1292    minus: &mut [f64],
1293) {
1294    dm_row_scalar(high, low, first, period, plus, minus)
1295}
1296
1297#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1298#[inline(always)]
1299unsafe fn dm_row_avx512_short(
1300    high: &[f64],
1301    low: &[f64],
1302    first: usize,
1303    period: usize,
1304    plus: &mut [f64],
1305    minus: &mut [f64],
1306) {
1307    dm_row_avx512(high, low, first, period, plus, minus)
1308}
1309
1310#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1311#[inline(always)]
1312unsafe fn dm_row_avx512_long(
1313    high: &[f64],
1314    low: &[f64],
1315    first: usize,
1316    period: usize,
1317    plus: &mut [f64],
1318    minus: &mut [f64],
1319) {
1320    dm_row_avx512(high, low, first, period, plus, minus)
1321}
1322
1323#[cfg(test)]
1324mod tests {
1325    use super::*;
1326    use crate::skip_if_unsupported;
1327    use crate::utilities::data_loader::read_candles_from_csv;
1328
1329    fn check_dm_partial_params(
1330        test_name: &str,
1331        kernel: Kernel,
1332    ) -> Result<(), Box<dyn std::error::Error>> {
1333        skip_if_unsupported!(kernel, test_name);
1334        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1335        let candles = read_candles_from_csv(file_path)?;
1336        let default_params = DmParams { period: None };
1337        let input_default = DmInput::from_candles(&candles, default_params);
1338        let output_default = dm_with_kernel(&input_default, kernel)?;
1339        assert_eq!(output_default.plus.len(), candles.high.len());
1340        assert_eq!(output_default.minus.len(), candles.high.len());
1341
1342        let params_custom = DmParams { period: Some(10) };
1343        let input_custom = DmInput::from_candles(&candles, params_custom);
1344        let output_custom = dm_with_kernel(&input_custom, kernel)?;
1345        assert_eq!(output_custom.plus.len(), candles.high.len());
1346        assert_eq!(output_custom.minus.len(), candles.high.len());
1347        Ok(())
1348    }
1349
1350    fn check_dm_default_candles(
1351        test_name: &str,
1352        kernel: Kernel,
1353    ) -> Result<(), Box<dyn std::error::Error>> {
1354        skip_if_unsupported!(kernel, test_name);
1355        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1356        let candles = read_candles_from_csv(file_path)?;
1357        let input = DmInput::with_default_candles(&candles);
1358        let result = dm_with_kernel(&input, kernel)?;
1359        assert_eq!(result.plus.len(), candles.high.len());
1360        assert_eq!(result.minus.len(), candles.high.len());
1361        Ok(())
1362    }
1363
1364    fn check_dm_with_slice_data(
1365        test_name: &str,
1366        kernel: Kernel,
1367    ) -> Result<(), Box<dyn std::error::Error>> {
1368        skip_if_unsupported!(kernel, test_name);
1369        let high_values = [8000.0, 8050.0, 8100.0, 8075.0, 8110.0, 8050.0];
1370        let low_values = [7800.0, 7900.0, 7950.0, 7950.0, 8000.0, 7950.0];
1371        let params = DmParams { period: Some(3) };
1372        let input = DmInput::from_slices(&high_values, &low_values, params);
1373        let result = dm_with_kernel(&input, kernel)?;
1374        assert_eq!(result.plus.len(), 6);
1375        assert_eq!(result.minus.len(), 6);
1376
1377        for i in 0..2 {
1378            assert!(result.plus[i].is_nan());
1379            assert!(result.minus[i].is_nan());
1380        }
1381        Ok(())
1382    }
1383
1384    fn check_dm_zero_period(
1385        test_name: &str,
1386        kernel: Kernel,
1387    ) -> Result<(), Box<dyn std::error::Error>> {
1388        skip_if_unsupported!(kernel, test_name);
1389        let high_values = [100.0, 110.0, 120.0];
1390        let low_values = [90.0, 100.0, 110.0];
1391        let params = DmParams { period: Some(0) };
1392        let input = DmInput::from_slices(&high_values, &low_values, params);
1393        let result = dm_with_kernel(&input, kernel);
1394        assert!(result.is_err());
1395        Ok(())
1396    }
1397
1398    fn check_dm_period_exceeds_data_length(
1399        test_name: &str,
1400        kernel: Kernel,
1401    ) -> Result<(), Box<dyn std::error::Error>> {
1402        skip_if_unsupported!(kernel, test_name);
1403        let high_values = [100.0, 110.0, 120.0];
1404        let low_values = [90.0, 100.0, 110.0];
1405        let params = DmParams { period: Some(10) };
1406        let input = DmInput::from_slices(&high_values, &low_values, params);
1407        let result = dm_with_kernel(&input, kernel);
1408        assert!(result.is_err());
1409        Ok(())
1410    }
1411
1412    fn check_dm_not_enough_valid_data(
1413        test_name: &str,
1414        kernel: Kernel,
1415    ) -> Result<(), Box<dyn std::error::Error>> {
1416        skip_if_unsupported!(kernel, test_name);
1417        let high_values = [f64::NAN, f64::NAN, 100.0, 101.0, 102.0];
1418        let low_values = [f64::NAN, f64::NAN, 90.0, 89.0, 88.0];
1419        let params = DmParams { period: Some(5) };
1420        let input = DmInput::from_slices(&high_values, &low_values, params);
1421        let result = dm_with_kernel(&input, kernel);
1422        assert!(result.is_err());
1423        Ok(())
1424    }
1425
1426    fn check_dm_all_values_nan(
1427        test_name: &str,
1428        kernel: Kernel,
1429    ) -> Result<(), Box<dyn std::error::Error>> {
1430        skip_if_unsupported!(kernel, test_name);
1431        let high_values = [f64::NAN, f64::NAN, f64::NAN];
1432        let low_values = [f64::NAN, f64::NAN, f64::NAN];
1433        let params = DmParams { period: Some(3) };
1434        let input = DmInput::from_slices(&high_values, &low_values, params);
1435        let result = dm_with_kernel(&input, kernel);
1436        assert!(result.is_err());
1437        Ok(())
1438    }
1439
1440    fn check_dm_with_slice_reinput(
1441        test_name: &str,
1442        kernel: Kernel,
1443    ) -> Result<(), Box<dyn std::error::Error>> {
1444        skip_if_unsupported!(kernel, test_name);
1445        let high_values = [9000.0, 9100.0, 9050.0, 9200.0, 9150.0, 9300.0];
1446        let low_values = [8900.0, 9000.0, 8950.0, 9000.0, 9050.0, 9100.0];
1447        let params = DmParams { period: Some(2) };
1448        let input_first = DmInput::from_slices(&high_values, &low_values, params.clone());
1449        let result_first = dm_with_kernel(&input_first, kernel)?;
1450        let input_second = DmInput::from_slices(&result_first.plus, &result_first.minus, params);
1451        let result_second = dm_with_kernel(&input_second, kernel)?;
1452        assert_eq!(result_second.plus.len(), high_values.len());
1453        assert_eq!(result_second.minus.len(), high_values.len());
1454        Ok(())
1455    }
1456
1457    fn check_dm_known_values(
1458        test_name: &str,
1459        kernel: Kernel,
1460    ) -> Result<(), Box<dyn std::error::Error>> {
1461        skip_if_unsupported!(kernel, test_name);
1462        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1463        let candles = read_candles_from_csv(file_path)?;
1464        let params = DmParams { period: Some(14) };
1465        let input = DmInput::from_candles(&candles, params);
1466        let output = dm_with_kernel(&input, kernel)?;
1467
1468        let slice_size = 5;
1469        let last_plus_slice = &output.plus[output.plus.len() - slice_size..];
1470        let last_minus_slice = &output.minus[output.minus.len() - slice_size..];
1471
1472        let expected_plus = [
1473            1410.819956368491,
1474            1384.04710234217,
1475            1285.186595032015,
1476            1199.3875525297283,
1477            1113.7170130633192,
1478        ];
1479        let expected_minus = [
1480            3602.8631384045057,
1481            3345.5157713756125,
1482            3258.5503591344973,
1483            3025.796762053462,
1484            3493.668421906786,
1485        ];
1486
1487        for i in 0..slice_size {
1488            let diff_plus = (last_plus_slice[i] - expected_plus[i]).abs();
1489            let diff_minus = (last_minus_slice[i] - expected_minus[i]).abs();
1490            assert!(diff_plus < 1e-6);
1491            assert!(diff_minus < 1e-6);
1492        }
1493        Ok(())
1494    }
1495
1496    macro_rules! generate_all_dm_tests {
1497        ($($test_fn:ident),*) => {
1498            paste::paste! {
1499                $(
1500                    #[test]
1501                    fn [<$test_fn _scalar_f64>]() {
1502                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1503                    }
1504                )*
1505                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1506                $(
1507                    #[test]
1508                    fn [<$test_fn _avx2_f64>]() {
1509                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1510                    }
1511                    #[test]
1512                    fn [<$test_fn _avx512_f64>]() {
1513                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1514                    }
1515                )*
1516            }
1517        }
1518    }
1519
1520    #[cfg(debug_assertions)]
1521    fn check_dm_no_poison(
1522        test_name: &str,
1523        kernel: Kernel,
1524    ) -> Result<(), Box<dyn std::error::Error>> {
1525        skip_if_unsupported!(kernel, test_name);
1526
1527        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1528        let candles = read_candles_from_csv(file_path)?;
1529
1530        let test_params = vec![
1531            DmParams::default(),
1532            DmParams { period: Some(2) },
1533            DmParams { period: Some(3) },
1534            DmParams { period: Some(5) },
1535            DmParams { period: Some(7) },
1536            DmParams { period: Some(10) },
1537            DmParams { period: Some(14) },
1538            DmParams { period: Some(20) },
1539            DmParams { period: Some(30) },
1540            DmParams { period: Some(50) },
1541            DmParams { period: Some(100) },
1542            DmParams { period: Some(200) },
1543            DmParams { period: Some(25) },
1544        ];
1545
1546        for (param_idx, params) in test_params.iter().enumerate() {
1547            let input = DmInput::from_candles(&candles, params.clone());
1548            let output = dm_with_kernel(&input, kernel)?;
1549
1550            for (i, &val) in output.plus.iter().enumerate() {
1551                if val.is_nan() {
1552                    continue;
1553                }
1554
1555                let bits = val.to_bits();
1556
1557                if bits == 0x11111111_11111111 {
1558                    panic!(
1559						"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in plus array \
1560						 with params: period={} (param set {})",
1561						test_name, val, bits, i,
1562						params.period.unwrap_or(14), param_idx
1563					);
1564                }
1565
1566                if bits == 0x22222222_22222222 {
1567                    panic!(
1568						"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in plus array \
1569						 with params: period={} (param set {})",
1570						test_name, val, bits, i,
1571						params.period.unwrap_or(14), param_idx
1572					);
1573                }
1574
1575                if bits == 0x33333333_33333333 {
1576                    panic!(
1577						"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in plus array \
1578						 with params: period={} (param set {})",
1579						test_name, val, bits, i,
1580						params.period.unwrap_or(14), param_idx
1581					);
1582                }
1583            }
1584
1585            for (i, &val) in output.minus.iter().enumerate() {
1586                if val.is_nan() {
1587                    continue;
1588                }
1589
1590                let bits = val.to_bits();
1591
1592                if bits == 0x11111111_11111111 {
1593                    panic!(
1594						"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in minus array \
1595						 with params: period={} (param set {})",
1596						test_name, val, bits, i,
1597						params.period.unwrap_or(14), param_idx
1598					);
1599                }
1600
1601                if bits == 0x22222222_22222222 {
1602                    panic!(
1603						"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in minus array \
1604						 with params: period={} (param set {})",
1605						test_name, val, bits, i,
1606						params.period.unwrap_or(14), param_idx
1607					);
1608                }
1609
1610                if bits == 0x33333333_33333333 {
1611                    panic!(
1612						"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in minus array \
1613						 with params: period={} (param set {})",
1614						test_name, val, bits, i,
1615						params.period.unwrap_or(14), param_idx
1616					);
1617                }
1618            }
1619        }
1620
1621        Ok(())
1622    }
1623
1624    #[cfg(not(debug_assertions))]
1625    fn check_dm_no_poison(
1626        _test_name: &str,
1627        _kernel: Kernel,
1628    ) -> Result<(), Box<dyn std::error::Error>> {
1629        Ok(())
1630    }
1631
1632    #[cfg(feature = "proptest")]
1633    #[allow(clippy::float_cmp)]
1634    fn check_dm_property(
1635        test_name: &str,
1636        kernel: Kernel,
1637    ) -> Result<(), Box<dyn std::error::Error>> {
1638        use proptest::prelude::*;
1639        skip_if_unsupported!(kernel, test_name);
1640
1641        let strat = (2usize..=50).prop_flat_map(|period| {
1642            (
1643                (100f64..10000f64, 0.01f64..0.05f64, period + 10..400)
1644                    .prop_flat_map(move |(base_price, volatility, data_len)| {
1645                        (
1646                            Just(base_price),
1647                            Just(volatility),
1648                            Just(data_len),
1649                            prop::collection::vec((-1f64..1f64), data_len),
1650                            prop::collection::vec((0f64..2f64), data_len),
1651                        )
1652                    })
1653                    .prop_map(
1654                        move |(base_price, volatility, data_len, changes, spreads)| {
1655                            let mut high = Vec::with_capacity(data_len);
1656                            let mut low = Vec::with_capacity(data_len);
1657                            let mut current_price = base_price;
1658
1659                            for i in 0..data_len {
1660                                let change = changes[i] * volatility * current_price;
1661                                current_price = (current_price + change).max(10.0);
1662
1663                                let spread = current_price * 0.01 * spreads[i];
1664                                let daily_high = current_price + spread;
1665                                let daily_low = current_price - spread;
1666
1667                                high.push(daily_high);
1668                                low.push(daily_low.max(1.0));
1669                            }
1670
1671                            (high, low)
1672                        },
1673                    ),
1674                Just(period),
1675            )
1676        });
1677
1678        proptest::test_runner::TestRunner::default().run(&strat, |((high, low), period)| {
1679            let params = DmParams {
1680                period: Some(period),
1681            };
1682            let input = DmInput::from_slices(&high, &low, params);
1683
1684            let DmOutput {
1685                plus: out_plus,
1686                minus: out_minus,
1687            } = dm_with_kernel(&input, kernel)?;
1688
1689            let DmOutput {
1690                plus: ref_plus,
1691                minus: ref_minus,
1692            } = dm_with_kernel(&input, Kernel::Scalar)?;
1693
1694            prop_assert_eq!(out_plus.len(), high.len());
1695            prop_assert_eq!(out_minus.len(), high.len());
1696
1697            let warmup_period = period - 1;
1698            for i in 0..warmup_period {
1699                prop_assert!(
1700                    out_plus[i].is_nan(),
1701                    "Plus value at index {} should be NaN during warmup",
1702                    i
1703                );
1704                prop_assert!(
1705                    out_minus[i].is_nan(),
1706                    "Minus value at index {} should be NaN during warmup",
1707                    i
1708                );
1709            }
1710
1711            for i in warmup_period..high.len() {
1712                if !out_plus[i].is_nan() {
1713                    prop_assert!(
1714                        out_plus[i] >= -1e-9,
1715                        "Plus DM at index {} is negative: {}",
1716                        i,
1717                        out_plus[i]
1718                    );
1719                }
1720                if !out_minus[i].is_nan() {
1721                    prop_assert!(
1722                        out_minus[i] >= -1e-9,
1723                        "Minus DM at index {} is negative: {}",
1724                        i,
1725                        out_minus[i]
1726                    );
1727                }
1728            }
1729
1730            const MAX_ULP: i64 = 3;
1731            for i in 0..high.len() {
1732                let plus_y = out_plus[i];
1733                let plus_r = ref_plus[i];
1734                let minus_y = out_minus[i];
1735                let minus_r = ref_minus[i];
1736
1737                if plus_y.is_nan() {
1738                    prop_assert!(
1739                        plus_r.is_nan(),
1740                        "Plus kernel mismatch at {}: {} vs NaN",
1741                        i,
1742                        plus_r
1743                    );
1744                } else {
1745                    let plus_y_bits = plus_y.to_bits();
1746                    let plus_r_bits = plus_r.to_bits();
1747                    let plus_ulp_diff = (plus_y_bits as i64).wrapping_sub(plus_r_bits as i64).abs();
1748
1749                    prop_assert!(
1750                        plus_ulp_diff <= MAX_ULP,
1751                        "Plus kernel mismatch at {}: {} vs {} (ULP diff: {})",
1752                        i,
1753                        plus_y,
1754                        plus_r,
1755                        plus_ulp_diff
1756                    );
1757                }
1758
1759                if minus_y.is_nan() {
1760                    prop_assert!(
1761                        minus_r.is_nan(),
1762                        "Minus kernel mismatch at {}: {} vs NaN",
1763                        i,
1764                        minus_r
1765                    );
1766                } else {
1767                    let minus_y_bits = minus_y.to_bits();
1768                    let minus_r_bits = minus_r.to_bits();
1769                    let minus_ulp_diff = (minus_y_bits as i64)
1770                        .wrapping_sub(minus_r_bits as i64)
1771                        .abs();
1772
1773                    prop_assert!(
1774                        minus_ulp_diff <= MAX_ULP,
1775                        "Minus kernel mismatch at {}: {} vs {} (ULP diff: {})",
1776                        i,
1777                        minus_y,
1778                        minus_r,
1779                        minus_ulp_diff
1780                    );
1781                }
1782            }
1783
1784            let all_high_equal = high.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10);
1785            let all_low_equal = low.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10);
1786
1787            if all_high_equal && all_low_equal {
1788                for i in (period * 2).min(high.len() - 1)..high.len() {
1789                    if !out_plus[i].is_nan() {
1790                        prop_assert!(
1791                            out_plus[i].abs() < 1e-6,
1792                            "Plus DM should be near zero for constant data at {}: {}",
1793                            i,
1794                            out_plus[i]
1795                        );
1796                    }
1797                    if !out_minus[i].is_nan() {
1798                        prop_assert!(
1799                            out_minus[i].abs() < 1e-6,
1800                            "Minus DM should be near zero for constant data at {}: {}",
1801                            i,
1802                            out_minus[i]
1803                        );
1804                    }
1805                }
1806            }
1807
1808            Ok(())
1809        })?;
1810
1811        Ok(())
1812    }
1813
1814    generate_all_dm_tests!(
1815        check_dm_partial_params,
1816        check_dm_default_candles,
1817        check_dm_with_slice_data,
1818        check_dm_zero_period,
1819        check_dm_period_exceeds_data_length,
1820        check_dm_not_enough_valid_data,
1821        check_dm_all_values_nan,
1822        check_dm_with_slice_reinput,
1823        check_dm_known_values,
1824        check_dm_no_poison
1825    );
1826
1827    #[cfg(feature = "proptest")]
1828    generate_all_dm_tests!(check_dm_property);
1829
1830    #[test]
1831    fn test_dm_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
1832        let n = 256usize;
1833        let mut high = Vec::with_capacity(n);
1834        let mut low = Vec::with_capacity(n);
1835        let mut price = 100.0f64;
1836        for i in 0..n {
1837            let drift = ((i % 7) as i32 - 3) as f64 * 0.3;
1838            price = (price + drift).max(1.0);
1839            let spread = 0.5 + 0.1 * ((i % 5) as f64);
1840            high.push(price + spread);
1841            low.push((price - spread).max(0.01));
1842        }
1843
1844        let input = DmInput::from_slices(&high, &low, DmParams::default());
1845
1846        let base = dm(&input)?;
1847
1848        let mut plus = vec![0.0; n];
1849        let mut minus = vec![0.0; n];
1850        #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1851        dm_into(&input, &mut plus, &mut minus)?;
1852
1853        fn eq_or_both_nan(a: f64, b: f64) -> bool {
1854            a == b || (a.is_nan() && b.is_nan())
1855        }
1856
1857        assert_eq!(base.plus.len(), n);
1858        assert_eq!(base.minus.len(), n);
1859        for i in 0..n {
1860            assert!(
1861                eq_or_both_nan(base.plus[i], plus[i]),
1862                "plus mismatch at {}: base={} into={}",
1863                i,
1864                base.plus[i],
1865                plus[i]
1866            );
1867            assert!(
1868                eq_or_both_nan(base.minus[i], minus[i]),
1869                "minus mismatch at {}: base={} into={}",
1870                i,
1871                base.minus[i],
1872                minus[i]
1873            );
1874        }
1875        Ok(())
1876    }
1877
1878    fn check_batch_default_row(
1879        test: &str,
1880        kernel: Kernel,
1881    ) -> Result<(), Box<dyn std::error::Error>> {
1882        skip_if_unsupported!(kernel, test);
1883
1884        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1885        let c = read_candles_from_csv(file)?;
1886
1887        let output = DmBatchBuilder::new().kernel(kernel).apply_candles(&c)?;
1888
1889        let def = DmParams::default();
1890        let (row_plus, row_minus) = output.values_for(&def).expect("default row missing");
1891
1892        assert_eq!(row_plus.len(), c.high.len());
1893        assert_eq!(row_minus.len(), c.high.len());
1894
1895        let expected_plus = [
1896            1410.819956368491,
1897            1384.04710234217,
1898            1285.186595032015,
1899            1199.3875525297283,
1900            1113.7170130633192,
1901        ];
1902        let expected_minus = [
1903            3602.8631384045057,
1904            3345.5157713756125,
1905            3258.5503591344973,
1906            3025.796762053462,
1907            3493.668421906786,
1908        ];
1909        let start = row_plus.len() - 5;
1910        for (i, &v) in row_plus[start..].iter().enumerate() {
1911            assert!((v - expected_plus[i]).abs() < 1e-6);
1912        }
1913        for (i, &v) in row_minus[start..].iter().enumerate() {
1914            assert!((v - expected_minus[i]).abs() < 1e-6);
1915        }
1916        Ok(())
1917    }
1918
1919    #[cfg(debug_assertions)]
1920    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
1921        skip_if_unsupported!(kernel, test);
1922
1923        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1924        let c = read_candles_from_csv(file)?;
1925
1926        let test_configs = vec![
1927            (2, 10, 2),
1928            (5, 25, 5),
1929            (30, 60, 15),
1930            (2, 5, 1),
1931            (14, 14, 0),
1932            (10, 100, 10),
1933            (100, 200, 50),
1934        ];
1935
1936        for (cfg_idx, &(p_start, p_end, p_step)) in test_configs.iter().enumerate() {
1937            let output = DmBatchBuilder::new()
1938                .kernel(kernel)
1939                .period_range(p_start, p_end, p_step)
1940                .apply_candles(&c)?;
1941
1942            for (idx, &val) in output.plus.iter().enumerate() {
1943                if val.is_nan() {
1944                    continue;
1945                }
1946
1947                let bits = val.to_bits();
1948                let row = idx / output.cols;
1949                let col = idx % output.cols;
1950                let combo = &output.combos[row];
1951
1952                if bits == 0x11111111_11111111 {
1953                    panic!(
1954						"[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) in plus \
1955						 at row {} col {} (flat index {}) with params: period={}",
1956						test, cfg_idx, val, bits, row, col, idx,
1957						combo.period.unwrap_or(14)
1958					);
1959                }
1960
1961                if bits == 0x22222222_22222222 {
1962                    panic!(
1963						"[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) in plus \
1964						 at row {} col {} (flat index {}) with params: period={}",
1965						test, cfg_idx, val, bits, row, col, idx,
1966						combo.period.unwrap_or(14)
1967					);
1968                }
1969
1970                if bits == 0x33333333_33333333 {
1971                    panic!(
1972						"[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) in plus \
1973						 at row {} col {} (flat index {}) with params: period={}",
1974						test, cfg_idx, val, bits, row, col, idx,
1975						combo.period.unwrap_or(14)
1976					);
1977                }
1978            }
1979
1980            for (idx, &val) in output.minus.iter().enumerate() {
1981                if val.is_nan() {
1982                    continue;
1983                }
1984
1985                let bits = val.to_bits();
1986                let row = idx / output.cols;
1987                let col = idx % output.cols;
1988                let combo = &output.combos[row];
1989
1990                if bits == 0x11111111_11111111 {
1991                    panic!(
1992						"[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) in minus \
1993						 at row {} col {} (flat index {}) with params: period={}",
1994						test, cfg_idx, val, bits, row, col, idx,
1995						combo.period.unwrap_or(14)
1996					);
1997                }
1998
1999                if bits == 0x22222222_22222222 {
2000                    panic!(
2001						"[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) in minus \
2002						 at row {} col {} (flat index {}) with params: period={}",
2003						test, cfg_idx, val, bits, row, col, idx,
2004						combo.period.unwrap_or(14)
2005					);
2006                }
2007
2008                if bits == 0x33333333_33333333 {
2009                    panic!(
2010						"[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) in minus \
2011						 at row {} col {} (flat index {}) with params: period={}",
2012						test, cfg_idx, val, bits, row, col, idx,
2013						combo.period.unwrap_or(14)
2014					);
2015                }
2016            }
2017        }
2018
2019        Ok(())
2020    }
2021
2022    #[cfg(not(debug_assertions))]
2023    fn check_batch_no_poison(
2024        _test: &str,
2025        _kernel: Kernel,
2026    ) -> Result<(), Box<dyn std::error::Error>> {
2027        Ok(())
2028    }
2029
2030    macro_rules! gen_batch_tests {
2031        ($fn_name:ident) => {
2032            paste::paste! {
2033                #[test] fn [<$fn_name _scalar>]()      {
2034                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2035                }
2036                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2037                #[test] fn [<$fn_name _avx2>]()        {
2038                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2039                }
2040                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2041                #[test] fn [<$fn_name _avx512>]()      {
2042                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2043                }
2044                #[test] fn [<$fn_name _auto_detect>]() {
2045                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2046                }
2047            }
2048        };
2049    }
2050    gen_batch_tests!(check_batch_default_row);
2051    gen_batch_tests!(check_batch_no_poison);
2052}
2053
2054#[cfg(feature = "python")]
2055use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
2056#[cfg(feature = "python")]
2057use pyo3::exceptions::PyValueError;
2058#[cfg(feature = "python")]
2059use pyo3::prelude::*;
2060#[cfg(feature = "python")]
2061use pyo3::types::PyDict;
2062
2063#[cfg(feature = "python")]
2064#[pyfunction(name = "dm")]
2065#[pyo3(signature = (high, low, period, kernel=None))]
2066pub fn dm_py<'py>(
2067    py: Python<'py>,
2068    high: PyReadonlyArray1<'py, f64>,
2069    low: PyReadonlyArray1<'py, f64>,
2070    period: usize,
2071    kernel: Option<&str>,
2072) -> PyResult<(Bound<'py, PyArray1<f64>>, Bound<'py, PyArray1<f64>>)> {
2073    let h = high.as_slice()?;
2074    let l = low.as_slice()?;
2075    if h.len() != l.len() {
2076        return Err(PyValueError::new_err("high/low length mismatch"));
2077    }
2078
2079    let params = DmParams {
2080        period: Some(period),
2081    };
2082    let input = DmInput::from_slices(h, l, params);
2083    let kern = validate_kernel(kernel, false)?;
2084
2085    let out_plus = unsafe { PyArray1::<f64>::new(py, [h.len()], false) };
2086    let out_minus = unsafe { PyArray1::<f64>::new(py, [h.len()], false) };
2087    let plus_slice = unsafe { out_plus.as_slice_mut()? };
2088    let minus_slice = unsafe { out_minus.as_slice_mut()? };
2089
2090    py.allow_threads(|| dm_into_slice(plus_slice, minus_slice, &input, kern))
2091        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2092
2093    Ok((out_plus, out_minus))
2094}
2095
2096#[cfg(feature = "python")]
2097#[pyfunction(name = "dm_batch")]
2098#[pyo3(signature = (high, low, period_range, kernel=None))]
2099pub fn dm_batch_py<'py>(
2100    py: Python<'py>,
2101    high: PyReadonlyArray1<'py, f64>,
2102    low: PyReadonlyArray1<'py, f64>,
2103    period_range: (usize, usize, usize),
2104    kernel: Option<&str>,
2105) -> PyResult<Bound<'py, PyDict>> {
2106    let h = high.as_slice()?;
2107    let l = low.as_slice()?;
2108    if h.len() != l.len() {
2109        return Err(PyValueError::new_err("high/low length mismatch"));
2110    }
2111
2112    let sweep = DmBatchRange {
2113        period: period_range,
2114    };
2115    let kern = validate_kernel(kernel, true)?;
2116
2117    let output = py
2118        .allow_threads(|| dm_batch_with_kernel(h, l, &sweep, kern))
2119        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2120
2121    let plus = unsafe { PyArray1::from_vec(py, output.plus).reshape((output.rows, output.cols))? };
2122    let minus =
2123        unsafe { PyArray1::from_vec(py, output.minus).reshape((output.rows, output.cols))? };
2124
2125    let dict = PyDict::new(py);
2126    dict.set_item("plus", plus)?;
2127    dict.set_item("minus", minus)?;
2128    dict.set_item(
2129        "periods",
2130        output
2131            .combos
2132            .iter()
2133            .map(|p| p.period.unwrap() as u64)
2134            .collect::<Vec<_>>()
2135            .into_pyarray(py),
2136    )?;
2137    Ok(dict)
2138}
2139
2140#[cfg(all(feature = "python", feature = "cuda"))]
2141#[pyfunction(name = "dm_cuda_batch_dev")]
2142#[pyo3(signature = (high_f32, low_f32, period_range, device_id=0))]
2143pub fn dm_cuda_batch_dev_py(
2144    py: Python<'_>,
2145    high_f32: numpy::PyReadonlyArray1<'_, f32>,
2146    low_f32: numpy::PyReadonlyArray1<'_, f32>,
2147    period_range: (usize, usize, usize),
2148    device_id: usize,
2149) -> PyResult<(DeviceArrayF32Py, DeviceArrayF32Py)> {
2150    use crate::cuda::cuda_available;
2151    if !cuda_available() {
2152        return Err(PyValueError::new_err("CUDA not available"));
2153    }
2154    let h = high_f32.as_slice()?;
2155    let l = low_f32.as_slice()?;
2156    let sweep = DmBatchRange {
2157        period: period_range,
2158    };
2159    let (pair, ctx, dev) = py.allow_threads(|| {
2160        let cuda = CudaDm::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2161        let ctx = cuda.context_arc();
2162        let dev = cuda.device_id();
2163        cuda.dm_batch_dev(h, l, &sweep)
2164            .map(|(pair, _)| (pair, ctx, dev))
2165            .map_err(|e| PyValueError::new_err(e.to_string()))
2166    })?;
2167    Ok((
2168        DeviceArrayF32Py {
2169            inner: pair.plus,
2170            _ctx: Some(ctx.clone()),
2171            device_id: Some(dev),
2172        },
2173        DeviceArrayF32Py {
2174            inner: pair.minus,
2175            _ctx: Some(ctx),
2176            device_id: Some(dev),
2177        },
2178    ))
2179}
2180
2181#[cfg(all(feature = "python", feature = "cuda"))]
2182#[pyfunction(name = "dm_cuda_many_series_one_param_dev")]
2183#[pyo3(signature = (high_tm_f32, low_tm_f32, cols, rows, period, device_id=0))]
2184pub fn dm_cuda_many_series_one_param_dev_py(
2185    py: Python<'_>,
2186    high_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
2187    low_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
2188    cols: usize,
2189    rows: usize,
2190    period: usize,
2191    device_id: usize,
2192) -> PyResult<(DeviceArrayF32Py, DeviceArrayF32Py)> {
2193    use crate::cuda::cuda_available;
2194    if !cuda_available() {
2195        return Err(PyValueError::new_err("CUDA not available"));
2196    }
2197    let h = high_tm_f32.as_slice()?;
2198    let l = low_tm_f32.as_slice()?;
2199    let (pair, ctx, dev) = py.allow_threads(|| {
2200        let cuda = CudaDm::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2201        let ctx = cuda.context_arc();
2202        let dev = cuda.device_id();
2203        cuda.dm_many_series_one_param_time_major_dev(h, l, cols, rows, period)
2204            .map(|pair| (pair, ctx, dev))
2205            .map_err(|e| PyValueError::new_err(e.to_string()))
2206    })?;
2207    Ok((
2208        DeviceArrayF32Py {
2209            inner: pair.plus,
2210            _ctx: Some(ctx.clone()),
2211            device_id: Some(dev),
2212        },
2213        DeviceArrayF32Py {
2214            inner: pair.minus,
2215            _ctx: Some(ctx),
2216            device_id: Some(dev),
2217        },
2218    ))
2219}
2220
2221#[cfg(feature = "python")]
2222#[pyclass(name = "DmStream")]
2223pub struct DmStreamPy {
2224    stream: DmStream,
2225}
2226
2227#[cfg(feature = "python")]
2228#[pymethods]
2229impl DmStreamPy {
2230    #[new]
2231    fn new(period: usize) -> PyResult<Self> {
2232        let s = DmStream::try_new(DmParams {
2233            period: Some(period),
2234        })
2235        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2236        Ok(Self { stream: s })
2237    }
2238    fn update(&mut self, high: f64, low: f64) -> Option<(f64, f64)> {
2239        self.stream.update(high, low)
2240    }
2241}
2242
2243#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2244use serde::{Deserialize, Serialize};
2245#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2246use wasm_bindgen::prelude::*;
2247
2248#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2249#[derive(Serialize, Deserialize)]
2250pub struct DmJsOutput {
2251    pub values: Vec<f64>,
2252    pub rows: usize,
2253    pub cols: usize,
2254}
2255
2256#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2257#[wasm_bindgen(js_name = dm)]
2258pub fn dm_js(high: &[f64], low: &[f64], period: usize) -> Result<JsValue, JsValue> {
2259    if high.len() != low.len() {
2260        return Err(JsValue::from_str("length mismatch"));
2261    }
2262    let input = DmInput::from_slices(
2263        high,
2264        low,
2265        DmParams {
2266            period: Some(period),
2267        },
2268    );
2269
2270    let mut plus = vec![0.0; high.len()];
2271    let mut minus = vec![0.0; high.len()];
2272    dm_into_slice(&mut plus, &mut minus, &input, detect_best_kernel())
2273        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2274
2275    let mut values = plus;
2276    values.extend_from_slice(&minus);
2277
2278    let output = DmJsOutput {
2279        values,
2280        rows: 2,
2281        cols: high.len(),
2282    };
2283    serde_wasm_bindgen::to_value(&output)
2284        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2285}
2286
2287#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2288#[derive(Serialize, Deserialize)]
2289pub struct DmBatchConfig {
2290    pub period_range: (usize, usize, usize),
2291}
2292
2293#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2294#[derive(Serialize, Deserialize)]
2295pub struct DmBatchJsOutput {
2296    pub values: Vec<f64>,
2297    pub rows: usize,
2298    pub cols: usize,
2299    pub periods: Vec<usize>,
2300}
2301
2302#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2303#[wasm_bindgen(js_name = dm_batch)]
2304pub fn dm_batch_unified_js(high: &[f64], low: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2305    if high.len() != low.len() {
2306        return Err(JsValue::from_str("length mismatch"));
2307    }
2308    let cfg: DmBatchConfig = serde_wasm_bindgen::from_value(config)
2309        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2310
2311    let sweep = DmBatchRange {
2312        period: cfg.period_range,
2313    };
2314    let out = dm_batch_inner(high, low, &sweep, detect_best_kernel(), false)
2315        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2316
2317    let mut values = Vec::with_capacity(out.plus.len() + out.minus.len());
2318    values.extend_from_slice(&out.plus);
2319    values.extend_from_slice(&out.minus);
2320
2321    let periods = out
2322        .combos
2323        .iter()
2324        .map(|p| p.period.unwrap())
2325        .collect::<Vec<_>>();
2326
2327    let js = DmBatchJsOutput {
2328        values,
2329        rows: out.rows * 2,
2330        cols: out.cols,
2331        periods,
2332    };
2333    serde_wasm_bindgen::to_value(&js)
2334        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2335}
2336
2337#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2338#[wasm_bindgen]
2339pub fn dm_alloc(len: usize) -> *mut f64 {
2340    let mut v = Vec::<f64>::with_capacity(len);
2341    let p = v.as_mut_ptr();
2342    std::mem::forget(v);
2343    p
2344}
2345#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2346#[wasm_bindgen]
2347pub fn dm_free(ptr: *mut f64, len: usize) {
2348    unsafe {
2349        let _ = Vec::from_raw_parts(ptr, len, len);
2350    }
2351}
2352
2353#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2354#[wasm_bindgen(js_name = dm_into)]
2355pub fn dm_into_js(
2356    high_ptr: *const f64,
2357    low_ptr: *const f64,
2358    plus_ptr: *mut f64,
2359    minus_ptr: *mut f64,
2360    len: usize,
2361    period: usize,
2362) -> Result<(), JsValue> {
2363    if high_ptr.is_null() || low_ptr.is_null() || plus_ptr.is_null() || minus_ptr.is_null() {
2364        return Err(JsValue::from_str("null pointer"));
2365    }
2366    unsafe {
2367        let h = std::slice::from_raw_parts(high_ptr, len);
2368        let l = std::slice::from_raw_parts(low_ptr, len);
2369        let input = DmInput::from_slices(
2370            h,
2371            l,
2372            DmParams {
2373                period: Some(period),
2374            },
2375        );
2376        let plus = std::slice::from_raw_parts_mut(plus_ptr, len);
2377        let minus = std::slice::from_raw_parts_mut(minus_ptr, len);
2378        dm_into_slice(plus, minus, &input, detect_best_kernel())
2379            .map_err(|e| JsValue::from_str(&e.to_string()))
2380    }
2381}