Skip to main content

vector_ta/indicators/
ad.rs

1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::CudaAd;
3use crate::utilities::data_loader::Candles;
4#[cfg(all(feature = "python", feature = "cuda"))]
5use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
6use crate::utilities::enums::Kernel;
7use crate::utilities::helpers::{
8    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, make_uninit_matrix,
9};
10#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
11use core::arch::x86_64::*;
12#[cfg(all(feature = "python", feature = "cuda"))]
13use numpy::PyReadonlyArray1;
14#[cfg(feature = "python")]
15use pyo3::exceptions::PyValueError;
16#[cfg(feature = "python")]
17use pyo3::types::{PyDict, PyList, PyListMethods};
18#[cfg(feature = "python")]
19use pyo3::{pyfunction, Bound, PyResult, Python};
20#[cfg(not(target_arch = "wasm32"))]
21use rayon::prelude::*;
22use thiserror::Error;
23
24#[derive(Debug, Clone)]
25pub enum AdData<'a> {
26    Candles {
27        candles: &'a Candles,
28    },
29    Slices {
30        high: &'a [f64],
31        low: &'a [f64],
32        close: &'a [f64],
33        volume: &'a [f64],
34    },
35}
36
37#[derive(Debug, Clone, Default)]
38pub struct AdParams {}
39
40#[derive(Debug, Clone)]
41pub struct AdInput<'a> {
42    pub data: AdData<'a>,
43    pub params: AdParams,
44}
45
46impl<'a> AdInput<'a> {
47    #[inline]
48    pub fn from_candles(candles: &'a Candles, params: AdParams) -> Self {
49        Self {
50            data: AdData::Candles { candles },
51            params,
52        }
53    }
54
55    #[inline]
56    pub fn from_slices(
57        high: &'a [f64],
58        low: &'a [f64],
59        close: &'a [f64],
60        volume: &'a [f64],
61        params: AdParams,
62    ) -> Self {
63        Self {
64            data: AdData::Slices {
65                high,
66                low,
67                close,
68                volume,
69            },
70            params,
71        }
72    }
73
74    #[inline]
75    pub fn with_default_candles(candles: &'a Candles) -> Self {
76        Self::from_candles(candles, AdParams::default())
77    }
78}
79
80#[derive(Debug, Clone)]
81pub struct AdOutput {
82    pub values: Vec<f64>,
83}
84
85#[derive(Copy, Clone, Debug, Default)]
86pub struct AdBuilder {
87    kernel: Kernel,
88}
89
90impl AdBuilder {
91    #[inline(always)]
92    pub fn new() -> Self {
93        Self {
94            kernel: Kernel::Auto,
95        }
96    }
97
98    #[inline(always)]
99    pub fn kernel(mut self, k: Kernel) -> Self {
100        self.kernel = k;
101        self
102    }
103
104    #[inline(always)]
105    pub fn apply(self, c: &Candles) -> Result<AdOutput, AdError> {
106        let input = AdInput::from_candles(c, AdParams::default());
107        ad_with_kernel(&input, self.kernel)
108    }
109
110    #[inline(always)]
111    pub fn apply_slices(
112        self,
113        high: &[f64],
114        low: &[f64],
115        close: &[f64],
116        volume: &[f64],
117    ) -> Result<AdOutput, AdError> {
118        let input = AdInput::from_slices(high, low, close, volume, AdParams::default());
119        ad_with_kernel(&input, self.kernel)
120    }
121
122    #[inline(always)]
123    pub fn into_stream(self) -> Result<AdStream, AdError> {
124        AdStream::try_new()
125    }
126}
127
128#[derive(Debug, Error)]
129pub enum AdError {
130    #[error("ad: candle field error: {0}")]
131    CandleFieldError(String),
132    #[error(
133        "ad: Data length mismatch: high={high_len}, low={low_len}, close={close_len}, volume={volume_len}"
134    )]
135    DataLengthMismatch {
136        high_len: usize,
137        low_len: usize,
138        close_len: usize,
139        volume_len: usize,
140    },
141    #[error("ad: invalid period: period={period}, data_len={data_len}")]
142    InvalidPeriod { period: usize, data_len: usize },
143    #[error("ad: output length mismatch: expected={expected}, got={got}")]
144    OutputLengthMismatch { expected: usize, got: usize },
145    #[error("ad: not enough valid data: needed={needed}, valid={valid}")]
146    NotEnoughValidData { needed: usize, valid: usize },
147    #[error("ad: empty input data")]
148    EmptyInputData,
149    #[error("ad: all values are NaN")]
150    AllValuesNaN,
151    #[error("ad: invalid range: start={start}, end={end}, step={step}")]
152    InvalidRange {
153        start: isize,
154        end: isize,
155        step: isize,
156    },
157    #[error("ad: invalid kernel for batch: {0:?}")]
158    InvalidKernelForBatch(Kernel),
159    #[error("ad: invalid input: {0}")]
160    InvalidInput(String),
161}
162
163#[inline]
164pub fn ad(input: &AdInput) -> Result<AdOutput, AdError> {
165    ad_with_kernel(input, Kernel::Auto)
166}
167
168pub fn ad_with_kernel(input: &AdInput, kernel: Kernel) -> Result<AdOutput, AdError> {
169    let (high, low, close, volume): (&[f64], &[f64], &[f64], &[f64]) = match &input.data {
170        AdData::Candles { candles } => {
171            let high = candles
172                .select_candle_field("high")
173                .map_err(|e| AdError::CandleFieldError(e.to_string()))?;
174            let low = candles
175                .select_candle_field("low")
176                .map_err(|e| AdError::CandleFieldError(e.to_string()))?;
177            let close = candles
178                .select_candle_field("close")
179                .map_err(|e| AdError::CandleFieldError(e.to_string()))?;
180            let volume = candles
181                .select_candle_field("volume")
182                .map_err(|e| AdError::CandleFieldError(e.to_string()))?;
183            (high, low, close, volume)
184        }
185        AdData::Slices {
186            high,
187            low,
188            close,
189            volume,
190        } => (*high, *low, *close, *volume),
191    };
192
193    if high.len() != low.len() || high.len() != close.len() || high.len() != volume.len() {
194        return Err(AdError::DataLengthMismatch {
195            high_len: high.len(),
196            low_len: low.len(),
197            close_len: close.len(),
198            volume_len: volume.len(),
199        });
200    }
201
202    let size = high.len();
203    if size == 0 {
204        return Err(AdError::EmptyInputData);
205    }
206
207    let mut chosen = match kernel {
208        Kernel::Auto => detect_best_kernel(),
209        k => k,
210    };
211
212    #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
213    if matches!(kernel, Kernel::Auto) && matches!(chosen, Kernel::Avx512 | Kernel::Avx512Batch) {
214        chosen = Kernel::Avx2;
215    }
216
217    let mut out = alloc_with_nan_prefix(size, 0);
218
219    unsafe {
220        match chosen {
221            Kernel::Scalar | Kernel::ScalarBatch => ad_scalar(high, low, close, volume, &mut out),
222            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
223            Kernel::Avx2 | Kernel::Avx2Batch => ad_avx2(high, low, close, volume, &mut out),
224            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
225            Kernel::Avx512 | Kernel::Avx512Batch => ad_avx512(high, low, close, volume, &mut out),
226            _ => unreachable!(),
227        }
228    }
229    Ok(AdOutput { values: out })
230}
231
232#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
233#[inline]
234pub fn ad_into(input: &AdInput, out: &mut [f64]) -> Result<(), AdError> {
235    ad_into_slice(out, input, Kernel::Auto)
236}
237
238pub fn ad_into_slice(dst: &mut [f64], input: &AdInput, kern: Kernel) -> Result<(), AdError> {
239    let (high, low, close, volume) = match &input.data {
240        AdData::Candles { candles, .. } => (
241            &candles.high[..],
242            &candles.low[..],
243            &candles.close[..],
244            &candles.volume[..],
245        ),
246        AdData::Slices {
247            high,
248            low,
249            close,
250            volume,
251        } => (*high, *low, *close, *volume),
252    };
253
254    if high.is_empty() {
255        return Err(AdError::EmptyInputData);
256    }
257
258    if high.len() != low.len() || high.len() != close.len() || high.len() != volume.len() {
259        return Err(AdError::DataLengthMismatch {
260            high_len: high.len(),
261            low_len: low.len(),
262            close_len: close.len(),
263            volume_len: volume.len(),
264        });
265    }
266
267    if dst.len() != high.len() {
268        return Err(AdError::OutputLengthMismatch {
269            expected: high.len(),
270            got: dst.len(),
271        });
272    }
273
274    match kern {
275        Kernel::Auto => {
276            let mut k = detect_best_kernel();
277            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
278            if matches!(k, Kernel::Avx512) {
279                k = Kernel::Avx2;
280            }
281            match k {
282                Kernel::Scalar => ad_scalar(high, low, close, volume, dst),
283                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
284                Kernel::Avx2 => ad_avx2(high, low, close, volume, dst),
285                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
286                Kernel::Avx512 => ad_avx512(high, low, close, volume, dst),
287                _ => ad_scalar(high, low, close, volume, dst),
288            }
289        }
290        Kernel::Scalar => ad_scalar(high, low, close, volume, dst),
291        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
292        Kernel::Avx2 => ad_avx2(high, low, close, volume, dst),
293        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
294        Kernel::Avx512 => ad_avx512(high, low, close, volume, dst),
295        _ => ad_scalar(high, low, close, volume, dst),
296    }
297
298    Ok(())
299}
300
301#[inline]
302pub fn ad_scalar(high: &[f64], low: &[f64], close: &[f64], volume: &[f64], out: &mut [f64]) {
303    debug_assert_eq!(high.len(), low.len());
304    debug_assert_eq!(high.len(), close.len());
305    debug_assert_eq!(high.len(), volume.len());
306    debug_assert_eq!(high.len(), out.len());
307
308    let mut sum = 0.0f64;
309    for ((((&h, &l), &c), &v), o) in high
310        .iter()
311        .zip(low)
312        .zip(close)
313        .zip(volume)
314        .zip(out.iter_mut())
315    {
316        let hl = h - l;
317        if hl != 0.0 {
318            let num = (c - l) - (h - c);
319            sum += (num / hl) * v;
320        }
321        *o = sum;
322    }
323}
324
325#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
326#[inline]
327pub fn ad_avx2(high: &[f64], low: &[f64], close: &[f64], volume: &[f64], out: &mut [f64]) {
328    unsafe { ad_avx2_inner(high, low, close, volume, out) }
329}
330
331#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
332#[target_feature(enable = "avx2")]
333unsafe fn ad_avx2_inner(high: &[f64], low: &[f64], close: &[f64], volume: &[f64], out: &mut [f64]) {
334    use core::arch::x86_64::*;
335
336    let n = high.len();
337    let h = high.as_ptr();
338    let l = low.as_ptr();
339    let c = close.as_ptr();
340    let v = volume.as_ptr();
341    let o = out.as_mut_ptr();
342
343    let mut base = 0.0f64;
344    let mut i = 0usize;
345
346    while i + 4 <= n {
347        let hv = _mm256_loadu_pd(h.add(i));
348        let lv = _mm256_loadu_pd(l.add(i));
349        let cv = _mm256_loadu_pd(c.add(i));
350        let vv = _mm256_loadu_pd(v.add(i));
351
352        let hl = _mm256_sub_pd(hv, lv);
353        let num = _mm256_sub_pd(_mm256_sub_pd(cv, lv), _mm256_sub_pd(hv, cv));
354        let mfm = _mm256_div_pd(num, hl);
355        let mfv_unmasked = _mm256_mul_pd(mfm, vv);
356
357        let z = _mm256_set1_pd(0.0);
358        let mask = _mm256_cmp_pd(hl, z, _CMP_NEQ_OQ);
359        let mfv = _mm256_and_pd(mfv_unmasked, mask);
360
361        let mut tmp: [f64; 4] = core::mem::zeroed();
362        _mm256_storeu_pd(tmp.as_mut_ptr(), mfv);
363        *o.add(i + 0) = {
364            base += tmp[0];
365            base
366        };
367        *o.add(i + 1) = {
368            base += tmp[1];
369            base
370        };
371        *o.add(i + 2) = {
372            base += tmp[2];
373            base
374        };
375        *o.add(i + 3) = {
376            base += tmp[3];
377            base
378        };
379
380        i += 4;
381    }
382
383    while i < n {
384        let hi = *h.add(i);
385        let lo = *l.add(i);
386        let cl = *c.add(i);
387        let vo = *v.add(i);
388        let hl = hi - lo;
389        if hl != 0.0 {
390            let num = (cl - lo) - (hi - cl);
391            base += (num / hl) * vo;
392        }
393        *o.add(i) = base;
394        i += 1;
395    }
396}
397
398#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
399#[inline]
400pub fn ad_avx512(high: &[f64], low: &[f64], close: &[f64], volume: &[f64], out: &mut [f64]) {
401    unsafe { ad_avx512_inner(high, low, close, volume, out) }
402}
403
404#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
405#[target_feature(enable = "avx512f")]
406unsafe fn ad_avx512_inner(
407    high: &[f64],
408    low: &[f64],
409    close: &[f64],
410    volume: &[f64],
411    out: &mut [f64],
412) {
413    use core::arch::x86_64::*;
414
415    let n = high.len();
416    let h = high.as_ptr();
417    let l = low.as_ptr();
418    let c = close.as_ptr();
419    let v = volume.as_ptr();
420    let o = out.as_mut_ptr();
421
422    let mut base = 0.0f64;
423    let mut i = 0usize;
424
425    while i + 8 <= n {
426        let hv = _mm512_loadu_pd(h.add(i));
427        let lv = _mm512_loadu_pd(l.add(i));
428        let cv = _mm512_loadu_pd(c.add(i));
429        let vv = _mm512_loadu_pd(v.add(i));
430
431        let hl = _mm512_sub_pd(hv, lv);
432        let num = _mm512_sub_pd(_mm512_sub_pd(cv, lv), _mm512_sub_pd(hv, cv));
433        let mfm = _mm512_div_pd(num, hl);
434        let mfv_unmasked = _mm512_mul_pd(mfm, vv);
435
436        let mask = _mm512_cmpneq_pd_mask(hl, _mm512_set1_pd(0.0));
437        let mfv = _mm512_maskz_mov_pd(mask, mfv_unmasked);
438
439        let mut tmp = core::mem::MaybeUninit::<[f64; 8]>::uninit();
440        _mm512_storeu_pd(tmp.as_mut_ptr() as *mut f64, mfv);
441        let vals = tmp.assume_init();
442
443        *o.add(i + 0) = {
444            base += vals[0];
445            base
446        };
447        *o.add(i + 1) = {
448            base += vals[1];
449            base
450        };
451        *o.add(i + 2) = {
452            base += vals[2];
453            base
454        };
455        *o.add(i + 3) = {
456            base += vals[3];
457            base
458        };
459        *o.add(i + 4) = {
460            base += vals[4];
461            base
462        };
463        *o.add(i + 5) = {
464            base += vals[5];
465            base
466        };
467        *o.add(i + 6) = {
468            base += vals[6];
469            base
470        };
471        *o.add(i + 7) = {
472            base += vals[7];
473            base
474        };
475
476        i += 8;
477    }
478
479    while i < n {
480        let hi = *h.add(i);
481        let lo = *l.add(i);
482        let cl = *c.add(i);
483        let vo = *v.add(i);
484        let hl = hi - lo;
485        if hl != 0.0 {
486            let num = (cl - lo) - (hi - cl);
487            base += (num / hl) * vo;
488        }
489        *o.add(i) = base;
490        i += 1;
491    }
492}
493
494#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
495#[inline]
496pub fn ad_avx512_short(high: &[f64], low: &[f64], close: &[f64], volume: &[f64], out: &mut [f64]) {
497    ad_avx512(high, low, close, volume, out)
498}
499
500#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
501#[inline]
502pub fn ad_avx512_long(high: &[f64], low: &[f64], close: &[f64], volume: &[f64], out: &mut [f64]) {
503    ad_avx512(high, low, close, volume, out)
504}
505
506#[inline]
507pub fn ad_batch_with_kernel(data: &AdBatchInput, k: Kernel) -> Result<AdBatchOutput, AdError> {
508    let mut kernel = match k {
509        Kernel::Auto => detect_best_batch_kernel(),
510        other if other.is_batch() => other,
511        other => return Err(AdError::InvalidKernelForBatch(other)),
512    };
513    #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
514    if matches!(k, Kernel::Auto) && matches!(kernel, Kernel::Avx512Batch) {
515        kernel = Kernel::Avx2Batch;
516    }
517
518    let simd = match kernel {
519        Kernel::Avx512Batch => Kernel::Avx512,
520        Kernel::Avx2Batch => Kernel::Avx2,
521        Kernel::ScalarBatch => Kernel::Scalar,
522        _ => unreachable!(),
523    };
524    ad_batch_par_slice(data, simd)
525}
526
527#[derive(Clone, Debug)]
528pub struct AdBatchInput<'a> {
529    pub highs: &'a [&'a [f64]],
530    pub lows: &'a [&'a [f64]],
531    pub closes: &'a [&'a [f64]],
532    pub volumes: &'a [&'a [f64]],
533}
534
535#[derive(Clone, Debug)]
536pub struct AdBatchOutput {
537    pub values: Vec<f64>,
538    pub rows: usize,
539    pub cols: usize,
540}
541
542#[inline(always)]
543pub fn ad_batch_slice(data: &AdBatchInput, kern: Kernel) -> Result<AdBatchOutput, AdError> {
544    ad_batch_inner(data, kern, false)
545}
546
547#[inline(always)]
548pub fn ad_batch_par_slice(data: &AdBatchInput, kern: Kernel) -> Result<AdBatchOutput, AdError> {
549    ad_batch_inner(data, kern, true)
550}
551
552fn ad_batch_inner(
553    data: &AdBatchInput,
554    kern: Kernel,
555    parallel: bool,
556) -> Result<AdBatchOutput, AdError> {
557    let rows = data.highs.len();
558    let cols = if rows > 0 { data.highs[0].len() } else { 0 };
559    let len = rows
560        .checked_mul(cols)
561        .ok_or_else(|| AdError::InvalidInput("rows*cols overflow".into()))?;
562
563    let mut buf_mu = make_uninit_matrix(rows, cols);
564    let values = unsafe {
565        let ptr = buf_mu.as_mut_ptr() as *mut f64;
566        let slice = std::slice::from_raw_parts_mut(ptr, len);
567
568        ad_batch_inner_into(data, kern, parallel, slice)?;
569
570        Vec::from_raw_parts(ptr, len, len)
571    };
572    std::mem::forget(buf_mu);
573
574    Ok(AdBatchOutput { values, rows, cols })
575}
576
577fn ad_batch_inner_into(
578    data: &AdBatchInput,
579    kern: Kernel,
580    parallel: bool,
581    out: &mut [f64],
582) -> Result<(), AdError> {
583    let rows = data.highs.len();
584    let cols = if rows > 0 { data.highs[0].len() } else { 0 };
585
586    if data.lows.len() != rows || data.closes.len() != rows || data.volumes.len() != rows {
587        return Err(AdError::DataLengthMismatch {
588            high_len: data.highs.len(),
589            low_len: data.lows.len(),
590            close_len: data.closes.len(),
591            volume_len: data.volumes.len(),
592        });
593    }
594
595    for row in 0..rows {
596        let h_len = data.highs[row].len();
597        let l_len = data.lows[row].len();
598        let c_len = data.closes[row].len();
599        let v_len = data.volumes[row].len();
600
601        if h_len != cols || l_len != cols || c_len != cols || v_len != cols {
602            return Err(AdError::DataLengthMismatch {
603                high_len: h_len,
604                low_len: l_len,
605                close_len: c_len,
606                volume_len: v_len,
607            });
608        }
609    }
610
611    let expected = rows
612        .checked_mul(cols)
613        .ok_or_else(|| AdError::InvalidInput("rows*cols overflow".into()))?;
614    if out.len() != expected {
615        return Err(AdError::OutputLengthMismatch {
616            expected,
617            got: out.len(),
618        });
619    }
620
621    let mut actual = match kern {
622        Kernel::Auto => detect_best_batch_kernel(),
623        k => k,
624    };
625    #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
626    if matches!(kern, Kernel::Auto) && matches!(actual, Kernel::Avx512Batch) {
627        actual = Kernel::Avx2Batch;
628    }
629
630    let do_row = |row: usize, dst: &mut [f64]| unsafe {
631        match actual {
632            Kernel::Scalar | Kernel::ScalarBatch => ad_row_scalar(
633                data.highs[row],
634                data.lows[row],
635                data.closes[row],
636                data.volumes[row],
637                dst,
638            ),
639            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
640            Kernel::Avx2 | Kernel::Avx2Batch => ad_row_avx2(
641                data.highs[row],
642                data.lows[row],
643                data.closes[row],
644                data.volumes[row],
645                dst,
646            ),
647            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
648            Kernel::Avx512 | Kernel::Avx512Batch => ad_row_avx512(
649                data.highs[row],
650                data.lows[row],
651                data.closes[row],
652                data.volumes[row],
653                dst,
654            ),
655            _ => ad_row_scalar(
656                data.highs[row],
657                data.lows[row],
658                data.closes[row],
659                data.volumes[row],
660                dst,
661            ),
662        }
663    };
664
665    if parallel {
666        #[cfg(not(target_arch = "wasm32"))]
667        {
668            use rayon::prelude::*;
669            out.par_chunks_mut(cols)
670                .enumerate()
671                .for_each(|(r, s)| do_row(r, s));
672        }
673        #[cfg(target_arch = "wasm32")]
674        {
675            for (r, s) in out.chunks_mut(cols).enumerate() {
676                do_row(r, s);
677            }
678        }
679    } else {
680        for (r, s) in out.chunks_mut(cols).enumerate() {
681            do_row(r, s);
682        }
683    }
684
685    Ok(())
686}
687
688#[inline(always)]
689pub unsafe fn ad_row_scalar(
690    high: &[f64],
691    low: &[f64],
692    close: &[f64],
693    volume: &[f64],
694    out: &mut [f64],
695) {
696    ad_scalar(high, low, close, volume, out)
697}
698
699#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
700#[inline(always)]
701pub unsafe fn ad_row_avx2(
702    high: &[f64],
703    low: &[f64],
704    close: &[f64],
705    volume: &[f64],
706    out: &mut [f64],
707) {
708    ad_avx2(high, low, close, volume, out)
709}
710
711#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
712#[inline(always)]
713pub unsafe fn ad_row_avx512(
714    high: &[f64],
715    low: &[f64],
716    close: &[f64],
717    volume: &[f64],
718    out: &mut [f64],
719) {
720    ad_avx512(high, low, close, volume, out)
721}
722
723#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
724#[inline(always)]
725pub unsafe fn ad_row_avx512_short(
726    high: &[f64],
727    low: &[f64],
728    close: &[f64],
729    volume: &[f64],
730    out: &mut [f64],
731) {
732    ad_avx512(high, low, close, volume, out)
733}
734
735#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
736#[inline(always)]
737pub unsafe fn ad_row_avx512_long(
738    high: &[f64],
739    low: &[f64],
740    close: &[f64],
741    volume: &[f64],
742    out: &mut [f64],
743) {
744    ad_avx512(high, low, close, volume, out)
745}
746
747#[derive(Debug, Clone)]
748pub struct AdStream {
749    sum: f64,
750}
751
752impl AdStream {
753    #[inline(always)]
754    pub fn try_new() -> Result<Self, AdError> {
755        Ok(Self { sum: 0.0 })
756    }
757
758    #[inline(always)]
759    pub fn update(&mut self, high: f64, low: f64, close: f64, volume: f64) -> f64 {
760        if volume == 0.0 {
761            return self.sum;
762        }
763
764        let hl = high - low;
765        if hl != 0.0 {
766            let num = (close - low) - (high - close);
767
768            self.sum += (num / hl) * volume;
769        }
770        self.sum
771    }
772}
773
774#[cfg(all(feature = "python", feature = "cuda"))]
775use cust::context::Context;
776#[cfg(all(feature = "python", feature = "cuda"))]
777use cust::memory::DeviceBuffer;
778#[cfg(all(feature = "python", feature = "cuda"))]
779use std::sync::Arc;
780#[cfg(all(feature = "python", feature = "cuda"))]
781#[pyclass(module = "ta_indicators.cuda", unsendable)]
782pub struct AdDeviceArrayF32Py {
783    pub(crate) buf: Option<DeviceBuffer<f32>>,
784    pub(crate) rows: usize,
785    pub(crate) cols: usize,
786    pub(crate) _ctx: Arc<Context>,
787    pub(crate) device_id: u32,
788}
789#[cfg(all(feature = "python", feature = "cuda"))]
790#[pymethods]
791impl AdDeviceArrayF32Py {
792    #[getter]
793    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
794        let d = PyDict::new(py);
795        d.set_item("shape", (self.rows, self.cols))?;
796        d.set_item("typestr", "<f4")?;
797        d.set_item(
798            "strides",
799            (
800                self.cols * std::mem::size_of::<f32>(),
801                std::mem::size_of::<f32>(),
802            ),
803        )?;
804        let ptr = self
805            .buf
806            .as_ref()
807            .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?
808            .as_device_ptr()
809            .as_raw() as usize;
810        d.set_item("data", (ptr, false))?;
811
812        d.set_item("version", 3)?;
813        Ok(d)
814    }
815
816    fn __dlpack_device__(&self) -> (i32, i32) {
817        (2, self.device_id as i32)
818    }
819
820    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
821    fn __dlpack__<'py>(
822        &mut self,
823        py: Python<'py>,
824        stream: Option<pyo3::PyObject>,
825        max_version: Option<pyo3::PyObject>,
826        dl_device: Option<pyo3::PyObject>,
827        copy: Option<pyo3::PyObject>,
828    ) -> PyResult<PyObject> {
829        let (kdl, alloc_dev) = self.__dlpack_device__();
830        if let Some(dev_obj) = dl_device.as_ref() {
831            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
832                if dev_ty != kdl || dev_id != alloc_dev {
833                    let wants_copy = copy
834                        .as_ref()
835                        .and_then(|c| c.extract::<bool>(py).ok())
836                        .unwrap_or(false);
837                    if wants_copy {
838                        return Err(PyValueError::new_err(
839                            "device copy not implemented for __dlpack__",
840                        ));
841                    } else {
842                        return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
843                    }
844                }
845            }
846        }
847        let _ = stream;
848
849        let buf = self
850            .buf
851            .take()
852            .ok_or_else(|| PyValueError::new_err("__dlpack__ may only be called once"))?;
853
854        let rows = self.rows;
855        let cols = self.cols;
856
857        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
858
859        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
860    }
861}
862#[cfg(all(feature = "python", feature = "cuda"))]
863#[pyfunction(name = "ad_cuda_dev")]
864#[pyo3(signature = (high_f32, low_f32, close_f32, volume_f32, device_id=0))]
865pub fn ad_cuda_dev_py(
866    py: Python<'_>,
867    high_f32: PyReadonlyArray1<'_, f32>,
868    low_f32: PyReadonlyArray1<'_, f32>,
869    close_f32: PyReadonlyArray1<'_, f32>,
870    volume_f32: PyReadonlyArray1<'_, f32>,
871    device_id: usize,
872) -> PyResult<AdDeviceArrayF32Py> {
873    use crate::cuda::cuda_available;
874    if !cuda_available() {
875        return Err(PyValueError::new_err("CUDA not available"));
876    }
877
878    let high = high_f32.as_slice()?;
879    let low = low_f32.as_slice()?;
880    let close = close_f32.as_slice()?;
881    let volume = volume_f32.as_slice()?;
882
883    let (buf, rows, cols, ctx, dev_id) = py.allow_threads(|| {
884        let cuda = CudaAd::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
885        let out = cuda
886            .ad_series_dev(high, low, close, volume)
887            .map_err(|e| PyValueError::new_err(e.to_string()))?;
888        let ctx = cuda.context_arc();
889        Ok::<_, pyo3::PyErr>((out.buf, out.rows, out.cols, ctx, cuda.device_id()))
890    })?;
891
892    Ok(AdDeviceArrayF32Py {
893        buf: Some(buf),
894        rows,
895        cols,
896        _ctx: ctx,
897        device_id: dev_id,
898    })
899}
900
901#[cfg(all(feature = "python", feature = "cuda"))]
902#[pyfunction(name = "ad_cuda_many_series_one_param_dev")]
903#[pyo3(signature = (high_tm_f32, low_tm_f32, close_tm_f32, volume_tm_f32, cols, rows, device_id=0))]
904pub fn ad_cuda_many_series_one_param_dev_py(
905    py: Python<'_>,
906    high_tm_f32: PyReadonlyArray1<'_, f32>,
907    low_tm_f32: PyReadonlyArray1<'_, f32>,
908    close_tm_f32: PyReadonlyArray1<'_, f32>,
909    volume_tm_f32: PyReadonlyArray1<'_, f32>,
910    cols: usize,
911    rows: usize,
912    device_id: usize,
913) -> PyResult<AdDeviceArrayF32Py> {
914    use crate::cuda::cuda_available;
915    if !cuda_available() {
916        return Err(PyValueError::new_err("CUDA not available"));
917    }
918    let high_tm = high_tm_f32.as_slice()?;
919    let low_tm = low_tm_f32.as_slice()?;
920    let close_tm = close_tm_f32.as_slice()?;
921    let volume_tm = volume_tm_f32.as_slice()?;
922
923    let (buf, r_out, c_out, ctx, dev_id) = py.allow_threads(|| {
924        let cuda = CudaAd::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
925        let out = cuda
926            .ad_many_series_one_param_time_major_dev(
927                high_tm, low_tm, close_tm, volume_tm, cols, rows,
928            )
929            .map_err(|e| PyValueError::new_err(e.to_string()))?;
930        let ctx = cuda.context_arc();
931        Ok::<_, pyo3::PyErr>((out.buf, out.rows, out.cols, ctx, cuda.device_id()))
932    })?;
933
934    Ok(AdDeviceArrayF32Py {
935        buf: Some(buf),
936        rows: r_out,
937        cols: c_out,
938        _ctx: ctx,
939        device_id: dev_id,
940    })
941}
942
943#[derive(Clone, Debug, Default)]
944pub struct AdBatchBuilder {
945    pub kernel: Kernel,
946}
947
948impl AdBatchBuilder {
949    pub fn new() -> Self {
950        Self {
951            kernel: Kernel::Auto,
952        }
953    }
954    pub fn kernel(mut self, k: Kernel) -> Self {
955        self.kernel = k;
956        self
957    }
958
959    pub fn apply_slices(
960        self,
961        highs: &[&[f64]],
962        lows: &[&[f64]],
963        closes: &[&[f64]],
964        volumes: &[&[f64]],
965    ) -> Result<AdBatchOutput, AdError> {
966        let batch = AdBatchInput {
967            highs,
968            lows,
969            closes,
970            volumes,
971        };
972        ad_batch_with_kernel(&batch, self.kernel)
973    }
974}
975
976#[cfg(feature = "python")]
977use numpy::{IntoPyArray, PyArray1};
978#[cfg(feature = "python")]
979use pyo3::prelude::*;
980#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
981use wasm_bindgen::prelude::*;
982
983#[cfg(feature = "python")]
984#[pyfunction(name = "ad")]
985#[pyo3(signature = (high, low, close, volume, kernel=None))]
986
987pub fn ad_py<'py>(
988    py: Python<'py>,
989    high: numpy::PyReadonlyArray1<'py, f64>,
990    low: numpy::PyReadonlyArray1<'py, f64>,
991    close: numpy::PyReadonlyArray1<'py, f64>,
992    volume: numpy::PyReadonlyArray1<'py, f64>,
993    kernel: Option<&str>,
994) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
995    use numpy::{IntoPyArray, PyArrayMethods};
996
997    let high_slice = high.as_slice()?;
998    let low_slice = low.as_slice()?;
999    let close_slice = close.as_slice()?;
1000    let volume_slice = volume.as_slice()?;
1001
1002    if high_slice.is_empty()
1003        || low_slice.is_empty()
1004        || close_slice.is_empty()
1005        || volume_slice.is_empty()
1006    {
1007        return Err(PyValueError::new_err("Not enough data"));
1008    }
1009
1010    let kern = crate::utilities::kernel_validation::validate_kernel(kernel, false)?;
1011
1012    let input = AdInput::from_slices(
1013        high_slice,
1014        low_slice,
1015        close_slice,
1016        volume_slice,
1017        AdParams::default(),
1018    );
1019
1020    let result_vec: Vec<f64> = py
1021        .allow_threads(|| ad_with_kernel(&input, kern).map(|o| o.values))
1022        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1023
1024    Ok(result_vec.into_pyarray(py))
1025}
1026
1027#[cfg(feature = "python")]
1028#[pyclass(name = "AdStream")]
1029pub struct AdStreamPy {
1030    stream: AdStream,
1031}
1032
1033#[cfg(feature = "python")]
1034#[pymethods]
1035impl AdStreamPy {
1036    #[new]
1037    fn new() -> PyResult<Self> {
1038        let stream = AdStream::try_new().map_err(|e| PyValueError::new_err(e.to_string()))?;
1039        Ok(AdStreamPy { stream })
1040    }
1041
1042    fn update(&mut self, high: f64, low: f64, close: f64, volume: f64) -> f64 {
1043        self.stream.update(high, low, close, volume)
1044    }
1045}
1046
1047#[cfg(feature = "python")]
1048#[pyfunction(name = "ad_batch")]
1049#[pyo3(signature = (highs, lows, closes, volumes, kernel=None))]
1050
1051pub fn ad_batch_py<'py>(
1052    py: Python<'py>,
1053    highs: &Bound<'py, PyList>,
1054    lows: &Bound<'py, PyList>,
1055    closes: &Bound<'py, PyList>,
1056    volumes: &Bound<'py, PyList>,
1057    kernel: Option<&str>,
1058) -> PyResult<Bound<'py, PyDict>> {
1059    use numpy::{PyArray1, PyArrayMethods, PyReadonlyArray1};
1060    use pyo3::types::PyDict;
1061
1062    let rows = highs.len();
1063    if lows.len() != rows || closes.len() != rows || volumes.len() != rows {
1064        return Err(PyValueError::new_err(
1065            "All input lists must have the same length",
1066        ));
1067    }
1068
1069    let mut high_arrays: Vec<PyReadonlyArray1<f64>> = Vec::with_capacity(rows);
1070    let mut low_arrays: Vec<PyReadonlyArray1<f64>> = Vec::with_capacity(rows);
1071    let mut close_arrays: Vec<PyReadonlyArray1<f64>> = Vec::with_capacity(rows);
1072    let mut volume_arrays: Vec<PyReadonlyArray1<f64>> = Vec::with_capacity(rows);
1073
1074    for i in 0..rows {
1075        let h = highs.get_item(i)?.extract::<PyReadonlyArray1<f64>>()?;
1076        let l = lows.get_item(i)?.extract::<PyReadonlyArray1<f64>>()?;
1077        let c = closes.get_item(i)?.extract::<PyReadonlyArray1<f64>>()?;
1078        let v = volumes.get_item(i)?.extract::<PyReadonlyArray1<f64>>()?;
1079
1080        let n = h.len()?;
1081        if l.len()? != n || c.len()? != n || v.len()? != n {
1082            return Err(PyValueError::new_err(
1083                "Rows must have equal lengths across OHLCV arrays",
1084            ));
1085        }
1086        high_arrays.push(h);
1087        low_arrays.push(l);
1088        close_arrays.push(c);
1089        volume_arrays.push(v);
1090    }
1091
1092    let high_slices: Vec<&[f64]> = high_arrays.iter().map(|a| a.as_slice().unwrap()).collect();
1093    let low_slices: Vec<&[f64]> = low_arrays.iter().map(|a| a.as_slice().unwrap()).collect();
1094    let close_slices: Vec<&[f64]> = close_arrays.iter().map(|a| a.as_slice().unwrap()).collect();
1095    let volume_slices: Vec<&[f64]> = volume_arrays
1096        .iter()
1097        .map(|a| a.as_slice().unwrap())
1098        .collect();
1099
1100    let cols = if rows > 0 { high_slices[0].len() } else { 0 };
1101    let total = rows
1102        .checked_mul(cols)
1103        .ok_or_else(|| PyValueError::new_err("rows*cols overflow in ad_batch"))?;
1104    let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1105    let out_slice = unsafe { out_arr.as_slice_mut()? };
1106
1107    let kern = crate::utilities::kernel_validation::validate_kernel(kernel, true)?;
1108
1109    py.allow_threads(|| -> Result<(), AdError> {
1110        let batch_input = AdBatchInput {
1111            highs: &high_slices,
1112            lows: &low_slices,
1113            closes: &close_slices,
1114            volumes: &volume_slices,
1115        };
1116
1117        let actual = match kern {
1118            Kernel::Auto => detect_best_batch_kernel(),
1119            k => k,
1120        };
1121        ad_batch_inner_into(&batch_input, actual, true, out_slice)
1122    })
1123    .map_err(|e| PyValueError::new_err(e.to_string()))?;
1124
1125    let dict = PyDict::new(py);
1126    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1127    dict.set_item("rows", rows)?;
1128    dict.set_item("cols", cols)?;
1129    Ok(dict)
1130}
1131
1132#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1133#[wasm_bindgen]
1134pub fn ad_js(
1135    high: &[f64],
1136    low: &[f64],
1137    close: &[f64],
1138    volume: &[f64],
1139) -> Result<Vec<f64>, JsValue> {
1140    if high.is_empty() || low.is_empty() || close.is_empty() || volume.is_empty() {
1141        return Err(JsValue::from_str("Not enough data"));
1142    }
1143
1144    let input = AdInput::from_slices(high, low, close, volume, AdParams::default());
1145
1146    let mut output = vec![0.0; high.len()];
1147    ad_into_slice(&mut output, &input, Kernel::Auto)
1148        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1149
1150    Ok(output)
1151}
1152
1153#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1154#[wasm_bindgen]
1155pub fn ad_batch_js(
1156    highs_flat: &[f64],
1157    lows_flat: &[f64],
1158    closes_flat: &[f64],
1159    volumes_flat: &[f64],
1160    rows: usize,
1161) -> Result<Vec<f64>, JsValue> {
1162    if highs_flat.is_empty() || rows == 0 {
1163        return Err(JsValue::from_str("Empty input data"));
1164    }
1165
1166    let cols = highs_flat.len() / rows;
1167    let check = rows
1168        .checked_mul(cols)
1169        .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
1170    if highs_flat.len() != check
1171        || lows_flat.len() != check
1172        || closes_flat.len() != check
1173        || volumes_flat.len() != check
1174    {
1175        return Err(JsValue::from_str(
1176            "Input arrays must have rows*cols elements",
1177        ));
1178    }
1179
1180    let mut high_slices = Vec::with_capacity(rows);
1181    let mut low_slices = Vec::with_capacity(rows);
1182    let mut close_slices = Vec::with_capacity(rows);
1183    let mut volume_slices = Vec::with_capacity(rows);
1184
1185    for i in 0..rows {
1186        let start = i * cols;
1187        let end = start + cols;
1188        high_slices.push(&highs_flat[start..end]);
1189        low_slices.push(&lows_flat[start..end]);
1190        close_slices.push(&closes_flat[start..end]);
1191        volume_slices.push(&volumes_flat[start..end]);
1192    }
1193
1194    let batch_input = AdBatchInput {
1195        highs: &high_slices,
1196        lows: &low_slices,
1197        closes: &close_slices,
1198        volumes: &volume_slices,
1199    };
1200
1201    ad_batch_with_kernel(&batch_input, Kernel::ScalarBatch)
1202        .map(|o| o.values)
1203        .map_err(|e| JsValue::from_str(&e.to_string()))
1204}
1205
1206#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1207#[wasm_bindgen]
1208pub fn ad_batch_metadata_js(rows: usize, cols: usize) -> Vec<f64> {
1209    vec![rows as f64, cols as f64]
1210}
1211
1212#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1213#[wasm_bindgen]
1214pub fn ad_alloc(len: usize) -> *mut f64 {
1215    let mut vec = Vec::<f64>::with_capacity(len);
1216    let ptr = vec.as_mut_ptr();
1217    std::mem::forget(vec);
1218    ptr
1219}
1220
1221#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1222#[wasm_bindgen]
1223pub fn ad_free(ptr: *mut f64, len: usize) {
1224    if !ptr.is_null() {
1225        unsafe {
1226            let _ = Vec::from_raw_parts(ptr, len, len);
1227        }
1228    }
1229}
1230
1231#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1232#[wasm_bindgen]
1233pub fn ad_into(
1234    high_ptr: *const f64,
1235    low_ptr: *const f64,
1236    close_ptr: *const f64,
1237    volume_ptr: *const f64,
1238    out_ptr: *mut f64,
1239    len: usize,
1240) -> Result<(), JsValue> {
1241    if high_ptr.is_null()
1242        || low_ptr.is_null()
1243        || close_ptr.is_null()
1244        || volume_ptr.is_null()
1245        || out_ptr.is_null()
1246    {
1247        return Err(JsValue::from_str("Null pointer provided"));
1248    }
1249
1250    unsafe {
1251        let high = std::slice::from_raw_parts(high_ptr, len);
1252        let low = std::slice::from_raw_parts(low_ptr, len);
1253        let close = std::slice::from_raw_parts(close_ptr, len);
1254        let volume = std::slice::from_raw_parts(volume_ptr, len);
1255
1256        let input = AdInput::from_slices(high, low, close, volume, AdParams::default());
1257
1258        if high_ptr as *const f64 == out_ptr
1259            || low_ptr as *const f64 == out_ptr
1260            || close_ptr as *const f64 == out_ptr
1261            || volume_ptr as *const f64 == out_ptr
1262        {
1263            let mut temp = vec![0.0; len];
1264            ad_into_slice(&mut temp, &input, Kernel::Auto)
1265                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1266            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1267            out.copy_from_slice(&temp);
1268        } else {
1269            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1270            ad_into_slice(out, &input, Kernel::Auto)
1271                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1272        }
1273
1274        Ok(())
1275    }
1276}
1277
1278#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1279use serde::{Deserialize, Serialize};
1280
1281#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1282#[derive(Serialize, Deserialize)]
1283pub struct AdBatchJsOutput {
1284    pub values: Vec<f64>,
1285    pub rows: usize,
1286    pub cols: usize,
1287}
1288
1289#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1290#[wasm_bindgen(js_name = "ad_batch")]
1291pub fn ad_batch_unified_js(
1292    highs_flat: &[f64],
1293    lows_flat: &[f64],
1294    closes_flat: &[f64],
1295    volumes_flat: &[f64],
1296    rows: usize,
1297) -> Result<JsValue, JsValue> {
1298    if rows == 0 {
1299        return Err(JsValue::from_str("rows must be > 0"));
1300    }
1301    if highs_flat.is_empty() {
1302        return Err(JsValue::from_str("empty inputs"));
1303    }
1304    let cols = highs_flat.len() / rows;
1305    let check = rows
1306        .checked_mul(cols)
1307        .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
1308    if lows_flat.len() != check || closes_flat.len() != check || volumes_flat.len() != check {
1309        return Err(JsValue::from_str(
1310            "Input arrays must have rows*cols elements",
1311        ));
1312    }
1313
1314    let mut highs = Vec::with_capacity(rows);
1315    let mut lows = Vec::with_capacity(rows);
1316    let mut closes = Vec::with_capacity(rows);
1317    let mut volumes = Vec::with_capacity(rows);
1318    for r in 0..rows {
1319        let s = r * cols;
1320        let e = s + cols;
1321        highs.push(&highs_flat[s..e]);
1322        lows.push(&lows_flat[s..e]);
1323        closes.push(&closes_flat[s..e]);
1324        volumes.push(&volumes_flat[s..e]);
1325    }
1326
1327    let batch = AdBatchInput {
1328        highs: &highs,
1329        lows: &lows,
1330        closes: &closes,
1331        volumes: &volumes,
1332    };
1333    let out = ad_batch_with_kernel(&batch, Kernel::Auto)
1334        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1335
1336    let packed = AdBatchJsOutput {
1337        values: out.values,
1338        rows: out.rows,
1339        cols: out.cols,
1340    };
1341    serde_wasm_bindgen::to_value(&packed)
1342        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1343}
1344
1345#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1346#[wasm_bindgen]
1347pub fn ad_batch_into(
1348    highs_ptr: *const f64,
1349    lows_ptr: *const f64,
1350    closes_ptr: *const f64,
1351    volumes_ptr: *const f64,
1352    out_ptr: *mut f64,
1353    rows: usize,
1354    cols: usize,
1355) -> Result<(), JsValue> {
1356    if highs_ptr.is_null()
1357        || lows_ptr.is_null()
1358        || closes_ptr.is_null()
1359        || volumes_ptr.is_null()
1360        || out_ptr.is_null()
1361    {
1362        return Err(JsValue::from_str("null pointer"));
1363    }
1364    unsafe {
1365        let check = rows
1366            .checked_mul(cols)
1367            .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
1368        let highs_flat = std::slice::from_raw_parts(highs_ptr, check);
1369        let lows_flat = std::slice::from_raw_parts(lows_ptr, check);
1370        let closes_flat = std::slice::from_raw_parts(closes_ptr, check);
1371        let volumes_flat = std::slice::from_raw_parts(volumes_ptr, check);
1372        let out = std::slice::from_raw_parts_mut(out_ptr, check);
1373
1374        let mut highs = Vec::with_capacity(rows);
1375        let mut lows = Vec::with_capacity(rows);
1376        let mut closes = Vec::with_capacity(rows);
1377        let mut volumes = Vec::with_capacity(rows);
1378        for r in 0..rows {
1379            let s = r * cols;
1380            let e = s + cols;
1381            highs.push(&highs_flat[s..e]);
1382            lows.push(&lows_flat[s..e]);
1383            closes.push(&closes_flat[s..e]);
1384            volumes.push(&volumes_flat[s..e]);
1385        }
1386        let batch = AdBatchInput {
1387            highs: &highs,
1388            lows: &lows,
1389            closes: &closes,
1390            volumes: &volumes,
1391        };
1392
1393        ad_batch_inner_into(&batch, detect_best_batch_kernel(), false, out)
1394            .map_err(|e| JsValue::from_str(&e.to_string()))
1395    }
1396}
1397
1398#[cfg(test)]
1399mod tests {
1400    use super::*;
1401    use crate::skip_if_unsupported;
1402    use crate::utilities::data_loader::{read_candles_from_csv, Candles};
1403    use crate::utilities::enums::Kernel;
1404
1405    fn check_ad_partial_params(
1406        test_name: &str,
1407        kernel: Kernel,
1408    ) -> Result<(), Box<dyn std::error::Error>> {
1409        skip_if_unsupported!(kernel, test_name);
1410        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1411        let candles = read_candles_from_csv(file_path)?;
1412        let default_params = AdParams::default();
1413        let input = AdInput::from_candles(&candles, default_params);
1414        let output = ad_with_kernel(&input, kernel)?;
1415        assert_eq!(output.values.len(), candles.close.len());
1416        Ok(())
1417    }
1418
1419    fn check_ad_accuracy(
1420        test_name: &str,
1421        kernel: Kernel,
1422    ) -> Result<(), Box<dyn std::error::Error>> {
1423        skip_if_unsupported!(kernel, test_name);
1424        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1425        let candles = read_candles_from_csv(file_path)?;
1426        let input = AdInput::with_default_candles(&candles);
1427        let ad_result = ad_with_kernel(&input, kernel)?;
1428        assert_eq!(ad_result.values.len(), candles.close.len());
1429        let expected_last_five = [1645918.16, 1645876.11, 1645824.27, 1645828.87, 1645728.78];
1430        let start = ad_result.values.len() - 5;
1431        let actual = &ad_result.values[start..];
1432        for (i, &val) in actual.iter().enumerate() {
1433            assert!(
1434                (val - expected_last_five[i]).abs() < 1e-1,
1435                "[{}] AD mismatch at idx {}: got {}, expected {}",
1436                test_name,
1437                i,
1438                val,
1439                expected_last_five[i]
1440            );
1441        }
1442        Ok(())
1443    }
1444
1445    fn check_ad_with_slice_data_reinput(
1446        test_name: &str,
1447        kernel: Kernel,
1448    ) -> Result<(), Box<dyn std::error::Error>> {
1449        skip_if_unsupported!(kernel, test_name);
1450        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1451        let candles = read_candles_from_csv(file_path)?;
1452        let first_input = AdInput::with_default_candles(&candles);
1453        let first_result = ad_with_kernel(&first_input, kernel)?;
1454        let second_input = AdInput::from_slices(
1455            &first_result.values,
1456            &first_result.values,
1457            &first_result.values,
1458            &first_result.values,
1459            AdParams::default(),
1460        );
1461        let second_result = ad_with_kernel(&second_input, kernel)?;
1462        assert_eq!(second_result.values.len(), first_result.values.len());
1463        for i in 50..second_result.values.len() {
1464            assert!(!second_result.values[i].is_nan());
1465        }
1466        Ok(())
1467    }
1468
1469    fn check_ad_input_with_default_candles(
1470        test_name: &str,
1471        kernel: Kernel,
1472    ) -> Result<(), Box<dyn std::error::Error>> {
1473        skip_if_unsupported!(kernel, test_name);
1474        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1475        let candles = read_candles_from_csv(file_path)?;
1476        let input = AdInput::with_default_candles(&candles);
1477        match input.data {
1478            AdData::Candles { .. } => {}
1479            _ => panic!("Expected AdData::Candles variant"),
1480        }
1481        Ok(())
1482    }
1483
1484    fn check_ad_accuracy_nan_check(
1485        test_name: &str,
1486        kernel: Kernel,
1487    ) -> Result<(), Box<dyn std::error::Error>> {
1488        skip_if_unsupported!(kernel, test_name);
1489        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1490        let candles = read_candles_from_csv(file_path)?;
1491        let input = AdInput::with_default_candles(&candles);
1492        let ad_result = ad_with_kernel(&input, kernel)?;
1493        assert_eq!(ad_result.values.len(), candles.close.len());
1494        if ad_result.values.len() > 50 {
1495            for i in 50..ad_result.values.len() {
1496                assert!(
1497                    !ad_result.values[i].is_nan(),
1498                    "[{}] Expected no NaN after index 50, but found NaN at index {}",
1499                    test_name,
1500                    i
1501                );
1502            }
1503        }
1504        Ok(())
1505    }
1506
1507    fn check_ad_streaming(
1508        test_name: &str,
1509        kernel: Kernel,
1510    ) -> Result<(), Box<dyn std::error::Error>> {
1511        skip_if_unsupported!(kernel, test_name);
1512        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1513        let candles = read_candles_from_csv(file_path)?;
1514        let input = AdInput::with_default_candles(&candles);
1515        let batch = ad_with_kernel(&input, kernel)?.values;
1516        let mut stream = AdStream::try_new()?;
1517        let mut stream_values = Vec::with_capacity(candles.close.len());
1518        for i in 0..candles.close.len() {
1519            let val = stream.update(
1520                candles.high[i],
1521                candles.low[i],
1522                candles.close[i],
1523                candles.volume[i],
1524            );
1525            stream_values.push(val);
1526        }
1527        assert_eq!(batch.len(), stream_values.len());
1528        for (b, s) in batch.iter().zip(stream_values.iter()) {
1529            if b.is_nan() && s.is_nan() {
1530                continue;
1531            }
1532            assert!(
1533                (b - s).abs() < 1e-9,
1534                "[{}] AD streaming mismatch",
1535                test_name
1536            );
1537        }
1538        Ok(())
1539    }
1540
1541    #[cfg(debug_assertions)]
1542    fn check_ad_no_poison(
1543        test_name: &str,
1544        kernel: Kernel,
1545    ) -> Result<(), Box<dyn std::error::Error>> {
1546        skip_if_unsupported!(kernel, test_name);
1547
1548        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1549        let candles = read_candles_from_csv(file_path)?;
1550
1551        let input = AdInput::with_default_candles(&candles);
1552        let output = ad_with_kernel(&input, kernel)?;
1553
1554        for (i, &val) in output.values.iter().enumerate() {
1555            if val.is_nan() {
1556                continue;
1557            }
1558
1559            let bits = val.to_bits();
1560
1561            if bits == 0x11111111_11111111 {
1562                panic!(
1563                    "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {}",
1564                    test_name, val, bits, i
1565                );
1566            }
1567
1568            if bits == 0x22222222_22222222 {
1569                panic!(
1570                    "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {}",
1571                    test_name, val, bits, i
1572                );
1573            }
1574
1575            if bits == 0x33333333_33333333 {
1576                panic!(
1577                    "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {}",
1578                    test_name, val, bits, i
1579                );
1580            }
1581        }
1582
1583        let slice_input = AdInput::from_slices(
1584            &candles.high,
1585            &candles.low,
1586            &candles.close,
1587            &candles.volume,
1588            AdParams::default(),
1589        );
1590        let slice_output = ad_with_kernel(&slice_input, kernel)?;
1591
1592        for (i, &val) in slice_output.values.iter().enumerate() {
1593            if val.is_nan() {
1594                continue;
1595            }
1596
1597            let bits = val.to_bits();
1598
1599            if bits == 0x11111111_11111111 {
1600                panic!(
1601					"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} (slice test)",
1602					test_name, val, bits, i
1603				);
1604            }
1605
1606            if bits == 0x22222222_22222222 {
1607                panic!(
1608					"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} (slice test)",
1609					test_name, val, bits, i
1610				);
1611            }
1612
1613            if bits == 0x33333333_33333333 {
1614                panic!(
1615					"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} (slice test)",
1616					test_name, val, bits, i
1617				);
1618            }
1619        }
1620
1621        Ok(())
1622    }
1623
1624    #[cfg(not(debug_assertions))]
1625    fn check_ad_no_poison(
1626        _test_name: &str,
1627        _kernel: Kernel,
1628    ) -> Result<(), Box<dyn std::error::Error>> {
1629        Ok(())
1630    }
1631
1632    macro_rules! generate_all_ad_tests {
1633        ($($test_fn:ident),*) => {
1634            paste::paste! {
1635                $(#[test] fn [<$test_fn _scalar_f64>]() {
1636                    let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1637                })*
1638                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1639                $(#[test] fn [<$test_fn _avx2_f64>]() {
1640                    let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1641                })*
1642                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1643                $(#[test] fn [<$test_fn _avx512_f64>]() {
1644                    let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1645                })*
1646            }
1647        }
1648    }
1649
1650    #[cfg(feature = "proptest")]
1651    #[allow(clippy::float_cmp)]
1652    fn check_ad_property(
1653        test_name: &str,
1654        kernel: Kernel,
1655    ) -> Result<(), Box<dyn std::error::Error>> {
1656        use proptest::prelude::*;
1657        skip_if_unsupported!(kernel, test_name);
1658
1659        let strat = (10usize..400).prop_flat_map(|len| {
1660            prop::collection::vec(
1661                (
1662                    1.0f64..1000.0f64,
1663                    0.0f64..500.0f64,
1664                    0.0f64..1.0f64,
1665                    0.0f64..1e6f64,
1666                )
1667                    .prop_filter("finite values", |(l, hd, cr, v)| {
1668                        l.is_finite()
1669                            && hd.is_finite()
1670                            && cr.is_finite()
1671                            && v.is_finite()
1672                            && *v >= 0.0
1673                    })
1674                    .prop_map(|(low, high_delta, close_ratio, volume)| {
1675                        let high = low + high_delta;
1676                        let close = if high_delta == 0.0 {
1677                            low
1678                        } else {
1679                            low + high_delta * close_ratio
1680                        };
1681                        (high, low, close, volume)
1682                    }),
1683                len,
1684            )
1685            .prop_map(|data| {
1686                let (highs, lows, closes, volumes): (Vec<_>, Vec<_>, Vec<_>, Vec<_>) =
1687                    data.into_iter().map(|(h, l, c, v)| (h, l, c, v)).unzip4();
1688                (highs, lows, closes, volumes)
1689            })
1690        });
1691
1692        trait Unzip4<A, B, C, D> {
1693            fn unzip4(self) -> (Vec<A>, Vec<B>, Vec<C>, Vec<D>);
1694        }
1695
1696        impl<I, A, B, C, D> Unzip4<A, B, C, D> for I
1697        where
1698            I: Iterator<Item = (A, B, C, D)>,
1699        {
1700            fn unzip4(self) -> (Vec<A>, Vec<B>, Vec<C>, Vec<D>) {
1701                let (mut a, mut b, mut c, mut d) = (Vec::new(), Vec::new(), Vec::new(), Vec::new());
1702                for (av, bv, cv, dv) in self {
1703                    a.push(av);
1704                    b.push(bv);
1705                    c.push(cv);
1706                    d.push(dv);
1707                }
1708                (a, b, c, d)
1709            }
1710        }
1711
1712        proptest::test_runner::TestRunner::default()
1713            .run(&strat, |(highs, lows, closes, volumes)| {
1714                let input =
1715                    AdInput::from_slices(&highs, &lows, &closes, &volumes, AdParams::default());
1716
1717                let AdOutput { values: out } = ad_with_kernel(&input, kernel).unwrap();
1718
1719                let AdOutput { values: ref_out } = ad_with_kernel(&input, Kernel::Scalar).unwrap();
1720
1721                prop_assert_eq!(out.len(), highs.len(), "Output length mismatch");
1722
1723                for (i, &val) in out.iter().enumerate() {
1724                    prop_assert!(
1725                        !val.is_nan(),
1726                        "Unexpected NaN at index {}: AD should not have NaN values",
1727                        i
1728                    );
1729                }
1730
1731                for i in 0..out.len() {
1732                    let y = out[i];
1733                    let r = ref_out[i];
1734
1735                    let y_bits = y.to_bits();
1736                    let r_bits = r.to_bits();
1737
1738                    if !y.is_finite() || !r.is_finite() {
1739                        prop_assert_eq!(
1740                            y_bits,
1741                            r_bits,
1742                            "Special value mismatch at idx {}: {} vs {}",
1743                            i,
1744                            y,
1745                            r
1746                        );
1747                    } else {
1748                        let ulp_diff: u64 = y_bits.abs_diff(r_bits);
1749                        prop_assert!(
1750                            (y - r).abs() <= 1e-9 || ulp_diff <= 4,
1751                            "Value mismatch at idx {}: {} vs {} (ULP={})",
1752                            i,
1753                            y,
1754                            r,
1755                            ulp_diff
1756                        );
1757                    }
1758                }
1759
1760                for i in 1..volumes.len() {
1761                    if volumes[i] == 0.0 {
1762                        prop_assert!(
1763                            (out[i] - out[i - 1]).abs() < 1e-10,
1764                            "AD should not change when volume is 0 at index {}",
1765                            i
1766                        );
1767                    }
1768                }
1769
1770                for i in 0..highs.len() {
1771                    if (highs[i] - lows[i]).abs() < 1e-10 {
1772                        if i == 0 {
1773                            prop_assert!(
1774                                out[i].abs() < 1e-10,
1775                                "When high=low, first AD value should be 0, got {}",
1776                                out[i]
1777                            );
1778                        } else {
1779                            prop_assert!(
1780                                (out[i] - out[i - 1]).abs() < 1e-10,
1781                                "When high=low at index {}, AD should remain unchanged",
1782                                i
1783                            );
1784                        }
1785                    }
1786                }
1787
1788                let mut expected_ad = 0.0;
1789                for i in 0..highs.len() {
1790                    let hl = highs[i] - lows[i];
1791                    if hl != 0.0 {
1792                        let mfm = ((closes[i] - lows[i]) - (highs[i] - closes[i])) / hl;
1793                        let mfv = mfm * volumes[i];
1794                        expected_ad += mfv;
1795                    }
1796                    prop_assert!(
1797                        (out[i] - expected_ad).abs() < 1e-9,
1798                        "Cumulative property violation at index {}: expected {}, got {}",
1799                        i,
1800                        expected_ad,
1801                        out[i]
1802                    );
1803                }
1804
1805                if !highs.is_empty() {
1806                    let hl = highs[0] - lows[0];
1807                    let expected_first = if hl != 0.0 {
1808                        ((closes[0] - lows[0]) - (highs[0] - closes[0])) / hl * volumes[0]
1809                    } else {
1810                        0.0
1811                    };
1812                    prop_assert!(
1813                        (out[0] - expected_first).abs() < 1e-10,
1814                        "First value mismatch: expected {}, got {}",
1815                        expected_first,
1816                        out[0]
1817                    );
1818                }
1819
1820                for i in 0..highs.len() {
1821                    prop_assert!(
1822                        lows[i] <= closes[i] + 1e-10 && closes[i] <= highs[i] + 1e-10,
1823                        "Price constraint violation at index {}: low={}, close={}, high={}",
1824                        i,
1825                        lows[i],
1826                        closes[i],
1827                        highs[i]
1828                    );
1829                }
1830
1831                let all_equal = highs
1832                    .iter()
1833                    .zip(lows.iter())
1834                    .zip(closes.iter())
1835                    .all(|((&h, &l), &c)| (h - l).abs() < 1e-10 && (l - c).abs() < 1e-10);
1836
1837                if all_equal {
1838                    for (i, &val) in out.iter().enumerate() {
1839                        prop_assert!(
1840                            val.abs() < 1e-10,
1841                            "When all prices are equal, AD should be 0 at index {}, got {}",
1842                            i,
1843                            val
1844                        );
1845                    }
1846                }
1847
1848                Ok(())
1849            })
1850            .unwrap();
1851
1852        Ok(())
1853    }
1854
1855    generate_all_ad_tests!(
1856        check_ad_partial_params,
1857        check_ad_accuracy,
1858        check_ad_input_with_default_candles,
1859        check_ad_with_slice_data_reinput,
1860        check_ad_accuracy_nan_check,
1861        check_ad_streaming,
1862        check_ad_no_poison
1863    );
1864
1865    #[cfg(feature = "proptest")]
1866    generate_all_ad_tests!(check_ad_property);
1867
1868    fn check_batch_single_row(
1869        test: &str,
1870        kernel: Kernel,
1871    ) -> Result<(), Box<dyn std::error::Error>> {
1872        skip_if_unsupported!(kernel, test);
1873        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1874        let candles = read_candles_from_csv(file_path)?;
1875
1876        let highs: Vec<&[f64]> = vec![&candles.high];
1877        let lows: Vec<&[f64]> = vec![&candles.low];
1878        let closes: Vec<&[f64]> = vec![&candles.close];
1879        let volumes: Vec<&[f64]> = vec![&candles.volume];
1880
1881        let single = ad_with_kernel(
1882            &AdInput::from_candles(&candles, AdParams::default()),
1883            kernel,
1884        )?
1885        .values;
1886
1887        let batch = AdBatchBuilder::new()
1888            .kernel(kernel)
1889            .apply_slices(&highs, &lows, &closes, &volumes)?;
1890
1891        assert_eq!(batch.rows, 1);
1892        assert_eq!(batch.cols, candles.close.len());
1893        assert_eq!(batch.values.len(), candles.close.len());
1894
1895        for (i, (a, b)) in single.iter().zip(&batch.values).enumerate() {
1896            assert!(
1897                (a - b).abs() < 1e-8,
1898                "[{}] AD batch single row mismatch at {}: {} vs {}",
1899                test,
1900                i,
1901                a,
1902                b
1903            );
1904        }
1905        Ok(())
1906    }
1907
1908    fn check_batch_multi_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
1909        skip_if_unsupported!(kernel, test);
1910        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1911        let candles = read_candles_from_csv(file_path)?;
1912
1913        let highs: Vec<&[f64]> = vec![&candles.high, &candles.high, &candles.high];
1914        let lows: Vec<&[f64]> = vec![&candles.low, &candles.low, &candles.low];
1915        let closes: Vec<&[f64]> = vec![&candles.close, &candles.close, &candles.close];
1916        let volumes: Vec<&[f64]> = vec![&candles.volume, &candles.volume, &candles.volume];
1917
1918        let single = ad_with_kernel(
1919            &AdInput::from_candles(&candles, AdParams::default()),
1920            kernel,
1921        )?
1922        .values;
1923
1924        let batch = AdBatchBuilder::new()
1925            .kernel(kernel)
1926            .apply_slices(&highs, &lows, &closes, &volumes)?;
1927
1928        assert_eq!(batch.rows, 3);
1929        assert_eq!(batch.cols, candles.close.len());
1930        assert_eq!(batch.values.len(), 3 * candles.close.len());
1931
1932        for row in 0..3 {
1933            let row_slice = &batch.values[row * batch.cols..(row + 1) * batch.cols];
1934            for (i, (a, b)) in single.iter().zip(row_slice.iter()).enumerate() {
1935                assert!(
1936                    (a - b).abs() < 1e-8,
1937                    "[{}] AD batch multi row mismatch row {} idx {}: {} vs {}",
1938                    test,
1939                    row,
1940                    i,
1941                    a,
1942                    b
1943                );
1944            }
1945        }
1946        Ok(())
1947    }
1948
1949    macro_rules! gen_batch_tests {
1950        ($fn_name:ident) => {
1951            paste::paste! {
1952                #[test] fn [<$fn_name _scalar>]()      {
1953                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1954                }
1955                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1956                #[test] fn [<$fn_name _avx2>]()        {
1957                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1958                }
1959                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1960                #[test] fn [<$fn_name _avx512>]()      {
1961                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1962                }
1963                #[test] fn [<$fn_name _auto_detect>]() {
1964                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1965                }
1966            }
1967        };
1968    }
1969
1970    #[cfg(debug_assertions)]
1971    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
1972        skip_if_unsupported!(kernel, test);
1973
1974        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1975        let c = read_candles_from_csv(file)?;
1976
1977        let mut highs: Vec<&[f64]> = vec![];
1978        let mut lows: Vec<&[f64]> = vec![];
1979        let mut closes: Vec<&[f64]> = vec![];
1980        let mut volumes: Vec<&[f64]> = vec![];
1981
1982        highs.push(&c.high);
1983        lows.push(&c.low);
1984        closes.push(&c.close);
1985        volumes.push(&c.volume);
1986
1987        let high_rev: Vec<f64> = c.high.iter().rev().copied().collect();
1988        let low_rev: Vec<f64> = c.low.iter().rev().copied().collect();
1989        let close_rev: Vec<f64> = c.close.iter().rev().copied().collect();
1990        let volume_rev: Vec<f64> = c.volume.iter().rev().copied().collect();
1991
1992        highs.push(&high_rev);
1993        lows.push(&low_rev);
1994        closes.push(&close_rev);
1995        volumes.push(&volume_rev);
1996
1997        if c.high.len() > 100 {
1998            highs.push(&c.high[50..]);
1999            lows.push(&c.low[50..]);
2000            closes.push(&c.close[50..]);
2001            volumes.push(&c.volume[50..]);
2002        }
2003
2004        let batch = AdBatchBuilder::new()
2005            .kernel(kernel)
2006            .apply_slices(&highs, &lows, &closes, &volumes)?;
2007
2008        for (idx, &val) in batch.values.iter().enumerate() {
2009            if val.is_nan() {
2010                continue;
2011            }
2012
2013            let bits = val.to_bits();
2014            let row = idx / batch.cols;
2015            let col = idx % batch.cols;
2016
2017            if bits == 0x11111111_11111111 {
2018                panic!(
2019					"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (flat index {})",
2020					test, val, bits, row, col, idx
2021				);
2022            }
2023
2024            if bits == 0x22222222_22222222 {
2025                panic!(
2026					"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (flat index {})",
2027					test, val, bits, row, col, idx
2028				);
2029            }
2030
2031            if bits == 0x33333333_33333333 {
2032                panic!(
2033					"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (flat index {})",
2034					test, val, bits, row, col, idx
2035				);
2036            }
2037        }
2038
2039        Ok(())
2040    }
2041
2042    #[cfg(not(debug_assertions))]
2043    fn check_batch_no_poison(
2044        _test: &str,
2045        _kernel: Kernel,
2046    ) -> Result<(), Box<dyn std::error::Error>> {
2047        Ok(())
2048    }
2049
2050    gen_batch_tests!(check_batch_single_row);
2051    gen_batch_tests!(check_batch_multi_row);
2052    gen_batch_tests!(check_batch_no_poison);
2053
2054    #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2055    #[test]
2056    fn test_ad_into_matches_api() {
2057        let n = 256usize;
2058        let mut ts = Vec::with_capacity(n);
2059        let mut open = Vec::with_capacity(n);
2060        let mut high = Vec::with_capacity(n);
2061        let mut low = Vec::with_capacity(n);
2062        let mut close = Vec::with_capacity(n);
2063        let mut volume = Vec::with_capacity(n);
2064
2065        for i in 0..n {
2066            let i_f = i as f64;
2067            ts.push(i as i64);
2068            let o = 100.0 + (i % 13) as f64 * 0.75;
2069            let l = o - 2.0;
2070            let h = o + 2.0 + ((i % 3) as f64) * 0.1;
2071            let c = l + ((i % 5) as f64) * 0.5;
2072            let v = 1000.0 + 10.0 * i_f;
2073            open.push(o);
2074            low.push(l);
2075            high.push(h);
2076            close.push(c);
2077            volume.push(v);
2078        }
2079
2080        let candles = Candles::new(
2081            ts,
2082            open,
2083            high.clone(),
2084            low.clone(),
2085            close.clone(),
2086            volume.clone(),
2087        );
2088        let input = AdInput::with_default_candles(&candles);
2089
2090        let baseline = ad(&input).expect("ad() should succeed").values;
2091
2092        let mut out = vec![0.0; baseline.len()];
2093        ad_into(&input, &mut out).expect("ad_into() should succeed");
2094
2095        assert_eq!(out.len(), baseline.len());
2096
2097        fn eq_or_both_nan(a: f64, b: f64) -> bool {
2098            (a.is_nan() && b.is_nan()) || (a == b) || ((a - b).abs() <= 1e-12)
2099        }
2100
2101        for (i, (a, b)) in out
2102            .iter()
2103            .copied()
2104            .zip(baseline.iter().copied())
2105            .enumerate()
2106        {
2107            assert!(
2108                eq_or_both_nan(a, b),
2109                "ad_into parity failed at index {}: {} vs {}",
2110                i,
2111                a,
2112                b
2113            );
2114        }
2115    }
2116}