1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::dm_wrapper::CudaDm;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::indicators::moving_averages::alma::DeviceArrayF32Py;
5use crate::utilities::data_loader::Candles;
6use crate::utilities::enums::Kernel;
7use crate::utilities::helpers::{
8 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
9 make_uninit_matrix,
10};
11#[cfg(feature = "python")]
12use crate::utilities::kernel_validation::validate_kernel;
13#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
14use core::arch::x86_64::*;
15#[cfg(not(target_arch = "wasm32"))]
16use rayon::prelude::*;
17use std::mem::MaybeUninit;
18use thiserror::Error;
19
20#[derive(Debug, Clone)]
21pub enum DmData<'a> {
22 Candles { candles: &'a Candles },
23 Slices { high: &'a [f64], low: &'a [f64] },
24}
25
26#[derive(Debug, Clone)]
27pub struct DmOutput {
28 pub plus: Vec<f64>,
29 pub minus: Vec<f64>,
30}
31
32#[derive(Debug, Clone)]
33pub struct DmParams {
34 pub period: Option<usize>,
35}
36
37impl Default for DmParams {
38 fn default() -> Self {
39 Self { period: Some(14) }
40 }
41}
42
43#[derive(Debug, Clone)]
44pub struct DmInput<'a> {
45 pub data: DmData<'a>,
46 pub params: DmParams,
47}
48
49impl<'a> DmInput<'a> {
50 #[inline]
51 pub fn from_candles(candles: &'a Candles, params: DmParams) -> Self {
52 Self {
53 data: DmData::Candles { candles },
54 params,
55 }
56 }
57 #[inline]
58 pub fn from_slices(high: &'a [f64], low: &'a [f64], params: DmParams) -> Self {
59 Self {
60 data: DmData::Slices { high, low },
61 params,
62 }
63 }
64 #[inline]
65 pub fn with_default_candles(candles: &'a Candles) -> Self {
66 Self {
67 data: DmData::Candles { candles },
68 params: DmParams::default(),
69 }
70 }
71 #[inline]
72 pub fn get_period(&self) -> usize {
73 self.params
74 .period
75 .unwrap_or_else(|| DmParams::default().period.unwrap())
76 }
77}
78
79#[derive(Copy, Clone, Debug)]
80pub struct DmBuilder {
81 period: Option<usize>,
82 kernel: Kernel,
83}
84
85impl Default for DmBuilder {
86 fn default() -> Self {
87 Self {
88 period: None,
89 kernel: Kernel::Auto,
90 }
91 }
92}
93
94impl DmBuilder {
95 #[inline(always)]
96 pub fn new() -> Self {
97 Self::default()
98 }
99 #[inline(always)]
100 pub fn period(mut self, n: usize) -> Self {
101 self.period = Some(n);
102 self
103 }
104 #[inline(always)]
105 pub fn kernel(mut self, k: Kernel) -> Self {
106 self.kernel = k;
107 self
108 }
109
110 #[inline(always)]
111 pub fn apply(self, candles: &Candles) -> Result<DmOutput, DmError> {
112 let p = DmParams {
113 period: self.period,
114 };
115 let i = DmInput::from_candles(candles, p);
116 dm_with_kernel(&i, self.kernel)
117 }
118
119 #[inline(always)]
120 pub fn apply_slices(self, high: &[f64], low: &[f64]) -> Result<DmOutput, DmError> {
121 let p = DmParams {
122 period: self.period,
123 };
124 let i = DmInput::from_slices(high, low, p);
125 dm_with_kernel(&i, self.kernel)
126 }
127
128 #[inline(always)]
129 pub fn into_stream(self) -> Result<DmStream, DmError> {
130 let p = DmParams {
131 period: self.period,
132 };
133 DmStream::try_new(p)
134 }
135}
136
137#[derive(Debug, Error)]
138pub enum DmError {
139 #[error("dm: Empty data provided (or high/low length mismatch).")]
140 EmptyInputData,
141 #[error("dm: Invalid period: period = {period}, data length = {data_len}")]
142 InvalidPeriod { period: usize, data_len: usize },
143 #[error("dm: Not enough valid data: needed = {needed}, valid = {valid}")]
144 NotEnoughValidData { needed: usize, valid: usize },
145 #[error("dm: All values are NaN.")]
146 AllValuesNaN,
147 #[error("dm: output length mismatch: expected = {expected}, got = {got}")]
148 OutputLengthMismatch { expected: usize, got: usize },
149 #[error("dm: invalid range: start={start}, end={end}, step={step}")]
150 InvalidRange {
151 start: usize,
152 end: usize,
153 step: usize,
154 },
155 #[error("dm: invalid kernel for batch: {0:?}")]
156 InvalidKernelForBatch(Kernel),
157}
158
159#[inline]
160pub fn dm(input: &DmInput) -> Result<DmOutput, DmError> {
161 dm_with_kernel(input, Kernel::Auto)
162}
163
164#[inline(always)]
165fn dm_prepare<'a>(
166 input: &'a DmInput,
167 kernel: Kernel,
168) -> Result<(&'a [f64], &'a [f64], usize, usize, Kernel), DmError> {
169 let (high, low) = match &input.data {
170 DmData::Candles { candles } => {
171 let h = candles
172 .select_candle_field("high")
173 .map_err(|_| DmError::EmptyInputData)?;
174 let l = candles
175 .select_candle_field("low")
176 .map_err(|_| DmError::EmptyInputData)?;
177 (h, l)
178 }
179 DmData::Slices { high, low } => (*high, *low),
180 };
181
182 if high.is_empty() || low.is_empty() || high.len() != low.len() {
183 return Err(DmError::EmptyInputData);
184 }
185
186 let period = input.get_period();
187 if period == 0 || period > high.len() {
188 return Err(DmError::InvalidPeriod {
189 period,
190 data_len: high.len(),
191 });
192 }
193
194 let first = high
195 .iter()
196 .zip(low.iter())
197 .position(|(&h, &l)| !h.is_nan() && !l.is_nan())
198 .ok_or(DmError::AllValuesNaN)?;
199
200 if high.len() - first < period {
201 return Err(DmError::NotEnoughValidData {
202 needed: period,
203 valid: high.len() - first,
204 });
205 }
206
207 let chosen = match kernel {
208 Kernel::Auto => Kernel::Scalar,
209 k => k,
210 };
211 Ok((high, low, period, first, chosen))
212}
213
214#[inline(always)]
215fn dm_compute_into_scalar(
216 high: &[f64],
217 low: &[f64],
218 period: usize,
219 first: usize,
220 plus_out: &mut [f64],
221 minus_out: &mut [f64],
222) {
223 debug_assert_eq!(high.len(), low.len());
224 let n = high.len();
225 if n == 0 {
226 return;
227 }
228
229 let end_init = first + period - 1;
230
231 unsafe {
232 let mut sum_plus = 0.0f64;
233 let mut sum_minus = 0.0f64;
234
235 let mut i = first + 1;
236 let warm_stop = end_init + 1;
237
238 let mut prev_high = *high.get_unchecked(first);
239 let mut prev_low = *low.get_unchecked(first);
240
241 while i < warm_stop {
242 let hi = *high.get_unchecked(i);
243 let lo = *low.get_unchecked(i);
244 let diff_p = hi - prev_high;
245 let diff_m = prev_low - lo;
246 prev_high = hi;
247 prev_low = lo;
248
249 if diff_p > 0.0 && diff_p > diff_m {
250 sum_plus += diff_p;
251 } else if diff_m > 0.0 && diff_m > diff_p {
252 sum_minus += diff_m;
253 }
254 i += 1;
255 }
256
257 *plus_out.get_unchecked_mut(end_init) = sum_plus;
258 *minus_out.get_unchecked_mut(end_init) = sum_minus;
259
260 if end_init + 1 >= n {
261 return;
262 }
263 let inv_p = 1.0 / (period as f64);
264
265 let mut j = end_init + 1;
266 while j < n {
267 let hi = *high.get_unchecked(j);
268 let lo = *low.get_unchecked(j);
269 let diff_p = hi - prev_high;
270 let diff_m = prev_low - lo;
271 prev_high = hi;
272 prev_low = lo;
273
274 let (p, m) = if diff_p > 0.0 && diff_p > diff_m {
275 (diff_p, 0.0)
276 } else if diff_m > 0.0 && diff_m > diff_p {
277 (0.0, diff_m)
278 } else {
279 (0.0, 0.0)
280 };
281
282 #[cfg(target_feature = "fma")]
283 {
284 sum_plus = (-inv_p).mul_add(sum_plus, sum_plus + p);
285 sum_minus = (-inv_p).mul_add(sum_minus, sum_minus + m);
286 }
287 #[cfg(not(target_feature = "fma"))]
288 {
289 sum_plus = sum_plus - (sum_plus * inv_p) + p;
290 sum_minus = sum_minus - (sum_minus * inv_p) + m;
291 }
292
293 *plus_out.get_unchecked_mut(j) = sum_plus;
294 *minus_out.get_unchecked_mut(j) = sum_minus;
295 j += 1;
296 }
297 }
298}
299
300#[inline(always)]
301fn dm_compute_into(
302 high: &[f64],
303 low: &[f64],
304 period: usize,
305 first: usize,
306 kernel: Kernel,
307 plus_out: &mut [f64],
308 minus_out: &mut [f64],
309) {
310 match kernel {
311 Kernel::Scalar | Kernel::ScalarBatch => {
312 dm_compute_into_scalar(high, low, period, first, plus_out, minus_out)
313 }
314 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
315 Kernel::Avx2 | Kernel::Avx2Batch => unsafe {
316 dm_compute_into_avx2(high, low, period, first, plus_out, minus_out)
317 },
318 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
319 Kernel::Avx512 | Kernel::Avx512Batch => unsafe {
320 dm_compute_into_avx512(high, low, period, first, plus_out, minus_out)
321 },
322 _ => unreachable!(),
323 }
324}
325
326#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
327#[target_feature(enable = "avx2")]
328unsafe fn dm_compute_into_avx2(
329 high: &[f64],
330 low: &[f64],
331 period: usize,
332 first: usize,
333 plus_out: &mut [f64],
334 minus_out: &mut [f64],
335) {
336 use core::arch::x86_64::*;
337 debug_assert_eq!(high.len(), low.len());
338 let n = high.len();
339 if n == 0 {
340 return;
341 }
342
343 let end_init = first + period - 1;
344 let inv_p = 1.0 / (period as f64);
345 let zero = _mm256_setzero_pd();
346
347 let mut sum_plus = 0.0f64;
348 let mut sum_minus = 0.0f64;
349 let mut i = first + 1;
350 let warm_stop = end_init + 1;
351 while i + 4 <= warm_stop {
352 let hc = _mm256_loadu_pd(high.as_ptr().add(i));
353 let hp = _mm256_loadu_pd(high.as_ptr().add(i - 1));
354 let dp = _mm256_sub_pd(hc, hp);
355
356 let lp = _mm256_loadu_pd(low.as_ptr().add(i - 1));
357 let lc = _mm256_loadu_pd(low.as_ptr().add(i));
358 let dm = _mm256_sub_pd(lp, lc);
359
360 let dp_pos = _mm256_max_pd(dp, zero);
361 let dm_pos = _mm256_max_pd(dm, zero);
362
363 let p_mask = _mm256_cmp_pd(dp_pos, dm_pos, _CMP_GT_OQ);
364 let m_mask = _mm256_cmp_pd(dm_pos, dp_pos, _CMP_GT_OQ);
365 let p_vec = _mm256_and_pd(dp_pos, p_mask);
366 let m_vec = _mm256_and_pd(dm_pos, m_mask);
367
368 let mut p_buf = [0.0f64; 4];
369 let mut m_buf = [0.0f64; 4];
370 _mm256_storeu_pd(p_buf.as_mut_ptr(), p_vec);
371 _mm256_storeu_pd(m_buf.as_mut_ptr(), m_vec);
372 sum_plus += p_buf.iter().sum::<f64>();
373 sum_minus += m_buf.iter().sum::<f64>();
374 i += 4;
375 }
376 while i < warm_stop {
377 let dp = *high.get_unchecked(i) - *high.get_unchecked(i - 1);
378 let dm = *low.get_unchecked(i - 1) - *low.get_unchecked(i);
379 if dp > 0.0 && dp > dm {
380 sum_plus += dp;
381 } else if dm > 0.0 && dm > dp {
382 sum_minus += dm;
383 }
384 i += 1;
385 }
386
387 *plus_out.get_unchecked_mut(end_init) = sum_plus;
388 *minus_out.get_unchecked_mut(end_init) = sum_minus;
389
390 if end_init + 1 >= n {
391 return;
392 }
393
394 let mut j = end_init + 1;
395 while j + 4 <= n {
396 let hc = _mm256_loadu_pd(high.as_ptr().add(j));
397 let hp = _mm256_loadu_pd(high.as_ptr().add(j - 1));
398 let dp = _mm256_sub_pd(hc, hp);
399
400 let lp = _mm256_loadu_pd(low.as_ptr().add(j - 1));
401 let lc = _mm256_loadu_pd(low.as_ptr().add(j));
402 let dm = _mm256_sub_pd(lp, lc);
403
404 let dp_pos = _mm256_max_pd(dp, zero);
405 let dm_pos = _mm256_max_pd(dm, zero);
406
407 let p_mask = _mm256_cmp_pd(dp_pos, dm_pos, _CMP_GT_OQ);
408 let m_mask = _mm256_cmp_pd(dm_pos, dp_pos, _CMP_GT_OQ);
409 let p_vec = _mm256_and_pd(dp_pos, p_mask);
410 let m_vec = _mm256_and_pd(dm_pos, m_mask);
411
412 let mut p_buf = [0.0f64; 4];
413 let mut m_buf = [0.0f64; 4];
414 _mm256_storeu_pd(p_buf.as_mut_ptr(), p_vec);
415 _mm256_storeu_pd(m_buf.as_mut_ptr(), m_vec);
416
417 #[cfg(target_feature = "fma")]
418 {
419 sum_plus = (-inv_p).mul_add(sum_plus, sum_plus + p_buf[0]);
420 sum_minus = (-inv_p).mul_add(sum_minus, sum_minus + m_buf[0]);
421 *plus_out.get_unchecked_mut(j) = sum_plus;
422 *minus_out.get_unchecked_mut(j) = sum_minus;
423
424 sum_plus = (-inv_p).mul_add(sum_plus, sum_plus + p_buf[1]);
425 sum_minus = (-inv_p).mul_add(sum_minus, sum_minus + m_buf[1]);
426 *plus_out.get_unchecked_mut(j + 1) = sum_plus;
427 *minus_out.get_unchecked_mut(j + 1) = sum_minus;
428
429 sum_plus = (-inv_p).mul_add(sum_plus, sum_plus + p_buf[2]);
430 sum_minus = (-inv_p).mul_add(sum_minus, sum_minus + m_buf[2]);
431 *plus_out.get_unchecked_mut(j + 2) = sum_plus;
432 *minus_out.get_unchecked_mut(j + 2) = sum_minus;
433
434 sum_plus = (-inv_p).mul_add(sum_plus, sum_plus + p_buf[3]);
435 sum_minus = (-inv_p).mul_add(sum_minus, sum_minus + m_buf[3]);
436 *plus_out.get_unchecked_mut(j + 3) = sum_plus;
437 *minus_out.get_unchecked_mut(j + 3) = sum_minus;
438 }
439 #[cfg(not(target_feature = "fma"))]
440 {
441 sum_plus = sum_plus - (sum_plus * inv_p) + p_buf[0];
442 sum_minus = sum_minus - (sum_minus * inv_p) + m_buf[0];
443 *plus_out.get_unchecked_mut(j) = sum_plus;
444 *minus_out.get_unchecked_mut(j) = sum_minus;
445
446 sum_plus = sum_plus - (sum_plus * inv_p) + p_buf[1];
447 sum_minus = sum_minus - (sum_minus * inv_p) + m_buf[1];
448 *plus_out.get_unchecked_mut(j + 1) = sum_plus;
449 *minus_out.get_unchecked_mut(j + 1) = sum_minus;
450
451 sum_plus = sum_plus - (sum_plus * inv_p) + p_buf[2];
452 sum_minus = sum_minus - (sum_minus * inv_p) + m_buf[2];
453 *plus_out.get_unchecked_mut(j + 2) = sum_plus;
454 *minus_out.get_unchecked_mut(j + 2) = sum_minus;
455
456 sum_plus = sum_plus - (sum_plus * inv_p) + p_buf[3];
457 sum_minus = sum_minus - (sum_minus * inv_p) + m_buf[3];
458 *plus_out.get_unchecked_mut(j + 3) = sum_plus;
459 *minus_out.get_unchecked_mut(j + 3) = sum_minus;
460 }
461 j += 4;
462 }
463
464 while j < n {
465 let dp = *high.get_unchecked(j) - *high.get_unchecked(j - 1);
466 let dm = *low.get_unchecked(j - 1) - *low.get_unchecked(j);
467
468 let (p, m) = if dp > 0.0 && dp > dm {
469 (dp, 0.0)
470 } else if dm > 0.0 && dm > dp {
471 (0.0, dm)
472 } else {
473 (0.0, 0.0)
474 };
475
476 #[cfg(target_feature = "fma")]
477 {
478 sum_plus = (-inv_p).mul_add(sum_plus, sum_plus + p);
479 sum_minus = (-inv_p).mul_add(sum_minus, sum_minus + m);
480 }
481 #[cfg(not(target_feature = "fma"))]
482 {
483 sum_plus = sum_plus - (sum_plus * inv_p) + p;
484 sum_minus = sum_minus - (sum_minus * inv_p) + m;
485 }
486 *plus_out.get_unchecked_mut(j) = sum_plus;
487 *minus_out.get_unchecked_mut(j) = sum_minus;
488 j += 1;
489 }
490}
491
492#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
493#[target_feature(enable = "avx512f")]
494unsafe fn dm_compute_into_avx512(
495 high: &[f64],
496 low: &[f64],
497 period: usize,
498 first: usize,
499 plus_out: &mut [f64],
500 minus_out: &mut [f64],
501) {
502 use core::arch::x86_64::*;
503 debug_assert_eq!(high.len(), low.len());
504 let n = high.len();
505 if n == 0 {
506 return;
507 }
508
509 let end_init = first + period - 1;
510 let inv_p = 1.0 / (period as f64);
511 let zero = _mm512_set1_pd(0.0);
512
513 let mut sum_plus = 0.0f64;
514 let mut sum_minus = 0.0f64;
515 let mut i = first + 1;
516 let warm_stop = end_init + 1;
517 while i + 8 <= warm_stop {
518 let hc = _mm512_loadu_pd(high.as_ptr().add(i));
519 let hp = _mm512_loadu_pd(high.as_ptr().add(i - 1));
520 let dp = _mm512_sub_pd(hc, hp);
521
522 let lp = _mm512_loadu_pd(low.as_ptr().add(i - 1));
523 let lc = _mm512_loadu_pd(low.as_ptr().add(i));
524 let dm = _mm512_sub_pd(lp, lc);
525
526 let dp_pos = _mm512_max_pd(dp, zero);
527 let dm_pos = _mm512_max_pd(dm, zero);
528
529 let p_mask = _mm512_cmp_pd_mask(dp_pos, dm_pos, _CMP_GT_OQ);
530 let m_mask = _mm512_cmp_pd_mask(dm_pos, dp_pos, _CMP_GT_OQ);
531 let p_vec = _mm512_maskz_mov_pd(p_mask, dp_pos);
532 let m_vec = _mm512_maskz_mov_pd(m_mask, dm_pos);
533
534 let mut p_buf = [0.0f64; 8];
535 let mut m_buf = [0.0f64; 8];
536 _mm512_storeu_pd(p_buf.as_mut_ptr(), p_vec);
537 _mm512_storeu_pd(m_buf.as_mut_ptr(), m_vec);
538 for k in 0..8 {
539 sum_plus += p_buf[k];
540 sum_minus += m_buf[k];
541 }
542 i += 8;
543 }
544 while i < warm_stop {
545 let dp = *high.get_unchecked(i) - *high.get_unchecked(i - 1);
546 let dm = *low.get_unchecked(i - 1) - *low.get_unchecked(i);
547 if dp > 0.0 && dp > dm {
548 sum_plus += dp;
549 } else if dm > 0.0 && dm > dp {
550 sum_minus += dm;
551 }
552 i += 1;
553 }
554 *plus_out.get_unchecked_mut(end_init) = sum_plus;
555 *minus_out.get_unchecked_mut(end_init) = sum_minus;
556
557 if end_init + 1 >= n {
558 return;
559 }
560
561 let mut j = end_init + 1;
562 while j + 8 <= n {
563 let hc = _mm512_loadu_pd(high.as_ptr().add(j));
564 let hp = _mm512_loadu_pd(high.as_ptr().add(j - 1));
565 let dp = _mm512_sub_pd(hc, hp);
566
567 let lp = _mm512_loadu_pd(low.as_ptr().add(j - 1));
568 let lc = _mm512_loadu_pd(low.as_ptr().add(j));
569 let dm = _mm512_sub_pd(lp, lc);
570
571 let dp_pos = _mm512_max_pd(dp, zero);
572 let dm_pos = _mm512_max_pd(dm, zero);
573
574 let p_mask = _mm512_cmp_pd_mask(dp_pos, dm_pos, _CMP_GT_OQ);
575 let m_mask = _mm512_cmp_pd_mask(dm_pos, dp_pos, _CMP_GT_OQ);
576 let p_vec = _mm512_maskz_mov_pd(p_mask, dp_pos);
577 let m_vec = _mm512_maskz_mov_pd(m_mask, dm_pos);
578
579 let mut p_buf = [0.0f64; 8];
580 let mut m_buf = [0.0f64; 8];
581 _mm512_storeu_pd(p_buf.as_mut_ptr(), p_vec);
582 _mm512_storeu_pd(m_buf.as_mut_ptr(), m_vec);
583
584 #[cfg(target_feature = "fma")]
585 {
586 for t in 0..8 {
587 sum_plus = (-inv_p).mul_add(sum_plus, sum_plus + p_buf[t]);
588 sum_minus = (-inv_p).mul_add(sum_minus, sum_minus + m_buf[t]);
589 *plus_out.get_unchecked_mut(j + t) = sum_plus;
590 *minus_out.get_unchecked_mut(j + t) = sum_minus;
591 }
592 }
593 #[cfg(not(target_feature = "fma"))]
594 {
595 for t in 0..8 {
596 sum_plus = sum_plus - (sum_plus * inv_p) + p_buf[t];
597 sum_minus = sum_minus - (sum_minus * inv_p) + m_buf[t];
598 *plus_out.get_unchecked_mut(j + t) = sum_plus;
599 *minus_out.get_unchecked_mut(j + t) = sum_minus;
600 }
601 }
602 j += 8;
603 }
604 while j < n {
605 let dp = *high.get_unchecked(j) - *high.get_unchecked(j - 1);
606 let dm = *low.get_unchecked(j - 1) - *low.get_unchecked(j);
607
608 let (p, m) = if dp > 0.0 && dp > dm {
609 (dp, 0.0)
610 } else if dm > 0.0 && dm > dp {
611 (0.0, dm)
612 } else {
613 (0.0, 0.0)
614 };
615
616 #[cfg(target_feature = "fma")]
617 {
618 sum_plus = (-inv_p).mul_add(sum_plus, sum_plus + p);
619 sum_minus = (-inv_p).mul_add(sum_minus, sum_minus + m);
620 }
621 #[cfg(not(target_feature = "fma"))]
622 {
623 sum_plus = sum_plus - (sum_plus * inv_p) + p;
624 sum_minus = sum_minus - (sum_minus * inv_p) + m;
625 }
626 *plus_out.get_unchecked_mut(j) = sum_plus;
627 *minus_out.get_unchecked_mut(j) = sum_minus;
628 j += 1;
629 }
630}
631
632pub fn dm_with_kernel(input: &DmInput, kernel: Kernel) -> Result<DmOutput, DmError> {
633 let (high, low, period, first, chosen) = dm_prepare(input, kernel)?;
634 let warm = first + period - 1;
635
636 let mut plus = alloc_with_nan_prefix(high.len(), warm);
637 let mut minus = alloc_with_nan_prefix(high.len(), warm);
638
639 dm_compute_into(high, low, period, first, chosen, &mut plus, &mut minus);
640 Ok(DmOutput { plus, minus })
641}
642
643#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
644#[inline]
645pub fn dm_into(
646 input: &DmInput,
647 plus_out: &mut [f64],
648 minus_out: &mut [f64],
649) -> Result<(), DmError> {
650 let (high, low, period, first, chosen) = dm_prepare(input, Kernel::Auto)?;
651
652 if plus_out.len() != high.len() {
653 return Err(DmError::OutputLengthMismatch {
654 expected: high.len(),
655 got: plus_out.len(),
656 });
657 }
658 if minus_out.len() != high.len() {
659 return Err(DmError::OutputLengthMismatch {
660 expected: high.len(),
661 got: minus_out.len(),
662 });
663 }
664
665 let warm = first + period - 1;
666 let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
667 let warm_end = warm.min(high.len());
668 for v in &mut plus_out[..warm_end] {
669 *v = qnan;
670 }
671 for v in &mut minus_out[..warm_end] {
672 *v = qnan;
673 }
674
675 dm_compute_into(high, low, period, first, chosen, plus_out, minus_out);
676 Ok(())
677}
678
679#[inline]
680pub fn dm_into_slice(
681 plus_dst: &mut [f64],
682 minus_dst: &mut [f64],
683 input: &DmInput,
684 kernel: Kernel,
685) -> Result<(), DmError> {
686 let (high, low, period, first, chosen) = dm_prepare(input, kernel)?;
687 if plus_dst.len() != high.len() {
688 return Err(DmError::OutputLengthMismatch {
689 expected: high.len(),
690 got: plus_dst.len(),
691 });
692 }
693 if minus_dst.len() != high.len() {
694 return Err(DmError::OutputLengthMismatch {
695 expected: high.len(),
696 got: minus_dst.len(),
697 });
698 }
699
700 dm_compute_into(high, low, period, first, chosen, plus_dst, minus_dst);
701
702 let warm = first + period - 1;
703 for v in &mut plus_dst[..warm] {
704 *v = f64::NAN;
705 }
706 for v in &mut minus_dst[..warm] {
707 *v = f64::NAN;
708 }
709 Ok(())
710}
711
712#[inline]
713pub unsafe fn dm_scalar(
714 high: &[f64],
715 low: &[f64],
716 period: usize,
717 first_valid_idx: usize,
718) -> Result<DmOutput, DmError> {
719 let warm = first_valid_idx + period - 1;
720 let mut plus_dm = alloc_with_nan_prefix(high.len(), warm);
721 let mut minus_dm = alloc_with_nan_prefix(high.len(), warm);
722
723 dm_compute_into_scalar(
724 high,
725 low,
726 period,
727 first_valid_idx,
728 &mut plus_dm,
729 &mut minus_dm,
730 );
731
732 Ok(DmOutput {
733 plus: plus_dm,
734 minus: minus_dm,
735 })
736}
737
738#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
739#[inline]
740pub unsafe fn dm_avx2(
741 high: &[f64],
742 low: &[f64],
743 period: usize,
744 first_valid_idx: usize,
745) -> Result<DmOutput, DmError> {
746 let warm = first_valid_idx + period - 1;
747 let mut plus_dm = alloc_with_nan_prefix(high.len(), warm);
748 let mut minus_dm = alloc_with_nan_prefix(high.len(), warm);
749 dm_compute_into_avx2(
750 high,
751 low,
752 period,
753 first_valid_idx,
754 &mut plus_dm,
755 &mut minus_dm,
756 );
757 Ok(DmOutput {
758 plus: plus_dm,
759 minus: minus_dm,
760 })
761}
762
763#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
764#[inline]
765pub unsafe fn dm_avx512(
766 high: &[f64],
767 low: &[f64],
768 period: usize,
769 first_valid_idx: usize,
770) -> Result<DmOutput, DmError> {
771 let warm = first_valid_idx + period - 1;
772 let mut plus_dm = alloc_with_nan_prefix(high.len(), warm);
773 let mut minus_dm = alloc_with_nan_prefix(high.len(), warm);
774 dm_compute_into_avx512(
775 high,
776 low,
777 period,
778 first_valid_idx,
779 &mut plus_dm,
780 &mut minus_dm,
781 );
782 Ok(DmOutput {
783 plus: plus_dm,
784 minus: minus_dm,
785 })
786}
787
788#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
789#[inline]
790pub unsafe fn dm_avx512_short(
791 high: &[f64],
792 low: &[f64],
793 period: usize,
794 first_valid_idx: usize,
795) -> Result<DmOutput, DmError> {
796 dm_avx512(high, low, period, first_valid_idx)
797}
798
799#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
800#[inline]
801pub unsafe fn dm_avx512_long(
802 high: &[f64],
803 low: &[f64],
804 period: usize,
805 first_valid_idx: usize,
806) -> Result<DmOutput, DmError> {
807 dm_avx512(high, low, period, first_valid_idx)
808}
809
810#[derive(Debug, Clone)]
811pub struct DmStream {
812 period: usize,
813 inv_period: f64,
814 sum_plus: f64,
815 sum_minus: f64,
816 prev_high: f64,
817 prev_low: f64,
818 count: usize,
819}
820
821impl DmStream {
822 pub fn try_new(params: DmParams) -> Result<Self, DmError> {
823 let period = params.period.unwrap_or(14);
824 if period == 0 {
825 return Err(DmError::InvalidPeriod {
826 period,
827 data_len: 0,
828 });
829 }
830 let inv = 1.0 / (period as f64);
831 Ok(Self {
832 period,
833 inv_period: inv,
834 sum_plus: 0.0,
835 sum_minus: 0.0,
836 prev_high: f64::NAN,
837 prev_low: f64::NAN,
838 count: 0,
839 })
840 }
841
842 #[inline(always)]
843 pub fn update(&mut self, high: f64, low: f64) -> Option<(f64, f64)> {
844 if self.count == 0 {
845 self.prev_high = high;
846 self.prev_low = low;
847 }
848
849 let dp = high - self.prev_high;
850 let dm = self.prev_low - low;
851
852 self.prev_high = high;
853 self.prev_low = low;
854
855 let dp_pos = dp.max(0.0);
856 let dm_pos = dm.max(0.0);
857
858 let plus_val = if dp_pos > dm_pos { dp_pos } else { 0.0 };
859 let minus_val = if dm_pos > dp_pos { dm_pos } else { 0.0 };
860
861 if self.count < self.period - 1 {
862 self.sum_plus += plus_val;
863 self.sum_minus += minus_val;
864 self.count += 1;
865 return None;
866 } else if self.count == self.period - 1 {
867 self.sum_plus += plus_val;
868 self.sum_minus += minus_val;
869 self.count += 1;
870 return Some((self.sum_plus, self.sum_minus));
871 }
872
873 #[cfg(target_feature = "fma")]
874 {
875 self.sum_plus = (-self.inv_period).mul_add(self.sum_plus, self.sum_plus + plus_val);
876 self.sum_minus = (-self.inv_period).mul_add(self.sum_minus, self.sum_minus + minus_val);
877 }
878 #[cfg(not(target_feature = "fma"))]
879 {
880 self.sum_plus = self.sum_plus - (self.sum_plus * self.inv_period) + plus_val;
881 self.sum_minus = self.sum_minus - (self.sum_minus * self.inv_period) + minus_val;
882 }
883
884 Some((self.sum_plus, self.sum_minus))
885 }
886}
887
888#[derive(Clone, Debug)]
889pub struct DmBatchRange {
890 pub period: (usize, usize, usize),
891}
892
893impl Default for DmBatchRange {
894 fn default() -> Self {
895 Self {
896 period: (14, 263, 1),
897 }
898 }
899}
900
901#[derive(Clone, Debug, Default)]
902pub struct DmBatchBuilder {
903 range: DmBatchRange,
904 kernel: Kernel,
905}
906
907impl DmBatchBuilder {
908 pub fn new() -> Self {
909 Self::default()
910 }
911 pub fn kernel(mut self, k: Kernel) -> Self {
912 self.kernel = k;
913 self
914 }
915 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
916 self.range.period = (start, end, step);
917 self
918 }
919 pub fn period_static(mut self, p: usize) -> Self {
920 self.range.period = (p, p, 0);
921 self
922 }
923 pub fn apply_slices(self, high: &[f64], low: &[f64]) -> Result<DmBatchOutput, DmError> {
924 dm_batch_with_kernel(high, low, &self.range, self.kernel)
925 }
926 pub fn apply_candles(self, c: &Candles) -> Result<DmBatchOutput, DmError> {
927 let high = c
928 .select_candle_field("high")
929 .map_err(|_| DmError::EmptyInputData)?;
930 let low = c
931 .select_candle_field("low")
932 .map_err(|_| DmError::EmptyInputData)?;
933 self.apply_slices(high, low)
934 }
935 pub fn with_default_candles(c: &Candles) -> Result<DmBatchOutput, DmError> {
936 DmBatchBuilder::new().kernel(Kernel::Auto).apply_candles(c)
937 }
938}
939
940#[derive(Clone, Debug)]
941pub struct DmBatchOutput {
942 pub plus: Vec<f64>,
943 pub minus: Vec<f64>,
944 pub combos: Vec<DmParams>,
945 pub rows: usize,
946 pub cols: usize,
947}
948impl DmBatchOutput {
949 pub fn row_for_params(&self, p: &DmParams) -> Option<usize> {
950 self.combos
951 .iter()
952 .position(|c| c.period.unwrap_or(14) == p.period.unwrap_or(14))
953 }
954 pub fn values_for(&self, p: &DmParams) -> Option<(&[f64], &[f64])> {
955 self.row_for_params(p).map(|row| {
956 let start = row * self.cols;
957 (
958 &self.plus[start..start + self.cols],
959 &self.minus[start..start + self.cols],
960 )
961 })
962 }
963}
964
965#[inline(always)]
966fn expand_grid(r: &DmBatchRange) -> Result<Vec<DmParams>, DmError> {
967 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, DmError> {
968 if step == 0 || start == end {
969 return Ok(vec![start]);
970 }
971 if start < end {
972 let mut v = Vec::new();
973 let st = step.max(1);
974 let mut x = start;
975 while x <= end {
976 v.push(x);
977 match x.checked_add(st) {
978 Some(next) => x = next,
979 None => break,
980 }
981 }
982 if v.is_empty() {
983 return Err(DmError::InvalidRange { start, end, step });
984 }
985 return Ok(v);
986 }
987
988 let mut v = Vec::new();
989 let st = step.max(1) as isize;
990 let mut x = start as isize;
991 let end_i = end as isize;
992 while x >= end_i {
993 v.push(x as usize);
994 x -= st;
995 }
996 if v.is_empty() {
997 return Err(DmError::InvalidRange { start, end, step });
998 }
999 Ok(v)
1000 }
1001
1002 let periods = axis_usize(r.period)?;
1003 let mut out = Vec::with_capacity(periods.len());
1004 for p in periods {
1005 out.push(DmParams { period: Some(p) });
1006 }
1007 Ok(out)
1008}
1009
1010pub fn dm_batch_with_kernel(
1011 high: &[f64],
1012 low: &[f64],
1013 sweep: &DmBatchRange,
1014 k: Kernel,
1015) -> Result<DmBatchOutput, DmError> {
1016 let kernel = match k {
1017 Kernel::Auto => detect_best_batch_kernel(),
1018 other if other.is_batch() => other,
1019 _ => return Err(DmError::InvalidKernelForBatch(k)),
1020 };
1021 let simd = match kernel {
1022 Kernel::Avx512Batch => Kernel::Avx512,
1023 Kernel::Avx2Batch => Kernel::Avx2,
1024 Kernel::ScalarBatch => Kernel::Scalar,
1025 _ => unreachable!(),
1026 };
1027 dm_batch_par_slice(high, low, sweep, simd)
1028}
1029
1030#[inline(always)]
1031pub fn dm_batch_slice(
1032 high: &[f64],
1033 low: &[f64],
1034 sweep: &DmBatchRange,
1035 kern: Kernel,
1036) -> Result<DmBatchOutput, DmError> {
1037 dm_batch_inner(high, low, sweep, kern, false)
1038}
1039
1040#[inline(always)]
1041pub fn dm_batch_par_slice(
1042 high: &[f64],
1043 low: &[f64],
1044 sweep: &DmBatchRange,
1045 kern: Kernel,
1046) -> Result<DmBatchOutput, DmError> {
1047 dm_batch_inner(high, low, sweep, kern, true)
1048}
1049
1050#[inline(always)]
1051fn dm_batch_inner_into(
1052 high: &[f64],
1053 low: &[f64],
1054 sweep: &DmBatchRange,
1055 kern: Kernel,
1056 parallel: bool,
1057 first: usize,
1058 plus_out: &mut [f64],
1059 minus_out: &mut [f64],
1060) -> Result<Vec<DmParams>, DmError> {
1061 let combos = expand_grid(sweep)?;
1062
1063 let rows = combos.len();
1064 let cols = high.len();
1065
1066 let _total = rows.checked_mul(cols).ok_or(DmError::InvalidRange {
1067 start: sweep.period.0,
1068 end: sweep.period.1,
1069 step: sweep.period.2,
1070 })?;
1071 let chosen = match kern {
1072 Kernel::Auto => detect_best_batch_kernel(),
1073 k => k,
1074 };
1075
1076 let do_row = |row: usize, plus_row: &mut [f64], minus_row: &mut [f64]| {
1077 let p = combos[row].period.unwrap();
1078 dm_compute_into(
1079 high,
1080 low,
1081 p,
1082 first,
1083 match chosen {
1084 Kernel::Avx512Batch => Kernel::Avx512,
1085 Kernel::Avx2Batch => Kernel::Avx2,
1086 Kernel::ScalarBatch => Kernel::Scalar,
1087 k => k,
1088 },
1089 plus_row,
1090 minus_row,
1091 );
1092 };
1093
1094 if parallel {
1095 #[cfg(not(target_arch = "wasm32"))]
1096 {
1097 use rayon::prelude::*;
1098 plus_out
1099 .par_chunks_mut(cols)
1100 .zip(minus_out.par_chunks_mut(cols))
1101 .enumerate()
1102 .for_each(|(r, (pr, mr))| do_row(r, pr, mr));
1103 }
1104 #[cfg(target_arch = "wasm32")]
1105 {
1106 for (r, (pr, mr)) in plus_out
1107 .chunks_mut(cols)
1108 .zip(minus_out.chunks_mut(cols))
1109 .enumerate()
1110 {
1111 do_row(r, pr, mr);
1112 }
1113 }
1114 } else {
1115 for (r, (pr, mr)) in plus_out
1116 .chunks_mut(cols)
1117 .zip(minus_out.chunks_mut(cols))
1118 .enumerate()
1119 {
1120 do_row(r, pr, mr);
1121 }
1122 }
1123
1124 Ok(combos)
1125}
1126
1127#[inline(always)]
1128fn dm_batch_inner(
1129 high: &[f64],
1130 low: &[f64],
1131 sweep: &DmBatchRange,
1132 kern: Kernel,
1133 parallel: bool,
1134) -> Result<DmBatchOutput, DmError> {
1135 let combos = expand_grid(sweep)?;
1136
1137 let first = high
1138 .iter()
1139 .zip(low.iter())
1140 .position(|(&h, &l)| !h.is_nan() && !l.is_nan())
1141 .ok_or(DmError::AllValuesNaN)?;
1142
1143 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1144 if high.len() - first < max_p {
1145 return Err(DmError::NotEnoughValidData {
1146 needed: max_p,
1147 valid: high.len() - first,
1148 });
1149 }
1150
1151 let rows = combos.len();
1152 let cols = high.len();
1153
1154 let _total = rows.checked_mul(cols).ok_or(DmError::InvalidRange {
1155 start: sweep.period.0,
1156 end: sweep.period.1,
1157 step: sweep.period.2,
1158 })?;
1159
1160 let mut plus_mu = make_uninit_matrix(rows, cols);
1161 let mut minus_mu = make_uninit_matrix(rows, cols);
1162
1163 let warm: Vec<usize> = combos
1164 .iter()
1165 .map(|c| first + c.period.unwrap() - 1)
1166 .collect();
1167 init_matrix_prefixes(&mut plus_mu, cols, &warm);
1168 init_matrix_prefixes(&mut minus_mu, cols, &warm);
1169
1170 let mut plus_guard = core::mem::ManuallyDrop::new(plus_mu);
1171 let mut minus_guard = core::mem::ManuallyDrop::new(minus_mu);
1172 let plus_out: &mut [f64] = unsafe {
1173 core::slice::from_raw_parts_mut(plus_guard.as_mut_ptr() as *mut f64, plus_guard.len())
1174 };
1175 let minus_out: &mut [f64] = unsafe {
1176 core::slice::from_raw_parts_mut(minus_guard.as_mut_ptr() as *mut f64, minus_guard.len())
1177 };
1178
1179 let combos = dm_batch_inner_into(high, low, sweep, kern, parallel, first, plus_out, minus_out)?;
1180
1181 let plus = unsafe {
1182 Vec::from_raw_parts(
1183 plus_guard.as_mut_ptr() as *mut f64,
1184 plus_guard.len(),
1185 plus_guard.capacity(),
1186 )
1187 };
1188 let minus = unsafe {
1189 Vec::from_raw_parts(
1190 minus_guard.as_mut_ptr() as *mut f64,
1191 minus_guard.len(),
1192 minus_guard.capacity(),
1193 )
1194 };
1195
1196 Ok(DmBatchOutput {
1197 plus,
1198 minus,
1199 combos,
1200 rows,
1201 cols,
1202 })
1203}
1204
1205#[inline(always)]
1206unsafe fn dm_row_scalar(
1207 high: &[f64],
1208 low: &[f64],
1209 first: usize,
1210 period: usize,
1211 plus: &mut [f64],
1212 minus: &mut [f64],
1213) {
1214 let mut prev_high = high[first];
1215 let mut prev_low = low[first];
1216 let mut sum_plus = 0.0;
1217 let mut sum_minus = 0.0;
1218
1219 let end_init = first + period - 1;
1220 for i in (first + 1)..=end_init {
1221 let diff_p = high[i] - prev_high;
1222 let diff_m = prev_low - low[i];
1223 prev_high = high[i];
1224 prev_low = low[i];
1225
1226 let plus_val = if diff_p > 0.0 && diff_p > diff_m {
1227 diff_p
1228 } else {
1229 0.0
1230 };
1231 let minus_val = if diff_m > 0.0 && diff_m > diff_p {
1232 diff_m
1233 } else {
1234 0.0
1235 };
1236
1237 sum_plus += plus_val;
1238 sum_minus += minus_val;
1239 }
1240
1241 plus[end_init] = sum_plus;
1242 minus[end_init] = sum_minus;
1243
1244 let inv_period = 1.0 / (period as f64);
1245
1246 for i in (end_init + 1)..high.len() {
1247 let diff_p = high[i] - prev_high;
1248 let diff_m = prev_low - low[i];
1249 prev_high = high[i];
1250 prev_low = low[i];
1251
1252 let plus_val = if diff_p > 0.0 && diff_p > diff_m {
1253 diff_p
1254 } else {
1255 0.0
1256 };
1257 let minus_val = if diff_m > 0.0 && diff_m > diff_p {
1258 diff_m
1259 } else {
1260 0.0
1261 };
1262
1263 sum_plus = sum_plus - (sum_plus * inv_period) + plus_val;
1264 sum_minus = sum_minus - (sum_minus * inv_period) + minus_val;
1265
1266 plus[i] = sum_plus;
1267 minus[i] = sum_minus;
1268 }
1269}
1270
1271#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1272#[inline(always)]
1273unsafe fn dm_row_avx2(
1274 high: &[f64],
1275 low: &[f64],
1276 first: usize,
1277 period: usize,
1278 plus: &mut [f64],
1279 minus: &mut [f64],
1280) {
1281 dm_row_scalar(high, low, first, period, plus, minus)
1282}
1283
1284#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1285#[inline(always)]
1286unsafe fn dm_row_avx512(
1287 high: &[f64],
1288 low: &[f64],
1289 first: usize,
1290 period: usize,
1291 plus: &mut [f64],
1292 minus: &mut [f64],
1293) {
1294 dm_row_scalar(high, low, first, period, plus, minus)
1295}
1296
1297#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1298#[inline(always)]
1299unsafe fn dm_row_avx512_short(
1300 high: &[f64],
1301 low: &[f64],
1302 first: usize,
1303 period: usize,
1304 plus: &mut [f64],
1305 minus: &mut [f64],
1306) {
1307 dm_row_avx512(high, low, first, period, plus, minus)
1308}
1309
1310#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1311#[inline(always)]
1312unsafe fn dm_row_avx512_long(
1313 high: &[f64],
1314 low: &[f64],
1315 first: usize,
1316 period: usize,
1317 plus: &mut [f64],
1318 minus: &mut [f64],
1319) {
1320 dm_row_avx512(high, low, first, period, plus, minus)
1321}
1322
1323#[cfg(test)]
1324mod tests {
1325 use super::*;
1326 use crate::skip_if_unsupported;
1327 use crate::utilities::data_loader::read_candles_from_csv;
1328
1329 fn check_dm_partial_params(
1330 test_name: &str,
1331 kernel: Kernel,
1332 ) -> Result<(), Box<dyn std::error::Error>> {
1333 skip_if_unsupported!(kernel, test_name);
1334 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1335 let candles = read_candles_from_csv(file_path)?;
1336 let default_params = DmParams { period: None };
1337 let input_default = DmInput::from_candles(&candles, default_params);
1338 let output_default = dm_with_kernel(&input_default, kernel)?;
1339 assert_eq!(output_default.plus.len(), candles.high.len());
1340 assert_eq!(output_default.minus.len(), candles.high.len());
1341
1342 let params_custom = DmParams { period: Some(10) };
1343 let input_custom = DmInput::from_candles(&candles, params_custom);
1344 let output_custom = dm_with_kernel(&input_custom, kernel)?;
1345 assert_eq!(output_custom.plus.len(), candles.high.len());
1346 assert_eq!(output_custom.minus.len(), candles.high.len());
1347 Ok(())
1348 }
1349
1350 fn check_dm_default_candles(
1351 test_name: &str,
1352 kernel: Kernel,
1353 ) -> Result<(), Box<dyn std::error::Error>> {
1354 skip_if_unsupported!(kernel, test_name);
1355 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1356 let candles = read_candles_from_csv(file_path)?;
1357 let input = DmInput::with_default_candles(&candles);
1358 let result = dm_with_kernel(&input, kernel)?;
1359 assert_eq!(result.plus.len(), candles.high.len());
1360 assert_eq!(result.minus.len(), candles.high.len());
1361 Ok(())
1362 }
1363
1364 fn check_dm_with_slice_data(
1365 test_name: &str,
1366 kernel: Kernel,
1367 ) -> Result<(), Box<dyn std::error::Error>> {
1368 skip_if_unsupported!(kernel, test_name);
1369 let high_values = [8000.0, 8050.0, 8100.0, 8075.0, 8110.0, 8050.0];
1370 let low_values = [7800.0, 7900.0, 7950.0, 7950.0, 8000.0, 7950.0];
1371 let params = DmParams { period: Some(3) };
1372 let input = DmInput::from_slices(&high_values, &low_values, params);
1373 let result = dm_with_kernel(&input, kernel)?;
1374 assert_eq!(result.plus.len(), 6);
1375 assert_eq!(result.minus.len(), 6);
1376
1377 for i in 0..2 {
1378 assert!(result.plus[i].is_nan());
1379 assert!(result.minus[i].is_nan());
1380 }
1381 Ok(())
1382 }
1383
1384 fn check_dm_zero_period(
1385 test_name: &str,
1386 kernel: Kernel,
1387 ) -> Result<(), Box<dyn std::error::Error>> {
1388 skip_if_unsupported!(kernel, test_name);
1389 let high_values = [100.0, 110.0, 120.0];
1390 let low_values = [90.0, 100.0, 110.0];
1391 let params = DmParams { period: Some(0) };
1392 let input = DmInput::from_slices(&high_values, &low_values, params);
1393 let result = dm_with_kernel(&input, kernel);
1394 assert!(result.is_err());
1395 Ok(())
1396 }
1397
1398 fn check_dm_period_exceeds_data_length(
1399 test_name: &str,
1400 kernel: Kernel,
1401 ) -> Result<(), Box<dyn std::error::Error>> {
1402 skip_if_unsupported!(kernel, test_name);
1403 let high_values = [100.0, 110.0, 120.0];
1404 let low_values = [90.0, 100.0, 110.0];
1405 let params = DmParams { period: Some(10) };
1406 let input = DmInput::from_slices(&high_values, &low_values, params);
1407 let result = dm_with_kernel(&input, kernel);
1408 assert!(result.is_err());
1409 Ok(())
1410 }
1411
1412 fn check_dm_not_enough_valid_data(
1413 test_name: &str,
1414 kernel: Kernel,
1415 ) -> Result<(), Box<dyn std::error::Error>> {
1416 skip_if_unsupported!(kernel, test_name);
1417 let high_values = [f64::NAN, f64::NAN, 100.0, 101.0, 102.0];
1418 let low_values = [f64::NAN, f64::NAN, 90.0, 89.0, 88.0];
1419 let params = DmParams { period: Some(5) };
1420 let input = DmInput::from_slices(&high_values, &low_values, params);
1421 let result = dm_with_kernel(&input, kernel);
1422 assert!(result.is_err());
1423 Ok(())
1424 }
1425
1426 fn check_dm_all_values_nan(
1427 test_name: &str,
1428 kernel: Kernel,
1429 ) -> Result<(), Box<dyn std::error::Error>> {
1430 skip_if_unsupported!(kernel, test_name);
1431 let high_values = [f64::NAN, f64::NAN, f64::NAN];
1432 let low_values = [f64::NAN, f64::NAN, f64::NAN];
1433 let params = DmParams { period: Some(3) };
1434 let input = DmInput::from_slices(&high_values, &low_values, params);
1435 let result = dm_with_kernel(&input, kernel);
1436 assert!(result.is_err());
1437 Ok(())
1438 }
1439
1440 fn check_dm_with_slice_reinput(
1441 test_name: &str,
1442 kernel: Kernel,
1443 ) -> Result<(), Box<dyn std::error::Error>> {
1444 skip_if_unsupported!(kernel, test_name);
1445 let high_values = [9000.0, 9100.0, 9050.0, 9200.0, 9150.0, 9300.0];
1446 let low_values = [8900.0, 9000.0, 8950.0, 9000.0, 9050.0, 9100.0];
1447 let params = DmParams { period: Some(2) };
1448 let input_first = DmInput::from_slices(&high_values, &low_values, params.clone());
1449 let result_first = dm_with_kernel(&input_first, kernel)?;
1450 let input_second = DmInput::from_slices(&result_first.plus, &result_first.minus, params);
1451 let result_second = dm_with_kernel(&input_second, kernel)?;
1452 assert_eq!(result_second.plus.len(), high_values.len());
1453 assert_eq!(result_second.minus.len(), high_values.len());
1454 Ok(())
1455 }
1456
1457 fn check_dm_known_values(
1458 test_name: &str,
1459 kernel: Kernel,
1460 ) -> Result<(), Box<dyn std::error::Error>> {
1461 skip_if_unsupported!(kernel, test_name);
1462 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1463 let candles = read_candles_from_csv(file_path)?;
1464 let params = DmParams { period: Some(14) };
1465 let input = DmInput::from_candles(&candles, params);
1466 let output = dm_with_kernel(&input, kernel)?;
1467
1468 let slice_size = 5;
1469 let last_plus_slice = &output.plus[output.plus.len() - slice_size..];
1470 let last_minus_slice = &output.minus[output.minus.len() - slice_size..];
1471
1472 let expected_plus = [
1473 1410.819956368491,
1474 1384.04710234217,
1475 1285.186595032015,
1476 1199.3875525297283,
1477 1113.7170130633192,
1478 ];
1479 let expected_minus = [
1480 3602.8631384045057,
1481 3345.5157713756125,
1482 3258.5503591344973,
1483 3025.796762053462,
1484 3493.668421906786,
1485 ];
1486
1487 for i in 0..slice_size {
1488 let diff_plus = (last_plus_slice[i] - expected_plus[i]).abs();
1489 let diff_minus = (last_minus_slice[i] - expected_minus[i]).abs();
1490 assert!(diff_plus < 1e-6);
1491 assert!(diff_minus < 1e-6);
1492 }
1493 Ok(())
1494 }
1495
1496 macro_rules! generate_all_dm_tests {
1497 ($($test_fn:ident),*) => {
1498 paste::paste! {
1499 $(
1500 #[test]
1501 fn [<$test_fn _scalar_f64>]() {
1502 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1503 }
1504 )*
1505 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1506 $(
1507 #[test]
1508 fn [<$test_fn _avx2_f64>]() {
1509 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1510 }
1511 #[test]
1512 fn [<$test_fn _avx512_f64>]() {
1513 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1514 }
1515 )*
1516 }
1517 }
1518 }
1519
1520 #[cfg(debug_assertions)]
1521 fn check_dm_no_poison(
1522 test_name: &str,
1523 kernel: Kernel,
1524 ) -> Result<(), Box<dyn std::error::Error>> {
1525 skip_if_unsupported!(kernel, test_name);
1526
1527 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1528 let candles = read_candles_from_csv(file_path)?;
1529
1530 let test_params = vec![
1531 DmParams::default(),
1532 DmParams { period: Some(2) },
1533 DmParams { period: Some(3) },
1534 DmParams { period: Some(5) },
1535 DmParams { period: Some(7) },
1536 DmParams { period: Some(10) },
1537 DmParams { period: Some(14) },
1538 DmParams { period: Some(20) },
1539 DmParams { period: Some(30) },
1540 DmParams { period: Some(50) },
1541 DmParams { period: Some(100) },
1542 DmParams { period: Some(200) },
1543 DmParams { period: Some(25) },
1544 ];
1545
1546 for (param_idx, params) in test_params.iter().enumerate() {
1547 let input = DmInput::from_candles(&candles, params.clone());
1548 let output = dm_with_kernel(&input, kernel)?;
1549
1550 for (i, &val) in output.plus.iter().enumerate() {
1551 if val.is_nan() {
1552 continue;
1553 }
1554
1555 let bits = val.to_bits();
1556
1557 if bits == 0x11111111_11111111 {
1558 panic!(
1559 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in plus array \
1560 with params: period={} (param set {})",
1561 test_name, val, bits, i,
1562 params.period.unwrap_or(14), param_idx
1563 );
1564 }
1565
1566 if bits == 0x22222222_22222222 {
1567 panic!(
1568 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in plus array \
1569 with params: period={} (param set {})",
1570 test_name, val, bits, i,
1571 params.period.unwrap_or(14), param_idx
1572 );
1573 }
1574
1575 if bits == 0x33333333_33333333 {
1576 panic!(
1577 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in plus array \
1578 with params: period={} (param set {})",
1579 test_name, val, bits, i,
1580 params.period.unwrap_or(14), param_idx
1581 );
1582 }
1583 }
1584
1585 for (i, &val) in output.minus.iter().enumerate() {
1586 if val.is_nan() {
1587 continue;
1588 }
1589
1590 let bits = val.to_bits();
1591
1592 if bits == 0x11111111_11111111 {
1593 panic!(
1594 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in minus array \
1595 with params: period={} (param set {})",
1596 test_name, val, bits, i,
1597 params.period.unwrap_or(14), param_idx
1598 );
1599 }
1600
1601 if bits == 0x22222222_22222222 {
1602 panic!(
1603 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in minus array \
1604 with params: period={} (param set {})",
1605 test_name, val, bits, i,
1606 params.period.unwrap_or(14), param_idx
1607 );
1608 }
1609
1610 if bits == 0x33333333_33333333 {
1611 panic!(
1612 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in minus array \
1613 with params: period={} (param set {})",
1614 test_name, val, bits, i,
1615 params.period.unwrap_or(14), param_idx
1616 );
1617 }
1618 }
1619 }
1620
1621 Ok(())
1622 }
1623
1624 #[cfg(not(debug_assertions))]
1625 fn check_dm_no_poison(
1626 _test_name: &str,
1627 _kernel: Kernel,
1628 ) -> Result<(), Box<dyn std::error::Error>> {
1629 Ok(())
1630 }
1631
1632 #[cfg(feature = "proptest")]
1633 #[allow(clippy::float_cmp)]
1634 fn check_dm_property(
1635 test_name: &str,
1636 kernel: Kernel,
1637 ) -> Result<(), Box<dyn std::error::Error>> {
1638 use proptest::prelude::*;
1639 skip_if_unsupported!(kernel, test_name);
1640
1641 let strat = (2usize..=50).prop_flat_map(|period| {
1642 (
1643 (100f64..10000f64, 0.01f64..0.05f64, period + 10..400)
1644 .prop_flat_map(move |(base_price, volatility, data_len)| {
1645 (
1646 Just(base_price),
1647 Just(volatility),
1648 Just(data_len),
1649 prop::collection::vec((-1f64..1f64), data_len),
1650 prop::collection::vec((0f64..2f64), data_len),
1651 )
1652 })
1653 .prop_map(
1654 move |(base_price, volatility, data_len, changes, spreads)| {
1655 let mut high = Vec::with_capacity(data_len);
1656 let mut low = Vec::with_capacity(data_len);
1657 let mut current_price = base_price;
1658
1659 for i in 0..data_len {
1660 let change = changes[i] * volatility * current_price;
1661 current_price = (current_price + change).max(10.0);
1662
1663 let spread = current_price * 0.01 * spreads[i];
1664 let daily_high = current_price + spread;
1665 let daily_low = current_price - spread;
1666
1667 high.push(daily_high);
1668 low.push(daily_low.max(1.0));
1669 }
1670
1671 (high, low)
1672 },
1673 ),
1674 Just(period),
1675 )
1676 });
1677
1678 proptest::test_runner::TestRunner::default().run(&strat, |((high, low), period)| {
1679 let params = DmParams {
1680 period: Some(period),
1681 };
1682 let input = DmInput::from_slices(&high, &low, params);
1683
1684 let DmOutput {
1685 plus: out_plus,
1686 minus: out_minus,
1687 } = dm_with_kernel(&input, kernel)?;
1688
1689 let DmOutput {
1690 plus: ref_plus,
1691 minus: ref_minus,
1692 } = dm_with_kernel(&input, Kernel::Scalar)?;
1693
1694 prop_assert_eq!(out_plus.len(), high.len());
1695 prop_assert_eq!(out_minus.len(), high.len());
1696
1697 let warmup_period = period - 1;
1698 for i in 0..warmup_period {
1699 prop_assert!(
1700 out_plus[i].is_nan(),
1701 "Plus value at index {} should be NaN during warmup",
1702 i
1703 );
1704 prop_assert!(
1705 out_minus[i].is_nan(),
1706 "Minus value at index {} should be NaN during warmup",
1707 i
1708 );
1709 }
1710
1711 for i in warmup_period..high.len() {
1712 if !out_plus[i].is_nan() {
1713 prop_assert!(
1714 out_plus[i] >= -1e-9,
1715 "Plus DM at index {} is negative: {}",
1716 i,
1717 out_plus[i]
1718 );
1719 }
1720 if !out_minus[i].is_nan() {
1721 prop_assert!(
1722 out_minus[i] >= -1e-9,
1723 "Minus DM at index {} is negative: {}",
1724 i,
1725 out_minus[i]
1726 );
1727 }
1728 }
1729
1730 const MAX_ULP: i64 = 3;
1731 for i in 0..high.len() {
1732 let plus_y = out_plus[i];
1733 let plus_r = ref_plus[i];
1734 let minus_y = out_minus[i];
1735 let minus_r = ref_minus[i];
1736
1737 if plus_y.is_nan() {
1738 prop_assert!(
1739 plus_r.is_nan(),
1740 "Plus kernel mismatch at {}: {} vs NaN",
1741 i,
1742 plus_r
1743 );
1744 } else {
1745 let plus_y_bits = plus_y.to_bits();
1746 let plus_r_bits = plus_r.to_bits();
1747 let plus_ulp_diff = (plus_y_bits as i64).wrapping_sub(plus_r_bits as i64).abs();
1748
1749 prop_assert!(
1750 plus_ulp_diff <= MAX_ULP,
1751 "Plus kernel mismatch at {}: {} vs {} (ULP diff: {})",
1752 i,
1753 plus_y,
1754 plus_r,
1755 plus_ulp_diff
1756 );
1757 }
1758
1759 if minus_y.is_nan() {
1760 prop_assert!(
1761 minus_r.is_nan(),
1762 "Minus kernel mismatch at {}: {} vs NaN",
1763 i,
1764 minus_r
1765 );
1766 } else {
1767 let minus_y_bits = minus_y.to_bits();
1768 let minus_r_bits = minus_r.to_bits();
1769 let minus_ulp_diff = (minus_y_bits as i64)
1770 .wrapping_sub(minus_r_bits as i64)
1771 .abs();
1772
1773 prop_assert!(
1774 minus_ulp_diff <= MAX_ULP,
1775 "Minus kernel mismatch at {}: {} vs {} (ULP diff: {})",
1776 i,
1777 minus_y,
1778 minus_r,
1779 minus_ulp_diff
1780 );
1781 }
1782 }
1783
1784 let all_high_equal = high.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10);
1785 let all_low_equal = low.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10);
1786
1787 if all_high_equal && all_low_equal {
1788 for i in (period * 2).min(high.len() - 1)..high.len() {
1789 if !out_plus[i].is_nan() {
1790 prop_assert!(
1791 out_plus[i].abs() < 1e-6,
1792 "Plus DM should be near zero for constant data at {}: {}",
1793 i,
1794 out_plus[i]
1795 );
1796 }
1797 if !out_minus[i].is_nan() {
1798 prop_assert!(
1799 out_minus[i].abs() < 1e-6,
1800 "Minus DM should be near zero for constant data at {}: {}",
1801 i,
1802 out_minus[i]
1803 );
1804 }
1805 }
1806 }
1807
1808 Ok(())
1809 })?;
1810
1811 Ok(())
1812 }
1813
1814 generate_all_dm_tests!(
1815 check_dm_partial_params,
1816 check_dm_default_candles,
1817 check_dm_with_slice_data,
1818 check_dm_zero_period,
1819 check_dm_period_exceeds_data_length,
1820 check_dm_not_enough_valid_data,
1821 check_dm_all_values_nan,
1822 check_dm_with_slice_reinput,
1823 check_dm_known_values,
1824 check_dm_no_poison
1825 );
1826
1827 #[cfg(feature = "proptest")]
1828 generate_all_dm_tests!(check_dm_property);
1829
1830 #[test]
1831 fn test_dm_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
1832 let n = 256usize;
1833 let mut high = Vec::with_capacity(n);
1834 let mut low = Vec::with_capacity(n);
1835 let mut price = 100.0f64;
1836 for i in 0..n {
1837 let drift = ((i % 7) as i32 - 3) as f64 * 0.3;
1838 price = (price + drift).max(1.0);
1839 let spread = 0.5 + 0.1 * ((i % 5) as f64);
1840 high.push(price + spread);
1841 low.push((price - spread).max(0.01));
1842 }
1843
1844 let input = DmInput::from_slices(&high, &low, DmParams::default());
1845
1846 let base = dm(&input)?;
1847
1848 let mut plus = vec![0.0; n];
1849 let mut minus = vec![0.0; n];
1850 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1851 dm_into(&input, &mut plus, &mut minus)?;
1852
1853 fn eq_or_both_nan(a: f64, b: f64) -> bool {
1854 a == b || (a.is_nan() && b.is_nan())
1855 }
1856
1857 assert_eq!(base.plus.len(), n);
1858 assert_eq!(base.minus.len(), n);
1859 for i in 0..n {
1860 assert!(
1861 eq_or_both_nan(base.plus[i], plus[i]),
1862 "plus mismatch at {}: base={} into={}",
1863 i,
1864 base.plus[i],
1865 plus[i]
1866 );
1867 assert!(
1868 eq_or_both_nan(base.minus[i], minus[i]),
1869 "minus mismatch at {}: base={} into={}",
1870 i,
1871 base.minus[i],
1872 minus[i]
1873 );
1874 }
1875 Ok(())
1876 }
1877
1878 fn check_batch_default_row(
1879 test: &str,
1880 kernel: Kernel,
1881 ) -> Result<(), Box<dyn std::error::Error>> {
1882 skip_if_unsupported!(kernel, test);
1883
1884 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1885 let c = read_candles_from_csv(file)?;
1886
1887 let output = DmBatchBuilder::new().kernel(kernel).apply_candles(&c)?;
1888
1889 let def = DmParams::default();
1890 let (row_plus, row_minus) = output.values_for(&def).expect("default row missing");
1891
1892 assert_eq!(row_plus.len(), c.high.len());
1893 assert_eq!(row_minus.len(), c.high.len());
1894
1895 let expected_plus = [
1896 1410.819956368491,
1897 1384.04710234217,
1898 1285.186595032015,
1899 1199.3875525297283,
1900 1113.7170130633192,
1901 ];
1902 let expected_minus = [
1903 3602.8631384045057,
1904 3345.5157713756125,
1905 3258.5503591344973,
1906 3025.796762053462,
1907 3493.668421906786,
1908 ];
1909 let start = row_plus.len() - 5;
1910 for (i, &v) in row_plus[start..].iter().enumerate() {
1911 assert!((v - expected_plus[i]).abs() < 1e-6);
1912 }
1913 for (i, &v) in row_minus[start..].iter().enumerate() {
1914 assert!((v - expected_minus[i]).abs() < 1e-6);
1915 }
1916 Ok(())
1917 }
1918
1919 #[cfg(debug_assertions)]
1920 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
1921 skip_if_unsupported!(kernel, test);
1922
1923 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1924 let c = read_candles_from_csv(file)?;
1925
1926 let test_configs = vec![
1927 (2, 10, 2),
1928 (5, 25, 5),
1929 (30, 60, 15),
1930 (2, 5, 1),
1931 (14, 14, 0),
1932 (10, 100, 10),
1933 (100, 200, 50),
1934 ];
1935
1936 for (cfg_idx, &(p_start, p_end, p_step)) in test_configs.iter().enumerate() {
1937 let output = DmBatchBuilder::new()
1938 .kernel(kernel)
1939 .period_range(p_start, p_end, p_step)
1940 .apply_candles(&c)?;
1941
1942 for (idx, &val) in output.plus.iter().enumerate() {
1943 if val.is_nan() {
1944 continue;
1945 }
1946
1947 let bits = val.to_bits();
1948 let row = idx / output.cols;
1949 let col = idx % output.cols;
1950 let combo = &output.combos[row];
1951
1952 if bits == 0x11111111_11111111 {
1953 panic!(
1954 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) in plus \
1955 at row {} col {} (flat index {}) with params: period={}",
1956 test, cfg_idx, val, bits, row, col, idx,
1957 combo.period.unwrap_or(14)
1958 );
1959 }
1960
1961 if bits == 0x22222222_22222222 {
1962 panic!(
1963 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) in plus \
1964 at row {} col {} (flat index {}) with params: period={}",
1965 test, cfg_idx, val, bits, row, col, idx,
1966 combo.period.unwrap_or(14)
1967 );
1968 }
1969
1970 if bits == 0x33333333_33333333 {
1971 panic!(
1972 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) in plus \
1973 at row {} col {} (flat index {}) with params: period={}",
1974 test, cfg_idx, val, bits, row, col, idx,
1975 combo.period.unwrap_or(14)
1976 );
1977 }
1978 }
1979
1980 for (idx, &val) in output.minus.iter().enumerate() {
1981 if val.is_nan() {
1982 continue;
1983 }
1984
1985 let bits = val.to_bits();
1986 let row = idx / output.cols;
1987 let col = idx % output.cols;
1988 let combo = &output.combos[row];
1989
1990 if bits == 0x11111111_11111111 {
1991 panic!(
1992 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) in minus \
1993 at row {} col {} (flat index {}) with params: period={}",
1994 test, cfg_idx, val, bits, row, col, idx,
1995 combo.period.unwrap_or(14)
1996 );
1997 }
1998
1999 if bits == 0x22222222_22222222 {
2000 panic!(
2001 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) in minus \
2002 at row {} col {} (flat index {}) with params: period={}",
2003 test, cfg_idx, val, bits, row, col, idx,
2004 combo.period.unwrap_or(14)
2005 );
2006 }
2007
2008 if bits == 0x33333333_33333333 {
2009 panic!(
2010 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) in minus \
2011 at row {} col {} (flat index {}) with params: period={}",
2012 test, cfg_idx, val, bits, row, col, idx,
2013 combo.period.unwrap_or(14)
2014 );
2015 }
2016 }
2017 }
2018
2019 Ok(())
2020 }
2021
2022 #[cfg(not(debug_assertions))]
2023 fn check_batch_no_poison(
2024 _test: &str,
2025 _kernel: Kernel,
2026 ) -> Result<(), Box<dyn std::error::Error>> {
2027 Ok(())
2028 }
2029
2030 macro_rules! gen_batch_tests {
2031 ($fn_name:ident) => {
2032 paste::paste! {
2033 #[test] fn [<$fn_name _scalar>]() {
2034 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2035 }
2036 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2037 #[test] fn [<$fn_name _avx2>]() {
2038 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2039 }
2040 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2041 #[test] fn [<$fn_name _avx512>]() {
2042 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2043 }
2044 #[test] fn [<$fn_name _auto_detect>]() {
2045 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2046 }
2047 }
2048 };
2049 }
2050 gen_batch_tests!(check_batch_default_row);
2051 gen_batch_tests!(check_batch_no_poison);
2052}
2053
2054#[cfg(feature = "python")]
2055use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
2056#[cfg(feature = "python")]
2057use pyo3::exceptions::PyValueError;
2058#[cfg(feature = "python")]
2059use pyo3::prelude::*;
2060#[cfg(feature = "python")]
2061use pyo3::types::PyDict;
2062
2063#[cfg(feature = "python")]
2064#[pyfunction(name = "dm")]
2065#[pyo3(signature = (high, low, period, kernel=None))]
2066pub fn dm_py<'py>(
2067 py: Python<'py>,
2068 high: PyReadonlyArray1<'py, f64>,
2069 low: PyReadonlyArray1<'py, f64>,
2070 period: usize,
2071 kernel: Option<&str>,
2072) -> PyResult<(Bound<'py, PyArray1<f64>>, Bound<'py, PyArray1<f64>>)> {
2073 let h = high.as_slice()?;
2074 let l = low.as_slice()?;
2075 if h.len() != l.len() {
2076 return Err(PyValueError::new_err("high/low length mismatch"));
2077 }
2078
2079 let params = DmParams {
2080 period: Some(period),
2081 };
2082 let input = DmInput::from_slices(h, l, params);
2083 let kern = validate_kernel(kernel, false)?;
2084
2085 let out_plus = unsafe { PyArray1::<f64>::new(py, [h.len()], false) };
2086 let out_minus = unsafe { PyArray1::<f64>::new(py, [h.len()], false) };
2087 let plus_slice = unsafe { out_plus.as_slice_mut()? };
2088 let minus_slice = unsafe { out_minus.as_slice_mut()? };
2089
2090 py.allow_threads(|| dm_into_slice(plus_slice, minus_slice, &input, kern))
2091 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2092
2093 Ok((out_plus, out_minus))
2094}
2095
2096#[cfg(feature = "python")]
2097#[pyfunction(name = "dm_batch")]
2098#[pyo3(signature = (high, low, period_range, kernel=None))]
2099pub fn dm_batch_py<'py>(
2100 py: Python<'py>,
2101 high: PyReadonlyArray1<'py, f64>,
2102 low: PyReadonlyArray1<'py, f64>,
2103 period_range: (usize, usize, usize),
2104 kernel: Option<&str>,
2105) -> PyResult<Bound<'py, PyDict>> {
2106 let h = high.as_slice()?;
2107 let l = low.as_slice()?;
2108 if h.len() != l.len() {
2109 return Err(PyValueError::new_err("high/low length mismatch"));
2110 }
2111
2112 let sweep = DmBatchRange {
2113 period: period_range,
2114 };
2115 let kern = validate_kernel(kernel, true)?;
2116
2117 let output = py
2118 .allow_threads(|| dm_batch_with_kernel(h, l, &sweep, kern))
2119 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2120
2121 let plus = unsafe { PyArray1::from_vec(py, output.plus).reshape((output.rows, output.cols))? };
2122 let minus =
2123 unsafe { PyArray1::from_vec(py, output.minus).reshape((output.rows, output.cols))? };
2124
2125 let dict = PyDict::new(py);
2126 dict.set_item("plus", plus)?;
2127 dict.set_item("minus", minus)?;
2128 dict.set_item(
2129 "periods",
2130 output
2131 .combos
2132 .iter()
2133 .map(|p| p.period.unwrap() as u64)
2134 .collect::<Vec<_>>()
2135 .into_pyarray(py),
2136 )?;
2137 Ok(dict)
2138}
2139
2140#[cfg(all(feature = "python", feature = "cuda"))]
2141#[pyfunction(name = "dm_cuda_batch_dev")]
2142#[pyo3(signature = (high_f32, low_f32, period_range, device_id=0))]
2143pub fn dm_cuda_batch_dev_py(
2144 py: Python<'_>,
2145 high_f32: numpy::PyReadonlyArray1<'_, f32>,
2146 low_f32: numpy::PyReadonlyArray1<'_, f32>,
2147 period_range: (usize, usize, usize),
2148 device_id: usize,
2149) -> PyResult<(DeviceArrayF32Py, DeviceArrayF32Py)> {
2150 use crate::cuda::cuda_available;
2151 if !cuda_available() {
2152 return Err(PyValueError::new_err("CUDA not available"));
2153 }
2154 let h = high_f32.as_slice()?;
2155 let l = low_f32.as_slice()?;
2156 let sweep = DmBatchRange {
2157 period: period_range,
2158 };
2159 let (pair, ctx, dev) = py.allow_threads(|| {
2160 let cuda = CudaDm::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2161 let ctx = cuda.context_arc();
2162 let dev = cuda.device_id();
2163 cuda.dm_batch_dev(h, l, &sweep)
2164 .map(|(pair, _)| (pair, ctx, dev))
2165 .map_err(|e| PyValueError::new_err(e.to_string()))
2166 })?;
2167 Ok((
2168 DeviceArrayF32Py {
2169 inner: pair.plus,
2170 _ctx: Some(ctx.clone()),
2171 device_id: Some(dev),
2172 },
2173 DeviceArrayF32Py {
2174 inner: pair.minus,
2175 _ctx: Some(ctx),
2176 device_id: Some(dev),
2177 },
2178 ))
2179}
2180
2181#[cfg(all(feature = "python", feature = "cuda"))]
2182#[pyfunction(name = "dm_cuda_many_series_one_param_dev")]
2183#[pyo3(signature = (high_tm_f32, low_tm_f32, cols, rows, period, device_id=0))]
2184pub fn dm_cuda_many_series_one_param_dev_py(
2185 py: Python<'_>,
2186 high_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
2187 low_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
2188 cols: usize,
2189 rows: usize,
2190 period: usize,
2191 device_id: usize,
2192) -> PyResult<(DeviceArrayF32Py, DeviceArrayF32Py)> {
2193 use crate::cuda::cuda_available;
2194 if !cuda_available() {
2195 return Err(PyValueError::new_err("CUDA not available"));
2196 }
2197 let h = high_tm_f32.as_slice()?;
2198 let l = low_tm_f32.as_slice()?;
2199 let (pair, ctx, dev) = py.allow_threads(|| {
2200 let cuda = CudaDm::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2201 let ctx = cuda.context_arc();
2202 let dev = cuda.device_id();
2203 cuda.dm_many_series_one_param_time_major_dev(h, l, cols, rows, period)
2204 .map(|pair| (pair, ctx, dev))
2205 .map_err(|e| PyValueError::new_err(e.to_string()))
2206 })?;
2207 Ok((
2208 DeviceArrayF32Py {
2209 inner: pair.plus,
2210 _ctx: Some(ctx.clone()),
2211 device_id: Some(dev),
2212 },
2213 DeviceArrayF32Py {
2214 inner: pair.minus,
2215 _ctx: Some(ctx),
2216 device_id: Some(dev),
2217 },
2218 ))
2219}
2220
2221#[cfg(feature = "python")]
2222#[pyclass(name = "DmStream")]
2223pub struct DmStreamPy {
2224 stream: DmStream,
2225}
2226
2227#[cfg(feature = "python")]
2228#[pymethods]
2229impl DmStreamPy {
2230 #[new]
2231 fn new(period: usize) -> PyResult<Self> {
2232 let s = DmStream::try_new(DmParams {
2233 period: Some(period),
2234 })
2235 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2236 Ok(Self { stream: s })
2237 }
2238 fn update(&mut self, high: f64, low: f64) -> Option<(f64, f64)> {
2239 self.stream.update(high, low)
2240 }
2241}
2242
2243#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2244use serde::{Deserialize, Serialize};
2245#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2246use wasm_bindgen::prelude::*;
2247
2248#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2249#[derive(Serialize, Deserialize)]
2250pub struct DmJsOutput {
2251 pub values: Vec<f64>,
2252 pub rows: usize,
2253 pub cols: usize,
2254}
2255
2256#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2257#[wasm_bindgen(js_name = dm)]
2258pub fn dm_js(high: &[f64], low: &[f64], period: usize) -> Result<JsValue, JsValue> {
2259 if high.len() != low.len() {
2260 return Err(JsValue::from_str("length mismatch"));
2261 }
2262 let input = DmInput::from_slices(
2263 high,
2264 low,
2265 DmParams {
2266 period: Some(period),
2267 },
2268 );
2269
2270 let mut plus = vec![0.0; high.len()];
2271 let mut minus = vec![0.0; high.len()];
2272 dm_into_slice(&mut plus, &mut minus, &input, detect_best_kernel())
2273 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2274
2275 let mut values = plus;
2276 values.extend_from_slice(&minus);
2277
2278 let output = DmJsOutput {
2279 values,
2280 rows: 2,
2281 cols: high.len(),
2282 };
2283 serde_wasm_bindgen::to_value(&output)
2284 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2285}
2286
2287#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2288#[derive(Serialize, Deserialize)]
2289pub struct DmBatchConfig {
2290 pub period_range: (usize, usize, usize),
2291}
2292
2293#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2294#[derive(Serialize, Deserialize)]
2295pub struct DmBatchJsOutput {
2296 pub values: Vec<f64>,
2297 pub rows: usize,
2298 pub cols: usize,
2299 pub periods: Vec<usize>,
2300}
2301
2302#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2303#[wasm_bindgen(js_name = dm_batch)]
2304pub fn dm_batch_unified_js(high: &[f64], low: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2305 if high.len() != low.len() {
2306 return Err(JsValue::from_str("length mismatch"));
2307 }
2308 let cfg: DmBatchConfig = serde_wasm_bindgen::from_value(config)
2309 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2310
2311 let sweep = DmBatchRange {
2312 period: cfg.period_range,
2313 };
2314 let out = dm_batch_inner(high, low, &sweep, detect_best_kernel(), false)
2315 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2316
2317 let mut values = Vec::with_capacity(out.plus.len() + out.minus.len());
2318 values.extend_from_slice(&out.plus);
2319 values.extend_from_slice(&out.minus);
2320
2321 let periods = out
2322 .combos
2323 .iter()
2324 .map(|p| p.period.unwrap())
2325 .collect::<Vec<_>>();
2326
2327 let js = DmBatchJsOutput {
2328 values,
2329 rows: out.rows * 2,
2330 cols: out.cols,
2331 periods,
2332 };
2333 serde_wasm_bindgen::to_value(&js)
2334 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2335}
2336
2337#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2338#[wasm_bindgen]
2339pub fn dm_alloc(len: usize) -> *mut f64 {
2340 let mut v = Vec::<f64>::with_capacity(len);
2341 let p = v.as_mut_ptr();
2342 std::mem::forget(v);
2343 p
2344}
2345#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2346#[wasm_bindgen]
2347pub fn dm_free(ptr: *mut f64, len: usize) {
2348 unsafe {
2349 let _ = Vec::from_raw_parts(ptr, len, len);
2350 }
2351}
2352
2353#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2354#[wasm_bindgen(js_name = dm_into)]
2355pub fn dm_into_js(
2356 high_ptr: *const f64,
2357 low_ptr: *const f64,
2358 plus_ptr: *mut f64,
2359 minus_ptr: *mut f64,
2360 len: usize,
2361 period: usize,
2362) -> Result<(), JsValue> {
2363 if high_ptr.is_null() || low_ptr.is_null() || plus_ptr.is_null() || minus_ptr.is_null() {
2364 return Err(JsValue::from_str("null pointer"));
2365 }
2366 unsafe {
2367 let h = std::slice::from_raw_parts(high_ptr, len);
2368 let l = std::slice::from_raw_parts(low_ptr, len);
2369 let input = DmInput::from_slices(
2370 h,
2371 l,
2372 DmParams {
2373 period: Some(period),
2374 },
2375 );
2376 let plus = std::slice::from_raw_parts_mut(plus_ptr, len);
2377 let minus = std::slice::from_raw_parts_mut(minus_ptr, len);
2378 dm_into_slice(plus, minus, &input, detect_best_kernel())
2379 .map_err(|e| JsValue::from_str(&e.to_string()))
2380 }
2381}