1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5 make_uninit_matrix,
6};
7use aligned_vec::{AVec, CACHELINE_ALIGN};
8#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
9use core::arch::x86_64::*;
10#[cfg(not(target_arch = "wasm32"))]
11use rayon::prelude::*;
12use std::convert::AsRef;
13use std::error::Error;
14use std::mem::MaybeUninit;
15use thiserror::Error;
16
17#[cfg(feature = "python")]
18use crate::utilities::kernel_validation::validate_kernel;
19#[cfg(feature = "python")]
20use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
21#[cfg(feature = "python")]
22use pyo3::exceptions::PyValueError;
23#[cfg(feature = "python")]
24use pyo3::prelude::*;
25#[cfg(feature = "python")]
26use pyo3::types::{PyDict, PyList};
27
28#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
29use serde::{Deserialize, Serialize};
30#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
31use wasm_bindgen::prelude::*;
32
33#[derive(Debug, Clone)]
34pub enum DxData<'a> {
35 Candles {
36 candles: &'a Candles,
37 },
38 HlcSlices {
39 high: &'a [f64],
40 low: &'a [f64],
41 close: &'a [f64],
42 },
43}
44
45#[derive(Debug, Clone)]
46pub struct DxOutput {
47 pub values: Vec<f64>,
48}
49
50#[derive(Debug, Clone)]
51#[cfg_attr(
52 all(target_arch = "wasm32", feature = "wasm"),
53 derive(Serialize, Deserialize)
54)]
55pub struct DxParams {
56 pub period: Option<usize>,
57}
58
59impl Default for DxParams {
60 fn default() -> Self {
61 Self { period: Some(14) }
62 }
63}
64
65impl DxParams {
66 pub fn generate_batch_params(period_range: (usize, usize, usize)) -> Vec<Self> {
67 let (start, end, step) = period_range;
68 let step = if step == 0 { 1 } else { step };
69 let mut params = Vec::new();
70
71 let mut period = start;
72 while period <= end {
73 params.push(Self {
74 period: Some(period),
75 });
76 period += step;
77 }
78
79 params
80 }
81}
82
83#[derive(Debug, Clone)]
84pub struct DxInput<'a> {
85 pub data: DxData<'a>,
86 pub params: DxParams,
87}
88
89impl<'a> DxInput<'a> {
90 #[inline]
91 pub fn from_candles(candles: &'a Candles, params: DxParams) -> Self {
92 Self {
93 data: DxData::Candles { candles },
94 params,
95 }
96 }
97 #[inline]
98 pub fn from_hlc_slices(
99 high: &'a [f64],
100 low: &'a [f64],
101 close: &'a [f64],
102 params: DxParams,
103 ) -> Self {
104 Self {
105 data: DxData::HlcSlices { high, low, close },
106 params,
107 }
108 }
109 #[inline]
110 pub fn with_default_candles(candles: &'a Candles) -> Self {
111 Self {
112 data: DxData::Candles { candles },
113 params: DxParams::default(),
114 }
115 }
116 #[inline]
117 pub fn get_period(&self) -> usize {
118 self.params.period.unwrap_or(14)
119 }
120}
121
122#[derive(Copy, Clone, Debug)]
123pub struct DxBuilder {
124 period: Option<usize>,
125 kernel: Kernel,
126}
127
128impl Default for DxBuilder {
129 fn default() -> Self {
130 Self {
131 period: None,
132 kernel: Kernel::Auto,
133 }
134 }
135}
136
137impl DxBuilder {
138 #[inline(always)]
139 pub fn new() -> Self {
140 Self::default()
141 }
142 #[inline(always)]
143 pub fn period(mut self, n: usize) -> Self {
144 self.period = Some(n);
145 self
146 }
147 #[inline(always)]
148 pub fn kernel(mut self, k: Kernel) -> Self {
149 self.kernel = k;
150 self
151 }
152 #[inline(always)]
153 pub fn apply(self, c: &Candles) -> Result<DxOutput, DxError> {
154 let p = DxParams {
155 period: self.period,
156 };
157 let i = DxInput::from_candles(c, p);
158 dx_with_kernel(&i, self.kernel)
159 }
160 #[inline(always)]
161 pub fn apply_hlc(self, high: &[f64], low: &[f64], close: &[f64]) -> Result<DxOutput, DxError> {
162 let p = DxParams {
163 period: self.period,
164 };
165 let i = DxInput::from_hlc_slices(high, low, close, p);
166 dx_with_kernel(&i, self.kernel)
167 }
168 #[inline(always)]
169 pub fn into_stream(self) -> Result<DxStream, DxError> {
170 let p = DxParams {
171 period: self.period,
172 };
173 DxStream::try_new(p)
174 }
175}
176
177#[derive(Debug, Error)]
178pub enum DxError {
179 #[error("dx: Empty data provided for DX.")]
180 EmptyInputData,
181 #[error("dx: Could not select candle field: {0}")]
182 SelectCandleFieldError(String),
183 #[error("dx: Invalid period: period = {period}, data length = {data_len}")]
184 InvalidPeriod { period: usize, data_len: usize },
185 #[error("dx: Not enough valid data: needed = {needed}, valid = {valid}")]
186 NotEnoughValidData { needed: usize, valid: usize },
187 #[error("dx: All high, low, and close values are NaN.")]
188 AllValuesNaN,
189 #[error("dx: output length mismatch: expected = {expected}, got = {got}")]
190 OutputLengthMismatch { expected: usize, got: usize },
191 #[error("dx: invalid kernel for batch: {0:?}")]
192 InvalidKernelForBatch(Kernel),
193 #[error("dx: invalid range: start={start}, end={end}, step={step}")]
194 InvalidRange {
195 start: usize,
196 end: usize,
197 step: usize,
198 },
199 #[error("dx: invalid input: {0}")]
200 InvalidInput(String),
201}
202
203#[inline(always)]
204fn dx_prepare<'a>(
205 input: &'a DxInput,
206 kernel: Kernel,
207) -> Result<(&'a [f64], &'a [f64], &'a [f64], usize, usize, Kernel), DxError> {
208 let (high, low, close) = match &input.data {
209 DxData::Candles { candles } => {
210 let h = candles
211 .select_candle_field("high")
212 .map_err(|e| DxError::SelectCandleFieldError(e.to_string()))?;
213 let l = candles
214 .select_candle_field("low")
215 .map_err(|e| DxError::SelectCandleFieldError(e.to_string()))?;
216 let c = candles
217 .select_candle_field("close")
218 .map_err(|e| DxError::SelectCandleFieldError(e.to_string()))?;
219 (h, l, c)
220 }
221 DxData::HlcSlices { high, low, close } => (*high, *low, *close),
222 };
223 let len = high.len().min(low.len()).min(close.len());
224 if len == 0 {
225 return Err(DxError::EmptyInputData);
226 }
227 let period = input.get_period();
228 if period == 0 || period > len {
229 return Err(DxError::InvalidPeriod {
230 period,
231 data_len: len,
232 });
233 }
234
235 let first = (0..len)
236 .find(|&i| !high[i].is_nan() && !low[i].is_nan() && !close[i].is_nan())
237 .ok_or(DxError::AllValuesNaN)?;
238 if len - first < period {
239 return Err(DxError::NotEnoughValidData {
240 needed: period,
241 valid: len - first,
242 });
243 }
244 let chosen = match kernel {
245 Kernel::Auto => Kernel::Scalar,
246 k => k,
247 };
248 Ok((high, low, close, len, first, chosen))
249}
250
251#[inline]
252pub fn dx(input: &DxInput) -> Result<DxOutput, DxError> {
253 dx_with_kernel(input, Kernel::Auto)
254}
255
256pub fn dx_with_kernel(input: &DxInput, kernel: Kernel) -> Result<DxOutput, DxError> {
257 let (h, l, c, len, first, chosen) = dx_prepare(input, kernel)?;
258 let warm = first + input.get_period() - 1;
259 let mut out = alloc_with_nan_prefix(len, warm);
260 unsafe {
261 match chosen {
262 Kernel::Scalar | Kernel::ScalarBatch => {
263 dx_scalar(h, l, c, input.get_period(), first, &mut out)
264 }
265 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
266 Kernel::Avx2 | Kernel::Avx2Batch => {
267 dx_avx2(h, l, c, input.get_period(), first, &mut out)
268 }
269 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
270 Kernel::Avx512 | Kernel::Avx512Batch => {
271 dx_avx512(h, l, c, input.get_period(), first, &mut out)
272 }
273 _ => unreachable!(),
274 }
275 }
276 Ok(DxOutput { values: out })
277}
278
279#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
280#[inline]
281pub fn dx_into(input: &DxInput, out: &mut [f64]) -> Result<(), DxError> {
282 dx_into_slice(out, input, Kernel::Auto)
283}
284
285#[inline]
286pub fn dx_scalar(
287 high: &[f64],
288 low: &[f64],
289 close: &[f64],
290 period: usize,
291 first_valid_idx: usize,
292 out: &mut [f64],
293) {
294 let len = high.len().min(low.len()).min(close.len());
295 if len == 0 {
296 return;
297 }
298
299 let p_f64 = period as f64;
300 let hundred = 100.0f64;
301
302 let mut prev_high = high[first_valid_idx];
303 let mut prev_low = low[first_valid_idx];
304 let mut prev_close = close[first_valid_idx];
305
306 let mut plus_dm_sum = 0.0f64;
307 let mut minus_dm_sum = 0.0f64;
308 let mut tr_sum = 0.0f64;
309 let mut initial_count: usize = 0;
310
311 unsafe {
312 let mut i = first_valid_idx + 1;
313 while i < len {
314 let h = *high.get_unchecked(i);
315 let l = *low.get_unchecked(i);
316 let cl = *close.get_unchecked(i);
317
318 if h.is_nan() | l.is_nan() | cl.is_nan() {
319 *out.get_unchecked_mut(i) = if i > 0 {
320 *out.get_unchecked(i - 1)
321 } else {
322 f64::NAN
323 };
324 prev_high = h;
325 prev_low = l;
326 prev_close = cl;
327 i += 1;
328 continue;
329 }
330
331 let up_move = h - prev_high;
332 let down_move = prev_low - l;
333 let mut plus_dm = 0.0f64;
334 let mut minus_dm = 0.0f64;
335 if up_move > 0.0 && up_move > down_move {
336 plus_dm = up_move;
337 } else if down_move > 0.0 && down_move > up_move {
338 minus_dm = down_move;
339 }
340
341 let tr1 = h - l;
342 let tr2 = (h - prev_close).abs();
343 let tr3 = (l - prev_close).abs();
344 let tr = tr1.max(tr2).max(tr3);
345
346 if initial_count < (period - 1) {
347 plus_dm_sum += plus_dm;
348 minus_dm_sum += minus_dm;
349 tr_sum += tr;
350 initial_count += 1;
351
352 if initial_count == (period - 1) {
353 let plus_di = (plus_dm_sum / tr_sum) * hundred;
354 let minus_di = (minus_dm_sum / tr_sum) * hundred;
355 let sum_di = plus_di + minus_di;
356 *out.get_unchecked_mut(i) = if sum_di != 0.0 {
357 hundred * ((plus_di - minus_di).abs() / sum_di)
358 } else {
359 0.0
360 };
361 }
362 } else {
363 plus_dm_sum = plus_dm_sum - (plus_dm_sum / p_f64) + plus_dm;
364 minus_dm_sum = minus_dm_sum - (minus_dm_sum / p_f64) + minus_dm;
365 tr_sum = tr_sum - (tr_sum / p_f64) + tr;
366
367 let plus_di = if tr_sum != 0.0 {
368 (plus_dm_sum / tr_sum) * hundred
369 } else {
370 0.0
371 };
372 let minus_di = if tr_sum != 0.0 {
373 (minus_dm_sum / tr_sum) * hundred
374 } else {
375 0.0
376 };
377 let sum_di = plus_di + minus_di;
378 *out.get_unchecked_mut(i) = if sum_di != 0.0 {
379 hundred * ((plus_di - minus_di).abs() / sum_di)
380 } else {
381 *out.get_unchecked(i - 1)
382 };
383 }
384
385 prev_high = h;
386 prev_low = l;
387 prev_close = cl;
388
389 i += 1;
390 }
391 }
392}
393
394#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
395#[inline]
396pub fn dx_avx512(
397 high: &[f64],
398 low: &[f64],
399 close: &[f64],
400 period: usize,
401 first_valid: usize,
402 out: &mut [f64],
403) {
404 if period <= 32 {
405 unsafe { dx_avx512_short(high, low, close, period, first_valid, out) }
406 } else {
407 unsafe { dx_avx512_long(high, low, close, period, first_valid, out) }
408 }
409}
410
411#[inline]
412pub fn dx_avx2(
413 high: &[f64],
414 low: &[f64],
415 close: &[f64],
416 period: usize,
417 first_valid: usize,
418 out: &mut [f64],
419) {
420 dx_scalar(high, low, close, period, first_valid, out)
421}
422
423#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
424#[inline]
425pub fn dx_avx512_short(
426 high: &[f64],
427 low: &[f64],
428 close: &[f64],
429 period: usize,
430 first_valid: usize,
431 out: &mut [f64],
432) {
433 dx_scalar(high, low, close, period, first_valid, out)
434}
435
436#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
437#[inline]
438pub fn dx_avx512_long(
439 high: &[f64],
440 low: &[f64],
441 close: &[f64],
442 period: usize,
443 first_valid: usize,
444 out: &mut [f64],
445) {
446 dx_scalar(high, low, close, period, first_valid, out)
447}
448
449#[derive(Debug, Clone)]
450pub struct DxStream {
451 period: usize,
452
453 p_f64: f64,
454 hundred: f64,
455
456 plus_dm_sum: f64,
457 minus_dm_sum: f64,
458 tr_sum: f64,
459
460 prev_high: f64,
461 prev_low: f64,
462 prev_close: f64,
463
464 initial_count: usize,
465 filled: bool,
466
467 last_dx: f64,
468}
469
470impl DxStream {
471 pub fn try_new(params: DxParams) -> Result<Self, DxError> {
472 let period = params.period.unwrap_or(14);
473 if period == 0 {
474 return Err(DxError::InvalidPeriod {
475 period,
476 data_len: 0,
477 });
478 }
479 Ok(Self {
480 period,
481 p_f64: period as f64,
482 hundred: 100.0,
483
484 plus_dm_sum: 0.0,
485 minus_dm_sum: 0.0,
486 tr_sum: 0.0,
487
488 prev_high: f64::NAN,
489 prev_low: f64::NAN,
490 prev_close: f64::NAN,
491
492 initial_count: 0,
493 filled: false,
494 last_dx: f64::NAN,
495 })
496 }
497
498 #[inline(always)]
499 pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<f64> {
500 if self.prev_high.is_nan() || self.prev_low.is_nan() || self.prev_close.is_nan() {
501 self.prev_high = high;
502 self.prev_low = low;
503 self.prev_close = close;
504 return None;
505 }
506
507 if high.is_nan() || low.is_nan() || close.is_nan() {
508 let carried = if self.filled { self.last_dx } else { f64::NAN };
509 self.prev_high = high;
510 self.prev_low = low;
511 self.prev_close = close;
512 return Some(carried);
513 }
514
515 let up_move = high - self.prev_high;
516 let down_move = self.prev_low - low;
517 let plus_dm = if up_move > 0.0 && up_move > down_move {
518 up_move
519 } else {
520 0.0
521 };
522 let minus_dm = if down_move > 0.0 && down_move > up_move {
523 down_move
524 } else {
525 0.0
526 };
527
528 let tr1 = high - low;
529 let tr2 = (high - self.prev_close).abs();
530 let tr3 = (low - self.prev_close).abs();
531 let tr = tr1.max(tr2).max(tr3);
532
533 let mut out: Option<f64> = None;
534
535 if self.initial_count < (self.period - 1) {
536 self.plus_dm_sum += plus_dm;
537 self.minus_dm_sum += minus_dm;
538 self.tr_sum += tr;
539 self.initial_count += 1;
540
541 if self.initial_count == (self.period - 1) {
542 let plus_di = (self.plus_dm_sum / self.tr_sum) * self.hundred;
543 let minus_di = (self.minus_dm_sum / self.tr_sum) * self.hundred;
544 let sum_di = plus_di + minus_di;
545
546 let dx = if sum_di != 0.0 {
547 self.hundred * ((plus_di - minus_di).abs() / sum_di)
548 } else {
549 0.0
550 };
551 self.filled = true;
552 self.last_dx = dx;
553 out = Some(dx);
554 }
555 } else {
556 self.plus_dm_sum = self.plus_dm_sum - (self.plus_dm_sum / self.p_f64) + plus_dm;
557 self.minus_dm_sum = self.minus_dm_sum - (self.minus_dm_sum / self.p_f64) + minus_dm;
558 self.tr_sum = self.tr_sum - (self.tr_sum / self.p_f64) + tr;
559
560 let plus_di = if self.tr_sum != 0.0 {
561 (self.plus_dm_sum / self.tr_sum) * self.hundred
562 } else {
563 0.0
564 };
565 let minus_di = if self.tr_sum != 0.0 {
566 (self.minus_dm_sum / self.tr_sum) * self.hundred
567 } else {
568 0.0
569 };
570 let sum_di = plus_di + minus_di;
571
572 let dx = if sum_di != 0.0 {
573 self.hundred * ((plus_di - minus_di).abs() / sum_di)
574 } else if self.filled {
575 self.last_dx
576 } else {
577 f64::NAN
578 };
579 self.last_dx = dx;
580 out = Some(dx);
581 }
582
583 self.prev_high = high;
584 self.prev_low = low;
585 self.prev_close = close;
586
587 out
588 }
589}
590
591#[derive(Clone, Debug)]
592pub struct DxBatchRange {
593 pub period: (usize, usize, usize),
594}
595
596impl Default for DxBatchRange {
597 fn default() -> Self {
598 Self {
599 period: (14, 263, 1),
600 }
601 }
602}
603
604impl DxBatchRange {
605 pub fn from_tuple(period: (usize, usize, usize)) -> Self {
606 Self { period }
607 }
608}
609
610#[derive(Clone, Debug, Default)]
611pub struct DxBatchBuilder {
612 range: DxBatchRange,
613 kernel: Kernel,
614}
615
616impl DxBatchBuilder {
617 pub fn new() -> Self {
618 Self::default()
619 }
620 pub fn kernel(mut self, k: Kernel) -> Self {
621 self.kernel = k;
622 self
623 }
624
625 #[inline]
626 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
627 self.range.period = (start, end, step);
628 self
629 }
630 #[inline]
631 pub fn period_static(mut self, p: usize) -> Self {
632 self.range.period = (p, p, 0);
633 self
634 }
635
636 pub fn apply_hlc(
637 self,
638 high: &[f64],
639 low: &[f64],
640 close: &[f64],
641 ) -> Result<DxBatchOutput, DxError> {
642 dx_batch_with_kernel(high, low, close, &self.range, self.kernel)
643 }
644
645 pub fn apply_candles(self, c: &Candles) -> Result<DxBatchOutput, DxError> {
646 let high = source_type(c, "high");
647 let low = source_type(c, "low");
648 let close = source_type(c, "close");
649 self.apply_hlc(high, low, close)
650 }
651}
652
653pub struct DxBatchOutput {
654 pub values: Vec<f64>,
655 pub combos: Vec<DxParams>,
656 pub rows: usize,
657 pub cols: usize,
658}
659
660impl DxBatchOutput {
661 pub fn row_for_params(&self, p: &DxParams) -> Option<usize> {
662 self.combos
663 .iter()
664 .position(|c| c.period.unwrap_or(14) == p.period.unwrap_or(14))
665 }
666 pub fn values_for(&self, p: &DxParams) -> Option<&[f64]> {
667 self.row_for_params(p).map(|row| {
668 let start = row * self.cols;
669 &self.values[start..start + self.cols]
670 })
671 }
672}
673
674#[inline(always)]
675fn expand_grid_checked(r: &DxBatchRange) -> Result<Vec<DxParams>, DxError> {
676 let (start, end, step) = r.period;
677
678 if step == 0 || start == end {
679 return Ok(vec![DxParams {
680 period: Some(start),
681 }]);
682 }
683
684 let mut out: Vec<usize> = Vec::new();
685 if start < end {
686 let mut v = start;
687 while v <= end {
688 out.push(v);
689 match v.checked_add(step) {
690 Some(next) if next != v => v = next,
691 _ => break,
692 }
693 }
694 } else {
695 let mut v = start;
696
697 loop {
698 out.push(v);
699 if v <= end {
700 break;
701 }
702 let dec = v.saturating_sub(step);
703 if dec == v {
704 break;
705 }
706 v = dec;
707 }
708
709 out.sort_unstable();
710 }
711 if out.is_empty() {
712 return Err(DxError::InvalidRange { start, end, step });
713 }
714 Ok(out
715 .into_iter()
716 .map(|p| DxParams { period: Some(p) })
717 .collect())
718}
719
720pub fn dx_batch_with_kernel(
721 high: &[f64],
722 low: &[f64],
723 close: &[f64],
724 sweep: &DxBatchRange,
725 k: Kernel,
726) -> Result<DxBatchOutput, DxError> {
727 let kernel = match k {
728 Kernel::Auto => Kernel::ScalarBatch,
729 other if other.is_batch() => other,
730 other => return Err(DxError::InvalidKernelForBatch(other)),
731 };
732
733 let simd = match kernel {
734 Kernel::Avx512Batch => Kernel::Avx512,
735 Kernel::Avx2Batch => Kernel::Avx2,
736 Kernel::ScalarBatch => Kernel::Scalar,
737 _ => unreachable!(),
738 };
739 dx_batch_par_slice(high, low, close, sweep, simd)
740}
741
742#[inline(always)]
743pub fn dx_batch_slice(
744 high: &[f64],
745 low: &[f64],
746 close: &[f64],
747 sweep: &DxBatchRange,
748 kern: Kernel,
749) -> Result<DxBatchOutput, DxError> {
750 dx_batch_inner(high, low, close, sweep, kern, false)
751}
752
753#[inline(always)]
754pub fn dx_batch_par_slice(
755 high: &[f64],
756 low: &[f64],
757 close: &[f64],
758 sweep: &DxBatchRange,
759 kern: Kernel,
760) -> Result<DxBatchOutput, DxError> {
761 dx_batch_inner(high, low, close, sweep, kern, true)
762}
763
764#[inline(always)]
765fn dx_batch_inner_into(
766 high: &[f64],
767 low: &[f64],
768 close: &[f64],
769 sweep: &DxBatchRange,
770 kern: Kernel,
771 parallel: bool,
772 out: &mut [f64],
773) -> Result<Vec<DxParams>, DxError> {
774 let combos_vec = expand_grid_checked(sweep)?;
775 let combos = combos_vec;
776
777 let len = high.len().min(low.len()).min(close.len());
778 let first = (0..len)
779 .find(|&i| !high[i].is_nan() && !low[i].is_nan() && !close[i].is_nan())
780 .ok_or(DxError::AllValuesNaN)?;
781 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
782 if len - first < max_p {
783 return Err(DxError::NotEnoughValidData {
784 needed: max_p,
785 valid: len - first,
786 });
787 }
788
789 let rows = combos.len();
790 let total = rows
791 .checked_mul(len)
792 .ok_or_else(|| DxError::InvalidInput("rows*cols overflow".into()))?;
793 if out.len() != total {
794 return Err(DxError::OutputLengthMismatch {
795 expected: total,
796 got: out.len(),
797 });
798 }
799
800 let actual = match kern {
801 Kernel::Auto => detect_best_batch_kernel(),
802 k => k,
803 };
804 let simd = match actual {
805 Kernel::ScalarBatch | Kernel::Scalar => Kernel::Scalar,
806 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
807 Kernel::Avx2Batch | Kernel::Avx2 => Kernel::Avx2,
808 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
809 Kernel::Avx512Batch | Kernel::Avx512 => Kernel::Avx512,
810 _ => unreachable!(),
811 };
812
813 let (plus_dm, minus_dm, tr, carry) = dx_precompute_terms(high, low, close, first, len);
814
815 let do_row = |row: usize, dst_row: &mut [f64]| unsafe {
816 let p = combos[row].period.unwrap();
817 dx_row_scalar_precomputed(&plus_dm, &minus_dm, &tr, &carry, first, p, dst_row);
818 };
819
820 if parallel {
821 #[cfg(not(target_arch = "wasm32"))]
822 out.par_chunks_mut(len)
823 .enumerate()
824 .for_each(|(r, s)| do_row(r, s));
825 #[cfg(target_arch = "wasm32")]
826 for (r, s) in out.chunks_mut(len).enumerate() {
827 do_row(r, s);
828 }
829 } else {
830 for (r, s) in out.chunks_mut(len).enumerate() {
831 do_row(r, s);
832 }
833 }
834 Ok(combos)
835}
836
837#[inline(always)]
838fn dx_precompute_terms(
839 high: &[f64],
840 low: &[f64],
841 close: &[f64],
842 first: usize,
843 len: usize,
844) -> (AVec<f64>, AVec<f64>, AVec<f64>, Vec<u8>) {
845 let mut plus_dm: AVec<f64> = AVec::with_capacity(CACHELINE_ALIGN, len);
846 let mut minus_dm: AVec<f64> = AVec::with_capacity(CACHELINE_ALIGN, len);
847 let mut tr: AVec<f64> = AVec::with_capacity(CACHELINE_ALIGN, len);
848
849 let mut carry: Vec<u8> = vec![0; len];
850
851 for _ in 0..len {
852 plus_dm.push(0.0);
853 }
854 for _ in 0..len {
855 minus_dm.push(0.0);
856 }
857 for _ in 0..len {
858 tr.push(0.0);
859 }
860
861 if len == 0 || first + 1 >= len {
862 return (plus_dm, minus_dm, tr, carry);
863 }
864
865 for i in (first + 1)..len {
866 let h = high[i];
867 let l = low[i];
868 let c = close[i];
869 if h.is_nan() || l.is_nan() || c.is_nan() {
870 carry[i] = 1;
871 continue;
872 }
873
874 let up_move = h - high[i - 1];
875 let down_move = low[i - 1] - l;
876 let pdm = if up_move > 0.0 && up_move > down_move {
877 up_move
878 } else {
879 0.0
880 };
881 let mdm = if down_move > 0.0 && down_move > up_move {
882 down_move
883 } else {
884 0.0
885 };
886
887 let tr1 = h - l;
888 let tr2 = (h - close[i - 1]).abs();
889 let tr3 = (l - close[i - 1]).abs();
890 let t = tr1.max(tr2).max(tr3);
891
892 plus_dm[i] = pdm;
893 minus_dm[i] = mdm;
894 tr[i] = t;
895 }
896
897 (plus_dm, minus_dm, tr, carry)
898}
899
900#[inline(always)]
901unsafe fn dx_row_scalar_precomputed(
902 plus_dm: &[f64],
903 minus_dm: &[f64],
904 tr: &[f64],
905 carry: &[u8],
906 first: usize,
907 period: usize,
908 out: &mut [f64],
909) {
910 let len = out.len();
911 if len == 0 || first + 1 >= len {
912 return;
913 }
914
915 let p_f64 = period as f64;
916 let hundred = 100.0f64;
917
918 let mut plus_dm_sum = 0.0f64;
919 let mut minus_dm_sum = 0.0f64;
920 let mut tr_sum = 0.0f64;
921 let mut initial_count: usize = 0;
922
923 let mut i = first + 1;
924 while i < len {
925 if *carry.get_unchecked(i) != 0 {
926 *out.get_unchecked_mut(i) = if i > 0 {
927 *out.get_unchecked(i - 1)
928 } else {
929 f64::NAN
930 };
931 i += 1;
932 continue;
933 }
934
935 let pdm = *plus_dm.get_unchecked(i);
936 let mdm = *minus_dm.get_unchecked(i);
937 let t = *tr.get_unchecked(i);
938
939 if initial_count < (period - 1) {
940 plus_dm_sum += pdm;
941 minus_dm_sum += mdm;
942 tr_sum += t;
943 initial_count += 1;
944 if initial_count == (period - 1) {
945 let plus_di = (plus_dm_sum / tr_sum) * hundred;
946 let minus_di = (minus_dm_sum / tr_sum) * hundred;
947 let sum_di = plus_di + minus_di;
948 *out.get_unchecked_mut(i) = if sum_di != 0.0 {
949 hundred * ((plus_di - minus_di).abs() / sum_di)
950 } else {
951 0.0
952 };
953 }
954 } else {
955 plus_dm_sum = plus_dm_sum - (plus_dm_sum / p_f64) + pdm;
956 minus_dm_sum = minus_dm_sum - (minus_dm_sum / p_f64) + mdm;
957 tr_sum = tr_sum - (tr_sum / p_f64) + t;
958 let plus_di = if tr_sum != 0.0 {
959 (plus_dm_sum / tr_sum) * hundred
960 } else {
961 0.0
962 };
963 let minus_di = if tr_sum != 0.0 {
964 (minus_dm_sum / tr_sum) * hundred
965 } else {
966 0.0
967 };
968 let sum_di = plus_di + minus_di;
969 *out.get_unchecked_mut(i) = if sum_di != 0.0 {
970 hundred * ((plus_di - minus_di).abs() / sum_di)
971 } else {
972 *out.get_unchecked(i - 1)
973 };
974 }
975
976 i += 1;
977 }
978}
979
980fn dx_batch_inner(
981 high: &[f64],
982 low: &[f64],
983 close: &[f64],
984 sweep: &DxBatchRange,
985 kern: Kernel,
986 parallel: bool,
987) -> Result<DxBatchOutput, DxError> {
988 let combos = expand_grid_checked(sweep)?;
989 let rows = combos.len();
990 let cols = high.len().min(low.len()).min(close.len());
991 if cols == 0 {
992 return Err(DxError::EmptyInputData);
993 }
994 let _ = rows
995 .checked_mul(cols)
996 .ok_or_else(|| DxError::InvalidInput("rows*cols overflow".into()))?;
997
998 let first = (0..cols)
999 .find(|&i| !high[i].is_nan() && !low[i].is_nan() && !close[i].is_nan())
1000 .ok_or(DxError::AllValuesNaN)?;
1001 let warm: Vec<usize> = combos
1002 .iter()
1003 .map(|c| first + c.period.unwrap() - 1)
1004 .collect();
1005
1006 let mut buf_mu = make_uninit_matrix(rows, cols);
1007 init_matrix_prefixes(&mut buf_mu, cols, &warm);
1008
1009 let mut guard = core::mem::ManuallyDrop::new(buf_mu);
1010 let out_slice: &mut [f64] =
1011 unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
1012
1013 let _ = dx_batch_inner_into(high, low, close, sweep, kern, parallel, out_slice)?;
1014
1015 let values = unsafe {
1016 Vec::from_raw_parts(
1017 guard.as_mut_ptr() as *mut f64,
1018 guard.len(),
1019 guard.capacity(),
1020 )
1021 };
1022 Ok(DxBatchOutput {
1023 values,
1024 combos,
1025 rows,
1026 cols,
1027 })
1028}
1029
1030#[inline(always)]
1031unsafe fn dx_row_scalar(
1032 high: &[f64],
1033 low: &[f64],
1034 close: &[f64],
1035 first: usize,
1036 period: usize,
1037 out: &mut [f64],
1038) {
1039 dx_scalar(high, low, close, period, first, out)
1040}
1041
1042#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1043#[inline(always)]
1044unsafe fn dx_row_avx2(
1045 high: &[f64],
1046 low: &[f64],
1047 close: &[f64],
1048 first: usize,
1049 period: usize,
1050 out: &mut [f64],
1051) {
1052 dx_scalar(high, low, close, period, first, out)
1053}
1054
1055#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1056#[inline(always)]
1057pub unsafe fn dx_row_avx512(
1058 high: &[f64],
1059 low: &[f64],
1060 close: &[f64],
1061 first: usize,
1062 period: usize,
1063 out: &mut [f64],
1064) {
1065 if period <= 32 {
1066 dx_row_avx512_short(high, low, close, first, period, out);
1067 } else {
1068 dx_row_avx512_long(high, low, close, first, period, out);
1069 }
1070}
1071
1072#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1073#[inline(always)]
1074unsafe fn dx_row_avx512_short(
1075 high: &[f64],
1076 low: &[f64],
1077 close: &[f64],
1078 first: usize,
1079 period: usize,
1080 out: &mut [f64],
1081) {
1082 dx_scalar(high, low, close, period, first, out)
1083}
1084
1085#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1086#[inline(always)]
1087unsafe fn dx_row_avx512_long(
1088 high: &[f64],
1089 low: &[f64],
1090 close: &[f64],
1091 first: usize,
1092 period: usize,
1093 out: &mut [f64],
1094) {
1095 dx_scalar(high, low, close, period, first, out)
1096}
1097
1098#[inline(always)]
1099pub fn expand_grid_dx(r: &DxBatchRange) -> Vec<DxParams> {
1100 expand_grid_checked(r).unwrap_or_else(|_| vec![])
1101}
1102
1103#[cfg(test)]
1104mod tests {
1105 use super::*;
1106 use crate::skip_if_unsupported;
1107 use crate::utilities::data_loader::read_candles_from_csv;
1108
1109 fn check_dx_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1110 skip_if_unsupported!(kernel, test_name);
1111 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1112 let candles = read_candles_from_csv(file_path)?;
1113
1114 let default_params = DxParams { period: None };
1115 let input = DxInput::from_candles(&candles, default_params);
1116 let output = dx_with_kernel(&input, kernel)?;
1117 assert_eq!(output.values.len(), candles.close.len());
1118 Ok(())
1119 }
1120
1121 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1122 #[test]
1123 fn test_dx_into_matches_api() -> Result<(), Box<dyn Error>> {
1124 let n = 512usize;
1125 let mut close = vec![0.0f64; n];
1126 for i in 0..n {
1127 let t = i as f64;
1128 close[i] = 100.0 + 0.1 * t + (t * 0.2).sin() * 2.0;
1129 }
1130 let mut high = vec![0.0f64; n];
1131 let mut low = vec![0.0f64; n];
1132 for i in 0..n {
1133 let t = i as f64;
1134
1135 high[i] = close[i] + 0.6 + 0.05 * (t * 0.3).sin();
1136 low[i] = close[i] - 0.6 - 0.05 * (t * 0.3).cos();
1137 if low[i] > high[i] {
1138 core::mem::swap(&mut low[i], &mut high[i]);
1139 }
1140 }
1141
1142 let params = DxParams { period: Some(14) };
1143 let input = DxInput::from_hlc_slices(&high, &low, &close, params);
1144
1145 let base = dx(&input)?.values;
1146
1147 let mut into_out = vec![0.0f64; n];
1148 dx_into(&input, &mut into_out)?;
1149
1150 #[inline]
1151 fn eq_or_both_nan(a: f64, b: f64) -> bool {
1152 (a.is_nan() && b.is_nan()) || (a == b) || ((a - b).abs() <= 1e-12)
1153 }
1154
1155 assert_eq!(base.len(), into_out.len());
1156 for i in 0..n {
1157 assert!(
1158 eq_or_both_nan(base[i], into_out[i]),
1159 "dx_into mismatch at {}: base={}, into={}",
1160 i,
1161 base[i],
1162 into_out[i]
1163 );
1164 }
1165 Ok(())
1166 }
1167
1168 fn check_dx_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1169 skip_if_unsupported!(kernel, test_name);
1170 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1171 let candles = read_candles_from_csv(file_path)?;
1172
1173 let input = DxInput::from_candles(&candles, DxParams::default());
1174 let result = dx_with_kernel(&input, kernel)?;
1175 let expected_last_five = [
1176 43.72121533411883,
1177 41.47251493226443,
1178 43.43041386436222,
1179 43.22673458811955,
1180 51.65514026197179,
1181 ];
1182 let start = result.values.len().saturating_sub(5);
1183 for (i, &val) in result.values[start..].iter().enumerate() {
1184 let diff = (val - expected_last_five[i]).abs();
1185 assert!(
1186 diff < 1e-4,
1187 "[{}] DX {:?} mismatch at idx {}: got {}, expected {}",
1188 test_name,
1189 kernel,
1190 i,
1191 val,
1192 expected_last_five[i]
1193 );
1194 }
1195 Ok(())
1196 }
1197
1198 fn check_dx_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1199 skip_if_unsupported!(kernel, test_name);
1200 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1201 let candles = read_candles_from_csv(file_path)?;
1202
1203 let input = DxInput::with_default_candles(&candles);
1204 match input.data {
1205 DxData::Candles { .. } => {}
1206 _ => panic!("Expected DxData::Candles"),
1207 }
1208 let output = dx_with_kernel(&input, kernel)?;
1209 assert_eq!(output.values.len(), candles.close.len());
1210 Ok(())
1211 }
1212
1213 fn check_dx_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1214 skip_if_unsupported!(kernel, test_name);
1215 let high = [2.0, 2.5, 3.0];
1216 let low = [1.0, 1.2, 2.1];
1217 let close = [1.5, 2.3, 2.2];
1218 let params = DxParams { period: Some(0) };
1219 let input = DxInput::from_hlc_slices(&high, &low, &close, params);
1220 let res = dx_with_kernel(&input, kernel);
1221 assert!(
1222 res.is_err(),
1223 "[{}] DX should fail with zero period",
1224 test_name
1225 );
1226 Ok(())
1227 }
1228
1229 fn check_dx_period_exceeds_length(
1230 test_name: &str,
1231 kernel: Kernel,
1232 ) -> Result<(), Box<dyn Error>> {
1233 skip_if_unsupported!(kernel, test_name);
1234 let high = [3.0, 4.0];
1235 let low = [2.0, 3.0];
1236 let close = [2.5, 3.5];
1237 let params = DxParams { period: Some(14) };
1238 let input = DxInput::from_hlc_slices(&high, &low, &close, params);
1239 let res = dx_with_kernel(&input, kernel);
1240 assert!(
1241 res.is_err(),
1242 "[{}] DX should fail with period exceeding length",
1243 test_name
1244 );
1245 Ok(())
1246 }
1247
1248 fn check_dx_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1249 skip_if_unsupported!(kernel, test_name);
1250 let high = [3.0];
1251 let low = [2.0];
1252 let close = [2.5];
1253 let params = DxParams { period: Some(14) };
1254 let input = DxInput::from_hlc_slices(&high, &low, &close, params);
1255 let res = dx_with_kernel(&input, kernel);
1256 assert!(
1257 res.is_err(),
1258 "[{}] DX should fail with insufficient data",
1259 test_name
1260 );
1261 Ok(())
1262 }
1263
1264 fn check_dx_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1265 skip_if_unsupported!(kernel, test_name);
1266 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1267 let candles = read_candles_from_csv(file_path)?;
1268
1269 let first_params = DxParams { period: Some(14) };
1270 let first_input = DxInput::from_candles(&candles, first_params);
1271 let first_result = dx_with_kernel(&first_input, kernel)?;
1272
1273 let second_params = DxParams { period: Some(14) };
1274 let second_input = DxInput::from_hlc_slices(
1275 &first_result.values,
1276 &first_result.values,
1277 &first_result.values,
1278 second_params,
1279 );
1280 let second_result = dx_with_kernel(&second_input, kernel)?;
1281 assert_eq!(second_result.values.len(), first_result.values.len());
1282 for i in 28..second_result.values.len() {
1283 assert!(
1284 !second_result.values[i].is_nan(),
1285 "[{}] Expected no NaN after index 28, found NaN at idx {}",
1286 test_name,
1287 i
1288 );
1289 }
1290 Ok(())
1291 }
1292
1293 fn check_dx_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1294 skip_if_unsupported!(kernel, test_name);
1295 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1296 let candles = read_candles_from_csv(file_path)?;
1297 let input = DxInput::from_candles(&candles, DxParams { period: Some(14) });
1298 let res = dx_with_kernel(&input, kernel)?;
1299 assert_eq!(res.values.len(), candles.close.len());
1300 if res.values.len() > 50 {
1301 for (i, &val) in res.values[50..].iter().enumerate() {
1302 assert!(
1303 !val.is_nan(),
1304 "[{}] Found unexpected NaN at out-index {}",
1305 test_name,
1306 50 + i
1307 );
1308 }
1309 }
1310 Ok(())
1311 }
1312
1313 fn check_dx_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1314 skip_if_unsupported!(kernel, test_name);
1315 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1316 let candles = read_candles_from_csv(file_path)?;
1317 let high = source_type(&candles, "high");
1318 let low = source_type(&candles, "low");
1319 let close = source_type(&candles, "close");
1320 let period = 14;
1321
1322 let input = DxInput::from_candles(
1323 &candles,
1324 DxParams {
1325 period: Some(period),
1326 },
1327 );
1328 let batch_output = dx_with_kernel(&input, kernel)?.values;
1329
1330 let mut stream = DxStream::try_new(DxParams {
1331 period: Some(period),
1332 })?;
1333 let mut stream_values = Vec::with_capacity(candles.close.len());
1334 for ((&h, &l), &c) in high.iter().zip(low).zip(close) {
1335 match stream.update(h, l, c) {
1336 Some(dx_val) => stream_values.push(dx_val),
1337 None => stream_values.push(f64::NAN),
1338 }
1339 }
1340
1341 assert_eq!(batch_output.len(), stream_values.len());
1342 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1343 if b.is_nan() && s.is_nan() {
1344 continue;
1345 }
1346 let diff = (b - s).abs();
1347 assert!(
1348 diff < 1e-9,
1349 "[{}] DX streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1350 test_name,
1351 i,
1352 b,
1353 s,
1354 diff
1355 );
1356 }
1357 Ok(())
1358 }
1359
1360 macro_rules! generate_all_dx_tests {
1361 ($($test_fn:ident),*) => {
1362 paste::paste! {
1363 $(
1364 #[test]
1365 fn [<$test_fn _scalar_f64>]() {
1366 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1367 }
1368 )*
1369 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1370 $(
1371 #[test]
1372 fn [<$test_fn _avx2_f64>]() {
1373 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1374 }
1375 #[test]
1376 fn [<$test_fn _avx512_f64>]() {
1377 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1378 }
1379 )*
1380 }
1381 }
1382 }
1383 #[cfg(debug_assertions)]
1384 fn check_dx_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1385 skip_if_unsupported!(kernel, test_name);
1386
1387 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1388 let candles = read_candles_from_csv(file_path)?;
1389
1390 let test_params = vec![
1391 DxParams::default(),
1392 DxParams { period: Some(2) },
1393 DxParams { period: Some(5) },
1394 DxParams { period: Some(7) },
1395 DxParams { period: Some(10) },
1396 DxParams { period: Some(14) },
1397 DxParams { period: Some(20) },
1398 DxParams { period: Some(30) },
1399 DxParams { period: Some(50) },
1400 DxParams { period: Some(100) },
1401 DxParams { period: Some(200) },
1402 ];
1403
1404 for (param_idx, params) in test_params.iter().enumerate() {
1405 let input = DxInput::from_candles(&candles, params.clone());
1406 let output = dx_with_kernel(&input, kernel)?;
1407
1408 for (i, &val) in output.values.iter().enumerate() {
1409 if val.is_nan() {
1410 continue;
1411 }
1412
1413 let bits = val.to_bits();
1414
1415 if bits == 0x11111111_11111111 {
1416 panic!(
1417 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1418 with params: period={} (param set {})",
1419 test_name,
1420 val,
1421 bits,
1422 i,
1423 params.period.unwrap_or(14),
1424 param_idx
1425 );
1426 }
1427
1428 if bits == 0x22222222_22222222 {
1429 panic!(
1430 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1431 with params: period={} (param set {})",
1432 test_name,
1433 val,
1434 bits,
1435 i,
1436 params.period.unwrap_or(14),
1437 param_idx
1438 );
1439 }
1440
1441 if bits == 0x33333333_33333333 {
1442 panic!(
1443 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1444 with params: period={} (param set {})",
1445 test_name,
1446 val,
1447 bits,
1448 i,
1449 params.period.unwrap_or(14),
1450 param_idx
1451 );
1452 }
1453 }
1454 }
1455
1456 Ok(())
1457 }
1458
1459 #[cfg(not(debug_assertions))]
1460 fn check_dx_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1461 Ok(())
1462 }
1463
1464 #[cfg(test)]
1465 #[allow(clippy::float_cmp)]
1466 fn check_dx_property(
1467 test_name: &str,
1468 kernel: Kernel,
1469 ) -> Result<(), Box<dyn std::error::Error>> {
1470 use proptest::prelude::*;
1471 skip_if_unsupported!(kernel, test_name);
1472
1473 let strat = (2usize..=50)
1474 .prop_flat_map(|period| {
1475 (
1476 100.0f64..5000.0f64,
1477 (period + 20)..400,
1478 0.001f64..0.05f64,
1479 -0.01f64..0.01f64,
1480 Just(period),
1481 )
1482 })
1483 .prop_map(|(base_price, data_len, volatility, trend, period)| {
1484 let mut high = Vec::with_capacity(data_len);
1485 let mut low = Vec::with_capacity(data_len);
1486 let mut close = Vec::with_capacity(data_len);
1487
1488 let mut price = base_price;
1489 for i in 0..data_len {
1490 let trend_component = trend * i as f64;
1491 let random_component = ((i * 7 + 13) % 17) as f64 / 17.0 - 0.5;
1492 price =
1493 base_price + trend_component + random_component * volatility * base_price;
1494
1495 let daily_volatility = volatility * price;
1496 let h = price + daily_volatility * (0.5 + ((i * 3) % 7) as f64 / 14.0);
1497 let l = price - daily_volatility * (0.5 + ((i * 5) % 7) as f64 / 14.0);
1498 let c = l + (h - l) * (0.3 + ((i * 11) % 7) as f64 / 10.0);
1499
1500 high.push(h);
1501 low.push(l);
1502 close.push(c);
1503 }
1504
1505 (high, low, close, period)
1506 });
1507
1508 proptest::test_runner::TestRunner::default()
1509 .run(&strat, |(high, low, close, period)| {
1510 let params = DxParams { period: Some(period) };
1511 let input = DxInput::from_hlc_slices(&high, &low, &close, params.clone());
1512
1513 let DxOutput { values: out } = dx_with_kernel(&input, kernel).unwrap();
1514 let DxOutput { values: ref_out } = dx_with_kernel(&input, Kernel::Scalar).unwrap();
1515
1516
1517 for (i, &val) in out.iter().enumerate() {
1518 if !val.is_nan() {
1519 prop_assert!(
1520 val >= -1e-9 && val <= 100.0 + 1e-9,
1521 "[{}] DX value {} at index {} is outside [0, 100] range",
1522 test_name, val, i
1523 );
1524 }
1525 }
1526
1527
1528
1529
1530 let warmup = period - 1;
1531 for i in 0..warmup {
1532 prop_assert!(
1533 out[i].is_nan(),
1534 "[{}] Expected NaN during warmup at index {}, got {}",
1535 test_name, i, out[i]
1536 );
1537 }
1538
1539
1540 if out.len() > warmup + 10 {
1541 for i in (warmup + 10)..out.len() {
1542 prop_assert!(
1543 !out[i].is_nan(),
1544 "[{}] Unexpected NaN after warmup at index {}",
1545 test_name, i
1546 );
1547 }
1548 }
1549
1550
1551 for (i, (&val, &ref_val)) in out.iter().zip(ref_out.iter()).enumerate() {
1552 if val.is_nan() && ref_val.is_nan() {
1553 continue;
1554 }
1555
1556 let diff = (val - ref_val).abs();
1557 prop_assert!(
1558 diff < 1e-9,
1559 "[{}] Kernel mismatch at index {}: {} vs {} (diff: {})",
1560 test_name, i, val, ref_val, diff
1561 );
1562 }
1563
1564
1565 let all_same_high = high.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10);
1566 let all_same_low = low.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10);
1567 let all_same_close = close.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10);
1568
1569 if all_same_high && all_same_low && all_same_close {
1570
1571 if out.len() > warmup + 10 {
1572 let stable_vals = &out[warmup + 10..];
1573 for (i, &val) in stable_vals.iter().enumerate() {
1574 if !val.is_nan() {
1575 prop_assert!(
1576 val < 1.0,
1577 "[{}] With constant prices, expected DX < 1.0, got {} at index {}",
1578 test_name, val, warmup + 10 + i
1579 );
1580 }
1581 }
1582 }
1583 }
1584
1585
1586 if period <= 20 && out.len() > 100 {
1587
1588 let mid = out.len() / 2;
1589 let first_half_avg_price = close[..mid].iter().sum::<f64>() / mid as f64;
1590 let second_half_avg_price = close[mid..].iter().sum::<f64>() / (out.len() - mid) as f64;
1591 let price_change = ((second_half_avg_price - first_half_avg_price) / first_half_avg_price).abs();
1592
1593
1594 if price_change > 0.05 {
1595
1596 let first_half_dx = &out[warmup..mid];
1597 let second_half_dx = &out[mid..];
1598
1599 let first_avg = first_half_dx.iter()
1600 .filter(|v| !v.is_nan())
1601 .sum::<f64>() / first_half_dx.len() as f64;
1602 let second_avg = second_half_dx.iter()
1603 .filter(|v| !v.is_nan())
1604 .sum::<f64>() / second_half_dx.len() as f64;
1605
1606
1607 prop_assert!(
1608 second_avg > 20.0 || first_avg > 20.0,
1609 "[{}] Expected higher average DX in trending market. First half avg: {}, Second half avg: {}",
1610 test_name, first_avg, second_avg
1611 );
1612 }
1613 }
1614
1615
1616
1617 if period <= 14 && out.len() > 50 {
1618
1619 let trend_base = close[0];
1620 let perfect_trend = (0..50)
1621 .map(|i| {
1622 let price = trend_base + (i as f64 * trend_base * 0.01);
1623 let h = price * 1.005;
1624 let l = price * 0.995;
1625 let c = price;
1626 (h, l, c)
1627 })
1628 .collect::<Vec<_>>();
1629
1630 let perfect_high: Vec<f64> = perfect_trend.iter().map(|&(h, _, _)| h).collect();
1631 let perfect_low: Vec<f64> = perfect_trend.iter().map(|&(_, l, _)| l).collect();
1632 let perfect_close: Vec<f64> = perfect_trend.iter().map(|&(_, _, c)| c).collect();
1633
1634 let perfect_input = DxInput::from_hlc_slices(&perfect_high, &perfect_low, &perfect_close, params.clone());
1635 let DxOutput { values: perfect_out } = dx_with_kernel(&perfect_input, kernel).unwrap();
1636
1637
1638 if perfect_out.len() > warmup + 10 {
1639 let stable_dx = &perfect_out[warmup + 10..];
1640 let avg_dx = stable_dx.iter()
1641 .filter(|v| !v.is_nan())
1642 .sum::<f64>() / stable_dx.len() as f64;
1643
1644 prop_assert!(
1645 avg_dx > 50.0,
1646 "[{}] Expected high DX (>50) in perfect trend, got avg {}",
1647 test_name, avg_dx
1648 );
1649 }
1650 }
1651
1652
1653
1654 if period <= 14 && out.len() > 50 {
1655
1656 let range_base = close[0];
1657 let ranging_data = (0..50)
1658 .map(|i| {
1659
1660 let price = if i % 4 < 2 {
1661 range_base * 1.01
1662 } else {
1663 range_base * 0.99
1664 };
1665 let h = price * 1.002;
1666 let l = price * 0.998;
1667 let c = price;
1668 (h, l, c)
1669 })
1670 .collect::<Vec<_>>();
1671
1672 let ranging_high: Vec<f64> = ranging_data.iter().map(|&(h, _, _)| h).collect();
1673 let ranging_low: Vec<f64> = ranging_data.iter().map(|&(_, l, _)| l).collect();
1674 let ranging_close: Vec<f64> = ranging_data.iter().map(|&(_, _, c)| c).collect();
1675
1676 let ranging_input = DxInput::from_hlc_slices(&ranging_high, &ranging_low, &ranging_close, params.clone());
1677 let DxOutput { values: ranging_out } = dx_with_kernel(&ranging_input, kernel).unwrap();
1678
1679
1680 if ranging_out.len() > warmup + 10 {
1681 let stable_dx = &ranging_out[warmup + 10..];
1682 let avg_dx = stable_dx.iter()
1683 .filter(|v| !v.is_nan())
1684 .sum::<f64>() / stable_dx.len() as f64;
1685
1686
1687
1688 prop_assert!(
1689 avg_dx < 65.0,
1690 "[{}] Expected moderate DX (<65) in ranging market, got avg {}",
1691 test_name, avg_dx
1692 );
1693 }
1694 }
1695
1696 Ok(())
1697 })
1698 .unwrap();
1699
1700 Ok(())
1701 }
1702
1703 generate_all_dx_tests!(
1704 check_dx_partial_params,
1705 check_dx_accuracy,
1706 check_dx_default_candles,
1707 check_dx_zero_period,
1708 check_dx_period_exceeds_length,
1709 check_dx_very_small_dataset,
1710 check_dx_reinput,
1711 check_dx_nan_handling,
1712 check_dx_streaming,
1713 check_dx_no_poison
1714 );
1715
1716 #[cfg(test)]
1717 generate_all_dx_tests!(check_dx_property);
1718
1719 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1720 skip_if_unsupported!(kernel, test);
1721
1722 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1723 let c = read_candles_from_csv(file)?;
1724
1725 let output = DxBatchBuilder::new().kernel(kernel).apply_candles(&c)?;
1726
1727 let def = DxParams::default();
1728 let row = output.values_for(&def).expect("default row missing");
1729
1730 assert_eq!(row.len(), c.close.len());
1731
1732 let expected = [
1733 43.72121533411883,
1734 41.47251493226443,
1735 43.43041386436222,
1736 43.22673458811955,
1737 51.65514026197179,
1738 ];
1739 let start = row.len() - 5;
1740 for (i, &v) in row[start..].iter().enumerate() {
1741 assert!(
1742 (v - expected[i]).abs() < 1e-4,
1743 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1744 );
1745 }
1746 Ok(())
1747 }
1748
1749 macro_rules! gen_batch_tests {
1750 ($fn_name:ident) => {
1751 paste::paste! {
1752 #[test] fn [<$fn_name _scalar>]() {
1753 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1754 }
1755 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1756 #[test] fn [<$fn_name _avx2>]() {
1757 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1758 }
1759 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1760 #[test] fn [<$fn_name _avx512>]() {
1761 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1762 }
1763 #[test] fn [<$fn_name _auto_detect>]() {
1764 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1765 }
1766 }
1767 };
1768 }
1769 fn check_batch_sweep(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1770 skip_if_unsupported!(kernel, test);
1771
1772 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1773 let c = read_candles_from_csv(file)?;
1774
1775 let output = DxBatchBuilder::new()
1776 .kernel(kernel)
1777 .period_range(10, 30, 5)
1778 .apply_candles(&c)?;
1779
1780 let expected_combos = 5;
1781 assert_eq!(output.combos.len(), expected_combos);
1782 assert_eq!(output.rows, expected_combos);
1783 assert_eq!(output.cols, c.close.len());
1784
1785 Ok(())
1786 }
1787
1788 #[cfg(debug_assertions)]
1789 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1790 skip_if_unsupported!(kernel, test);
1791
1792 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1793 let c = read_candles_from_csv(file)?;
1794
1795 let test_configs = vec![
1796 (2, 10, 2),
1797 (5, 25, 5),
1798 (30, 60, 15),
1799 (2, 5, 1),
1800 (10, 20, 2),
1801 (14, 14, 0),
1802 (5, 50, 15),
1803 (100, 200, 50),
1804 ];
1805
1806 for (cfg_idx, &(p_start, p_end, p_step)) in test_configs.iter().enumerate() {
1807 let output = DxBatchBuilder::new()
1808 .kernel(kernel)
1809 .period_range(p_start, p_end, p_step)
1810 .apply_candles(&c)?;
1811
1812 for (idx, &val) in output.values.iter().enumerate() {
1813 if val.is_nan() {
1814 continue;
1815 }
1816
1817 let bits = val.to_bits();
1818 let row = idx / output.cols;
1819 let col = idx % output.cols;
1820 let combo = &output.combos[row];
1821
1822 if bits == 0x11111111_11111111 {
1823 panic!(
1824 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1825 at row {} col {} (flat index {}) with params: period={}",
1826 test,
1827 cfg_idx,
1828 val,
1829 bits,
1830 row,
1831 col,
1832 idx,
1833 combo.period.unwrap_or(14)
1834 );
1835 }
1836
1837 if bits == 0x22222222_22222222 {
1838 panic!(
1839 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1840 at row {} col {} (flat index {}) with params: period={}",
1841 test,
1842 cfg_idx,
1843 val,
1844 bits,
1845 row,
1846 col,
1847 idx,
1848 combo.period.unwrap_or(14)
1849 );
1850 }
1851
1852 if bits == 0x33333333_33333333 {
1853 panic!(
1854 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1855 at row {} col {} (flat index {}) with params: period={}",
1856 test,
1857 cfg_idx,
1858 val,
1859 bits,
1860 row,
1861 col,
1862 idx,
1863 combo.period.unwrap_or(14)
1864 );
1865 }
1866 }
1867 }
1868
1869 Ok(())
1870 }
1871
1872 #[cfg(not(debug_assertions))]
1873 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1874 Ok(())
1875 }
1876
1877 gen_batch_tests!(check_batch_default_row);
1878 gen_batch_tests!(check_batch_sweep);
1879 gen_batch_tests!(check_batch_no_poison);
1880}
1881
1882#[cfg(feature = "python")]
1883#[pyfunction(name = "dx")]
1884#[pyo3(signature = (high, low, close, period, kernel=None))]
1885pub fn dx_py<'py>(
1886 py: Python<'py>,
1887 high: PyReadonlyArray1<f64>,
1888 low: PyReadonlyArray1<f64>,
1889 close: PyReadonlyArray1<f64>,
1890 period: usize,
1891 kernel: Option<&str>,
1892) -> PyResult<Bound<'py, PyArray1<f64>>> {
1893 let h = high.as_slice()?;
1894 let l = low.as_slice()?;
1895 let c = close.as_slice()?;
1896 let kern = validate_kernel(kernel, false)?;
1897 let params = DxParams {
1898 period: Some(period),
1899 };
1900 let inp = DxInput::from_hlc_slices(h, l, c, params);
1901 let vec_out: Vec<f64> = py
1902 .allow_threads(|| dx_with_kernel(&inp, kern).map(|o| o.values))
1903 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1904 Ok(vec_out.into_pyarray(py))
1905}
1906
1907#[cfg(feature = "python")]
1908#[pyfunction(name = "dx_batch")]
1909#[pyo3(signature = (high, low, close, period_range, kernel=None))]
1910pub fn dx_batch_py<'py>(
1911 py: Python<'py>,
1912 high: PyReadonlyArray1<f64>,
1913 low: PyReadonlyArray1<f64>,
1914 close: PyReadonlyArray1<f64>,
1915 period_range: (usize, usize, usize),
1916 kernel: Option<&str>,
1917) -> PyResult<Bound<'py, PyDict>> {
1918 use numpy::{PyArray1, PyArrayMethods};
1919 let h = high.as_slice()?;
1920 let l = low.as_slice()?;
1921 let c = close.as_slice()?;
1922 let sweep = DxBatchRange::from_tuple(period_range);
1923 let combos = expand_grid_checked(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1924 let rows = combos.len();
1925
1926 let cols = h.len().min(l.len()).min(c.len());
1927 let kern = validate_kernel(kernel, true)?;
1928 let DxBatchOutput { values, .. } = py
1929 .allow_threads(|| dx_batch_with_kernel(h, l, c, &sweep, kern))
1930 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1931
1932 let out_arr = PyArray1::from_vec(py, values);
1933
1934 let dict = PyDict::new(py);
1935 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1936 dict.set_item(
1937 "periods",
1938 combos
1939 .iter()
1940 .map(|p| p.period.unwrap() as u64)
1941 .collect::<Vec<_>>()
1942 .into_pyarray(py),
1943 )?;
1944 Ok(dict.into())
1945}
1946
1947#[cfg(all(feature = "python", feature = "cuda"))]
1948use crate::cuda::dx_wrapper::CudaDx;
1949#[cfg(all(feature = "python", feature = "cuda"))]
1950use crate::cuda::moving_averages::alma_wrapper::DeviceArrayF32 as DeviceArrayF32Cuda;
1951#[cfg(all(feature = "python", feature = "cuda"))]
1952use cust::context::Context as CudaContext;
1953#[cfg(all(feature = "python", feature = "cuda"))]
1954use std::sync::Arc;
1955
1956#[cfg(all(feature = "python", feature = "cuda"))]
1957#[pyfunction(name = "dx_cuda_batch_dev")]
1958#[pyo3(signature = (high_f32, low_f32, close_f32, period_range, device_id=0))]
1959pub fn dx_cuda_batch_dev_py<'py>(
1960 py: Python<'py>,
1961 high_f32: numpy::PyReadonlyArray1<'py, f32>,
1962 low_f32: numpy::PyReadonlyArray1<'py, f32>,
1963 close_f32: numpy::PyReadonlyArray1<'py, f32>,
1964 period_range: (usize, usize, usize),
1965 device_id: usize,
1966) -> PyResult<(DxDeviceArrayF32Py, Bound<'py, PyDict>)> {
1967 use crate::cuda::cuda_available;
1968 if !cuda_available() {
1969 return Err(PyValueError::new_err("CUDA not available"));
1970 }
1971 let h = high_f32.as_slice()?;
1972 let l = low_f32.as_slice()?;
1973 let c = close_f32.as_slice()?;
1974 let sweep = DxBatchRange::from_tuple(period_range);
1975 let (inner, combos, ctx, dev_id) = py.allow_threads(|| {
1976 let cuda = CudaDx::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1977 let ctx = cuda.context_arc();
1978 let dev_id = cuda.device_id();
1979 cuda.dx_batch_dev(h, l, c, &sweep)
1980 .map(|(arr, combos)| (arr, combos, ctx, dev_id))
1981 .map_err(|e| PyValueError::new_err(e.to_string()))
1982 })?;
1983 let dict = PyDict::new(py);
1984 dict.set_item(
1985 "periods",
1986 combos
1987 .iter()
1988 .map(|p| p.period.unwrap() as u64)
1989 .collect::<Vec<_>>()
1990 .into_pyarray(py),
1991 )?;
1992 Ok((
1993 DxDeviceArrayF32Py {
1994 inner,
1995 _ctx: ctx,
1996 device_id: dev_id,
1997 },
1998 dict,
1999 ))
2000}
2001
2002#[cfg(all(feature = "python", feature = "cuda"))]
2003#[pyfunction(name = "dx_cuda_many_series_one_param_dev")]
2004#[pyo3(signature = (high_tm_f32, low_tm_f32, close_tm_f32, cols, rows, period, device_id=0))]
2005pub fn dx_cuda_many_series_one_param_dev_py(
2006 py: Python<'_>,
2007 high_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
2008 low_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
2009 close_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
2010 cols: usize,
2011 rows: usize,
2012 period: usize,
2013 device_id: usize,
2014) -> PyResult<DxDeviceArrayF32Py> {
2015 use crate::cuda::cuda_available;
2016 if !cuda_available() {
2017 return Err(PyValueError::new_err("CUDA not available"));
2018 }
2019 let h = high_tm_f32.as_slice()?;
2020 let l = low_tm_f32.as_slice()?;
2021 let c = close_tm_f32.as_slice()?;
2022 let (inner, ctx, dev_id) = py.allow_threads(|| {
2023 let cuda = CudaDx::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2024 let ctx = cuda.context_arc();
2025 let dev_id = cuda.device_id();
2026 cuda.dx_many_series_one_param_time_major_dev(h, l, c, cols, rows, period)
2027 .map(|arr| (arr, ctx, dev_id))
2028 .map_err(|e| PyValueError::new_err(e.to_string()))
2029 })?;
2030 Ok(DxDeviceArrayF32Py {
2031 inner,
2032 _ctx: ctx,
2033 device_id: dev_id,
2034 })
2035}
2036
2037#[cfg(all(feature = "python", feature = "cuda"))]
2038#[pyclass(module = "ta_indicators.cuda", unsendable)]
2039pub struct DxDeviceArrayF32Py {
2040 pub(crate) inner: DeviceArrayF32Cuda,
2041 pub(crate) _ctx: Arc<CudaContext>,
2042 pub(crate) device_id: u32,
2043}
2044
2045#[cfg(all(feature = "python", feature = "cuda"))]
2046#[pymethods]
2047impl DxDeviceArrayF32Py {
2048 #[getter]
2049 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2050 let inner = &self.inner;
2051 let d = PyDict::new(py);
2052
2053 d.set_item("shape", (inner.rows, inner.cols))?;
2054
2055 d.set_item("typestr", "<f4")?;
2056
2057 d.set_item(
2058 "strides",
2059 (
2060 inner.cols * std::mem::size_of::<f32>(),
2061 std::mem::size_of::<f32>(),
2062 ),
2063 )?;
2064 let size = inner.rows.saturating_mul(inner.cols);
2065 let ptr = if size == 0 {
2066 0usize
2067 } else {
2068 inner.device_ptr() as usize
2069 };
2070 d.set_item("data", (ptr, false))?;
2071
2072 d.set_item("version", 3)?;
2073 Ok(d)
2074 }
2075
2076 fn __dlpack_device__(&self) -> PyResult<(i32, i32)> {
2077 unsafe {
2078 use cust::sys::cuPointerGetAttribute;
2079 let attr = cust::sys::CUpointer_attribute::CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL;
2080 let mut dev_ordinal: i32 = -1;
2081 let res = cuPointerGetAttribute(
2082 &mut dev_ordinal as *mut _ as *mut std::ffi::c_void,
2083 attr,
2084 self.inner.device_ptr(),
2085 );
2086 if res == cust::sys::CUresult::CUDA_SUCCESS && dev_ordinal >= 0 {
2087 return Ok((2, dev_ordinal));
2088 }
2089 Ok((2, self.device_id as i32))
2090 }
2091 }
2092
2093 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2094 fn __dlpack__<'py>(
2095 &mut self,
2096 py: Python<'py>,
2097 stream: Option<PyObject>,
2098 max_version: Option<PyObject>,
2099 dl_device: Option<PyObject>,
2100 copy: Option<PyObject>,
2101 ) -> PyResult<PyObject> {
2102 let (kdl, alloc_dev) = self.__dlpack_device__()?;
2103 if let Some(dev_obj) = dl_device.as_ref() {
2104 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2105 if dev_ty != kdl || dev_id != alloc_dev {
2106 let wants_copy = copy
2107 .as_ref()
2108 .and_then(|c| c.extract::<bool>(py).ok())
2109 .unwrap_or(false);
2110 if wants_copy {
2111 return Err(PyValueError::new_err(
2112 "device copy not implemented for __dlpack__",
2113 ));
2114 } else {
2115 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
2116 }
2117 }
2118 }
2119 }
2120 let _ = stream;
2121
2122 let dummy = cust::memory::DeviceBuffer::<f32>::from_slice(&[])
2123 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2124 let inner = std::mem::replace(
2125 &mut self.inner,
2126 DeviceArrayF32Cuda {
2127 buf: dummy,
2128 rows: 0,
2129 cols: 0,
2130 },
2131 );
2132
2133 let rows = inner.rows;
2134 let cols = inner.cols;
2135 let buf = inner.buf;
2136
2137 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2138
2139 crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d(
2140 py,
2141 buf,
2142 rows,
2143 cols,
2144 alloc_dev,
2145 max_version_bound,
2146 )
2147 }
2148}
2149
2150#[cfg(feature = "python")]
2151#[pyclass(name = "DxStream")]
2152pub struct DxStreamPy {
2153 inner: DxStream,
2154}
2155
2156#[cfg(feature = "python")]
2157#[pymethods]
2158impl DxStreamPy {
2159 #[new]
2160 pub fn new(period: usize) -> PyResult<Self> {
2161 let params = DxParams {
2162 period: Some(period),
2163 };
2164 let inner = DxStream::try_new(params)
2165 .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
2166 Ok(Self { inner })
2167 }
2168
2169 pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<f64> {
2170 self.inner.update(high, low, close)
2171 }
2172}
2173
2174#[inline]
2175pub fn dx_into_slice(dst: &mut [f64], input: &DxInput, kern: Kernel) -> Result<(), DxError> {
2176 let (h, l, c, len, first, chosen) = dx_prepare(input, kern)?;
2177 if dst.len() != len {
2178 return Err(DxError::OutputLengthMismatch {
2179 expected: len,
2180 got: dst.len(),
2181 });
2182 }
2183 unsafe {
2184 match chosen {
2185 Kernel::Scalar | Kernel::ScalarBatch => {
2186 dx_scalar(h, l, c, input.get_period(), first, dst)
2187 }
2188 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2189 Kernel::Avx2 | Kernel::Avx2Batch => dx_avx2(h, l, c, input.get_period(), first, dst),
2190 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2191 Kernel::Avx512 | Kernel::Avx512Batch => {
2192 dx_avx512(h, l, c, input.get_period(), first, dst)
2193 }
2194 _ => unreachable!(),
2195 }
2196 }
2197 let warm = first + input.get_period() - 1;
2198 for v in &mut dst[..warm] {
2199 *v = f64::NAN;
2200 }
2201 Ok(())
2202}
2203
2204#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2205#[wasm_bindgen]
2206pub fn dx_js(high: &[f64], low: &[f64], close: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
2207 let input = DxInput::from_hlc_slices(
2208 high,
2209 low,
2210 close,
2211 DxParams {
2212 period: Some(period),
2213 },
2214 );
2215 let mut out = vec![0.0; high.len().min(low.len()).min(close.len())];
2216 dx_into_slice(&mut out, &input, detect_best_kernel())
2217 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2218 Ok(out)
2219}
2220
2221#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2222#[wasm_bindgen]
2223pub fn dx_into(
2224 h_ptr: *const f64,
2225 l_ptr: *const f64,
2226 c_ptr: *const f64,
2227 out_ptr: *mut f64,
2228 len: usize,
2229 period: usize,
2230) -> Result<(), JsValue> {
2231 if [
2232 h_ptr as usize,
2233 l_ptr as usize,
2234 c_ptr as usize,
2235 out_ptr as usize,
2236 ]
2237 .iter()
2238 .any(|&p| p == 0)
2239 {
2240 return Err(JsValue::from_str("null pointer"));
2241 }
2242 unsafe {
2243 let h = core::slice::from_raw_parts(h_ptr, len);
2244 let l = core::slice::from_raw_parts(l_ptr, len);
2245 let c = core::slice::from_raw_parts(c_ptr, len);
2246 let inp = DxInput::from_hlc_slices(
2247 h,
2248 l,
2249 c,
2250 DxParams {
2251 period: Some(period),
2252 },
2253 );
2254
2255 if out_ptr == h_ptr as *mut f64
2256 || out_ptr == l_ptr as *mut f64
2257 || out_ptr == c_ptr as *mut f64
2258 {
2259 let mut tmp = vec![0.0; len];
2260 dx_into_slice(&mut tmp, &inp, detect_best_kernel())
2261 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2262 let dst = core::slice::from_raw_parts_mut(out_ptr, len);
2263 dst.copy_from_slice(&tmp);
2264 } else {
2265 let out = core::slice::from_raw_parts_mut(out_ptr, len);
2266 dx_into_slice(out, &inp, detect_best_kernel())
2267 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2268 }
2269 Ok(())
2270 }
2271}
2272
2273#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2274#[wasm_bindgen]
2275pub fn dx_alloc(len: usize) -> *mut f64 {
2276 let mut vec = Vec::<f64>::with_capacity(len);
2277 let ptr = vec.as_mut_ptr();
2278 std::mem::forget(vec);
2279 ptr
2280}
2281
2282#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2283#[wasm_bindgen]
2284pub fn dx_free(ptr: *mut f64, len: usize) {
2285 if !ptr.is_null() {
2286 unsafe {
2287 let _ = Vec::from_raw_parts(ptr, len, len);
2288 }
2289 }
2290}
2291
2292#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2293#[derive(Serialize, Deserialize)]
2294pub struct DxBatchConfig {
2295 pub period_range: (usize, usize, usize),
2296}
2297
2298#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2299#[derive(Serialize, Deserialize)]
2300pub struct DxBatchJsOutput {
2301 pub values: Vec<f64>,
2302 pub combos: Vec<DxParams>,
2303 pub rows: usize,
2304 pub cols: usize,
2305}
2306
2307#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2308#[wasm_bindgen(js_name = "dx_batch")]
2309pub fn dx_batch_unified_js(
2310 high: &[f64],
2311 low: &[f64],
2312 close: &[f64],
2313 config: JsValue,
2314) -> Result<JsValue, JsValue> {
2315 let cfg: DxBatchConfig = serde_wasm_bindgen::from_value(config)
2316 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2317 let sweep = DxBatchRange::from_tuple(cfg.period_range);
2318
2319 let rows = expand_grid_checked(&sweep)
2320 .map_err(|e| JsValue::from_str(&e.to_string()))?
2321 .len();
2322 let cols = high.len().min(low.len()).min(close.len());
2323 let mut buf_mu = make_uninit_matrix(rows, cols);
2324
2325 let first = (0..cols)
2326 .find(|&i| !high[i].is_nan() && !low[i].is_nan() && !close[i].is_nan())
2327 .ok_or_else(|| JsValue::from_str("AllValuesNaN"))?;
2328 let warm: Vec<usize> = expand_grid_checked(&sweep)
2329 .map_err(|e| JsValue::from_str(&e.to_string()))?
2330 .iter()
2331 .map(|p| first + p.period.unwrap() - 1)
2332 .collect();
2333 init_matrix_prefixes(&mut buf_mu, cols, &warm);
2334
2335 let mut guard = core::mem::ManuallyDrop::new(buf_mu);
2336 let out_slice: &mut [f64] =
2337 unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
2338 let combos = dx_batch_inner_into(
2339 high,
2340 low,
2341 close,
2342 &sweep,
2343 detect_best_batch_kernel(),
2344 false,
2345 out_slice,
2346 )
2347 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2348 let values = unsafe {
2349 Vec::from_raw_parts(
2350 guard.as_mut_ptr() as *mut f64,
2351 guard.len(),
2352 guard.capacity(),
2353 )
2354 };
2355 let js = DxBatchJsOutput {
2356 values,
2357 combos,
2358 rows,
2359 cols,
2360 };
2361 serde_wasm_bindgen::to_value(&js)
2362 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2363}
2364
2365#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2366#[wasm_bindgen]
2367pub fn dx_batch_into(
2368 high_ptr: *const f64,
2369 low_ptr: *const f64,
2370 close_ptr: *const f64,
2371 out_ptr: *mut f64,
2372 len: usize,
2373 period_start: usize,
2374 period_end: usize,
2375 period_step: usize,
2376) -> Result<(), JsValue> {
2377 if high_ptr.is_null() || low_ptr.is_null() || close_ptr.is_null() || out_ptr.is_null() {
2378 return Err(JsValue::from_str("Null pointer provided"));
2379 }
2380
2381 unsafe {
2382 let high = std::slice::from_raw_parts(high_ptr, len);
2383 let low = std::slice::from_raw_parts(low_ptr, len);
2384 let close = std::slice::from_raw_parts(close_ptr, len);
2385 let batch_range = DxBatchRange::from_tuple((period_start, period_end, period_step));
2386 let combos = DxParams::generate_batch_params((period_start, period_end, period_step));
2387 let n_combos = combos.len();
2388
2389 if high_ptr == out_ptr || low_ptr == out_ptr || close_ptr == out_ptr {
2390 let result = dx_batch_with_kernel(high, low, close, &batch_range, Kernel::Auto)
2391 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2392 let out = std::slice::from_raw_parts_mut(out_ptr, len * n_combos);
2393 out.copy_from_slice(&result.values);
2394 } else {
2395 let params = combos;
2396 let out = std::slice::from_raw_parts_mut(out_ptr, len * n_combos);
2397
2398 let first = high
2399 .iter()
2400 .zip(low)
2401 .zip(close)
2402 .position(|((&h, &l), &c)| !h.is_nan() && !l.is_nan() && !c.is_nan())
2403 .ok_or_else(|| JsValue::from_str("All values are NaN"))?;
2404
2405 let mut buf_uninit = make_uninit_matrix(params.len(), len);
2406 let warmup_periods: Vec<usize> = params
2407 .iter()
2408 .map(|p| first + p.period.unwrap() - 1)
2409 .collect();
2410 init_matrix_prefixes(&mut buf_uninit, len, &warmup_periods);
2411
2412 let buf_ptr = buf_uninit.as_mut_ptr() as *mut f64;
2413 std::mem::forget(buf_uninit);
2414 let slice_out = std::slice::from_raw_parts_mut(buf_ptr, params.len() * len);
2415
2416 for (i, param) in params.iter().enumerate() {
2417 let row_offset = i * len;
2418 let row = &mut slice_out[row_offset..row_offset + len];
2419
2420 let warmup = first + param.period.unwrap() - 1;
2421 dx_scalar(high, low, close, param.period.unwrap(), first, row);
2422 }
2423
2424 out.copy_from_slice(slice_out);
2425 }
2426 Ok(())
2427 }
2428}