Skip to main content

vector_ta/indicators/
adxr.rs

1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::cuda_available;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::cuda::moving_averages::alma_wrapper::DeviceArrayF32;
5#[cfg(all(feature = "python", feature = "cuda"))]
6use crate::cuda::CudaAdxr;
7#[cfg(all(feature = "python", feature = "cuda"))]
8use cust::context::Context;
9#[cfg(all(feature = "python", feature = "cuda"))]
10use numpy::PyUntypedArrayMethods;
11#[cfg(feature = "python")]
12use numpy::{IntoPyArray, PyArray1};
13#[cfg(feature = "python")]
14use pyo3::exceptions::PyValueError;
15#[cfg(feature = "python")]
16use pyo3::prelude::*;
17#[cfg(feature = "python")]
18use pyo3::types::PyDict;
19#[cfg(all(feature = "python", feature = "cuda"))]
20use std::sync::Arc;
21
22#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
23use serde::{Deserialize, Serialize};
24#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
25use wasm_bindgen::prelude::*;
26
27use crate::utilities::data_loader::Candles;
28use crate::utilities::enums::Kernel;
29use crate::utilities::helpers::{
30    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
31    make_uninit_matrix,
32};
33#[cfg(feature = "python")]
34use crate::utilities::kernel_validation::validate_kernel;
35#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
36use core::arch::x86_64::*;
37#[cfg(not(target_arch = "wasm32"))]
38use rayon::prelude::*;
39use std::error::Error;
40use thiserror::Error;
41
42#[derive(Debug, Clone)]
43pub enum AdxrData<'a> {
44    Candles {
45        candles: &'a Candles,
46    },
47    Slices {
48        high: &'a [f64],
49        low: &'a [f64],
50        close: &'a [f64],
51    },
52}
53
54#[derive(Debug, Clone)]
55pub struct AdxrOutput {
56    pub values: Vec<f64>,
57}
58
59#[derive(Debug, Clone)]
60#[cfg_attr(
61    all(target_arch = "wasm32", feature = "wasm"),
62    derive(Serialize, Deserialize)
63)]
64pub struct AdxrParams {
65    pub period: Option<usize>,
66}
67
68impl Default for AdxrParams {
69    fn default() -> Self {
70        Self { period: Some(14) }
71    }
72}
73
74#[derive(Debug, Clone)]
75pub struct AdxrInput<'a> {
76    pub data: AdxrData<'a>,
77    pub params: AdxrParams,
78}
79
80impl<'a> AdxrInput<'a> {
81    #[inline]
82    pub fn from_candles(c: &'a Candles, p: AdxrParams) -> Self {
83        Self {
84            data: AdxrData::Candles { candles: c },
85            params: p,
86        }
87    }
88    #[inline]
89    pub fn from_slices(h: &'a [f64], l: &'a [f64], c: &'a [f64], p: AdxrParams) -> Self {
90        Self {
91            data: AdxrData::Slices {
92                high: h,
93                low: l,
94                close: c,
95            },
96            params: p,
97        }
98    }
99    #[inline]
100    pub fn with_default_candles(c: &'a Candles) -> Self {
101        Self::from_candles(c, AdxrParams::default())
102    }
103    #[inline]
104    pub fn get_period(&self) -> usize {
105        self.params.period.unwrap_or(14)
106    }
107}
108
109#[derive(Copy, Clone, Debug)]
110pub struct AdxrBuilder {
111    period: Option<usize>,
112    kernel: Kernel,
113}
114
115impl Default for AdxrBuilder {
116    fn default() -> Self {
117        Self {
118            period: None,
119            kernel: Kernel::Auto,
120        }
121    }
122}
123
124impl AdxrBuilder {
125    #[inline(always)]
126    pub fn new() -> Self {
127        Self::default()
128    }
129    #[inline(always)]
130    pub fn period(mut self, n: usize) -> Self {
131        self.period = Some(n);
132        self
133    }
134    #[inline(always)]
135    pub fn kernel(mut self, k: Kernel) -> Self {
136        self.kernel = k;
137        self
138    }
139    #[inline(always)]
140    pub fn apply(self, c: &Candles) -> Result<AdxrOutput, AdxrError> {
141        let p = AdxrParams {
142            period: self.period,
143        };
144        let i = AdxrInput::from_candles(c, p);
145        adxr_with_kernel(&i, self.kernel)
146    }
147    #[inline(always)]
148    pub fn apply_slices(self, h: &[f64], l: &[f64], c: &[f64]) -> Result<AdxrOutput, AdxrError> {
149        let p = AdxrParams {
150            period: self.period,
151        };
152        let i = AdxrInput::from_slices(h, l, c, p);
153        adxr_with_kernel(&i, self.kernel)
154    }
155    #[inline(always)]
156    pub fn into_stream(self) -> Result<AdxrStream, AdxrError> {
157        let p = AdxrParams {
158            period: self.period,
159        };
160        AdxrStream::try_new(p)
161    }
162}
163
164#[derive(Debug, Error)]
165pub enum AdxrError {
166    #[error("adxr: Candle field error: {0}")]
167    CandleFieldError(String),
168    #[error("adxr: Empty input data (All values are NaN).")]
169    EmptyInputData,
170    #[error("adxr: HLC data length mismatch: high={high_len}, low={low_len}, close={close_len}")]
171    HlcLengthMismatch {
172        high_len: usize,
173        low_len: usize,
174        close_len: usize,
175    },
176    #[error("adxr: All values are NaN.")]
177    AllValuesNaN,
178    #[error("adxr: Invalid period: period = {period}, data length = {data_len}")]
179    InvalidPeriod { period: usize, data_len: usize },
180
181    #[error("adxr: Not enough data: needed = {needed}, valid = {valid}")]
182    NotEnoughValidData { needed: usize, valid: usize },
183    #[error("adxr: Output length mismatch: expected = {expected}, got = {got}")]
184    OutputLengthMismatch { expected: usize, got: usize },
185    #[error("adxr: Invalid kernel type - expected batch kernel, got {kernel:?}")]
186    InvalidKernel { kernel: Kernel },
187    #[error("adxr: Invalid range: start={start}, end={end}, step={step}")]
188    InvalidRange {
189        start: usize,
190        end: usize,
191        step: usize,
192    },
193    #[error("adxr: Invalid kernel for batch: {0:?}")]
194    InvalidKernelForBatch(Kernel),
195}
196
197#[inline]
198pub fn adxr(input: &AdxrInput) -> Result<AdxrOutput, AdxrError> {
199    adxr_with_kernel(input, Kernel::Auto)
200}
201
202pub fn adxr_with_kernel(input: &AdxrInput, kernel: Kernel) -> Result<AdxrOutput, AdxrError> {
203    let (high, low, close, period, first, chosen) = adxr_prepare(input, kernel)?;
204
205    let len = close.len();
206
207    let warmup_period = first + 2 * period;
208    let mut out = alloc_with_nan_prefix(len, warmup_period);
209    unsafe {
210        match chosen {
211            Kernel::Scalar | Kernel::ScalarBatch => {
212                adxr_scalar(high, low, close, period, first, &mut out)
213            }
214            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
215            Kernel::Avx2 | Kernel::Avx2Batch => {
216                adxr_avx2(high, low, close, period, first, &mut out)
217            }
218            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
219            Kernel::Avx512 | Kernel::Avx512Batch => {
220                adxr_avx512(high, low, close, period, first, &mut out)
221            }
222            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
223            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
224                adxr_scalar(high, low, close, period, first, &mut out)
225            }
226            _ => unreachable!(),
227        }
228    }
229
230    Ok(AdxrOutput { values: out })
231}
232
233#[inline(always)]
234fn adxr_prepare<'a>(
235    input: &'a AdxrInput,
236    kernel: Kernel,
237) -> Result<(&'a [f64], &'a [f64], &'a [f64], usize, usize, Kernel), AdxrError> {
238    let (high, low, close) = match &input.data {
239        AdxrData::Candles { candles } => (
240            candles
241                .select_candle_field("high")
242                .map_err(|e| AdxrError::CandleFieldError(e.to_string()))?,
243            candles
244                .select_candle_field("low")
245                .map_err(|e| AdxrError::CandleFieldError(e.to_string()))?,
246            candles
247                .select_candle_field("close")
248                .map_err(|e| AdxrError::CandleFieldError(e.to_string()))?,
249        ),
250        AdxrData::Slices { high, low, close } => (*high, *low, *close),
251    };
252
253    let len = close.len();
254    if len == 0 {
255        return Err(AdxrError::EmptyInputData);
256    }
257    if high.len() != len || low.len() != len {
258        return Err(AdxrError::HlcLengthMismatch {
259            high_len: high.len(),
260            low_len: low.len(),
261            close_len: len,
262        });
263    }
264
265    let first = close
266        .iter()
267        .position(|x| !x.is_nan())
268        .ok_or(AdxrError::AllValuesNaN)?;
269    let period = input.get_period();
270    if period == 0 || period > len {
271        return Err(AdxrError::InvalidPeriod {
272            period,
273            data_len: len,
274        });
275    }
276
277    if len - first < period + 1 {
278        return Err(AdxrError::NotEnoughValidData {
279            needed: period + 1,
280            valid: len - first,
281        });
282    }
283
284    let mut chosen = match kernel {
285        Kernel::Auto => detect_best_kernel(),
286        other => other,
287    };
288
289    #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
290    if matches!(kernel, Kernel::Auto) && matches!(chosen, Kernel::Avx512 | Kernel::Avx512Batch) {
291        chosen = Kernel::Avx2;
292    }
293
294    Ok((high, low, close, period, first, chosen))
295}
296
297#[inline]
298pub fn adxr_into_slice(dst: &mut [f64], input: &AdxrInput, kern: Kernel) -> Result<(), AdxrError> {
299    let (high, low, close, period, first, chosen) = adxr_prepare(input, kern)?;
300
301    let len = close.len();
302    if dst.len() != len {
303        return Err(AdxrError::OutputLengthMismatch {
304            expected: len,
305            got: dst.len(),
306        });
307    }
308
309    unsafe {
310        match chosen {
311            Kernel::Scalar | Kernel::ScalarBatch => {
312                adxr_scalar(high, low, close, period, first, dst)
313            }
314            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
315            Kernel::Avx2 | Kernel::Avx2Batch => adxr_avx2(high, low, close, period, first, dst),
316            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
317            Kernel::Avx512 | Kernel::Avx512Batch => {
318                adxr_avx512(high, low, close, period, first, dst)
319            }
320            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
321            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
322                adxr_scalar(high, low, close, period, first, dst)
323            }
324            _ => unreachable!(),
325        }
326    }
327
328    let warmup_end = (first + 2 * period).min(dst.len());
329    for v in &mut dst[..warmup_end] {
330        *v = f64::NAN;
331    }
332
333    Ok(())
334}
335
336#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
337#[inline]
338pub fn adxr_into(input: &AdxrInput, out: &mut [f64]) -> Result<(), AdxrError> {
339    adxr_into_slice(out, input, Kernel::Auto)
340}
341
342#[inline]
343pub fn adxr_scalar(
344    high: &[f64],
345    low: &[f64],
346    close: &[f64],
347    period: usize,
348    first: usize,
349    out: &mut [f64],
350) {
351    let len = close.len();
352    if len == 0 {
353        return;
354    }
355
356    let p = period as f64;
357    let rp = 1.0 / p;
358    let om = 1.0 - rp;
359    let pm1 = p - 1.0;
360    let warmup_start = first + 2 * period;
361
362    let mut atr_sum = 0.0;
363    let mut plus_dm_sum = 0.0;
364    let mut minus_dm_sum = 0.0;
365
366    let stop = (first + period).min(len.saturating_sub(1));
367    for i in (first + 1)..=stop {
368        let prev_close = close[i - 1];
369        let ch = high[i];
370        let cl = low[i];
371        let ph = high[i - 1];
372        let pl = low[i - 1];
373
374        let a = ch - cl;
375        let b = (ch - prev_close).abs();
376        let c = (cl - prev_close).abs();
377        let tr = a.max(b).max(c);
378        atr_sum += tr;
379
380        let up = ch - ph;
381        let down = pl - cl;
382        if up > down && up > 0.0 {
383            plus_dm_sum += up;
384        }
385        if down > up && down > 0.0 {
386            minus_dm_sum += down;
387        }
388    }
389
390    let denom0 = plus_dm_sum + minus_dm_sum;
391    let initial_dx = if denom0 > 0.0 {
392        100.0 * (plus_dm_sum - minus_dm_sum).abs() / denom0
393    } else {
394        0.0
395    };
396
397    let mut atr = atr_sum;
398    let mut pdm_s = plus_dm_sum;
399    let mut mdm_s = minus_dm_sum;
400
401    let mut dx_sum = initial_dx;
402    let mut dx_count: usize = 1;
403    let mut adx_last = f64::NAN;
404    let mut have_adx = false;
405
406    let mut adx_ring = vec![f64::NAN; period];
407    let mut head = 0usize;
408
409    let mut i = first + period + 1;
410    while i < len {
411        let prev_close = close[i - 1];
412        let ch = high[i];
413        let cl = low[i];
414        let ph = high[i - 1];
415        let pl = low[i - 1];
416
417        let a = ch - cl;
418        let b = (ch - prev_close).abs();
419        let c = (cl - prev_close).abs();
420        let tr = a.max(b).max(c);
421
422        let up = ch - ph;
423        let down = pl - cl;
424        let plus_dm = if up > down && up > 0.0 { up } else { 0.0 };
425        let minus_dm = if down > up && down > 0.0 { down } else { 0.0 };
426
427        atr = atr.mul_add(om, tr);
428        pdm_s = pdm_s.mul_add(om, plus_dm);
429        mdm_s = mdm_s.mul_add(om, minus_dm);
430
431        let denom = pdm_s + mdm_s;
432        let dx = if denom > 0.0 {
433            100.0 * (pdm_s - mdm_s).abs() / denom
434        } else {
435            0.0
436        };
437
438        if dx_count < period {
439            dx_sum += dx;
440            dx_count += 1;
441
442            if i >= warmup_start {
443                out[i] = f64::NAN;
444            }
445
446            if dx_count == period {
447                adx_last = dx_sum * rp;
448                have_adx = true;
449
450                let prev_adx = adx_ring[head];
451                adx_ring[head] = adx_last;
452                head += 1;
453                if head == period {
454                    head = 0;
455                }
456
457                if i >= warmup_start {
458                    let v = if prev_adx.is_finite() {
459                        0.5 * (adx_last + prev_adx)
460                    } else {
461                        f64::NAN
462                    };
463                    out[i] = v;
464                }
465            }
466        } else if have_adx {
467            let adx_curr = (adx_last * pm1 + dx) * rp;
468            adx_last = adx_curr;
469
470            let prev_adx = adx_ring[head];
471            adx_ring[head] = adx_curr;
472            head += 1;
473            if head == period {
474                head = 0;
475            }
476
477            if i >= warmup_start {
478                let v = if prev_adx.is_finite() {
479                    0.5 * (adx_curr + prev_adx)
480                } else {
481                    f64::NAN
482                };
483                out[i] = v;
484            }
485        }
486
487        i += 1;
488    }
489}
490
491#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
492#[inline]
493pub fn adxr_avx512(
494    high: &[f64],
495    low: &[f64],
496    close: &[f64],
497    period: usize,
498    first: usize,
499    out: &mut [f64],
500) {
501    unsafe {
502        if period <= 32 {
503            adxr_avx512_short(high, low, close, period, first, out)
504        } else {
505            adxr_avx512_long(high, low, close, period, first, out)
506        }
507    }
508}
509
510#[inline]
511pub fn adxr_avx2(
512    high: &[f64],
513    low: &[f64],
514    close: &[f64],
515    period: usize,
516    first: usize,
517    out: &mut [f64],
518) {
519    unsafe { adxr_scalar_unchecked(high, low, close, period, first, out) }
520}
521
522#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
523#[inline]
524pub fn adxr_avx512_short(
525    high: &[f64],
526    low: &[f64],
527    close: &[f64],
528    period: usize,
529    first: usize,
530    out: &mut [f64],
531) {
532    adxr_scalar(high, low, close, period, first, out)
533}
534
535#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
536#[inline]
537pub fn adxr_avx512_long(
538    high: &[f64],
539    low: &[f64],
540    close: &[f64],
541    period: usize,
542    first: usize,
543    out: &mut [f64],
544) {
545    adxr_scalar(high, low, close, period, first, out)
546}
547
548#[inline]
549unsafe fn adxr_scalar_unchecked(
550    high: &[f64],
551    low: &[f64],
552    close: &[f64],
553    period: usize,
554    first: usize,
555    out: &mut [f64],
556) {
557    let len = close.len();
558    if len == 0 {
559        return;
560    }
561
562    let p = period as f64;
563    let rp = 1.0 / p;
564    let om = 1.0 - rp;
565    let pm1 = p - 1.0;
566    let warmup_start = first + 2 * period;
567
568    let mut atr_sum = 0.0;
569    let mut plus_dm_sum = 0.0;
570    let mut minus_dm_sum = 0.0;
571
572    let mut i = first + 1;
573    let stop = core::cmp::min(first + period, len - 1);
574    while i <= stop {
575        let prev_close = *close.get_unchecked(i - 1);
576        let ch = *high.get_unchecked(i);
577        let cl = *low.get_unchecked(i);
578        let ph = *high.get_unchecked(i - 1);
579        let pl = *low.get_unchecked(i - 1);
580
581        let a = ch - cl;
582        let b = (ch - prev_close).abs();
583        let c = (cl - prev_close).abs();
584        let tr = a.max(b).max(c);
585        atr_sum += tr;
586
587        let up = ch - ph;
588        let down = pl - cl;
589        if up > down && up > 0.0 {
590            plus_dm_sum += up;
591        }
592        if down > up && down > 0.0 {
593            minus_dm_sum += down;
594        }
595        i += 1;
596    }
597
598    let denom0 = plus_dm_sum + minus_dm_sum;
599    let initial_dx = if denom0 > 0.0 {
600        100.0 * (plus_dm_sum - minus_dm_sum).abs() / denom0
601    } else {
602        0.0
603    };
604
605    let mut atr = atr_sum;
606    let mut pdm_s = plus_dm_sum;
607    let mut mdm_s = minus_dm_sum;
608
609    let mut dx_sum = initial_dx;
610    let mut dx_count: usize = 1;
611    let mut adx_last = f64::NAN;
612    let mut have_adx = false;
613
614    let mut adx_ring = vec![f64::NAN; period];
615    let mut head = 0usize;
616
617    i = first + period + 1;
618    while i < len {
619        let prev_close = *close.get_unchecked(i - 1);
620        let ch = *high.get_unchecked(i);
621        let cl = *low.get_unchecked(i);
622        let ph = *high.get_unchecked(i - 1);
623        let pl = *low.get_unchecked(i - 1);
624
625        let a = ch - cl;
626        let b = (ch - prev_close).abs();
627        let c = (cl - prev_close).abs();
628        let tr = a.max(b).max(c);
629
630        let up = ch - ph;
631        let down = pl - cl;
632        let plus_dm = if up > down && up > 0.0 { up } else { 0.0 };
633        let minus_dm = if down > up && down > 0.0 { down } else { 0.0 };
634
635        atr = atr.mul_add(om, tr);
636        pdm_s = pdm_s.mul_add(om, plus_dm);
637        mdm_s = mdm_s.mul_add(om, minus_dm);
638
639        let denom = pdm_s + mdm_s;
640        let dx = if denom > 0.0 {
641            100.0 * (pdm_s - mdm_s).abs() / denom
642        } else {
643            0.0
644        };
645
646        if dx_count < period {
647            dx_sum += dx;
648            dx_count += 1;
649
650            if i >= warmup_start {
651                *out.get_unchecked_mut(i) = f64::NAN;
652            }
653
654            if dx_count == period {
655                adx_last = dx_sum * rp;
656                have_adx = true;
657
658                let prev_adx = *adx_ring.get_unchecked(head);
659                *adx_ring.get_unchecked_mut(head) = adx_last;
660                head += 1;
661                if head == period {
662                    head = 0;
663                }
664
665                if i >= warmup_start {
666                    let v = if prev_adx.is_finite() {
667                        0.5 * (adx_last + prev_adx)
668                    } else {
669                        f64::NAN
670                    };
671                    *out.get_unchecked_mut(i) = v;
672                }
673            }
674        } else if have_adx {
675            let adx_curr = (adx_last * pm1 + dx) * rp;
676            adx_last = adx_curr;
677
678            let prev_adx = *adx_ring.get_unchecked(head);
679            *adx_ring.get_unchecked_mut(head) = adx_curr;
680            head += 1;
681            if head == period {
682                head = 0;
683            }
684
685            if i >= warmup_start {
686                let v = if prev_adx.is_finite() {
687                    0.5 * (adx_curr + prev_adx)
688                } else {
689                    f64::NAN
690                };
691                *out.get_unchecked_mut(i) = v;
692            }
693        }
694
695        i += 1;
696    }
697}
698
699#[inline(always)]
700pub fn adxr_batch_with_kernel(
701    h: &[f64],
702    l: &[f64],
703    c: &[f64],
704    sweep: &AdxrBatchRange,
705    k: Kernel,
706) -> Result<AdxrBatchOutput, AdxrError> {
707    let mut kernel = match k {
708        Kernel::Auto => detect_best_batch_kernel(),
709        other if other.is_batch() => other,
710        _ => return Err(AdxrError::InvalidKernelForBatch(k)),
711    };
712    #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
713    if matches!(k, Kernel::Auto) && matches!(kernel, Kernel::Avx512Batch) {
714        kernel = Kernel::Avx2Batch;
715    }
716    let simd = match kernel {
717        Kernel::Avx512Batch => Kernel::Avx512,
718        Kernel::Avx2Batch => Kernel::Avx2,
719        Kernel::ScalarBatch => Kernel::Scalar,
720        _ => unreachable!(),
721    };
722
723    let combos = expand_grid(sweep)?;
724    let rows = combos.len();
725    let cols = c.len();
726    let mut buf_mu = make_uninit_matrix(rows, cols);
727    let warm: Vec<usize> = combos
728        .iter()
729        .map(|p| {
730            let first = c.iter().position(|x| !x.is_nan()).unwrap_or(0);
731
732            first + 2 * p.period.unwrap()
733        })
734        .collect();
735    init_matrix_prefixes(&mut buf_mu, cols, &warm);
736
737    let mut guard = core::mem::ManuallyDrop::new(buf_mu);
738    let out: &mut [f64] =
739        unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
740    let combos = adxr_batch_inner_into(h, l, c, sweep, simd, true, out)?;
741    let values = unsafe {
742        Vec::from_raw_parts(
743            guard.as_mut_ptr() as *mut f64,
744            guard.len(),
745            guard.capacity(),
746        )
747    };
748    Ok(AdxrBatchOutput {
749        values,
750        combos,
751        rows,
752        cols,
753    })
754}
755
756#[derive(Clone, Debug)]
757pub struct AdxrBatchRange {
758    pub period: (usize, usize, usize),
759}
760
761impl Default for AdxrBatchRange {
762    fn default() -> Self {
763        Self {
764            period: (14, 263, 1),
765        }
766    }
767}
768
769#[derive(Clone, Debug, Default)]
770pub struct AdxrBatchBuilder {
771    range: AdxrBatchRange,
772    kernel: Kernel,
773}
774impl AdxrBatchBuilder {
775    pub fn new() -> Self {
776        Self::default()
777    }
778    pub fn kernel(mut self, k: Kernel) -> Self {
779        self.kernel = k;
780        self
781    }
782    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
783        self.range.period = (start, end, step);
784        self
785    }
786    pub fn period_static(mut self, p: usize) -> Self {
787        self.range.period = (p, p, 0);
788        self
789    }
790    pub fn apply_slices(
791        self,
792        h: &[f64],
793        l: &[f64],
794        c: &[f64],
795    ) -> Result<AdxrBatchOutput, AdxrError> {
796        adxr_batch_with_kernel(h, l, c, &self.range, self.kernel)
797    }
798    pub fn apply_candles(self, candles: &Candles) -> Result<AdxrBatchOutput, AdxrError> {
799        let h = &candles.high;
800        let l = &candles.low;
801        let c = &candles.close;
802        self.apply_slices(h, l, c)
803    }
804}
805
806#[derive(Clone, Debug)]
807pub struct AdxrBatchOutput {
808    pub values: Vec<f64>,
809    pub combos: Vec<AdxrParams>,
810    pub rows: usize,
811    pub cols: usize,
812}
813impl AdxrBatchOutput {
814    pub fn row_for_params(&self, p: &AdxrParams) -> Option<usize> {
815        self.combos
816            .iter()
817            .position(|c| c.period.unwrap_or(14) == p.period.unwrap_or(14))
818    }
819    pub fn values_for(&self, p: &AdxrParams) -> Option<&[f64]> {
820        self.row_for_params(p).map(|row| {
821            let start = row * self.cols;
822            &self.values[start..start + self.cols]
823        })
824    }
825}
826
827#[inline(always)]
828fn expand_grid(r: &AdxrBatchRange) -> Result<Vec<AdxrParams>, AdxrError> {
829    fn axis((start, end, step): (usize, usize, usize)) -> Option<Vec<usize>> {
830        if step == 0 || start == end {
831            return Some(vec![start]);
832        }
833        if start < end {
834            return Some((start..=end).step_by(step).collect());
835        }
836
837        if step == 0 {
838            return Some(vec![start]);
839        }
840        let mut v = Vec::new();
841        let mut cur = start;
842        while cur >= end {
843            v.push(cur);
844            if let Some(next) = cur.checked_sub(step) {
845                cur = next;
846            } else {
847                break;
848            }
849            if cur == usize::MAX {
850                break;
851            }
852            if cur < end {
853                break;
854            }
855        }
856        Some(v)
857    }
858    let periods = axis(r.period).unwrap_or_default();
859    if periods.is_empty() {
860        return Err(AdxrError::InvalidRange {
861            start: r.period.0,
862            end: r.period.1,
863            step: r.period.2,
864        });
865    }
866    let mut out = Vec::with_capacity(periods.len());
867    for &p in &periods {
868        out.push(AdxrParams { period: Some(p) });
869    }
870    Ok(out)
871}
872
873#[inline]
874fn shared_precompute_tr_dm(
875    high: &[f64],
876    low: &[f64],
877    close: &[f64],
878    first: usize,
879) -> (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>) {
880    let len = close.len();
881    let mut tr_all = vec![0.0; len];
882    let mut pdm_all = vec![0.0; len];
883    let mut mdm_all = vec![0.0; len];
884
885    for i in (first + 1)..len {
886        let prev_close = close[i - 1];
887        let ch = high[i];
888        let cl = low[i];
889        let ph = high[i - 1];
890        let pl = low[i - 1];
891
892        let a = ch - cl;
893        let b = (ch - prev_close).abs();
894        let c = (cl - prev_close).abs();
895        tr_all[i] = a.max(b).max(c);
896
897        let up = ch - ph;
898        let down = pl - cl;
899        pdm_all[i] = if up > down && up > 0.0 { up } else { 0.0 };
900        mdm_all[i] = if down > up && down > 0.0 { down } else { 0.0 };
901    }
902
903    let start = first + 1;
904    let pre_len = len.saturating_sub(start);
905    let mut prefix_tr = vec![0.0; pre_len + 1];
906    let mut prefix_pdm = vec![0.0; pre_len + 1];
907    let mut prefix_mdm = vec![0.0; pre_len + 1];
908    for k in 1..=pre_len {
909        let i = start + (k - 1);
910        prefix_tr[k] = prefix_tr[k - 1] + tr_all[i];
911        prefix_pdm[k] = prefix_pdm[k - 1] + pdm_all[i];
912        prefix_mdm[k] = prefix_mdm[k - 1] + mdm_all[i];
913    }
914
915    (tr_all, pdm_all, mdm_all, prefix_tr, prefix_pdm, prefix_mdm)
916}
917
918#[inline]
919fn adxr_row_from_precomputed(
920    tr_all: &[f64],
921    pdm_all: &[f64],
922    mdm_all: &[f64],
923    prefix_tr: &[f64],
924    prefix_pdm: &[f64],
925    prefix_mdm: &[f64],
926    first: usize,
927    period: usize,
928    out: &mut [f64],
929) {
930    let len = tr_all.len();
931    if len == 0 {
932        return;
933    }
934
935    let p = period as f64;
936    let rp = 1.0 / p;
937    let om = 1.0 - rp;
938    let pm1 = p - 1.0;
939    let warmup_start = first + 2 * period;
940
941    let atr0 = prefix_tr.get(period).copied().unwrap_or(0.0);
942    let pdm0 = prefix_pdm.get(period).copied().unwrap_or(0.0);
943    let mdm0 = prefix_mdm.get(period).copied().unwrap_or(0.0);
944
945    let denom0 = pdm0 + mdm0;
946    let initial_dx = if denom0 > 0.0 {
947        100.0 * (pdm0 - mdm0).abs() / denom0
948    } else {
949        0.0
950    };
951
952    let mut atr = atr0;
953    let mut pdm_s = pdm0;
954    let mut mdm_s = mdm0;
955
956    let mut dx_sum = initial_dx;
957    let mut dx_count: usize = 1;
958    let mut adx_last = f64::NAN;
959    let mut have_adx = false;
960    let mut adx_ring = vec![f64::NAN; period];
961    let mut head = 0usize;
962
963    let mut i = first + period + 1;
964    while i < len {
965        let tr = tr_all[i];
966        let plus_dm = pdm_all[i];
967        let minus_dm = mdm_all[i];
968
969        atr = atr.mul_add(om, tr);
970        pdm_s = pdm_s.mul_add(om, plus_dm);
971        mdm_s = mdm_s.mul_add(om, minus_dm);
972
973        let denom = pdm_s + mdm_s;
974        let dx = if denom > 0.0 {
975            100.0 * (pdm_s - mdm_s).abs() / denom
976        } else {
977            0.0
978        };
979
980        if dx_count < period {
981            dx_sum += dx;
982            dx_count += 1;
983            if i >= warmup_start {
984                out[i] = f64::NAN;
985            }
986            if dx_count == period {
987                adx_last = dx_sum * rp;
988                have_adx = true;
989
990                let prev_adx = adx_ring[head];
991                adx_ring[head] = adx_last;
992                head += 1;
993                if head == period {
994                    head = 0;
995                }
996
997                if i >= warmup_start {
998                    let v = if prev_adx.is_finite() {
999                        0.5 * (adx_last + prev_adx)
1000                    } else {
1001                        f64::NAN
1002                    };
1003                    out[i] = v;
1004                }
1005            }
1006        } else if have_adx {
1007            let adx_curr = (adx_last * pm1 + dx) * rp;
1008            adx_last = adx_curr;
1009
1010            let prev_adx = adx_ring[head];
1011            adx_ring[head] = adx_curr;
1012            head += 1;
1013            if head == period {
1014                head = 0;
1015            }
1016
1017            if i >= warmup_start {
1018                let v = if prev_adx.is_finite() {
1019                    0.5 * (adx_curr + prev_adx)
1020                } else {
1021                    f64::NAN
1022                };
1023                out[i] = v;
1024            }
1025        }
1026
1027        i += 1;
1028    }
1029}
1030
1031#[inline(always)]
1032pub fn adxr_batch_slice(
1033    high: &[f64],
1034    low: &[f64],
1035    close: &[f64],
1036    sweep: &AdxrBatchRange,
1037    kern: Kernel,
1038) -> Result<AdxrBatchOutput, AdxrError> {
1039    adxr_batch_inner(high, low, close, sweep, kern, false)
1040}
1041
1042#[inline(always)]
1043pub fn adxr_batch_par_slice(
1044    high: &[f64],
1045    low: &[f64],
1046    close: &[f64],
1047    sweep: &AdxrBatchRange,
1048    kern: Kernel,
1049) -> Result<AdxrBatchOutput, AdxrError> {
1050    adxr_batch_inner(high, low, close, sweep, kern, true)
1051}
1052
1053#[inline(always)]
1054fn adxr_batch_inner(
1055    high: &[f64],
1056    low: &[f64],
1057    close: &[f64],
1058    sweep: &AdxrBatchRange,
1059    kern: Kernel,
1060    parallel: bool,
1061) -> Result<AdxrBatchOutput, AdxrError> {
1062    let combos = expand_grid(sweep)?;
1063    let len = close.len();
1064    let first = close
1065        .iter()
1066        .position(|x| !x.is_nan())
1067        .ok_or(AdxrError::AllValuesNaN)?;
1068    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1069
1070    if len - first < max_p + 1 {
1071        return Err(AdxrError::NotEnoughValidData {
1072            needed: max_p + 1,
1073            valid: len - first,
1074        });
1075    }
1076    let rows = combos.len();
1077    let cols = len;
1078
1079    let mut buf_mu = make_uninit_matrix(rows, cols);
1080
1081    let warm: Vec<usize> = combos
1082        .iter()
1083        .map(|c| first + 2 * c.period.unwrap())
1084        .collect();
1085
1086    init_matrix_prefixes(&mut buf_mu, cols, &warm);
1087
1088    let mut buf_guard = std::mem::ManuallyDrop::new(buf_mu);
1089    let values: &mut [f64] = unsafe {
1090        std::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
1091    };
1092
1093    let (tr_all, pdm_all, mdm_all, prefix_tr, prefix_pdm, prefix_mdm) =
1094        shared_precompute_tr_dm(high, low, close, first);
1095
1096    let do_row = |row: usize, out_row: &mut [f64]| unsafe {
1097        let period = combos[row].period.unwrap();
1098
1099        match kern {
1100            Kernel::Scalar => adxr_row_from_precomputed(
1101                &tr_all,
1102                &pdm_all,
1103                &mdm_all,
1104                &prefix_tr,
1105                &prefix_pdm,
1106                &prefix_mdm,
1107                first,
1108                period,
1109                out_row,
1110            ),
1111            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1112            Kernel::Avx2 | Kernel::Avx512 => adxr_row_from_precomputed(
1113                &tr_all,
1114                &pdm_all,
1115                &mdm_all,
1116                &prefix_tr,
1117                &prefix_pdm,
1118                &prefix_mdm,
1119                first,
1120                period,
1121                out_row,
1122            ),
1123            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1124            Kernel::Avx2 | Kernel::Avx512 => adxr_row_from_precomputed(
1125                &tr_all,
1126                &pdm_all,
1127                &mdm_all,
1128                &prefix_tr,
1129                &prefix_pdm,
1130                &prefix_mdm,
1131                first,
1132                period,
1133                out_row,
1134            ),
1135            _ => unreachable!(),
1136        }
1137    };
1138
1139    if parallel {
1140        #[cfg(not(target_arch = "wasm32"))]
1141        {
1142            values
1143                .par_chunks_mut(cols)
1144                .enumerate()
1145                .for_each(|(row, slice)| do_row(row, slice));
1146        }
1147
1148        #[cfg(target_arch = "wasm32")]
1149        {
1150            for (row, slice) in values.chunks_mut(cols).enumerate() {
1151                do_row(row, slice);
1152            }
1153        }
1154    } else {
1155        for (row, slice) in values.chunks_mut(cols).enumerate() {
1156            do_row(row, slice);
1157        }
1158    }
1159
1160    let values = unsafe {
1161        Vec::from_raw_parts(
1162            buf_guard.as_mut_ptr() as *mut f64,
1163            buf_guard.len(),
1164            buf_guard.capacity(),
1165        )
1166    };
1167
1168    Ok(AdxrBatchOutput {
1169        values,
1170        combos,
1171        rows,
1172        cols,
1173    })
1174}
1175
1176#[inline(always)]
1177pub fn adxr_batch_inner_into(
1178    high: &[f64],
1179    low: &[f64],
1180    close: &[f64],
1181    sweep: &AdxrBatchRange,
1182    kern: Kernel,
1183    parallel: bool,
1184    out: &mut [f64],
1185) -> Result<Vec<AdxrParams>, AdxrError> {
1186    let combos = expand_grid(sweep)?;
1187
1188    let len = close.len();
1189    if high.len() != len || low.len() != len {
1190        return Err(AdxrError::HlcLengthMismatch {
1191            high_len: high.len(),
1192            low_len: low.len(),
1193            close_len: len,
1194        });
1195    }
1196
1197    let first = close
1198        .iter()
1199        .position(|x| !x.is_nan())
1200        .ok_or(AdxrError::AllValuesNaN)?;
1201    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1202
1203    if len - first < max_p + 1 {
1204        return Err(AdxrError::NotEnoughValidData {
1205            needed: max_p + 1,
1206            valid: len - first,
1207        });
1208    }
1209
1210    let rows = combos.len();
1211    let cols = len;
1212    if let Some(expected) = rows.checked_mul(cols) {
1213        if out.len() != expected {
1214            return Err(AdxrError::OutputLengthMismatch {
1215                expected,
1216                got: out.len(),
1217            });
1218        }
1219    } else {
1220        return Err(AdxrError::InvalidRange {
1221            start: rows,
1222            end: cols,
1223            step: 0,
1224        });
1225    }
1226
1227    let warm: Vec<usize> = combos
1228        .iter()
1229        .map(|c| first + 2 * c.period.unwrap())
1230        .collect();
1231
1232    let out_mu: &mut [std::mem::MaybeUninit<f64>] = unsafe {
1233        std::slice::from_raw_parts_mut(
1234            out.as_mut_ptr() as *mut std::mem::MaybeUninit<f64>,
1235            out.len(),
1236        )
1237    };
1238    init_matrix_prefixes(out_mu, cols, &warm);
1239
1240    let (tr_all, pdm_all, mdm_all, prefix_tr, prefix_pdm, prefix_mdm) =
1241        shared_precompute_tr_dm(high, low, close, first);
1242
1243    let do_row = |row: usize, dst_mu: &mut [std::mem::MaybeUninit<f64>]| unsafe {
1244        let period = combos[row].period.unwrap();
1245        let dst = std::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
1246
1247        match kern {
1248            Kernel::Scalar => adxr_row_from_precomputed(
1249                &tr_all,
1250                &pdm_all,
1251                &mdm_all,
1252                &prefix_tr,
1253                &prefix_pdm,
1254                &prefix_mdm,
1255                first,
1256                period,
1257                dst,
1258            ),
1259            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1260            Kernel::Avx2 | Kernel::Avx512 => adxr_row_from_precomputed(
1261                &tr_all,
1262                &pdm_all,
1263                &mdm_all,
1264                &prefix_tr,
1265                &prefix_pdm,
1266                &prefix_mdm,
1267                first,
1268                period,
1269                dst,
1270            ),
1271            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1272            Kernel::Avx2 | Kernel::Avx512 => adxr_row_from_precomputed(
1273                &tr_all,
1274                &pdm_all,
1275                &mdm_all,
1276                &prefix_tr,
1277                &prefix_pdm,
1278                &prefix_mdm,
1279                first,
1280                period,
1281                dst,
1282            ),
1283            _ => unreachable!("pass non-batch kernel"),
1284        }
1285    };
1286
1287    if parallel {
1288        #[cfg(not(target_arch = "wasm32"))]
1289        {
1290            use rayon::prelude::*;
1291            out_mu
1292                .par_chunks_mut(cols)
1293                .enumerate()
1294                .for_each(|(row, slice)| do_row(row, slice));
1295        }
1296        #[cfg(target_arch = "wasm32")]
1297        {
1298            for (row, slice) in out_mu.chunks_mut(cols).enumerate() {
1299                do_row(row, slice);
1300            }
1301        }
1302    } else {
1303        for (row, slice) in out_mu.chunks_mut(cols).enumerate() {
1304            do_row(row, slice);
1305        }
1306    }
1307
1308    Ok(combos)
1309}
1310
1311#[inline(always)]
1312unsafe fn adxr_row_scalar(
1313    high: &[f64],
1314    low: &[f64],
1315    close: &[f64],
1316    first: usize,
1317    period: usize,
1318    out: &mut [f64],
1319) {
1320    adxr_scalar(high, low, close, period, first, out)
1321}
1322
1323#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1324#[inline(always)]
1325unsafe fn adxr_row_avx2(
1326    high: &[f64],
1327    low: &[f64],
1328    close: &[f64],
1329    first: usize,
1330    period: usize,
1331    out: &mut [f64],
1332) {
1333    adxr_scalar(high, low, close, period, first, out)
1334}
1335
1336#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1337#[inline(always)]
1338unsafe fn adxr_row_avx512(
1339    high: &[f64],
1340    low: &[f64],
1341    close: &[f64],
1342    first: usize,
1343    period: usize,
1344    out: &mut [f64],
1345) {
1346    adxr_avx512(high, low, close, period, first, out)
1347}
1348
1349#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1350#[inline(always)]
1351unsafe fn adxr_row_avx512_short(
1352    high: &[f64],
1353    low: &[f64],
1354    close: &[f64],
1355    first: usize,
1356    period: usize,
1357    out: &mut [f64],
1358) {
1359    adxr_scalar(high, low, close, period, first, out)
1360}
1361
1362#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1363#[inline(always)]
1364unsafe fn adxr_row_avx512_long(
1365    high: &[f64],
1366    low: &[f64],
1367    close: &[f64],
1368    first: usize,
1369    period: usize,
1370    out: &mut [f64],
1371) {
1372    adxr_scalar(high, low, close, period, first, out)
1373}
1374
1375#[derive(Debug, Clone)]
1376pub struct AdxrStream {
1377    period: usize,
1378
1379    rp: f64,
1380    om: f64,
1381    pm1: f64,
1382
1383    atr: f64,
1384    pdm_s: f64,
1385    mdm_s: f64,
1386
1387    dx_sum: f64,
1388    dx_count: usize,
1389    adx_last: f64,
1390    have_adx: bool,
1391
1392    adx_ring: Vec<f64>,
1393    head: usize,
1394
1395    prev_hlc: Option<(f64, f64, f64)>,
1396
1397    seen: usize,
1398}
1399
1400impl AdxrStream {
1401    #[inline(always)]
1402    pub fn try_new(params: AdxrParams) -> Result<Self, AdxrError> {
1403        let period = params.period.unwrap_or(14);
1404        if period == 0 {
1405            return Err(AdxrError::InvalidPeriod {
1406                period,
1407                data_len: 0,
1408            });
1409        }
1410        let p = period as f64;
1411        Ok(Self {
1412            period,
1413            rp: 1.0 / p,
1414            om: 1.0 - 1.0 / p,
1415            pm1: p - 1.0,
1416            atr: 0.0,
1417            pdm_s: 0.0,
1418            mdm_s: 0.0,
1419            dx_sum: 0.0,
1420            dx_count: 0,
1421            adx_last: f64::NAN,
1422            have_adx: false,
1423            adx_ring: vec![f64::NAN; period],
1424            head: 0,
1425            prev_hlc: None,
1426            seen: 0,
1427        })
1428    }
1429
1430    #[inline(always)]
1431    pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<f64> {
1432        if !(high.is_finite() && low.is_finite() && close.is_finite()) {
1433            return None;
1434        }
1435
1436        if self.prev_hlc.is_none() {
1437            self.prev_hlc = Some((high, low, close));
1438            return None;
1439        }
1440
1441        let (ph, pl, pc) = unsafe { self.prev_hlc.unwrap_unchecked() };
1442        self.prev_hlc = Some((high, low, close));
1443        self.seen = self.seen.wrapping_add(1);
1444
1445        let tr = {
1446            let a = high - low;
1447            let b = (high - pc).abs();
1448            let c = (low - pc).abs();
1449            a.max(b).max(c)
1450        };
1451
1452        let up = high - ph;
1453        let down = pl - low;
1454        let plus_dm = if up > down && up > 0.0 { up } else { 0.0 };
1455        let minus_dm = if down > up && down > 0.0 { down } else { 0.0 };
1456
1457        if self.seen <= self.period {
1458            self.atr += tr;
1459            self.pdm_s += plus_dm;
1460            self.mdm_s += minus_dm;
1461
1462            if self.seen == self.period {
1463                let denom = self.pdm_s + self.mdm_s;
1464                let dx0 = if denom > 0.0 {
1465                    100.0 * (self.pdm_s - self.mdm_s).abs() / denom
1466                } else {
1467                    0.0
1468                };
1469                self.dx_sum = dx0;
1470                self.dx_count = 1;
1471            }
1472            return None;
1473        }
1474
1475        self.atr = self.atr.mul_add(self.om, tr);
1476        self.pdm_s = self.pdm_s.mul_add(self.om, plus_dm);
1477        self.mdm_s = self.mdm_s.mul_add(self.om, minus_dm);
1478
1479        let denom = self.pdm_s + self.mdm_s;
1480        let dx = if denom > 0.0 {
1481            100.0 * (self.pdm_s - self.mdm_s).abs() / denom
1482        } else {
1483            0.0
1484        };
1485
1486        if !self.have_adx {
1487            if self.dx_count + 1 < self.period {
1488                self.dx_sum += dx;
1489                self.dx_count += 1;
1490                return None;
1491            } else {
1492                self.dx_sum += dx;
1493                self.adx_last = self.dx_sum * self.rp;
1494                self.have_adx = true;
1495
1496                self.adx_ring[self.head] = self.adx_last;
1497                self.head = (self.head + 1) % self.period;
1498                return None;
1499            }
1500        }
1501
1502        let adx_curr = (self.adx_last.mul_add(self.pm1, dx)) * self.rp;
1503        self.adx_last = adx_curr;
1504
1505        let adx_period_ago = self.adx_ring[self.head];
1506        self.adx_ring[self.head] = adx_curr;
1507        self.head = (self.head + 1) % self.period;
1508
1509        if adx_period_ago.is_finite() {
1510            Some(0.5 * (adx_curr + adx_period_ago))
1511        } else {
1512            None
1513        }
1514    }
1515}
1516
1517#[cfg(test)]
1518mod tests {
1519    use super::*;
1520    use crate::skip_if_unsupported;
1521    use crate::utilities::data_loader::read_candles_from_csv;
1522    use paste::paste;
1523
1524    fn check_adxr_partial_params(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1525        skip_if_unsupported!(kernel, test);
1526        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1527        let candles = read_candles_from_csv(file_path)?;
1528        let input = AdxrInput::from_candles(&candles, AdxrParams { period: None });
1529        let output = adxr_with_kernel(&input, kernel)?;
1530        assert_eq!(output.values.len(), candles.close.len());
1531        Ok(())
1532    }
1533
1534    fn check_adxr_accuracy(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1535        skip_if_unsupported!(kernel, test);
1536        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1537        let candles = read_candles_from_csv(file_path)?;
1538        let input = AdxrInput::from_candles(&candles, AdxrParams::default());
1539        let result = adxr_with_kernel(&input, kernel)?;
1540        let expected = [37.10, 37.3, 37.0, 36.2, 36.3];
1541        let start = result.values.len().saturating_sub(5);
1542        for (i, &val) in result.values[start..].iter().enumerate() {
1543            let diff = (val - expected[i]).abs();
1544            assert!(
1545                diff < 1e-1,
1546                "[{}] ADXR {:?} mismatch at idx {}: got {}, expected {}",
1547                test,
1548                kernel,
1549                i,
1550                val,
1551                expected[i]
1552            );
1553        }
1554        Ok(())
1555    }
1556
1557    fn check_adxr_zero_period(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1558        skip_if_unsupported!(kernel, test);
1559        let high = [10.0, 20.0, 30.0];
1560        let low = [9.0, 19.0, 29.0];
1561        let close = [9.5, 19.5, 29.5];
1562        let input = AdxrInput::from_slices(&high, &low, &close, AdxrParams { period: Some(0) });
1563        let res = adxr_with_kernel(&input, kernel);
1564        assert!(res.is_err(), "[{}] ADXR should fail with zero period", test);
1565        Ok(())
1566    }
1567
1568    fn check_adxr_period_exceeds_length(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1569        skip_if_unsupported!(kernel, test);
1570        let high = [10.0, 20.0];
1571        let low = [9.0, 19.0];
1572        let close = [9.5, 19.5];
1573        let input = AdxrInput::from_slices(&high, &low, &close, AdxrParams { period: Some(10) });
1574        let res = adxr_with_kernel(&input, kernel);
1575        assert!(
1576            res.is_err(),
1577            "[{}] ADXR should fail with period > data.len()",
1578            test
1579        );
1580        Ok(())
1581    }
1582
1583    fn check_adxr_very_small_dataset(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1584        skip_if_unsupported!(kernel, test);
1585        let high = [100.0];
1586        let low = [99.0];
1587        let close = [99.5];
1588        let input = AdxrInput::from_slices(&high, &low, &close, AdxrParams { period: Some(14) });
1589        let res = adxr_with_kernel(&input, kernel);
1590        assert!(
1591            res.is_err(),
1592            "[{}] ADXR should fail with insufficient data",
1593            test
1594        );
1595        Ok(())
1596    }
1597
1598    fn check_adxr_reinput(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1599        skip_if_unsupported!(kernel, test);
1600        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1601        let candles = read_candles_from_csv(file_path)?;
1602        let first_input = AdxrInput::from_candles(&candles, AdxrParams { period: Some(14) });
1603        let first_result = adxr_with_kernel(&first_input, kernel)?;
1604        let high = &candles.high;
1605        let low = &candles.low;
1606        let close = &candles.close;
1607        let second_input = AdxrInput::from_slices(high, low, close, AdxrParams { period: Some(5) });
1608        let second_result = adxr_with_kernel(&second_input, kernel)?;
1609        assert_eq!(second_result.values.len(), candles.close.len());
1610        Ok(())
1611    }
1612
1613    fn check_adxr_nan_handling(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1614        skip_if_unsupported!(kernel, test);
1615        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1616        let candles = read_candles_from_csv(file_path)?;
1617        let input = AdxrInput::from_candles(&candles, AdxrParams { period: Some(14) });
1618        let res = adxr_with_kernel(&input, kernel)?;
1619        assert_eq!(res.values.len(), candles.close.len());
1620        if res.values.len() > 240 {
1621            for (i, &val) in res.values[240..].iter().enumerate() {
1622                assert!(
1623                    !val.is_nan(),
1624                    "[{}] Found unexpected NaN at out-index {}",
1625                    test,
1626                    240 + i
1627                );
1628            }
1629        }
1630        Ok(())
1631    }
1632
1633    macro_rules! generate_all_adxr_tests {
1634        ($($test_fn:ident),*) => {
1635            paste! {
1636                $(
1637                    #[test]
1638                    fn [<$test_fn _scalar_f64>]() {
1639                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1640                    }
1641                )*
1642                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1643                $(
1644                    #[test]
1645                    fn [<$test_fn _avx2_f64>]() {
1646                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1647                    }
1648                    #[test]
1649                    fn [<$test_fn _avx512_f64>]() {
1650                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1651                    }
1652                )*
1653            }
1654        }
1655    }
1656
1657    #[cfg(debug_assertions)]
1658    fn check_adxr_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1659        skip_if_unsupported!(kernel, test_name);
1660
1661        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1662        let candles = read_candles_from_csv(file_path)?;
1663
1664        let test_params = vec![
1665            AdxrParams::default(),
1666            AdxrParams { period: Some(5) },
1667            AdxrParams { period: Some(10) },
1668            AdxrParams { period: Some(14) },
1669            AdxrParams { period: Some(20) },
1670            AdxrParams { period: Some(25) },
1671            AdxrParams { period: Some(30) },
1672            AdxrParams { period: Some(50) },
1673            AdxrParams { period: Some(100) },
1674            AdxrParams { period: Some(2) },
1675        ];
1676
1677        for (param_idx, params) in test_params.iter().enumerate() {
1678            let input = AdxrInput::from_candles(&candles, params.clone());
1679            let output = adxr_with_kernel(&input, kernel)?;
1680
1681            for (i, &val) in output.values.iter().enumerate() {
1682                if val.is_nan() {
1683                    continue;
1684                }
1685
1686                let bits = val.to_bits();
1687
1688                if bits == 0x11111111_11111111 {
1689                    panic!(
1690                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1691                        with params: period={}",
1692                        test_name,
1693                        val,
1694                        bits,
1695                        i,
1696                        params.period.unwrap_or(14)
1697                    );
1698                }
1699
1700                if bits == 0x22222222_22222222 {
1701                    panic!(
1702                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1703                        with params: period={}",
1704                        test_name,
1705                        val,
1706                        bits,
1707                        i,
1708                        params.period.unwrap_or(14)
1709                    );
1710                }
1711
1712                if bits == 0x33333333_33333333 {
1713                    panic!(
1714                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1715                        with params: period={}",
1716                        test_name,
1717                        val,
1718                        bits,
1719                        i,
1720                        params.period.unwrap_or(14)
1721                    );
1722                }
1723            }
1724        }
1725
1726        Ok(())
1727    }
1728
1729    #[cfg(not(debug_assertions))]
1730    fn check_adxr_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1731        Ok(())
1732    }
1733
1734    fn check_adxr_property(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1735        use proptest::prelude::*;
1736        skip_if_unsupported!(kernel, test_name);
1737
1738        let strat = (2usize..=50)
1739            .prop_flat_map(|period| {
1740                let min_size = (period * 3).max(period + 10);
1741                let max_size = 400;
1742                (
1743                    10.0f64..1000.0f64,
1744                    0.0f64..0.1f64,
1745                    -0.01f64..0.01f64,
1746                    min_size..max_size,
1747                    Just(period),
1748                    0u8..3,
1749                )
1750            })
1751            .prop_map(
1752                |(base_price, volatility_pct, trend, size, period, market_type)| {
1753                    let mut high_data = Vec::with_capacity(size);
1754                    let mut low_data = Vec::with_capacity(size);
1755                    let mut close_data = Vec::with_capacity(size);
1756
1757                    for i in 0..size {
1758                        let price = match market_type {
1759                            0 => {
1760                                let cycle = (i as f64 * 0.1).sin();
1761                                base_price * (1.0 + cycle * volatility_pct)
1762                            }
1763                            1 => base_price * (1.0 + trend * i as f64),
1764                            2 => base_price,
1765                            _ => base_price,
1766                        };
1767
1768                        let (high, low, close) = if market_type == 2 {
1769                            (price, price, price)
1770                        } else {
1771                            let daily_volatility =
1772                                price * volatility_pct * (0.5 + 0.5 * (i as f64 * 0.05).cos());
1773                            let close = price;
1774                            let high = close + daily_volatility.abs();
1775                            let low = close - daily_volatility.abs();
1776                            (high, low, close)
1777                        };
1778
1779                        high_data.push(high);
1780                        low_data.push(low);
1781                        close_data.push(close);
1782                    }
1783
1784                    (high_data, low_data, close_data, period, market_type)
1785                },
1786            );
1787
1788        proptest::test_runner::TestRunner::default().run(
1789            &strat,
1790            |(high_data, low_data, close_data, period, market_type)| {
1791                let params = AdxrParams {
1792                    period: Some(period),
1793                };
1794                let input = AdxrInput::from_slices(&high_data, &low_data, &close_data, params);
1795
1796                let result = adxr_with_kernel(&input, kernel);
1797                prop_assert!(result.is_ok(), "ADXR computation failed: {:?}", result);
1798                let AdxrOutput { values: out } = result.unwrap();
1799
1800                let ref_result = adxr_with_kernel(&input, Kernel::Scalar);
1801                prop_assert!(ref_result.is_ok(), "Reference ADXR computation failed");
1802                let AdxrOutput { values: ref_out } = ref_result.unwrap();
1803
1804                let first = close_data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1805
1806                let warmup_period = first + 2 * period;
1807
1808                for i in 0..out.len() {
1809                    let y = out[i];
1810                    let r = ref_out[i];
1811
1812                    if i < warmup_period {
1813                        prop_assert!(
1814                            y.is_nan(),
1815                            "Expected NaN during warmup at index {}, got {}",
1816                            i,
1817                            y
1818                        );
1819                    } else {
1820                        if !y.is_nan() {
1821                            prop_assert!(
1822                                y >= -1e-9 && y <= 100.0 + 1e-9,
1823                                "ADXR value {} at index {} is outside [0, 100] range",
1824                                y,
1825                                i
1826                            );
1827                        }
1828
1829                        if !y.is_nan() && !r.is_nan() {
1830                            let diff = (y - r).abs();
1831                            prop_assert!(
1832                                diff < 1e-6,
1833                                "Kernel {:?} and Scalar differ by {} at index {}: {} vs {}",
1834                                kernel,
1835                                diff,
1836                                i,
1837                                y,
1838                                r
1839                            );
1840                        }
1841                    }
1842                }
1843
1844                if market_type == 2 && out.len() > warmup_period + period {
1845                    let last_values = &out[out.len().saturating_sub(10)..];
1846                    let non_nan_values: Vec<f64> = last_values
1847                        .iter()
1848                        .filter(|v| !v.is_nan())
1849                        .copied()
1850                        .collect();
1851
1852                    if !non_nan_values.is_empty() {
1853                        let avg_last =
1854                            non_nan_values.iter().sum::<f64>() / non_nan_values.len() as f64;
1855                        prop_assert!(
1856                            avg_last < 25.0,
1857                            "ADXR should be low with zero volatility, got average {}",
1858                            avg_last
1859                        );
1860                    }
1861                }
1862
1863                if period == 2 {
1864                    prop_assert!(out.len() == close_data.len());
1865                }
1866
1867                let is_constant = close_data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10)
1868                    && high_data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10)
1869                    && low_data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10);
1870
1871                if is_constant && out.len() > warmup_period {
1872                    let stable_values = &out[warmup_period..];
1873                    let non_nan: Vec<f64> = stable_values
1874                        .iter()
1875                        .filter(|v| !v.is_nan())
1876                        .copied()
1877                        .collect();
1878
1879                    if non_nan.len() > 10 {
1880                        let mean = non_nan.iter().sum::<f64>() / non_nan.len() as f64;
1881                        let variance = non_nan.iter().map(|v| (v - mean).powi(2)).sum::<f64>()
1882                            / non_nan.len() as f64;
1883                        let std_dev = variance.sqrt();
1884
1885                        prop_assert!(
1886                            std_dev < 5.0,
1887                            "ADXR should stabilize with constant data, std_dev = {}",
1888                            std_dev
1889                        );
1890                    }
1891                }
1892
1893                Ok(())
1894            },
1895        )?;
1896
1897        Ok(())
1898    }
1899
1900    generate_all_adxr_tests!(
1901        check_adxr_partial_params,
1902        check_adxr_accuracy,
1903        check_adxr_zero_period,
1904        check_adxr_period_exceeds_length,
1905        check_adxr_very_small_dataset,
1906        check_adxr_reinput,
1907        check_adxr_nan_handling,
1908        check_adxr_no_poison,
1909        check_adxr_property
1910    );
1911
1912    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1913        skip_if_unsupported!(kernel, test);
1914        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1915        let c = read_candles_from_csv(file)?;
1916        let output = AdxrBatchBuilder::new().kernel(kernel).apply_candles(&c)?;
1917        let def = AdxrParams::default();
1918        let row = output.values_for(&def).expect("default row missing");
1919        assert_eq!(row.len(), c.close.len());
1920        Ok(())
1921    }
1922
1923    macro_rules! gen_batch_tests {
1924        ($fn_name:ident) => {
1925            paste! {
1926                #[test] fn [<$fn_name _scalar>]()      {
1927                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1928                }
1929                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1930                #[test] fn [<$fn_name _avx2>]()        {
1931                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1932                }
1933                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1934                #[test] fn [<$fn_name _avx512>]()      {
1935                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1936                }
1937                #[test] fn [<$fn_name _auto_detect>]() {
1938                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1939                }
1940            }
1941        };
1942    }
1943    #[cfg(debug_assertions)]
1944    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1945        skip_if_unsupported!(kernel, test);
1946
1947        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1948        let c = read_candles_from_csv(file)?;
1949
1950        let test_configs = vec![
1951            (2, 10, 2),
1952            (5, 25, 5),
1953            (10, 20, 2),
1954            (14, 50, 6),
1955            (20, 100, 10),
1956            (2, 30, 7),
1957            (8, 40, 8),
1958        ];
1959
1960        for (cfg_idx, &(p_start, p_end, p_step)) in test_configs.iter().enumerate() {
1961            let output = AdxrBatchBuilder::new()
1962                .kernel(kernel)
1963                .period_range(p_start, p_end, p_step)
1964                .apply_candles(&c)?;
1965
1966            for (idx, &val) in output.values.iter().enumerate() {
1967                if val.is_nan() {
1968                    continue;
1969                }
1970
1971                let bits = val.to_bits();
1972                let row = idx / output.cols;
1973                let col = idx % output.cols;
1974                let combo = &output.combos[row];
1975
1976                if bits == 0x11111111_11111111 {
1977                    panic!(
1978                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1979                        at row {} col {} (flat index {}) with params: period={}",
1980                        test,
1981                        cfg_idx,
1982                        val,
1983                        bits,
1984                        row,
1985                        col,
1986                        idx,
1987                        combo.period.unwrap_or(14)
1988                    );
1989                }
1990
1991                if bits == 0x22222222_22222222 {
1992                    panic!(
1993                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1994                        at row {} col {} (flat index {}) with params: period={}",
1995                        test,
1996                        cfg_idx,
1997                        val,
1998                        bits,
1999                        row,
2000                        col,
2001                        idx,
2002                        combo.period.unwrap_or(14)
2003                    );
2004                }
2005
2006                if bits == 0x33333333_33333333 {
2007                    panic!(
2008                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2009                        at row {} col {} (flat index {}) with params: period={}",
2010                        test,
2011                        cfg_idx,
2012                        val,
2013                        bits,
2014                        row,
2015                        col,
2016                        idx,
2017                        combo.period.unwrap_or(14)
2018                    );
2019                }
2020            }
2021        }
2022
2023        Ok(())
2024    }
2025
2026    #[cfg(not(debug_assertions))]
2027    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2028        Ok(())
2029    }
2030
2031    gen_batch_tests!(check_batch_default_row);
2032    gen_batch_tests!(check_batch_no_poison);
2033
2034    #[test]
2035    fn test_adxr_into_matches_api() -> Result<(), Box<dyn Error>> {
2036        let len = 256usize;
2037        let mut high = vec![0.0f64; len];
2038        let mut low = vec![0.0f64; len];
2039        let mut close = vec![0.0f64; len];
2040        for i in 0..len {
2041            let base = 100.0 + (i as f64) * 0.1 + (i as f64 * 0.07).sin();
2042            low[i] = base - 1.0;
2043            close[i] = base - 0.3;
2044            high[i] = base + 0.8;
2045        }
2046
2047        let input = AdxrInput::from_slices(&high, &low, &close, AdxrParams::default());
2048
2049        let baseline = adxr(&input)?.values;
2050
2051        let mut out = vec![0.0f64; len];
2052        #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2053        {
2054            adxr_into(&input, &mut out)?;
2055        }
2056        #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2057        {
2058            adxr_into_slice(&mut out, &input, Kernel::Auto)?;
2059        }
2060
2061        assert_eq!(baseline.len(), out.len());
2062        for (a, b) in baseline.iter().zip(out.iter()) {
2063            let equal = (a.is_nan() && b.is_nan()) || (a == b);
2064            assert!(equal, "Mismatch: baseline={} out={}", a, b);
2065        }
2066
2067        Ok(())
2068    }
2069}
2070
2071#[cfg(feature = "python")]
2072#[pyfunction(name = "adxr")]
2073#[pyo3(signature = (high, low, close, period=None, kernel=None))]
2074pub fn adxr_py<'py>(
2075    py: Python<'py>,
2076    high: numpy::PyReadonlyArray1<'py, f64>,
2077    low: numpy::PyReadonlyArray1<'py, f64>,
2078    close: numpy::PyReadonlyArray1<'py, f64>,
2079    period: Option<usize>,
2080    kernel: Option<&str>,
2081) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
2082    use numpy::{IntoPyArray, PyArrayMethods};
2083
2084    let high_slice = high.as_slice()?;
2085    let low_slice = low.as_slice()?;
2086    let close_slice = close.as_slice()?;
2087
2088    if high_slice.len() != low_slice.len() || high_slice.len() != close_slice.len() {
2089        return Err(PyValueError::new_err(format!(
2090            "HLC data length mismatch: high={}, low={}, close={}",
2091            high_slice.len(),
2092            low_slice.len(),
2093            close_slice.len()
2094        )));
2095    }
2096
2097    let kern = validate_kernel(kernel, false)?;
2098
2099    let params = AdxrParams {
2100        period: period.or(Some(14)),
2101    };
2102    let adxr_in = AdxrInput::from_slices(high_slice, low_slice, close_slice, params);
2103
2104    let result_vec: Vec<f64> = py
2105        .allow_threads(|| adxr_with_kernel(&adxr_in, kern).map(|o| o.values))
2106        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2107
2108    Ok(result_vec.into_pyarray(py))
2109}
2110
2111#[cfg(feature = "python")]
2112#[pyclass(name = "AdxrStream")]
2113pub struct AdxrStreamPy {
2114    stream: AdxrStream,
2115}
2116
2117#[cfg(feature = "python")]
2118#[pymethods]
2119impl AdxrStreamPy {
2120    #[new]
2121    #[pyo3(signature = (period=None))]
2122    fn new(period: Option<usize>) -> PyResult<Self> {
2123        let params = AdxrParams {
2124            period: period.or(Some(14)),
2125        };
2126        let stream =
2127            AdxrStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2128        Ok(AdxrStreamPy { stream })
2129    }
2130
2131    fn update(&mut self, high: f64, low: f64, close: f64) -> Option<f64> {
2132        self.stream.update(high, low, close)
2133    }
2134}
2135
2136#[cfg(feature = "python")]
2137#[pyfunction(name = "adxr_batch")]
2138#[pyo3(signature = (high, low, close, period_range, kernel=None))]
2139pub fn adxr_batch_py<'py>(
2140    py: Python<'py>,
2141    high: numpy::PyReadonlyArray1<'py, f64>,
2142    low: numpy::PyReadonlyArray1<'py, f64>,
2143    close: numpy::PyReadonlyArray1<'py, f64>,
2144    period_range: (usize, usize, usize),
2145    kernel: Option<&str>,
2146) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
2147    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
2148    use pyo3::types::PyDict;
2149
2150    let h = high.as_slice()?;
2151    let l = low.as_slice()?;
2152    let c = close.as_slice()?;
2153
2154    if h.len() != l.len() || h.len() != c.len() {
2155        return Err(PyValueError::new_err(format!(
2156            "HLC data length mismatch: high={}, low={}, close={}",
2157            h.len(),
2158            l.len(),
2159            c.len()
2160        )));
2161    }
2162
2163    let sweep = AdxrBatchRange {
2164        period: period_range,
2165    };
2166    let combos_probe = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
2167    let rows = combos_probe.len();
2168    let cols = c.len();
2169
2170    let total = rows
2171        .checked_mul(cols)
2172        .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
2173
2174    let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2175    let out_slice = unsafe { out_arr.as_slice_mut()? };
2176
2177    let k = crate::utilities::kernel_validation::validate_kernel(kernel, true)?;
2178    let simd = match k {
2179        Kernel::Auto => match detect_best_batch_kernel() {
2180            Kernel::Avx512Batch => Kernel::Avx512,
2181            Kernel::Avx2Batch => Kernel::Avx2,
2182            _ => Kernel::Scalar,
2183        },
2184        Kernel::Avx512Batch => Kernel::Avx512,
2185        Kernel::Avx2Batch => Kernel::Avx2,
2186        Kernel::ScalarBatch => Kernel::Scalar,
2187        other => other,
2188    };
2189
2190    let combos = py
2191        .allow_threads(|| adxr_batch_inner_into(h, l, c, &sweep, simd, true, out_slice))
2192        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2193
2194    let dict = PyDict::new(py);
2195    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2196    dict.set_item(
2197        "periods",
2198        combos
2199            .iter()
2200            .map(|p| p.period.unwrap() as u64)
2201            .collect::<Vec<_>>()
2202            .into_pyarray(py),
2203    )?;
2204    Ok(dict)
2205}
2206
2207#[cfg(all(feature = "python", feature = "cuda"))]
2208#[pyfunction(name = "adxr_cuda_batch_dev")]
2209#[pyo3(signature = (high_f32, low_f32, close_f32, period_range, device_id=0))]
2210pub fn adxr_cuda_batch_dev_py<'py>(
2211    py: Python<'py>,
2212    high_f32: numpy::PyReadonlyArray1<'py, f32>,
2213    low_f32: numpy::PyReadonlyArray1<'py, f32>,
2214    close_f32: numpy::PyReadonlyArray1<'py, f32>,
2215    period_range: (usize, usize, usize),
2216    device_id: usize,
2217) -> PyResult<AdxrDeviceArrayF32Py> {
2218    if !cuda_available() {
2219        return Err(PyValueError::new_err("CUDA not available"));
2220    }
2221    let h = high_f32.as_slice()?;
2222    let l = low_f32.as_slice()?;
2223    let c = close_f32.as_slice()?;
2224    let sweep = AdxrBatchRange {
2225        period: period_range,
2226    };
2227    let (inner, ctx_arc, dev_id) = py.allow_threads(|| {
2228        let cuda = CudaAdxr::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2229        let (dev, _combos) = cuda
2230            .adxr_batch_dev(h, l, c, &sweep)
2231            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2232        Ok::<_, PyErr>((dev, cuda.context_arc_clone(), cuda.device_id()))
2233    })?;
2234    Ok(AdxrDeviceArrayF32Py {
2235        inner: Some(inner),
2236        _ctx: ctx_arc,
2237        device_id: dev_id,
2238    })
2239}
2240
2241#[cfg(all(feature = "python", feature = "cuda"))]
2242#[pyfunction(name = "adxr_cuda_many_series_one_param_dev")]
2243#[pyo3(signature = (high_tm_f32, low_tm_f32, close_tm_f32, period, device_id=0))]
2244pub fn adxr_cuda_many_series_one_param_dev_py<'py>(
2245    py: Python<'py>,
2246    high_tm_f32: numpy::PyReadonlyArray2<'py, f32>,
2247    low_tm_f32: numpy::PyReadonlyArray2<'py, f32>,
2248    close_tm_f32: numpy::PyReadonlyArray2<'py, f32>,
2249    period: usize,
2250    device_id: usize,
2251) -> PyResult<AdxrDeviceArrayF32Py> {
2252    if !cuda_available() {
2253        return Err(PyValueError::new_err("CUDA not available"));
2254    }
2255    let shape = high_tm_f32.shape();
2256    if shape.len() != 2 || low_tm_f32.shape() != shape || close_tm_f32.shape() != shape {
2257        return Err(PyValueError::new_err("expected three matching 2D arrays"));
2258    }
2259    let rows = shape[0];
2260    let cols = shape[1];
2261    let h = high_tm_f32.as_slice()?;
2262    let l = low_tm_f32.as_slice()?;
2263    let c = close_tm_f32.as_slice()?;
2264    let (inner, ctx_arc, dev_id) = py.allow_threads(|| {
2265        let cuda = CudaAdxr::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2266        let dev = cuda
2267            .adxr_many_series_one_param_time_major_dev(h, l, c, cols, rows, period)
2268            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2269        Ok::<_, PyErr>((dev, cuda.context_arc_clone(), cuda.device_id()))
2270    })?;
2271    Ok(AdxrDeviceArrayF32Py {
2272        inner: Some(inner),
2273        _ctx: ctx_arc,
2274        device_id: dev_id,
2275    })
2276}
2277
2278#[cfg(all(feature = "python", feature = "cuda"))]
2279#[pyclass(module = "ta_indicators.cuda", name = "AdxrDeviceArrayF32", unsendable)]
2280pub struct AdxrDeviceArrayF32Py {
2281    pub(crate) inner: Option<DeviceArrayF32>,
2282    pub(crate) _ctx: Arc<Context>,
2283    pub(crate) device_id: u32,
2284}
2285
2286#[cfg(all(feature = "python", feature = "cuda"))]
2287#[pymethods]
2288impl AdxrDeviceArrayF32Py {
2289    #[getter]
2290    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2291        let inner = self
2292            .inner
2293            .as_ref()
2294            .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
2295        let d = PyDict::new(py);
2296        d.set_item("shape", (inner.rows, inner.cols))?;
2297        d.set_item("typestr", "<f4")?;
2298        d.set_item(
2299            "strides",
2300            (
2301                inner.cols * std::mem::size_of::<f32>(),
2302                std::mem::size_of::<f32>(),
2303            ),
2304        )?;
2305        d.set_item("data", (inner.device_ptr() as usize, false))?;
2306
2307        d.set_item("version", 3)?;
2308        Ok(d)
2309    }
2310
2311    fn __dlpack_device__(&self) -> (i32, i32) {
2312        (2, self.device_id as i32)
2313    }
2314
2315    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2316    fn __dlpack__<'py>(
2317        &mut self,
2318        py: Python<'py>,
2319        stream: Option<pyo3::PyObject>,
2320        max_version: Option<pyo3::PyObject>,
2321        dl_device: Option<pyo3::PyObject>,
2322        copy: Option<pyo3::PyObject>,
2323    ) -> PyResult<PyObject> {
2324        use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
2325
2326        let (kdl, alloc_dev) = self.__dlpack_device__();
2327        if let Some(dev_obj) = dl_device.as_ref() {
2328            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2329                if dev_ty != kdl || dev_id != alloc_dev {
2330                    let wants_copy = copy
2331                        .as_ref()
2332                        .and_then(|c| c.extract::<bool>(py).ok())
2333                        .unwrap_or(false);
2334                    if wants_copy {
2335                        return Err(PyValueError::new_err(
2336                            "device copy not implemented for __dlpack__",
2337                        ));
2338                    } else {
2339                        return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
2340                    }
2341                }
2342            }
2343        }
2344        let _ = stream;
2345
2346        let inner = self
2347            .inner
2348            .take()
2349            .ok_or_else(|| PyValueError::new_err("__dlpack__ may only be called once"))?;
2350
2351        let rows = inner.rows;
2352        let cols = inner.cols;
2353        let buf = inner.buf;
2354
2355        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2356
2357        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2358    }
2359}
2360
2361#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2362#[wasm_bindgen]
2363pub fn adxr_js(
2364    high: &[f64],
2365    low: &[f64],
2366    close: &[f64],
2367    period: usize,
2368) -> Result<Vec<f64>, JsValue> {
2369    let params = AdxrParams {
2370        period: Some(period),
2371    };
2372    let input = AdxrInput::from_slices(high, low, close, params);
2373
2374    let mut output = vec![0.0; close.len()];
2375
2376    adxr_into_slice(&mut output, &input, Kernel::Auto)
2377        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2378
2379    Ok(output)
2380}
2381
2382#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2383#[wasm_bindgen]
2384pub fn adxr_batch_js(
2385    high: &[f64],
2386    low: &[f64],
2387    close: &[f64],
2388    period_start: usize,
2389    period_end: usize,
2390    period_step: usize,
2391) -> Result<Vec<f64>, JsValue> {
2392    let sweep = AdxrBatchRange {
2393        period: (period_start, period_end, period_step),
2394    };
2395
2396    adxr_batch_inner(high, low, close, &sweep, Kernel::Scalar, false)
2397        .map(|output| output.values)
2398        .map_err(|e| JsValue::from_str(&e.to_string()))
2399}
2400
2401#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2402#[wasm_bindgen]
2403pub fn adxr_batch_metadata_js(
2404    period_start: usize,
2405    period_end: usize,
2406    period_step: usize,
2407) -> Result<Vec<f64>, JsValue> {
2408    let sweep = AdxrBatchRange {
2409        period: (period_start, period_end, period_step),
2410    };
2411
2412    let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2413    let mut metadata = Vec::with_capacity(combos.len());
2414
2415    for combo in combos {
2416        metadata.push(combo.period.unwrap() as f64);
2417    }
2418
2419    Ok(metadata)
2420}
2421
2422#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2423#[derive(Serialize, Deserialize)]
2424pub struct AdxrBatchConfig {
2425    pub period_range: (usize, usize, usize),
2426}
2427
2428#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2429#[derive(Serialize, Deserialize)]
2430pub struct AdxrBatchJsOutput {
2431    pub values: Vec<f64>,
2432    pub combos: Vec<AdxrParams>,
2433    pub rows: usize,
2434    pub cols: usize,
2435}
2436
2437#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2438#[wasm_bindgen(js_name = adxr_batch)]
2439pub fn adxr_batch_unified_js(
2440    high: &[f64],
2441    low: &[f64],
2442    close: &[f64],
2443    config: JsValue,
2444) -> Result<JsValue, JsValue> {
2445    let config: AdxrBatchConfig = serde_wasm_bindgen::from_value(config)
2446        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2447
2448    let sweep = AdxrBatchRange {
2449        period: config.period_range,
2450    };
2451
2452    let output = adxr_batch_inner(high, low, close, &sweep, Kernel::Scalar, false)
2453        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2454
2455    let js_output = AdxrBatchJsOutput {
2456        values: output.values,
2457        combos: output.combos,
2458        rows: output.rows,
2459        cols: output.cols,
2460    };
2461
2462    serde_wasm_bindgen::to_value(&js_output)
2463        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2464}
2465
2466#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2467#[wasm_bindgen]
2468pub fn adxr_alloc(len: usize) -> *mut f64 {
2469    let mut vec = Vec::<f64>::with_capacity(len);
2470    let ptr = vec.as_mut_ptr();
2471    std::mem::forget(vec);
2472    ptr
2473}
2474
2475#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2476#[wasm_bindgen]
2477pub fn adxr_free(ptr: *mut f64, len: usize) {
2478    if !ptr.is_null() {
2479        unsafe {
2480            let _ = Vec::from_raw_parts(ptr, len, len);
2481        }
2482    }
2483}
2484
2485#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2486#[wasm_bindgen]
2487pub fn adxr_into(
2488    high_ptr: *const f64,
2489    low_ptr: *const f64,
2490    close_ptr: *const f64,
2491    out_ptr: *mut f64,
2492    len: usize,
2493    period: usize,
2494) -> Result<(), JsValue> {
2495    if high_ptr.is_null() || low_ptr.is_null() || close_ptr.is_null() || out_ptr.is_null() {
2496        return Err(JsValue::from_str("Null pointer provided"));
2497    }
2498
2499    unsafe {
2500        let high = std::slice::from_raw_parts(high_ptr, len);
2501        let low = std::slice::from_raw_parts(low_ptr, len);
2502        let close = std::slice::from_raw_parts(close_ptr, len);
2503
2504        if period == 0 || period > len {
2505            return Err(JsValue::from_str("Invalid period"));
2506        }
2507
2508        let params = AdxrParams {
2509            period: Some(period),
2510        };
2511        let input = AdxrInput::from_slices(high, low, close, params);
2512
2513        if high_ptr == out_ptr as *const f64
2514            || low_ptr == out_ptr as *const f64
2515            || close_ptr == out_ptr as *const f64
2516        {
2517            let mut temp = vec![0.0; len];
2518            adxr_into_slice(&mut temp, &input, Kernel::Auto)
2519                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2520            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2521            out.copy_from_slice(&temp);
2522        } else {
2523            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2524            adxr_into_slice(out, &input, Kernel::Auto)
2525                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2526        }
2527
2528        Ok(())
2529    }
2530}
2531
2532#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2533#[wasm_bindgen]
2534pub fn adxr_batch_into(
2535    high_ptr: *const f64,
2536    low_ptr: *const f64,
2537    close_ptr: *const f64,
2538    out_ptr: *mut f64,
2539    len: usize,
2540    period_start: usize,
2541    period_end: usize,
2542    period_step: usize,
2543) -> Result<usize, JsValue> {
2544    if high_ptr.is_null() || low_ptr.is_null() || close_ptr.is_null() || out_ptr.is_null() {
2545        return Err(JsValue::from_str("Null pointer provided"));
2546    }
2547    unsafe {
2548        let h = std::slice::from_raw_parts(high_ptr, len);
2549        let l = std::slice::from_raw_parts(low_ptr, len);
2550        let c = std::slice::from_raw_parts(close_ptr, len);
2551
2552        let sweep = AdxrBatchRange {
2553            period: (period_start, period_end, period_step),
2554        };
2555        let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2556        let rows = combos.len();
2557        let cols = len;
2558        let total = rows
2559            .checked_mul(cols)
2560            .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
2561
2562        let out = std::slice::from_raw_parts_mut(out_ptr, total);
2563
2564        adxr_batch_inner_into(h, l, c, &sweep, Kernel::Scalar, false, out)
2565            .map_err(|e| JsValue::from_str(&e.to_string()))?;
2566
2567        Ok(rows)
2568    }
2569}