Skip to main content

vector_ta/indicators/
dpo.rs

1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5    make_uninit_matrix,
6};
7use aligned_vec::{AVec, CACHELINE_ALIGN};
8use paste::paste;
9#[cfg(not(target_arch = "wasm32"))]
10use rayon::prelude::*;
11use std::convert::AsRef;
12use std::error::Error;
13use thiserror::Error;
14
15impl<'a> AsRef<[f64]> for DpoInput<'a> {
16    #[inline(always)]
17    fn as_ref(&self) -> &[f64] {
18        match &self.data {
19            DpoData::Slice(slice) => slice,
20            DpoData::Candles { candles, source } => source_type(candles, source),
21        }
22    }
23}
24
25#[derive(Debug, Clone)]
26pub enum DpoData<'a> {
27    Candles {
28        candles: &'a Candles,
29        source: &'a str,
30    },
31    Slice(&'a [f64]),
32}
33
34#[derive(Debug, Clone)]
35pub struct DpoOutput {
36    pub values: Vec<f64>,
37}
38
39#[derive(Debug, Clone)]
40#[cfg_attr(
41    all(target_arch = "wasm32", feature = "wasm"),
42    derive(Serialize, Deserialize)
43)]
44pub struct DpoParams {
45    pub period: Option<usize>,
46}
47
48impl Default for DpoParams {
49    fn default() -> Self {
50        Self { period: Some(5) }
51    }
52}
53
54#[derive(Debug, Clone)]
55pub struct DpoInput<'a> {
56    pub data: DpoData<'a>,
57    pub params: DpoParams,
58}
59
60impl<'a> DpoInput<'a> {
61    #[inline]
62    pub fn from_candles(c: &'a Candles, s: &'a str, p: DpoParams) -> Self {
63        Self {
64            data: DpoData::Candles {
65                candles: c,
66                source: s,
67            },
68            params: p,
69        }
70    }
71    #[inline]
72    pub fn from_slice(sl: &'a [f64], p: DpoParams) -> Self {
73        Self {
74            data: DpoData::Slice(sl),
75            params: p,
76        }
77    }
78    #[inline]
79    pub fn with_default_candles(c: &'a Candles) -> Self {
80        Self::from_candles(c, "close", DpoParams::default())
81    }
82    #[inline]
83    pub fn get_period(&self) -> usize {
84        self.params.period.unwrap_or(5)
85    }
86}
87
88#[derive(Copy, Clone, Debug)]
89pub struct DpoBuilder {
90    period: Option<usize>,
91    kernel: Kernel,
92}
93
94impl Default for DpoBuilder {
95    fn default() -> Self {
96        Self {
97            period: None,
98            kernel: Kernel::Auto,
99        }
100    }
101}
102
103impl DpoBuilder {
104    #[inline(always)]
105    pub fn new() -> Self {
106        Self::default()
107    }
108    #[inline(always)]
109    pub fn period(mut self, n: usize) -> Self {
110        self.period = Some(n);
111        self
112    }
113    #[inline(always)]
114    pub fn kernel(mut self, k: Kernel) -> Self {
115        self.kernel = k;
116        self
117    }
118
119    #[inline(always)]
120    pub fn apply(self, c: &Candles) -> Result<DpoOutput, DpoError> {
121        let p = DpoParams {
122            period: self.period,
123        };
124        let i = DpoInput::from_candles(c, "close", p);
125        dpo_with_kernel(&i, self.kernel)
126    }
127
128    #[inline(always)]
129    pub fn apply_slice(self, d: &[f64]) -> Result<DpoOutput, DpoError> {
130        let p = DpoParams {
131            period: self.period,
132        };
133        let i = DpoInput::from_slice(d, p);
134        dpo_with_kernel(&i, self.kernel)
135    }
136
137    #[inline(always)]
138    pub fn into_stream(self) -> Result<DpoStream, DpoError> {
139        let p = DpoParams {
140            period: self.period,
141        };
142        DpoStream::try_new(p)
143    }
144}
145
146#[derive(Debug, Error)]
147pub enum DpoError {
148    #[error("dpo: Input data slice is empty.")]
149    EmptyInputData,
150    #[error("dpo: All values are NaN.")]
151    AllValuesNaN,
152
153    #[error("dpo: Invalid period: period = {period}, data length = {data_len}")]
154    InvalidPeriod { period: usize, data_len: usize },
155
156    #[error("dpo: Not enough valid data: needed = {needed}, valid = {valid}")]
157    NotEnoughValidData { needed: usize, valid: usize },
158}
159
160#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
161impl From<DpoError> for JsValue {
162    fn from(err: DpoError) -> Self {
163        JsValue::from_str(&err.to_string())
164    }
165}
166
167#[inline]
168pub fn dpo(input: &DpoInput) -> Result<DpoOutput, DpoError> {
169    dpo_with_kernel(input, Kernel::Auto)
170}
171
172#[inline]
173pub fn dpo_into_slice(dst: &mut [f64], input: &DpoInput, kern: Kernel) -> Result<(), DpoError> {
174    let data: &[f64] = input.as_ref();
175    let len = data.len();
176    if len == 0 {
177        return Err(DpoError::EmptyInputData);
178    }
179
180    let first = data
181        .iter()
182        .position(|x| !x.is_nan())
183        .ok_or(DpoError::AllValuesNaN)?;
184    let period = input.get_period();
185
186    if period == 0 || period > len {
187        return Err(DpoError::InvalidPeriod {
188            period,
189            data_len: len,
190        });
191    }
192    if (len - first) < period {
193        return Err(DpoError::NotEnoughValidData {
194            needed: period,
195            valid: len - first,
196        });
197    }
198    if dst.len() != len {
199        return Err(DpoError::InvalidPeriod {
200            period: dst.len(),
201            data_len: len,
202        });
203    }
204
205    let chosen = match kern {
206        Kernel::Auto => detect_best_kernel(),
207        other => other,
208    };
209
210    unsafe {
211        #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
212        if matches!(chosen, Kernel::Scalar | Kernel::ScalarBatch) {
213            dpo_simd128(data, period, first, dst);
214        } else {
215            match chosen {
216                Kernel::Scalar | Kernel::ScalarBatch => dpo_scalar(data, period, first, dst),
217                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
218                Kernel::Avx2 | Kernel::Avx2Batch => dpo_avx2(data, period, first, dst),
219                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
220                Kernel::Avx512 | Kernel::Avx512Batch => dpo_avx512(data, period, first, dst),
221                _ => unreachable!(),
222            }
223        }
224
225        #[cfg(not(all(target_arch = "wasm32", target_feature = "simd128")))]
226        {
227            match chosen {
228                Kernel::Scalar | Kernel::ScalarBatch => dpo_scalar(data, period, first, dst),
229                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
230                Kernel::Avx2 | Kernel::Avx2Batch => dpo_avx2(data, period, first, dst),
231                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
232                Kernel::Avx512 | Kernel::Avx512Batch => dpo_avx512(data, period, first, dst),
233                _ => unreachable!(),
234            }
235        }
236    }
237
238    let back = period / 2 + 1;
239    let warm = (first + period - 1).max(back);
240
241    let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
242    for v in &mut dst[..warm] {
243        *v = qnan;
244    }
245
246    Ok(())
247}
248
249#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
250#[inline]
251pub fn dpo_into(input: &DpoInput, out: &mut [f64]) -> Result<(), DpoError> {
252    dpo_into_slice(out, input, Kernel::Auto)
253}
254
255pub fn dpo_with_kernel(input: &DpoInput, kernel: Kernel) -> Result<DpoOutput, DpoError> {
256    let data: &[f64] = input.as_ref();
257    let len = data.len();
258    if len == 0 {
259        return Err(DpoError::EmptyInputData);
260    }
261
262    let first = data
263        .iter()
264        .position(|x| !x.is_nan())
265        .ok_or(DpoError::AllValuesNaN)?;
266    let period = input.get_period();
267
268    if period == 0 || period > len {
269        return Err(DpoError::InvalidPeriod {
270            period,
271            data_len: len,
272        });
273    }
274    if (len - first) < period {
275        return Err(DpoError::NotEnoughValidData {
276            needed: period,
277            valid: len - first,
278        });
279    }
280
281    let back = period / 2 + 1;
282    let warm = (first + period - 1).max(back);
283
284    let chosen = match kernel {
285        Kernel::Auto => detect_best_kernel(),
286        other => other,
287    };
288
289    let mut out = alloc_with_nan_prefix(len, warm);
290    unsafe {
291        match chosen {
292            Kernel::Scalar | Kernel::ScalarBatch => dpo_scalar(data, period, first, &mut out),
293            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
294            Kernel::Avx2 | Kernel::Avx2Batch => dpo_avx2(data, period, first, &mut out),
295            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
296            Kernel::Avx512 | Kernel::Avx512Batch => dpo_avx512(data, period, first, &mut out),
297            _ => unreachable!(),
298        }
299    }
300    Ok(DpoOutput { values: out })
301}
302
303#[inline(always)]
304pub fn dpo_scalar(data: &[f64], period: usize, first_val: usize, out: &mut [f64]) {
305    let len = data.len();
306    if len == 0 {
307        return;
308    }
309
310    let back = period / 2 + 1;
311    let scale = 1.0f64 / (period as f64);
312
313    unsafe {
314        let base = first_val;
315        let ptr_d = data.as_ptr();
316        let ptr_o = out.as_mut_ptr();
317
318        let mut sum = 0.0f64;
319        let mut k = 0usize;
320        while k < period {
321            sum += *ptr_d.add(base + k);
322            k += 1;
323        }
324
325        let mut i = base + period - 1;
326
327        if i < back {
328            let stop = back.min(len.saturating_sub(1));
329            while i < stop {
330                let next = i + 1;
331                sum += *ptr_d.add(next);
332                sum -= *ptr_d.add(next - period);
333                i = next;
334            }
335        }
336
337        while i + 3 < len {
338            let p0 = *ptr_d.add(i - back);
339            *ptr_o.add(i) = sum.mul_add(-scale, p0);
340
341            let a1 = *ptr_d.add(i + 1);
342            let r1 = *ptr_d.add(i + 1 - period);
343            let s1 = (sum + a1) - r1;
344            let p1 = *ptr_d.add(i + 1 - back);
345            *ptr_o.add(i + 1) = s1.mul_add(-scale, p1);
346
347            let a2 = *ptr_d.add(i + 2);
348            let r2 = *ptr_d.add(i + 2 - period);
349            let s2 = (s1 + a2) - r2;
350            let p2 = *ptr_d.add(i + 2 - back);
351            *ptr_o.add(i + 2) = s2.mul_add(-scale, p2);
352
353            let a3 = *ptr_d.add(i + 3);
354            let r3 = *ptr_d.add(i + 3 - period);
355            let s3 = (s2 + a3) - r3;
356            let p3 = *ptr_d.add(i + 3 - back);
357            *ptr_o.add(i + 3) = s3.mul_add(-scale, p3);
358
359            i += 4;
360            if i >= len {
361                return;
362            }
363            sum = s3 + *ptr_d.add(i);
364            sum -= *ptr_d.add(i - period);
365        }
366
367        while i < len {
368            if i >= back {
369                let p = *ptr_d.add(i - back);
370                *ptr_o.add(i) = sum.mul_add(-scale, p);
371            }
372            if i + 1 < len {
373                let next = i + 1;
374                sum += *ptr_d.add(next);
375                sum -= *ptr_d.add(next - period);
376            }
377            i += 1;
378        }
379    }
380}
381
382#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
383#[inline]
384unsafe fn dpo_simd128(data: &[f64], period: usize, first_val: usize, out: &mut [f64]) {
385    use core::arch::wasm32::*;
386
387    let len = data.len();
388    if len == 0 {
389        return;
390    }
391
392    let back = period / 2 + 1;
393    let start_idx = first_val + period - 1;
394    let warm = start_idx.max(back);
395    if warm >= len {
396        return;
397    }
398
399    let mut sum = 0.0f64;
400    for j in 0..period {
401        sum += data[first_val + j];
402    }
403
404    let mut cur = start_idx;
405    while cur < warm {
406        let next = cur + 1;
407        sum += data[next] - data[next - period];
408        cur = next;
409    }
410
411    let scale = 1.0f64 / (period as f64);
412    let scale_vec = f64x2_splat(scale);
413
414    let mut i = warm;
415    while i + 1 < len {
416        let price_vec = v128_load(&data[i - back] as *const f64 as *const v128);
417
418        let sum0 = sum;
419        let sum1 = sum0 + data[i + 1] - data[i + 1 - period];
420        let sum_vec = f64x2(sum0, sum1);
421
422        let result = f64x2_sub(price_vec, f64x2_mul(sum_vec, scale_vec));
423        v128_store(&mut out[i] as *mut f64 as *mut v128, result);
424
425        if i + 2 < len {
426            sum = sum1 + data[i + 2] - data[i + 2 - period];
427        }
428        i += 2;
429    }
430
431    if i < len {
432        out[i] = data[i - back] - (sum * scale);
433    }
434}
435
436#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
437#[inline]
438pub fn dpo_avx2(data: &[f64], period: usize, first_val: usize, out: &mut [f64]) {
439    let len = data.len();
440    if len == 0 {
441        return;
442    }
443    let back = period / 2 + 1;
444    let scale = 1.0f64 / (period as f64);
445    unsafe {
446        let mut sum = 0.0f64;
447        let base = first_val;
448        let mut k = 0usize;
449        let ptr_d = data.as_ptr();
450        while k < period {
451            sum += *ptr_d.add(base + k);
452            k += 1;
453        }
454        let mut i = base + period - 1;
455
456        if i < back {
457            let stop = back.min(len.saturating_sub(1));
458            while i < stop {
459                sum += *ptr_d.add(i + 1) - *ptr_d.add(i + 1 - period);
460                i += 1;
461            }
462            if i >= len {
463                return;
464            }
465        }
466
467        let ptr_o = out.as_mut_ptr();
468        while i + 3 < len {
469            let p0 = *ptr_d.add(i - back);
470            *ptr_o.add(i) = sum.mul_add(-scale, p0);
471
472            let a1 = *ptr_d.add(i + 1);
473            let r1 = *ptr_d.add(i + 1 - period);
474            let s1 = sum + (a1 - r1);
475            if i + 1 >= back {
476                let p1 = *ptr_d.add(i + 1 - back);
477                *ptr_o.add(i + 1) = s1.mul_add(-scale, p1);
478            }
479
480            let a2 = *ptr_d.add(i + 2);
481            let r2 = *ptr_d.add(i + 2 - period);
482            let s2 = s1 + (a2 - r2);
483            if i + 2 >= back {
484                let p2 = *ptr_d.add(i + 2 - back);
485                *ptr_o.add(i + 2) = s2.mul_add(-scale, p2);
486            }
487
488            let a3 = *ptr_d.add(i + 3);
489            let r3 = *ptr_d.add(i + 3 - period);
490            let s3 = s2 + (a3 - r3);
491            if i + 3 >= back {
492                let p3 = *ptr_d.add(i + 3 - back);
493                *ptr_o.add(i + 3) = s3.mul_add(-scale, p3);
494            }
495
496            i += 4;
497            if i >= len {
498                return;
499            }
500            let a4 = *ptr_d.add(i);
501            let r4 = *ptr_d.add(i - period);
502            sum = s3 + (a4 - r4);
503        }
504
505        while i < len {
506            if i >= back {
507                let p = *ptr_d.add(i - back);
508                *ptr_o.add(i) = sum.mul_add(-scale, p);
509            }
510            if i + 1 < len {
511                sum += *ptr_d.add(i + 1) - *ptr_d.add(i + 1 - period);
512            }
513            i += 1;
514        }
515    }
516}
517
518#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
519#[inline]
520pub fn dpo_avx512(data: &[f64], period: usize, first_val: usize, out: &mut [f64]) {
521    dpo_avx2(data, period, first_val, out)
522}
523
524#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
525#[inline]
526pub fn dpo_avx512_short(data: &[f64], period: usize, first_val: usize, out: &mut [f64]) {
527    dpo_avx2(data, period, first_val, out)
528}
529
530#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
531#[inline]
532pub fn dpo_avx512_long(data: &[f64], period: usize, first_val: usize, out: &mut [f64]) {
533    dpo_avx2(data, period, first_val, out)
534}
535
536#[inline]
537pub fn dpo_batch_with_kernel(
538    data: &[f64],
539    sweep: &DpoBatchRange,
540    k: Kernel,
541) -> Result<DpoBatchOutput, DpoError> {
542    let kernel = match k {
543        Kernel::Auto => detect_best_batch_kernel(),
544        other if other.is_batch() => other,
545        _ => {
546            return Err(DpoError::InvalidPeriod {
547                period: 0,
548                data_len: 0,
549            })
550        }
551    };
552    let simd = match kernel {
553        Kernel::Avx512Batch => Kernel::Avx512,
554        Kernel::Avx2Batch => Kernel::Avx2,
555        Kernel::ScalarBatch => Kernel::Scalar,
556        _ => unreachable!(),
557    };
558    dpo_batch_par_slice(data, sweep, simd)
559}
560
561#[derive(Clone, Debug)]
562pub struct DpoBatchRange {
563    pub period: (usize, usize, usize),
564}
565
566impl Default for DpoBatchRange {
567    fn default() -> Self {
568        Self {
569            period: (5, 254, 1),
570        }
571    }
572}
573
574#[derive(Clone, Debug, Default)]
575pub struct DpoBatchBuilder {
576    range: DpoBatchRange,
577    kernel: Kernel,
578}
579
580impl DpoBatchBuilder {
581    pub fn new() -> Self {
582        Self::default()
583    }
584    pub fn kernel(mut self, k: Kernel) -> Self {
585        self.kernel = k;
586        self
587    }
588    #[inline]
589    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
590        self.range.period = (start, end, step);
591        self
592    }
593    #[inline]
594    pub fn period_static(mut self, p: usize) -> Self {
595        self.range.period = (p, p, 0);
596        self
597    }
598
599    pub fn apply_slice(self, data: &[f64]) -> Result<DpoBatchOutput, DpoError> {
600        dpo_batch_with_kernel(data, &self.range, self.kernel)
601    }
602    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<DpoBatchOutput, DpoError> {
603        DpoBatchBuilder::new().kernel(k).apply_slice(data)
604    }
605    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<DpoBatchOutput, DpoError> {
606        let slice = source_type(c, src);
607        self.apply_slice(slice)
608    }
609    pub fn with_default_candles(c: &Candles) -> Result<DpoBatchOutput, DpoError> {
610        DpoBatchBuilder::new()
611            .kernel(Kernel::Auto)
612            .apply_candles(c, "close")
613    }
614}
615
616#[derive(Clone, Debug)]
617pub struct DpoBatchOutput {
618    pub values: Vec<f64>,
619    pub combos: Vec<DpoParams>,
620    pub rows: usize,
621    pub cols: usize,
622}
623
624impl DpoBatchOutput {
625    pub fn row_for_params(&self, p: &DpoParams) -> Option<usize> {
626        self.combos
627            .iter()
628            .position(|c| c.period.unwrap_or(5) == p.period.unwrap_or(5))
629    }
630
631    pub fn values_for(&self, p: &DpoParams) -> Option<&[f64]> {
632        self.row_for_params(p).map(|row| {
633            let start = row * self.cols;
634            &self.values[start..start + self.cols]
635        })
636    }
637}
638
639#[inline(always)]
640fn expand_grid(r: &DpoBatchRange) -> Vec<DpoParams> {
641    fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
642        if step == 0 || start == end {
643            return vec![start];
644        }
645        (start..=end).step_by(step).collect()
646    }
647    let periods = axis_usize(r.period);
648    let mut out = Vec::with_capacity(periods.len());
649    for &p in &periods {
650        out.push(DpoParams { period: Some(p) });
651    }
652    out
653}
654
655#[inline(always)]
656pub fn dpo_batch_slice(
657    data: &[f64],
658    sweep: &DpoBatchRange,
659    kern: Kernel,
660) -> Result<DpoBatchOutput, DpoError> {
661    dpo_batch_inner(data, sweep, kern, false)
662}
663
664#[inline(always)]
665pub fn dpo_batch_par_slice(
666    data: &[f64],
667    sweep: &DpoBatchRange,
668    kern: Kernel,
669) -> Result<DpoBatchOutput, DpoError> {
670    dpo_batch_inner(data, sweep, kern, true)
671}
672
673#[inline(always)]
674pub fn dpo_batch_inner_into(
675    data: &[f64],
676    sweep: &DpoBatchRange,
677    kern: Kernel,
678    parallel: bool,
679    out: &mut [f64],
680) -> Result<Vec<DpoParams>, DpoError> {
681    let combos = expand_grid(sweep);
682    if combos.is_empty() {
683        return Err(DpoError::InvalidPeriod {
684            period: 0,
685            data_len: 0,
686        });
687    }
688
689    let len = data.len();
690    if len == 0 {
691        return Err(DpoError::EmptyInputData);
692    }
693
694    let first = data
695        .iter()
696        .position(|x| !x.is_nan())
697        .ok_or(DpoError::AllValuesNaN)?;
698    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
699    if len - first < max_p {
700        return Err(DpoError::NotEnoughValidData {
701            needed: max_p,
702            valid: len - first,
703        });
704    }
705
706    let rows = combos.len();
707    let cols = len;
708    debug_assert_eq!(out.len(), rows * cols);
709
710    let warm: Vec<usize> = combos
711        .iter()
712        .map(|c| {
713            let p = c.period.unwrap();
714            let back = p / 2 + 1;
715            (first + p - 1).max(back)
716        })
717        .collect();
718
719    let out_mu: &mut [std::mem::MaybeUninit<f64>] = unsafe {
720        std::slice::from_raw_parts_mut(
721            out.as_mut_ptr() as *mut std::mem::MaybeUninit<f64>,
722            out.len(),
723        )
724    };
725    init_matrix_prefixes(out_mu, cols, &warm);
726
727    let mut pfx = vec![0.0f64; cols];
728    if cols > 0 {
729        let mut acc = 0.0f64;
730        for i in 0..cols {
731            if i < first {
732                pfx[i] = 0.0;
733            } else {
734                acc += data[i];
735                pfx[i] = acc;
736            }
737        }
738    }
739
740    let do_row = |row: usize, dst_mu: &mut [std::mem::MaybeUninit<f64>]| unsafe {
741        let period = combos[row].period.unwrap();
742        let back = period / 2 + 1;
743        let warm = (first + period - 1).max(back);
744        let dst = std::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, cols);
745        let scale = 1.0f64 / (period as f64);
746
747        let mut i = warm;
748        while i < cols {
749            let prev = if i >= period { pfx[i - period] } else { 0.0 };
750            let sum = pfx[i] - prev;
751            let avg = sum * scale;
752            let price = data[i - back];
753            dst[i] = price - avg;
754            i += 1;
755        }
756    };
757
758    if parallel {
759        #[cfg(not(target_arch = "wasm32"))]
760        out_mu
761            .par_chunks_mut(cols)
762            .enumerate()
763            .for_each(|(r, s)| do_row(r, s));
764        #[cfg(target_arch = "wasm32")]
765        for (r, s) in out_mu.chunks_mut(cols).enumerate() {
766            do_row(r, s);
767        }
768    } else {
769        for (r, s) in out_mu.chunks_mut(cols).enumerate() {
770            do_row(r, s);
771        }
772    }
773
774    Ok(combos)
775}
776
777#[inline(always)]
778fn dpo_batch_inner(
779    data: &[f64],
780    sweep: &DpoBatchRange,
781    kern: Kernel,
782    parallel: bool,
783) -> Result<DpoBatchOutput, DpoError> {
784    let combos = expand_grid(sweep);
785    if combos.is_empty() {
786        return Err(DpoError::InvalidPeriod {
787            period: 0,
788            data_len: 0,
789        });
790    }
791
792    let len = data.len();
793    if len == 0 {
794        return Err(DpoError::EmptyInputData);
795    }
796
797    let first = data
798        .iter()
799        .position(|x| !x.is_nan())
800        .ok_or(DpoError::AllValuesNaN)?;
801    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
802    if len - first < max_p {
803        return Err(DpoError::NotEnoughValidData {
804            needed: max_p,
805            valid: len - first,
806        });
807    }
808
809    let rows = combos.len();
810    let cols = len;
811
812    let warm: Vec<usize> = combos
813        .iter()
814        .map(|c| {
815            let p = c.period.unwrap();
816            let back = p / 2 + 1;
817            (first + p - 1).max(back)
818        })
819        .collect();
820
821    let mut buf_mu = make_uninit_matrix(rows, cols);
822    init_matrix_prefixes(&mut buf_mu, cols, &warm);
823
824    let mut values = unsafe {
825        let ptr = buf_mu.as_mut_ptr() as *mut f64;
826        let len = buf_mu.len();
827        std::mem::forget(buf_mu);
828        Vec::from_raw_parts(ptr, len, len)
829    };
830
831    let mut pfx = vec![0.0f64; cols];
832    if cols > 0 {
833        let mut acc = 0.0f64;
834        for i in 0..cols {
835            if i < first {
836                pfx[i] = 0.0;
837            } else {
838                acc += data[i];
839                pfx[i] = acc;
840            }
841        }
842    }
843
844    let do_row = |row: usize, out_row: &mut [f64]| {
845        let period = combos[row].period.unwrap();
846        let back = period / 2 + 1;
847        let warm = (first + period - 1).max(back);
848        let scale = 1.0f64 / (period as f64);
849
850        let mut i = warm;
851        while i < cols {
852            let prev = if i >= period { pfx[i - period] } else { 0.0 };
853            let sum = pfx[i] - prev;
854            let avg = sum * scale;
855            let price = data[i - back];
856            out_row[i] = price - avg;
857            i += 1;
858        }
859    };
860
861    if parallel {
862        #[cfg(not(target_arch = "wasm32"))]
863        {
864            values
865                .par_chunks_mut(cols)
866                .enumerate()
867                .for_each(|(row, slice)| do_row(row, slice));
868        }
869
870        #[cfg(target_arch = "wasm32")]
871        {
872            for (row, slice) in values.chunks_mut(cols).enumerate() {
873                do_row(row, slice);
874            }
875        }
876    } else {
877        for (row, slice) in values.chunks_mut(cols).enumerate() {
878            do_row(row, slice);
879        }
880    }
881
882    Ok(DpoBatchOutput {
883        values,
884        combos,
885        rows,
886        cols,
887    })
888}
889
890#[inline(always)]
891unsafe fn dpo_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
892    dpo_scalar(data, period, first, out)
893}
894
895#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
896#[inline(always)]
897unsafe fn dpo_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
898    dpo_avx2(data, period, first, out)
899}
900
901#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
902#[inline(always)]
903pub unsafe fn dpo_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
904    if period <= 32 {
905        dpo_row_avx512_short(data, first, period, out);
906    } else {
907        dpo_row_avx512_long(data, first, period, out);
908    }
909}
910
911#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
912#[inline(always)]
913unsafe fn dpo_row_avx512_short(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
914    dpo_avx2(data, period, first, out)
915}
916
917#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
918#[inline(always)]
919unsafe fn dpo_row_avx512_long(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
920    dpo_avx2(data, period, first, out)
921}
922
923#[derive(Debug, Clone)]
924pub struct DpoStream {
925    period: usize,
926    back: usize,
927    inv_period: f64,
928
929    sma_buf: Vec<f64>,
930    lag_buf: Vec<f64>,
931    sum: f64,
932
933    sma_head: usize,
934    lag_head: usize,
935    count: usize,
936}
937
938impl DpoStream {
939    #[inline]
940    pub fn try_new(params: DpoParams) -> Result<Self, DpoError> {
941        let period = params.period.unwrap_or(5);
942        if period == 0 {
943            return Err(DpoError::InvalidPeriod {
944                period,
945                data_len: 0,
946            });
947        }
948        let back = period / 2 + 1;
949        let inv_period = 1.0f64 / (period as f64);
950
951        Ok(Self {
952            period,
953            back,
954            inv_period,
955
956            sma_buf: vec![f64::NAN; period],
957
958            lag_buf: vec![f64::NAN; back + 1],
959            sum: 0.0,
960
961            sma_head: 0,
962            lag_head: 0,
963            count: 0,
964        })
965    }
966
967    #[inline(always)]
968    pub fn update(&mut self, value: f64) -> Option<f64> {
969        self.lag_buf[self.lag_head] = value;
970        self.lag_head += 1;
971        if self.lag_head == self.lag_buf.len() {
972            self.lag_head = 0;
973        }
974
975        let old = self.sma_buf[self.sma_head];
976        self.sma_buf[self.sma_head] = value;
977        self.sma_head += 1;
978        if self.sma_head == self.period {
979            self.sma_head = 0;
980        }
981
982        if old.is_nan() {
983            self.sum += value;
984        } else {
985            self.sum += value - old;
986        }
987
988        self.count += 1;
989
990        if self.count < self.period || self.count <= self.back {
991            return None;
992        }
993
994        let lagged_value = self.lag_buf[self.lag_head];
995
996        let dpo = (-self.inv_period).mul_add(self.sum, lagged_value);
997
998        Some(dpo)
999    }
1000}
1001
1002#[cfg(test)]
1003mod tests {
1004    use super::*;
1005    use crate::skip_if_unsupported;
1006    use crate::utilities::data_loader::read_candles_from_csv;
1007
1008    fn check_dpo_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1009        skip_if_unsupported!(kernel, test_name);
1010        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1011        let candles = read_candles_from_csv(file_path)?;
1012        let default_params = DpoParams { period: None };
1013        let input = DpoInput::from_candles(&candles, "close", default_params);
1014        let output = dpo_with_kernel(&input, kernel)?;
1015        assert_eq!(output.values.len(), candles.close.len());
1016        Ok(())
1017    }
1018
1019    fn check_dpo_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1020        skip_if_unsupported!(kernel, test_name);
1021        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1022        let candles = read_candles_from_csv(file_path)?;
1023        let input = DpoInput::from_candles(&candles, "close", DpoParams { period: Some(5) });
1024        let result = dpo_with_kernel(&input, kernel)?;
1025        let expected_last_five = [
1026            65.3999999999287,
1027            131.3999999999287,
1028            32.599999999925785,
1029            98.3999999999287,
1030            117.99999999992724,
1031        ];
1032        let start = result.values.len().saturating_sub(5);
1033        for (i, &val) in result.values[start..].iter().enumerate() {
1034            let diff = (val - expected_last_five[i]).abs();
1035            assert!(
1036                diff < 1e-1,
1037                "[{}] DPO {:?} mismatch at idx {}: got {}, expected {}",
1038                test_name,
1039                kernel,
1040                i,
1041                val,
1042                expected_last_five[i]
1043            );
1044        }
1045        Ok(())
1046    }
1047
1048    fn check_dpo_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1049        skip_if_unsupported!(kernel, test_name);
1050        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1051        let candles = read_candles_from_csv(file_path)?;
1052        let input = DpoInput::with_default_candles(&candles);
1053        match input.data {
1054            DpoData::Candles { source, .. } => assert_eq!(source, "close"),
1055            _ => panic!("Expected DpoData::Candles"),
1056        }
1057        let output = dpo_with_kernel(&input, kernel)?;
1058        assert_eq!(output.values.len(), candles.close.len());
1059        Ok(())
1060    }
1061
1062    fn check_dpo_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1063        skip_if_unsupported!(kernel, test_name);
1064        let input_data = [10.0, 20.0, 30.0];
1065        let params = DpoParams { period: Some(0) };
1066        let input = DpoInput::from_slice(&input_data, params);
1067        let res = dpo_with_kernel(&input, kernel);
1068        assert!(
1069            res.is_err(),
1070            "[{}] DPO should fail with zero period",
1071            test_name
1072        );
1073        Ok(())
1074    }
1075
1076    fn check_dpo_period_exceeds_length(
1077        test_name: &str,
1078        kernel: Kernel,
1079    ) -> Result<(), Box<dyn Error>> {
1080        skip_if_unsupported!(kernel, test_name);
1081        let data_small = [10.0, 20.0, 30.0];
1082        let params = DpoParams { period: Some(10) };
1083        let input = DpoInput::from_slice(&data_small, params);
1084        let res = dpo_with_kernel(&input, kernel);
1085        assert!(
1086            res.is_err(),
1087            "[{}] DPO should fail with period exceeding length",
1088            test_name
1089        );
1090        Ok(())
1091    }
1092
1093    fn check_dpo_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1094        skip_if_unsupported!(kernel, test_name);
1095        let single_point = [42.0];
1096        let params = DpoParams { period: Some(5) };
1097        let input = DpoInput::from_slice(&single_point, params);
1098        let res = dpo_with_kernel(&input, kernel);
1099        assert!(
1100            res.is_err(),
1101            "[{}] DPO should fail with insufficient data",
1102            test_name
1103        );
1104        Ok(())
1105    }
1106
1107    fn check_dpo_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1108        skip_if_unsupported!(kernel, test_name);
1109        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1110        let candles = read_candles_from_csv(file_path)?;
1111        let first_params = DpoParams { period: Some(5) };
1112        let first_input = DpoInput::from_candles(&candles, "close", first_params);
1113        let first_result = dpo_with_kernel(&first_input, kernel)?;
1114        let second_params = DpoParams { period: Some(5) };
1115        let second_input = DpoInput::from_slice(&first_result.values, second_params);
1116        let second_result = dpo_with_kernel(&second_input, kernel)?;
1117        assert_eq!(second_result.values.len(), first_result.values.len());
1118        Ok(())
1119    }
1120
1121    fn check_dpo_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1122        skip_if_unsupported!(kernel, test_name);
1123        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1124        let candles = read_candles_from_csv(file_path)?;
1125        let input = DpoInput::from_candles(&candles, "close", DpoParams { period: Some(5) });
1126        let res = dpo_with_kernel(&input, kernel)?;
1127        assert_eq!(res.values.len(), candles.close.len());
1128        if res.values.len() > 20 {
1129            for (i, &val) in res.values[20..].iter().enumerate() {
1130                assert!(
1131                    !val.is_nan(),
1132                    "[{}] Found unexpected NaN at out-index {}",
1133                    test_name,
1134                    20 + i
1135                );
1136            }
1137        }
1138        Ok(())
1139    }
1140
1141    fn check_dpo_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1142        skip_if_unsupported!(kernel, test_name);
1143        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1144        let candles = read_candles_from_csv(file_path)?;
1145        let period = 5;
1146        let input = DpoInput::from_candles(
1147            &candles,
1148            "close",
1149            DpoParams {
1150                period: Some(period),
1151            },
1152        );
1153        let batch_output = dpo_with_kernel(&input, kernel)?.values;
1154        let mut stream = DpoStream::try_new(DpoParams {
1155            period: Some(period),
1156        })?;
1157        let mut stream_values = Vec::with_capacity(candles.close.len());
1158        for &price in &candles.close {
1159            match stream.update(price) {
1160                Some(dpo_val) => stream_values.push(dpo_val),
1161                None => stream_values.push(f64::NAN),
1162            }
1163        }
1164        assert_eq!(batch_output.len(), stream_values.len());
1165        for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1166            if b.is_nan() && s.is_nan() {
1167                continue;
1168            }
1169            let diff = (b - s).abs();
1170            assert!(
1171                diff < 1e-9,
1172                "[{}] DPO streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1173                test_name,
1174                i,
1175                b,
1176                s,
1177                diff
1178            );
1179        }
1180        Ok(())
1181    }
1182
1183    #[cfg(debug_assertions)]
1184    fn check_dpo_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1185        skip_if_unsupported!(kernel, test_name);
1186
1187        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1188        let candles = read_candles_from_csv(file_path)?;
1189
1190        let test_params = vec![
1191            DpoParams::default(),
1192            DpoParams { period: Some(1) },
1193            DpoParams { period: Some(2) },
1194            DpoParams { period: Some(3) },
1195            DpoParams { period: Some(5) },
1196            DpoParams { period: Some(10) },
1197            DpoParams { period: Some(15) },
1198            DpoParams { period: Some(20) },
1199            DpoParams { period: Some(30) },
1200            DpoParams { period: Some(50) },
1201            DpoParams { period: Some(75) },
1202            DpoParams { period: Some(100) },
1203            DpoParams { period: Some(200) },
1204            DpoParams { period: Some(500) },
1205            DpoParams { period: Some(1000) },
1206        ];
1207
1208        for (param_idx, params) in test_params.iter().enumerate() {
1209            let input = DpoInput::from_candles(&candles, "close", params.clone());
1210            let output = dpo_with_kernel(&input, kernel)?;
1211
1212            for (i, &val) in output.values.iter().enumerate() {
1213                if val.is_nan() {
1214                    continue;
1215                }
1216
1217                let bits = val.to_bits();
1218
1219                if bits == 0x11111111_11111111 {
1220                    panic!(
1221                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1222						 with params: period={} (param set {})",
1223                        test_name,
1224                        val,
1225                        bits,
1226                        i,
1227                        params.period.unwrap_or(5),
1228                        param_idx
1229                    );
1230                }
1231
1232                if bits == 0x22222222_22222222 {
1233                    panic!(
1234                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1235						 with params: period={} (param set {})",
1236                        test_name,
1237                        val,
1238                        bits,
1239                        i,
1240                        params.period.unwrap_or(5),
1241                        param_idx
1242                    );
1243                }
1244
1245                if bits == 0x33333333_33333333 {
1246                    panic!(
1247                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1248						 with params: period={} (param set {})",
1249                        test_name,
1250                        val,
1251                        bits,
1252                        i,
1253                        params.period.unwrap_or(5),
1254                        param_idx
1255                    );
1256                }
1257            }
1258        }
1259
1260        Ok(())
1261    }
1262
1263    #[cfg(not(debug_assertions))]
1264    fn check_dpo_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1265        Ok(())
1266    }
1267
1268    #[cfg(feature = "proptest")]
1269    #[allow(clippy::float_cmp)]
1270    fn check_dpo_property(
1271        test_name: &str,
1272        kernel: Kernel,
1273    ) -> Result<(), Box<dyn std::error::Error>> {
1274        use proptest::prelude::*;
1275        skip_if_unsupported!(kernel, test_name);
1276
1277        let strat = (1usize..=100).prop_flat_map(|period| {
1278            (
1279                prop::collection::vec(
1280                    (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1281                    period..400,
1282                ),
1283                Just(period),
1284            )
1285        });
1286
1287        proptest::test_runner::TestRunner::default().run(&strat, |(data, period)| {
1288            let params = DpoParams {
1289                period: Some(period),
1290            };
1291            let input = DpoInput::from_slice(&data, params);
1292
1293            let DpoOutput { values: out } = dpo_with_kernel(&input, kernel).unwrap();
1294            let DpoOutput { values: ref_out } = dpo_with_kernel(&input, Kernel::Scalar).unwrap();
1295
1296            for i in 0..period.min(data.len()) {
1297                prop_assert!(
1298                    out[i].is_nan(),
1299                    "[{}] Expected NaN during warmup at index {}, got {}",
1300                    test_name,
1301                    i,
1302                    out[i]
1303                );
1304            }
1305
1306            for i in period..data.len() {
1307                prop_assert!(
1308                    out[i].is_finite(),
1309                    "[{}] Expected finite value at index {}, got {}",
1310                    test_name,
1311                    i,
1312                    out[i]
1313                );
1314            }
1315
1316            for i in 0..data.len() {
1317                if out[i].is_nan() && ref_out[i].is_nan() {
1318                    continue;
1319                }
1320                let y_bits = out[i].to_bits();
1321                let r_bits = ref_out[i].to_bits();
1322                let ulp_diff = if y_bits > r_bits {
1323                    y_bits - r_bits
1324                } else {
1325                    r_bits - y_bits
1326                };
1327                prop_assert!(
1328                    ulp_diff <= 3,
1329                    "[{}] Kernel mismatch at idx {}: {} ({:016X}) vs {} ({:016X}), ULP diff: {}",
1330                    test_name,
1331                    i,
1332                    out[i],
1333                    y_bits,
1334                    ref_out[i],
1335                    r_bits,
1336                    ulp_diff
1337                );
1338            }
1339
1340            let back = period / 2 + 1;
1341            for i in period..data.len() {
1342                if i >= back {
1343                    let sum: f64 = data[i + 1 - period..=i].iter().sum();
1344                    let avg = sum / period as f64;
1345                    let expected_dpo = data[i - back] - avg;
1346
1347                    prop_assert!(
1348                        (out[i] - expected_dpo).abs() < 1e-9,
1349                        "[{}] DPO formula mismatch at idx {}: got {}, expected {}",
1350                        test_name,
1351                        i,
1352                        out[i],
1353                        expected_dpo
1354                    );
1355                }
1356            }
1357
1358            if data.len() > period {
1359                let min_val = data.iter().cloned().fold(f64::INFINITY, f64::min);
1360                let max_val = data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1361                let range = max_val - min_val;
1362
1363                for i in period..data.len() {
1364                    prop_assert!(
1365                        out[i].abs() <= 2.0 * range,
1366                        "[{}] DPO exceeds reasonable bounds at idx {}: {} (data range: {})",
1367                        test_name,
1368                        i,
1369                        out[i],
1370                        range
1371                    );
1372                }
1373            }
1374
1375            if period == 1 {
1376                for i in 1..data.len() {
1377                    let expected = data[i - 1] - data[i];
1378                    prop_assert!(
1379                        (out[i] - expected).abs() < 1e-9,
1380                        "[{}] Period=1 special case failed at idx {}: got {}, expected {}",
1381                        test_name,
1382                        i,
1383                        out[i],
1384                        expected
1385                    );
1386                }
1387            }
1388
1389            if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10) && data.len() > period {
1390                for i in period..data.len() {
1391                    prop_assert!(
1392                        out[i].abs() < 1e-9,
1393                        "[{}] Expected zero DPO for constant data at idx {}, got {}",
1394                        test_name,
1395                        i,
1396                        out[i]
1397                    );
1398                }
1399            }
1400
1401            for i in 0..data.len() {
1402                if !out[i].is_nan() {
1403                    let bits = out[i].to_bits();
1404                    prop_assert!(
1405                        bits != 0x11111111_11111111
1406                            && bits != 0x22222222_22222222
1407                            && bits != 0x33333333_33333333,
1408                        "[{}] Found poison value at idx {}: {} ({:016X})",
1409                        test_name,
1410                        i,
1411                        out[i],
1412                        bits
1413                    );
1414                }
1415            }
1416
1417            Ok(())
1418        })?;
1419
1420        Ok(())
1421    }
1422
1423    macro_rules! generate_all_dpo_tests {
1424        ($($test_fn:ident),*) => {
1425            paste! {
1426                $(
1427                    #[test]
1428                    fn [<$test_fn _scalar_f64>]() {
1429                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1430                    }
1431                )*
1432                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1433                $(
1434                    #[test]
1435                    fn [<$test_fn _avx2_f64>]() {
1436                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1437                    }
1438                    #[test]
1439                    fn [<$test_fn _avx512_f64>]() {
1440                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1441                    }
1442                )*
1443                #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
1444                $(
1445                    #[test]
1446                    fn [<$test_fn _simd128_f64>]() {
1447                        let _ = $test_fn(stringify!([<$test_fn _simd128_f64>]), Kernel::Scalar);
1448                    }
1449                )*
1450            }
1451        }
1452    }
1453
1454    generate_all_dpo_tests!(
1455        check_dpo_partial_params,
1456        check_dpo_accuracy,
1457        check_dpo_default_candles,
1458        check_dpo_zero_period,
1459        check_dpo_period_exceeds_length,
1460        check_dpo_very_small_dataset,
1461        check_dpo_reinput,
1462        check_dpo_nan_handling,
1463        check_dpo_streaming,
1464        check_dpo_no_poison
1465    );
1466
1467    #[cfg(feature = "proptest")]
1468    generate_all_dpo_tests!(check_dpo_property);
1469
1470    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1471        skip_if_unsupported!(kernel, test);
1472
1473        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1474        let c = read_candles_from_csv(file)?;
1475
1476        let output = DpoBatchBuilder::new()
1477            .kernel(kernel)
1478            .apply_candles(&c, "close")?;
1479
1480        let def = DpoParams::default();
1481        let row = output.values_for(&def).expect("default row missing");
1482
1483        assert_eq!(row.len(), c.close.len());
1484
1485        let expected = [
1486            65.3999999999287,
1487            131.3999999999287,
1488            32.599999999925785,
1489            98.3999999999287,
1490            117.99999999992724,
1491        ];
1492        let start = row.len() - 5;
1493        for (i, &v) in row[start..].iter().enumerate() {
1494            assert!(
1495                (v - expected[i]).abs() < 1e-1,
1496                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1497            );
1498        }
1499        Ok(())
1500    }
1501
1502    macro_rules! gen_batch_tests {
1503        ($fn_name:ident) => {
1504            paste! {
1505                #[test] fn [<$fn_name _scalar>]()      {
1506                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1507                }
1508                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1509                #[test] fn [<$fn_name _avx2>]()        {
1510                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1511                }
1512                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1513                #[test] fn [<$fn_name _avx512>]()      {
1514                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1515                }
1516                #[test] fn [<$fn_name _auto_detect>]() {
1517                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1518                }
1519            }
1520        };
1521    }
1522    #[cfg(debug_assertions)]
1523    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1524        skip_if_unsupported!(kernel, test);
1525
1526        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1527        let c = read_candles_from_csv(file)?;
1528
1529        let test_configs = vec![
1530            (1, 10, 1),
1531            (5, 30, 5),
1532            (10, 50, 10),
1533            (50, 200, 50),
1534            (100, 500, 100),
1535            (5, 5, 0),
1536            (2, 20, 2),
1537            (25, 100, 25),
1538            (200, 1000, 200),
1539        ];
1540
1541        for (cfg_idx, &(period_start, period_end, period_step)) in test_configs.iter().enumerate() {
1542            let output = DpoBatchBuilder::new()
1543                .kernel(kernel)
1544                .period_range(period_start, period_end, period_step)
1545                .apply_candles(&c, "close")?;
1546
1547            for (idx, &val) in output.values.iter().enumerate() {
1548                if val.is_nan() {
1549                    continue;
1550                }
1551
1552                let bits = val.to_bits();
1553                let row = idx / output.cols;
1554                let col = idx % output.cols;
1555                let combo = &output.combos[row];
1556
1557                if bits == 0x11111111_11111111 {
1558                    panic!(
1559                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1560						at row {} col {} (flat index {}) with params: period={}",
1561                        test,
1562                        cfg_idx,
1563                        val,
1564                        bits,
1565                        row,
1566                        col,
1567                        idx,
1568                        combo.period.unwrap_or(5)
1569                    );
1570                }
1571
1572                if bits == 0x22222222_22222222 {
1573                    panic!(
1574                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1575						at row {} col {} (flat index {}) with params: period={}",
1576                        test,
1577                        cfg_idx,
1578                        val,
1579                        bits,
1580                        row,
1581                        col,
1582                        idx,
1583                        combo.period.unwrap_or(5)
1584                    );
1585                }
1586
1587                if bits == 0x33333333_33333333 {
1588                    panic!(
1589                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1590						at row {} col {} (flat index {}) with params: period={}",
1591                        test,
1592                        cfg_idx,
1593                        val,
1594                        bits,
1595                        row,
1596                        col,
1597                        idx,
1598                        combo.period.unwrap_or(5)
1599                    );
1600                }
1601            }
1602        }
1603
1604        Ok(())
1605    }
1606
1607    #[cfg(not(debug_assertions))]
1608    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1609        Ok(())
1610    }
1611
1612    gen_batch_tests!(check_batch_default_row);
1613    gen_batch_tests!(check_batch_no_poison);
1614
1615    #[test]
1616    fn test_dpo_into_matches_api() -> Result<(), Box<dyn Error>> {
1617        let mut data = Vec::with_capacity(512);
1618        for _ in 0..4 {
1619            data.push(f64::NAN);
1620        }
1621        for i in 0..508 {
1622            let x = i as f64;
1623            data.push((0.1 * x).sin() * 50.0 + 0.01 * x);
1624        }
1625
1626        let params = DpoParams { period: Some(5) };
1627        let input = DpoInput::from_slice(&data, params);
1628
1629        let baseline = dpo(&input)?.values;
1630
1631        let mut out = vec![0.0; data.len()];
1632
1633        #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1634        {
1635            dpo_into(&input, &mut out)?;
1636            assert_eq!(out.len(), baseline.len());
1637            for (idx, (a, b)) in out.iter().zip(baseline.iter()).enumerate() {
1638                let equal = (a.is_nan() && b.is_nan()) || (a == b);
1639                assert!(
1640                    equal,
1641                    "dpo_into parity mismatch at idx {}: got {}, expected {}",
1642                    idx, a, b
1643                );
1644            }
1645        }
1646
1647        Ok(())
1648    }
1649}
1650
1651#[cfg(all(feature = "python", feature = "cuda"))]
1652use crate::utilities::dlpack_cuda::{make_device_array_py, DeviceArrayF32Py};
1653#[cfg(feature = "python")]
1654use crate::utilities::kernel_validation::validate_kernel;
1655#[cfg(all(feature = "python", feature = "cuda"))]
1656use numpy::PyUntypedArrayMethods;
1657#[cfg(feature = "python")]
1658use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
1659#[cfg(feature = "python")]
1660use pyo3::exceptions::PyValueError;
1661#[cfg(feature = "python")]
1662use pyo3::prelude::*;
1663#[cfg(feature = "python")]
1664use pyo3::types::PyDict;
1665
1666#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1667use serde::{Deserialize, Serialize};
1668#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1669use wasm_bindgen::prelude::*;
1670
1671#[cfg(all(feature = "python", feature = "cuda"))]
1672pub type DpoDeviceArrayF32Py = DeviceArrayF32Py;
1673
1674#[cfg(feature = "python")]
1675#[pyfunction(name = "dpo")]
1676#[pyo3(signature = (data, period, kernel=None))]
1677pub fn dpo_py<'py>(
1678    py: Python<'py>,
1679    data: PyReadonlyArray1<'py, f64>,
1680    period: usize,
1681    kernel: Option<&str>,
1682) -> PyResult<Bound<'py, PyArray1<f64>>> {
1683    use numpy::{IntoPyArray, PyArrayMethods};
1684
1685    let slice_in = data.as_slice()?;
1686    let kern = validate_kernel(kernel, false)?;
1687
1688    let params = DpoParams {
1689        period: Some(period),
1690    };
1691    let input = DpoInput::from_slice(slice_in, params);
1692
1693    let result_vec: Vec<f64> = py
1694        .allow_threads(|| dpo_with_kernel(&input, kern).map(|o| o.values))
1695        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1696
1697    Ok(result_vec.into_pyarray(py))
1698}
1699
1700#[cfg(feature = "python")]
1701#[pyclass(name = "DpoStream")]
1702pub struct DpoStreamPy {
1703    stream: DpoStream,
1704}
1705
1706#[cfg(feature = "python")]
1707#[pymethods]
1708impl DpoStreamPy {
1709    #[new]
1710    fn new(period: usize) -> PyResult<Self> {
1711        let params = DpoParams {
1712            period: Some(period),
1713        };
1714        let stream =
1715            DpoStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1716        Ok(DpoStreamPy { stream })
1717    }
1718
1719    fn update(&mut self, value: f64) -> Option<f64> {
1720        self.stream.update(value)
1721    }
1722}
1723
1724#[cfg(feature = "python")]
1725#[pyfunction(name = "dpo_batch")]
1726#[pyo3(signature = (data, period_range, kernel=None))]
1727pub fn dpo_batch_py<'py>(
1728    py: Python<'py>,
1729    data: PyReadonlyArray1<'py, f64>,
1730    period_range: (usize, usize, usize),
1731    kernel: Option<&str>,
1732) -> PyResult<Bound<'py, PyDict>> {
1733    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1734    use pyo3::types::PyDict;
1735
1736    let slice_in = data.as_slice()?;
1737    let kern = validate_kernel(kernel, true)?;
1738
1739    let sweep = DpoBatchRange {
1740        period: period_range,
1741    };
1742
1743    let combos = expand_grid(&sweep);
1744    let rows = combos.len();
1745    let cols = slice_in.len();
1746
1747    let out_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
1748    let slice_out = unsafe { out_arr.as_slice_mut()? };
1749
1750    let combos = py
1751        .allow_threads(|| {
1752            let kernel = match kern {
1753                Kernel::Auto => detect_best_batch_kernel(),
1754                k => k,
1755            };
1756            let simd = match kernel {
1757                Kernel::Avx512Batch => Kernel::Avx512,
1758                Kernel::Avx2Batch => Kernel::Avx2,
1759                Kernel::ScalarBatch => Kernel::Scalar,
1760                _ => unreachable!(),
1761            };
1762            dpo_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1763        })
1764        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1765
1766    let dict = PyDict::new(py);
1767    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1768    dict.set_item(
1769        "periods",
1770        combos
1771            .iter()
1772            .map(|p| p.period.unwrap() as u64)
1773            .collect::<Vec<_>>()
1774            .into_pyarray(py),
1775    )?;
1776
1777    Ok(dict)
1778}
1779
1780#[cfg(all(feature = "python", feature = "cuda"))]
1781#[pyfunction(name = "dpo_cuda_batch_dev")]
1782#[pyo3(signature = (data_f32, period_range, device_id=0))]
1783pub fn dpo_cuda_batch_dev_py<'py>(
1784    py: Python<'py>,
1785    data_f32: numpy::PyReadonlyArray1<'py, f32>,
1786    period_range: (usize, usize, usize),
1787    device_id: usize,
1788) -> PyResult<(DpoDeviceArrayF32Py, Bound<'py, PyDict>)> {
1789    use crate::cuda::cuda_available;
1790    if !cuda_available() {
1791        return Err(PyValueError::new_err("CUDA not available"));
1792    }
1793
1794    let slice_in = data_f32.as_slice()?;
1795    let sweep = DpoBatchRange {
1796        period: period_range,
1797    };
1798
1799    let combos = expand_grid(&sweep);
1800    let inner = py.allow_threads(|| {
1801        let cuda = crate::cuda::oscillators::dpo_wrapper::CudaDpo::new(device_id)
1802            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1803        cuda.dpo_batch_dev(slice_in, &sweep)
1804            .map_err(|e| PyValueError::new_err(e.to_string()))
1805    })?;
1806
1807    let dict = PyDict::new(py);
1808    dict.set_item(
1809        "periods",
1810        combos
1811            .iter()
1812            .map(|p| p.period.unwrap() as u64)
1813            .collect::<Vec<_>>()
1814            .into_pyarray(py),
1815    )?;
1816    let handle = make_device_array_py(device_id, inner)?;
1817    Ok((handle, dict))
1818}
1819
1820#[cfg(all(feature = "python", feature = "cuda"))]
1821#[pyfunction(name = "dpo_cuda_many_series_one_param_dev")]
1822#[pyo3(signature = (data_tm_f32, period, device_id=0))]
1823pub fn dpo_cuda_many_series_one_param_dev_py<'py>(
1824    py: Python<'py>,
1825    data_tm_f32: numpy::PyReadonlyArray2<'py, f32>,
1826    period: usize,
1827    device_id: usize,
1828) -> PyResult<DpoDeviceArrayF32Py> {
1829    use crate::cuda::cuda_available;
1830    if !cuda_available() {
1831        return Err(PyValueError::new_err("CUDA not available"));
1832    }
1833
1834    let shape = data_tm_f32.shape();
1835    if shape.len() != 2 {
1836        return Err(PyValueError::new_err("expected 2D array"));
1837    }
1838    let rows = shape[0];
1839    let cols = shape[1];
1840    let flat = data_tm_f32.as_slice()?;
1841    let params = DpoParams {
1842        period: Some(period),
1843    };
1844    let inner = py.allow_threads(|| {
1845        let cuda = crate::cuda::oscillators::dpo_wrapper::CudaDpo::new(device_id)
1846            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1847        cuda.dpo_many_series_one_param_time_major_dev(flat, cols, rows, &params)
1848            .map_err(|e| PyValueError::new_err(e.to_string()))
1849    })?;
1850    make_device_array_py(device_id, inner)
1851}
1852
1853#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1854#[wasm_bindgen]
1855pub fn dpo_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1856    let params = DpoParams {
1857        period: Some(period),
1858    };
1859    let input = DpoInput::from_slice(data, params);
1860
1861    let mut output = vec![0.0; data.len()];
1862
1863    dpo_into_slice(&mut output, &input, Kernel::Auto)
1864        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1865
1866    Ok(output)
1867}
1868
1869#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1870#[wasm_bindgen]
1871pub fn dpo_into(
1872    in_ptr: *const f64,
1873    out_ptr: *mut f64,
1874    len: usize,
1875    period: usize,
1876) -> Result<(), JsValue> {
1877    if in_ptr.is_null() || out_ptr.is_null() {
1878        return Err(JsValue::from_str("Null pointer provided"));
1879    }
1880
1881    unsafe {
1882        let data = std::slice::from_raw_parts(in_ptr, len);
1883        let params = DpoParams {
1884            period: Some(period),
1885        };
1886        let input = DpoInput::from_slice(data, params);
1887
1888        if in_ptr == out_ptr {
1889            let mut temp = vec![0.0; len];
1890            dpo_into_slice(&mut temp, &input, Kernel::Auto)
1891                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1892            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1893            out.copy_from_slice(&temp);
1894        } else {
1895            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1896            dpo_into_slice(out, &input, Kernel::Auto)
1897                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1898        }
1899        Ok(())
1900    }
1901}
1902
1903#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1904#[wasm_bindgen]
1905pub fn dpo_alloc(len: usize) -> *mut f64 {
1906    let mut vec = Vec::<f64>::with_capacity(len);
1907    let ptr = vec.as_mut_ptr();
1908    std::mem::forget(vec);
1909    ptr
1910}
1911
1912#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1913#[wasm_bindgen]
1914pub fn dpo_free(ptr: *mut f64, len: usize) {
1915    if !ptr.is_null() {
1916        unsafe {
1917            let _ = Vec::from_raw_parts(ptr, len, len);
1918        }
1919    }
1920}
1921
1922#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1923#[derive(Serialize, Deserialize)]
1924pub struct DpoBatchConfig {
1925    pub period_range: (usize, usize, usize),
1926}
1927
1928#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1929#[derive(Serialize, Deserialize)]
1930pub struct DpoBatchJsOutput {
1931    pub values: Vec<f64>,
1932    pub combos: Vec<DpoParams>,
1933    pub rows: usize,
1934    pub cols: usize,
1935}
1936
1937#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1938#[wasm_bindgen(js_name = dpo_batch)]
1939pub fn dpo_batch_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1940    let config: DpoBatchConfig = serde_wasm_bindgen::from_value(config)
1941        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1942
1943    let sweep = DpoBatchRange {
1944        period: config.period_range,
1945    };
1946
1947    let output = dpo_batch_inner(data, &sweep, Kernel::Auto, false)
1948        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1949
1950    let js_output = DpoBatchJsOutput {
1951        values: output.values,
1952        combos: output.combos,
1953        rows: output.rows,
1954        cols: output.cols,
1955    };
1956
1957    serde_wasm_bindgen::to_value(&js_output)
1958        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1959}
1960
1961#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1962#[wasm_bindgen]
1963pub fn dpo_batch_into(
1964    in_ptr: *const f64,
1965    out_ptr: *mut f64,
1966    len: usize,
1967    period_start: usize,
1968    period_end: usize,
1969    period_step: usize,
1970) -> Result<usize, JsValue> {
1971    if in_ptr.is_null() || out_ptr.is_null() {
1972        return Err(JsValue::from_str("Null pointer provided"));
1973    }
1974
1975    unsafe {
1976        let data = std::slice::from_raw_parts(in_ptr, len);
1977
1978        let sweep = DpoBatchRange {
1979            period: (period_start, period_end, period_step),
1980        };
1981
1982        let combos = expand_grid(&sweep);
1983        let rows = combos.len();
1984        let cols = len;
1985
1986        let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
1987
1988        dpo_batch_inner_into(data, &sweep, Kernel::Auto, false, out)
1989            .map_err(|e| JsValue::from_str(&e.to_string()))?;
1990
1991        Ok(rows)
1992    }
1993}