1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5 make_uninit_matrix,
6};
7use aligned_vec::{AVec, CACHELINE_ALIGN};
8use paste::paste;
9#[cfg(not(target_arch = "wasm32"))]
10use rayon::prelude::*;
11use std::convert::AsRef;
12use std::error::Error;
13use thiserror::Error;
14
15impl<'a> AsRef<[f64]> for DpoInput<'a> {
16 #[inline(always)]
17 fn as_ref(&self) -> &[f64] {
18 match &self.data {
19 DpoData::Slice(slice) => slice,
20 DpoData::Candles { candles, source } => source_type(candles, source),
21 }
22 }
23}
24
25#[derive(Debug, Clone)]
26pub enum DpoData<'a> {
27 Candles {
28 candles: &'a Candles,
29 source: &'a str,
30 },
31 Slice(&'a [f64]),
32}
33
34#[derive(Debug, Clone)]
35pub struct DpoOutput {
36 pub values: Vec<f64>,
37}
38
39#[derive(Debug, Clone)]
40#[cfg_attr(
41 all(target_arch = "wasm32", feature = "wasm"),
42 derive(Serialize, Deserialize)
43)]
44pub struct DpoParams {
45 pub period: Option<usize>,
46}
47
48impl Default for DpoParams {
49 fn default() -> Self {
50 Self { period: Some(5) }
51 }
52}
53
54#[derive(Debug, Clone)]
55pub struct DpoInput<'a> {
56 pub data: DpoData<'a>,
57 pub params: DpoParams,
58}
59
60impl<'a> DpoInput<'a> {
61 #[inline]
62 pub fn from_candles(c: &'a Candles, s: &'a str, p: DpoParams) -> Self {
63 Self {
64 data: DpoData::Candles {
65 candles: c,
66 source: s,
67 },
68 params: p,
69 }
70 }
71 #[inline]
72 pub fn from_slice(sl: &'a [f64], p: DpoParams) -> Self {
73 Self {
74 data: DpoData::Slice(sl),
75 params: p,
76 }
77 }
78 #[inline]
79 pub fn with_default_candles(c: &'a Candles) -> Self {
80 Self::from_candles(c, "close", DpoParams::default())
81 }
82 #[inline]
83 pub fn get_period(&self) -> usize {
84 self.params.period.unwrap_or(5)
85 }
86}
87
88#[derive(Copy, Clone, Debug)]
89pub struct DpoBuilder {
90 period: Option<usize>,
91 kernel: Kernel,
92}
93
94impl Default for DpoBuilder {
95 fn default() -> Self {
96 Self {
97 period: None,
98 kernel: Kernel::Auto,
99 }
100 }
101}
102
103impl DpoBuilder {
104 #[inline(always)]
105 pub fn new() -> Self {
106 Self::default()
107 }
108 #[inline(always)]
109 pub fn period(mut self, n: usize) -> Self {
110 self.period = Some(n);
111 self
112 }
113 #[inline(always)]
114 pub fn kernel(mut self, k: Kernel) -> Self {
115 self.kernel = k;
116 self
117 }
118
119 #[inline(always)]
120 pub fn apply(self, c: &Candles) -> Result<DpoOutput, DpoError> {
121 let p = DpoParams {
122 period: self.period,
123 };
124 let i = DpoInput::from_candles(c, "close", p);
125 dpo_with_kernel(&i, self.kernel)
126 }
127
128 #[inline(always)]
129 pub fn apply_slice(self, d: &[f64]) -> Result<DpoOutput, DpoError> {
130 let p = DpoParams {
131 period: self.period,
132 };
133 let i = DpoInput::from_slice(d, p);
134 dpo_with_kernel(&i, self.kernel)
135 }
136
137 #[inline(always)]
138 pub fn into_stream(self) -> Result<DpoStream, DpoError> {
139 let p = DpoParams {
140 period: self.period,
141 };
142 DpoStream::try_new(p)
143 }
144}
145
146#[derive(Debug, Error)]
147pub enum DpoError {
148 #[error("dpo: Input data slice is empty.")]
149 EmptyInputData,
150 #[error("dpo: All values are NaN.")]
151 AllValuesNaN,
152
153 #[error("dpo: Invalid period: period = {period}, data length = {data_len}")]
154 InvalidPeriod { period: usize, data_len: usize },
155
156 #[error("dpo: Not enough valid data: needed = {needed}, valid = {valid}")]
157 NotEnoughValidData { needed: usize, valid: usize },
158}
159
160#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
161impl From<DpoError> for JsValue {
162 fn from(err: DpoError) -> Self {
163 JsValue::from_str(&err.to_string())
164 }
165}
166
167#[inline]
168pub fn dpo(input: &DpoInput) -> Result<DpoOutput, DpoError> {
169 dpo_with_kernel(input, Kernel::Auto)
170}
171
172#[inline]
173pub fn dpo_into_slice(dst: &mut [f64], input: &DpoInput, kern: Kernel) -> Result<(), DpoError> {
174 let data: &[f64] = input.as_ref();
175 let len = data.len();
176 if len == 0 {
177 return Err(DpoError::EmptyInputData);
178 }
179
180 let first = data
181 .iter()
182 .position(|x| !x.is_nan())
183 .ok_or(DpoError::AllValuesNaN)?;
184 let period = input.get_period();
185
186 if period == 0 || period > len {
187 return Err(DpoError::InvalidPeriod {
188 period,
189 data_len: len,
190 });
191 }
192 if (len - first) < period {
193 return Err(DpoError::NotEnoughValidData {
194 needed: period,
195 valid: len - first,
196 });
197 }
198 if dst.len() != len {
199 return Err(DpoError::InvalidPeriod {
200 period: dst.len(),
201 data_len: len,
202 });
203 }
204
205 let chosen = match kern {
206 Kernel::Auto => detect_best_kernel(),
207 other => other,
208 };
209
210 unsafe {
211 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
212 if matches!(chosen, Kernel::Scalar | Kernel::ScalarBatch) {
213 dpo_simd128(data, period, first, dst);
214 } else {
215 match chosen {
216 Kernel::Scalar | Kernel::ScalarBatch => dpo_scalar(data, period, first, dst),
217 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
218 Kernel::Avx2 | Kernel::Avx2Batch => dpo_avx2(data, period, first, dst),
219 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
220 Kernel::Avx512 | Kernel::Avx512Batch => dpo_avx512(data, period, first, dst),
221 _ => unreachable!(),
222 }
223 }
224
225 #[cfg(not(all(target_arch = "wasm32", target_feature = "simd128")))]
226 {
227 match chosen {
228 Kernel::Scalar | Kernel::ScalarBatch => dpo_scalar(data, period, first, dst),
229 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
230 Kernel::Avx2 | Kernel::Avx2Batch => dpo_avx2(data, period, first, dst),
231 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
232 Kernel::Avx512 | Kernel::Avx512Batch => dpo_avx512(data, period, first, dst),
233 _ => unreachable!(),
234 }
235 }
236 }
237
238 let back = period / 2 + 1;
239 let warm = (first + period - 1).max(back);
240
241 let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
242 for v in &mut dst[..warm] {
243 *v = qnan;
244 }
245
246 Ok(())
247}
248
249#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
250#[inline]
251pub fn dpo_into(input: &DpoInput, out: &mut [f64]) -> Result<(), DpoError> {
252 dpo_into_slice(out, input, Kernel::Auto)
253}
254
255pub fn dpo_with_kernel(input: &DpoInput, kernel: Kernel) -> Result<DpoOutput, DpoError> {
256 let data: &[f64] = input.as_ref();
257 let len = data.len();
258 if len == 0 {
259 return Err(DpoError::EmptyInputData);
260 }
261
262 let first = data
263 .iter()
264 .position(|x| !x.is_nan())
265 .ok_or(DpoError::AllValuesNaN)?;
266 let period = input.get_period();
267
268 if period == 0 || period > len {
269 return Err(DpoError::InvalidPeriod {
270 period,
271 data_len: len,
272 });
273 }
274 if (len - first) < period {
275 return Err(DpoError::NotEnoughValidData {
276 needed: period,
277 valid: len - first,
278 });
279 }
280
281 let back = period / 2 + 1;
282 let warm = (first + period - 1).max(back);
283
284 let chosen = match kernel {
285 Kernel::Auto => detect_best_kernel(),
286 other => other,
287 };
288
289 let mut out = alloc_with_nan_prefix(len, warm);
290 unsafe {
291 match chosen {
292 Kernel::Scalar | Kernel::ScalarBatch => dpo_scalar(data, period, first, &mut out),
293 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
294 Kernel::Avx2 | Kernel::Avx2Batch => dpo_avx2(data, period, first, &mut out),
295 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
296 Kernel::Avx512 | Kernel::Avx512Batch => dpo_avx512(data, period, first, &mut out),
297 _ => unreachable!(),
298 }
299 }
300 Ok(DpoOutput { values: out })
301}
302
303#[inline(always)]
304pub fn dpo_scalar(data: &[f64], period: usize, first_val: usize, out: &mut [f64]) {
305 let len = data.len();
306 if len == 0 {
307 return;
308 }
309
310 let back = period / 2 + 1;
311 let scale = 1.0f64 / (period as f64);
312
313 unsafe {
314 let base = first_val;
315 let ptr_d = data.as_ptr();
316 let ptr_o = out.as_mut_ptr();
317
318 let mut sum = 0.0f64;
319 let mut k = 0usize;
320 while k < period {
321 sum += *ptr_d.add(base + k);
322 k += 1;
323 }
324
325 let mut i = base + period - 1;
326
327 if i < back {
328 let stop = back.min(len.saturating_sub(1));
329 while i < stop {
330 let next = i + 1;
331 sum += *ptr_d.add(next);
332 sum -= *ptr_d.add(next - period);
333 i = next;
334 }
335 }
336
337 while i + 3 < len {
338 let p0 = *ptr_d.add(i - back);
339 *ptr_o.add(i) = sum.mul_add(-scale, p0);
340
341 let a1 = *ptr_d.add(i + 1);
342 let r1 = *ptr_d.add(i + 1 - period);
343 let s1 = (sum + a1) - r1;
344 let p1 = *ptr_d.add(i + 1 - back);
345 *ptr_o.add(i + 1) = s1.mul_add(-scale, p1);
346
347 let a2 = *ptr_d.add(i + 2);
348 let r2 = *ptr_d.add(i + 2 - period);
349 let s2 = (s1 + a2) - r2;
350 let p2 = *ptr_d.add(i + 2 - back);
351 *ptr_o.add(i + 2) = s2.mul_add(-scale, p2);
352
353 let a3 = *ptr_d.add(i + 3);
354 let r3 = *ptr_d.add(i + 3 - period);
355 let s3 = (s2 + a3) - r3;
356 let p3 = *ptr_d.add(i + 3 - back);
357 *ptr_o.add(i + 3) = s3.mul_add(-scale, p3);
358
359 i += 4;
360 if i >= len {
361 return;
362 }
363 sum = s3 + *ptr_d.add(i);
364 sum -= *ptr_d.add(i - period);
365 }
366
367 while i < len {
368 if i >= back {
369 let p = *ptr_d.add(i - back);
370 *ptr_o.add(i) = sum.mul_add(-scale, p);
371 }
372 if i + 1 < len {
373 let next = i + 1;
374 sum += *ptr_d.add(next);
375 sum -= *ptr_d.add(next - period);
376 }
377 i += 1;
378 }
379 }
380}
381
382#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
383#[inline]
384unsafe fn dpo_simd128(data: &[f64], period: usize, first_val: usize, out: &mut [f64]) {
385 use core::arch::wasm32::*;
386
387 let len = data.len();
388 if len == 0 {
389 return;
390 }
391
392 let back = period / 2 + 1;
393 let start_idx = first_val + period - 1;
394 let warm = start_idx.max(back);
395 if warm >= len {
396 return;
397 }
398
399 let mut sum = 0.0f64;
400 for j in 0..period {
401 sum += data[first_val + j];
402 }
403
404 let mut cur = start_idx;
405 while cur < warm {
406 let next = cur + 1;
407 sum += data[next] - data[next - period];
408 cur = next;
409 }
410
411 let scale = 1.0f64 / (period as f64);
412 let scale_vec = f64x2_splat(scale);
413
414 let mut i = warm;
415 while i + 1 < len {
416 let price_vec = v128_load(&data[i - back] as *const f64 as *const v128);
417
418 let sum0 = sum;
419 let sum1 = sum0 + data[i + 1] - data[i + 1 - period];
420 let sum_vec = f64x2(sum0, sum1);
421
422 let result = f64x2_sub(price_vec, f64x2_mul(sum_vec, scale_vec));
423 v128_store(&mut out[i] as *mut f64 as *mut v128, result);
424
425 if i + 2 < len {
426 sum = sum1 + data[i + 2] - data[i + 2 - period];
427 }
428 i += 2;
429 }
430
431 if i < len {
432 out[i] = data[i - back] - (sum * scale);
433 }
434}
435
436#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
437#[inline]
438pub fn dpo_avx2(data: &[f64], period: usize, first_val: usize, out: &mut [f64]) {
439 let len = data.len();
440 if len == 0 {
441 return;
442 }
443 let back = period / 2 + 1;
444 let scale = 1.0f64 / (period as f64);
445 unsafe {
446 let mut sum = 0.0f64;
447 let base = first_val;
448 let mut k = 0usize;
449 let ptr_d = data.as_ptr();
450 while k < period {
451 sum += *ptr_d.add(base + k);
452 k += 1;
453 }
454 let mut i = base + period - 1;
455
456 if i < back {
457 let stop = back.min(len.saturating_sub(1));
458 while i < stop {
459 sum += *ptr_d.add(i + 1) - *ptr_d.add(i + 1 - period);
460 i += 1;
461 }
462 if i >= len {
463 return;
464 }
465 }
466
467 let ptr_o = out.as_mut_ptr();
468 while i + 3 < len {
469 let p0 = *ptr_d.add(i - back);
470 *ptr_o.add(i) = sum.mul_add(-scale, p0);
471
472 let a1 = *ptr_d.add(i + 1);
473 let r1 = *ptr_d.add(i + 1 - period);
474 let s1 = sum + (a1 - r1);
475 if i + 1 >= back {
476 let p1 = *ptr_d.add(i + 1 - back);
477 *ptr_o.add(i + 1) = s1.mul_add(-scale, p1);
478 }
479
480 let a2 = *ptr_d.add(i + 2);
481 let r2 = *ptr_d.add(i + 2 - period);
482 let s2 = s1 + (a2 - r2);
483 if i + 2 >= back {
484 let p2 = *ptr_d.add(i + 2 - back);
485 *ptr_o.add(i + 2) = s2.mul_add(-scale, p2);
486 }
487
488 let a3 = *ptr_d.add(i + 3);
489 let r3 = *ptr_d.add(i + 3 - period);
490 let s3 = s2 + (a3 - r3);
491 if i + 3 >= back {
492 let p3 = *ptr_d.add(i + 3 - back);
493 *ptr_o.add(i + 3) = s3.mul_add(-scale, p3);
494 }
495
496 i += 4;
497 if i >= len {
498 return;
499 }
500 let a4 = *ptr_d.add(i);
501 let r4 = *ptr_d.add(i - period);
502 sum = s3 + (a4 - r4);
503 }
504
505 while i < len {
506 if i >= back {
507 let p = *ptr_d.add(i - back);
508 *ptr_o.add(i) = sum.mul_add(-scale, p);
509 }
510 if i + 1 < len {
511 sum += *ptr_d.add(i + 1) - *ptr_d.add(i + 1 - period);
512 }
513 i += 1;
514 }
515 }
516}
517
518#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
519#[inline]
520pub fn dpo_avx512(data: &[f64], period: usize, first_val: usize, out: &mut [f64]) {
521 dpo_avx2(data, period, first_val, out)
522}
523
524#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
525#[inline]
526pub fn dpo_avx512_short(data: &[f64], period: usize, first_val: usize, out: &mut [f64]) {
527 dpo_avx2(data, period, first_val, out)
528}
529
530#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
531#[inline]
532pub fn dpo_avx512_long(data: &[f64], period: usize, first_val: usize, out: &mut [f64]) {
533 dpo_avx2(data, period, first_val, out)
534}
535
536#[inline]
537pub fn dpo_batch_with_kernel(
538 data: &[f64],
539 sweep: &DpoBatchRange,
540 k: Kernel,
541) -> Result<DpoBatchOutput, DpoError> {
542 let kernel = match k {
543 Kernel::Auto => detect_best_batch_kernel(),
544 other if other.is_batch() => other,
545 _ => {
546 return Err(DpoError::InvalidPeriod {
547 period: 0,
548 data_len: 0,
549 })
550 }
551 };
552 let simd = match kernel {
553 Kernel::Avx512Batch => Kernel::Avx512,
554 Kernel::Avx2Batch => Kernel::Avx2,
555 Kernel::ScalarBatch => Kernel::Scalar,
556 _ => unreachable!(),
557 };
558 dpo_batch_par_slice(data, sweep, simd)
559}
560
561#[derive(Clone, Debug)]
562pub struct DpoBatchRange {
563 pub period: (usize, usize, usize),
564}
565
566impl Default for DpoBatchRange {
567 fn default() -> Self {
568 Self {
569 period: (5, 254, 1),
570 }
571 }
572}
573
574#[derive(Clone, Debug, Default)]
575pub struct DpoBatchBuilder {
576 range: DpoBatchRange,
577 kernel: Kernel,
578}
579
580impl DpoBatchBuilder {
581 pub fn new() -> Self {
582 Self::default()
583 }
584 pub fn kernel(mut self, k: Kernel) -> Self {
585 self.kernel = k;
586 self
587 }
588 #[inline]
589 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
590 self.range.period = (start, end, step);
591 self
592 }
593 #[inline]
594 pub fn period_static(mut self, p: usize) -> Self {
595 self.range.period = (p, p, 0);
596 self
597 }
598
599 pub fn apply_slice(self, data: &[f64]) -> Result<DpoBatchOutput, DpoError> {
600 dpo_batch_with_kernel(data, &self.range, self.kernel)
601 }
602 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<DpoBatchOutput, DpoError> {
603 DpoBatchBuilder::new().kernel(k).apply_slice(data)
604 }
605 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<DpoBatchOutput, DpoError> {
606 let slice = source_type(c, src);
607 self.apply_slice(slice)
608 }
609 pub fn with_default_candles(c: &Candles) -> Result<DpoBatchOutput, DpoError> {
610 DpoBatchBuilder::new()
611 .kernel(Kernel::Auto)
612 .apply_candles(c, "close")
613 }
614}
615
616#[derive(Clone, Debug)]
617pub struct DpoBatchOutput {
618 pub values: Vec<f64>,
619 pub combos: Vec<DpoParams>,
620 pub rows: usize,
621 pub cols: usize,
622}
623
624impl DpoBatchOutput {
625 pub fn row_for_params(&self, p: &DpoParams) -> Option<usize> {
626 self.combos
627 .iter()
628 .position(|c| c.period.unwrap_or(5) == p.period.unwrap_or(5))
629 }
630
631 pub fn values_for(&self, p: &DpoParams) -> Option<&[f64]> {
632 self.row_for_params(p).map(|row| {
633 let start = row * self.cols;
634 &self.values[start..start + self.cols]
635 })
636 }
637}
638
639#[inline(always)]
640fn expand_grid(r: &DpoBatchRange) -> Vec<DpoParams> {
641 fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
642 if step == 0 || start == end {
643 return vec![start];
644 }
645 (start..=end).step_by(step).collect()
646 }
647 let periods = axis_usize(r.period);
648 let mut out = Vec::with_capacity(periods.len());
649 for &p in &periods {
650 out.push(DpoParams { period: Some(p) });
651 }
652 out
653}
654
655#[inline(always)]
656pub fn dpo_batch_slice(
657 data: &[f64],
658 sweep: &DpoBatchRange,
659 kern: Kernel,
660) -> Result<DpoBatchOutput, DpoError> {
661 dpo_batch_inner(data, sweep, kern, false)
662}
663
664#[inline(always)]
665pub fn dpo_batch_par_slice(
666 data: &[f64],
667 sweep: &DpoBatchRange,
668 kern: Kernel,
669) -> Result<DpoBatchOutput, DpoError> {
670 dpo_batch_inner(data, sweep, kern, true)
671}
672
673#[inline(always)]
674pub fn dpo_batch_inner_into(
675 data: &[f64],
676 sweep: &DpoBatchRange,
677 kern: Kernel,
678 parallel: bool,
679 out: &mut [f64],
680) -> Result<Vec<DpoParams>, DpoError> {
681 let combos = expand_grid(sweep);
682 if combos.is_empty() {
683 return Err(DpoError::InvalidPeriod {
684 period: 0,
685 data_len: 0,
686 });
687 }
688
689 let len = data.len();
690 if len == 0 {
691 return Err(DpoError::EmptyInputData);
692 }
693
694 let first = data
695 .iter()
696 .position(|x| !x.is_nan())
697 .ok_or(DpoError::AllValuesNaN)?;
698 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
699 if len - first < max_p {
700 return Err(DpoError::NotEnoughValidData {
701 needed: max_p,
702 valid: len - first,
703 });
704 }
705
706 let rows = combos.len();
707 let cols = len;
708 debug_assert_eq!(out.len(), rows * cols);
709
710 let warm: Vec<usize> = combos
711 .iter()
712 .map(|c| {
713 let p = c.period.unwrap();
714 let back = p / 2 + 1;
715 (first + p - 1).max(back)
716 })
717 .collect();
718
719 let out_mu: &mut [std::mem::MaybeUninit<f64>] = unsafe {
720 std::slice::from_raw_parts_mut(
721 out.as_mut_ptr() as *mut std::mem::MaybeUninit<f64>,
722 out.len(),
723 )
724 };
725 init_matrix_prefixes(out_mu, cols, &warm);
726
727 let mut pfx = vec![0.0f64; cols];
728 if cols > 0 {
729 let mut acc = 0.0f64;
730 for i in 0..cols {
731 if i < first {
732 pfx[i] = 0.0;
733 } else {
734 acc += data[i];
735 pfx[i] = acc;
736 }
737 }
738 }
739
740 let do_row = |row: usize, dst_mu: &mut [std::mem::MaybeUninit<f64>]| unsafe {
741 let period = combos[row].period.unwrap();
742 let back = period / 2 + 1;
743 let warm = (first + period - 1).max(back);
744 let dst = std::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, cols);
745 let scale = 1.0f64 / (period as f64);
746
747 let mut i = warm;
748 while i < cols {
749 let prev = if i >= period { pfx[i - period] } else { 0.0 };
750 let sum = pfx[i] - prev;
751 let avg = sum * scale;
752 let price = data[i - back];
753 dst[i] = price - avg;
754 i += 1;
755 }
756 };
757
758 if parallel {
759 #[cfg(not(target_arch = "wasm32"))]
760 out_mu
761 .par_chunks_mut(cols)
762 .enumerate()
763 .for_each(|(r, s)| do_row(r, s));
764 #[cfg(target_arch = "wasm32")]
765 for (r, s) in out_mu.chunks_mut(cols).enumerate() {
766 do_row(r, s);
767 }
768 } else {
769 for (r, s) in out_mu.chunks_mut(cols).enumerate() {
770 do_row(r, s);
771 }
772 }
773
774 Ok(combos)
775}
776
777#[inline(always)]
778fn dpo_batch_inner(
779 data: &[f64],
780 sweep: &DpoBatchRange,
781 kern: Kernel,
782 parallel: bool,
783) -> Result<DpoBatchOutput, DpoError> {
784 let combos = expand_grid(sweep);
785 if combos.is_empty() {
786 return Err(DpoError::InvalidPeriod {
787 period: 0,
788 data_len: 0,
789 });
790 }
791
792 let len = data.len();
793 if len == 0 {
794 return Err(DpoError::EmptyInputData);
795 }
796
797 let first = data
798 .iter()
799 .position(|x| !x.is_nan())
800 .ok_or(DpoError::AllValuesNaN)?;
801 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
802 if len - first < max_p {
803 return Err(DpoError::NotEnoughValidData {
804 needed: max_p,
805 valid: len - first,
806 });
807 }
808
809 let rows = combos.len();
810 let cols = len;
811
812 let warm: Vec<usize> = combos
813 .iter()
814 .map(|c| {
815 let p = c.period.unwrap();
816 let back = p / 2 + 1;
817 (first + p - 1).max(back)
818 })
819 .collect();
820
821 let mut buf_mu = make_uninit_matrix(rows, cols);
822 init_matrix_prefixes(&mut buf_mu, cols, &warm);
823
824 let mut values = unsafe {
825 let ptr = buf_mu.as_mut_ptr() as *mut f64;
826 let len = buf_mu.len();
827 std::mem::forget(buf_mu);
828 Vec::from_raw_parts(ptr, len, len)
829 };
830
831 let mut pfx = vec![0.0f64; cols];
832 if cols > 0 {
833 let mut acc = 0.0f64;
834 for i in 0..cols {
835 if i < first {
836 pfx[i] = 0.0;
837 } else {
838 acc += data[i];
839 pfx[i] = acc;
840 }
841 }
842 }
843
844 let do_row = |row: usize, out_row: &mut [f64]| {
845 let period = combos[row].period.unwrap();
846 let back = period / 2 + 1;
847 let warm = (first + period - 1).max(back);
848 let scale = 1.0f64 / (period as f64);
849
850 let mut i = warm;
851 while i < cols {
852 let prev = if i >= period { pfx[i - period] } else { 0.0 };
853 let sum = pfx[i] - prev;
854 let avg = sum * scale;
855 let price = data[i - back];
856 out_row[i] = price - avg;
857 i += 1;
858 }
859 };
860
861 if parallel {
862 #[cfg(not(target_arch = "wasm32"))]
863 {
864 values
865 .par_chunks_mut(cols)
866 .enumerate()
867 .for_each(|(row, slice)| do_row(row, slice));
868 }
869
870 #[cfg(target_arch = "wasm32")]
871 {
872 for (row, slice) in values.chunks_mut(cols).enumerate() {
873 do_row(row, slice);
874 }
875 }
876 } else {
877 for (row, slice) in values.chunks_mut(cols).enumerate() {
878 do_row(row, slice);
879 }
880 }
881
882 Ok(DpoBatchOutput {
883 values,
884 combos,
885 rows,
886 cols,
887 })
888}
889
890#[inline(always)]
891unsafe fn dpo_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
892 dpo_scalar(data, period, first, out)
893}
894
895#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
896#[inline(always)]
897unsafe fn dpo_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
898 dpo_avx2(data, period, first, out)
899}
900
901#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
902#[inline(always)]
903pub unsafe fn dpo_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
904 if period <= 32 {
905 dpo_row_avx512_short(data, first, period, out);
906 } else {
907 dpo_row_avx512_long(data, first, period, out);
908 }
909}
910
911#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
912#[inline(always)]
913unsafe fn dpo_row_avx512_short(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
914 dpo_avx2(data, period, first, out)
915}
916
917#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
918#[inline(always)]
919unsafe fn dpo_row_avx512_long(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
920 dpo_avx2(data, period, first, out)
921}
922
923#[derive(Debug, Clone)]
924pub struct DpoStream {
925 period: usize,
926 back: usize,
927 inv_period: f64,
928
929 sma_buf: Vec<f64>,
930 lag_buf: Vec<f64>,
931 sum: f64,
932
933 sma_head: usize,
934 lag_head: usize,
935 count: usize,
936}
937
938impl DpoStream {
939 #[inline]
940 pub fn try_new(params: DpoParams) -> Result<Self, DpoError> {
941 let period = params.period.unwrap_or(5);
942 if period == 0 {
943 return Err(DpoError::InvalidPeriod {
944 period,
945 data_len: 0,
946 });
947 }
948 let back = period / 2 + 1;
949 let inv_period = 1.0f64 / (period as f64);
950
951 Ok(Self {
952 period,
953 back,
954 inv_period,
955
956 sma_buf: vec![f64::NAN; period],
957
958 lag_buf: vec![f64::NAN; back + 1],
959 sum: 0.0,
960
961 sma_head: 0,
962 lag_head: 0,
963 count: 0,
964 })
965 }
966
967 #[inline(always)]
968 pub fn update(&mut self, value: f64) -> Option<f64> {
969 self.lag_buf[self.lag_head] = value;
970 self.lag_head += 1;
971 if self.lag_head == self.lag_buf.len() {
972 self.lag_head = 0;
973 }
974
975 let old = self.sma_buf[self.sma_head];
976 self.sma_buf[self.sma_head] = value;
977 self.sma_head += 1;
978 if self.sma_head == self.period {
979 self.sma_head = 0;
980 }
981
982 if old.is_nan() {
983 self.sum += value;
984 } else {
985 self.sum += value - old;
986 }
987
988 self.count += 1;
989
990 if self.count < self.period || self.count <= self.back {
991 return None;
992 }
993
994 let lagged_value = self.lag_buf[self.lag_head];
995
996 let dpo = (-self.inv_period).mul_add(self.sum, lagged_value);
997
998 Some(dpo)
999 }
1000}
1001
1002#[cfg(test)]
1003mod tests {
1004 use super::*;
1005 use crate::skip_if_unsupported;
1006 use crate::utilities::data_loader::read_candles_from_csv;
1007
1008 fn check_dpo_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1009 skip_if_unsupported!(kernel, test_name);
1010 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1011 let candles = read_candles_from_csv(file_path)?;
1012 let default_params = DpoParams { period: None };
1013 let input = DpoInput::from_candles(&candles, "close", default_params);
1014 let output = dpo_with_kernel(&input, kernel)?;
1015 assert_eq!(output.values.len(), candles.close.len());
1016 Ok(())
1017 }
1018
1019 fn check_dpo_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1020 skip_if_unsupported!(kernel, test_name);
1021 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1022 let candles = read_candles_from_csv(file_path)?;
1023 let input = DpoInput::from_candles(&candles, "close", DpoParams { period: Some(5) });
1024 let result = dpo_with_kernel(&input, kernel)?;
1025 let expected_last_five = [
1026 65.3999999999287,
1027 131.3999999999287,
1028 32.599999999925785,
1029 98.3999999999287,
1030 117.99999999992724,
1031 ];
1032 let start = result.values.len().saturating_sub(5);
1033 for (i, &val) in result.values[start..].iter().enumerate() {
1034 let diff = (val - expected_last_five[i]).abs();
1035 assert!(
1036 diff < 1e-1,
1037 "[{}] DPO {:?} mismatch at idx {}: got {}, expected {}",
1038 test_name,
1039 kernel,
1040 i,
1041 val,
1042 expected_last_five[i]
1043 );
1044 }
1045 Ok(())
1046 }
1047
1048 fn check_dpo_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1049 skip_if_unsupported!(kernel, test_name);
1050 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1051 let candles = read_candles_from_csv(file_path)?;
1052 let input = DpoInput::with_default_candles(&candles);
1053 match input.data {
1054 DpoData::Candles { source, .. } => assert_eq!(source, "close"),
1055 _ => panic!("Expected DpoData::Candles"),
1056 }
1057 let output = dpo_with_kernel(&input, kernel)?;
1058 assert_eq!(output.values.len(), candles.close.len());
1059 Ok(())
1060 }
1061
1062 fn check_dpo_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1063 skip_if_unsupported!(kernel, test_name);
1064 let input_data = [10.0, 20.0, 30.0];
1065 let params = DpoParams { period: Some(0) };
1066 let input = DpoInput::from_slice(&input_data, params);
1067 let res = dpo_with_kernel(&input, kernel);
1068 assert!(
1069 res.is_err(),
1070 "[{}] DPO should fail with zero period",
1071 test_name
1072 );
1073 Ok(())
1074 }
1075
1076 fn check_dpo_period_exceeds_length(
1077 test_name: &str,
1078 kernel: Kernel,
1079 ) -> Result<(), Box<dyn Error>> {
1080 skip_if_unsupported!(kernel, test_name);
1081 let data_small = [10.0, 20.0, 30.0];
1082 let params = DpoParams { period: Some(10) };
1083 let input = DpoInput::from_slice(&data_small, params);
1084 let res = dpo_with_kernel(&input, kernel);
1085 assert!(
1086 res.is_err(),
1087 "[{}] DPO should fail with period exceeding length",
1088 test_name
1089 );
1090 Ok(())
1091 }
1092
1093 fn check_dpo_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1094 skip_if_unsupported!(kernel, test_name);
1095 let single_point = [42.0];
1096 let params = DpoParams { period: Some(5) };
1097 let input = DpoInput::from_slice(&single_point, params);
1098 let res = dpo_with_kernel(&input, kernel);
1099 assert!(
1100 res.is_err(),
1101 "[{}] DPO should fail with insufficient data",
1102 test_name
1103 );
1104 Ok(())
1105 }
1106
1107 fn check_dpo_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1108 skip_if_unsupported!(kernel, test_name);
1109 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1110 let candles = read_candles_from_csv(file_path)?;
1111 let first_params = DpoParams { period: Some(5) };
1112 let first_input = DpoInput::from_candles(&candles, "close", first_params);
1113 let first_result = dpo_with_kernel(&first_input, kernel)?;
1114 let second_params = DpoParams { period: Some(5) };
1115 let second_input = DpoInput::from_slice(&first_result.values, second_params);
1116 let second_result = dpo_with_kernel(&second_input, kernel)?;
1117 assert_eq!(second_result.values.len(), first_result.values.len());
1118 Ok(())
1119 }
1120
1121 fn check_dpo_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1122 skip_if_unsupported!(kernel, test_name);
1123 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1124 let candles = read_candles_from_csv(file_path)?;
1125 let input = DpoInput::from_candles(&candles, "close", DpoParams { period: Some(5) });
1126 let res = dpo_with_kernel(&input, kernel)?;
1127 assert_eq!(res.values.len(), candles.close.len());
1128 if res.values.len() > 20 {
1129 for (i, &val) in res.values[20..].iter().enumerate() {
1130 assert!(
1131 !val.is_nan(),
1132 "[{}] Found unexpected NaN at out-index {}",
1133 test_name,
1134 20 + i
1135 );
1136 }
1137 }
1138 Ok(())
1139 }
1140
1141 fn check_dpo_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1142 skip_if_unsupported!(kernel, test_name);
1143 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1144 let candles = read_candles_from_csv(file_path)?;
1145 let period = 5;
1146 let input = DpoInput::from_candles(
1147 &candles,
1148 "close",
1149 DpoParams {
1150 period: Some(period),
1151 },
1152 );
1153 let batch_output = dpo_with_kernel(&input, kernel)?.values;
1154 let mut stream = DpoStream::try_new(DpoParams {
1155 period: Some(period),
1156 })?;
1157 let mut stream_values = Vec::with_capacity(candles.close.len());
1158 for &price in &candles.close {
1159 match stream.update(price) {
1160 Some(dpo_val) => stream_values.push(dpo_val),
1161 None => stream_values.push(f64::NAN),
1162 }
1163 }
1164 assert_eq!(batch_output.len(), stream_values.len());
1165 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1166 if b.is_nan() && s.is_nan() {
1167 continue;
1168 }
1169 let diff = (b - s).abs();
1170 assert!(
1171 diff < 1e-9,
1172 "[{}] DPO streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1173 test_name,
1174 i,
1175 b,
1176 s,
1177 diff
1178 );
1179 }
1180 Ok(())
1181 }
1182
1183 #[cfg(debug_assertions)]
1184 fn check_dpo_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1185 skip_if_unsupported!(kernel, test_name);
1186
1187 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1188 let candles = read_candles_from_csv(file_path)?;
1189
1190 let test_params = vec![
1191 DpoParams::default(),
1192 DpoParams { period: Some(1) },
1193 DpoParams { period: Some(2) },
1194 DpoParams { period: Some(3) },
1195 DpoParams { period: Some(5) },
1196 DpoParams { period: Some(10) },
1197 DpoParams { period: Some(15) },
1198 DpoParams { period: Some(20) },
1199 DpoParams { period: Some(30) },
1200 DpoParams { period: Some(50) },
1201 DpoParams { period: Some(75) },
1202 DpoParams { period: Some(100) },
1203 DpoParams { period: Some(200) },
1204 DpoParams { period: Some(500) },
1205 DpoParams { period: Some(1000) },
1206 ];
1207
1208 for (param_idx, params) in test_params.iter().enumerate() {
1209 let input = DpoInput::from_candles(&candles, "close", params.clone());
1210 let output = dpo_with_kernel(&input, kernel)?;
1211
1212 for (i, &val) in output.values.iter().enumerate() {
1213 if val.is_nan() {
1214 continue;
1215 }
1216
1217 let bits = val.to_bits();
1218
1219 if bits == 0x11111111_11111111 {
1220 panic!(
1221 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1222 with params: period={} (param set {})",
1223 test_name,
1224 val,
1225 bits,
1226 i,
1227 params.period.unwrap_or(5),
1228 param_idx
1229 );
1230 }
1231
1232 if bits == 0x22222222_22222222 {
1233 panic!(
1234 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1235 with params: period={} (param set {})",
1236 test_name,
1237 val,
1238 bits,
1239 i,
1240 params.period.unwrap_or(5),
1241 param_idx
1242 );
1243 }
1244
1245 if bits == 0x33333333_33333333 {
1246 panic!(
1247 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1248 with params: period={} (param set {})",
1249 test_name,
1250 val,
1251 bits,
1252 i,
1253 params.period.unwrap_or(5),
1254 param_idx
1255 );
1256 }
1257 }
1258 }
1259
1260 Ok(())
1261 }
1262
1263 #[cfg(not(debug_assertions))]
1264 fn check_dpo_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1265 Ok(())
1266 }
1267
1268 #[cfg(feature = "proptest")]
1269 #[allow(clippy::float_cmp)]
1270 fn check_dpo_property(
1271 test_name: &str,
1272 kernel: Kernel,
1273 ) -> Result<(), Box<dyn std::error::Error>> {
1274 use proptest::prelude::*;
1275 skip_if_unsupported!(kernel, test_name);
1276
1277 let strat = (1usize..=100).prop_flat_map(|period| {
1278 (
1279 prop::collection::vec(
1280 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1281 period..400,
1282 ),
1283 Just(period),
1284 )
1285 });
1286
1287 proptest::test_runner::TestRunner::default().run(&strat, |(data, period)| {
1288 let params = DpoParams {
1289 period: Some(period),
1290 };
1291 let input = DpoInput::from_slice(&data, params);
1292
1293 let DpoOutput { values: out } = dpo_with_kernel(&input, kernel).unwrap();
1294 let DpoOutput { values: ref_out } = dpo_with_kernel(&input, Kernel::Scalar).unwrap();
1295
1296 for i in 0..period.min(data.len()) {
1297 prop_assert!(
1298 out[i].is_nan(),
1299 "[{}] Expected NaN during warmup at index {}, got {}",
1300 test_name,
1301 i,
1302 out[i]
1303 );
1304 }
1305
1306 for i in period..data.len() {
1307 prop_assert!(
1308 out[i].is_finite(),
1309 "[{}] Expected finite value at index {}, got {}",
1310 test_name,
1311 i,
1312 out[i]
1313 );
1314 }
1315
1316 for i in 0..data.len() {
1317 if out[i].is_nan() && ref_out[i].is_nan() {
1318 continue;
1319 }
1320 let y_bits = out[i].to_bits();
1321 let r_bits = ref_out[i].to_bits();
1322 let ulp_diff = if y_bits > r_bits {
1323 y_bits - r_bits
1324 } else {
1325 r_bits - y_bits
1326 };
1327 prop_assert!(
1328 ulp_diff <= 3,
1329 "[{}] Kernel mismatch at idx {}: {} ({:016X}) vs {} ({:016X}), ULP diff: {}",
1330 test_name,
1331 i,
1332 out[i],
1333 y_bits,
1334 ref_out[i],
1335 r_bits,
1336 ulp_diff
1337 );
1338 }
1339
1340 let back = period / 2 + 1;
1341 for i in period..data.len() {
1342 if i >= back {
1343 let sum: f64 = data[i + 1 - period..=i].iter().sum();
1344 let avg = sum / period as f64;
1345 let expected_dpo = data[i - back] - avg;
1346
1347 prop_assert!(
1348 (out[i] - expected_dpo).abs() < 1e-9,
1349 "[{}] DPO formula mismatch at idx {}: got {}, expected {}",
1350 test_name,
1351 i,
1352 out[i],
1353 expected_dpo
1354 );
1355 }
1356 }
1357
1358 if data.len() > period {
1359 let min_val = data.iter().cloned().fold(f64::INFINITY, f64::min);
1360 let max_val = data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1361 let range = max_val - min_val;
1362
1363 for i in period..data.len() {
1364 prop_assert!(
1365 out[i].abs() <= 2.0 * range,
1366 "[{}] DPO exceeds reasonable bounds at idx {}: {} (data range: {})",
1367 test_name,
1368 i,
1369 out[i],
1370 range
1371 );
1372 }
1373 }
1374
1375 if period == 1 {
1376 for i in 1..data.len() {
1377 let expected = data[i - 1] - data[i];
1378 prop_assert!(
1379 (out[i] - expected).abs() < 1e-9,
1380 "[{}] Period=1 special case failed at idx {}: got {}, expected {}",
1381 test_name,
1382 i,
1383 out[i],
1384 expected
1385 );
1386 }
1387 }
1388
1389 if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10) && data.len() > period {
1390 for i in period..data.len() {
1391 prop_assert!(
1392 out[i].abs() < 1e-9,
1393 "[{}] Expected zero DPO for constant data at idx {}, got {}",
1394 test_name,
1395 i,
1396 out[i]
1397 );
1398 }
1399 }
1400
1401 for i in 0..data.len() {
1402 if !out[i].is_nan() {
1403 let bits = out[i].to_bits();
1404 prop_assert!(
1405 bits != 0x11111111_11111111
1406 && bits != 0x22222222_22222222
1407 && bits != 0x33333333_33333333,
1408 "[{}] Found poison value at idx {}: {} ({:016X})",
1409 test_name,
1410 i,
1411 out[i],
1412 bits
1413 );
1414 }
1415 }
1416
1417 Ok(())
1418 })?;
1419
1420 Ok(())
1421 }
1422
1423 macro_rules! generate_all_dpo_tests {
1424 ($($test_fn:ident),*) => {
1425 paste! {
1426 $(
1427 #[test]
1428 fn [<$test_fn _scalar_f64>]() {
1429 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1430 }
1431 )*
1432 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1433 $(
1434 #[test]
1435 fn [<$test_fn _avx2_f64>]() {
1436 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1437 }
1438 #[test]
1439 fn [<$test_fn _avx512_f64>]() {
1440 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1441 }
1442 )*
1443 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
1444 $(
1445 #[test]
1446 fn [<$test_fn _simd128_f64>]() {
1447 let _ = $test_fn(stringify!([<$test_fn _simd128_f64>]), Kernel::Scalar);
1448 }
1449 )*
1450 }
1451 }
1452 }
1453
1454 generate_all_dpo_tests!(
1455 check_dpo_partial_params,
1456 check_dpo_accuracy,
1457 check_dpo_default_candles,
1458 check_dpo_zero_period,
1459 check_dpo_period_exceeds_length,
1460 check_dpo_very_small_dataset,
1461 check_dpo_reinput,
1462 check_dpo_nan_handling,
1463 check_dpo_streaming,
1464 check_dpo_no_poison
1465 );
1466
1467 #[cfg(feature = "proptest")]
1468 generate_all_dpo_tests!(check_dpo_property);
1469
1470 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1471 skip_if_unsupported!(kernel, test);
1472
1473 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1474 let c = read_candles_from_csv(file)?;
1475
1476 let output = DpoBatchBuilder::new()
1477 .kernel(kernel)
1478 .apply_candles(&c, "close")?;
1479
1480 let def = DpoParams::default();
1481 let row = output.values_for(&def).expect("default row missing");
1482
1483 assert_eq!(row.len(), c.close.len());
1484
1485 let expected = [
1486 65.3999999999287,
1487 131.3999999999287,
1488 32.599999999925785,
1489 98.3999999999287,
1490 117.99999999992724,
1491 ];
1492 let start = row.len() - 5;
1493 for (i, &v) in row[start..].iter().enumerate() {
1494 assert!(
1495 (v - expected[i]).abs() < 1e-1,
1496 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1497 );
1498 }
1499 Ok(())
1500 }
1501
1502 macro_rules! gen_batch_tests {
1503 ($fn_name:ident) => {
1504 paste! {
1505 #[test] fn [<$fn_name _scalar>]() {
1506 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1507 }
1508 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1509 #[test] fn [<$fn_name _avx2>]() {
1510 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1511 }
1512 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1513 #[test] fn [<$fn_name _avx512>]() {
1514 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1515 }
1516 #[test] fn [<$fn_name _auto_detect>]() {
1517 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1518 }
1519 }
1520 };
1521 }
1522 #[cfg(debug_assertions)]
1523 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1524 skip_if_unsupported!(kernel, test);
1525
1526 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1527 let c = read_candles_from_csv(file)?;
1528
1529 let test_configs = vec![
1530 (1, 10, 1),
1531 (5, 30, 5),
1532 (10, 50, 10),
1533 (50, 200, 50),
1534 (100, 500, 100),
1535 (5, 5, 0),
1536 (2, 20, 2),
1537 (25, 100, 25),
1538 (200, 1000, 200),
1539 ];
1540
1541 for (cfg_idx, &(period_start, period_end, period_step)) in test_configs.iter().enumerate() {
1542 let output = DpoBatchBuilder::new()
1543 .kernel(kernel)
1544 .period_range(period_start, period_end, period_step)
1545 .apply_candles(&c, "close")?;
1546
1547 for (idx, &val) in output.values.iter().enumerate() {
1548 if val.is_nan() {
1549 continue;
1550 }
1551
1552 let bits = val.to_bits();
1553 let row = idx / output.cols;
1554 let col = idx % output.cols;
1555 let combo = &output.combos[row];
1556
1557 if bits == 0x11111111_11111111 {
1558 panic!(
1559 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1560 at row {} col {} (flat index {}) with params: period={}",
1561 test,
1562 cfg_idx,
1563 val,
1564 bits,
1565 row,
1566 col,
1567 idx,
1568 combo.period.unwrap_or(5)
1569 );
1570 }
1571
1572 if bits == 0x22222222_22222222 {
1573 panic!(
1574 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1575 at row {} col {} (flat index {}) with params: period={}",
1576 test,
1577 cfg_idx,
1578 val,
1579 bits,
1580 row,
1581 col,
1582 idx,
1583 combo.period.unwrap_or(5)
1584 );
1585 }
1586
1587 if bits == 0x33333333_33333333 {
1588 panic!(
1589 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1590 at row {} col {} (flat index {}) with params: period={}",
1591 test,
1592 cfg_idx,
1593 val,
1594 bits,
1595 row,
1596 col,
1597 idx,
1598 combo.period.unwrap_or(5)
1599 );
1600 }
1601 }
1602 }
1603
1604 Ok(())
1605 }
1606
1607 #[cfg(not(debug_assertions))]
1608 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1609 Ok(())
1610 }
1611
1612 gen_batch_tests!(check_batch_default_row);
1613 gen_batch_tests!(check_batch_no_poison);
1614
1615 #[test]
1616 fn test_dpo_into_matches_api() -> Result<(), Box<dyn Error>> {
1617 let mut data = Vec::with_capacity(512);
1618 for _ in 0..4 {
1619 data.push(f64::NAN);
1620 }
1621 for i in 0..508 {
1622 let x = i as f64;
1623 data.push((0.1 * x).sin() * 50.0 + 0.01 * x);
1624 }
1625
1626 let params = DpoParams { period: Some(5) };
1627 let input = DpoInput::from_slice(&data, params);
1628
1629 let baseline = dpo(&input)?.values;
1630
1631 let mut out = vec![0.0; data.len()];
1632
1633 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1634 {
1635 dpo_into(&input, &mut out)?;
1636 assert_eq!(out.len(), baseline.len());
1637 for (idx, (a, b)) in out.iter().zip(baseline.iter()).enumerate() {
1638 let equal = (a.is_nan() && b.is_nan()) || (a == b);
1639 assert!(
1640 equal,
1641 "dpo_into parity mismatch at idx {}: got {}, expected {}",
1642 idx, a, b
1643 );
1644 }
1645 }
1646
1647 Ok(())
1648 }
1649}
1650
1651#[cfg(all(feature = "python", feature = "cuda"))]
1652use crate::utilities::dlpack_cuda::{make_device_array_py, DeviceArrayF32Py};
1653#[cfg(feature = "python")]
1654use crate::utilities::kernel_validation::validate_kernel;
1655#[cfg(all(feature = "python", feature = "cuda"))]
1656use numpy::PyUntypedArrayMethods;
1657#[cfg(feature = "python")]
1658use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
1659#[cfg(feature = "python")]
1660use pyo3::exceptions::PyValueError;
1661#[cfg(feature = "python")]
1662use pyo3::prelude::*;
1663#[cfg(feature = "python")]
1664use pyo3::types::PyDict;
1665
1666#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1667use serde::{Deserialize, Serialize};
1668#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1669use wasm_bindgen::prelude::*;
1670
1671#[cfg(all(feature = "python", feature = "cuda"))]
1672pub type DpoDeviceArrayF32Py = DeviceArrayF32Py;
1673
1674#[cfg(feature = "python")]
1675#[pyfunction(name = "dpo")]
1676#[pyo3(signature = (data, period, kernel=None))]
1677pub fn dpo_py<'py>(
1678 py: Python<'py>,
1679 data: PyReadonlyArray1<'py, f64>,
1680 period: usize,
1681 kernel: Option<&str>,
1682) -> PyResult<Bound<'py, PyArray1<f64>>> {
1683 use numpy::{IntoPyArray, PyArrayMethods};
1684
1685 let slice_in = data.as_slice()?;
1686 let kern = validate_kernel(kernel, false)?;
1687
1688 let params = DpoParams {
1689 period: Some(period),
1690 };
1691 let input = DpoInput::from_slice(slice_in, params);
1692
1693 let result_vec: Vec<f64> = py
1694 .allow_threads(|| dpo_with_kernel(&input, kern).map(|o| o.values))
1695 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1696
1697 Ok(result_vec.into_pyarray(py))
1698}
1699
1700#[cfg(feature = "python")]
1701#[pyclass(name = "DpoStream")]
1702pub struct DpoStreamPy {
1703 stream: DpoStream,
1704}
1705
1706#[cfg(feature = "python")]
1707#[pymethods]
1708impl DpoStreamPy {
1709 #[new]
1710 fn new(period: usize) -> PyResult<Self> {
1711 let params = DpoParams {
1712 period: Some(period),
1713 };
1714 let stream =
1715 DpoStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1716 Ok(DpoStreamPy { stream })
1717 }
1718
1719 fn update(&mut self, value: f64) -> Option<f64> {
1720 self.stream.update(value)
1721 }
1722}
1723
1724#[cfg(feature = "python")]
1725#[pyfunction(name = "dpo_batch")]
1726#[pyo3(signature = (data, period_range, kernel=None))]
1727pub fn dpo_batch_py<'py>(
1728 py: Python<'py>,
1729 data: PyReadonlyArray1<'py, f64>,
1730 period_range: (usize, usize, usize),
1731 kernel: Option<&str>,
1732) -> PyResult<Bound<'py, PyDict>> {
1733 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1734 use pyo3::types::PyDict;
1735
1736 let slice_in = data.as_slice()?;
1737 let kern = validate_kernel(kernel, true)?;
1738
1739 let sweep = DpoBatchRange {
1740 period: period_range,
1741 };
1742
1743 let combos = expand_grid(&sweep);
1744 let rows = combos.len();
1745 let cols = slice_in.len();
1746
1747 let out_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
1748 let slice_out = unsafe { out_arr.as_slice_mut()? };
1749
1750 let combos = py
1751 .allow_threads(|| {
1752 let kernel = match kern {
1753 Kernel::Auto => detect_best_batch_kernel(),
1754 k => k,
1755 };
1756 let simd = match kernel {
1757 Kernel::Avx512Batch => Kernel::Avx512,
1758 Kernel::Avx2Batch => Kernel::Avx2,
1759 Kernel::ScalarBatch => Kernel::Scalar,
1760 _ => unreachable!(),
1761 };
1762 dpo_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1763 })
1764 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1765
1766 let dict = PyDict::new(py);
1767 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1768 dict.set_item(
1769 "periods",
1770 combos
1771 .iter()
1772 .map(|p| p.period.unwrap() as u64)
1773 .collect::<Vec<_>>()
1774 .into_pyarray(py),
1775 )?;
1776
1777 Ok(dict)
1778}
1779
1780#[cfg(all(feature = "python", feature = "cuda"))]
1781#[pyfunction(name = "dpo_cuda_batch_dev")]
1782#[pyo3(signature = (data_f32, period_range, device_id=0))]
1783pub fn dpo_cuda_batch_dev_py<'py>(
1784 py: Python<'py>,
1785 data_f32: numpy::PyReadonlyArray1<'py, f32>,
1786 period_range: (usize, usize, usize),
1787 device_id: usize,
1788) -> PyResult<(DpoDeviceArrayF32Py, Bound<'py, PyDict>)> {
1789 use crate::cuda::cuda_available;
1790 if !cuda_available() {
1791 return Err(PyValueError::new_err("CUDA not available"));
1792 }
1793
1794 let slice_in = data_f32.as_slice()?;
1795 let sweep = DpoBatchRange {
1796 period: period_range,
1797 };
1798
1799 let combos = expand_grid(&sweep);
1800 let inner = py.allow_threads(|| {
1801 let cuda = crate::cuda::oscillators::dpo_wrapper::CudaDpo::new(device_id)
1802 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1803 cuda.dpo_batch_dev(slice_in, &sweep)
1804 .map_err(|e| PyValueError::new_err(e.to_string()))
1805 })?;
1806
1807 let dict = PyDict::new(py);
1808 dict.set_item(
1809 "periods",
1810 combos
1811 .iter()
1812 .map(|p| p.period.unwrap() as u64)
1813 .collect::<Vec<_>>()
1814 .into_pyarray(py),
1815 )?;
1816 let handle = make_device_array_py(device_id, inner)?;
1817 Ok((handle, dict))
1818}
1819
1820#[cfg(all(feature = "python", feature = "cuda"))]
1821#[pyfunction(name = "dpo_cuda_many_series_one_param_dev")]
1822#[pyo3(signature = (data_tm_f32, period, device_id=0))]
1823pub fn dpo_cuda_many_series_one_param_dev_py<'py>(
1824 py: Python<'py>,
1825 data_tm_f32: numpy::PyReadonlyArray2<'py, f32>,
1826 period: usize,
1827 device_id: usize,
1828) -> PyResult<DpoDeviceArrayF32Py> {
1829 use crate::cuda::cuda_available;
1830 if !cuda_available() {
1831 return Err(PyValueError::new_err("CUDA not available"));
1832 }
1833
1834 let shape = data_tm_f32.shape();
1835 if shape.len() != 2 {
1836 return Err(PyValueError::new_err("expected 2D array"));
1837 }
1838 let rows = shape[0];
1839 let cols = shape[1];
1840 let flat = data_tm_f32.as_slice()?;
1841 let params = DpoParams {
1842 period: Some(period),
1843 };
1844 let inner = py.allow_threads(|| {
1845 let cuda = crate::cuda::oscillators::dpo_wrapper::CudaDpo::new(device_id)
1846 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1847 cuda.dpo_many_series_one_param_time_major_dev(flat, cols, rows, ¶ms)
1848 .map_err(|e| PyValueError::new_err(e.to_string()))
1849 })?;
1850 make_device_array_py(device_id, inner)
1851}
1852
1853#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1854#[wasm_bindgen]
1855pub fn dpo_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1856 let params = DpoParams {
1857 period: Some(period),
1858 };
1859 let input = DpoInput::from_slice(data, params);
1860
1861 let mut output = vec![0.0; data.len()];
1862
1863 dpo_into_slice(&mut output, &input, Kernel::Auto)
1864 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1865
1866 Ok(output)
1867}
1868
1869#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1870#[wasm_bindgen]
1871pub fn dpo_into(
1872 in_ptr: *const f64,
1873 out_ptr: *mut f64,
1874 len: usize,
1875 period: usize,
1876) -> Result<(), JsValue> {
1877 if in_ptr.is_null() || out_ptr.is_null() {
1878 return Err(JsValue::from_str("Null pointer provided"));
1879 }
1880
1881 unsafe {
1882 let data = std::slice::from_raw_parts(in_ptr, len);
1883 let params = DpoParams {
1884 period: Some(period),
1885 };
1886 let input = DpoInput::from_slice(data, params);
1887
1888 if in_ptr == out_ptr {
1889 let mut temp = vec![0.0; len];
1890 dpo_into_slice(&mut temp, &input, Kernel::Auto)
1891 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1892 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1893 out.copy_from_slice(&temp);
1894 } else {
1895 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1896 dpo_into_slice(out, &input, Kernel::Auto)
1897 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1898 }
1899 Ok(())
1900 }
1901}
1902
1903#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1904#[wasm_bindgen]
1905pub fn dpo_alloc(len: usize) -> *mut f64 {
1906 let mut vec = Vec::<f64>::with_capacity(len);
1907 let ptr = vec.as_mut_ptr();
1908 std::mem::forget(vec);
1909 ptr
1910}
1911
1912#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1913#[wasm_bindgen]
1914pub fn dpo_free(ptr: *mut f64, len: usize) {
1915 if !ptr.is_null() {
1916 unsafe {
1917 let _ = Vec::from_raw_parts(ptr, len, len);
1918 }
1919 }
1920}
1921
1922#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1923#[derive(Serialize, Deserialize)]
1924pub struct DpoBatchConfig {
1925 pub period_range: (usize, usize, usize),
1926}
1927
1928#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1929#[derive(Serialize, Deserialize)]
1930pub struct DpoBatchJsOutput {
1931 pub values: Vec<f64>,
1932 pub combos: Vec<DpoParams>,
1933 pub rows: usize,
1934 pub cols: usize,
1935}
1936
1937#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1938#[wasm_bindgen(js_name = dpo_batch)]
1939pub fn dpo_batch_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1940 let config: DpoBatchConfig = serde_wasm_bindgen::from_value(config)
1941 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1942
1943 let sweep = DpoBatchRange {
1944 period: config.period_range,
1945 };
1946
1947 let output = dpo_batch_inner(data, &sweep, Kernel::Auto, false)
1948 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1949
1950 let js_output = DpoBatchJsOutput {
1951 values: output.values,
1952 combos: output.combos,
1953 rows: output.rows,
1954 cols: output.cols,
1955 };
1956
1957 serde_wasm_bindgen::to_value(&js_output)
1958 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1959}
1960
1961#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1962#[wasm_bindgen]
1963pub fn dpo_batch_into(
1964 in_ptr: *const f64,
1965 out_ptr: *mut f64,
1966 len: usize,
1967 period_start: usize,
1968 period_end: usize,
1969 period_step: usize,
1970) -> Result<usize, JsValue> {
1971 if in_ptr.is_null() || out_ptr.is_null() {
1972 return Err(JsValue::from_str("Null pointer provided"));
1973 }
1974
1975 unsafe {
1976 let data = std::slice::from_raw_parts(in_ptr, len);
1977
1978 let sweep = DpoBatchRange {
1979 period: (period_start, period_end, period_step),
1980 };
1981
1982 let combos = expand_grid(&sweep);
1983 let rows = combos.len();
1984 let cols = len;
1985
1986 let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
1987
1988 dpo_batch_inner_into(data, &sweep, Kernel::Auto, false, out)
1989 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1990
1991 Ok(rows)
1992 }
1993}