1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::cuda_available;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::cuda::moving_averages::alma_wrapper::DeviceArrayF32;
5#[cfg(all(feature = "python", feature = "cuda"))]
6use crate::cuda::CudaAdxr;
7#[cfg(all(feature = "python", feature = "cuda"))]
8use cust::context::Context;
9#[cfg(all(feature = "python", feature = "cuda"))]
10use numpy::PyUntypedArrayMethods;
11#[cfg(feature = "python")]
12use numpy::{IntoPyArray, PyArray1};
13#[cfg(feature = "python")]
14use pyo3::exceptions::PyValueError;
15#[cfg(feature = "python")]
16use pyo3::prelude::*;
17#[cfg(feature = "python")]
18use pyo3::types::PyDict;
19#[cfg(all(feature = "python", feature = "cuda"))]
20use std::sync::Arc;
21
22#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
23use serde::{Deserialize, Serialize};
24#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
25use wasm_bindgen::prelude::*;
26
27use crate::utilities::data_loader::Candles;
28use crate::utilities::enums::Kernel;
29use crate::utilities::helpers::{
30 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
31 make_uninit_matrix,
32};
33#[cfg(feature = "python")]
34use crate::utilities::kernel_validation::validate_kernel;
35#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
36use core::arch::x86_64::*;
37#[cfg(not(target_arch = "wasm32"))]
38use rayon::prelude::*;
39use std::error::Error;
40use thiserror::Error;
41
42#[derive(Debug, Clone)]
43pub enum AdxrData<'a> {
44 Candles {
45 candles: &'a Candles,
46 },
47 Slices {
48 high: &'a [f64],
49 low: &'a [f64],
50 close: &'a [f64],
51 },
52}
53
54#[derive(Debug, Clone)]
55pub struct AdxrOutput {
56 pub values: Vec<f64>,
57}
58
59#[derive(Debug, Clone)]
60#[cfg_attr(
61 all(target_arch = "wasm32", feature = "wasm"),
62 derive(Serialize, Deserialize)
63)]
64pub struct AdxrParams {
65 pub period: Option<usize>,
66}
67
68impl Default for AdxrParams {
69 fn default() -> Self {
70 Self { period: Some(14) }
71 }
72}
73
74#[derive(Debug, Clone)]
75pub struct AdxrInput<'a> {
76 pub data: AdxrData<'a>,
77 pub params: AdxrParams,
78}
79
80impl<'a> AdxrInput<'a> {
81 #[inline]
82 pub fn from_candles(c: &'a Candles, p: AdxrParams) -> Self {
83 Self {
84 data: AdxrData::Candles { candles: c },
85 params: p,
86 }
87 }
88 #[inline]
89 pub fn from_slices(h: &'a [f64], l: &'a [f64], c: &'a [f64], p: AdxrParams) -> Self {
90 Self {
91 data: AdxrData::Slices {
92 high: h,
93 low: l,
94 close: c,
95 },
96 params: p,
97 }
98 }
99 #[inline]
100 pub fn with_default_candles(c: &'a Candles) -> Self {
101 Self::from_candles(c, AdxrParams::default())
102 }
103 #[inline]
104 pub fn get_period(&self) -> usize {
105 self.params.period.unwrap_or(14)
106 }
107}
108
109#[derive(Copy, Clone, Debug)]
110pub struct AdxrBuilder {
111 period: Option<usize>,
112 kernel: Kernel,
113}
114
115impl Default for AdxrBuilder {
116 fn default() -> Self {
117 Self {
118 period: None,
119 kernel: Kernel::Auto,
120 }
121 }
122}
123
124impl AdxrBuilder {
125 #[inline(always)]
126 pub fn new() -> Self {
127 Self::default()
128 }
129 #[inline(always)]
130 pub fn period(mut self, n: usize) -> Self {
131 self.period = Some(n);
132 self
133 }
134 #[inline(always)]
135 pub fn kernel(mut self, k: Kernel) -> Self {
136 self.kernel = k;
137 self
138 }
139 #[inline(always)]
140 pub fn apply(self, c: &Candles) -> Result<AdxrOutput, AdxrError> {
141 let p = AdxrParams {
142 period: self.period,
143 };
144 let i = AdxrInput::from_candles(c, p);
145 adxr_with_kernel(&i, self.kernel)
146 }
147 #[inline(always)]
148 pub fn apply_slices(self, h: &[f64], l: &[f64], c: &[f64]) -> Result<AdxrOutput, AdxrError> {
149 let p = AdxrParams {
150 period: self.period,
151 };
152 let i = AdxrInput::from_slices(h, l, c, p);
153 adxr_with_kernel(&i, self.kernel)
154 }
155 #[inline(always)]
156 pub fn into_stream(self) -> Result<AdxrStream, AdxrError> {
157 let p = AdxrParams {
158 period: self.period,
159 };
160 AdxrStream::try_new(p)
161 }
162}
163
164#[derive(Debug, Error)]
165pub enum AdxrError {
166 #[error("adxr: Candle field error: {0}")]
167 CandleFieldError(String),
168 #[error("adxr: Empty input data (All values are NaN).")]
169 EmptyInputData,
170 #[error("adxr: HLC data length mismatch: high={high_len}, low={low_len}, close={close_len}")]
171 HlcLengthMismatch {
172 high_len: usize,
173 low_len: usize,
174 close_len: usize,
175 },
176 #[error("adxr: All values are NaN.")]
177 AllValuesNaN,
178 #[error("adxr: Invalid period: period = {period}, data length = {data_len}")]
179 InvalidPeriod { period: usize, data_len: usize },
180
181 #[error("adxr: Not enough data: needed = {needed}, valid = {valid}")]
182 NotEnoughValidData { needed: usize, valid: usize },
183 #[error("adxr: Output length mismatch: expected = {expected}, got = {got}")]
184 OutputLengthMismatch { expected: usize, got: usize },
185 #[error("adxr: Invalid kernel type - expected batch kernel, got {kernel:?}")]
186 InvalidKernel { kernel: Kernel },
187 #[error("adxr: Invalid range: start={start}, end={end}, step={step}")]
188 InvalidRange {
189 start: usize,
190 end: usize,
191 step: usize,
192 },
193 #[error("adxr: Invalid kernel for batch: {0:?}")]
194 InvalidKernelForBatch(Kernel),
195}
196
197#[inline]
198pub fn adxr(input: &AdxrInput) -> Result<AdxrOutput, AdxrError> {
199 adxr_with_kernel(input, Kernel::Auto)
200}
201
202pub fn adxr_with_kernel(input: &AdxrInput, kernel: Kernel) -> Result<AdxrOutput, AdxrError> {
203 let (high, low, close, period, first, chosen) = adxr_prepare(input, kernel)?;
204
205 let len = close.len();
206
207 let warmup_period = first + 2 * period;
208 let mut out = alloc_with_nan_prefix(len, warmup_period);
209 unsafe {
210 match chosen {
211 Kernel::Scalar | Kernel::ScalarBatch => {
212 adxr_scalar(high, low, close, period, first, &mut out)
213 }
214 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
215 Kernel::Avx2 | Kernel::Avx2Batch => {
216 adxr_avx2(high, low, close, period, first, &mut out)
217 }
218 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
219 Kernel::Avx512 | Kernel::Avx512Batch => {
220 adxr_avx512(high, low, close, period, first, &mut out)
221 }
222 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
223 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
224 adxr_scalar(high, low, close, period, first, &mut out)
225 }
226 _ => unreachable!(),
227 }
228 }
229
230 Ok(AdxrOutput { values: out })
231}
232
233#[inline(always)]
234fn adxr_prepare<'a>(
235 input: &'a AdxrInput,
236 kernel: Kernel,
237) -> Result<(&'a [f64], &'a [f64], &'a [f64], usize, usize, Kernel), AdxrError> {
238 let (high, low, close) = match &input.data {
239 AdxrData::Candles { candles } => (
240 candles
241 .select_candle_field("high")
242 .map_err(|e| AdxrError::CandleFieldError(e.to_string()))?,
243 candles
244 .select_candle_field("low")
245 .map_err(|e| AdxrError::CandleFieldError(e.to_string()))?,
246 candles
247 .select_candle_field("close")
248 .map_err(|e| AdxrError::CandleFieldError(e.to_string()))?,
249 ),
250 AdxrData::Slices { high, low, close } => (*high, *low, *close),
251 };
252
253 let len = close.len();
254 if len == 0 {
255 return Err(AdxrError::EmptyInputData);
256 }
257 if high.len() != len || low.len() != len {
258 return Err(AdxrError::HlcLengthMismatch {
259 high_len: high.len(),
260 low_len: low.len(),
261 close_len: len,
262 });
263 }
264
265 let first = close
266 .iter()
267 .position(|x| !x.is_nan())
268 .ok_or(AdxrError::AllValuesNaN)?;
269 let period = input.get_period();
270 if period == 0 || period > len {
271 return Err(AdxrError::InvalidPeriod {
272 period,
273 data_len: len,
274 });
275 }
276
277 if len - first < period + 1 {
278 return Err(AdxrError::NotEnoughValidData {
279 needed: period + 1,
280 valid: len - first,
281 });
282 }
283
284 let mut chosen = match kernel {
285 Kernel::Auto => detect_best_kernel(),
286 other => other,
287 };
288
289 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
290 if matches!(kernel, Kernel::Auto) && matches!(chosen, Kernel::Avx512 | Kernel::Avx512Batch) {
291 chosen = Kernel::Avx2;
292 }
293
294 Ok((high, low, close, period, first, chosen))
295}
296
297#[inline]
298pub fn adxr_into_slice(dst: &mut [f64], input: &AdxrInput, kern: Kernel) -> Result<(), AdxrError> {
299 let (high, low, close, period, first, chosen) = adxr_prepare(input, kern)?;
300
301 let len = close.len();
302 if dst.len() != len {
303 return Err(AdxrError::OutputLengthMismatch {
304 expected: len,
305 got: dst.len(),
306 });
307 }
308
309 unsafe {
310 match chosen {
311 Kernel::Scalar | Kernel::ScalarBatch => {
312 adxr_scalar(high, low, close, period, first, dst)
313 }
314 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
315 Kernel::Avx2 | Kernel::Avx2Batch => adxr_avx2(high, low, close, period, first, dst),
316 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
317 Kernel::Avx512 | Kernel::Avx512Batch => {
318 adxr_avx512(high, low, close, period, first, dst)
319 }
320 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
321 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
322 adxr_scalar(high, low, close, period, first, dst)
323 }
324 _ => unreachable!(),
325 }
326 }
327
328 let warmup_end = (first + 2 * period).min(dst.len());
329 for v in &mut dst[..warmup_end] {
330 *v = f64::NAN;
331 }
332
333 Ok(())
334}
335
336#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
337#[inline]
338pub fn adxr_into(input: &AdxrInput, out: &mut [f64]) -> Result<(), AdxrError> {
339 adxr_into_slice(out, input, Kernel::Auto)
340}
341
342#[inline]
343pub fn adxr_scalar(
344 high: &[f64],
345 low: &[f64],
346 close: &[f64],
347 period: usize,
348 first: usize,
349 out: &mut [f64],
350) {
351 let len = close.len();
352 if len == 0 {
353 return;
354 }
355
356 let p = period as f64;
357 let rp = 1.0 / p;
358 let om = 1.0 - rp;
359 let pm1 = p - 1.0;
360 let warmup_start = first + 2 * period;
361
362 let mut atr_sum = 0.0;
363 let mut plus_dm_sum = 0.0;
364 let mut minus_dm_sum = 0.0;
365
366 let stop = (first + period).min(len.saturating_sub(1));
367 for i in (first + 1)..=stop {
368 let prev_close = close[i - 1];
369 let ch = high[i];
370 let cl = low[i];
371 let ph = high[i - 1];
372 let pl = low[i - 1];
373
374 let a = ch - cl;
375 let b = (ch - prev_close).abs();
376 let c = (cl - prev_close).abs();
377 let tr = a.max(b).max(c);
378 atr_sum += tr;
379
380 let up = ch - ph;
381 let down = pl - cl;
382 if up > down && up > 0.0 {
383 plus_dm_sum += up;
384 }
385 if down > up && down > 0.0 {
386 minus_dm_sum += down;
387 }
388 }
389
390 let denom0 = plus_dm_sum + minus_dm_sum;
391 let initial_dx = if denom0 > 0.0 {
392 100.0 * (plus_dm_sum - minus_dm_sum).abs() / denom0
393 } else {
394 0.0
395 };
396
397 let mut atr = atr_sum;
398 let mut pdm_s = plus_dm_sum;
399 let mut mdm_s = minus_dm_sum;
400
401 let mut dx_sum = initial_dx;
402 let mut dx_count: usize = 1;
403 let mut adx_last = f64::NAN;
404 let mut have_adx = false;
405
406 let mut adx_ring = vec![f64::NAN; period];
407 let mut head = 0usize;
408
409 let mut i = first + period + 1;
410 while i < len {
411 let prev_close = close[i - 1];
412 let ch = high[i];
413 let cl = low[i];
414 let ph = high[i - 1];
415 let pl = low[i - 1];
416
417 let a = ch - cl;
418 let b = (ch - prev_close).abs();
419 let c = (cl - prev_close).abs();
420 let tr = a.max(b).max(c);
421
422 let up = ch - ph;
423 let down = pl - cl;
424 let plus_dm = if up > down && up > 0.0 { up } else { 0.0 };
425 let minus_dm = if down > up && down > 0.0 { down } else { 0.0 };
426
427 atr = atr.mul_add(om, tr);
428 pdm_s = pdm_s.mul_add(om, plus_dm);
429 mdm_s = mdm_s.mul_add(om, minus_dm);
430
431 let denom = pdm_s + mdm_s;
432 let dx = if denom > 0.0 {
433 100.0 * (pdm_s - mdm_s).abs() / denom
434 } else {
435 0.0
436 };
437
438 if dx_count < period {
439 dx_sum += dx;
440 dx_count += 1;
441
442 if i >= warmup_start {
443 out[i] = f64::NAN;
444 }
445
446 if dx_count == period {
447 adx_last = dx_sum * rp;
448 have_adx = true;
449
450 let prev_adx = adx_ring[head];
451 adx_ring[head] = adx_last;
452 head += 1;
453 if head == period {
454 head = 0;
455 }
456
457 if i >= warmup_start {
458 let v = if prev_adx.is_finite() {
459 0.5 * (adx_last + prev_adx)
460 } else {
461 f64::NAN
462 };
463 out[i] = v;
464 }
465 }
466 } else if have_adx {
467 let adx_curr = (adx_last * pm1 + dx) * rp;
468 adx_last = adx_curr;
469
470 let prev_adx = adx_ring[head];
471 adx_ring[head] = adx_curr;
472 head += 1;
473 if head == period {
474 head = 0;
475 }
476
477 if i >= warmup_start {
478 let v = if prev_adx.is_finite() {
479 0.5 * (adx_curr + prev_adx)
480 } else {
481 f64::NAN
482 };
483 out[i] = v;
484 }
485 }
486
487 i += 1;
488 }
489}
490
491#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
492#[inline]
493pub fn adxr_avx512(
494 high: &[f64],
495 low: &[f64],
496 close: &[f64],
497 period: usize,
498 first: usize,
499 out: &mut [f64],
500) {
501 unsafe {
502 if period <= 32 {
503 adxr_avx512_short(high, low, close, period, first, out)
504 } else {
505 adxr_avx512_long(high, low, close, period, first, out)
506 }
507 }
508}
509
510#[inline]
511pub fn adxr_avx2(
512 high: &[f64],
513 low: &[f64],
514 close: &[f64],
515 period: usize,
516 first: usize,
517 out: &mut [f64],
518) {
519 unsafe { adxr_scalar_unchecked(high, low, close, period, first, out) }
520}
521
522#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
523#[inline]
524pub fn adxr_avx512_short(
525 high: &[f64],
526 low: &[f64],
527 close: &[f64],
528 period: usize,
529 first: usize,
530 out: &mut [f64],
531) {
532 adxr_scalar(high, low, close, period, first, out)
533}
534
535#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
536#[inline]
537pub fn adxr_avx512_long(
538 high: &[f64],
539 low: &[f64],
540 close: &[f64],
541 period: usize,
542 first: usize,
543 out: &mut [f64],
544) {
545 adxr_scalar(high, low, close, period, first, out)
546}
547
548#[inline]
549unsafe fn adxr_scalar_unchecked(
550 high: &[f64],
551 low: &[f64],
552 close: &[f64],
553 period: usize,
554 first: usize,
555 out: &mut [f64],
556) {
557 let len = close.len();
558 if len == 0 {
559 return;
560 }
561
562 let p = period as f64;
563 let rp = 1.0 / p;
564 let om = 1.0 - rp;
565 let pm1 = p - 1.0;
566 let warmup_start = first + 2 * period;
567
568 let mut atr_sum = 0.0;
569 let mut plus_dm_sum = 0.0;
570 let mut minus_dm_sum = 0.0;
571
572 let mut i = first + 1;
573 let stop = core::cmp::min(first + period, len - 1);
574 while i <= stop {
575 let prev_close = *close.get_unchecked(i - 1);
576 let ch = *high.get_unchecked(i);
577 let cl = *low.get_unchecked(i);
578 let ph = *high.get_unchecked(i - 1);
579 let pl = *low.get_unchecked(i - 1);
580
581 let a = ch - cl;
582 let b = (ch - prev_close).abs();
583 let c = (cl - prev_close).abs();
584 let tr = a.max(b).max(c);
585 atr_sum += tr;
586
587 let up = ch - ph;
588 let down = pl - cl;
589 if up > down && up > 0.0 {
590 plus_dm_sum += up;
591 }
592 if down > up && down > 0.0 {
593 minus_dm_sum += down;
594 }
595 i += 1;
596 }
597
598 let denom0 = plus_dm_sum + minus_dm_sum;
599 let initial_dx = if denom0 > 0.0 {
600 100.0 * (plus_dm_sum - minus_dm_sum).abs() / denom0
601 } else {
602 0.0
603 };
604
605 let mut atr = atr_sum;
606 let mut pdm_s = plus_dm_sum;
607 let mut mdm_s = minus_dm_sum;
608
609 let mut dx_sum = initial_dx;
610 let mut dx_count: usize = 1;
611 let mut adx_last = f64::NAN;
612 let mut have_adx = false;
613
614 let mut adx_ring = vec![f64::NAN; period];
615 let mut head = 0usize;
616
617 i = first + period + 1;
618 while i < len {
619 let prev_close = *close.get_unchecked(i - 1);
620 let ch = *high.get_unchecked(i);
621 let cl = *low.get_unchecked(i);
622 let ph = *high.get_unchecked(i - 1);
623 let pl = *low.get_unchecked(i - 1);
624
625 let a = ch - cl;
626 let b = (ch - prev_close).abs();
627 let c = (cl - prev_close).abs();
628 let tr = a.max(b).max(c);
629
630 let up = ch - ph;
631 let down = pl - cl;
632 let plus_dm = if up > down && up > 0.0 { up } else { 0.0 };
633 let minus_dm = if down > up && down > 0.0 { down } else { 0.0 };
634
635 atr = atr.mul_add(om, tr);
636 pdm_s = pdm_s.mul_add(om, plus_dm);
637 mdm_s = mdm_s.mul_add(om, minus_dm);
638
639 let denom = pdm_s + mdm_s;
640 let dx = if denom > 0.0 {
641 100.0 * (pdm_s - mdm_s).abs() / denom
642 } else {
643 0.0
644 };
645
646 if dx_count < period {
647 dx_sum += dx;
648 dx_count += 1;
649
650 if i >= warmup_start {
651 *out.get_unchecked_mut(i) = f64::NAN;
652 }
653
654 if dx_count == period {
655 adx_last = dx_sum * rp;
656 have_adx = true;
657
658 let prev_adx = *adx_ring.get_unchecked(head);
659 *adx_ring.get_unchecked_mut(head) = adx_last;
660 head += 1;
661 if head == period {
662 head = 0;
663 }
664
665 if i >= warmup_start {
666 let v = if prev_adx.is_finite() {
667 0.5 * (adx_last + prev_adx)
668 } else {
669 f64::NAN
670 };
671 *out.get_unchecked_mut(i) = v;
672 }
673 }
674 } else if have_adx {
675 let adx_curr = (adx_last * pm1 + dx) * rp;
676 adx_last = adx_curr;
677
678 let prev_adx = *adx_ring.get_unchecked(head);
679 *adx_ring.get_unchecked_mut(head) = adx_curr;
680 head += 1;
681 if head == period {
682 head = 0;
683 }
684
685 if i >= warmup_start {
686 let v = if prev_adx.is_finite() {
687 0.5 * (adx_curr + prev_adx)
688 } else {
689 f64::NAN
690 };
691 *out.get_unchecked_mut(i) = v;
692 }
693 }
694
695 i += 1;
696 }
697}
698
699#[inline(always)]
700pub fn adxr_batch_with_kernel(
701 h: &[f64],
702 l: &[f64],
703 c: &[f64],
704 sweep: &AdxrBatchRange,
705 k: Kernel,
706) -> Result<AdxrBatchOutput, AdxrError> {
707 let mut kernel = match k {
708 Kernel::Auto => detect_best_batch_kernel(),
709 other if other.is_batch() => other,
710 _ => return Err(AdxrError::InvalidKernelForBatch(k)),
711 };
712 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
713 if matches!(k, Kernel::Auto) && matches!(kernel, Kernel::Avx512Batch) {
714 kernel = Kernel::Avx2Batch;
715 }
716 let simd = match kernel {
717 Kernel::Avx512Batch => Kernel::Avx512,
718 Kernel::Avx2Batch => Kernel::Avx2,
719 Kernel::ScalarBatch => Kernel::Scalar,
720 _ => unreachable!(),
721 };
722
723 let combos = expand_grid(sweep)?;
724 let rows = combos.len();
725 let cols = c.len();
726 let mut buf_mu = make_uninit_matrix(rows, cols);
727 let warm: Vec<usize> = combos
728 .iter()
729 .map(|p| {
730 let first = c.iter().position(|x| !x.is_nan()).unwrap_or(0);
731
732 first + 2 * p.period.unwrap()
733 })
734 .collect();
735 init_matrix_prefixes(&mut buf_mu, cols, &warm);
736
737 let mut guard = core::mem::ManuallyDrop::new(buf_mu);
738 let out: &mut [f64] =
739 unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
740 let combos = adxr_batch_inner_into(h, l, c, sweep, simd, true, out)?;
741 let values = unsafe {
742 Vec::from_raw_parts(
743 guard.as_mut_ptr() as *mut f64,
744 guard.len(),
745 guard.capacity(),
746 )
747 };
748 Ok(AdxrBatchOutput {
749 values,
750 combos,
751 rows,
752 cols,
753 })
754}
755
756#[derive(Clone, Debug)]
757pub struct AdxrBatchRange {
758 pub period: (usize, usize, usize),
759}
760
761impl Default for AdxrBatchRange {
762 fn default() -> Self {
763 Self {
764 period: (14, 263, 1),
765 }
766 }
767}
768
769#[derive(Clone, Debug, Default)]
770pub struct AdxrBatchBuilder {
771 range: AdxrBatchRange,
772 kernel: Kernel,
773}
774impl AdxrBatchBuilder {
775 pub fn new() -> Self {
776 Self::default()
777 }
778 pub fn kernel(mut self, k: Kernel) -> Self {
779 self.kernel = k;
780 self
781 }
782 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
783 self.range.period = (start, end, step);
784 self
785 }
786 pub fn period_static(mut self, p: usize) -> Self {
787 self.range.period = (p, p, 0);
788 self
789 }
790 pub fn apply_slices(
791 self,
792 h: &[f64],
793 l: &[f64],
794 c: &[f64],
795 ) -> Result<AdxrBatchOutput, AdxrError> {
796 adxr_batch_with_kernel(h, l, c, &self.range, self.kernel)
797 }
798 pub fn apply_candles(self, candles: &Candles) -> Result<AdxrBatchOutput, AdxrError> {
799 let h = &candles.high;
800 let l = &candles.low;
801 let c = &candles.close;
802 self.apply_slices(h, l, c)
803 }
804}
805
806#[derive(Clone, Debug)]
807pub struct AdxrBatchOutput {
808 pub values: Vec<f64>,
809 pub combos: Vec<AdxrParams>,
810 pub rows: usize,
811 pub cols: usize,
812}
813impl AdxrBatchOutput {
814 pub fn row_for_params(&self, p: &AdxrParams) -> Option<usize> {
815 self.combos
816 .iter()
817 .position(|c| c.period.unwrap_or(14) == p.period.unwrap_or(14))
818 }
819 pub fn values_for(&self, p: &AdxrParams) -> Option<&[f64]> {
820 self.row_for_params(p).map(|row| {
821 let start = row * self.cols;
822 &self.values[start..start + self.cols]
823 })
824 }
825}
826
827#[inline(always)]
828fn expand_grid(r: &AdxrBatchRange) -> Result<Vec<AdxrParams>, AdxrError> {
829 fn axis((start, end, step): (usize, usize, usize)) -> Option<Vec<usize>> {
830 if step == 0 || start == end {
831 return Some(vec![start]);
832 }
833 if start < end {
834 return Some((start..=end).step_by(step).collect());
835 }
836
837 if step == 0 {
838 return Some(vec![start]);
839 }
840 let mut v = Vec::new();
841 let mut cur = start;
842 while cur >= end {
843 v.push(cur);
844 if let Some(next) = cur.checked_sub(step) {
845 cur = next;
846 } else {
847 break;
848 }
849 if cur == usize::MAX {
850 break;
851 }
852 if cur < end {
853 break;
854 }
855 }
856 Some(v)
857 }
858 let periods = axis(r.period).unwrap_or_default();
859 if periods.is_empty() {
860 return Err(AdxrError::InvalidRange {
861 start: r.period.0,
862 end: r.period.1,
863 step: r.period.2,
864 });
865 }
866 let mut out = Vec::with_capacity(periods.len());
867 for &p in &periods {
868 out.push(AdxrParams { period: Some(p) });
869 }
870 Ok(out)
871}
872
873#[inline]
874fn shared_precompute_tr_dm(
875 high: &[f64],
876 low: &[f64],
877 close: &[f64],
878 first: usize,
879) -> (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>) {
880 let len = close.len();
881 let mut tr_all = vec![0.0; len];
882 let mut pdm_all = vec![0.0; len];
883 let mut mdm_all = vec![0.0; len];
884
885 for i in (first + 1)..len {
886 let prev_close = close[i - 1];
887 let ch = high[i];
888 let cl = low[i];
889 let ph = high[i - 1];
890 let pl = low[i - 1];
891
892 let a = ch - cl;
893 let b = (ch - prev_close).abs();
894 let c = (cl - prev_close).abs();
895 tr_all[i] = a.max(b).max(c);
896
897 let up = ch - ph;
898 let down = pl - cl;
899 pdm_all[i] = if up > down && up > 0.0 { up } else { 0.0 };
900 mdm_all[i] = if down > up && down > 0.0 { down } else { 0.0 };
901 }
902
903 let start = first + 1;
904 let pre_len = len.saturating_sub(start);
905 let mut prefix_tr = vec![0.0; pre_len + 1];
906 let mut prefix_pdm = vec![0.0; pre_len + 1];
907 let mut prefix_mdm = vec![0.0; pre_len + 1];
908 for k in 1..=pre_len {
909 let i = start + (k - 1);
910 prefix_tr[k] = prefix_tr[k - 1] + tr_all[i];
911 prefix_pdm[k] = prefix_pdm[k - 1] + pdm_all[i];
912 prefix_mdm[k] = prefix_mdm[k - 1] + mdm_all[i];
913 }
914
915 (tr_all, pdm_all, mdm_all, prefix_tr, prefix_pdm, prefix_mdm)
916}
917
918#[inline]
919fn adxr_row_from_precomputed(
920 tr_all: &[f64],
921 pdm_all: &[f64],
922 mdm_all: &[f64],
923 prefix_tr: &[f64],
924 prefix_pdm: &[f64],
925 prefix_mdm: &[f64],
926 first: usize,
927 period: usize,
928 out: &mut [f64],
929) {
930 let len = tr_all.len();
931 if len == 0 {
932 return;
933 }
934
935 let p = period as f64;
936 let rp = 1.0 / p;
937 let om = 1.0 - rp;
938 let pm1 = p - 1.0;
939 let warmup_start = first + 2 * period;
940
941 let atr0 = prefix_tr.get(period).copied().unwrap_or(0.0);
942 let pdm0 = prefix_pdm.get(period).copied().unwrap_or(0.0);
943 let mdm0 = prefix_mdm.get(period).copied().unwrap_or(0.0);
944
945 let denom0 = pdm0 + mdm0;
946 let initial_dx = if denom0 > 0.0 {
947 100.0 * (pdm0 - mdm0).abs() / denom0
948 } else {
949 0.0
950 };
951
952 let mut atr = atr0;
953 let mut pdm_s = pdm0;
954 let mut mdm_s = mdm0;
955
956 let mut dx_sum = initial_dx;
957 let mut dx_count: usize = 1;
958 let mut adx_last = f64::NAN;
959 let mut have_adx = false;
960 let mut adx_ring = vec![f64::NAN; period];
961 let mut head = 0usize;
962
963 let mut i = first + period + 1;
964 while i < len {
965 let tr = tr_all[i];
966 let plus_dm = pdm_all[i];
967 let minus_dm = mdm_all[i];
968
969 atr = atr.mul_add(om, tr);
970 pdm_s = pdm_s.mul_add(om, plus_dm);
971 mdm_s = mdm_s.mul_add(om, minus_dm);
972
973 let denom = pdm_s + mdm_s;
974 let dx = if denom > 0.0 {
975 100.0 * (pdm_s - mdm_s).abs() / denom
976 } else {
977 0.0
978 };
979
980 if dx_count < period {
981 dx_sum += dx;
982 dx_count += 1;
983 if i >= warmup_start {
984 out[i] = f64::NAN;
985 }
986 if dx_count == period {
987 adx_last = dx_sum * rp;
988 have_adx = true;
989
990 let prev_adx = adx_ring[head];
991 adx_ring[head] = adx_last;
992 head += 1;
993 if head == period {
994 head = 0;
995 }
996
997 if i >= warmup_start {
998 let v = if prev_adx.is_finite() {
999 0.5 * (adx_last + prev_adx)
1000 } else {
1001 f64::NAN
1002 };
1003 out[i] = v;
1004 }
1005 }
1006 } else if have_adx {
1007 let adx_curr = (adx_last * pm1 + dx) * rp;
1008 adx_last = adx_curr;
1009
1010 let prev_adx = adx_ring[head];
1011 adx_ring[head] = adx_curr;
1012 head += 1;
1013 if head == period {
1014 head = 0;
1015 }
1016
1017 if i >= warmup_start {
1018 let v = if prev_adx.is_finite() {
1019 0.5 * (adx_curr + prev_adx)
1020 } else {
1021 f64::NAN
1022 };
1023 out[i] = v;
1024 }
1025 }
1026
1027 i += 1;
1028 }
1029}
1030
1031#[inline(always)]
1032pub fn adxr_batch_slice(
1033 high: &[f64],
1034 low: &[f64],
1035 close: &[f64],
1036 sweep: &AdxrBatchRange,
1037 kern: Kernel,
1038) -> Result<AdxrBatchOutput, AdxrError> {
1039 adxr_batch_inner(high, low, close, sweep, kern, false)
1040}
1041
1042#[inline(always)]
1043pub fn adxr_batch_par_slice(
1044 high: &[f64],
1045 low: &[f64],
1046 close: &[f64],
1047 sweep: &AdxrBatchRange,
1048 kern: Kernel,
1049) -> Result<AdxrBatchOutput, AdxrError> {
1050 adxr_batch_inner(high, low, close, sweep, kern, true)
1051}
1052
1053#[inline(always)]
1054fn adxr_batch_inner(
1055 high: &[f64],
1056 low: &[f64],
1057 close: &[f64],
1058 sweep: &AdxrBatchRange,
1059 kern: Kernel,
1060 parallel: bool,
1061) -> Result<AdxrBatchOutput, AdxrError> {
1062 let combos = expand_grid(sweep)?;
1063 let len = close.len();
1064 let first = close
1065 .iter()
1066 .position(|x| !x.is_nan())
1067 .ok_or(AdxrError::AllValuesNaN)?;
1068 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1069
1070 if len - first < max_p + 1 {
1071 return Err(AdxrError::NotEnoughValidData {
1072 needed: max_p + 1,
1073 valid: len - first,
1074 });
1075 }
1076 let rows = combos.len();
1077 let cols = len;
1078
1079 let mut buf_mu = make_uninit_matrix(rows, cols);
1080
1081 let warm: Vec<usize> = combos
1082 .iter()
1083 .map(|c| first + 2 * c.period.unwrap())
1084 .collect();
1085
1086 init_matrix_prefixes(&mut buf_mu, cols, &warm);
1087
1088 let mut buf_guard = std::mem::ManuallyDrop::new(buf_mu);
1089 let values: &mut [f64] = unsafe {
1090 std::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
1091 };
1092
1093 let (tr_all, pdm_all, mdm_all, prefix_tr, prefix_pdm, prefix_mdm) =
1094 shared_precompute_tr_dm(high, low, close, first);
1095
1096 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
1097 let period = combos[row].period.unwrap();
1098
1099 match kern {
1100 Kernel::Scalar => adxr_row_from_precomputed(
1101 &tr_all,
1102 &pdm_all,
1103 &mdm_all,
1104 &prefix_tr,
1105 &prefix_pdm,
1106 &prefix_mdm,
1107 first,
1108 period,
1109 out_row,
1110 ),
1111 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1112 Kernel::Avx2 | Kernel::Avx512 => adxr_row_from_precomputed(
1113 &tr_all,
1114 &pdm_all,
1115 &mdm_all,
1116 &prefix_tr,
1117 &prefix_pdm,
1118 &prefix_mdm,
1119 first,
1120 period,
1121 out_row,
1122 ),
1123 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1124 Kernel::Avx2 | Kernel::Avx512 => adxr_row_from_precomputed(
1125 &tr_all,
1126 &pdm_all,
1127 &mdm_all,
1128 &prefix_tr,
1129 &prefix_pdm,
1130 &prefix_mdm,
1131 first,
1132 period,
1133 out_row,
1134 ),
1135 _ => unreachable!(),
1136 }
1137 };
1138
1139 if parallel {
1140 #[cfg(not(target_arch = "wasm32"))]
1141 {
1142 values
1143 .par_chunks_mut(cols)
1144 .enumerate()
1145 .for_each(|(row, slice)| do_row(row, slice));
1146 }
1147
1148 #[cfg(target_arch = "wasm32")]
1149 {
1150 for (row, slice) in values.chunks_mut(cols).enumerate() {
1151 do_row(row, slice);
1152 }
1153 }
1154 } else {
1155 for (row, slice) in values.chunks_mut(cols).enumerate() {
1156 do_row(row, slice);
1157 }
1158 }
1159
1160 let values = unsafe {
1161 Vec::from_raw_parts(
1162 buf_guard.as_mut_ptr() as *mut f64,
1163 buf_guard.len(),
1164 buf_guard.capacity(),
1165 )
1166 };
1167
1168 Ok(AdxrBatchOutput {
1169 values,
1170 combos,
1171 rows,
1172 cols,
1173 })
1174}
1175
1176#[inline(always)]
1177pub fn adxr_batch_inner_into(
1178 high: &[f64],
1179 low: &[f64],
1180 close: &[f64],
1181 sweep: &AdxrBatchRange,
1182 kern: Kernel,
1183 parallel: bool,
1184 out: &mut [f64],
1185) -> Result<Vec<AdxrParams>, AdxrError> {
1186 let combos = expand_grid(sweep)?;
1187
1188 let len = close.len();
1189 if high.len() != len || low.len() != len {
1190 return Err(AdxrError::HlcLengthMismatch {
1191 high_len: high.len(),
1192 low_len: low.len(),
1193 close_len: len,
1194 });
1195 }
1196
1197 let first = close
1198 .iter()
1199 .position(|x| !x.is_nan())
1200 .ok_or(AdxrError::AllValuesNaN)?;
1201 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1202
1203 if len - first < max_p + 1 {
1204 return Err(AdxrError::NotEnoughValidData {
1205 needed: max_p + 1,
1206 valid: len - first,
1207 });
1208 }
1209
1210 let rows = combos.len();
1211 let cols = len;
1212 if let Some(expected) = rows.checked_mul(cols) {
1213 if out.len() != expected {
1214 return Err(AdxrError::OutputLengthMismatch {
1215 expected,
1216 got: out.len(),
1217 });
1218 }
1219 } else {
1220 return Err(AdxrError::InvalidRange {
1221 start: rows,
1222 end: cols,
1223 step: 0,
1224 });
1225 }
1226
1227 let warm: Vec<usize> = combos
1228 .iter()
1229 .map(|c| first + 2 * c.period.unwrap())
1230 .collect();
1231
1232 let out_mu: &mut [std::mem::MaybeUninit<f64>] = unsafe {
1233 std::slice::from_raw_parts_mut(
1234 out.as_mut_ptr() as *mut std::mem::MaybeUninit<f64>,
1235 out.len(),
1236 )
1237 };
1238 init_matrix_prefixes(out_mu, cols, &warm);
1239
1240 let (tr_all, pdm_all, mdm_all, prefix_tr, prefix_pdm, prefix_mdm) =
1241 shared_precompute_tr_dm(high, low, close, first);
1242
1243 let do_row = |row: usize, dst_mu: &mut [std::mem::MaybeUninit<f64>]| unsafe {
1244 let period = combos[row].period.unwrap();
1245 let dst = std::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
1246
1247 match kern {
1248 Kernel::Scalar => adxr_row_from_precomputed(
1249 &tr_all,
1250 &pdm_all,
1251 &mdm_all,
1252 &prefix_tr,
1253 &prefix_pdm,
1254 &prefix_mdm,
1255 first,
1256 period,
1257 dst,
1258 ),
1259 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1260 Kernel::Avx2 | Kernel::Avx512 => adxr_row_from_precomputed(
1261 &tr_all,
1262 &pdm_all,
1263 &mdm_all,
1264 &prefix_tr,
1265 &prefix_pdm,
1266 &prefix_mdm,
1267 first,
1268 period,
1269 dst,
1270 ),
1271 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1272 Kernel::Avx2 | Kernel::Avx512 => adxr_row_from_precomputed(
1273 &tr_all,
1274 &pdm_all,
1275 &mdm_all,
1276 &prefix_tr,
1277 &prefix_pdm,
1278 &prefix_mdm,
1279 first,
1280 period,
1281 dst,
1282 ),
1283 _ => unreachable!("pass non-batch kernel"),
1284 }
1285 };
1286
1287 if parallel {
1288 #[cfg(not(target_arch = "wasm32"))]
1289 {
1290 use rayon::prelude::*;
1291 out_mu
1292 .par_chunks_mut(cols)
1293 .enumerate()
1294 .for_each(|(row, slice)| do_row(row, slice));
1295 }
1296 #[cfg(target_arch = "wasm32")]
1297 {
1298 for (row, slice) in out_mu.chunks_mut(cols).enumerate() {
1299 do_row(row, slice);
1300 }
1301 }
1302 } else {
1303 for (row, slice) in out_mu.chunks_mut(cols).enumerate() {
1304 do_row(row, slice);
1305 }
1306 }
1307
1308 Ok(combos)
1309}
1310
1311#[inline(always)]
1312unsafe fn adxr_row_scalar(
1313 high: &[f64],
1314 low: &[f64],
1315 close: &[f64],
1316 first: usize,
1317 period: usize,
1318 out: &mut [f64],
1319) {
1320 adxr_scalar(high, low, close, period, first, out)
1321}
1322
1323#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1324#[inline(always)]
1325unsafe fn adxr_row_avx2(
1326 high: &[f64],
1327 low: &[f64],
1328 close: &[f64],
1329 first: usize,
1330 period: usize,
1331 out: &mut [f64],
1332) {
1333 adxr_scalar(high, low, close, period, first, out)
1334}
1335
1336#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1337#[inline(always)]
1338unsafe fn adxr_row_avx512(
1339 high: &[f64],
1340 low: &[f64],
1341 close: &[f64],
1342 first: usize,
1343 period: usize,
1344 out: &mut [f64],
1345) {
1346 adxr_avx512(high, low, close, period, first, out)
1347}
1348
1349#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1350#[inline(always)]
1351unsafe fn adxr_row_avx512_short(
1352 high: &[f64],
1353 low: &[f64],
1354 close: &[f64],
1355 first: usize,
1356 period: usize,
1357 out: &mut [f64],
1358) {
1359 adxr_scalar(high, low, close, period, first, out)
1360}
1361
1362#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1363#[inline(always)]
1364unsafe fn adxr_row_avx512_long(
1365 high: &[f64],
1366 low: &[f64],
1367 close: &[f64],
1368 first: usize,
1369 period: usize,
1370 out: &mut [f64],
1371) {
1372 adxr_scalar(high, low, close, period, first, out)
1373}
1374
1375#[derive(Debug, Clone)]
1376pub struct AdxrStream {
1377 period: usize,
1378
1379 rp: f64,
1380 om: f64,
1381 pm1: f64,
1382
1383 atr: f64,
1384 pdm_s: f64,
1385 mdm_s: f64,
1386
1387 dx_sum: f64,
1388 dx_count: usize,
1389 adx_last: f64,
1390 have_adx: bool,
1391
1392 adx_ring: Vec<f64>,
1393 head: usize,
1394
1395 prev_hlc: Option<(f64, f64, f64)>,
1396
1397 seen: usize,
1398}
1399
1400impl AdxrStream {
1401 #[inline(always)]
1402 pub fn try_new(params: AdxrParams) -> Result<Self, AdxrError> {
1403 let period = params.period.unwrap_or(14);
1404 if period == 0 {
1405 return Err(AdxrError::InvalidPeriod {
1406 period,
1407 data_len: 0,
1408 });
1409 }
1410 let p = period as f64;
1411 Ok(Self {
1412 period,
1413 rp: 1.0 / p,
1414 om: 1.0 - 1.0 / p,
1415 pm1: p - 1.0,
1416 atr: 0.0,
1417 pdm_s: 0.0,
1418 mdm_s: 0.0,
1419 dx_sum: 0.0,
1420 dx_count: 0,
1421 adx_last: f64::NAN,
1422 have_adx: false,
1423 adx_ring: vec![f64::NAN; period],
1424 head: 0,
1425 prev_hlc: None,
1426 seen: 0,
1427 })
1428 }
1429
1430 #[inline(always)]
1431 pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<f64> {
1432 if !(high.is_finite() && low.is_finite() && close.is_finite()) {
1433 return None;
1434 }
1435
1436 if self.prev_hlc.is_none() {
1437 self.prev_hlc = Some((high, low, close));
1438 return None;
1439 }
1440
1441 let (ph, pl, pc) = unsafe { self.prev_hlc.unwrap_unchecked() };
1442 self.prev_hlc = Some((high, low, close));
1443 self.seen = self.seen.wrapping_add(1);
1444
1445 let tr = {
1446 let a = high - low;
1447 let b = (high - pc).abs();
1448 let c = (low - pc).abs();
1449 a.max(b).max(c)
1450 };
1451
1452 let up = high - ph;
1453 let down = pl - low;
1454 let plus_dm = if up > down && up > 0.0 { up } else { 0.0 };
1455 let minus_dm = if down > up && down > 0.0 { down } else { 0.0 };
1456
1457 if self.seen <= self.period {
1458 self.atr += tr;
1459 self.pdm_s += plus_dm;
1460 self.mdm_s += minus_dm;
1461
1462 if self.seen == self.period {
1463 let denom = self.pdm_s + self.mdm_s;
1464 let dx0 = if denom > 0.0 {
1465 100.0 * (self.pdm_s - self.mdm_s).abs() / denom
1466 } else {
1467 0.0
1468 };
1469 self.dx_sum = dx0;
1470 self.dx_count = 1;
1471 }
1472 return None;
1473 }
1474
1475 self.atr = self.atr.mul_add(self.om, tr);
1476 self.pdm_s = self.pdm_s.mul_add(self.om, plus_dm);
1477 self.mdm_s = self.mdm_s.mul_add(self.om, minus_dm);
1478
1479 let denom = self.pdm_s + self.mdm_s;
1480 let dx = if denom > 0.0 {
1481 100.0 * (self.pdm_s - self.mdm_s).abs() / denom
1482 } else {
1483 0.0
1484 };
1485
1486 if !self.have_adx {
1487 if self.dx_count + 1 < self.period {
1488 self.dx_sum += dx;
1489 self.dx_count += 1;
1490 return None;
1491 } else {
1492 self.dx_sum += dx;
1493 self.adx_last = self.dx_sum * self.rp;
1494 self.have_adx = true;
1495
1496 self.adx_ring[self.head] = self.adx_last;
1497 self.head = (self.head + 1) % self.period;
1498 return None;
1499 }
1500 }
1501
1502 let adx_curr = (self.adx_last.mul_add(self.pm1, dx)) * self.rp;
1503 self.adx_last = adx_curr;
1504
1505 let adx_period_ago = self.adx_ring[self.head];
1506 self.adx_ring[self.head] = adx_curr;
1507 self.head = (self.head + 1) % self.period;
1508
1509 if adx_period_ago.is_finite() {
1510 Some(0.5 * (adx_curr + adx_period_ago))
1511 } else {
1512 None
1513 }
1514 }
1515}
1516
1517#[cfg(test)]
1518mod tests {
1519 use super::*;
1520 use crate::skip_if_unsupported;
1521 use crate::utilities::data_loader::read_candles_from_csv;
1522 use paste::paste;
1523
1524 fn check_adxr_partial_params(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1525 skip_if_unsupported!(kernel, test);
1526 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1527 let candles = read_candles_from_csv(file_path)?;
1528 let input = AdxrInput::from_candles(&candles, AdxrParams { period: None });
1529 let output = adxr_with_kernel(&input, kernel)?;
1530 assert_eq!(output.values.len(), candles.close.len());
1531 Ok(())
1532 }
1533
1534 fn check_adxr_accuracy(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1535 skip_if_unsupported!(kernel, test);
1536 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1537 let candles = read_candles_from_csv(file_path)?;
1538 let input = AdxrInput::from_candles(&candles, AdxrParams::default());
1539 let result = adxr_with_kernel(&input, kernel)?;
1540 let expected = [37.10, 37.3, 37.0, 36.2, 36.3];
1541 let start = result.values.len().saturating_sub(5);
1542 for (i, &val) in result.values[start..].iter().enumerate() {
1543 let diff = (val - expected[i]).abs();
1544 assert!(
1545 diff < 1e-1,
1546 "[{}] ADXR {:?} mismatch at idx {}: got {}, expected {}",
1547 test,
1548 kernel,
1549 i,
1550 val,
1551 expected[i]
1552 );
1553 }
1554 Ok(())
1555 }
1556
1557 fn check_adxr_zero_period(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1558 skip_if_unsupported!(kernel, test);
1559 let high = [10.0, 20.0, 30.0];
1560 let low = [9.0, 19.0, 29.0];
1561 let close = [9.5, 19.5, 29.5];
1562 let input = AdxrInput::from_slices(&high, &low, &close, AdxrParams { period: Some(0) });
1563 let res = adxr_with_kernel(&input, kernel);
1564 assert!(res.is_err(), "[{}] ADXR should fail with zero period", test);
1565 Ok(())
1566 }
1567
1568 fn check_adxr_period_exceeds_length(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1569 skip_if_unsupported!(kernel, test);
1570 let high = [10.0, 20.0];
1571 let low = [9.0, 19.0];
1572 let close = [9.5, 19.5];
1573 let input = AdxrInput::from_slices(&high, &low, &close, AdxrParams { period: Some(10) });
1574 let res = adxr_with_kernel(&input, kernel);
1575 assert!(
1576 res.is_err(),
1577 "[{}] ADXR should fail with period > data.len()",
1578 test
1579 );
1580 Ok(())
1581 }
1582
1583 fn check_adxr_very_small_dataset(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1584 skip_if_unsupported!(kernel, test);
1585 let high = [100.0];
1586 let low = [99.0];
1587 let close = [99.5];
1588 let input = AdxrInput::from_slices(&high, &low, &close, AdxrParams { period: Some(14) });
1589 let res = adxr_with_kernel(&input, kernel);
1590 assert!(
1591 res.is_err(),
1592 "[{}] ADXR should fail with insufficient data",
1593 test
1594 );
1595 Ok(())
1596 }
1597
1598 fn check_adxr_reinput(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1599 skip_if_unsupported!(kernel, test);
1600 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1601 let candles = read_candles_from_csv(file_path)?;
1602 let first_input = AdxrInput::from_candles(&candles, AdxrParams { period: Some(14) });
1603 let first_result = adxr_with_kernel(&first_input, kernel)?;
1604 let high = &candles.high;
1605 let low = &candles.low;
1606 let close = &candles.close;
1607 let second_input = AdxrInput::from_slices(high, low, close, AdxrParams { period: Some(5) });
1608 let second_result = adxr_with_kernel(&second_input, kernel)?;
1609 assert_eq!(second_result.values.len(), candles.close.len());
1610 Ok(())
1611 }
1612
1613 fn check_adxr_nan_handling(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1614 skip_if_unsupported!(kernel, test);
1615 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1616 let candles = read_candles_from_csv(file_path)?;
1617 let input = AdxrInput::from_candles(&candles, AdxrParams { period: Some(14) });
1618 let res = adxr_with_kernel(&input, kernel)?;
1619 assert_eq!(res.values.len(), candles.close.len());
1620 if res.values.len() > 240 {
1621 for (i, &val) in res.values[240..].iter().enumerate() {
1622 assert!(
1623 !val.is_nan(),
1624 "[{}] Found unexpected NaN at out-index {}",
1625 test,
1626 240 + i
1627 );
1628 }
1629 }
1630 Ok(())
1631 }
1632
1633 macro_rules! generate_all_adxr_tests {
1634 ($($test_fn:ident),*) => {
1635 paste! {
1636 $(
1637 #[test]
1638 fn [<$test_fn _scalar_f64>]() {
1639 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1640 }
1641 )*
1642 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1643 $(
1644 #[test]
1645 fn [<$test_fn _avx2_f64>]() {
1646 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1647 }
1648 #[test]
1649 fn [<$test_fn _avx512_f64>]() {
1650 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1651 }
1652 )*
1653 }
1654 }
1655 }
1656
1657 #[cfg(debug_assertions)]
1658 fn check_adxr_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1659 skip_if_unsupported!(kernel, test_name);
1660
1661 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1662 let candles = read_candles_from_csv(file_path)?;
1663
1664 let test_params = vec![
1665 AdxrParams::default(),
1666 AdxrParams { period: Some(5) },
1667 AdxrParams { period: Some(10) },
1668 AdxrParams { period: Some(14) },
1669 AdxrParams { period: Some(20) },
1670 AdxrParams { period: Some(25) },
1671 AdxrParams { period: Some(30) },
1672 AdxrParams { period: Some(50) },
1673 AdxrParams { period: Some(100) },
1674 AdxrParams { period: Some(2) },
1675 ];
1676
1677 for (param_idx, params) in test_params.iter().enumerate() {
1678 let input = AdxrInput::from_candles(&candles, params.clone());
1679 let output = adxr_with_kernel(&input, kernel)?;
1680
1681 for (i, &val) in output.values.iter().enumerate() {
1682 if val.is_nan() {
1683 continue;
1684 }
1685
1686 let bits = val.to_bits();
1687
1688 if bits == 0x11111111_11111111 {
1689 panic!(
1690 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1691 with params: period={}",
1692 test_name,
1693 val,
1694 bits,
1695 i,
1696 params.period.unwrap_or(14)
1697 );
1698 }
1699
1700 if bits == 0x22222222_22222222 {
1701 panic!(
1702 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1703 with params: period={}",
1704 test_name,
1705 val,
1706 bits,
1707 i,
1708 params.period.unwrap_or(14)
1709 );
1710 }
1711
1712 if bits == 0x33333333_33333333 {
1713 panic!(
1714 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1715 with params: period={}",
1716 test_name,
1717 val,
1718 bits,
1719 i,
1720 params.period.unwrap_or(14)
1721 );
1722 }
1723 }
1724 }
1725
1726 Ok(())
1727 }
1728
1729 #[cfg(not(debug_assertions))]
1730 fn check_adxr_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1731 Ok(())
1732 }
1733
1734 fn check_adxr_property(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1735 use proptest::prelude::*;
1736 skip_if_unsupported!(kernel, test_name);
1737
1738 let strat = (2usize..=50)
1739 .prop_flat_map(|period| {
1740 let min_size = (period * 3).max(period + 10);
1741 let max_size = 400;
1742 (
1743 10.0f64..1000.0f64,
1744 0.0f64..0.1f64,
1745 -0.01f64..0.01f64,
1746 min_size..max_size,
1747 Just(period),
1748 0u8..3,
1749 )
1750 })
1751 .prop_map(
1752 |(base_price, volatility_pct, trend, size, period, market_type)| {
1753 let mut high_data = Vec::with_capacity(size);
1754 let mut low_data = Vec::with_capacity(size);
1755 let mut close_data = Vec::with_capacity(size);
1756
1757 for i in 0..size {
1758 let price = match market_type {
1759 0 => {
1760 let cycle = (i as f64 * 0.1).sin();
1761 base_price * (1.0 + cycle * volatility_pct)
1762 }
1763 1 => base_price * (1.0 + trend * i as f64),
1764 2 => base_price,
1765 _ => base_price,
1766 };
1767
1768 let (high, low, close) = if market_type == 2 {
1769 (price, price, price)
1770 } else {
1771 let daily_volatility =
1772 price * volatility_pct * (0.5 + 0.5 * (i as f64 * 0.05).cos());
1773 let close = price;
1774 let high = close + daily_volatility.abs();
1775 let low = close - daily_volatility.abs();
1776 (high, low, close)
1777 };
1778
1779 high_data.push(high);
1780 low_data.push(low);
1781 close_data.push(close);
1782 }
1783
1784 (high_data, low_data, close_data, period, market_type)
1785 },
1786 );
1787
1788 proptest::test_runner::TestRunner::default().run(
1789 &strat,
1790 |(high_data, low_data, close_data, period, market_type)| {
1791 let params = AdxrParams {
1792 period: Some(period),
1793 };
1794 let input = AdxrInput::from_slices(&high_data, &low_data, &close_data, params);
1795
1796 let result = adxr_with_kernel(&input, kernel);
1797 prop_assert!(result.is_ok(), "ADXR computation failed: {:?}", result);
1798 let AdxrOutput { values: out } = result.unwrap();
1799
1800 let ref_result = adxr_with_kernel(&input, Kernel::Scalar);
1801 prop_assert!(ref_result.is_ok(), "Reference ADXR computation failed");
1802 let AdxrOutput { values: ref_out } = ref_result.unwrap();
1803
1804 let first = close_data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1805
1806 let warmup_period = first + 2 * period;
1807
1808 for i in 0..out.len() {
1809 let y = out[i];
1810 let r = ref_out[i];
1811
1812 if i < warmup_period {
1813 prop_assert!(
1814 y.is_nan(),
1815 "Expected NaN during warmup at index {}, got {}",
1816 i,
1817 y
1818 );
1819 } else {
1820 if !y.is_nan() {
1821 prop_assert!(
1822 y >= -1e-9 && y <= 100.0 + 1e-9,
1823 "ADXR value {} at index {} is outside [0, 100] range",
1824 y,
1825 i
1826 );
1827 }
1828
1829 if !y.is_nan() && !r.is_nan() {
1830 let diff = (y - r).abs();
1831 prop_assert!(
1832 diff < 1e-6,
1833 "Kernel {:?} and Scalar differ by {} at index {}: {} vs {}",
1834 kernel,
1835 diff,
1836 i,
1837 y,
1838 r
1839 );
1840 }
1841 }
1842 }
1843
1844 if market_type == 2 && out.len() > warmup_period + period {
1845 let last_values = &out[out.len().saturating_sub(10)..];
1846 let non_nan_values: Vec<f64> = last_values
1847 .iter()
1848 .filter(|v| !v.is_nan())
1849 .copied()
1850 .collect();
1851
1852 if !non_nan_values.is_empty() {
1853 let avg_last =
1854 non_nan_values.iter().sum::<f64>() / non_nan_values.len() as f64;
1855 prop_assert!(
1856 avg_last < 25.0,
1857 "ADXR should be low with zero volatility, got average {}",
1858 avg_last
1859 );
1860 }
1861 }
1862
1863 if period == 2 {
1864 prop_assert!(out.len() == close_data.len());
1865 }
1866
1867 let is_constant = close_data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10)
1868 && high_data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10)
1869 && low_data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10);
1870
1871 if is_constant && out.len() > warmup_period {
1872 let stable_values = &out[warmup_period..];
1873 let non_nan: Vec<f64> = stable_values
1874 .iter()
1875 .filter(|v| !v.is_nan())
1876 .copied()
1877 .collect();
1878
1879 if non_nan.len() > 10 {
1880 let mean = non_nan.iter().sum::<f64>() / non_nan.len() as f64;
1881 let variance = non_nan.iter().map(|v| (v - mean).powi(2)).sum::<f64>()
1882 / non_nan.len() as f64;
1883 let std_dev = variance.sqrt();
1884
1885 prop_assert!(
1886 std_dev < 5.0,
1887 "ADXR should stabilize with constant data, std_dev = {}",
1888 std_dev
1889 );
1890 }
1891 }
1892
1893 Ok(())
1894 },
1895 )?;
1896
1897 Ok(())
1898 }
1899
1900 generate_all_adxr_tests!(
1901 check_adxr_partial_params,
1902 check_adxr_accuracy,
1903 check_adxr_zero_period,
1904 check_adxr_period_exceeds_length,
1905 check_adxr_very_small_dataset,
1906 check_adxr_reinput,
1907 check_adxr_nan_handling,
1908 check_adxr_no_poison,
1909 check_adxr_property
1910 );
1911
1912 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1913 skip_if_unsupported!(kernel, test);
1914 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1915 let c = read_candles_from_csv(file)?;
1916 let output = AdxrBatchBuilder::new().kernel(kernel).apply_candles(&c)?;
1917 let def = AdxrParams::default();
1918 let row = output.values_for(&def).expect("default row missing");
1919 assert_eq!(row.len(), c.close.len());
1920 Ok(())
1921 }
1922
1923 macro_rules! gen_batch_tests {
1924 ($fn_name:ident) => {
1925 paste! {
1926 #[test] fn [<$fn_name _scalar>]() {
1927 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1928 }
1929 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1930 #[test] fn [<$fn_name _avx2>]() {
1931 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1932 }
1933 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1934 #[test] fn [<$fn_name _avx512>]() {
1935 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1936 }
1937 #[test] fn [<$fn_name _auto_detect>]() {
1938 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1939 }
1940 }
1941 };
1942 }
1943 #[cfg(debug_assertions)]
1944 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1945 skip_if_unsupported!(kernel, test);
1946
1947 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1948 let c = read_candles_from_csv(file)?;
1949
1950 let test_configs = vec![
1951 (2, 10, 2),
1952 (5, 25, 5),
1953 (10, 20, 2),
1954 (14, 50, 6),
1955 (20, 100, 10),
1956 (2, 30, 7),
1957 (8, 40, 8),
1958 ];
1959
1960 for (cfg_idx, &(p_start, p_end, p_step)) in test_configs.iter().enumerate() {
1961 let output = AdxrBatchBuilder::new()
1962 .kernel(kernel)
1963 .period_range(p_start, p_end, p_step)
1964 .apply_candles(&c)?;
1965
1966 for (idx, &val) in output.values.iter().enumerate() {
1967 if val.is_nan() {
1968 continue;
1969 }
1970
1971 let bits = val.to_bits();
1972 let row = idx / output.cols;
1973 let col = idx % output.cols;
1974 let combo = &output.combos[row];
1975
1976 if bits == 0x11111111_11111111 {
1977 panic!(
1978 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1979 at row {} col {} (flat index {}) with params: period={}",
1980 test,
1981 cfg_idx,
1982 val,
1983 bits,
1984 row,
1985 col,
1986 idx,
1987 combo.period.unwrap_or(14)
1988 );
1989 }
1990
1991 if bits == 0x22222222_22222222 {
1992 panic!(
1993 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1994 at row {} col {} (flat index {}) with params: period={}",
1995 test,
1996 cfg_idx,
1997 val,
1998 bits,
1999 row,
2000 col,
2001 idx,
2002 combo.period.unwrap_or(14)
2003 );
2004 }
2005
2006 if bits == 0x33333333_33333333 {
2007 panic!(
2008 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2009 at row {} col {} (flat index {}) with params: period={}",
2010 test,
2011 cfg_idx,
2012 val,
2013 bits,
2014 row,
2015 col,
2016 idx,
2017 combo.period.unwrap_or(14)
2018 );
2019 }
2020 }
2021 }
2022
2023 Ok(())
2024 }
2025
2026 #[cfg(not(debug_assertions))]
2027 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2028 Ok(())
2029 }
2030
2031 gen_batch_tests!(check_batch_default_row);
2032 gen_batch_tests!(check_batch_no_poison);
2033
2034 #[test]
2035 fn test_adxr_into_matches_api() -> Result<(), Box<dyn Error>> {
2036 let len = 256usize;
2037 let mut high = vec![0.0f64; len];
2038 let mut low = vec![0.0f64; len];
2039 let mut close = vec![0.0f64; len];
2040 for i in 0..len {
2041 let base = 100.0 + (i as f64) * 0.1 + (i as f64 * 0.07).sin();
2042 low[i] = base - 1.0;
2043 close[i] = base - 0.3;
2044 high[i] = base + 0.8;
2045 }
2046
2047 let input = AdxrInput::from_slices(&high, &low, &close, AdxrParams::default());
2048
2049 let baseline = adxr(&input)?.values;
2050
2051 let mut out = vec![0.0f64; len];
2052 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2053 {
2054 adxr_into(&input, &mut out)?;
2055 }
2056 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2057 {
2058 adxr_into_slice(&mut out, &input, Kernel::Auto)?;
2059 }
2060
2061 assert_eq!(baseline.len(), out.len());
2062 for (a, b) in baseline.iter().zip(out.iter()) {
2063 let equal = (a.is_nan() && b.is_nan()) || (a == b);
2064 assert!(equal, "Mismatch: baseline={} out={}", a, b);
2065 }
2066
2067 Ok(())
2068 }
2069}
2070
2071#[cfg(feature = "python")]
2072#[pyfunction(name = "adxr")]
2073#[pyo3(signature = (high, low, close, period=None, kernel=None))]
2074pub fn adxr_py<'py>(
2075 py: Python<'py>,
2076 high: numpy::PyReadonlyArray1<'py, f64>,
2077 low: numpy::PyReadonlyArray1<'py, f64>,
2078 close: numpy::PyReadonlyArray1<'py, f64>,
2079 period: Option<usize>,
2080 kernel: Option<&str>,
2081) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
2082 use numpy::{IntoPyArray, PyArrayMethods};
2083
2084 let high_slice = high.as_slice()?;
2085 let low_slice = low.as_slice()?;
2086 let close_slice = close.as_slice()?;
2087
2088 if high_slice.len() != low_slice.len() || high_slice.len() != close_slice.len() {
2089 return Err(PyValueError::new_err(format!(
2090 "HLC data length mismatch: high={}, low={}, close={}",
2091 high_slice.len(),
2092 low_slice.len(),
2093 close_slice.len()
2094 )));
2095 }
2096
2097 let kern = validate_kernel(kernel, false)?;
2098
2099 let params = AdxrParams {
2100 period: period.or(Some(14)),
2101 };
2102 let adxr_in = AdxrInput::from_slices(high_slice, low_slice, close_slice, params);
2103
2104 let result_vec: Vec<f64> = py
2105 .allow_threads(|| adxr_with_kernel(&adxr_in, kern).map(|o| o.values))
2106 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2107
2108 Ok(result_vec.into_pyarray(py))
2109}
2110
2111#[cfg(feature = "python")]
2112#[pyclass(name = "AdxrStream")]
2113pub struct AdxrStreamPy {
2114 stream: AdxrStream,
2115}
2116
2117#[cfg(feature = "python")]
2118#[pymethods]
2119impl AdxrStreamPy {
2120 #[new]
2121 #[pyo3(signature = (period=None))]
2122 fn new(period: Option<usize>) -> PyResult<Self> {
2123 let params = AdxrParams {
2124 period: period.or(Some(14)),
2125 };
2126 let stream =
2127 AdxrStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2128 Ok(AdxrStreamPy { stream })
2129 }
2130
2131 fn update(&mut self, high: f64, low: f64, close: f64) -> Option<f64> {
2132 self.stream.update(high, low, close)
2133 }
2134}
2135
2136#[cfg(feature = "python")]
2137#[pyfunction(name = "adxr_batch")]
2138#[pyo3(signature = (high, low, close, period_range, kernel=None))]
2139pub fn adxr_batch_py<'py>(
2140 py: Python<'py>,
2141 high: numpy::PyReadonlyArray1<'py, f64>,
2142 low: numpy::PyReadonlyArray1<'py, f64>,
2143 close: numpy::PyReadonlyArray1<'py, f64>,
2144 period_range: (usize, usize, usize),
2145 kernel: Option<&str>,
2146) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
2147 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
2148 use pyo3::types::PyDict;
2149
2150 let h = high.as_slice()?;
2151 let l = low.as_slice()?;
2152 let c = close.as_slice()?;
2153
2154 if h.len() != l.len() || h.len() != c.len() {
2155 return Err(PyValueError::new_err(format!(
2156 "HLC data length mismatch: high={}, low={}, close={}",
2157 h.len(),
2158 l.len(),
2159 c.len()
2160 )));
2161 }
2162
2163 let sweep = AdxrBatchRange {
2164 period: period_range,
2165 };
2166 let combos_probe = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
2167 let rows = combos_probe.len();
2168 let cols = c.len();
2169
2170 let total = rows
2171 .checked_mul(cols)
2172 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
2173
2174 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2175 let out_slice = unsafe { out_arr.as_slice_mut()? };
2176
2177 let k = crate::utilities::kernel_validation::validate_kernel(kernel, true)?;
2178 let simd = match k {
2179 Kernel::Auto => match detect_best_batch_kernel() {
2180 Kernel::Avx512Batch => Kernel::Avx512,
2181 Kernel::Avx2Batch => Kernel::Avx2,
2182 _ => Kernel::Scalar,
2183 },
2184 Kernel::Avx512Batch => Kernel::Avx512,
2185 Kernel::Avx2Batch => Kernel::Avx2,
2186 Kernel::ScalarBatch => Kernel::Scalar,
2187 other => other,
2188 };
2189
2190 let combos = py
2191 .allow_threads(|| adxr_batch_inner_into(h, l, c, &sweep, simd, true, out_slice))
2192 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2193
2194 let dict = PyDict::new(py);
2195 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2196 dict.set_item(
2197 "periods",
2198 combos
2199 .iter()
2200 .map(|p| p.period.unwrap() as u64)
2201 .collect::<Vec<_>>()
2202 .into_pyarray(py),
2203 )?;
2204 Ok(dict)
2205}
2206
2207#[cfg(all(feature = "python", feature = "cuda"))]
2208#[pyfunction(name = "adxr_cuda_batch_dev")]
2209#[pyo3(signature = (high_f32, low_f32, close_f32, period_range, device_id=0))]
2210pub fn adxr_cuda_batch_dev_py<'py>(
2211 py: Python<'py>,
2212 high_f32: numpy::PyReadonlyArray1<'py, f32>,
2213 low_f32: numpy::PyReadonlyArray1<'py, f32>,
2214 close_f32: numpy::PyReadonlyArray1<'py, f32>,
2215 period_range: (usize, usize, usize),
2216 device_id: usize,
2217) -> PyResult<AdxrDeviceArrayF32Py> {
2218 if !cuda_available() {
2219 return Err(PyValueError::new_err("CUDA not available"));
2220 }
2221 let h = high_f32.as_slice()?;
2222 let l = low_f32.as_slice()?;
2223 let c = close_f32.as_slice()?;
2224 let sweep = AdxrBatchRange {
2225 period: period_range,
2226 };
2227 let (inner, ctx_arc, dev_id) = py.allow_threads(|| {
2228 let cuda = CudaAdxr::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2229 let (dev, _combos) = cuda
2230 .adxr_batch_dev(h, l, c, &sweep)
2231 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2232 Ok::<_, PyErr>((dev, cuda.context_arc_clone(), cuda.device_id()))
2233 })?;
2234 Ok(AdxrDeviceArrayF32Py {
2235 inner: Some(inner),
2236 _ctx: ctx_arc,
2237 device_id: dev_id,
2238 })
2239}
2240
2241#[cfg(all(feature = "python", feature = "cuda"))]
2242#[pyfunction(name = "adxr_cuda_many_series_one_param_dev")]
2243#[pyo3(signature = (high_tm_f32, low_tm_f32, close_tm_f32, period, device_id=0))]
2244pub fn adxr_cuda_many_series_one_param_dev_py<'py>(
2245 py: Python<'py>,
2246 high_tm_f32: numpy::PyReadonlyArray2<'py, f32>,
2247 low_tm_f32: numpy::PyReadonlyArray2<'py, f32>,
2248 close_tm_f32: numpy::PyReadonlyArray2<'py, f32>,
2249 period: usize,
2250 device_id: usize,
2251) -> PyResult<AdxrDeviceArrayF32Py> {
2252 if !cuda_available() {
2253 return Err(PyValueError::new_err("CUDA not available"));
2254 }
2255 let shape = high_tm_f32.shape();
2256 if shape.len() != 2 || low_tm_f32.shape() != shape || close_tm_f32.shape() != shape {
2257 return Err(PyValueError::new_err("expected three matching 2D arrays"));
2258 }
2259 let rows = shape[0];
2260 let cols = shape[1];
2261 let h = high_tm_f32.as_slice()?;
2262 let l = low_tm_f32.as_slice()?;
2263 let c = close_tm_f32.as_slice()?;
2264 let (inner, ctx_arc, dev_id) = py.allow_threads(|| {
2265 let cuda = CudaAdxr::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2266 let dev = cuda
2267 .adxr_many_series_one_param_time_major_dev(h, l, c, cols, rows, period)
2268 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2269 Ok::<_, PyErr>((dev, cuda.context_arc_clone(), cuda.device_id()))
2270 })?;
2271 Ok(AdxrDeviceArrayF32Py {
2272 inner: Some(inner),
2273 _ctx: ctx_arc,
2274 device_id: dev_id,
2275 })
2276}
2277
2278#[cfg(all(feature = "python", feature = "cuda"))]
2279#[pyclass(module = "ta_indicators.cuda", name = "AdxrDeviceArrayF32", unsendable)]
2280pub struct AdxrDeviceArrayF32Py {
2281 pub(crate) inner: Option<DeviceArrayF32>,
2282 pub(crate) _ctx: Arc<Context>,
2283 pub(crate) device_id: u32,
2284}
2285
2286#[cfg(all(feature = "python", feature = "cuda"))]
2287#[pymethods]
2288impl AdxrDeviceArrayF32Py {
2289 #[getter]
2290 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2291 let inner = self
2292 .inner
2293 .as_ref()
2294 .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
2295 let d = PyDict::new(py);
2296 d.set_item("shape", (inner.rows, inner.cols))?;
2297 d.set_item("typestr", "<f4")?;
2298 d.set_item(
2299 "strides",
2300 (
2301 inner.cols * std::mem::size_of::<f32>(),
2302 std::mem::size_of::<f32>(),
2303 ),
2304 )?;
2305 d.set_item("data", (inner.device_ptr() as usize, false))?;
2306
2307 d.set_item("version", 3)?;
2308 Ok(d)
2309 }
2310
2311 fn __dlpack_device__(&self) -> (i32, i32) {
2312 (2, self.device_id as i32)
2313 }
2314
2315 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2316 fn __dlpack__<'py>(
2317 &mut self,
2318 py: Python<'py>,
2319 stream: Option<pyo3::PyObject>,
2320 max_version: Option<pyo3::PyObject>,
2321 dl_device: Option<pyo3::PyObject>,
2322 copy: Option<pyo3::PyObject>,
2323 ) -> PyResult<PyObject> {
2324 use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
2325
2326 let (kdl, alloc_dev) = self.__dlpack_device__();
2327 if let Some(dev_obj) = dl_device.as_ref() {
2328 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2329 if dev_ty != kdl || dev_id != alloc_dev {
2330 let wants_copy = copy
2331 .as_ref()
2332 .and_then(|c| c.extract::<bool>(py).ok())
2333 .unwrap_or(false);
2334 if wants_copy {
2335 return Err(PyValueError::new_err(
2336 "device copy not implemented for __dlpack__",
2337 ));
2338 } else {
2339 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
2340 }
2341 }
2342 }
2343 }
2344 let _ = stream;
2345
2346 let inner = self
2347 .inner
2348 .take()
2349 .ok_or_else(|| PyValueError::new_err("__dlpack__ may only be called once"))?;
2350
2351 let rows = inner.rows;
2352 let cols = inner.cols;
2353 let buf = inner.buf;
2354
2355 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2356
2357 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2358 }
2359}
2360
2361#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2362#[wasm_bindgen]
2363pub fn adxr_js(
2364 high: &[f64],
2365 low: &[f64],
2366 close: &[f64],
2367 period: usize,
2368) -> Result<Vec<f64>, JsValue> {
2369 let params = AdxrParams {
2370 period: Some(period),
2371 };
2372 let input = AdxrInput::from_slices(high, low, close, params);
2373
2374 let mut output = vec![0.0; close.len()];
2375
2376 adxr_into_slice(&mut output, &input, Kernel::Auto)
2377 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2378
2379 Ok(output)
2380}
2381
2382#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2383#[wasm_bindgen]
2384pub fn adxr_batch_js(
2385 high: &[f64],
2386 low: &[f64],
2387 close: &[f64],
2388 period_start: usize,
2389 period_end: usize,
2390 period_step: usize,
2391) -> Result<Vec<f64>, JsValue> {
2392 let sweep = AdxrBatchRange {
2393 period: (period_start, period_end, period_step),
2394 };
2395
2396 adxr_batch_inner(high, low, close, &sweep, Kernel::Scalar, false)
2397 .map(|output| output.values)
2398 .map_err(|e| JsValue::from_str(&e.to_string()))
2399}
2400
2401#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2402#[wasm_bindgen]
2403pub fn adxr_batch_metadata_js(
2404 period_start: usize,
2405 period_end: usize,
2406 period_step: usize,
2407) -> Result<Vec<f64>, JsValue> {
2408 let sweep = AdxrBatchRange {
2409 period: (period_start, period_end, period_step),
2410 };
2411
2412 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2413 let mut metadata = Vec::with_capacity(combos.len());
2414
2415 for combo in combos {
2416 metadata.push(combo.period.unwrap() as f64);
2417 }
2418
2419 Ok(metadata)
2420}
2421
2422#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2423#[derive(Serialize, Deserialize)]
2424pub struct AdxrBatchConfig {
2425 pub period_range: (usize, usize, usize),
2426}
2427
2428#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2429#[derive(Serialize, Deserialize)]
2430pub struct AdxrBatchJsOutput {
2431 pub values: Vec<f64>,
2432 pub combos: Vec<AdxrParams>,
2433 pub rows: usize,
2434 pub cols: usize,
2435}
2436
2437#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2438#[wasm_bindgen(js_name = adxr_batch)]
2439pub fn adxr_batch_unified_js(
2440 high: &[f64],
2441 low: &[f64],
2442 close: &[f64],
2443 config: JsValue,
2444) -> Result<JsValue, JsValue> {
2445 let config: AdxrBatchConfig = serde_wasm_bindgen::from_value(config)
2446 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2447
2448 let sweep = AdxrBatchRange {
2449 period: config.period_range,
2450 };
2451
2452 let output = adxr_batch_inner(high, low, close, &sweep, Kernel::Scalar, false)
2453 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2454
2455 let js_output = AdxrBatchJsOutput {
2456 values: output.values,
2457 combos: output.combos,
2458 rows: output.rows,
2459 cols: output.cols,
2460 };
2461
2462 serde_wasm_bindgen::to_value(&js_output)
2463 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2464}
2465
2466#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2467#[wasm_bindgen]
2468pub fn adxr_alloc(len: usize) -> *mut f64 {
2469 let mut vec = Vec::<f64>::with_capacity(len);
2470 let ptr = vec.as_mut_ptr();
2471 std::mem::forget(vec);
2472 ptr
2473}
2474
2475#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2476#[wasm_bindgen]
2477pub fn adxr_free(ptr: *mut f64, len: usize) {
2478 if !ptr.is_null() {
2479 unsafe {
2480 let _ = Vec::from_raw_parts(ptr, len, len);
2481 }
2482 }
2483}
2484
2485#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2486#[wasm_bindgen]
2487pub fn adxr_into(
2488 high_ptr: *const f64,
2489 low_ptr: *const f64,
2490 close_ptr: *const f64,
2491 out_ptr: *mut f64,
2492 len: usize,
2493 period: usize,
2494) -> Result<(), JsValue> {
2495 if high_ptr.is_null() || low_ptr.is_null() || close_ptr.is_null() || out_ptr.is_null() {
2496 return Err(JsValue::from_str("Null pointer provided"));
2497 }
2498
2499 unsafe {
2500 let high = std::slice::from_raw_parts(high_ptr, len);
2501 let low = std::slice::from_raw_parts(low_ptr, len);
2502 let close = std::slice::from_raw_parts(close_ptr, len);
2503
2504 if period == 0 || period > len {
2505 return Err(JsValue::from_str("Invalid period"));
2506 }
2507
2508 let params = AdxrParams {
2509 period: Some(period),
2510 };
2511 let input = AdxrInput::from_slices(high, low, close, params);
2512
2513 if high_ptr == out_ptr as *const f64
2514 || low_ptr == out_ptr as *const f64
2515 || close_ptr == out_ptr as *const f64
2516 {
2517 let mut temp = vec![0.0; len];
2518 adxr_into_slice(&mut temp, &input, Kernel::Auto)
2519 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2520 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2521 out.copy_from_slice(&temp);
2522 } else {
2523 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2524 adxr_into_slice(out, &input, Kernel::Auto)
2525 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2526 }
2527
2528 Ok(())
2529 }
2530}
2531
2532#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2533#[wasm_bindgen]
2534pub fn adxr_batch_into(
2535 high_ptr: *const f64,
2536 low_ptr: *const f64,
2537 close_ptr: *const f64,
2538 out_ptr: *mut f64,
2539 len: usize,
2540 period_start: usize,
2541 period_end: usize,
2542 period_step: usize,
2543) -> Result<usize, JsValue> {
2544 if high_ptr.is_null() || low_ptr.is_null() || close_ptr.is_null() || out_ptr.is_null() {
2545 return Err(JsValue::from_str("Null pointer provided"));
2546 }
2547 unsafe {
2548 let h = std::slice::from_raw_parts(high_ptr, len);
2549 let l = std::slice::from_raw_parts(low_ptr, len);
2550 let c = std::slice::from_raw_parts(close_ptr, len);
2551
2552 let sweep = AdxrBatchRange {
2553 period: (period_start, period_end, period_step),
2554 };
2555 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2556 let rows = combos.len();
2557 let cols = len;
2558 let total = rows
2559 .checked_mul(cols)
2560 .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
2561
2562 let out = std::slice::from_raw_parts_mut(out_ptr, total);
2563
2564 adxr_batch_inner_into(h, l, c, &sweep, Kernel::Scalar, false, out)
2565 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2566
2567 Ok(rows)
2568 }
2569}