1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::{PyDict, PyList};
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::utilities::data_loader::Candles;
16use crate::utilities::enums::Kernel;
17use crate::utilities::helpers::{
18 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
19 make_uninit_matrix,
20};
21#[cfg(feature = "python")]
22use crate::utilities::kernel_validation::validate_kernel;
23#[cfg(not(target_arch = "wasm32"))]
24use rayon::prelude::*;
25use std::error::Error;
26use std::mem::ManuallyDrop;
27use thiserror::Error;
28
29#[derive(Debug, Clone)]
30pub enum AdxData<'a> {
31 Candles {
32 candles: &'a Candles,
33 },
34 Slices {
35 high: &'a [f64],
36 low: &'a [f64],
37 close: &'a [f64],
38 },
39}
40
41#[derive(Debug, Clone)]
42pub struct AdxOutput {
43 pub values: Vec<f64>,
44}
45
46#[derive(Debug, Clone)]
47#[cfg_attr(
48 all(target_arch = "wasm32", feature = "wasm"),
49 derive(Serialize, Deserialize)
50)]
51pub struct AdxParams {
52 pub period: Option<usize>,
53}
54
55impl Default for AdxParams {
56 fn default() -> Self {
57 Self { period: Some(14) }
58 }
59}
60
61#[derive(Debug, Clone)]
62pub struct AdxInput<'a> {
63 pub data: AdxData<'a>,
64 pub params: AdxParams,
65}
66
67impl<'a> AdxInput<'a> {
68 #[inline]
69 pub fn from_candles(c: &'a Candles, p: AdxParams) -> Self {
70 Self {
71 data: AdxData::Candles { candles: c },
72 params: p,
73 }
74 }
75 #[inline]
76 pub fn from_slices(h: &'a [f64], l: &'a [f64], c: &'a [f64], p: AdxParams) -> Self {
77 Self {
78 data: AdxData::Slices {
79 high: h,
80 low: l,
81 close: c,
82 },
83 params: p,
84 }
85 }
86 #[inline]
87 pub fn with_default_candles(c: &'a Candles) -> Self {
88 Self::from_candles(c, AdxParams::default())
89 }
90 #[inline]
91 pub fn get_period(&self) -> usize {
92 self.params.period.unwrap_or(14)
93 }
94}
95
96#[derive(Copy, Clone, Debug)]
97pub struct AdxBuilder {
98 period: Option<usize>,
99 kernel: Kernel,
100}
101
102impl Default for AdxBuilder {
103 fn default() -> Self {
104 Self {
105 period: None,
106 kernel: Kernel::Auto,
107 }
108 }
109}
110
111impl AdxBuilder {
112 #[inline(always)]
113 pub fn new() -> Self {
114 Self::default()
115 }
116 #[inline(always)]
117 pub fn period(mut self, n: usize) -> Self {
118 self.period = Some(n);
119 self
120 }
121 #[inline(always)]
122 pub fn kernel(mut self, k: Kernel) -> Self {
123 self.kernel = k;
124 self
125 }
126 #[inline(always)]
127 pub fn apply(self, candles: &Candles) -> Result<AdxOutput, AdxError> {
128 let p = AdxParams {
129 period: self.period,
130 };
131 let i = AdxInput::from_candles(candles, p);
132 adx_with_kernel(&i, self.kernel)
133 }
134 #[inline(always)]
135 pub fn apply_slices(
136 self,
137 high: &[f64],
138 low: &[f64],
139 close: &[f64],
140 ) -> Result<AdxOutput, AdxError> {
141 let p = AdxParams {
142 period: self.period,
143 };
144 let i = AdxInput::from_slices(high, low, close, p);
145 adx_with_kernel(&i, self.kernel)
146 }
147 #[inline(always)]
148 pub fn into_stream(self) -> Result<AdxStream, AdxError> {
149 let p = AdxParams {
150 period: self.period,
151 };
152 AdxStream::try_new(p)
153 }
154}
155#[derive(Debug, thiserror::Error)]
156pub enum AdxError {
157 #[error("adx: All values are NaN.")]
158 AllValuesNaN,
159
160 #[error("adx: Invalid period: period = {period}, data_len = {data_len}")]
161 InvalidPeriod { period: usize, data_len: usize },
162
163 #[error("adx: Not enough valid data: needed = {needed}, valid = {valid}")]
164 NotEnoughValidData { needed: usize, valid: usize },
165
166 #[error("adx: Candle field error: {field}")]
167 CandleFieldError { field: &'static str },
168
169 #[error("adx: Input arrays must have the same length")]
170 InconsistentLengths,
171
172 #[error("adx: Input data slice is empty.")]
173 EmptyInputData,
174
175 #[error("adx: Output length mismatch: expected = {expected}, got = {got}")]
176 OutputLengthMismatch { expected: usize, got: usize },
177
178 #[error("adx: Invalid range: start = {start}, end = {end}, step = {step}")]
179 InvalidRange {
180 start: usize,
181 end: usize,
182 step: usize,
183 },
184
185 #[error("adx: Invalid kernel for batch: {0:?}")]
186 InvalidKernelForBatch(Kernel),
187}
188
189#[inline(always)]
190fn first_valid_triple(high: &[f64], low: &[f64], close: &[f64]) -> usize {
191 let fh = high.iter().position(|x| !x.is_nan()).unwrap_or(high.len());
192 let fl = low.iter().position(|x| !x.is_nan()).unwrap_or(low.len());
193 let fc = close
194 .iter()
195 .position(|x| !x.is_nan())
196 .unwrap_or(close.len());
197 fh.max(fl).max(fc)
198}
199
200#[inline]
201pub fn adx(input: &AdxInput) -> Result<AdxOutput, AdxError> {
202 adx_with_kernel(input, Kernel::Auto)
203}
204
205#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
206#[inline]
207pub fn adx_into(input: &AdxInput, out: &mut [f64]) -> Result<(), AdxError> {
208 adx_into_slice(out, input, Kernel::Auto)
209}
210
211pub fn adx_with_kernel(input: &AdxInput, kernel: Kernel) -> Result<AdxOutput, AdxError> {
212 let (high, low, close) = match &input.data {
213 AdxData::Candles { candles } => {
214 let h = candles
215 .select_candle_field("high")
216 .map_err(|_| AdxError::CandleFieldError { field: "high" })?;
217 let l = candles
218 .select_candle_field("low")
219 .map_err(|_| AdxError::CandleFieldError { field: "low" })?;
220 let c = candles
221 .select_candle_field("close")
222 .map_err(|_| AdxError::CandleFieldError { field: "close" })?;
223 (h, l, c)
224 }
225 AdxData::Slices { high, low, close } => (*high, *low, *close),
226 };
227
228 if high.len() != low.len() || high.len() != close.len() {
229 return Err(AdxError::InconsistentLengths);
230 }
231 let len = close.len();
232 if len == 0 {
233 return Err(AdxError::EmptyInputData);
234 }
235
236 let period = input.get_period();
237 if period == 0 || period > len {
238 return Err(AdxError::InvalidPeriod {
239 period,
240 data_len: len,
241 });
242 }
243
244 if high.iter().all(|x| x.is_nan())
245 || low.iter().all(|x| x.is_nan())
246 || close.iter().all(|x| x.is_nan())
247 {
248 return Err(AdxError::AllValuesNaN);
249 }
250
251 let first = first_valid_triple(high, low, close);
252 if len - first < period + 1 {
253 return Err(AdxError::NotEnoughValidData {
254 needed: period + 1,
255 valid: len - first,
256 });
257 }
258
259 let warm_end = first + (2 * period - 1);
260 let mut out = alloc_with_nan_prefix(len, warm_end);
261
262 let chosen = match kernel {
263 Kernel::Auto => detect_best_kernel(),
264 k => k,
265 };
266 unsafe {
267 match chosen {
268 Kernel::Scalar | Kernel::ScalarBatch => adx_scalar(
269 &high[first..],
270 &low[first..],
271 &close[first..],
272 period,
273 &mut out[first..],
274 ),
275 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
276 Kernel::Avx2 | Kernel::Avx2Batch => adx_avx2(
277 &high[first..],
278 &low[first..],
279 &close[first..],
280 period,
281 &mut out[first..],
282 ),
283 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
284 Kernel::Avx512 | Kernel::Avx512Batch => adx_avx512(
285 &high[first..],
286 &low[first..],
287 &close[first..],
288 period,
289 &mut out[first..],
290 ),
291 _ => unreachable!(),
292 }
293 }
294
295 Ok(AdxOutput { values: out })
296}
297
298#[inline]
299pub fn adx_scalar(high: &[f64], low: &[f64], close: &[f64], period: usize, out: &mut [f64]) {
300 let len = close.len();
301 if len <= period {
302 return;
303 }
304
305 let period_f64 = period as f64;
306 let reciprocal_period = 1.0 / period_f64;
307 let one_minus_rp = 1.0 - reciprocal_period;
308 let period_minus_one = period_f64 - 1.0;
309
310 let mut tr_sum = 0.0f64;
311 let mut plus_dm_sum = 0.0f64;
312 let mut minus_dm_sum = 0.0f64;
313
314 let mut prev_h = high[0];
315 let mut prev_l = low[0];
316 let mut prev_c = close[0];
317
318 let mut i = 1usize;
319 while i <= period {
320 let ch = high[i];
321 let cl = low[i];
322
323 let hl = ch - cl;
324 let hpc = (ch - prev_c).abs();
325 let lpc = (cl - prev_c).abs();
326 let tr = hl.max(hpc).max(lpc);
327
328 let up = ch - prev_h;
329 let down = prev_l - cl;
330 if up > down && up > 0.0 {
331 plus_dm_sum += up;
332 }
333 if down > up && down > 0.0 {
334 minus_dm_sum += down;
335 }
336 tr_sum += tr;
337
338 prev_h = ch;
339 prev_l = cl;
340 prev_c = close[i];
341 i += 1;
342 }
343
344 let mut atr = tr_sum;
345 let mut plus_dm_smooth = plus_dm_sum;
346 let mut minus_dm_smooth = minus_dm_sum;
347
348 let (plus_di_prev, minus_di_prev) = if atr != 0.0 {
349 (
350 (plus_dm_smooth / atr) * 100.0,
351 (minus_dm_smooth / atr) * 100.0,
352 )
353 } else {
354 (0.0, 0.0)
355 };
356 let sum_di_prev = plus_di_prev + minus_di_prev;
357 let mut dx_sum = if sum_di_prev != 0.0 {
358 ((plus_di_prev - minus_di_prev).abs() / sum_di_prev) * 100.0
359 } else {
360 0.0
361 };
362 let mut dx_count = 1usize;
363 let mut last_adx = 0.0f64;
364
365 let mut prev_h = high[period];
366 let mut prev_l = low[period];
367 let mut prev_c = close[period];
368
369 let mut i = period + 1;
370 while i < len {
371 let ch = high[i];
372 let cl = low[i];
373
374 let hl = ch - cl;
375 let hpc = (ch - prev_c).abs();
376 let lpc = (cl - prev_c).abs();
377 let tr = hl.max(hpc).max(lpc);
378
379 let up = ch - prev_h;
380 let down = prev_l - cl;
381 let plus_dm = if up > down && up > 0.0 { up } else { 0.0 };
382 let minus_dm = if down > up && down > 0.0 { down } else { 0.0 };
383
384 atr = atr * one_minus_rp + tr;
385 plus_dm_smooth = plus_dm_smooth * one_minus_rp + plus_dm;
386 minus_dm_smooth = minus_dm_smooth * one_minus_rp + minus_dm;
387
388 let (plus_di, minus_di) = if atr != 0.0 {
389 (
390 (plus_dm_smooth / atr) * 100.0,
391 (minus_dm_smooth / atr) * 100.0,
392 )
393 } else {
394 (0.0, 0.0)
395 };
396 let sum_di = plus_di + minus_di;
397 let dx = if sum_di != 0.0 {
398 ((plus_di - minus_di).abs() / sum_di) * 100.0
399 } else {
400 0.0
401 };
402
403 if dx_count < period {
404 dx_sum += dx;
405 dx_count += 1;
406 if dx_count == period {
407 last_adx = dx_sum * reciprocal_period;
408 out[i] = last_adx;
409 }
410 } else {
411 last_adx = (last_adx * period_minus_one + dx) * reciprocal_period;
412 out[i] = last_adx;
413 }
414
415 prev_h = ch;
416 prev_l = cl;
417 prev_c = close[i];
418 i += 1;
419 }
420}
421
422#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
423#[inline]
424pub fn adx_avx2(high: &[f64], low: &[f64], close: &[f64], period: usize, out: &mut [f64]) {
425 use core::arch::x86_64::*;
426 let len = close.len();
427 if len <= period {
428 return;
429 }
430
431 let period_f64 = period as f64;
432 let reciprocal_period = 1.0 / period_f64;
433 let one_minus_rp = 1.0 - reciprocal_period;
434 let period_minus_one = period_f64 - 1.0;
435
436 unsafe {
437 let hp = high.as_ptr();
438 let lp = low.as_ptr();
439 let cp = close.as_ptr();
440
441 let mut tr_sum = 0.0f64;
442 let mut plus_dm_sum = 0.0f64;
443 let mut minus_dm_sum = 0.0f64;
444
445 let mut prev_h_scalar = *hp.add(0);
446 let mut prev_l_scalar = *lp.add(0);
447 let mut prev_c_scalar = *cp.add(0);
448
449 let zero = _mm256_setzero_pd();
450 let sign_mask = _mm256_set1_pd(-0.0f64);
451
452 let mut i = 1usize;
453 while i + 3 <= period {
454 let ch = _mm256_loadu_pd(hp.add(i));
455 let cl = _mm256_loadu_pd(lp.add(i));
456 let pch = _mm256_loadu_pd(hp.add(i - 1));
457 let pcl = _mm256_loadu_pd(lp.add(i - 1));
458 let pcc = _mm256_loadu_pd(cp.add(i - 1));
459
460 let hl = _mm256_sub_pd(ch, cl);
461 let hpc = _mm256_andnot_pd(sign_mask, _mm256_sub_pd(ch, pcc));
462 let lpc = _mm256_andnot_pd(sign_mask, _mm256_sub_pd(cl, pcc));
463 let t0 = _mm256_max_pd(hl, hpc);
464 let trv = _mm256_max_pd(t0, lpc);
465
466 let up = _mm256_sub_pd(ch, pch);
467 let down = _mm256_sub_pd(pcl, cl);
468 let m_up_gt_down = _mm256_cmp_pd(up, down, _CMP_GT_OQ);
469 let m_up_gt_zero = _mm256_cmp_pd(up, zero, _CMP_GT_OQ);
470 let m_dn_gt_up = _mm256_cmp_pd(down, up, _CMP_GT_OQ);
471 let m_dn_gt_zero = _mm256_cmp_pd(down, zero, _CMP_GT_OQ);
472 let plus_mask = _mm256_and_pd(m_up_gt_down, m_up_gt_zero);
473 let minus_mask = _mm256_and_pd(m_dn_gt_up, m_dn_gt_zero);
474 let plus_v = _mm256_and_pd(plus_mask, up);
475 let minus_v = _mm256_and_pd(minus_mask, down);
476
477 let mut buf_tr = [0.0f64; 4];
478 let mut buf_p = [0.0f64; 4];
479 let mut buf_m = [0.0f64; 4];
480 _mm256_storeu_pd(buf_tr.as_mut_ptr(), trv);
481 _mm256_storeu_pd(buf_p.as_mut_ptr(), plus_v);
482 _mm256_storeu_pd(buf_m.as_mut_ptr(), minus_v);
483
484 tr_sum += buf_tr[0];
485 plus_dm_sum += buf_p[0];
486 minus_dm_sum += buf_m[0];
487 tr_sum += buf_tr[1];
488 plus_dm_sum += buf_p[1];
489 minus_dm_sum += buf_m[1];
490 tr_sum += buf_tr[2];
491 plus_dm_sum += buf_p[2];
492 minus_dm_sum += buf_m[2];
493 tr_sum += buf_tr[3];
494 plus_dm_sum += buf_p[3];
495 minus_dm_sum += buf_m[3];
496
497 prev_h_scalar = *hp.add(i + 3);
498 prev_l_scalar = *lp.add(i + 3);
499 prev_c_scalar = *cp.add(i + 3);
500
501 i += 4;
502 }
503 while i <= period {
504 let ch = *hp.add(i);
505 let cl = *lp.add(i);
506 let hl = ch - cl;
507 let hpc = (ch - prev_c_scalar).abs();
508 let lpc = (cl - prev_c_scalar).abs();
509 let t0 = if hl > hpc { hl } else { hpc };
510 let tr = if t0 > lpc { t0 } else { lpc };
511 let up = ch - prev_h_scalar;
512 let down = prev_l_scalar - cl;
513 if up > down && up > 0.0 {
514 plus_dm_sum += up;
515 }
516 if down > up && down > 0.0 {
517 minus_dm_sum += down;
518 }
519 tr_sum += tr;
520 prev_h_scalar = ch;
521 prev_l_scalar = cl;
522 prev_c_scalar = *cp.add(i);
523 i += 1;
524 }
525
526 let mut atr = tr_sum;
527 let mut plus_dm_smooth = plus_dm_sum;
528 let mut minus_dm_smooth = minus_dm_sum;
529
530 let (plus_di_prev, minus_di_prev) = if atr != 0.0 {
531 (
532 (plus_dm_smooth / atr) * 100.0,
533 (minus_dm_smooth / atr) * 100.0,
534 )
535 } else {
536 (0.0, 0.0)
537 };
538 let sum_di_prev = plus_di_prev + minus_di_prev;
539 let mut dx_sum = if sum_di_prev != 0.0 {
540 ((plus_di_prev - minus_di_prev).abs() / sum_di_prev) * 100.0
541 } else {
542 0.0
543 };
544 let mut dx_count = 1usize;
545 let mut last_adx = 0.0f64;
546
547 let mut prev_h = *hp.add(period);
548 let mut prev_l = *lp.add(period);
549 let mut prev_c = *cp.add(period);
550
551 let mut i = period + 1;
552 while i < len {
553 let ch = *hp.add(i);
554 let cl = *lp.add(i);
555
556 let hl = ch - cl;
557 let hpc = (ch - prev_c).abs();
558 let lpc = (cl - prev_c).abs();
559 let t0 = if hl > hpc { hl } else { hpc };
560 let tr = if t0 > lpc { t0 } else { lpc };
561
562 let up = ch - prev_h;
563 let down = prev_l - cl;
564 let plus_dm = if up > down && up > 0.0 { up } else { 0.0 };
565 let minus_dm = if down > up && down > 0.0 { down } else { 0.0 };
566
567 atr = atr * one_minus_rp + tr;
568 plus_dm_smooth = plus_dm_smooth * one_minus_rp + plus_dm;
569 minus_dm_smooth = minus_dm_smooth * one_minus_rp + minus_dm;
570
571 let (plus_di, minus_di) = if atr != 0.0 {
572 (
573 (plus_dm_smooth / atr) * 100.0,
574 (minus_dm_smooth / atr) * 100.0,
575 )
576 } else {
577 (0.0, 0.0)
578 };
579 let sum_di = plus_di + minus_di;
580 let dx = if sum_di != 0.0 {
581 ((plus_di - minus_di).abs() / sum_di) * 100.0
582 } else {
583 0.0
584 };
585
586 if dx_count < period {
587 dx_sum += dx;
588 dx_count += 1;
589 if dx_count == period {
590 last_adx = dx_sum * reciprocal_period;
591 *out.get_unchecked_mut(i) = last_adx;
592 }
593 } else {
594 last_adx = (last_adx * period_minus_one + dx) * reciprocal_period;
595 *out.get_unchecked_mut(i) = last_adx;
596 }
597
598 prev_h = ch;
599 prev_l = cl;
600 prev_c = *cp.add(i);
601 i += 1;
602 }
603 }
604}
605
606#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
607#[inline]
608pub fn adx_avx512(high: &[f64], low: &[f64], close: &[f64], period: usize, out: &mut [f64]) {
609 use core::arch::x86_64::*;
610 let len = close.len();
611 if len <= period {
612 return;
613 }
614
615 let period_f64 = period as f64;
616 let reciprocal_period = 1.0 / period_f64;
617 let one_minus_rp = 1.0 - reciprocal_period;
618 let period_minus_one = period_f64 - 1.0;
619
620 unsafe {
621 let hp = high.as_ptr();
622 let lp = low.as_ptr();
623 let cp = close.as_ptr();
624
625 let mut tr_sum = 0.0f64;
626 let mut plus_dm_sum = 0.0f64;
627 let mut minus_dm_sum = 0.0f64;
628
629 let mut prev_h_scalar = *hp.add(0);
630 let mut prev_l_scalar = *lp.add(0);
631 let mut prev_c_scalar = *cp.add(0);
632
633 let zero = _mm512_setzero_pd();
634 let sign_mask = _mm512_set1_pd(-0.0f64);
635
636 let mut i = 1usize;
637 while i + 7 <= period {
638 let ch = _mm512_loadu_pd(hp.add(i));
639 let cl = _mm512_loadu_pd(lp.add(i));
640 let pch = _mm512_loadu_pd(hp.add(i - 1));
641 let pcl = _mm512_loadu_pd(lp.add(i - 1));
642 let pcc = _mm512_loadu_pd(cp.add(i - 1));
643
644 let hl = _mm512_sub_pd(ch, cl);
645 let hpc = _mm512_andnot_pd(sign_mask, _mm512_sub_pd(ch, pcc));
646 let lpc = _mm512_andnot_pd(sign_mask, _mm512_sub_pd(cl, pcc));
647 let t0 = _mm512_max_pd(hl, hpc);
648 let trv = _mm512_max_pd(t0, lpc);
649
650 let up = _mm512_sub_pd(ch, pch);
651 let down = _mm512_sub_pd(pcl, cl);
652 let m_up_gt_down = _mm512_cmp_pd_mask(up, down, _CMP_GT_OQ);
653 let m_up_gt_zero = _mm512_cmp_pd_mask(up, zero, _CMP_GT_OQ);
654 let m_dn_gt_up = _mm512_cmp_pd_mask(down, up, _CMP_GT_OQ);
655 let m_dn_gt_zero = _mm512_cmp_pd_mask(down, zero, _CMP_GT_OQ);
656 let m_plus = m_up_gt_down & m_up_gt_zero;
657 let m_minus = m_dn_gt_up & m_dn_gt_zero;
658 let plus_v = _mm512_maskz_mov_pd(m_plus, up);
659 let minus_v = _mm512_maskz_mov_pd(m_minus, down);
660
661 let mut buf_tr = [0.0f64; 8];
662 let mut buf_p = [0.0f64; 8];
663 let mut buf_m = [0.0f64; 8];
664 _mm512_storeu_pd(buf_tr.as_mut_ptr(), trv);
665 _mm512_storeu_pd(buf_p.as_mut_ptr(), plus_v);
666 _mm512_storeu_pd(buf_m.as_mut_ptr(), minus_v);
667
668 tr_sum += buf_tr[0];
669 plus_dm_sum += buf_p[0];
670 minus_dm_sum += buf_m[0];
671 tr_sum += buf_tr[1];
672 plus_dm_sum += buf_p[1];
673 minus_dm_sum += buf_m[1];
674 tr_sum += buf_tr[2];
675 plus_dm_sum += buf_p[2];
676 minus_dm_sum += buf_m[2];
677 tr_sum += buf_tr[3];
678 plus_dm_sum += buf_p[3];
679 minus_dm_sum += buf_m[3];
680 tr_sum += buf_tr[4];
681 plus_dm_sum += buf_p[4];
682 minus_dm_sum += buf_m[4];
683 tr_sum += buf_tr[5];
684 plus_dm_sum += buf_p[5];
685 minus_dm_sum += buf_m[5];
686 tr_sum += buf_tr[6];
687 plus_dm_sum += buf_p[6];
688 minus_dm_sum += buf_m[6];
689 tr_sum += buf_tr[7];
690 plus_dm_sum += buf_p[7];
691 minus_dm_sum += buf_m[7];
692
693 prev_h_scalar = *hp.add(i + 7);
694 prev_l_scalar = *lp.add(i + 7);
695 prev_c_scalar = *cp.add(i + 7);
696
697 i += 8;
698 }
699 while i <= period {
700 let ch = *hp.add(i);
701 let cl = *lp.add(i);
702 let hl = ch - cl;
703 let hpc = (ch - prev_c_scalar).abs();
704 let lpc = (cl - prev_c_scalar).abs();
705 let t0 = if hl > hpc { hl } else { hpc };
706 let tr = if t0 > lpc { t0 } else { lpc };
707 let up = ch - prev_h_scalar;
708 let down = prev_l_scalar - cl;
709 if up > down && up > 0.0 {
710 plus_dm_sum += up;
711 }
712 if down > up && down > 0.0 {
713 minus_dm_sum += down;
714 }
715 tr_sum += tr;
716 prev_h_scalar = ch;
717 prev_l_scalar = cl;
718 prev_c_scalar = *cp.add(i);
719 i += 1;
720 }
721
722 let mut atr = tr_sum;
723 let mut plus_dm_smooth = plus_dm_sum;
724 let mut minus_dm_smooth = minus_dm_sum;
725
726 let (plus_di_prev, minus_di_prev) = if atr != 0.0 {
727 (
728 (plus_dm_smooth / atr) * 100.0,
729 (minus_dm_smooth / atr) * 100.0,
730 )
731 } else {
732 (0.0, 0.0)
733 };
734 let sum_di_prev = plus_di_prev + minus_di_prev;
735 let mut dx_sum = if sum_di_prev != 0.0 {
736 ((plus_di_prev - minus_di_prev).abs() / sum_di_prev) * 100.0
737 } else {
738 0.0
739 };
740 let mut dx_count = 1usize;
741 let mut last_adx = 0.0f64;
742
743 let mut prev_h = *hp.add(period);
744 let mut prev_l = *lp.add(period);
745 let mut prev_c = *cp.add(period);
746
747 let mut i = period + 1;
748 while i < len {
749 let ch = *hp.add(i);
750 let cl = *lp.add(i);
751
752 let hl = ch - cl;
753 let hpc = (ch - prev_c).abs();
754 let lpc = (cl - prev_c).abs();
755 let t0 = if hl > hpc { hl } else { hpc };
756 let tr = if t0 > lpc { t0 } else { lpc };
757
758 let up = ch - prev_h;
759 let down = prev_l - cl;
760 let plus_dm = if up > down && up > 0.0 { up } else { 0.0 };
761 let minus_dm = if down > up && down > 0.0 { down } else { 0.0 };
762
763 atr = atr * one_minus_rp + tr;
764 plus_dm_smooth = plus_dm_smooth * one_minus_rp + plus_dm;
765 minus_dm_smooth = minus_dm_smooth * one_minus_rp + minus_dm;
766
767 let (plus_di, minus_di) = if atr != 0.0 {
768 (
769 (plus_dm_smooth / atr) * 100.0,
770 (minus_dm_smooth / atr) * 100.0,
771 )
772 } else {
773 (0.0, 0.0)
774 };
775 let sum_di = plus_di + minus_di;
776 let dx = if sum_di != 0.0 {
777 ((plus_di - minus_di).abs() / sum_di) * 100.0
778 } else {
779 0.0
780 };
781
782 if dx_count < period {
783 dx_sum += dx;
784 dx_count += 1;
785 if dx_count == period {
786 last_adx = dx_sum * reciprocal_period;
787 *out.get_unchecked_mut(i) = last_adx;
788 }
789 } else {
790 last_adx = (last_adx * period_minus_one + dx) * reciprocal_period;
791 *out.get_unchecked_mut(i) = last_adx;
792 }
793
794 prev_h = ch;
795 prev_l = cl;
796 prev_c = *cp.add(i);
797 i += 1;
798 }
799 }
800}
801
802#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
803#[inline]
804pub fn adx_avx512_short(high: &[f64], low: &[f64], close: &[f64], period: usize, out: &mut [f64]) {
805 adx_avx512(high, low, close, period, out)
806}
807
808#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
809#[inline]
810pub fn adx_avx512_long(high: &[f64], low: &[f64], close: &[f64], period: usize, out: &mut [f64]) {
811 adx_avx512(high, low, close, period, out)
812}
813
814#[inline]
815pub fn adx_batch_with_kernel(
816 high: &[f64],
817 low: &[f64],
818 close: &[f64],
819 sweep: &AdxBatchRange,
820 k: Kernel,
821) -> Result<AdxBatchOutput, AdxError> {
822 let kernel = match k {
823 Kernel::Auto => detect_best_batch_kernel(),
824 other if other.is_batch() => other,
825 Kernel::Scalar => Kernel::ScalarBatch,
826 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
827 Kernel::Avx2 => Kernel::Avx2Batch,
828 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
829 Kernel::Avx512 => Kernel::Avx512Batch,
830 _ => return Err(AdxError::InvalidKernelForBatch(k)),
831 };
832
833 let simd = match kernel {
834 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
835 Kernel::Avx512Batch => Kernel::Avx512,
836 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
837 Kernel::Avx2Batch => Kernel::Avx2,
838 Kernel::ScalarBatch => Kernel::Scalar,
839 _ => Kernel::Scalar,
840 };
841 adx_batch_par_slice(high, low, close, sweep, simd)
842}
843
844#[derive(Clone, Debug)]
845pub struct AdxBatchRange {
846 pub period: (usize, usize, usize),
847}
848
849impl Default for AdxBatchRange {
850 fn default() -> Self {
851 Self {
852 period: (14, 263, 1),
853 }
854 }
855}
856
857const ADX_SHARED_PRECOMP_THRESHOLD: usize = 16;
858
859#[inline(always)]
860fn precompute_streams_scalar(
861 high: &[f64],
862 low: &[f64],
863 close: &[f64],
864 first: usize,
865) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
866 let tail_len = high.len() - first;
867 let mut tr = Vec::with_capacity(tail_len);
868 let mut pdm = Vec::with_capacity(tail_len);
869 let mut mdm = Vec::with_capacity(tail_len);
870 tr.push(0.0);
871 pdm.push(0.0);
872 mdm.push(0.0);
873 let mut prev_h = high[first];
874 let mut prev_l = low[first];
875 let mut prev_c = close[first];
876 let mut j = 1usize;
877 while first + j < high.len() {
878 let ch = high[first + j];
879 let cl = low[first + j];
880 let hl = ch - cl;
881 let hpc = (ch - prev_c).abs();
882 let lpc = (cl - prev_c).abs();
883 let trj = hl.max(hpc).max(lpc);
884 let up = ch - prev_h;
885 let down = prev_l - cl;
886 let plus = if up > down && up > 0.0 { up } else { 0.0 };
887 let minus = if down > up && down > 0.0 { down } else { 0.0 };
888 tr.push(trj);
889 pdm.push(plus);
890 mdm.push(minus);
891 prev_h = ch;
892 prev_l = cl;
893 prev_c = close[first + j];
894 j += 1;
895 }
896 (tr, pdm, mdm)
897}
898
899#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
900#[inline]
901unsafe fn precompute_streams_avx2(
902 high: &[f64],
903 low: &[f64],
904 close: &[f64],
905 first: usize,
906) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
907 use core::arch::x86_64::*;
908 let tail_len = high.len() - first;
909 let mut tr = Vec::with_capacity(tail_len);
910 let mut pdm = Vec::with_capacity(tail_len);
911 let mut mdm = Vec::with_capacity(tail_len);
912 tr.push(0.0);
913 pdm.push(0.0);
914 mdm.push(0.0);
915
916 let hp = high.as_ptr();
917 let lp = low.as_ptr();
918 let cp = close.as_ptr();
919
920 let sign_mask = _mm256_set1_pd(-0.0f64);
921 let zero = _mm256_setzero_pd();
922
923 let mut prev_h_scalar = *hp.add(first);
924 let mut prev_l_scalar = *lp.add(first);
925 let mut prev_c_scalar = *cp.add(first);
926
927 let mut j = 1usize;
928 while j + 3 < tail_len {
929 let base = first + j;
930 let ch = _mm256_loadu_pd(hp.add(base));
931 let cl = _mm256_loadu_pd(lp.add(base));
932 let pch = _mm256_loadu_pd(hp.add(base - 1));
933 let pcl = _mm256_loadu_pd(lp.add(base - 1));
934 let pcc = _mm256_loadu_pd(cp.add(base - 1));
935
936 let hl = _mm256_sub_pd(ch, cl);
937 let hpc = _mm256_andnot_pd(sign_mask, _mm256_sub_pd(ch, pcc));
938 let lpc = _mm256_andnot_pd(sign_mask, _mm256_sub_pd(cl, pcc));
939 let t0 = _mm256_max_pd(hl, hpc);
940 let trv = _mm256_max_pd(t0, lpc);
941
942 let up = _mm256_sub_pd(ch, pch);
943 let down = _mm256_sub_pd(pcl, cl);
944 let m_up_gt_down = _mm256_cmp_pd(up, down, _CMP_GT_OQ);
945 let m_up_gt_zero = _mm256_cmp_pd(up, zero, _CMP_GT_OQ);
946 let m_dn_gt_up = _mm256_cmp_pd(down, up, _CMP_GT_OQ);
947 let m_dn_gt_zero = _mm256_cmp_pd(down, zero, _CMP_GT_OQ);
948 let plus_mask = _mm256_and_pd(m_up_gt_down, m_up_gt_zero);
949 let minus_mask = _mm256_and_pd(m_dn_gt_up, m_dn_gt_zero);
950 let plus_v = _mm256_and_pd(plus_mask, up);
951 let minus_v = _mm256_and_pd(minus_mask, down);
952
953 let mut buf_tr = [0.0f64; 4];
954 let mut buf_p = [0.0f64; 4];
955 let mut buf_m = [0.0f64; 4];
956 _mm256_storeu_pd(buf_tr.as_mut_ptr(), trv);
957 _mm256_storeu_pd(buf_p.as_mut_ptr(), plus_v);
958 _mm256_storeu_pd(buf_m.as_mut_ptr(), minus_v);
959
960 tr.push(buf_tr[0]);
961 pdm.push(buf_p[0]);
962 mdm.push(buf_m[0]);
963 tr.push(buf_tr[1]);
964 pdm.push(buf_p[1]);
965 mdm.push(buf_m[1]);
966 tr.push(buf_tr[2]);
967 pdm.push(buf_p[2]);
968 mdm.push(buf_m[2]);
969 tr.push(buf_tr[3]);
970 pdm.push(buf_p[3]);
971 mdm.push(buf_m[3]);
972
973 prev_h_scalar = *hp.add(base + 3);
974 prev_l_scalar = *lp.add(base + 3);
975 prev_c_scalar = *cp.add(base + 3);
976 j += 4;
977 }
978 while j < tail_len {
979 let ch = *hp.add(first + j);
980 let cl = *lp.add(first + j);
981 let hl = ch - cl;
982 let hpc = (ch - prev_c_scalar).abs();
983 let lpc = (cl - prev_c_scalar).abs();
984 let trj = if hl > hpc { hl } else { hpc }.max(lpc);
985 let up = ch - prev_h_scalar;
986 let down = prev_l_scalar - cl;
987 let plus = if up > down && up > 0.0 { up } else { 0.0 };
988 let minus = if down > up && down > 0.0 { down } else { 0.0 };
989 tr.push(trj);
990 pdm.push(plus);
991 mdm.push(minus);
992 prev_h_scalar = ch;
993 prev_l_scalar = cl;
994 prev_c_scalar = *cp.add(first + j);
995 j += 1;
996 }
997 (tr, pdm, mdm)
998}
999
1000#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1001#[inline]
1002unsafe fn precompute_streams_avx512(
1003 high: &[f64],
1004 low: &[f64],
1005 close: &[f64],
1006 first: usize,
1007) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
1008 use core::arch::x86_64::*;
1009 let tail_len = high.len() - first;
1010 let mut tr = Vec::with_capacity(tail_len);
1011 let mut pdm = Vec::with_capacity(tail_len);
1012 let mut mdm = Vec::with_capacity(tail_len);
1013 tr.push(0.0);
1014 pdm.push(0.0);
1015 mdm.push(0.0);
1016
1017 let hp = high.as_ptr();
1018 let lp = low.as_ptr();
1019 let cp = close.as_ptr();
1020
1021 let sign_mask = _mm512_set1_pd(-0.0f64);
1022 let zero = _mm512_setzero_pd();
1023
1024 let mut prev_h_scalar = *hp.add(first);
1025 let mut prev_l_scalar = *lp.add(first);
1026 let mut prev_c_scalar = *cp.add(first);
1027
1028 let mut j = 1usize;
1029 while j + 7 < tail_len {
1030 let base = first + j;
1031 let ch = _mm512_loadu_pd(hp.add(base));
1032 let cl = _mm512_loadu_pd(lp.add(base));
1033 let pch = _mm512_loadu_pd(hp.add(base - 1));
1034 let pcl = _mm512_loadu_pd(lp.add(base - 1));
1035 let pcc = _mm512_loadu_pd(cp.add(base - 1));
1036
1037 let hl = _mm512_sub_pd(ch, cl);
1038 let hpc = _mm512_andnot_pd(sign_mask, _mm512_sub_pd(ch, pcc));
1039 let lpc = _mm512_andnot_pd(sign_mask, _mm512_sub_pd(cl, pcc));
1040 let t0 = _mm512_max_pd(hl, hpc);
1041 let trv = _mm512_max_pd(t0, lpc);
1042
1043 let up = _mm512_sub_pd(ch, pch);
1044 let down = _mm512_sub_pd(pcl, cl);
1045 let m_up_gt_down = _mm512_cmp_pd_mask(up, down, _CMP_GT_OQ);
1046 let m_up_gt_zero = _mm512_cmp_pd_mask(up, zero, _CMP_GT_OQ);
1047 let m_dn_gt_up = _mm512_cmp_pd_mask(down, up, _CMP_GT_OQ);
1048 let m_dn_gt_zero = _mm512_cmp_pd_mask(down, zero, _CMP_GT_OQ);
1049 let m_plus = m_up_gt_down & m_up_gt_zero;
1050 let m_minus = m_dn_gt_up & m_dn_gt_zero;
1051 let plus_v = _mm512_maskz_mov_pd(m_plus, up);
1052 let minus_v = _mm512_maskz_mov_pd(m_minus, down);
1053
1054 let mut buf_tr = [0.0f64; 8];
1055 let mut buf_p = [0.0f64; 8];
1056 let mut buf_m = [0.0f64; 8];
1057 _mm512_storeu_pd(buf_tr.as_mut_ptr(), trv);
1058 _mm512_storeu_pd(buf_p.as_mut_ptr(), plus_v);
1059 _mm512_storeu_pd(buf_m.as_mut_ptr(), minus_v);
1060
1061 for k in 0..8 {
1062 tr.push(buf_tr[k]);
1063 pdm.push(buf_p[k]);
1064 mdm.push(buf_m[k]);
1065 }
1066 prev_h_scalar = *hp.add(base + 7);
1067 prev_l_scalar = *lp.add(base + 7);
1068 prev_c_scalar = *cp.add(base + 7);
1069 j += 8;
1070 }
1071 while j < tail_len {
1072 let ch = *hp.add(first + j);
1073 let cl = *lp.add(first + j);
1074 let hl = ch - cl;
1075 let hpc = (ch - prev_c_scalar).abs();
1076 let lpc = (cl - prev_c_scalar).abs();
1077 let trj = if hl > hpc { hl } else { hpc }.max(lpc);
1078 let up = ch - prev_h_scalar;
1079 let down = prev_l_scalar - cl;
1080 let plus = if up > down && up > 0.0 { up } else { 0.0 };
1081 let minus = if down > up && down > 0.0 { down } else { 0.0 };
1082 tr.push(trj);
1083 pdm.push(plus);
1084 mdm.push(minus);
1085 prev_h_scalar = ch;
1086 prev_l_scalar = cl;
1087 prev_c_scalar = *cp.add(first + j);
1088 j += 1;
1089 }
1090 (tr, pdm, mdm)
1091}
1092
1093#[derive(Clone, Debug, Default)]
1094pub struct AdxBatchBuilder {
1095 range: AdxBatchRange,
1096 kernel: Kernel,
1097}
1098
1099impl AdxBatchBuilder {
1100 pub fn new() -> Self {
1101 Self::default()
1102 }
1103 pub fn kernel(mut self, k: Kernel) -> Self {
1104 self.kernel = k;
1105 self
1106 }
1107 #[inline]
1108 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1109 self.range.period = (start, end, step);
1110 self
1111 }
1112 #[inline]
1113 pub fn period_static(mut self, p: usize) -> Self {
1114 self.range.period = (p, p, 0);
1115 self
1116 }
1117 pub fn apply_slices(
1118 self,
1119 high: &[f64],
1120 low: &[f64],
1121 close: &[f64],
1122 ) -> Result<AdxBatchOutput, AdxError> {
1123 adx_batch_with_kernel(high, low, close, &self.range, self.kernel)
1124 }
1125 pub fn apply_candles(self, candles: &Candles) -> Result<AdxBatchOutput, AdxError> {
1126 let high = candles
1127 .select_candle_field("high")
1128 .map_err(|_| AdxError::CandleFieldError { field: "high" })?;
1129 let low = candles
1130 .select_candle_field("low")
1131 .map_err(|_| AdxError::CandleFieldError { field: "low" })?;
1132 let close = candles
1133 .select_candle_field("close")
1134 .map_err(|_| AdxError::CandleFieldError { field: "close" })?;
1135 self.apply_slices(high, low, close)
1136 }
1137 pub fn with_default_candles(c: &Candles) -> Result<AdxBatchOutput, AdxError> {
1138 AdxBatchBuilder::new().kernel(Kernel::Auto).apply_candles(c)
1139 }
1140}
1141
1142#[derive(Clone, Debug)]
1143pub struct AdxBatchOutput {
1144 pub values: Vec<f64>,
1145 pub combos: Vec<AdxParams>,
1146 pub rows: usize,
1147 pub cols: usize,
1148}
1149
1150impl AdxBatchOutput {
1151 pub fn row_for_params(&self, p: &AdxParams) -> Option<usize> {
1152 self.combos
1153 .iter()
1154 .position(|c| c.period.unwrap_or(14) == p.period.unwrap_or(14))
1155 }
1156
1157 pub fn values_for(&self, p: &AdxParams) -> Option<&[f64]> {
1158 self.row_for_params(p).map(|row| {
1159 let start = row * self.cols;
1160 &self.values[start..start + self.cols]
1161 })
1162 }
1163}
1164
1165#[cfg(all(feature = "python", feature = "cuda"))]
1166use crate::cuda::adx_wrapper::CudaAdx;
1167#[cfg(all(feature = "python", feature = "cuda"))]
1168use crate::cuda::moving_averages::DeviceArrayF32;
1169#[cfg(all(feature = "python", feature = "cuda"))]
1170use cust::context::Context;
1171#[cfg(all(feature = "python", feature = "cuda"))]
1172use cust::memory::DeviceBuffer;
1173#[cfg(all(feature = "python", feature = "cuda"))]
1174use std::sync::Arc;
1175#[cfg(all(feature = "python", feature = "cuda"))]
1176#[pyfunction(name = "adx_cuda_batch_dev")]
1177#[pyo3(signature = (high_f32, low_f32, close_f32, period_range, device_id=0))]
1178pub fn adx_cuda_batch_dev_py<'py>(
1179 py: Python<'py>,
1180 high_f32: numpy::PyReadonlyArray1<'py, f32>,
1181 low_f32: numpy::PyReadonlyArray1<'py, f32>,
1182 close_f32: numpy::PyReadonlyArray1<'py, f32>,
1183 period_range: (usize, usize, usize),
1184 device_id: usize,
1185) -> PyResult<(DeviceArrayF32AdxPy, Bound<'py, PyDict>)> {
1186 use crate::cuda::cuda_available;
1187 if !cuda_available() {
1188 return Err(PyValueError::new_err("CUDA not available"));
1189 }
1190 let h = high_f32.as_slice()?;
1191 let l = low_f32.as_slice()?;
1192 let c = close_f32.as_slice()?;
1193 let sweep = AdxBatchRange {
1194 period: period_range,
1195 };
1196 let (inner, combos, ctx, dev_id) = py.allow_threads(|| {
1197 let cuda = CudaAdx::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1198 let ctx = cuda.ctx();
1199 let dev_id = cuda.device_id();
1200 let (dev_arr, cmb) = cuda
1201 .adx_batch_dev(h, l, c, &sweep)
1202 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1203 Ok::<_, pyo3::PyErr>((dev_arr, cmb, ctx, dev_id))
1204 })?;
1205 let dict = PyDict::new(py);
1206 dict.set_item(
1207 "periods",
1208 combos
1209 .iter()
1210 .map(|p| p.period.unwrap() as u64)
1211 .collect::<Vec<_>>()
1212 .into_pyarray(py),
1213 )?;
1214 Ok((DeviceArrayF32AdxPy::new(inner, ctx, dev_id), dict))
1215}
1216
1217#[cfg(all(feature = "python", feature = "cuda"))]
1218#[pyfunction(name = "adx_cuda_many_series_one_param_dev")]
1219#[pyo3(signature = (high_tm_f32, low_tm_f32, close_tm_f32, cols, rows, period, device_id=0))]
1220pub fn adx_cuda_many_series_one_param_dev_py(
1221 py: Python<'_>,
1222 high_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
1223 low_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
1224 close_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
1225 cols: usize,
1226 rows: usize,
1227 period: usize,
1228 device_id: usize,
1229) -> PyResult<DeviceArrayF32AdxPy> {
1230 use crate::cuda::cuda_available;
1231 if !cuda_available() {
1232 return Err(PyValueError::new_err("CUDA not available"));
1233 }
1234 let h = high_tm_f32.as_slice()?;
1235 let l = low_tm_f32.as_slice()?;
1236 let c = close_tm_f32.as_slice()?;
1237 let (inner, ctx, dev_id) = py.allow_threads(|| {
1238 let cuda = CudaAdx::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1239 let ctx = cuda.ctx();
1240 let dev_id = cuda.device_id();
1241 let arr = cuda
1242 .adx_many_series_one_param_time_major_dev(h, l, c, cols, rows, period)
1243 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1244 Ok::<_, pyo3::PyErr>((arr, ctx, dev_id))
1245 })?;
1246 Ok(DeviceArrayF32AdxPy::new(inner, ctx, dev_id))
1247}
1248
1249#[cfg(all(feature = "python", feature = "cuda"))]
1250#[pyclass(module = "ta_indicators.cuda", name = "DeviceArrayF32Adx", unsendable)]
1251pub struct DeviceArrayF32AdxPy {
1252 pub(crate) inner: DeviceArrayF32,
1253 _ctx_guard: Arc<Context>,
1254 _device_id: u32,
1255}
1256
1257#[cfg(all(feature = "python", feature = "cuda"))]
1258#[pymethods]
1259impl DeviceArrayF32AdxPy {
1260 #[new]
1261 fn py_new() -> PyResult<Self> {
1262 Err(pyo3::exceptions::PyTypeError::new_err(
1263 "use factory methods from CUDA functions",
1264 ))
1265 }
1266
1267 #[getter]
1268 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1269 let inner = &self.inner;
1270 let d = PyDict::new(py);
1271 d.set_item("shape", (inner.rows, inner.cols))?;
1272 d.set_item("typestr", "<f4")?;
1273 d.set_item(
1274 "strides",
1275 (
1276 inner.cols * std::mem::size_of::<f32>(),
1277 std::mem::size_of::<f32>(),
1278 ),
1279 )?;
1280 let size = inner.rows.saturating_mul(inner.cols);
1281 let ptr = if size == 0 {
1282 0usize
1283 } else {
1284 inner.device_ptr() as usize
1285 };
1286 d.set_item("data", (ptr, false))?;
1287
1288 d.set_item("version", 3)?;
1289 Ok(d)
1290 }
1291
1292 fn __dlpack_device__(&self) -> PyResult<(i32, i32)> {
1293 let mut device_ordinal: i32 = 0;
1294 unsafe {
1295 let attr = cust::sys::CUpointer_attribute::CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL;
1296 let mut value = std::mem::MaybeUninit::<i32>::uninit();
1297 let err = cust::sys::cuPointerGetAttribute(
1298 value.as_mut_ptr() as *mut std::ffi::c_void,
1299 attr,
1300 self.inner.buf.as_device_ptr().as_raw(),
1301 );
1302 if err == cust::sys::CUresult::CUDA_SUCCESS {
1303 device_ordinal = value.assume_init();
1304 } else {
1305 let _ = cust::sys::cuCtxGetDevice(&mut device_ordinal);
1306 }
1307 }
1308 Ok((2, device_ordinal))
1309 }
1310
1311 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1312 fn __dlpack__<'py>(
1313 &mut self,
1314 py: Python<'py>,
1315 stream: Option<pyo3::PyObject>,
1316 max_version: Option<pyo3::PyObject>,
1317 dl_device: Option<pyo3::PyObject>,
1318 copy: Option<pyo3::PyObject>,
1319 ) -> PyResult<PyObject> {
1320 let (kdl, alloc_dev) = self.__dlpack_device__()?;
1321 if let Some(dev_obj) = dl_device.as_ref() {
1322 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1323 if dev_ty != kdl || dev_id != alloc_dev {
1324 let wants_copy = copy
1325 .as_ref()
1326 .and_then(|c| c.extract::<bool>(py).ok())
1327 .unwrap_or(false);
1328 if wants_copy {
1329 return Err(PyValueError::new_err(
1330 "device copy not implemented for __dlpack__",
1331 ));
1332 } else {
1333 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
1334 }
1335 }
1336 }
1337 }
1338 let _ = stream;
1339
1340 let dummy =
1341 DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
1342 let inner = std::mem::replace(
1343 &mut self.inner,
1344 DeviceArrayF32 {
1345 buf: dummy,
1346 rows: 0,
1347 cols: 0,
1348 },
1349 );
1350
1351 let rows = inner.rows;
1352 let cols = inner.cols;
1353 let buf = inner.buf;
1354
1355 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1356
1357 crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d(
1358 py,
1359 buf,
1360 rows,
1361 cols,
1362 alloc_dev,
1363 max_version_bound,
1364 )
1365 }
1366}
1367
1368#[cfg(all(feature = "python", feature = "cuda"))]
1369impl DeviceArrayF32AdxPy {
1370 pub fn new(inner: DeviceArrayF32, ctx_guard: Arc<Context>, device_id: u32) -> Self {
1371 Self {
1372 inner,
1373 _ctx_guard: ctx_guard,
1374 _device_id: device_id,
1375 }
1376 }
1377}
1378
1379#[inline(always)]
1380fn expand_grid(r: &AdxBatchRange) -> Vec<AdxParams> {
1381 fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
1382 if start == end || step == 0 {
1383 return vec![start];
1384 }
1385 if start < end {
1386 return (start..=end).step_by(step.max(1)).collect();
1387 }
1388
1389 let mut v = Vec::new();
1390 let mut cur = start;
1391 let s = step.max(1);
1392 while cur >= end {
1393 v.push(cur);
1394 if cur < s {
1395 break;
1396 }
1397 cur -= s;
1398 if cur == usize::MAX {
1399 break;
1400 }
1401 }
1402 v
1403 }
1404 let periods = axis_usize(r.period);
1405 let mut out = Vec::with_capacity(periods.len());
1406 for &p in &periods {
1407 out.push(AdxParams { period: Some(p) });
1408 }
1409 out
1410}
1411
1412#[inline(always)]
1413fn adx_batch_inner_into(
1414 high: &[f64],
1415 low: &[f64],
1416 close: &[f64],
1417 sweep: &AdxBatchRange,
1418 kern: Kernel,
1419 parallel: bool,
1420 out: &mut [f64],
1421) -> Result<Vec<AdxParams>, AdxError> {
1422 if high.len() != low.len() || high.len() != close.len() {
1423 return Err(AdxError::InconsistentLengths);
1424 }
1425 let combos = expand_grid(sweep);
1426 if combos.is_empty() {
1427 return Err(AdxError::InvalidRange {
1428 start: sweep.period.0,
1429 end: sweep.period.1,
1430 step: sweep.period.2,
1431 });
1432 }
1433
1434 let rows = combos.len();
1435 let cols = close.len();
1436 let expected = rows.checked_mul(cols).ok_or(AdxError::InvalidRange {
1437 start: sweep.period.0,
1438 end: sweep.period.1,
1439 step: sweep.period.2,
1440 })?;
1441 if out.len() != expected {
1442 return Err(AdxError::OutputLengthMismatch {
1443 expected,
1444 got: out.len(),
1445 });
1446 }
1447
1448 let first = first_valid_triple(high, low, close);
1449 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1450 if cols - first < max_p + 1 {
1451 return Err(AdxError::NotEnoughValidData {
1452 needed: max_p + 1,
1453 valid: cols - first,
1454 });
1455 }
1456
1457 let mut warms: Vec<usize> = Vec::with_capacity(combos.len());
1458 for c in &combos {
1459 let p = c.period.unwrap();
1460 let two_p = p.checked_mul(2).ok_or(AdxError::InvalidRange {
1461 start: sweep.period.0,
1462 end: sweep.period.1,
1463 step: sweep.period.2,
1464 })?;
1465 let warm = first
1466 .checked_add(two_p.saturating_sub(1))
1467 .ok_or(AdxError::InvalidRange {
1468 start: sweep.period.0,
1469 end: sweep.period.1,
1470 step: sweep.period.2,
1471 })?;
1472 warms.push(warm);
1473 }
1474 let out_mu = unsafe {
1475 std::slice::from_raw_parts_mut(
1476 out.as_mut_ptr() as *mut std::mem::MaybeUninit<f64>,
1477 out.len(),
1478 )
1479 };
1480 init_matrix_prefixes(&mut { out_mu }, cols, &warms);
1481
1482 let use_shared = combos.len() >= ADX_SHARED_PRECOMP_THRESHOLD;
1483
1484 if use_shared {
1485 let (tr_stream, plus_stream, minus_stream) = {
1486 match kern {
1487 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1488 Kernel::Avx512 => unsafe { precompute_streams_avx512(high, low, close, first) },
1489 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1490 Kernel::Avx2 => unsafe { precompute_streams_avx2(high, low, close, first) },
1491 _ => precompute_streams_scalar(high, low, close, first),
1492 }
1493 };
1494
1495 let do_row_shared = |row: usize, row_mu: &mut [std::mem::MaybeUninit<f64>]| unsafe {
1496 let p = combos[row].period.unwrap();
1497 let row_f64 =
1498 core::slice::from_raw_parts_mut(row_mu.as_mut_ptr() as *mut f64, row_mu.len());
1499 let dst_tail = &mut row_f64[first..];
1500
1501 let pf = p as f64;
1502 let rp = 1.0 / pf;
1503 let one_minus_rp = 1.0 - rp;
1504 let pm1 = pf - 1.0;
1505
1506 let mut atr = 0.0f64;
1507 let mut plus_s = 0.0f64;
1508 let mut minus_s = 0.0f64;
1509 let mut j = 1usize;
1510 while j <= p {
1511 atr += tr_stream[j];
1512 plus_s += plus_stream[j];
1513 minus_s += minus_stream[j];
1514 j += 1;
1515 }
1516 let (plus_di_prev, minus_di_prev) = if atr != 0.0 {
1517 ((plus_s / atr) * 100.0, (minus_s / atr) * 100.0)
1518 } else {
1519 (0.0, 0.0)
1520 };
1521 let sum_di_prev = plus_di_prev + minus_di_prev;
1522 let mut dx_sum = if sum_di_prev != 0.0 {
1523 ((plus_di_prev - minus_di_prev).abs() / sum_di_prev) * 100.0
1524 } else {
1525 0.0
1526 };
1527 let mut dx_count = 1usize;
1528 let mut last_adx = 0.0f64;
1529
1530 let tail_len = tr_stream.len();
1531 let mut j = p + 1;
1532 while j < tail_len {
1533 atr = atr * one_minus_rp + tr_stream[j];
1534 plus_s = plus_s * one_minus_rp + plus_stream[j];
1535 minus_s = minus_s * one_minus_rp + minus_stream[j];
1536
1537 let (plus_di, minus_di) = if atr != 0.0 {
1538 ((plus_s / atr) * 100.0, (minus_s / atr) * 100.0)
1539 } else {
1540 (0.0, 0.0)
1541 };
1542 let sum_di = plus_di + minus_di;
1543 let dx = if sum_di != 0.0 {
1544 ((plus_di - minus_di).abs() / sum_di) * 100.0
1545 } else {
1546 0.0
1547 };
1548
1549 if dx_count < p {
1550 dx_sum += dx;
1551 dx_count += 1;
1552 if dx_count == p {
1553 last_adx = dx_sum * rp;
1554 dst_tail[j] = last_adx;
1555 }
1556 } else {
1557 last_adx = (last_adx * pm1 + dx) * rp;
1558 dst_tail[j] = last_adx;
1559 }
1560 j += 1;
1561 }
1562 };
1563
1564 let out_mu2 = unsafe {
1565 std::slice::from_raw_parts_mut(
1566 out.as_mut_ptr() as *mut std::mem::MaybeUninit<f64>,
1567 out.len(),
1568 )
1569 };
1570 let rows_iter = (0..rows).zip(out_mu2.chunks_mut(cols));
1571 if parallel {
1572 #[cfg(not(target_arch = "wasm32"))]
1573 rows_iter
1574 .par_bridge()
1575 .for_each(|(r, s)| do_row_shared(r, s));
1576 #[cfg(target_arch = "wasm32")]
1577 for (r, s) in rows_iter {
1578 do_row_shared(r, s);
1579 }
1580 } else {
1581 for (r, s) in rows_iter {
1582 do_row_shared(r, s);
1583 }
1584 }
1585 return Ok(combos);
1586 }
1587
1588 let do_row = |row: usize, row_mu: &mut [std::mem::MaybeUninit<f64>]| unsafe {
1589 let p = combos[row].period.unwrap();
1590 let row_f64 =
1591 core::slice::from_raw_parts_mut(row_mu.as_mut_ptr() as *mut f64, row_mu.len());
1592 let dst_tail = &mut row_f64[first..];
1593 match kern {
1594 Kernel::Scalar => {
1595 adx_row_scalar(&high[first..], &low[first..], &close[first..], p, dst_tail)
1596 }
1597 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1598 Kernel::Avx2 => {
1599 adx_row_avx2(&high[first..], &low[first..], &close[first..], p, dst_tail)
1600 }
1601 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1602 Kernel::Avx512 => {
1603 adx_row_avx512(&high[first..], &low[first..], &close[first..], p, dst_tail)
1604 }
1605 _ => adx_row_scalar(&high[first..], &low[first..], &close[first..], p, dst_tail),
1606 }
1607 };
1608
1609 let out_mu2 = unsafe {
1610 std::slice::from_raw_parts_mut(
1611 out.as_mut_ptr() as *mut std::mem::MaybeUninit<f64>,
1612 out.len(),
1613 )
1614 };
1615 let rows_iter = (0..rows).zip(out_mu2.chunks_mut(cols));
1616 if parallel {
1617 #[cfg(not(target_arch = "wasm32"))]
1618 rows_iter.par_bridge().for_each(|(r, s)| do_row(r, s));
1619 #[cfg(target_arch = "wasm32")]
1620 for (r, s) in rows_iter {
1621 do_row(r, s);
1622 }
1623 } else {
1624 for (r, s) in rows_iter {
1625 do_row(r, s);
1626 }
1627 }
1628
1629 Ok(combos)
1630}
1631
1632#[inline(always)]
1633pub fn adx_batch_slice(
1634 high: &[f64],
1635 low: &[f64],
1636 close: &[f64],
1637 sweep: &AdxBatchRange,
1638 kern: Kernel,
1639) -> Result<AdxBatchOutput, AdxError> {
1640 let simd_kern = match kern {
1641 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1642 Kernel::Avx512Batch => Kernel::Avx512,
1643 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1644 Kernel::Avx2Batch => Kernel::Avx2,
1645 Kernel::ScalarBatch => Kernel::Scalar,
1646 Kernel::Auto => detect_best_kernel(),
1647 other => other,
1648 };
1649 adx_batch_inner(high, low, close, sweep, simd_kern, false)
1650}
1651
1652#[inline(always)]
1653pub fn adx_batch_par_slice(
1654 high: &[f64],
1655 low: &[f64],
1656 close: &[f64],
1657 sweep: &AdxBatchRange,
1658 kern: Kernel,
1659) -> Result<AdxBatchOutput, AdxError> {
1660 adx_batch_inner(high, low, close, sweep, kern, true)
1661}
1662
1663#[inline(always)]
1664fn adx_batch_inner(
1665 high: &[f64],
1666 low: &[f64],
1667 close: &[f64],
1668 sweep: &AdxBatchRange,
1669 kern: Kernel,
1670 parallel: bool,
1671) -> Result<AdxBatchOutput, AdxError> {
1672 if high.len() != low.len() || high.len() != close.len() {
1673 return Err(AdxError::InconsistentLengths);
1674 }
1675 let combos = expand_grid(sweep);
1676 if combos.is_empty() {
1677 return Err(AdxError::InvalidRange {
1678 start: sweep.period.0,
1679 end: sweep.period.1,
1680 step: sweep.period.2,
1681 });
1682 }
1683
1684 let rows = combos.len();
1685 let cols = close.len();
1686 if cols == 0 {
1687 return Err(AdxError::EmptyInputData);
1688 }
1689
1690 let first = first_valid_triple(high, low, close);
1691 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1692 if cols - first < max_p + 1 {
1693 return Err(AdxError::NotEnoughValidData {
1694 needed: max_p + 1,
1695 valid: cols - first,
1696 });
1697 }
1698
1699 let _cap = rows.checked_mul(cols).ok_or(AdxError::InvalidRange {
1700 start: sweep.period.0,
1701 end: sweep.period.1,
1702 step: sweep.period.2,
1703 })?;
1704 let mut buf_mu = make_uninit_matrix(rows, cols);
1705
1706 let mut warm: Vec<usize> = Vec::with_capacity(combos.len());
1707 for c in &combos {
1708 let p = c.period.unwrap();
1709 let two_p = p.checked_mul(2).ok_or(AdxError::InvalidRange {
1710 start: sweep.period.0,
1711 end: sweep.period.1,
1712 step: sweep.period.2,
1713 })?;
1714 let w = first
1715 .checked_add(two_p.saturating_sub(1))
1716 .ok_or(AdxError::InvalidRange {
1717 start: sweep.period.0,
1718 end: sweep.period.1,
1719 step: sweep.period.2,
1720 })?;
1721 warm.push(w);
1722 }
1723 init_matrix_prefixes(&mut buf_mu, cols, &warm);
1724
1725 let mut guard = ManuallyDrop::new(buf_mu);
1726 let values: &mut [f64] =
1727 unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
1728
1729 let use_shared = combos.len() >= ADX_SHARED_PRECOMP_THRESHOLD;
1730
1731 if use_shared {
1732 let (tr_stream, plus_stream, minus_stream) = {
1733 match kern {
1734 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1735 Kernel::Avx512 => unsafe { precompute_streams_avx512(high, low, close, first) },
1736 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1737 Kernel::Avx2 => unsafe { precompute_streams_avx2(high, low, close, first) },
1738 _ => precompute_streams_scalar(high, low, close, first),
1739 }
1740 };
1741
1742 let do_row = |row: usize, out_row: &mut [f64]| {
1743 let p = combos[row].period.unwrap();
1744 let pf = p as f64;
1745 let rp = 1.0 / pf;
1746 let one_minus_rp = 1.0 - rp;
1747 let pm1 = pf - 1.0;
1748 let dst_tail = &mut out_row[first..];
1749
1750 let mut atr = 0.0f64;
1751 let mut plus_s = 0.0f64;
1752 let mut minus_s = 0.0f64;
1753 let mut j = 1usize;
1754 while j <= p {
1755 atr += tr_stream[j];
1756 plus_s += plus_stream[j];
1757 minus_s += minus_stream[j];
1758 j += 1;
1759 }
1760 let (plus_di_prev, minus_di_prev) = if atr != 0.0 {
1761 ((plus_s / atr) * 100.0, (minus_s / atr) * 100.0)
1762 } else {
1763 (0.0, 0.0)
1764 };
1765 let sum_di_prev = plus_di_prev + minus_di_prev;
1766 let mut dx_sum = if sum_di_prev != 0.0 {
1767 ((plus_di_prev - minus_di_prev).abs() / sum_di_prev) * 100.0
1768 } else {
1769 0.0
1770 };
1771 let mut dx_count = 1usize;
1772 let mut last_adx = 0.0f64;
1773
1774 let tail_len = tr_stream.len();
1775 let mut j = p + 1;
1776 while j < tail_len {
1777 atr = atr * one_minus_rp + tr_stream[j];
1778 plus_s = plus_s * one_minus_rp + plus_stream[j];
1779 minus_s = minus_s * one_minus_rp + minus_stream[j];
1780
1781 let (plus_di, minus_di) = if atr != 0.0 {
1782 ((plus_s / atr) * 100.0, (minus_s / atr) * 100.0)
1783 } else {
1784 (0.0, 0.0)
1785 };
1786 let sum_di = plus_di + minus_di;
1787 let dx = if sum_di != 0.0 {
1788 ((plus_di - minus_di).abs() / sum_di) * 100.0
1789 } else {
1790 0.0
1791 };
1792
1793 if dx_count < p {
1794 dx_sum += dx;
1795 dx_count += 1;
1796 if dx_count == p {
1797 last_adx = dx_sum * rp;
1798 dst_tail[j] = last_adx;
1799 }
1800 } else {
1801 last_adx = (last_adx * pm1 + dx) * rp;
1802 dst_tail[j] = last_adx;
1803 }
1804 j += 1;
1805 }
1806 };
1807
1808 if parallel {
1809 #[cfg(not(target_arch = "wasm32"))]
1810 values
1811 .par_chunks_mut(cols)
1812 .enumerate()
1813 .for_each(|(r, s)| do_row(r, s));
1814 #[cfg(target_arch = "wasm32")]
1815 for (r, s) in values.chunks_mut(cols).enumerate() {
1816 do_row(r, s);
1817 }
1818 } else {
1819 for (r, s) in values.chunks_mut(cols).enumerate() {
1820 do_row(r, s);
1821 }
1822 }
1823
1824 let values = unsafe {
1825 Vec::from_raw_parts(
1826 guard.as_mut_ptr() as *mut f64,
1827 guard.len(),
1828 guard.capacity(),
1829 )
1830 };
1831
1832 return Ok(AdxBatchOutput {
1833 values,
1834 combos,
1835 rows,
1836 cols,
1837 });
1838 }
1839
1840 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
1841 let p = combos[row].period.unwrap();
1842 let pf = p as f64;
1843 let rp = 1.0 / pf;
1844 let one_minus_rp = 1.0 - rp;
1845 let pm1 = pf - 1.0;
1846 let dst_tail = &mut out_row[first..];
1847 match kern {
1848 Kernel::Scalar => {
1849 adx_row_scalar(&high[first..], &low[first..], &close[first..], p, dst_tail)
1850 }
1851 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1852 Kernel::Avx2 => {
1853 adx_row_avx2(&high[first..], &low[first..], &close[first..], p, dst_tail)
1854 }
1855 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1856 Kernel::Avx512 => {
1857 adx_row_avx512(&high[first..], &low[first..], &close[first..], p, dst_tail)
1858 }
1859 _ => adx_row_scalar(&high[first..], &low[first..], &close[first..], p, dst_tail),
1860 }
1861 };
1862
1863 if parallel {
1864 #[cfg(not(target_arch = "wasm32"))]
1865 values
1866 .par_chunks_mut(cols)
1867 .enumerate()
1868 .for_each(|(r, s)| do_row(r, s));
1869 #[cfg(target_arch = "wasm32")]
1870 for (r, s) in values.chunks_mut(cols).enumerate() {
1871 do_row(r, s);
1872 }
1873 } else {
1874 for (r, s) in values.chunks_mut(cols).enumerate() {
1875 do_row(r, s);
1876 }
1877 }
1878
1879 let values = unsafe {
1880 Vec::from_raw_parts(
1881 guard.as_mut_ptr() as *mut f64,
1882 guard.len(),
1883 guard.capacity(),
1884 )
1885 };
1886
1887 Ok(AdxBatchOutput {
1888 values,
1889 combos,
1890 rows,
1891 cols,
1892 })
1893}
1894
1895#[inline(always)]
1896unsafe fn adx_row_scalar(high: &[f64], low: &[f64], close: &[f64], period: usize, out: &mut [f64]) {
1897 adx_scalar(high, low, close, period, out)
1898}
1899
1900#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1901#[inline(always)]
1902pub unsafe fn adx_row_avx2(
1903 high: &[f64],
1904 low: &[f64],
1905 close: &[f64],
1906 period: usize,
1907 out: &mut [f64],
1908) {
1909 adx_avx2(high, low, close, period, out)
1910}
1911
1912#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1913#[inline(always)]
1914pub unsafe fn adx_row_avx512(
1915 high: &[f64],
1916 low: &[f64],
1917 close: &[f64],
1918 period: usize,
1919 out: &mut [f64],
1920) {
1921 adx_avx512(high, low, close, period, out)
1922}
1923
1924#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1925#[inline(always)]
1926pub unsafe fn adx_row_avx512_short(
1927 high: &[f64],
1928 low: &[f64],
1929 close: &[f64],
1930 period: usize,
1931 out: &mut [f64],
1932) {
1933 adx_avx512(high, low, close, period, out)
1934}
1935
1936#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1937#[inline(always)]
1938pub unsafe fn adx_row_avx512_long(
1939 high: &[f64],
1940 low: &[f64],
1941 close: &[f64],
1942 period: usize,
1943 out: &mut [f64],
1944) {
1945 adx_avx512(high, low, close, period, out)
1946}
1947
1948#[derive(Debug, Clone)]
1949pub struct AdxStream {
1950 period: usize,
1951 atr: f64,
1952 plus_dm_smooth: f64,
1953 minus_dm_smooth: f64,
1954 dx_sum: f64,
1955 dx_count: usize,
1956 last_adx: f64,
1957 count: usize,
1958 prev_high: f64,
1959 prev_low: f64,
1960 prev_close: f64,
1961}
1962
1963impl AdxStream {
1964 pub fn try_new(params: AdxParams) -> Result<Self, AdxError> {
1965 let period = params.period.unwrap_or(14);
1966 if period == 0 {
1967 return Err(AdxError::InvalidPeriod {
1968 period,
1969 data_len: 0,
1970 });
1971 }
1972 Ok(Self {
1973 period,
1974 atr: 0.0,
1975 plus_dm_smooth: 0.0,
1976 minus_dm_smooth: 0.0,
1977 dx_sum: 0.0,
1978 dx_count: 0,
1979 last_adx: 0.0,
1980 count: 0,
1981 prev_high: f64::NAN,
1982 prev_low: f64::NAN,
1983 prev_close: f64::NAN,
1984 })
1985 }
1986
1987 #[inline(always)]
1988 pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<f64> {
1989 if self.count == 0 {
1990 self.prev_high = high;
1991 self.prev_low = low;
1992 self.prev_close = close;
1993 self.count = 1;
1994 return None;
1995 }
1996
1997 let prev_c = self.prev_close;
1998 let tr = high.max(prev_c) - low.min(prev_c);
1999
2000 let up_move = high - self.prev_high;
2001 let down_move = self.prev_low - low;
2002 let plus_dm = if up_move > down_move && up_move > 0.0 {
2003 up_move
2004 } else {
2005 0.0
2006 };
2007 let minus_dm = if down_move > up_move && down_move > 0.0 {
2008 down_move
2009 } else {
2010 0.0
2011 };
2012
2013 self.count += 1;
2014
2015 if self.count <= self.period + 1 {
2016 self.atr += tr;
2017 self.plus_dm_smooth += plus_dm;
2018 self.minus_dm_smooth += minus_dm;
2019
2020 if self.count == self.period + 1 {
2021 let inv_atr100 = if self.atr != 0.0 {
2022 100.0 / self.atr
2023 } else {
2024 0.0
2025 };
2026 let plus_di = self.plus_dm_smooth * inv_atr100;
2027 let minus_di = self.minus_dm_smooth * inv_atr100;
2028 let sum_di = plus_di + minus_di;
2029
2030 self.dx_sum = if sum_di != 0.0 {
2031 ((plus_di - minus_di).abs() / sum_di) * 100.0
2032 } else {
2033 0.0
2034 };
2035 self.dx_count = 1;
2036 }
2037
2038 self.prev_high = high;
2039 self.prev_low = low;
2040 self.prev_close = close;
2041 return None;
2042 }
2043
2044 let rp = 1.0 / (self.period as f64);
2045 let one_minus_rp = 1.0 - rp;
2046 let period_minus_one = (self.period as f64) - 1.0;
2047
2048 self.atr = self.atr * one_minus_rp + tr;
2049 self.plus_dm_smooth = self.plus_dm_smooth * one_minus_rp + plus_dm;
2050 self.minus_dm_smooth = self.minus_dm_smooth * one_minus_rp + minus_dm;
2051
2052 let inv_atr100 = if self.atr != 0.0 {
2053 100.0 / self.atr
2054 } else {
2055 0.0
2056 };
2057 let plus_di = self.plus_dm_smooth * inv_atr100;
2058 let minus_di = self.minus_dm_smooth * inv_atr100;
2059 let sum_di = plus_di + minus_di;
2060
2061 let dx = if sum_di != 0.0 {
2062 ((plus_di - minus_di).abs() / sum_di) * 100.0
2063 } else {
2064 0.0
2065 };
2066
2067 let out = if self.dx_count < self.period {
2068 self.dx_sum += dx;
2069 self.dx_count += 1;
2070 if self.dx_count == self.period {
2071 self.last_adx = self.dx_sum * rp;
2072 Some(self.last_adx)
2073 } else {
2074 None
2075 }
2076 } else {
2077 self.last_adx = (self.last_adx * period_minus_one + dx) * rp;
2078 Some(self.last_adx)
2079 };
2080
2081 self.prev_high = high;
2082 self.prev_low = low;
2083 self.prev_close = close;
2084
2085 out
2086 }
2087}
2088
2089#[cfg(test)]
2090mod tests {
2091 use super::*;
2092 use crate::skip_if_unsupported;
2093 use crate::utilities::data_loader::read_candles_from_csv;
2094
2095 #[test]
2096 fn test_adx_into_matches_api() -> Result<(), Box<dyn Error>> {
2097 let n = 256usize;
2098 let mut high = Vec::with_capacity(n);
2099 let mut low = Vec::with_capacity(n);
2100 let mut close = Vec::with_capacity(n);
2101 for i in 0..n {
2102 let t = i as f64;
2103 let base = 100.0 + 0.5 * t + (t * 0.1).sin() * 0.7;
2104 let c = base;
2105 let h = c + 0.6 + (t * 0.05).cos() * 0.1;
2106 let l = c - 0.6 - (t * 0.07).sin() * 0.1;
2107 high.push(h);
2108 low.push(l);
2109 close.push(c);
2110 }
2111
2112 let input = AdxInput::from_slices(&high, &low, &close, AdxParams::default());
2113
2114 let AdxOutput { values: expected } = adx(&input)?;
2115
2116 let mut got = vec![0.0; n];
2117 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2118 {
2119 adx_into(&input, &mut got)?;
2120 }
2121 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2122 {
2123 return Ok(());
2124 }
2125
2126 assert_eq!(expected.len(), got.len());
2127
2128 fn eq_or_both_nan(a: f64, b: f64) -> bool {
2129 (a.is_nan() && b.is_nan()) || (a - b).abs() <= 1e-12
2130 }
2131 for i in 0..n {
2132 assert!(
2133 eq_or_both_nan(expected[i], got[i]),
2134 "mismatch at {}: expected {:?}, got {:?}",
2135 i,
2136 expected[i],
2137 got[i]
2138 );
2139 }
2140 Ok(())
2141 }
2142
2143 fn check_adx_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2144 skip_if_unsupported!(kernel, test_name);
2145 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2146 let candles = read_candles_from_csv(file_path)?;
2147
2148 let default_params = AdxParams { period: None };
2149 let input = AdxInput::from_candles(&candles, default_params);
2150 let output = adx_with_kernel(&input, kernel)?;
2151 assert_eq!(output.values.len(), candles.close.len());
2152
2153 Ok(())
2154 }
2155
2156 fn check_adx_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2157 skip_if_unsupported!(kernel, test_name);
2158 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2159 let candles = read_candles_from_csv(file_path)?;
2160
2161 let input = AdxInput::from_candles(&candles, AdxParams::default());
2162 let result = adx_with_kernel(&input, kernel)?;
2163 let expected_last_five = [36.14, 36.52, 37.01, 37.46, 38.47];
2164 let start = result.values.len().saturating_sub(5);
2165 for (i, &val) in result.values[start..].iter().enumerate() {
2166 let diff = (val - expected_last_five[i]).abs();
2167 assert!(
2168 diff < 1e-1,
2169 "[{}] ADX {:?} mismatch at idx {}: got {}, expected {}",
2170 test_name,
2171 kernel,
2172 i,
2173 val,
2174 expected_last_five[i]
2175 );
2176 }
2177 Ok(())
2178 }
2179
2180 fn check_adx_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2181 skip_if_unsupported!(kernel, test_name);
2182 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2183 let candles = read_candles_from_csv(file_path)?;
2184
2185 let input = AdxInput::with_default_candles(&candles);
2186 let output = adx_with_kernel(&input, kernel)?;
2187 assert_eq!(output.values.len(), candles.close.len());
2188
2189 Ok(())
2190 }
2191
2192 fn check_adx_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2193 skip_if_unsupported!(kernel, test_name);
2194 let high = [10.0, 20.0, 30.0];
2195 let low = [5.0, 15.0, 25.0];
2196 let close = [9.0, 19.0, 29.0];
2197 let params = AdxParams { period: Some(0) };
2198 let input = AdxInput::from_slices(&high, &low, &close, params);
2199 let res = adx_with_kernel(&input, kernel);
2200 assert!(
2201 res.is_err(),
2202 "[{}] ADX should fail with zero period",
2203 test_name
2204 );
2205 Ok(())
2206 }
2207
2208 fn check_adx_period_exceeds_length(
2209 test_name: &str,
2210 kernel: Kernel,
2211 ) -> Result<(), Box<dyn Error>> {
2212 skip_if_unsupported!(kernel, test_name);
2213 let high = [10.0, 20.0, 30.0];
2214 let low = [5.0, 15.0, 25.0];
2215 let close = [9.0, 19.0, 29.0];
2216 let params = AdxParams { period: Some(10) };
2217 let input = AdxInput::from_slices(&high, &low, &close, params);
2218 let res = adx_with_kernel(&input, kernel);
2219 assert!(
2220 res.is_err(),
2221 "[{}] ADX should fail with period exceeding length",
2222 test_name
2223 );
2224 Ok(())
2225 }
2226
2227 fn check_adx_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2228 skip_if_unsupported!(kernel, test_name);
2229 let high = [42.0];
2230 let low = [41.0];
2231 let close = [40.5];
2232 let params = AdxParams { period: Some(14) };
2233 let input = AdxInput::from_slices(&high, &low, &close, params);
2234 let res = adx_with_kernel(&input, kernel);
2235 assert!(
2236 res.is_err(),
2237 "[{}] ADX should fail with insufficient data",
2238 test_name
2239 );
2240 Ok(())
2241 }
2242
2243 fn check_adx_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2244 skip_if_unsupported!(kernel, test_name);
2245 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2246 let candles = read_candles_from_csv(file_path)?;
2247
2248 let first_params = AdxParams { period: Some(14) };
2249 let first_input = AdxInput::from_candles(&candles, first_params);
2250 let first_result = adx_with_kernel(&first_input, kernel)?;
2251
2252 let second_params = AdxParams { period: Some(5) };
2253 let second_input = AdxInput::from_slices(
2254 &candles.high,
2255 &candles.low,
2256 &first_result.values,
2257 second_params,
2258 );
2259 let second_result = adx_with_kernel(&second_input, kernel)?;
2260
2261 assert_eq!(second_result.values.len(), candles.close.len());
2262 for i in 100..second_result.values.len() {
2263 assert!(!second_result.values[i].is_nan());
2264 }
2265 Ok(())
2266 }
2267
2268 fn check_adx_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2269 skip_if_unsupported!(kernel, test_name);
2270 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2271 let candles = read_candles_from_csv(file_path)?;
2272
2273 let input = AdxInput::from_candles(&candles, AdxParams { period: Some(14) });
2274 let res = adx_with_kernel(&input, kernel)?;
2275 assert_eq!(res.values.len(), candles.close.len());
2276 if res.values.len() > 100 {
2277 for (i, &val) in res.values[100..].iter().enumerate() {
2278 assert!(
2279 !val.is_nan(),
2280 "[{}] Found unexpected NaN at out-index {}",
2281 test_name,
2282 100 + i
2283 );
2284 }
2285 }
2286 Ok(())
2287 }
2288
2289 fn check_adx_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2290 skip_if_unsupported!(kernel, test_name);
2291 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2292 let candles = read_candles_from_csv(file_path)?;
2293
2294 let period = 14;
2295
2296 let input = AdxInput::from_candles(
2297 &candles,
2298 AdxParams {
2299 period: Some(period),
2300 },
2301 );
2302 let batch_output = adx_with_kernel(&input, kernel)?.values;
2303
2304 let mut stream = AdxStream::try_new(AdxParams {
2305 period: Some(period),
2306 })?;
2307 let mut stream_values = Vec::with_capacity(candles.close.len());
2308 for ((&h, &l), &c) in candles.high.iter().zip(&candles.low).zip(&candles.close) {
2309 match stream.update(h, l, c) {
2310 Some(adx_val) => stream_values.push(adx_val),
2311 None => stream_values.push(f64::NAN),
2312 }
2313 }
2314 assert_eq!(batch_output.len(), stream_values.len());
2315 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
2316 if b.is_nan() && s.is_nan() {
2317 continue;
2318 }
2319 let diff = (b - s).abs();
2320 assert!(
2321 diff < 1e-8,
2322 "[{}] ADX streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
2323 test_name,
2324 i,
2325 b,
2326 s,
2327 diff
2328 );
2329 }
2330 Ok(())
2331 }
2332
2333 #[cfg(debug_assertions)]
2334 fn check_adx_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2335 skip_if_unsupported!(kernel, test_name);
2336
2337 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2338 let candles = read_candles_from_csv(file_path)?;
2339
2340 let test_params = vec![
2341 AdxParams::default(),
2342 AdxParams { period: Some(5) },
2343 AdxParams { period: Some(10) },
2344 AdxParams { period: Some(20) },
2345 AdxParams { period: Some(50) },
2346 ];
2347
2348 for params in test_params {
2349 let input =
2350 AdxInput::from_slices(&candles.high, &candles.low, &candles.close, params.clone());
2351 let output = adx_with_kernel(&input, kernel)?;
2352
2353 for (idx, &val) in output.values.iter().enumerate() {
2354 if val.is_nan() || val.is_infinite() {
2355 continue;
2356 }
2357
2358 let bits = val.to_bits();
2359
2360 if bits == 0x11111111_11111111 {
2361 panic!(
2362 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
2363 with params: period={}",
2364 test_name,
2365 val,
2366 bits,
2367 idx,
2368 params.period.unwrap_or(14)
2369 );
2370 }
2371
2372 if bits == 0x22222222_22222222 {
2373 panic!(
2374 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
2375 with params: period={}",
2376 test_name,
2377 val,
2378 bits,
2379 idx,
2380 params.period.unwrap_or(14)
2381 );
2382 }
2383
2384 if bits == 0x33333333_33333333 {
2385 panic!(
2386 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
2387 with params: period={}",
2388 test_name,
2389 val,
2390 bits,
2391 idx,
2392 params.period.unwrap_or(14)
2393 );
2394 }
2395 }
2396 }
2397
2398 Ok(())
2399 }
2400
2401 #[cfg(not(debug_assertions))]
2402 fn check_adx_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2403 Ok(())
2404 }
2405
2406 #[cfg(feature = "proptest")]
2407 #[allow(clippy::float_cmp)]
2408 fn check_adx_property(
2409 test_name: &str,
2410 kernel: Kernel,
2411 ) -> Result<(), Box<dyn std::error::Error>> {
2412 use proptest::prelude::*;
2413 skip_if_unsupported!(kernel, test_name);
2414
2415 let strat = (2usize..=100).prop_flat_map(|period| {
2416 (
2417 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
2418 (0.01f64..0.2f64),
2419 period + 1..400,
2420 )
2421 .prop_flat_map(move |(base_price, volatility, len)| {
2422 prop::collection::vec(
2423 (0f64..1f64).prop_map(move |rand| {
2424 let change = (rand - 0.5) * volatility * base_price.abs();
2425 let open = base_price + change;
2426 let close = open + (rand - 0.5) * volatility * base_price.abs() * 0.5;
2427 let high = open.max(close) + rand * volatility * base_price.abs() * 0.3;
2428 let low = open.min(close) - rand * volatility * base_price.abs() * 0.3;
2429 (high, low, close)
2430 }),
2431 len,
2432 )
2433 .prop_map(move |bars| (bars, period))
2434 })
2435 });
2436
2437 proptest::test_runner::TestRunner::default()
2438 .run(&strat, |(bars, period)| {
2439 let mut highs = Vec::with_capacity(bars.len());
2440 let mut lows = Vec::with_capacity(bars.len());
2441 let mut closes = Vec::with_capacity(bars.len());
2442
2443 for &(h, l, c) in &bars {
2444 highs.push(h);
2445 lows.push(l);
2446 closes.push(c);
2447 }
2448
2449 let params = AdxParams {
2450 period: Some(period),
2451 };
2452 let input = AdxInput::from_slices(&highs, &lows, &closes, params.clone());
2453
2454 let AdxOutput { values: out } = adx_with_kernel(&input, kernel).unwrap();
2455 let AdxOutput { values: ref_out } =
2456 adx_with_kernel(&input, Kernel::Scalar).unwrap();
2457
2458 let warmup_period = 2 * period - 1;
2459 for i in 0..warmup_period.min(out.len()) {
2460 prop_assert!(
2461 out[i].is_nan(),
2462 "[{}] Property 1: Expected NaN during warmup at index {}, got {}",
2463 test_name,
2464 i,
2465 out[i]
2466 );
2467 }
2468
2469 if out.len() > warmup_period + 10 {
2470 for i in (warmup_period + 10)..out.len() {
2471 prop_assert!(
2472 !out[i].is_nan(),
2473 "[{}] Property 2: Unexpected NaN after warmup at index {}",
2474 test_name,
2475 i
2476 );
2477 }
2478 }
2479
2480 for (i, &val) in out.iter().enumerate() {
2481 if !val.is_nan() {
2482 prop_assert!(
2483 val >= 0.0 && val <= 100.0,
2484 "[{}] Property 3: ADX value {} at index {} outside [0, 100] range",
2485 test_name,
2486 val,
2487 i
2488 );
2489 }
2490 }
2491
2492 let const_price = 100.0;
2493 let const_highs = vec![const_price; closes.len()];
2494 let const_lows = vec![const_price; closes.len()];
2495 let const_closes = vec![const_price; closes.len()];
2496 let const_input =
2497 AdxInput::from_slices(&const_highs, &const_lows, &const_closes, params.clone());
2498
2499 if let Ok(AdxOutput { values: const_out }) = adx_with_kernel(&const_input, kernel) {
2500 for i in warmup_period..const_out.len() {
2501 if !const_out[i].is_nan() {
2502 prop_assert!(
2503 const_out[i] <= 1.0,
2504 "[{}] Property 4: ADX should be near 0 for constant prices, got {} at index {}",
2505 test_name, const_out[i], i
2506 );
2507 }
2508 }
2509 }
2510
2511 prop_assert_eq!(
2512 out.len(),
2513 ref_out.len(),
2514 "[{}] Property 5: Kernel output length mismatch",
2515 test_name
2516 );
2517
2518 for i in 0..out.len() {
2519 let y = out[i];
2520 let r = ref_out[i];
2521
2522 if !y.is_finite() || !r.is_finite() {
2523 prop_assert!(
2524 y.to_bits() == r.to_bits(),
2525 "[{}] Property 5: NaN/Inf mismatch at index {}: {} vs {}",
2526 test_name,
2527 i,
2528 y,
2529 r
2530 );
2531 continue;
2532 }
2533
2534 let ulp_diff = y.to_bits().abs_diff(r.to_bits());
2535 prop_assert!(
2536 (y - r).abs() <= 1e-9 || ulp_diff <= 4,
2537 "[{}] Property 5: Kernel mismatch at index {}: {} vs {} (ULP={})",
2538 test_name,
2539 i,
2540 y,
2541 r,
2542 ulp_diff
2543 );
2544 }
2545
2546 if period == 2 {
2547 prop_assert!(
2548 out.len() == closes.len(),
2549 "[{}] Property 6: Output length mismatch with period=2",
2550 test_name
2551 );
2552
2553 if out.len() > 3 {
2554 prop_assert!(
2555 !out[3].is_nan(),
2556 "[{}] Property 6: Should have valid ADX at index 3 with period=2",
2557 test_name
2558 );
2559 }
2560 }
2561
2562 let trend_len = closes.len();
2563 let mut trend_highs = Vec::with_capacity(trend_len);
2564 let mut trend_lows = Vec::with_capacity(trend_len);
2565 let mut trend_closes = Vec::with_capacity(trend_len);
2566
2567 for i in 0..trend_len {
2568 let base = 100.0 + (i as f64) * 2.0;
2569 trend_lows.push(base - 0.5);
2570 trend_highs.push(base + 0.5);
2571 trend_closes.push(base);
2572 }
2573
2574 let trend_input =
2575 AdxInput::from_slices(&trend_highs, &trend_lows, &trend_closes, params.clone());
2576
2577 if let Ok(AdxOutput { values: trend_out }) = adx_with_kernel(&trend_input, kernel) {
2578 let last_valid_adx = trend_out
2579 .iter()
2580 .rposition(|&v| !v.is_nan())
2581 .and_then(|i| Some(trend_out[i]));
2582
2583 if let Some(adx_val) = last_valid_adx {
2584 prop_assert!(
2585 adx_val > 20.0,
2586 "[{}] Property 7: Strong trend should produce high ADX, got {}",
2587 test_name,
2588 adx_val
2589 );
2590 }
2591 }
2592
2593 let doji_price = 100.0;
2594 let mut doji_highs = Vec::with_capacity(closes.len());
2595 let mut doji_lows = Vec::with_capacity(closes.len());
2596 let mut doji_closes = Vec::with_capacity(closes.len());
2597
2598 for _ in 0..closes.len() {
2599 doji_highs.push(doji_price + 0.01);
2600 doji_lows.push(doji_price - 0.01);
2601 doji_closes.push(doji_price);
2602 }
2603
2604 let doji_input =
2605 AdxInput::from_slices(&doji_highs, &doji_lows, &doji_closes, params.clone());
2606
2607 if let Ok(AdxOutput { values: doji_out }) = adx_with_kernel(&doji_input, kernel) {
2608 for i in warmup_period..doji_out.len() {
2609 if !doji_out[i].is_nan() {
2610 prop_assert!(
2611 doji_out[i] <= 30.0,
2612 "[{}] Property 8: Low movement should produce low ADX, got {} at index {}",
2613 test_name, doji_out[i], i
2614 );
2615 }
2616 }
2617 }
2618
2619 if out.len() > warmup_period {
2620 prop_assert!(
2621 !out[warmup_period].is_nan(),
2622 "[{}] Property 9: Should have valid ADX at index {} (warmup_period)",
2623 test_name,
2624 warmup_period
2625 );
2626 if warmup_period > 0 {
2627 prop_assert!(
2628 out[warmup_period - 1].is_nan(),
2629 "[{}] Property 9: Should have NaN at index {} (before warmup_period)",
2630 test_name,
2631 warmup_period - 1
2632 );
2633 }
2634 }
2635
2636 #[cfg(debug_assertions)]
2637 {
2638 for (i, &val) in out.iter().enumerate() {
2639 if val.is_finite() {
2640 let bits = val.to_bits();
2641 prop_assert!(
2642 bits != 0x11111111_11111111
2643 && bits != 0x22222222_22222222
2644 && bits != 0x33333333_33333333,
2645 "[{}] Property 10: Found poison value {} (0x{:016X}) at index {}",
2646 test_name,
2647 val,
2648 bits,
2649 i
2650 );
2651 }
2652 }
2653 }
2654
2655 for (i, &(h, l, c)) in bars.iter().enumerate() {
2656 prop_assert!(
2657 h >= l,
2658 "[{}] Property 11: High {} < Low {} at index {}",
2659 test_name,
2660 h,
2661 l,
2662 i
2663 );
2664 prop_assert!(
2665 c >= l && c <= h,
2666 "[{}] Property 11: Close {} outside [Low {}, High {}] at index {}",
2667 test_name,
2668 c,
2669 l,
2670 h,
2671 i
2672 );
2673 }
2674
2675 Ok(())
2676 })
2677 .unwrap();
2678
2679 Ok(())
2680 }
2681
2682 macro_rules! generate_all_adx_tests {
2683 ($($test_fn:ident),*) => {
2684 paste::paste! {
2685 $(
2686 #[test]
2687 fn [<$test_fn _scalar_f64>]() {
2688 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2689 }
2690 )*
2691 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2692 $(
2693 #[test]
2694 fn [<$test_fn _avx2_f64>]() {
2695 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2696 }
2697 #[test]
2698 fn [<$test_fn _avx512_f64>]() {
2699 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2700 }
2701 )*
2702 }
2703 }
2704 }
2705
2706 generate_all_adx_tests!(
2707 check_adx_partial_params,
2708 check_adx_accuracy,
2709 check_adx_default_candles,
2710 check_adx_zero_period,
2711 check_adx_period_exceeds_length,
2712 check_adx_very_small_dataset,
2713 check_adx_reinput,
2714 check_adx_nan_handling,
2715 check_adx_streaming,
2716 check_adx_no_poison
2717 );
2718
2719 #[cfg(feature = "proptest")]
2720 generate_all_adx_tests!(check_adx_property);
2721
2722 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2723 skip_if_unsupported!(kernel, test);
2724
2725 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2726 let c = read_candles_from_csv(file)?;
2727
2728 let output = AdxBatchBuilder::new().kernel(kernel).apply_candles(&c)?;
2729
2730 let def = AdxParams::default();
2731 let row = output
2732 .combos
2733 .iter()
2734 .position(|p| p.period == def.period)
2735 .expect("default row missing");
2736 let slice = &output.values[row * output.cols..][..output.cols];
2737
2738 assert_eq!(slice.len(), c.close.len());
2739 let expected = [36.14, 36.52, 37.01, 37.46, 38.47];
2740 let start = slice.len().saturating_sub(5);
2741 for (i, &v) in slice[start..].iter().enumerate() {
2742 assert!(
2743 (v - expected[i]).abs() < 1e-1,
2744 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
2745 );
2746 }
2747 Ok(())
2748 }
2749
2750 #[cfg(debug_assertions)]
2751 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2752 skip_if_unsupported!(kernel, test);
2753
2754 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2755 let c = read_candles_from_csv(file)?;
2756
2757 let test_configs = vec![
2758 (5, 20, 5),
2759 (10, 30, 10),
2760 (14, 14, 1),
2761 (20, 50, 15),
2762 (2, 10, 2),
2763 ];
2764
2765 for (cfg_idx, &(p_start, p_end, p_step)) in test_configs.iter().enumerate() {
2766 let output = AdxBatchBuilder::new()
2767 .kernel(kernel)
2768 .period_range(p_start, p_end, p_step)
2769 .apply_candles(&c)?;
2770
2771 for (idx, &val) in output.values.iter().enumerate() {
2772 if val.is_nan() || val.is_infinite() {
2773 continue;
2774 }
2775
2776 let bits = val.to_bits();
2777 let row = idx / output.cols;
2778 let col = idx % output.cols;
2779 let combo = &output.combos[row];
2780
2781 if bits == 0x11111111_11111111 {
2782 panic!(
2783 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2784 at row {} col {} (flat index {}) with params: period={}",
2785 test,
2786 cfg_idx,
2787 val,
2788 bits,
2789 row,
2790 col,
2791 idx,
2792 combo.period.unwrap_or(14)
2793 );
2794 }
2795
2796 if bits == 0x22222222_22222222 {
2797 panic!(
2798 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2799 at row {} col {} (flat index {}) with params: period={}",
2800 test,
2801 cfg_idx,
2802 val,
2803 bits,
2804 row,
2805 col,
2806 idx,
2807 combo.period.unwrap_or(14)
2808 );
2809 }
2810
2811 if bits == 0x33333333_33333333 {
2812 panic!(
2813 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2814 at row {} col {} (flat index {}) with params: period={}",
2815 test,
2816 cfg_idx,
2817 val,
2818 bits,
2819 row,
2820 col,
2821 idx,
2822 combo.period.unwrap_or(14)
2823 );
2824 }
2825 }
2826
2827 let params = expand_grid(&AdxBatchRange {
2828 period: (p_start, p_end, p_step),
2829 });
2830
2831 for p in ¶ms {
2832 if let Some(slice) = output.values_for(p) {
2833 for (idx, &val) in slice.iter().enumerate() {
2834 if val.is_nan() || val.is_infinite() {
2835 continue;
2836 }
2837
2838 let bits = val.to_bits();
2839 if bits == 0x11111111_11111111
2840 || bits == 0x22222222_22222222
2841 || bits == 0x33333333_33333333
2842 {
2843 panic!(
2844 "[{}] Config {}: Found poison value {} (0x{:016X}) in sliced output \
2845 at index {} with params: period={}",
2846 test, cfg_idx, val, bits, idx, p.period.unwrap_or(14)
2847 );
2848 }
2849 }
2850 }
2851 }
2852 }
2853
2854 Ok(())
2855 }
2856
2857 #[cfg(not(debug_assertions))]
2858 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2859 Ok(())
2860 }
2861
2862 macro_rules! gen_batch_tests {
2863 ($fn_name:ident) => {
2864 paste::paste! {
2865 #[test] fn [<$fn_name _scalar>]() {
2866 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2867 }
2868 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2869 #[test] fn [<$fn_name _avx2>]() {
2870 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2871 }
2872 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2873 #[test] fn [<$fn_name _avx512>]() {
2874 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2875 }
2876 #[test] fn [<$fn_name _auto_detect>]() {
2877 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2878 }
2879 }
2880 };
2881 }
2882 gen_batch_tests!(check_batch_default_row);
2883 gen_batch_tests!(check_batch_no_poison);
2884}
2885
2886#[cfg(feature = "python")]
2887#[pyfunction(name = "adx")]
2888#[pyo3(signature = (high, low, close, period, kernel=None))]
2889pub fn adx_py<'py>(
2890 py: Python<'py>,
2891 high: numpy::PyReadonlyArray1<'py, f64>,
2892 low: numpy::PyReadonlyArray1<'py, f64>,
2893 close: numpy::PyReadonlyArray1<'py, f64>,
2894 period: usize,
2895 kernel: Option<&str>,
2896) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
2897 use numpy::{IntoPyArray, PyArrayMethods};
2898
2899 let high_slice = high.as_slice()?;
2900 let low_slice = low.as_slice()?;
2901 let close_slice = close.as_slice()?;
2902
2903 if high_slice.len() != low_slice.len() || high_slice.len() != close_slice.len() {
2904 return Err(PyValueError::new_err(
2905 "Input arrays must have the same length",
2906 ));
2907 }
2908
2909 let kern = validate_kernel(kernel, false)?;
2910
2911 let params = AdxParams {
2912 period: Some(period),
2913 };
2914 let adx_in = AdxInput::from_slices(high_slice, low_slice, close_slice, params);
2915
2916 let result_vec: Vec<f64> = py
2917 .allow_threads(|| adx_with_kernel(&adx_in, kern).map(|o| o.values))
2918 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2919
2920 Ok(result_vec.into_pyarray(py))
2921}
2922
2923#[cfg(feature = "python")]
2924#[pyclass(name = "AdxStream")]
2925pub struct AdxStreamPy {
2926 stream: AdxStream,
2927}
2928
2929#[cfg(feature = "python")]
2930#[pymethods]
2931impl AdxStreamPy {
2932 #[new]
2933 fn new(period: usize) -> PyResult<Self> {
2934 let params = AdxParams {
2935 period: Some(period),
2936 };
2937 let stream =
2938 AdxStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2939 Ok(AdxStreamPy { stream })
2940 }
2941
2942 fn update(&mut self, high: f64, low: f64, close: f64) -> Option<f64> {
2943 self.stream.update(high, low, close)
2944 }
2945}
2946
2947#[cfg(feature = "python")]
2948#[pyfunction(name = "adx_batch")]
2949#[pyo3(signature = (high, low, close, period_range, kernel=None))]
2950pub fn adx_batch_py<'py>(
2951 py: Python<'py>,
2952 high: numpy::PyReadonlyArray1<'py, f64>,
2953 low: numpy::PyReadonlyArray1<'py, f64>,
2954 close: numpy::PyReadonlyArray1<'py, f64>,
2955 period_range: (usize, usize, usize),
2956 kernel: Option<&str>,
2957) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
2958 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
2959 use pyo3::types::PyDict;
2960
2961 let h = high.as_slice()?;
2962 let l = low.as_slice()?;
2963 let c = close.as_slice()?;
2964 if h.len() != l.len() || h.len() != c.len() {
2965 return Err(PyValueError::new_err(
2966 "Input arrays must have the same length",
2967 ));
2968 }
2969
2970 let sweep = AdxBatchRange {
2971 period: period_range,
2972 };
2973 let combos = expand_grid(&sweep);
2974 let rows = combos.len();
2975 let cols = c.len();
2976 let total = rows
2977 .checked_mul(cols)
2978 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
2979
2980 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2981 let out_slice = unsafe { out_arr.as_slice_mut()? };
2982
2983 let kern = validate_kernel(kernel, true)?;
2984 py.allow_threads(|| {
2985 let k = match kern {
2986 Kernel::Auto => detect_best_batch_kernel(),
2987 other => other,
2988 };
2989 let simd = match k {
2990 Kernel::ScalarBatch => Kernel::Scalar,
2991 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2992 Kernel::Avx2Batch => Kernel::Avx2,
2993 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2994 Kernel::Avx512Batch => Kernel::Avx512,
2995 _ => Kernel::Scalar,
2996 };
2997 adx_batch_inner_into(h, l, c, &sweep, simd, true, out_slice)
2998 })
2999 .map_err(|e| PyValueError::new_err(e.to_string()))?;
3000
3001 let dict = PyDict::new(py);
3002 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
3003 dict.set_item(
3004 "periods",
3005 combos
3006 .iter()
3007 .map(|p| p.period.unwrap() as u64)
3008 .collect::<Vec<_>>()
3009 .into_pyarray(py),
3010 )?;
3011 Ok(dict)
3012}
3013
3014#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3015#[wasm_bindgen]
3016pub fn adx_alloc(len: usize) -> *mut f64 {
3017 let mut v: Vec<f64> = Vec::with_capacity(len);
3018 let p = v.as_mut_ptr();
3019 std::mem::forget(v);
3020 p
3021}
3022
3023#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3024#[wasm_bindgen]
3025pub fn adx_free(ptr: *mut f64, len: usize) {
3026 unsafe {
3027 let _ = Vec::from_raw_parts(ptr, len, len);
3028 }
3029}
3030
3031#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3032#[wasm_bindgen]
3033pub fn adx_into(
3034 h_ptr: *const f64,
3035 l_ptr: *const f64,
3036 c_ptr: *const f64,
3037 out_ptr: *mut f64,
3038 len: usize,
3039 period: usize,
3040) -> Result<(), JsValue> {
3041 if [
3042 h_ptr as *const u8,
3043 l_ptr as *const u8,
3044 c_ptr as *const u8,
3045 out_ptr as *const u8,
3046 ]
3047 .iter()
3048 .any(|p| p.is_null())
3049 {
3050 return Err(JsValue::from_str("null pointer"));
3051 }
3052 unsafe {
3053 let h = std::slice::from_raw_parts(h_ptr, len);
3054 let l = std::slice::from_raw_parts(l_ptr, len);
3055 let c = std::slice::from_raw_parts(c_ptr, len);
3056 let out = std::slice::from_raw_parts_mut(out_ptr, len);
3057 let params = AdxParams {
3058 period: Some(period),
3059 };
3060 let input = AdxInput::from_slices(h, l, c, params);
3061 adx_into_slice(out, &input, detect_best_kernel())
3062 .map_err(|e| JsValue::from_str(&e.to_string()))
3063 }
3064}
3065
3066#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3067#[wasm_bindgen]
3068pub fn adx_js(
3069 high: &[f64],
3070 low: &[f64],
3071 close: &[f64],
3072 period: usize,
3073) -> Result<Vec<f64>, JsValue> {
3074 if high.len() != low.len() || high.len() != close.len() {
3075 return Err(JsValue::from_str("Input arrays must have the same length"));
3076 }
3077
3078 let params = AdxParams {
3079 period: Some(period),
3080 };
3081 let input = AdxInput::from_slices(high, low, close, params);
3082
3083 let mut output = vec![0.0; high.len()];
3084 #[cfg(target_arch = "wasm32")]
3085 let kernel = Kernel::Scalar;
3086 #[cfg(not(target_arch = "wasm32"))]
3087 let kernel = Kernel::Auto;
3088
3089 adx_into_slice(&mut output, &input, kernel).map_err(|e| JsValue::from_str(&e.to_string()))?;
3090
3091 Ok(output)
3092}
3093
3094#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3095#[wasm_bindgen]
3096pub fn adx_batch_into(
3097 h_ptr: *const f64,
3098 l_ptr: *const f64,
3099 c_ptr: *const f64,
3100 len: usize,
3101 out_ptr: *mut f64,
3102 rows: usize,
3103 cols: usize,
3104 period_start: usize,
3105 period_end: usize,
3106 period_step: usize,
3107) -> Result<usize, JsValue> {
3108 if [
3109 h_ptr as *const u8,
3110 l_ptr as *const u8,
3111 c_ptr as *const u8,
3112 out_ptr as *const u8,
3113 ]
3114 .iter()
3115 .any(|p| p.is_null())
3116 {
3117 return Err(JsValue::from_str("null pointer"));
3118 }
3119 if cols != len {
3120 return Err(JsValue::from_str("cols must equal len"));
3121 }
3122 unsafe {
3123 let h = std::slice::from_raw_parts(h_ptr, len);
3124 let l = std::slice::from_raw_parts(l_ptr, len);
3125 let c = std::slice::from_raw_parts(c_ptr, len);
3126 let sweep = AdxBatchRange {
3127 period: (period_start, period_end, period_step),
3128 };
3129 let combos = expand_grid(&sweep);
3130 if combos.len() != rows {
3131 return Err(JsValue::from_str("rows mismatch"));
3132 }
3133 let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
3134 adx_batch_inner_into(h, l, c, &sweep, detect_best_kernel(), false, out)
3135 .map(|_| rows)
3136 .map_err(|e| JsValue::from_str(&e.to_string()))
3137 }
3138}
3139
3140#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3141#[wasm_bindgen]
3142pub fn adx_batch_js(
3143 high: &[f64],
3144 low: &[f64],
3145 close: &[f64],
3146 period_start: usize,
3147 period_end: usize,
3148 period_step: usize,
3149) -> Result<Vec<f64>, JsValue> {
3150 if high.len() != low.len() || high.len() != close.len() {
3151 return Err(JsValue::from_str("Input arrays must have the same length"));
3152 }
3153
3154 let sweep = AdxBatchRange {
3155 period: (period_start, period_end, period_step),
3156 };
3157
3158 #[cfg(target_arch = "wasm32")]
3159 let kernel = Kernel::Scalar;
3160 #[cfg(not(target_arch = "wasm32"))]
3161 let kernel = Kernel::Auto;
3162
3163 adx_batch_inner(high, low, close, &sweep, kernel, false)
3164 .map(|output| output.values)
3165 .map_err(|e| JsValue::from_str(&e.to_string()))
3166}
3167
3168#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3169#[wasm_bindgen]
3170pub fn adx_batch_metadata_js(
3171 period_start: usize,
3172 period_end: usize,
3173 period_step: usize,
3174) -> Result<Vec<f64>, JsValue> {
3175 let sweep = AdxBatchRange {
3176 period: (period_start, period_end, period_step),
3177 };
3178
3179 let combos = expand_grid(&sweep);
3180 let metadata: Vec<f64> = combos
3181 .into_iter()
3182 .map(|combo| combo.period.unwrap() as f64)
3183 .collect();
3184
3185 Ok(metadata)
3186}
3187
3188#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3189#[derive(Serialize, Deserialize)]
3190pub struct AdxBatchConfig {
3191 pub period_range: (usize, usize, usize),
3192}
3193
3194#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3195#[derive(Serialize, Deserialize)]
3196pub struct AdxBatchJsOutput {
3197 pub values: Vec<f64>,
3198 pub combos: Vec<AdxParams>,
3199 pub rows: usize,
3200 pub cols: usize,
3201}
3202
3203#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3204#[wasm_bindgen(js_name = adx_batch)]
3205pub fn adx_batch_unified_js(
3206 high: &[f64],
3207 low: &[f64],
3208 close: &[f64],
3209 config: JsValue,
3210) -> Result<JsValue, JsValue> {
3211 if high.len() != low.len() || high.len() != close.len() {
3212 return Err(JsValue::from_str("Input arrays must have the same length"));
3213 }
3214
3215 let config: AdxBatchConfig = serde_wasm_bindgen::from_value(config)
3216 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
3217
3218 let sweep = AdxBatchRange {
3219 period: config.period_range,
3220 };
3221
3222 #[cfg(target_arch = "wasm32")]
3223 let kernel = Kernel::ScalarBatch;
3224 #[cfg(not(target_arch = "wasm32"))]
3225 let kernel = Kernel::Auto;
3226
3227 let output = adx_batch_inner(high, low, close, &sweep, kernel, false)
3228 .map_err(|e| JsValue::from_str(&e.to_string()))?;
3229
3230 let js_output = AdxBatchJsOutput {
3231 values: output.values,
3232 combos: output.combos,
3233 rows: output.rows,
3234 cols: output.cols,
3235 };
3236
3237 serde_wasm_bindgen::to_value(&js_output)
3238 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
3239}
3240
3241#[inline]
3242pub fn adx_into_slice(dst: &mut [f64], input: &AdxInput, kern: Kernel) -> Result<(), AdxError> {
3243 let (high, low, close) = match &input.data {
3244 AdxData::Candles { candles } => {
3245 let h = candles
3246 .select_candle_field("high")
3247 .map_err(|_| AdxError::CandleFieldError { field: "high" })?;
3248 let l = candles
3249 .select_candle_field("low")
3250 .map_err(|_| AdxError::CandleFieldError { field: "low" })?;
3251 let c = candles
3252 .select_candle_field("close")
3253 .map_err(|_| AdxError::CandleFieldError { field: "close" })?;
3254 (h, l, c)
3255 }
3256 AdxData::Slices { high, low, close } => (*high, *low, *close),
3257 };
3258
3259 if high.len() != low.len() || high.len() != close.len() {
3260 return Err(AdxError::InconsistentLengths);
3261 }
3262 let len = close.len();
3263 if dst.len() != len {
3264 return Err(AdxError::OutputLengthMismatch {
3265 expected: len,
3266 got: dst.len(),
3267 });
3268 }
3269 if len == 0 {
3270 return Err(AdxError::EmptyInputData);
3271 }
3272
3273 let period = input.get_period();
3274 if period == 0 || period > len {
3275 return Err(AdxError::InvalidPeriod {
3276 period,
3277 data_len: len,
3278 });
3279 }
3280 if high.iter().all(|x| x.is_nan())
3281 || low.iter().all(|x| x.is_nan())
3282 || close.iter().all(|x| x.is_nan())
3283 {
3284 return Err(AdxError::AllValuesNaN);
3285 }
3286
3287 let first = first_valid_triple(high, low, close);
3288 if len - first < period + 1 {
3289 return Err(AdxError::NotEnoughValidData {
3290 needed: period + 1,
3291 valid: len - first,
3292 });
3293 }
3294
3295 let warm_end = first + (2 * period - 1);
3296 for v in &mut dst[..warm_end.min(len)] {
3297 *v = f64::NAN;
3298 }
3299
3300 let chosen = match kern {
3301 Kernel::Auto => detect_best_kernel(),
3302 k => k,
3303 };
3304 unsafe {
3305 match chosen {
3306 Kernel::Scalar | Kernel::ScalarBatch => adx_scalar(
3307 &high[first..],
3308 &low[first..],
3309 &close[first..],
3310 period,
3311 &mut dst[first..],
3312 ),
3313 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3314 Kernel::Avx2 | Kernel::Avx2Batch => adx_avx2(
3315 &high[first..],
3316 &low[first..],
3317 &close[first..],
3318 period,
3319 &mut dst[first..],
3320 ),
3321 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3322 Kernel::Avx512 | Kernel::Avx512Batch => adx_avx512(
3323 &high[first..],
3324 &low[first..],
3325 &close[first..],
3326 period,
3327 &mut dst[first..],
3328 ),
3329 _ => unreachable!(),
3330 }
3331 }
3332 Ok(())
3333}