1#[cfg(all(feature = "python", feature = "cuda"))]
2use numpy::PyUntypedArrayMethods;
3#[cfg(feature = "python")]
4use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
5#[cfg(feature = "python")]
6use pyo3::exceptions::PyValueError;
7#[cfg(feature = "python")]
8use pyo3::prelude::*;
9#[cfg(feature = "python")]
10use pyo3::types::PyDict;
11
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use serde::{Deserialize, Serialize};
14#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
15use wasm_bindgen::prelude::*;
16
17use crate::utilities::data_loader::{source_type, Candles};
18use crate::utilities::enums::Kernel;
19use crate::utilities::helpers::{
20 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
21 make_uninit_matrix,
22};
23#[cfg(feature = "python")]
24use crate::utilities::kernel_validation::validate_kernel;
25#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
26use core::arch::x86_64::*;
27#[cfg(not(target_arch = "wasm32"))]
28use rayon::prelude::*;
29use std::convert::AsRef;
30use std::error::Error;
31use thiserror::Error;
32
33#[derive(Debug, Clone)]
34pub enum DiData<'a> {
35 Candles {
36 candles: &'a Candles,
37 },
38 Slices {
39 high: &'a [f64],
40 low: &'a [f64],
41 close: &'a [f64],
42 },
43}
44
45#[derive(Debug, Clone)]
46pub struct DiOutput {
47 pub plus: Vec<f64>,
48 pub minus: Vec<f64>,
49}
50
51#[derive(Debug, Clone)]
52#[cfg_attr(
53 all(target_arch = "wasm32", feature = "wasm"),
54 derive(Serialize, Deserialize)
55)]
56pub struct DiParams {
57 pub period: Option<usize>,
58}
59
60impl Default for DiParams {
61 fn default() -> Self {
62 Self { period: Some(14) }
63 }
64}
65
66#[derive(Debug, Clone)]
67pub struct DiInput<'a> {
68 pub data: DiData<'a>,
69 pub params: DiParams,
70}
71
72impl<'a> DiInput<'a> {
73 #[inline(always)]
74 pub fn from_candles(candles: &'a Candles, params: DiParams) -> Self {
75 Self {
76 data: DiData::Candles { candles },
77 params,
78 }
79 }
80 #[inline(always)]
81 pub fn from_slices(
82 high: &'a [f64],
83 low: &'a [f64],
84 close: &'a [f64],
85 params: DiParams,
86 ) -> Self {
87 Self {
88 data: DiData::Slices { high, low, close },
89 params,
90 }
91 }
92 #[inline(always)]
93 pub fn with_default_candles(candles: &'a Candles) -> Self {
94 Self {
95 data: DiData::Candles { candles },
96 params: DiParams::default(),
97 }
98 }
99 #[inline(always)]
100 pub fn get_period(&self) -> usize {
101 self.params.period.unwrap_or(14)
102 }
103 #[inline(always)]
104 pub fn as_slices(&self) -> (&'a [f64], &'a [f64], &'a [f64]) {
105 match &self.data {
106 DiData::Candles { candles } => (
107 source_type(candles, "high"),
108 source_type(candles, "low"),
109 source_type(candles, "close"),
110 ),
111 DiData::Slices { high, low, close } => (*high, *low, *close),
112 }
113 }
114}
115
116#[derive(Copy, Clone, Debug)]
117pub struct DiBuilder {
118 period: Option<usize>,
119 kernel: Kernel,
120}
121
122impl Default for DiBuilder {
123 fn default() -> Self {
124 Self {
125 period: None,
126 kernel: Kernel::Auto,
127 }
128 }
129}
130
131impl DiBuilder {
132 #[inline(always)]
133 pub fn new() -> Self {
134 Self::default()
135 }
136 #[inline(always)]
137 pub fn period(mut self, n: usize) -> Self {
138 self.period = Some(n);
139 self
140 }
141 #[inline(always)]
142 pub fn kernel(mut self, k: Kernel) -> Self {
143 self.kernel = k;
144 self
145 }
146 #[inline(always)]
147 pub fn apply(self, candles: &Candles) -> Result<DiOutput, DiError> {
148 let params = DiParams {
149 period: self.period,
150 };
151 let input = DiInput::from_candles(candles, params);
152 di_with_kernel(&input, self.kernel)
153 }
154 #[inline(always)]
155 pub fn apply_slices(
156 self,
157 high: &[f64],
158 low: &[f64],
159 close: &[f64],
160 ) -> Result<DiOutput, DiError> {
161 let params = DiParams {
162 period: self.period,
163 };
164 let input = DiInput::from_slices(high, low, close, params);
165 di_with_kernel(&input, self.kernel)
166 }
167 #[inline(always)]
168 pub fn into_stream(self) -> Result<DiStream, DiError> {
169 let params = DiParams {
170 period: self.period,
171 };
172 DiStream::try_new(params)
173 }
174}
175
176#[derive(Debug, Error)]
177pub enum DiError {
178 #[error("di: Empty data provided.")]
179 EmptyInputData,
180 #[error("di: Empty data provided for DI.")]
181 EmptyData,
182 #[error("di: Invalid period: period = {period}, data length = {data_len}")]
183 InvalidPeriod { period: usize, data_len: usize },
184 #[error("di: Not enough valid data: needed = {needed}, valid = {valid}")]
185 NotEnoughValidData { needed: usize, valid: usize },
186 #[error("di: All values are NaN.")]
187 AllValuesNaN,
188 #[error("di: Output length mismatch: expected {expected}, got {got}")]
189 OutputLengthMismatch { expected: usize, got: usize },
190 #[error("di: Invalid range expansion: start={start}, end={end}, step={step}")]
191 InvalidRange {
192 start: usize,
193 end: usize,
194 step: usize,
195 },
196 #[error("di: Invalid kernel for batch path: {0:?}")]
197 InvalidKernelForBatch(Kernel),
198}
199
200#[inline]
201pub fn di(input: &DiInput) -> Result<DiOutput, DiError> {
202 di_with_kernel(input, Kernel::Auto)
203}
204
205#[inline(always)]
206fn di_prepare<'a>(
207 input: &'a DiInput<'a>,
208 kernel: Kernel,
209) -> Result<(&'a [f64], &'a [f64], &'a [f64], usize, usize, Kernel), DiError> {
210 let (high, low, close) = match &input.data {
211 DiData::Candles { candles } => {
212 let h = source_type(candles, "high");
213 let l = source_type(candles, "low");
214 let c = source_type(candles, "close");
215 (h, l, c)
216 }
217 DiData::Slices { high, low, close } => (*high, *low, *close),
218 };
219 let n = high.len();
220 if n == 0 || low.len() != n || close.len() != n {
221 return Err(DiError::EmptyInputData);
222 }
223 let period = input.get_period();
224 if period == 0 || period > n {
225 return Err(DiError::InvalidPeriod {
226 period,
227 data_len: n,
228 });
229 }
230 let first_valid_idx =
231 (0..n).find(|&i| !(high[i].is_nan() || low[i].is_nan() || close[i].is_nan()));
232 let first_idx = match first_valid_idx {
233 Some(idx) => idx,
234 None => return Err(DiError::AllValuesNaN),
235 };
236 if (n - first_idx) < period {
237 return Err(DiError::NotEnoughValidData {
238 needed: period,
239 valid: n - first_idx,
240 });
241 }
242
243 let chosen = match kernel {
244 Kernel::Auto => Kernel::Scalar,
245 other => other,
246 };
247
248 Ok((high, low, close, period, first_idx, chosen))
249}
250
251pub fn di_with_kernel(input: &DiInput, kernel: Kernel) -> Result<DiOutput, DiError> {
252 let (high, low, close, period, first_idx, chosen) = di_prepare(input, kernel)?;
253
254 unsafe {
255 match chosen {
256 Kernel::Scalar | Kernel::ScalarBatch => di_scalar(high, low, close, period, first_idx),
257 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
258 Kernel::Avx2 | Kernel::Avx2Batch => di_avx2(high, low, close, period, first_idx),
259 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
260 Kernel::Avx512 | Kernel::Avx512Batch => di_avx512(high, low, close, period, first_idx),
261 _ => unreachable!(),
262 }
263 }
264}
265
266#[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
267#[inline(always)]
268pub unsafe fn di_avx2_into(
269 high: &[f64],
270 low: &[f64],
271 close: &[f64],
272 period: usize,
273 first_idx: usize,
274 out_plus: &mut [f64],
275 out_minus: &mut [f64],
276) {
277 di_scalar_into(high, low, close, period, first_idx, out_plus, out_minus)
278}
279
280#[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
281#[inline(always)]
282pub unsafe fn di_avx512_into(
283 high: &[f64],
284 low: &[f64],
285 close: &[f64],
286 period: usize,
287 first_idx: usize,
288 out_plus: &mut [f64],
289 out_minus: &mut [f64],
290) {
291 di_scalar_into(high, low, close, period, first_idx, out_plus, out_minus)
292}
293
294#[inline(always)]
295fn di_compute_into(
296 high: &[f64],
297 low: &[f64],
298 close: &[f64],
299 period: usize,
300 first_idx: usize,
301 kernel: Kernel,
302 out_plus: &mut [f64],
303 out_minus: &mut [f64],
304) {
305 unsafe {
306 match kernel {
307 Kernel::Scalar | Kernel::ScalarBatch => {
308 di_scalar_into(high, low, close, period, first_idx, out_plus, out_minus)
309 }
310 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
311 Kernel::Avx2 | Kernel::Avx2Batch => {
312 di_avx2_into(high, low, close, period, first_idx, out_plus, out_minus)
313 }
314 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
315 Kernel::Avx512 | Kernel::Avx512Batch => {
316 di_avx512_into(high, low, close, period, first_idx, out_plus, out_minus)
317 }
318 _ => unreachable!(),
319 }
320 }
321}
322
323pub fn di_into_slice(
324 dst_plus: &mut [f64],
325 dst_minus: &mut [f64],
326 input: &DiInput,
327 kern: Kernel,
328) -> Result<(), DiError> {
329 let (high, low, close, period, first_idx, chosen) = di_prepare(input, kern)?;
330
331 let n = high.len();
332 if dst_plus.len() != n || dst_minus.len() != n {
333 return Err(DiError::OutputLengthMismatch {
334 expected: n,
335 got: dst_plus.len().min(dst_minus.len()),
336 });
337 }
338
339 di_compute_into(
340 high, low, close, period, first_idx, chosen, dst_plus, dst_minus,
341 );
342
343 let warmup_end = first_idx + period - 1;
344 for v in &mut dst_plus[..warmup_end] {
345 *v = f64::from_bits(0x7ff8_0000_0000_0000);
346 }
347 for v in &mut dst_minus[..warmup_end] {
348 *v = f64::from_bits(0x7ff8_0000_0000_0000);
349 }
350
351 Ok(())
352}
353
354#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
355pub fn di_into(
356 input: &DiInput,
357 out_plus: &mut [f64],
358 out_minus: &mut [f64],
359) -> Result<(), DiError> {
360 let (high, low, close, period, first_idx, chosen) = di_prepare(input, Kernel::Auto)?;
361
362 let n = high.len();
363 if out_plus.len() != n || out_minus.len() != n {
364 return Err(DiError::OutputLengthMismatch {
365 expected: n,
366 got: out_plus.len().min(out_minus.len()),
367 });
368 }
369
370 di_compute_into(
371 high, low, close, period, first_idx, chosen, out_plus, out_minus,
372 );
373
374 let warmup_end = first_idx + period - 1;
375 let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
376 for v in &mut out_plus[..warmup_end] {
377 *v = qnan;
378 }
379 for v in &mut out_minus[..warmup_end] {
380 *v = qnan;
381 }
382
383 Ok(())
384}
385
386#[inline(always)]
387pub unsafe fn di_scalar_into(
388 high: &[f64],
389 low: &[f64],
390 close: &[f64],
391 period: usize,
392 first_idx: usize,
393 out_plus: &mut [f64],
394 out_minus: &mut [f64],
395) {
396 let n = high.len();
397 if n == 0 {
398 return;
399 }
400
401 let pf = period as f64;
402 let invp = pf.recip();
403 let keep = 1.0 - invp;
404
405 let mut prev_h = high[first_idx];
406 let mut prev_l = low[first_idx];
407 let mut prev_c = close[first_idx];
408
409 let start = first_idx + 1;
410 let stop = first_idx + period;
411 let mut plus_dm_sum = 0.0;
412 let mut minus_dm_sum = 0.0;
413 let mut tr_sum = 0.0;
414
415 let mut i = start;
416 while i < stop {
417 let ch = high[i];
418 let cl = low[i];
419 let cc = close[i];
420
421 let dp = ch - prev_h;
422 let dm = prev_l - cl;
423 if dp > dm && dp > 0.0 {
424 plus_dm_sum += dp;
425 }
426 if dm > dp && dm > 0.0 {
427 minus_dm_sum += dm;
428 }
429
430 let mut tr = ch - cl;
431 let tr2 = (ch - prev_c).abs();
432 let tr3 = (cl - prev_c).abs();
433 if tr2 > tr {
434 tr = tr2;
435 }
436 if tr3 > tr {
437 tr = tr3;
438 }
439 tr_sum += tr;
440
441 prev_h = ch;
442 prev_l = cl;
443 prev_c = cc;
444 i += 1;
445 }
446
447 let mut cur_plus = plus_dm_sum;
448 let mut cur_minus = minus_dm_sum;
449 let mut cur_tr = tr_sum;
450
451 let mut idx = stop - 1;
452 let mut scale = if cur_tr == 0.0 { 0.0 } else { 100.0 / cur_tr };
453 out_plus[idx] = cur_plus * scale;
454 out_minus[idx] = cur_minus * scale;
455 idx += 1;
456
457 while idx < n {
458 let ch = high[idx];
459 let cl = low[idx];
460 let cc = close[idx];
461
462 let dp = ch - prev_h;
463 let dm = prev_l - cl;
464 let inc_p = if dp > dm && dp > 0.0 { dp } else { 0.0 };
465 let inc_m = if dm > dp && dm > 0.0 { dm } else { 0.0 };
466
467 cur_plus = cur_plus.mul_add(keep, inc_p);
468 cur_minus = cur_minus.mul_add(keep, inc_m);
469
470 let mut tr = ch - cl;
471 let tr2 = (ch - prev_c).abs();
472 let tr3 = (cl - prev_c).abs();
473 if tr2 > tr {
474 tr = tr2;
475 }
476 if tr3 > tr {
477 tr = tr3;
478 }
479 cur_tr = cur_tr.mul_add(keep, tr);
480
481 scale = if cur_tr == 0.0 { 0.0 } else { 100.0 / cur_tr };
482 out_plus[idx] = cur_plus * scale;
483 out_minus[idx] = cur_minus * scale;
484
485 prev_h = ch;
486 prev_l = cl;
487 prev_c = cc;
488 idx += 1;
489 }
490}
491
492#[inline(always)]
493pub unsafe fn di_scalar(
494 high: &[f64],
495 low: &[f64],
496 close: &[f64],
497 period: usize,
498 first_idx: usize,
499) -> Result<DiOutput, DiError> {
500 let n = high.len();
501 let mut plus_di = alloc_with_nan_prefix(n, first_idx + period - 1);
502 let mut minus_di = alloc_with_nan_prefix(n, first_idx + period - 1);
503 di_scalar_into(
504 high,
505 low,
506 close,
507 period,
508 first_idx,
509 &mut plus_di,
510 &mut minus_di,
511 );
512 Ok(DiOutput {
513 plus: plus_di,
514 minus: minus_di,
515 })
516}
517
518#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
519#[inline(always)]
520pub unsafe fn di_avx2_into(
521 high: &[f64],
522 low: &[f64],
523 close: &[f64],
524 period: usize,
525 first_idx: usize,
526 out_plus: &mut [f64],
527 out_minus: &mut [f64],
528) {
529 di_scalar_into(high, low, close, period, first_idx, out_plus, out_minus)
530}
531
532#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
533#[inline(always)]
534pub unsafe fn di_avx2(
535 high: &[f64],
536 low: &[f64],
537 close: &[f64],
538 period: usize,
539 first_idx: usize,
540) -> Result<DiOutput, DiError> {
541 di_scalar(high, low, close, period, first_idx)
542}
543
544#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
545#[inline(always)]
546pub unsafe fn di_avx512_into(
547 high: &[f64],
548 low: &[f64],
549 close: &[f64],
550 period: usize,
551 first_idx: usize,
552 out_plus: &mut [f64],
553 out_minus: &mut [f64],
554) {
555 di_scalar_into(high, low, close, period, first_idx, out_plus, out_minus)
556}
557
558#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
559#[inline(always)]
560pub unsafe fn di_avx512(
561 high: &[f64],
562 low: &[f64],
563 close: &[f64],
564 period: usize,
565 first_idx: usize,
566) -> Result<DiOutput, DiError> {
567 di_scalar(high, low, close, period, first_idx)
568}
569
570#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
571#[inline(always)]
572pub unsafe fn di_avx512_short(
573 high: &[f64],
574 low: &[f64],
575 close: &[f64],
576 period: usize,
577 first_idx: usize,
578) -> Result<DiOutput, DiError> {
579 di_avx512(high, low, close, period, first_idx)
580}
581
582#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
583#[inline(always)]
584pub unsafe fn di_avx512_long(
585 high: &[f64],
586 low: &[f64],
587 close: &[f64],
588 period: usize,
589 first_idx: usize,
590) -> Result<DiOutput, DiError> {
591 di_avx512(high, low, close, period, first_idx)
592}
593
594#[derive(Debug, Clone)]
595pub struct DiStream {
596 period: usize,
597 keep: f64,
598
599 prev_h: f64,
600 prev_l: f64,
601 prev_c: f64,
602 have_prev: bool,
603
604 warm_plus: f64,
605 warm_minus: f64,
606 warm_tr: f64,
607 warm_count: usize,
608
609 cur_plus: f64,
610 cur_minus: f64,
611 cur_tr: f64,
612 warmed: bool,
613}
614
615impl DiStream {
616 pub fn try_new(params: DiParams) -> Result<Self, DiError> {
617 let period = params.period.unwrap_or(14);
618 if period == 0 {
619 return Err(DiError::InvalidPeriod {
620 period,
621 data_len: 0,
622 });
623 }
624 let pf = period as f64;
625 Ok(Self {
626 period,
627 keep: 1.0 - pf.recip(),
628
629 prev_h: f64::NAN,
630 prev_l: f64::NAN,
631 prev_c: f64::NAN,
632 have_prev: false,
633
634 warm_plus: 0.0,
635 warm_minus: 0.0,
636 warm_tr: 0.0,
637 warm_count: 0,
638
639 cur_plus: 0.0,
640 cur_minus: 0.0,
641 cur_tr: 0.0,
642 warmed: false,
643 })
644 }
645
646 #[inline(always)]
647 fn reset(&mut self) {
648 self.prev_h = f64::NAN;
649 self.prev_l = f64::NAN;
650 self.prev_c = f64::NAN;
651 self.have_prev = false;
652
653 self.warm_plus = 0.0;
654 self.warm_minus = 0.0;
655 self.warm_tr = 0.0;
656 self.warm_count = 0;
657
658 self.cur_plus = 0.0;
659 self.cur_minus = 0.0;
660 self.cur_tr = 0.0;
661 self.warmed = false;
662 }
663
664 #[inline(always)]
665 fn tr_fast(high: f64, low: f64, prev_close: f64) -> f64 {
666 let hi = if high > prev_close { high } else { prev_close };
667 let lo = if low < prev_close { low } else { prev_close };
668 hi - lo
669 }
670
671 #[inline(always)]
672 pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<(f64, f64)> {
673 if high.is_nan() || low.is_nan() || close.is_nan() {
674 self.reset();
675 return None;
676 }
677
678 if !self.have_prev {
679 self.prev_h = high;
680 self.prev_l = low;
681 self.prev_c = close;
682 self.have_prev = true;
683 self.warm_count = 1;
684 return None;
685 }
686
687 let dp = high - self.prev_h;
688 let dm = self.prev_l - low;
689
690 let inc_p = if dp > dm && dp > 0.0 { dp } else { 0.0 };
691 let inc_m = if dm > dp && dm > 0.0 { dm } else { 0.0 };
692 let tr = Self::tr_fast(high, low, self.prev_c);
693
694 self.prev_h = high;
695 self.prev_l = low;
696 self.prev_c = close;
697
698 if !self.warmed {
699 self.warm_plus += inc_p;
700 self.warm_minus += inc_m;
701 self.warm_tr += tr;
702 self.warm_count += 1;
703
704 if self.warm_count < self.period {
705 return None;
706 }
707
708 self.cur_plus = self.warm_plus;
709 self.cur_minus = self.warm_minus;
710 self.cur_tr = self.warm_tr;
711 self.warmed = true;
712
713 let scale = if self.cur_tr == 0.0 {
714 0.0
715 } else {
716 100.0 / self.cur_tr
717 };
718 return Some((self.cur_plus * scale, self.cur_minus * scale));
719 }
720
721 self.cur_plus = self.cur_plus.mul_add(self.keep, inc_p);
722 self.cur_minus = self.cur_minus.mul_add(self.keep, inc_m);
723 self.cur_tr = self.cur_tr.mul_add(self.keep, tr);
724
725 let scale = if self.cur_tr == 0.0 {
726 0.0
727 } else {
728 100.0 / self.cur_tr
729 };
730 Some((self.cur_plus * scale, self.cur_minus * scale))
731 }
732}
733
734#[derive(Clone, Debug)]
735pub struct DiBatchRange {
736 pub period: (usize, usize, usize),
737}
738
739impl Default for DiBatchRange {
740 fn default() -> Self {
741 Self {
742 period: (14, 263, 1),
743 }
744 }
745}
746
747#[derive(Clone, Debug, Default)]
748pub struct DiBatchBuilder {
749 range: DiBatchRange,
750 kernel: Kernel,
751}
752
753impl DiBatchBuilder {
754 pub fn new() -> Self {
755 Self::default()
756 }
757 pub fn kernel(mut self, k: Kernel) -> Self {
758 self.kernel = k;
759 self
760 }
761 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
762 self.range.period = (start, end, step);
763 self
764 }
765 pub fn period_static(mut self, p: usize) -> Self {
766 self.range.period = (p, p, 0);
767 self
768 }
769 pub fn apply_slices(
770 self,
771 high: &[f64],
772 low: &[f64],
773 close: &[f64],
774 ) -> Result<DiBatchOutput, DiError> {
775 di_batch_with_kernel(high, low, close, &self.range, self.kernel)
776 }
777 pub fn apply_candles(self, c: &Candles) -> Result<DiBatchOutput, DiError> {
778 let h = source_type(c, "high");
779 let l = source_type(c, "low");
780 let cl = source_type(c, "close");
781 self.apply_slices(h, l, cl)
782 }
783 pub fn with_default_candles(c: &Candles) -> Result<DiBatchOutput, DiError> {
784 DiBatchBuilder::new().kernel(Kernel::Auto).apply_candles(c)
785 }
786}
787
788pub fn di_batch_with_kernel(
789 high: &[f64],
790 low: &[f64],
791 close: &[f64],
792 sweep: &DiBatchRange,
793 k: Kernel,
794) -> Result<DiBatchOutput, DiError> {
795 let kernel = match k {
796 Kernel::Auto => detect_best_batch_kernel(),
797 other if other.is_batch() => other,
798 other => return Err(DiError::InvalidKernelForBatch(other)),
799 };
800 let simd = match kernel {
801 Kernel::Avx512Batch => Kernel::Avx512,
802 Kernel::Avx2Batch => Kernel::Avx2,
803 Kernel::ScalarBatch => Kernel::Scalar,
804 _ => unreachable!(),
805 };
806 di_batch_par_slice(high, low, close, sweep, simd)
807}
808
809#[derive(Clone, Debug)]
810pub struct DiBatchOutput {
811 pub plus: Vec<f64>,
812 pub minus: Vec<f64>,
813 pub combos: Vec<DiParams>,
814 pub rows: usize,
815 pub cols: usize,
816}
817impl DiBatchOutput {
818 pub fn row_for_params(&self, p: &DiParams) -> Option<usize> {
819 self.combos
820 .iter()
821 .position(|c| c.period.unwrap_or(14) == p.period.unwrap_or(14))
822 }
823 pub fn plus_for(&self, p: &DiParams) -> Option<&[f64]> {
824 self.row_for_params(p).map(|row| {
825 let start = row * self.cols;
826 &self.plus[start..start + self.cols]
827 })
828 }
829 pub fn minus_for(&self, p: &DiParams) -> Option<&[f64]> {
830 self.row_for_params(p).map(|row| {
831 let start = row * self.cols;
832 &self.minus[start..start + self.cols]
833 })
834 }
835}
836
837#[inline(always)]
838fn expand_grid(r: &DiBatchRange) -> Vec<DiParams> {
839 fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
840 if step == 0 || start == end {
841 return vec![start];
842 }
843 if start < end {
844 return (start..=end).step_by(step.max(1)).collect();
845 }
846
847 let mut v = Vec::new();
848 let mut cur = start;
849 while cur >= end {
850 v.push(cur);
851 if cur == end {
852 break;
853 }
854 cur = cur.saturating_sub(step.max(1));
855 if cur < end {
856 break;
857 }
858 }
859 v
860 }
861 let periods = axis_usize(r.period);
862 let mut out = Vec::with_capacity(periods.len());
863 for &p in &periods {
864 out.push(DiParams { period: Some(p) });
865 }
866 out
867}
868
869#[inline(always)]
870pub fn di_batch_slice(
871 high: &[f64],
872 low: &[f64],
873 close: &[f64],
874 sweep: &DiBatchRange,
875 kern: Kernel,
876) -> Result<DiBatchOutput, DiError> {
877 di_batch_inner(high, low, close, sweep, kern, false)
878}
879
880#[inline(always)]
881pub fn di_batch_par_slice(
882 high: &[f64],
883 low: &[f64],
884 close: &[f64],
885 sweep: &DiBatchRange,
886 kern: Kernel,
887) -> Result<DiBatchOutput, DiError> {
888 di_batch_inner(high, low, close, sweep, kern, true)
889}
890
891#[inline(always)]
892fn di_batch_inner_into(
893 high: &[f64],
894 low: &[f64],
895 close: &[f64],
896 sweep: &DiBatchRange,
897 kern: Kernel,
898 parallel: bool,
899 out_plus: &mut [f64],
900 out_minus: &mut [f64],
901) -> Result<Vec<DiParams>, DiError> {
902 let combos = expand_grid(sweep);
903 if combos.is_empty() {
904 let (s, e, st) = sweep.period;
905 return Err(DiError::InvalidRange {
906 start: s,
907 end: e,
908 step: st,
909 });
910 }
911 let n = high.len();
912 if n == 0 || low.len() != n || close.len() != n {
913 return Err(DiError::EmptyInputData);
914 }
915 let first = (0..n)
916 .find(|&i| !(high[i].is_nan() || low[i].is_nan() || close[i].is_nan()))
917 .ok_or(DiError::AllValuesNaN)?;
918 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
919 if n - first < max_p {
920 return Err(DiError::NotEnoughValidData {
921 needed: max_p,
922 valid: n - first,
923 });
924 }
925
926 let rows = combos.len();
927 let cols = n;
928
929 unsafe {
930 let total = rows.checked_mul(cols).ok_or(DiError::InvalidRange {
931 start: sweep.period.0,
932 end: sweep.period.1,
933 step: sweep.period.2,
934 })?;
935 let plus_mu = std::slice::from_raw_parts_mut(
936 out_plus.as_mut_ptr() as *mut std::mem::MaybeUninit<f64>,
937 total,
938 );
939 let minus_mu = std::slice::from_raw_parts_mut(
940 out_minus.as_mut_ptr() as *mut std::mem::MaybeUninit<f64>,
941 total,
942 );
943 let warm: Vec<usize> = combos
944 .iter()
945 .map(|c| first.saturating_add(c.period.unwrap()).saturating_sub(1))
946 .collect();
947 init_matrix_prefixes(plus_mu, cols, &warm);
948 init_matrix_prefixes(minus_mu, cols, &warm);
949 }
950
951 let mut up = vec![0.0f64; n];
952 let mut dn = vec![0.0f64; n];
953 let mut tr = vec![0.0f64; n];
954 {
955 let mut prev_h = high[first];
956 let mut prev_l = low[first];
957 let mut prev_c = close[first];
958 let mut i = first + 1;
959 while i < n {
960 let ch = high[i];
961 let cl = low[i];
962 let dp = ch - prev_h;
963 let dm = prev_l - cl;
964 if dp > dm && dp > 0.0 {
965 up[i] = dp;
966 }
967 if dm > dp && dm > 0.0 {
968 dn[i] = dm;
969 }
970 let mut t = ch - cl;
971 let t2 = (ch - prev_c).abs();
972 let t3 = (cl - prev_c).abs();
973 if t2 > t {
974 t = t2;
975 }
976 if t3 > t {
977 t = t3;
978 }
979 tr[i] = t;
980 prev_h = ch;
981 prev_l = cl;
982 prev_c = close[i];
983 i += 1;
984 }
985 }
986
987 let do_row = |row: usize, out_plus: &mut [f64], out_minus: &mut [f64]| unsafe {
988 let period = combos[row].period.unwrap();
989 let result = di_row_scalar_precomputed(&up, &dn, &tr, period, first, out_plus, out_minus);
990 debug_assert!(result.is_ok());
991 };
992
993 if parallel {
994 #[cfg(not(target_arch = "wasm32"))]
995 {
996 out_plus
997 .par_chunks_mut(cols)
998 .zip(out_minus.par_chunks_mut(cols))
999 .enumerate()
1000 .for_each(|(row, (pl, mi))| do_row(row, pl, mi));
1001 }
1002
1003 #[cfg(target_arch = "wasm32")]
1004 {
1005 for (row, (pl, mi)) in out_plus
1006 .chunks_mut(cols)
1007 .zip(out_minus.chunks_mut(cols))
1008 .enumerate()
1009 {
1010 do_row(row, pl, mi);
1011 }
1012 }
1013 } else {
1014 for (row, (pl, mi)) in out_plus
1015 .chunks_mut(cols)
1016 .zip(out_minus.chunks_mut(cols))
1017 .enumerate()
1018 {
1019 do_row(row, pl, mi);
1020 }
1021 }
1022
1023 Ok(combos)
1024}
1025
1026#[inline(always)]
1027fn di_batch_inner(
1028 high: &[f64],
1029 low: &[f64],
1030 close: &[f64],
1031 sweep: &DiBatchRange,
1032 kern: Kernel,
1033 parallel: bool,
1034) -> Result<DiBatchOutput, DiError> {
1035 let combos = expand_grid(sweep);
1036 if combos.is_empty() {
1037 let (s, e, st) = sweep.period;
1038 return Err(DiError::InvalidRange {
1039 start: s,
1040 end: e,
1041 step: st,
1042 });
1043 }
1044 let n = high.len();
1045 if n == 0 || low.len() != n || close.len() != n {
1046 return Err(DiError::EmptyInputData);
1047 }
1048 let first = (0..n)
1049 .find(|&i| !(high[i].is_nan() || low[i].is_nan() || close[i].is_nan()))
1050 .ok_or(DiError::AllValuesNaN)?;
1051 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1052 if n - first < max_p {
1053 return Err(DiError::NotEnoughValidData {
1054 needed: max_p,
1055 valid: n - first,
1056 });
1057 }
1058
1059 let rows = combos.len();
1060 let cols = n;
1061
1062 let _ = rows.checked_mul(cols).ok_or(DiError::InvalidRange {
1063 start: sweep.period.0,
1064 end: sweep.period.1,
1065 step: sweep.period.2,
1066 })?;
1067
1068 let warmup_periods: Vec<usize> = combos
1069 .iter()
1070 .map(|c| first.saturating_add(c.period.unwrap()).saturating_sub(1))
1071 .collect();
1072
1073 let mut plus_mu = make_uninit_matrix(rows, cols);
1074 let mut minus_mu = make_uninit_matrix(rows, cols);
1075
1076 init_matrix_prefixes(&mut plus_mu, cols, &warmup_periods);
1077 init_matrix_prefixes(&mut minus_mu, cols, &warmup_periods);
1078
1079 let mut plus_guard = core::mem::ManuallyDrop::new(plus_mu);
1080 let mut minus_guard = core::mem::ManuallyDrop::new(minus_mu);
1081 let plus: &mut [f64] = unsafe {
1082 core::slice::from_raw_parts_mut(plus_guard.as_mut_ptr() as *mut f64, plus_guard.len())
1083 };
1084 let minus: &mut [f64] = unsafe {
1085 core::slice::from_raw_parts_mut(minus_guard.as_mut_ptr() as *mut f64, minus_guard.len())
1086 };
1087
1088 let mut up = vec![0.0f64; n];
1089 let mut dn = vec![0.0f64; n];
1090 let mut tr = vec![0.0f64; n];
1091 {
1092 let mut prev_h = high[first];
1093 let mut prev_l = low[first];
1094 let mut prev_c = close[first];
1095 let mut i = first + 1;
1096 while i < n {
1097 let ch = high[i];
1098 let cl = low[i];
1099 let dp = ch - prev_h;
1100 let dm = prev_l - cl;
1101 if dp > dm && dp > 0.0 {
1102 up[i] = dp;
1103 }
1104 if dm > dp && dm > 0.0 {
1105 dn[i] = dm;
1106 }
1107 let mut t = ch - cl;
1108 let t2 = (ch - prev_c).abs();
1109 let t3 = (cl - prev_c).abs();
1110 if t2 > t {
1111 t = t2;
1112 }
1113 if t3 > t {
1114 t = t3;
1115 }
1116 tr[i] = t;
1117 prev_h = ch;
1118 prev_l = cl;
1119 prev_c = close[i];
1120 i += 1;
1121 }
1122 }
1123
1124 let do_row = |row: usize, out_plus: &mut [f64], out_minus: &mut [f64]| unsafe {
1125 let period = combos[row].period.unwrap();
1126 let result = di_row_scalar_precomputed(&up, &dn, &tr, period, first, out_plus, out_minus);
1127 debug_assert!(result.is_ok());
1128 };
1129
1130 if parallel {
1131 #[cfg(not(target_arch = "wasm32"))]
1132 {
1133 plus.par_chunks_mut(cols)
1134 .zip(minus.par_chunks_mut(cols))
1135 .enumerate()
1136 .for_each(|(row, (pl, mi))| do_row(row, pl, mi));
1137 }
1138
1139 #[cfg(target_arch = "wasm32")]
1140 {
1141 for (row, (pl, mi)) in plus
1142 .chunks_mut(cols)
1143 .zip(minus.chunks_mut(cols))
1144 .enumerate()
1145 {
1146 do_row(row, pl, mi);
1147 }
1148 }
1149 } else {
1150 for (row, (pl, mi)) in plus
1151 .chunks_mut(cols)
1152 .zip(minus.chunks_mut(cols))
1153 .enumerate()
1154 {
1155 do_row(row, pl, mi);
1156 }
1157 }
1158
1159 let plus = unsafe {
1160 Vec::from_raw_parts(
1161 plus_guard.as_mut_ptr() as *mut f64,
1162 plus_guard.len(),
1163 plus_guard.capacity(),
1164 )
1165 };
1166 let minus = unsafe {
1167 Vec::from_raw_parts(
1168 minus_guard.as_mut_ptr() as *mut f64,
1169 minus_guard.len(),
1170 minus_guard.capacity(),
1171 )
1172 };
1173
1174 Ok(DiBatchOutput {
1175 plus,
1176 minus,
1177 combos,
1178 rows,
1179 cols,
1180 })
1181}
1182
1183#[inline(always)]
1184pub unsafe fn di_row_scalar(
1185 high: &[f64],
1186 low: &[f64],
1187 close: &[f64],
1188 period: usize,
1189 first: usize,
1190 out_plus: &mut [f64],
1191 out_minus: &mut [f64],
1192) -> Result<(), DiError> {
1193 let n = high.len();
1194 if n == 0 {
1195 return Ok(());
1196 }
1197
1198 let pf = period as f64;
1199 let invp = pf.recip();
1200 let keep = 1.0 - invp;
1201
1202 let mut prev_h = high[first];
1203 let mut prev_l = low[first];
1204 let mut prev_c = close[first];
1205
1206 let start = first + 1;
1207 let stop = first + period;
1208 let mut plus_dm_sum = 0.0;
1209 let mut minus_dm_sum = 0.0;
1210 let mut tr_sum = 0.0;
1211
1212 let mut i = start;
1213 while i < stop {
1214 let ch = high[i];
1215 let cl = low[i];
1216 let cc = close[i];
1217
1218 let dp = ch - prev_h;
1219 let dm = prev_l - cl;
1220 if dp > dm && dp > 0.0 {
1221 plus_dm_sum += dp;
1222 }
1223 if dm > dp && dm > 0.0 {
1224 minus_dm_sum += dm;
1225 }
1226
1227 let mut tr = ch - cl;
1228 let tr2 = (ch - prev_c).abs();
1229 let tr3 = (cl - prev_c).abs();
1230 if tr2 > tr {
1231 tr = tr2;
1232 }
1233 if tr3 > tr {
1234 tr = tr3;
1235 }
1236 tr_sum += tr;
1237
1238 prev_h = ch;
1239 prev_l = cl;
1240 prev_c = cc;
1241 i += 1;
1242 }
1243
1244 let mut cur_plus = plus_dm_sum;
1245 let mut cur_minus = minus_dm_sum;
1246 let mut cur_tr = tr_sum;
1247
1248 let mut idx = stop - 1;
1249 let mut scale = if cur_tr == 0.0 { 0.0 } else { 100.0 / cur_tr };
1250 out_plus[idx] = cur_plus * scale;
1251 out_minus[idx] = cur_minus * scale;
1252 idx += 1;
1253
1254 while idx < n {
1255 let ch = high[idx];
1256 let cl = low[idx];
1257 let cc = close[idx];
1258
1259 let dp = ch - prev_h;
1260 let dm = prev_l - cl;
1261 let inc_p = if dp > dm && dp > 0.0 { dp } else { 0.0 };
1262 let inc_m = if dm > dp && dm > 0.0 { dm } else { 0.0 };
1263
1264 cur_plus = cur_plus.mul_add(keep, inc_p);
1265 cur_minus = cur_minus.mul_add(keep, inc_m);
1266
1267 let mut tr = ch - cl;
1268 let tr2 = (ch - prev_c).abs();
1269 let tr3 = (cl - prev_c).abs();
1270 if tr2 > tr {
1271 tr = tr2;
1272 }
1273 if tr3 > tr {
1274 tr = tr3;
1275 }
1276 cur_tr = cur_tr.mul_add(keep, tr);
1277
1278 scale = if cur_tr == 0.0 { 0.0 } else { 100.0 / cur_tr };
1279 out_plus[idx] = cur_plus * scale;
1280 out_minus[idx] = cur_minus * scale;
1281
1282 prev_h = ch;
1283 prev_l = cl;
1284 prev_c = cc;
1285 idx += 1;
1286 }
1287 Ok(())
1288}
1289
1290#[inline(always)]
1291pub unsafe fn di_row_scalar_precomputed(
1292 up: &[f64],
1293 dn: &[f64],
1294 tr: &[f64],
1295 period: usize,
1296 first: usize,
1297 out_plus: &mut [f64],
1298 out_minus: &mut [f64],
1299) -> Result<(), DiError> {
1300 let n = up.len();
1301 if n == 0 {
1302 return Ok(());
1303 }
1304
1305 let pf = period as f64;
1306 let invp = pf.recip();
1307 let keep = 1.0 - invp;
1308
1309 let start = first + 1;
1310 let stop = first + period;
1311 let mut plus_dm_sum = 0.0;
1312 let mut minus_dm_sum = 0.0;
1313 let mut tr_sum = 0.0;
1314
1315 let mut i = start;
1316 while i < stop {
1317 plus_dm_sum += up[i];
1318 minus_dm_sum += dn[i];
1319 tr_sum += tr[i];
1320 i += 1;
1321 }
1322
1323 let mut cur_plus = plus_dm_sum;
1324 let mut cur_minus = minus_dm_sum;
1325 let mut cur_tr = tr_sum;
1326
1327 let mut idx = stop - 1;
1328 let mut scale = if cur_tr == 0.0 { 0.0 } else { 100.0 / cur_tr };
1329 out_plus[idx] = cur_plus * scale;
1330 out_minus[idx] = cur_minus * scale;
1331 idx += 1;
1332
1333 while idx < n {
1334 cur_plus = cur_plus.mul_add(keep, up[idx]);
1335 cur_minus = cur_minus.mul_add(keep, dn[idx]);
1336 cur_tr = cur_tr.mul_add(keep, tr[idx]);
1337 scale = if cur_tr == 0.0 { 0.0 } else { 100.0 / cur_tr };
1338 out_plus[idx] = cur_plus * scale;
1339 out_minus[idx] = cur_minus * scale;
1340 idx += 1;
1341 }
1342 Ok(())
1343}
1344
1345#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1346#[inline(always)]
1347pub unsafe fn di_row_avx2(
1348 high: &[f64],
1349 low: &[f64],
1350 close: &[f64],
1351 period: usize,
1352 first: usize,
1353 out_plus: &mut [f64],
1354 out_minus: &mut [f64],
1355) -> Result<(), DiError> {
1356 di_row_scalar(high, low, close, period, first, out_plus, out_minus)
1357}
1358
1359#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1360#[inline(always)]
1361pub unsafe fn di_row_avx512(
1362 high: &[f64],
1363 low: &[f64],
1364 close: &[f64],
1365 period: usize,
1366 first: usize,
1367 out_plus: &mut [f64],
1368 out_minus: &mut [f64],
1369) -> Result<(), DiError> {
1370 di_row_scalar(high, low, close, period, first, out_plus, out_minus)
1371}
1372
1373#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1374#[inline(always)]
1375pub unsafe fn di_row_avx512_short(
1376 high: &[f64],
1377 low: &[f64],
1378 close: &[f64],
1379 period: usize,
1380 first: usize,
1381 out_plus: &mut [f64],
1382 out_minus: &mut [f64],
1383) -> Result<(), DiError> {
1384 di_row_avx512(high, low, close, period, first, out_plus, out_minus)
1385}
1386
1387#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1388#[inline(always)]
1389pub unsafe fn di_row_avx512_long(
1390 high: &[f64],
1391 low: &[f64],
1392 close: &[f64],
1393 period: usize,
1394 first: usize,
1395 out_plus: &mut [f64],
1396 out_minus: &mut [f64],
1397) -> Result<(), DiError> {
1398 di_row_avx512(high, low, close, period, first, out_plus, out_minus)
1399}
1400
1401#[inline(always)]
1402fn true_range(current_high: f64, current_low: f64, prev_close: f64) -> f64 {
1403 let mut tr1 = current_high - current_low;
1404 let tr2 = (current_high - prev_close).abs();
1405 let tr3 = (current_low - prev_close).abs();
1406 if tr2 > tr1 {
1407 tr1 = tr2;
1408 }
1409 if tr3 > tr1 {
1410 tr1 = tr3;
1411 }
1412 tr1
1413}
1414
1415#[cfg(test)]
1416mod tests {
1417 use super::*;
1418 use crate::skip_if_unsupported;
1419 use crate::utilities::data_loader::read_candles_from_csv;
1420
1421 fn check_di_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1422 skip_if_unsupported!(kernel, test_name);
1423 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1424 let candles = read_candles_from_csv(file_path)?;
1425 let default_params = DiParams { period: None };
1426 let input_default = DiInput::from_candles(&candles, default_params);
1427 let output_default = di_with_kernel(&input_default, kernel)?;
1428 assert_eq!(output_default.plus.len(), candles.close.len());
1429 assert_eq!(output_default.minus.len(), candles.close.len());
1430 let params_period_10 = DiParams { period: Some(10) };
1431 let input_period_10 = DiInput::from_candles(&candles, params_period_10);
1432 let output_period_10 = di_with_kernel(&input_period_10, kernel)?;
1433 assert_eq!(output_period_10.plus.len(), candles.close.len());
1434 assert_eq!(output_period_10.minus.len(), candles.close.len());
1435 Ok(())
1436 }
1437 fn check_di_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1438 skip_if_unsupported!(kernel, test_name);
1439 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1440 let candles = read_candles_from_csv(file_path)?;
1441 let params = DiParams { period: Some(14) };
1442 let input = DiInput::from_candles(&candles, params);
1443 let di_result = di_with_kernel(&input, kernel)?;
1444 assert_eq!(di_result.plus.len(), candles.close.len());
1445 assert_eq!(di_result.minus.len(), candles.close.len());
1446 let test_plus = [
1447 10.99067007335658,
1448 11.306993269828585,
1449 10.948661818939213,
1450 10.683207768215592,
1451 9.802180952619183,
1452 ];
1453 let test_minus = [
1454 28.06728094177839,
1455 27.331240567633152,
1456 27.759989125359493,
1457 26.951434842917386,
1458 30.748897303623057,
1459 ];
1460 if di_result.plus.len() > 5 {
1461 let plus_tail = &di_result.plus[di_result.plus.len() - 5..];
1462 let minus_tail = &di_result.minus[di_result.minus.len() - 5..];
1463 for i in 0..5 {
1464 assert!(
1465 (plus_tail[i] - test_plus[i]).abs() < 1e-6,
1466 "Mismatch in +DI at tail index {}: expected {}, got {}",
1467 i,
1468 test_plus[i],
1469 plus_tail[i]
1470 );
1471 assert!(
1472 (minus_tail[i] - test_minus[i]).abs() < 1e-6,
1473 "Mismatch in -DI at tail index {}: expected {}, got {}",
1474 i,
1475 test_minus[i],
1476 minus_tail[i]
1477 );
1478 }
1479 }
1480 Ok(())
1481 }
1482 fn check_di_with_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1483 skip_if_unsupported!(kernel, test_name);
1484 let high = [10.0, 11.0, 12.0];
1485 let low = [9.0, 8.0, 7.0];
1486 let close = [9.5, 10.0, 11.0];
1487 let params = DiParams { period: Some(0) };
1488 let input = DiInput::from_slices(&high, &low, &close, params);
1489 let result = di_with_kernel(&input, kernel);
1490 assert!(result.is_err());
1491 Ok(())
1492 }
1493 fn check_di_with_period_exceeding_data_length(
1494 test_name: &str,
1495 kernel: Kernel,
1496 ) -> Result<(), Box<dyn Error>> {
1497 skip_if_unsupported!(kernel, test_name);
1498 let high = [10.0, 11.0, 12.0];
1499 let low = [9.0, 8.0, 7.0];
1500 let close = [9.5, 10.0, 11.0];
1501 let params = DiParams { period: Some(10) };
1502 let input = DiInput::from_slices(&high, &low, &close, params);
1503 let result = di_with_kernel(&input, kernel);
1504 assert!(result.is_err());
1505 Ok(())
1506 }
1507 fn check_di_very_small_data_set(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1508 skip_if_unsupported!(kernel, test_name);
1509 let high = [42.0];
1510 let low = [41.0];
1511 let close = [41.5];
1512 let params = DiParams { period: Some(14) };
1513 let input = DiInput::from_slices(&high, &low, &close, params);
1514 let result = di_with_kernel(&input, kernel);
1515 assert!(result.is_err());
1516 Ok(())
1517 }
1518 fn check_di_with_slice_data_reinput(
1519 test_name: &str,
1520 kernel: Kernel,
1521 ) -> Result<(), Box<dyn Error>> {
1522 skip_if_unsupported!(kernel, test_name);
1523 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1524 let candles = read_candles_from_csv(file_path)?;
1525 let first_params = DiParams { period: Some(14) };
1526 let first_input = DiInput::from_candles(&candles, first_params);
1527 let first_result = di_with_kernel(&first_input, kernel)?;
1528 assert_eq!(first_result.plus.len(), candles.close.len());
1529 assert_eq!(first_result.minus.len(), candles.close.len());
1530 let second_params = DiParams { period: Some(14) };
1531 let second_input = DiInput::from_slices(
1532 &first_result.plus,
1533 &first_result.minus,
1534 &candles.close,
1535 second_params,
1536 );
1537 let second_result = di_with_kernel(&second_input, kernel)?;
1538 assert_eq!(second_result.plus.len(), first_result.plus.len());
1539 assert_eq!(second_result.minus.len(), first_result.minus.len());
1540 Ok(())
1541 }
1542 fn check_di_accuracy_nan_check(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1543 skip_if_unsupported!(kernel, test_name);
1544 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1545 let candles = read_candles_from_csv(file_path)?;
1546 let params = DiParams { period: Some(14) };
1547 let input = DiInput::from_candles(&candles, params);
1548 let di_result = di_with_kernel(&input, kernel)?;
1549 assert_eq!(di_result.plus.len(), candles.close.len());
1550 assert_eq!(di_result.minus.len(), candles.close.len());
1551 if di_result.plus.len() > 40 {
1552 for i in 40..di_result.plus.len() {
1553 assert!(!di_result.plus[i].is_nan());
1554 assert!(!di_result.minus[i].is_nan());
1555 }
1556 }
1557 Ok(())
1558 }
1559 #[cfg(debug_assertions)]
1560 fn check_di_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1561 skip_if_unsupported!(kernel, test_name);
1562
1563 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1564 let candles = read_candles_from_csv(file_path)?;
1565
1566 let test_params = vec![
1567 DiParams::default(),
1568 DiParams { period: Some(2) },
1569 DiParams { period: Some(5) },
1570 DiParams { period: Some(7) },
1571 DiParams { period: Some(10) },
1572 DiParams { period: Some(20) },
1573 DiParams { period: Some(30) },
1574 DiParams { period: Some(50) },
1575 DiParams { period: Some(100) },
1576 DiParams { period: Some(200) },
1577 ];
1578
1579 for (param_idx, params) in test_params.iter().enumerate() {
1580 let input = DiInput::from_candles(&candles, params.clone());
1581 let output = di_with_kernel(&input, kernel)?;
1582
1583 for (i, &val) in output.plus.iter().enumerate() {
1584 if val.is_nan() {
1585 continue;
1586 }
1587
1588 let bits = val.to_bits();
1589
1590 if bits == 0x11111111_11111111 {
1591 panic!(
1592 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1593 in plus output with params: {:?} (param set {})",
1594 test_name, val, bits, i, params, param_idx
1595 );
1596 }
1597
1598 if bits == 0x22222222_22222222 {
1599 panic!(
1600 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1601 in plus output with params: {:?} (param set {})",
1602 test_name, val, bits, i, params, param_idx
1603 );
1604 }
1605
1606 if bits == 0x33333333_33333333 {
1607 panic!(
1608 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1609 in plus output with params: {:?} (param set {})",
1610 test_name, val, bits, i, params, param_idx
1611 );
1612 }
1613 }
1614
1615 for (i, &val) in output.minus.iter().enumerate() {
1616 if val.is_nan() {
1617 continue;
1618 }
1619
1620 let bits = val.to_bits();
1621
1622 if bits == 0x11111111_11111111 {
1623 panic!(
1624 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1625 in minus output with params: {:?} (param set {})",
1626 test_name, val, bits, i, params, param_idx
1627 );
1628 }
1629
1630 if bits == 0x22222222_22222222 {
1631 panic!(
1632 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1633 in minus output with params: {:?} (param set {})",
1634 test_name, val, bits, i, params, param_idx
1635 );
1636 }
1637
1638 if bits == 0x33333333_33333333 {
1639 panic!(
1640 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1641 in minus output with params: {:?} (param set {})",
1642 test_name, val, bits, i, params, param_idx
1643 );
1644 }
1645 }
1646 }
1647
1648 Ok(())
1649 }
1650
1651 #[cfg(not(debug_assertions))]
1652 fn check_di_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1653 Ok(())
1654 }
1655
1656 #[cfg(test)]
1657 #[allow(clippy::float_cmp)]
1658 fn check_di_property(
1659 test_name: &str,
1660 kernel: Kernel,
1661 ) -> Result<(), Box<dyn std::error::Error>> {
1662 use proptest::prelude::*;
1663 skip_if_unsupported!(kernel, test_name);
1664
1665 let strat = (2usize..=50)
1666 .prop_flat_map(|period| {
1667 (
1668 100.0f64..5000.0f64,
1669 (period + 20)..400,
1670 0.001f64..0.05f64,
1671 -0.01f64..0.01f64,
1672 Just(period),
1673 )
1674 })
1675 .prop_map(|(base_price, data_len, volatility, trend, period)| {
1676 let mut high = Vec::with_capacity(data_len);
1677 let mut low = Vec::with_capacity(data_len);
1678 let mut close = Vec::with_capacity(data_len);
1679
1680 let mut price = base_price;
1681
1682 for i in 0..data_len {
1683 let trend_component = trend * i as f64;
1684 let random_component = ((i * 137 + 11) % 100) as f64 / 100.0 - 0.5;
1685 price = price * (1.0 + trend_component + random_component * volatility);
1686
1687 price = price.max(1.0);
1688
1689 let daily_range = price * volatility * (1.0 + ((i * 73) % 50) as f64 / 100.0);
1690 let h = price + daily_range * 0.5;
1691 let l = price - daily_range * 0.5;
1692
1693 let close_factor = ((i * 29 + 7) % 100) as f64 / 100.0;
1694 let c = l + (h - l) * close_factor;
1695
1696 high.push(h);
1697 low.push(l);
1698 close.push(c);
1699 }
1700
1701 (high, low, close, period)
1702 });
1703
1704 proptest::test_runner::TestRunner::default()
1705 .run(&strat, |(high, low, close, period)| {
1706 let params = DiParams {
1707 period: Some(period),
1708 };
1709 let input = DiInput::from_slices(&high, &low, &close, params);
1710
1711 let output = di_with_kernel(&input, kernel)?;
1712
1713 let ref_output = di_with_kernel(&input, Kernel::Scalar)?;
1714
1715 prop_assert_eq!(output.plus.len(), high.len());
1716 prop_assert_eq!(output.minus.len(), high.len());
1717
1718 let warmup_end = period - 1;
1719 for i in 0..warmup_end {
1720 prop_assert!(
1721 output.plus[i].is_nan(),
1722 "Expected NaN at index {} during warmup, got {}",
1723 i,
1724 output.plus[i]
1725 );
1726 prop_assert!(
1727 output.minus[i].is_nan(),
1728 "Expected NaN at index {} during warmup, got {}",
1729 i,
1730 output.minus[i]
1731 );
1732 }
1733
1734 for i in warmup_end..high.len() {
1735 let plus_val = output.plus[i];
1736 let minus_val = output.minus[i];
1737
1738 prop_assert!(
1739 plus_val.is_finite(),
1740 "Expected finite +DI at index {}, got {}",
1741 i,
1742 plus_val
1743 );
1744 prop_assert!(
1745 minus_val.is_finite(),
1746 "Expected finite -DI at index {}, got {}",
1747 i,
1748 minus_val
1749 );
1750
1751 prop_assert!(
1752 plus_val >= 0.0 && plus_val <= 100.0,
1753 "+DI at index {} = {} is out of range [0, 100]",
1754 i,
1755 plus_val
1756 );
1757 prop_assert!(
1758 minus_val >= 0.0 && minus_val <= 100.0,
1759 "-DI at index {} = {} is out of range [0, 100]",
1760 i,
1761 minus_val
1762 );
1763 }
1764
1765 for i in 0..high.len() {
1766 let plus_val = output.plus[i];
1767 let minus_val = output.minus[i];
1768 let ref_plus = ref_output.plus[i];
1769 let ref_minus = ref_output.minus[i];
1770
1771 if plus_val.is_nan() || ref_plus.is_nan() {
1772 prop_assert_eq!(
1773 plus_val.is_nan(),
1774 ref_plus.is_nan(),
1775 "NaN mismatch in +DI at index {}",
1776 i
1777 );
1778 } else {
1779 prop_assert!(
1780 (plus_val - ref_plus).abs() <= 1e-9,
1781 "+DI mismatch at index {}: {} vs {} (diff: {})",
1782 i,
1783 plus_val,
1784 ref_plus,
1785 (plus_val - ref_plus).abs()
1786 );
1787 }
1788
1789 if minus_val.is_nan() || ref_minus.is_nan() {
1790 prop_assert_eq!(
1791 minus_val.is_nan(),
1792 ref_minus.is_nan(),
1793 "NaN mismatch in -DI at index {}",
1794 i
1795 );
1796 } else {
1797 prop_assert!(
1798 (minus_val - ref_minus).abs() <= 1e-9,
1799 "-DI mismatch at index {}: {} vs {} (diff: {})",
1800 i,
1801 minus_val,
1802 ref_minus,
1803 (minus_val - ref_minus).abs()
1804 );
1805 }
1806 }
1807
1808 let constant_high = vec![100.0; 50];
1809 let constant_low = vec![100.0; 50];
1810 let constant_close = vec![100.0; 50];
1811 let const_params = DiParams {
1812 period: Some(period),
1813 };
1814 let const_input = DiInput::from_slices(
1815 &constant_high,
1816 &constant_low,
1817 &constant_close,
1818 const_params,
1819 );
1820
1821 if let Ok(const_output) = di_with_kernel(&const_input, kernel) {
1822 for i in (period + 5)..constant_high.len() {
1823 prop_assert!(
1824 const_output.plus[i] < 1.0,
1825 "Expected near-zero +DI for constant prices, got {} at index {}",
1826 const_output.plus[i],
1827 i
1828 );
1829 prop_assert!(
1830 const_output.minus[i] < 1.0,
1831 "Expected near-zero -DI for constant prices, got {} at index {}",
1832 const_output.minus[i],
1833 i
1834 );
1835 }
1836 }
1837
1838 let mut spike_high = vec![100.0; 30];
1839 let mut spike_low = vec![99.0; 30];
1840 let mut spike_close = vec![99.5; 30];
1841
1842 spike_high[15] = 120.0;
1843 spike_low[15] = 80.0;
1844 spike_close[15] = 100.0;
1845
1846 let spike_params = DiParams {
1847 period: Some(period.min(10)),
1848 };
1849 let spike_input =
1850 DiInput::from_slices(&spike_high, &spike_low, &spike_close, spike_params);
1851
1852 if let Ok(spike_output) = di_with_kernel(&spike_input, kernel) {
1853 for (i, (&plus, &minus)) in spike_output
1854 .plus
1855 .iter()
1856 .zip(spike_output.minus.iter())
1857 .enumerate()
1858 {
1859 if !plus.is_nan() {
1860 prop_assert!(
1861 plus >= 0.0 && plus <= 100.0,
1862 "Volatility spike caused +DI out of range at {}: {}",
1863 i,
1864 plus
1865 );
1866 }
1867 if !minus.is_nan() {
1868 prop_assert!(
1869 minus >= 0.0 && minus <= 100.0,
1870 "Volatility spike caused -DI out of range at {}: {}",
1871 i,
1872 minus
1873 );
1874 }
1875 }
1876 }
1877
1878 if period <= 20 {
1879 let trend_len = 200;
1880 let mut uptrend_high = Vec::with_capacity(trend_len);
1881 let mut uptrend_low = Vec::with_capacity(trend_len);
1882 let mut uptrend_close = Vec::with_capacity(trend_len);
1883
1884 for i in 0..trend_len {
1885 let price = 100.0 + i as f64 * 2.0;
1886 uptrend_high.push(price + 1.0);
1887 uptrend_low.push(price - 1.0);
1888 uptrend_close.push(price);
1889 }
1890
1891 let trend_params = DiParams {
1892 period: Some(period),
1893 };
1894 let trend_input = DiInput::from_slices(
1895 &uptrend_high,
1896 &uptrend_low,
1897 &uptrend_close,
1898 trend_params,
1899 );
1900
1901 if let Ok(trend_output) = di_with_kernel(&trend_input, kernel) {
1902 let check_start = trend_len / 2;
1903 let mut plus_wins = 0;
1904 let mut minus_wins = 0;
1905
1906 for i in check_start..trend_len {
1907 if trend_output.plus[i] > trend_output.minus[i] {
1908 plus_wins += 1;
1909 } else {
1910 minus_wins += 1;
1911 }
1912 }
1913
1914 if plus_wins + minus_wins > 0 {
1915 prop_assert!(
1916 plus_wins > minus_wins,
1917 "In uptrend, expected +DI > -DI more often. +DI wins: {}, -DI wins: {} (period: {})",
1918 plus_wins, minus_wins, period
1919 );
1920 }
1921 }
1922 }
1923
1924 Ok(())
1925 })
1926 .unwrap();
1927
1928 Ok(())
1929 }
1930
1931 #[test]
1932 fn test_di_into_matches_api() -> Result<(), Box<dyn Error>> {
1933 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1934 let candles = read_candles_from_csv(file_path)?;
1935
1936 let params = DiParams::default();
1937 let input = DiInput::from_candles(&candles, params);
1938
1939 let baseline = di(&input)?;
1940
1941 let n = candles.close.len();
1942 let mut plus = vec![0.0_f64; n];
1943 let mut minus = vec![0.0_f64; n];
1944 di_into(&input, &mut plus, &mut minus)?;
1945
1946 assert_eq!(baseline.plus.len(), plus.len());
1947 assert_eq!(baseline.minus.len(), minus.len());
1948
1949 fn eq_or_both_nan(a: f64, b: f64) -> bool {
1950 (a.is_nan() && b.is_nan()) || (a == b) || ((a - b).abs() <= 1e-12)
1951 }
1952
1953 for i in 0..n {
1954 assert!(
1955 eq_or_both_nan(baseline.plus[i], plus[i]),
1956 "+DI mismatch at index {}: baseline={} vs into={}",
1957 i,
1958 baseline.plus[i],
1959 plus[i]
1960 );
1961 assert!(
1962 eq_or_both_nan(baseline.minus[i], minus[i]),
1963 "-DI mismatch at index {}: baseline={} vs into={}",
1964 i,
1965 baseline.minus[i],
1966 minus[i]
1967 );
1968 }
1969 Ok(())
1970 }
1971
1972 macro_rules! generate_all_di_tests {
1973 ($($test_fn:ident),*) => {
1974 paste::paste! {
1975 $(
1976 #[test]
1977 fn [<$test_fn _scalar_f64>]() {
1978 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1979 }
1980 )*
1981 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1982 $(
1983 #[test]
1984 fn [<$test_fn _avx2_f64>]() {
1985 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1986 }
1987 #[test]
1988 fn [<$test_fn _avx512_f64>]() {
1989 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1990 }
1991 )*
1992 }
1993 }
1994 }
1995 generate_all_di_tests!(
1996 check_di_partial_params,
1997 check_di_accuracy,
1998 check_di_with_zero_period,
1999 check_di_with_period_exceeding_data_length,
2000 check_di_very_small_data_set,
2001 check_di_with_slice_data_reinput,
2002 check_di_accuracy_nan_check,
2003 check_di_no_poison
2004 );
2005
2006 #[cfg(test)]
2007 generate_all_di_tests!(check_di_property);
2008 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2009 skip_if_unsupported!(kernel, test);
2010
2011 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2012 let c = read_candles_from_csv(file)?;
2013
2014 let output = DiBatchBuilder::new().kernel(kernel).apply_candles(&c)?;
2015
2016 let def = DiParams::default();
2017 let row = output.plus_for(&def).expect("default row missing");
2018
2019 assert_eq!(row.len(), c.close.len());
2020
2021 let scalar = DiInput::from_candles(&c, DiParams::default());
2022 let scalar_out = di_with_kernel(&scalar, Kernel::Scalar)?;
2023 let plus_tail = &row[row.len() - 5..];
2024 let scalar_tail = &scalar_out.plus[scalar_out.plus.len() - 5..];
2025 for (i, (&a, &b)) in plus_tail.iter().zip(scalar_tail.iter()).enumerate() {
2026 assert!(
2027 (a - b).abs() < 1e-8,
2028 "[{test}] batch/scalar plus mismatch idx={i}: {a} vs {b}"
2029 );
2030 }
2031 Ok(())
2032 }
2033
2034 fn check_batch_period_range(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2035 skip_if_unsupported!(kernel, test);
2036
2037 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2038 let c = read_candles_from_csv(file)?;
2039
2040 let output = DiBatchBuilder::new()
2041 .kernel(kernel)
2042 .period_range(10, 20, 5)
2043 .apply_candles(&c)?;
2044
2045 assert_eq!(output.rows, 3);
2046 assert_eq!(output.cols, c.close.len());
2047
2048 let periods = [10, 15, 20];
2049 for (i, p) in periods.iter().enumerate() {
2050 let param = DiParams { period: Some(*p) };
2051 let plus = output.plus_for(¶m).expect("plus missing");
2052 let minus = output.minus_for(¶m).expect("minus missing");
2053 assert_eq!(plus.len(), c.close.len());
2054 assert_eq!(minus.len(), c.close.len());
2055 }
2056 Ok(())
2057 }
2058
2059 #[cfg(debug_assertions)]
2060 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2061 skip_if_unsupported!(kernel, test);
2062
2063 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2064 let c = read_candles_from_csv(file)?;
2065
2066 let test_configs = vec![
2067 (2, 10, 2),
2068 (5, 25, 5),
2069 (30, 60, 15),
2070 (2, 5, 1),
2071 (10, 10, 0),
2072 (14, 14, 0),
2073 (50, 50, 0),
2074 (7, 21, 7),
2075 ];
2076
2077 for (cfg_idx, &(p_start, p_end, p_step)) in test_configs.iter().enumerate() {
2078 let output = DiBatchBuilder::new()
2079 .kernel(kernel)
2080 .period_range(p_start, p_end, p_step)
2081 .apply_candles(&c)?;
2082
2083 for (idx, &val) in output.plus.iter().enumerate() {
2084 if val.is_nan() {
2085 continue;
2086 }
2087
2088 let bits = val.to_bits();
2089 let row = idx / output.cols;
2090 let col = idx % output.cols;
2091 let combo = &output.combos[row];
2092
2093 if bits == 0x11111111_11111111 {
2094 panic!(
2095 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2096 at row {} col {} (flat index {}) in plus output with params: period={}",
2097 test,
2098 cfg_idx,
2099 val,
2100 bits,
2101 row,
2102 col,
2103 idx,
2104 combo.period.unwrap_or(14)
2105 );
2106 }
2107
2108 if bits == 0x22222222_22222222 {
2109 panic!(
2110 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2111 at row {} col {} (flat index {}) in plus output with params: period={}",
2112 test,
2113 cfg_idx,
2114 val,
2115 bits,
2116 row,
2117 col,
2118 idx,
2119 combo.period.unwrap_or(14)
2120 );
2121 }
2122
2123 if bits == 0x33333333_33333333 {
2124 panic!(
2125 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2126 at row {} col {} (flat index {}) in plus output with params: period={}",
2127 test,
2128 cfg_idx,
2129 val,
2130 bits,
2131 row,
2132 col,
2133 idx,
2134 combo.period.unwrap_or(14)
2135 );
2136 }
2137 }
2138
2139 for (idx, &val) in output.minus.iter().enumerate() {
2140 if val.is_nan() {
2141 continue;
2142 }
2143
2144 let bits = val.to_bits();
2145 let row = idx / output.cols;
2146 let col = idx % output.cols;
2147 let combo = &output.combos[row];
2148
2149 if bits == 0x11111111_11111111 {
2150 panic!(
2151 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2152 at row {} col {} (flat index {}) in minus output with params: period={}",
2153 test,
2154 cfg_idx,
2155 val,
2156 bits,
2157 row,
2158 col,
2159 idx,
2160 combo.period.unwrap_or(14)
2161 );
2162 }
2163
2164 if bits == 0x22222222_22222222 {
2165 panic!(
2166 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2167 at row {} col {} (flat index {}) in minus output with params: period={}",
2168 test,
2169 cfg_idx,
2170 val,
2171 bits,
2172 row,
2173 col,
2174 idx,
2175 combo.period.unwrap_or(14)
2176 );
2177 }
2178
2179 if bits == 0x33333333_33333333 {
2180 panic!(
2181 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2182 at row {} col {} (flat index {}) in minus output with params: period={}",
2183 test,
2184 cfg_idx,
2185 val,
2186 bits,
2187 row,
2188 col,
2189 idx,
2190 combo.period.unwrap_or(14)
2191 );
2192 }
2193 }
2194 }
2195
2196 Ok(())
2197 }
2198
2199 #[cfg(not(debug_assertions))]
2200 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2201 Ok(())
2202 }
2203
2204 macro_rules! gen_batch_tests {
2205 ($fn_name:ident) => {
2206 paste::paste! {
2207 #[test] fn [<$fn_name _scalar>]() {
2208 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2209 }
2210 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2211 #[test] fn [<$fn_name _avx2>]() {
2212 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2213 }
2214 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2215 #[test] fn [<$fn_name _avx512>]() {
2216 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2217 }
2218 #[test] fn [<$fn_name _auto_detect>]() {
2219 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2220 }
2221 }
2222 };
2223 }
2224
2225 gen_batch_tests!(check_batch_default_row);
2226 gen_batch_tests!(check_batch_period_range);
2227 gen_batch_tests!(check_batch_no_poison);
2228}
2229
2230#[cfg(feature = "python")]
2231#[pyfunction(name = "di")]
2232#[pyo3(signature = (high, low, close, period, kernel=None))]
2233pub fn di_py<'py>(
2234 py: Python<'py>,
2235 high: PyReadonlyArray1<'py, f64>,
2236 low: PyReadonlyArray1<'py, f64>,
2237 close: PyReadonlyArray1<'py, f64>,
2238 period: usize,
2239 kernel: Option<&str>,
2240) -> PyResult<(Bound<'py, PyArray1<f64>>, Bound<'py, PyArray1<f64>>)> {
2241 let high_slice = high.as_slice()?;
2242 let low_slice = low.as_slice()?;
2243 let close_slice = close.as_slice()?;
2244 let kern = validate_kernel(kernel, false)?;
2245
2246 let params = DiParams {
2247 period: Some(period),
2248 };
2249 let input = DiInput::from_slices(high_slice, low_slice, close_slice, params);
2250
2251 let (plus_vec, minus_vec) = py
2252 .allow_threads(|| di_with_kernel(&input, kern).map(|o| (o.plus, o.minus)))
2253 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2254
2255 Ok((plus_vec.into_pyarray(py), minus_vec.into_pyarray(py)))
2256}
2257
2258#[cfg(feature = "python")]
2259#[pyclass(name = "DiStream")]
2260pub struct DiStreamPy {
2261 inner: DiStream,
2262}
2263
2264#[cfg(feature = "python")]
2265#[pymethods]
2266impl DiStreamPy {
2267 #[new]
2268 pub fn new(period: usize) -> PyResult<Self> {
2269 let params = DiParams {
2270 period: Some(period),
2271 };
2272 let inner = DiStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2273 Ok(DiStreamPy { inner })
2274 }
2275
2276 pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<(f64, f64)> {
2277 self.inner.update(high, low, close)
2278 }
2279}
2280
2281#[cfg(feature = "python")]
2282#[pyfunction(name = "di_batch")]
2283#[pyo3(signature = (high, low, close, period_range, kernel=None))]
2284pub fn di_batch_py<'py>(
2285 py: Python<'py>,
2286 high: PyReadonlyArray1<'py, f64>,
2287 low: PyReadonlyArray1<'py, f64>,
2288 close: PyReadonlyArray1<'py, f64>,
2289 period_range: (usize, usize, usize),
2290 kernel: Option<&str>,
2291) -> PyResult<Bound<'py, PyDict>> {
2292 let high_slice = high.as_slice()?;
2293 let low_slice = low.as_slice()?;
2294 let close_slice = close.as_slice()?;
2295 let kern = validate_kernel(kernel, true)?;
2296
2297 let sweep = DiBatchRange {
2298 period: period_range,
2299 };
2300
2301 let combos = expand_grid(&sweep);
2302 let rows = combos.len();
2303 let cols = high_slice.len();
2304 let total = rows
2305 .checked_mul(cols)
2306 .ok_or_else(|| PyValueError::new_err("di: size overflow in rows*cols"))?;
2307
2308 let out_plus = unsafe { PyArray1::<f64>::new(py, [total], false) };
2309 let out_minus = unsafe { PyArray1::<f64>::new(py, [total], false) };
2310
2311 let pl = unsafe { out_plus.as_slice_mut()? };
2312 let mi = unsafe { out_minus.as_slice_mut()? };
2313
2314 let first = (0..cols)
2315 .find(|&i| !(high_slice[i].is_nan() || low_slice[i].is_nan() || close_slice[i].is_nan()))
2316 .ok_or_else(|| PyValueError::new_err("di: All values are NaN"))?;
2317 let warm: Vec<usize> = combos
2318 .iter()
2319 .map(|p| first + p.period.unwrap() - 1)
2320 .collect();
2321
2322 unsafe {
2323 let plus_mu = std::slice::from_raw_parts_mut(
2324 pl.as_mut_ptr() as *mut std::mem::MaybeUninit<f64>,
2325 total,
2326 );
2327 let minus_mu = std::slice::from_raw_parts_mut(
2328 mi.as_mut_ptr() as *mut std::mem::MaybeUninit<f64>,
2329 total,
2330 );
2331 init_matrix_prefixes(plus_mu, cols, &warm);
2332 init_matrix_prefixes(minus_mu, cols, &warm);
2333 }
2334
2335 let combos = py
2336 .allow_threads(|| {
2337 let simd = match kern {
2338 Kernel::Auto => match detect_best_batch_kernel() {
2339 Kernel::Avx512Batch => Kernel::Avx512,
2340 Kernel::Avx2Batch => Kernel::Avx2,
2341 _ => Kernel::Scalar,
2342 },
2343 Kernel::Avx512Batch => Kernel::Avx512,
2344 Kernel::Avx2Batch => Kernel::Avx2,
2345 Kernel::ScalarBatch => Kernel::Scalar,
2346 k => k,
2347 };
2348
2349 di_batch_inner_into(
2350 high_slice,
2351 low_slice,
2352 close_slice,
2353 &sweep,
2354 simd,
2355 true,
2356 pl,
2357 mi,
2358 )
2359 })
2360 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2361
2362 let dict = PyDict::new(py);
2363 dict.set_item("plus", out_plus.reshape((rows, cols))?)?;
2364 dict.set_item("minus", out_minus.reshape((rows, cols))?)?;
2365 dict.set_item(
2366 "periods",
2367 combos
2368 .iter()
2369 .map(|p| p.period.unwrap() as u64)
2370 .collect::<Vec<_>>()
2371 .into_pyarray(py),
2372 )?;
2373
2374 Ok(dict)
2375}
2376
2377#[cfg(all(feature = "python", feature = "cuda"))]
2378use crate::cuda::cuda_available;
2379#[cfg(all(feature = "python", feature = "cuda"))]
2380use crate::cuda::moving_averages::DeviceArrayF32;
2381#[cfg(all(feature = "python", feature = "cuda"))]
2382use crate::cuda::CudaDi;
2383#[cfg(all(feature = "python", feature = "cuda"))]
2384use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
2385#[cfg(all(feature = "python", feature = "cuda"))]
2386use cust::context::Context;
2387#[cfg(all(feature = "python", feature = "cuda"))]
2388use cust::memory::DeviceBuffer;
2389#[cfg(all(feature = "python", feature = "cuda"))]
2390use pyo3::prelude::*;
2391#[cfg(all(feature = "python", feature = "cuda"))]
2392use std::os::raw::c_void;
2393#[cfg(all(feature = "python", feature = "cuda"))]
2394use std::sync::Arc;
2395
2396#[cfg(all(feature = "python", feature = "cuda"))]
2397#[pyclass(module = "ta_indicators.cuda", unsendable)]
2398pub struct DeviceArrayF32DiPy {
2399 pub(crate) inner: DeviceArrayF32,
2400 pub(crate) ctx: Arc<Context>,
2401 pub(crate) device_id: u32,
2402}
2403
2404#[cfg(all(feature = "python", feature = "cuda"))]
2405#[pymethods]
2406impl DeviceArrayF32DiPy {
2407 #[getter]
2408 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2409 let itemsize = std::mem::size_of::<f32>();
2410 let d = PyDict::new(py);
2411 d.set_item("shape", (self.inner.rows, self.inner.cols))?;
2412 d.set_item("typestr", "<f4")?;
2413 d.set_item("strides", (self.inner.cols * itemsize, itemsize))?;
2414 d.set_item("data", (self.inner.device_ptr() as usize, false))?;
2415
2416 d.set_item("version", 3)?;
2417 Ok(d)
2418 }
2419
2420 fn __dlpack_device__(&self) -> (i32, i32) {
2421 (2, self.device_id as i32)
2422 }
2423
2424 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2425 fn __dlpack__<'py>(
2426 &mut self,
2427 py: Python<'py>,
2428 stream: Option<PyObject>,
2429 max_version: Option<PyObject>,
2430 dl_device: Option<PyObject>,
2431 copy: Option<PyObject>,
2432 ) -> PyResult<PyObject> {
2433 let (kdl, alloc_dev) = self.__dlpack_device__();
2434 if let Some(dev_obj) = dl_device.as_ref() {
2435 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2436 if dev_ty != kdl || dev_id != alloc_dev {
2437 let wants_copy = copy
2438 .as_ref()
2439 .and_then(|c| c.extract::<bool>(py).ok())
2440 .unwrap_or(false);
2441 if wants_copy {
2442 return Err(PyValueError::new_err(
2443 "device copy not implemented for __dlpack__",
2444 ));
2445 } else {
2446 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
2447 }
2448 }
2449 }
2450 }
2451 let _ = stream;
2452
2453 let dummy =
2454 DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
2455 let inner = std::mem::replace(
2456 &mut self.inner,
2457 DeviceArrayF32 {
2458 buf: dummy,
2459 rows: 0,
2460 cols: 0,
2461 },
2462 );
2463
2464 let rows = inner.rows;
2465 let cols = inner.cols;
2466 let buf = inner.buf;
2467
2468 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2469
2470 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2471 }
2472}
2473
2474#[cfg(all(feature = "python", feature = "cuda"))]
2475#[pyfunction(name = "di_cuda_batch_dev")]
2476#[pyo3(signature = (high_f32, low_f32, close_f32, period_range, device_id=0))]
2477pub fn di_cuda_batch_dev_py<'py>(
2478 py: Python<'py>,
2479 high_f32: numpy::PyReadonlyArray1<'py, f32>,
2480 low_f32: numpy::PyReadonlyArray1<'py, f32>,
2481 close_f32: numpy::PyReadonlyArray1<'py, f32>,
2482 period_range: (usize, usize, usize),
2483 device_id: usize,
2484) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
2485 use numpy::IntoPyArray;
2486 if !cuda_available() {
2487 return Err(PyValueError::new_err("CUDA not available"));
2488 }
2489 let h = high_f32.as_slice()?;
2490 let l = low_f32.as_slice()?;
2491 let c = close_f32.as_slice()?;
2492 let sweep = DiBatchRange {
2493 period: period_range,
2494 };
2495 let (plus_dev, minus_dev, combos, ctx, dev_id) = py.allow_threads(|| {
2496 let cuda = CudaDi::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2497 let res = cuda
2498 .di_batch_dev(h, l, c, &sweep)
2499 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2500 Ok::<_, pyo3::PyErr>((res.0, res.1, res.2, cuda.context_arc(), cuda.device_id()))
2501 })?;
2502 let dict = pyo3::types::PyDict::new(py);
2503 dict.set_item(
2504 "plus",
2505 Py::new(
2506 py,
2507 DeviceArrayF32DiPy {
2508 inner: plus_dev,
2509 ctx: ctx.clone(),
2510 device_id: dev_id,
2511 },
2512 )?,
2513 )?;
2514 dict.set_item(
2515 "minus",
2516 Py::new(
2517 py,
2518 DeviceArrayF32DiPy {
2519 inner: minus_dev,
2520 ctx,
2521 device_id: dev_id,
2522 },
2523 )?,
2524 )?;
2525 dict.set_item(
2526 "periods",
2527 combos
2528 .iter()
2529 .map(|p| p.period.unwrap() as u64)
2530 .collect::<Vec<_>>()
2531 .into_pyarray(py),
2532 )?;
2533 dict.set_item("rows", combos.len())?;
2534 dict.set_item("cols", h.len())?;
2535 Ok(dict)
2536}
2537
2538#[cfg(all(feature = "python", feature = "cuda"))]
2539#[pyfunction(name = "di_cuda_many_series_one_param_dev")]
2540#[pyo3(signature = (high_tm_f32, low_tm_f32, close_tm_f32, period, device_id=0))]
2541pub fn di_cuda_many_series_one_param_dev_py<'py>(
2542 py: Python<'py>,
2543 high_tm_f32: numpy::PyReadonlyArray2<'py, f32>,
2544 low_tm_f32: numpy::PyReadonlyArray2<'py, f32>,
2545 close_tm_f32: numpy::PyReadonlyArray2<'py, f32>,
2546 period: usize,
2547 device_id: usize,
2548) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
2549 if !cuda_available() {
2550 return Err(PyValueError::new_err("CUDA not available"));
2551 }
2552 let h_shape = high_tm_f32.shape();
2553 let l_shape = low_tm_f32.shape();
2554 let c_shape = close_tm_f32.shape();
2555 if h_shape != l_shape || l_shape != c_shape || h_shape.len() != 2 {
2556 return Err(PyValueError::new_err(
2557 "expected three 2D arrays of same shape",
2558 ));
2559 }
2560 let rows = h_shape[0];
2561 let cols = h_shape[1];
2562 let h = high_tm_f32.as_slice()?;
2563 let l = low_tm_f32.as_slice()?;
2564 let c = close_tm_f32.as_slice()?;
2565 let (pair, ctx, dev_id) = py.allow_threads(|| {
2566 let cuda = CudaDi::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2567 let p = cuda
2568 .di_many_series_one_param_time_major_dev(h, l, c, cols, rows, period)
2569 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2570 Ok::<_, pyo3::PyErr>((p, cuda.context_arc(), cuda.device_id()))
2571 })?;
2572 let dict = pyo3::types::PyDict::new(py);
2573 dict.set_item(
2574 "plus",
2575 Py::new(
2576 py,
2577 DeviceArrayF32DiPy {
2578 inner: pair.plus,
2579 ctx: ctx.clone(),
2580 device_id: dev_id,
2581 },
2582 )?,
2583 )?;
2584 dict.set_item(
2585 "minus",
2586 Py::new(
2587 py,
2588 DeviceArrayF32DiPy {
2589 inner: pair.minus,
2590 ctx,
2591 device_id: dev_id,
2592 },
2593 )?,
2594 )?;
2595 dict.set_item("rows", rows)?;
2596 dict.set_item("cols", cols)?;
2597 dict.set_item("period", period)?;
2598 Ok(dict)
2599}
2600
2601#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2602#[derive(Serialize, Deserialize)]
2603pub struct DiJsResult {
2604 pub values: Vec<f64>,
2605 pub rows: usize,
2606 pub cols: usize,
2607}
2608
2609#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2610#[wasm_bindgen(js_name = di)]
2611pub fn di_js(high: &[f64], low: &[f64], close: &[f64], period: usize) -> Result<JsValue, JsValue> {
2612 let params = DiParams {
2613 period: Some(period),
2614 };
2615 let input = DiInput::from_slices(high, low, close, params);
2616 let out = di_with_kernel(&input, crate::utilities::enums::Kernel::Auto)
2617 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2618 let cols = high.len();
2619 let total = cols
2620 .checked_mul(2)
2621 .ok_or_else(|| JsValue::from_str("di_js: size overflow"))?;
2622 let mut values = Vec::with_capacity(total);
2623 values.extend_from_slice(&out.plus);
2624 values.extend_from_slice(&out.minus);
2625 let result = DiJsResult {
2626 values,
2627 rows: 2,
2628 cols,
2629 };
2630 serde_wasm_bindgen::to_value(&result).map_err(|e| JsValue::from_str(&e.to_string()))
2631}
2632
2633#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2634#[wasm_bindgen]
2635pub fn di_alloc(len: usize) -> *mut f64 {
2636 let mut vec = Vec::<f64>::with_capacity(len);
2637 let ptr = vec.as_mut_ptr();
2638 std::mem::forget(vec);
2639 ptr
2640}
2641
2642#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2643#[wasm_bindgen]
2644pub fn di_free(ptr: *mut f64, len: usize) {
2645 if !ptr.is_null() {
2646 unsafe {
2647 let _ = Vec::from_raw_parts(ptr, len, len);
2648 }
2649 }
2650}
2651
2652#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2653#[wasm_bindgen]
2654pub fn di_into(
2655 high_ptr: *const f64,
2656 low_ptr: *const f64,
2657 close_ptr: *const f64,
2658 out_ptr: *mut f64,
2659 len: usize,
2660 period: usize,
2661) -> Result<(), JsValue> {
2662 if high_ptr.is_null() || low_ptr.is_null() || close_ptr.is_null() || out_ptr.is_null() {
2663 return Err(JsValue::from_str("null pointer to di_into"));
2664 }
2665 unsafe {
2666 let h = std::slice::from_raw_parts(high_ptr, len);
2667 let l = std::slice::from_raw_parts(low_ptr, len);
2668 let c = std::slice::from_raw_parts(close_ptr, len);
2669 let params = DiParams {
2670 period: Some(period),
2671 };
2672 let input = DiInput::from_slices(h, l, c, params);
2673
2674 let DiOutput { plus, minus } =
2675 di_with_kernel(&input, crate::utilities::enums::Kernel::Auto)
2676 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2677
2678 let total = len
2679 .checked_mul(2)
2680 .ok_or_else(|| JsValue::from_str("di_into: size overflow"))?;
2681 let out = std::slice::from_raw_parts_mut(out_ptr, total);
2682 out[..len].copy_from_slice(&plus);
2683 out[len..total].copy_from_slice(&minus);
2684 Ok(())
2685 }
2686}
2687
2688#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2689#[derive(Serialize, Deserialize)]
2690pub struct DiBatchConfig {
2691 pub period_range: (usize, usize, usize),
2692}
2693
2694#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2695#[derive(Serialize, Deserialize)]
2696pub struct DiBatchJsOutput {
2697 pub values: Vec<f64>,
2698 pub rows: usize,
2699 pub cols: usize,
2700 pub periods: Vec<usize>,
2701}
2702
2703#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2704#[wasm_bindgen(js_name = di_batch)]
2705pub fn di_batch_unified_js(
2706 high: &[f64],
2707 low: &[f64],
2708 close: &[f64],
2709 config: JsValue,
2710) -> Result<JsValue, JsValue> {
2711 let cfg: DiBatchConfig = serde_wasm_bindgen::from_value(config)
2712 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2713 let sweep = DiBatchRange {
2714 period: cfg.period_range,
2715 };
2716 let output = di_batch_slice(
2717 high,
2718 low,
2719 close,
2720 &sweep,
2721 crate::utilities::enums::Kernel::Scalar,
2722 )
2723 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2724 let rows = output
2725 .rows
2726 .checked_mul(2)
2727 .ok_or_else(|| JsValue::from_str("di_batch_unified_js: rows overflow"))?;
2728 let cols = output.cols;
2729
2730 let total = rows
2731 .checked_mul(cols)
2732 .ok_or_else(|| JsValue::from_str("di_batch_unified_js: size overflow"))?;
2733 let mut values = Vec::with_capacity(total);
2734
2735 for combo_idx in 0..output.rows {
2736 let start = combo_idx * cols;
2737 values.extend_from_slice(&output.plus[start..start + cols]);
2738 values.extend_from_slice(&output.minus[start..start + cols]);
2739 }
2740
2741 let js = DiBatchJsOutput {
2742 values,
2743 rows,
2744 cols,
2745 periods: output
2746 .combos
2747 .iter()
2748 .map(|p| p.period.unwrap())
2749 .collect::<Vec<_>>(),
2750 };
2751 serde_wasm_bindgen::to_value(&js)
2752 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2753}
2754
2755#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2756#[wasm_bindgen]
2757pub fn di_batch_into(
2758 high_ptr: *const f64,
2759 low_ptr: *const f64,
2760 close_ptr: *const f64,
2761 plus_ptr: *mut f64,
2762 minus_ptr: *mut f64,
2763 len: usize,
2764 period_start: usize,
2765 period_end: usize,
2766 period_step: usize,
2767) -> Result<usize, JsValue> {
2768 if high_ptr.is_null()
2769 || low_ptr.is_null()
2770 || close_ptr.is_null()
2771 || plus_ptr.is_null()
2772 || minus_ptr.is_null()
2773 {
2774 return Err(JsValue::from_str("Null pointer provided"));
2775 }
2776
2777 unsafe {
2778 let high = std::slice::from_raw_parts(high_ptr, len);
2779 let low = std::slice::from_raw_parts(low_ptr, len);
2780 let close = std::slice::from_raw_parts(close_ptr, len);
2781
2782 let sweep = DiBatchRange {
2783 period: (period_start, period_end, period_step),
2784 };
2785
2786 let combos = expand_grid(&sweep);
2787 let rows = combos.len();
2788 let total_size = rows
2789 .checked_mul(len)
2790 .ok_or_else(|| JsValue::from_str("di_batch_into: size overflow"))?;
2791
2792 let out_plus = std::slice::from_raw_parts_mut(plus_ptr, total_size);
2793 let out_minus = std::slice::from_raw_parts_mut(minus_ptr, total_size);
2794
2795 di_batch_inner_into(
2796 high,
2797 low,
2798 close,
2799 &sweep,
2800 Kernel::Auto,
2801 false,
2802 out_plus,
2803 out_minus,
2804 )
2805 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2806
2807 Ok(rows)
2808 }
2809}