Skip to main content

vector_ta/indicators/
wad.rs

1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::cuda_available;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::cuda::CudaWad;
5#[cfg(all(feature = "python", feature = "cuda"))]
6use crate::indicators::moving_averages::alma::{make_device_array_py, DeviceArrayF32Py};
7#[cfg(feature = "python")]
8use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
9#[cfg(feature = "python")]
10use pyo3::exceptions::PyValueError;
11#[cfg(feature = "python")]
12use pyo3::prelude::*;
13#[cfg(feature = "python")]
14use pyo3::types::PyDict;
15#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
16use serde::{Deserialize, Serialize};
17#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
18use wasm_bindgen::prelude::*;
19
20use crate::utilities::data_loader::{source_type, Candles};
21use crate::utilities::enums::Kernel;
22use crate::utilities::helpers::{
23    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
24    make_uninit_matrix,
25};
26#[cfg(feature = "python")]
27use crate::utilities::kernel_validation::validate_kernel;
28use aligned_vec::{AVec, CACHELINE_ALIGN};
29#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
30use core::arch::x86_64::*;
31#[cfg(not(target_arch = "wasm32"))]
32use rayon::prelude::*;
33use std::error::Error;
34use std::mem::ManuallyDrop;
35use thiserror::Error;
36
37#[derive(Debug, Clone)]
38pub enum WadData<'a> {
39    Candles {
40        candles: &'a Candles,
41    },
42    Slices {
43        high: &'a [f64],
44        low: &'a [f64],
45        close: &'a [f64],
46    },
47}
48
49#[derive(Debug, Clone)]
50pub struct WadOutput {
51    pub values: Vec<f64>,
52}
53
54#[derive(Debug, Clone, Default, PartialEq, Eq)]
55pub struct WadParams;
56
57#[derive(Debug, Clone)]
58pub struct WadInput<'a> {
59    pub data: WadData<'a>,
60    pub params: WadParams,
61}
62
63impl<'a> WadInput<'a> {
64    #[inline]
65    pub fn from_candles(candles: &'a Candles) -> Self {
66        Self {
67            data: WadData::Candles { candles },
68            params: WadParams::default(),
69        }
70    }
71    #[inline]
72    pub fn from_slices(high: &'a [f64], low: &'a [f64], close: &'a [f64]) -> Self {
73        Self {
74            data: WadData::Slices { high, low, close },
75            params: WadParams::default(),
76        }
77    }
78    #[inline]
79    pub fn with_default_candles(candles: &'a Candles) -> Self {
80        Self::from_candles(candles)
81    }
82}
83
84#[derive(Copy, Clone, Debug, Default)]
85pub struct WadBuilder {
86    kernel: Kernel,
87}
88impl WadBuilder {
89    #[inline(always)]
90    pub fn new() -> Self {
91        Self::default()
92    }
93    #[inline(always)]
94    pub fn kernel(mut self, k: Kernel) -> Self {
95        self.kernel = k;
96        self
97    }
98    #[inline(always)]
99    pub fn apply(self, candles: &Candles) -> Result<WadOutput, WadError> {
100        let i = WadInput::from_candles(candles);
101        wad_with_kernel(&i, self.kernel)
102    }
103    #[inline(always)]
104    pub fn apply_slices(
105        self,
106        high: &[f64],
107        low: &[f64],
108        close: &[f64],
109    ) -> Result<WadOutput, WadError> {
110        let i = WadInput::from_slices(high, low, close);
111        wad_with_kernel(&i, self.kernel)
112    }
113    #[inline(always)]
114    pub fn into_stream(self) -> Result<WadStream, WadError> {
115        WadStream::try_new()
116    }
117}
118
119#[derive(Debug, Error)]
120pub enum WadError {
121    #[error("wad: Empty input data.")]
122    EmptyInputData,
123    #[error("wad: All values are NaN.")]
124    AllValuesNaN,
125    #[error("wad: Invalid period: period = {period}, data length = {data_len}.")]
126    InvalidPeriod { period: usize, data_len: usize },
127    #[error("wad: Not enough valid data: needed = {needed}, valid = {valid}.")]
128    NotEnoughValidData { needed: usize, valid: usize },
129    #[error("wad: Empty or mismatched lengths: expected = {expected}, got = {got}.")]
130    OutputLengthMismatch { expected: usize, got: usize },
131    #[error("wad: Invalid range: start={start}, end={end}, step={step}.")]
132    InvalidRange {
133        start: usize,
134        end: usize,
135        step: usize,
136    },
137    #[error("wad: Invalid kernel for batch: {0:?}.")]
138    InvalidKernelForBatch(Kernel),
139    #[error("wad: Invalid input: {msg}.")]
140    InvalidInput { msg: String },
141}
142
143#[inline]
144pub fn wad(input: &WadInput) -> Result<WadOutput, WadError> {
145    wad_with_kernel(input, Kernel::Auto)
146}
147
148pub fn wad_with_kernel(input: &WadInput, kernel: Kernel) -> Result<WadOutput, WadError> {
149    let (high, low, close): (&[f64], &[f64], &[f64]) = match &input.data {
150        WadData::Candles { candles } => (
151            source_type(candles, "high"),
152            source_type(candles, "low"),
153            source_type(candles, "close"),
154        ),
155        WadData::Slices { high, low, close } => (*high, *low, *close),
156    };
157    if high.is_empty() || low.is_empty() || close.is_empty() {
158        return Err(WadError::EmptyInputData);
159    }
160    let len = high.len();
161    if len != low.len() || len != close.len() {
162        let got = if low.len() != len {
163            low.len()
164        } else {
165            close.len()
166        };
167        return Err(WadError::OutputLengthMismatch { expected: len, got });
168    }
169    if high.iter().all(|x| x.is_nan())
170        || low.iter().all(|x| x.is_nan())
171        || close.iter().all(|x| x.is_nan())
172    {
173        return Err(WadError::AllValuesNaN);
174    }
175    let chosen = match kernel {
176        Kernel::Auto => detect_best_kernel(),
177        other => other,
178    };
179    let mut out = alloc_with_nan_prefix(len, 0);
180    unsafe {
181        match chosen {
182            Kernel::Scalar | Kernel::ScalarBatch => wad_scalar(high, low, close, &mut out),
183            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
184            Kernel::Avx2 | Kernel::Avx2Batch => wad_avx2(high, low, close, &mut out),
185            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
186            Kernel::Avx512 | Kernel::Avx512Batch => wad_avx512(high, low, close, &mut out),
187            _ => unreachable!(),
188        }
189    }
190    Ok(WadOutput { values: out })
191}
192
193#[inline(always)]
194pub fn wad_scalar(high: &[f64], low: &[f64], close: &[f64], out: &mut [f64]) {
195    let n = close.len();
196    if n == 0 {
197        return;
198    }
199
200    out[0] = 0.0;
201    let mut acc = 0.0f64;
202    let mut pc = close[0];
203
204    for i in 1..n {
205        let h = high[i];
206        let l = low[i];
207        let c = close[i];
208        let trh = pc.max(h);
209        let trl = pc.min(l);
210
211        let gt = (c > pc) as i32 as f64;
212        let lt = (c < pc) as i32 as f64;
213
214        let ad = gt.mul_add(c - trl, lt * (c - trh));
215        acc += ad;
216        out[i] = acc;
217        pc = c;
218    }
219}
220
221#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
222#[inline(always)]
223pub unsafe fn wad_avx2(high: &[f64], low: &[f64], close: &[f64], out: &mut [f64]) {
224    #[cfg(target_arch = "x86_64")]
225    #[target_feature(enable = "avx2,fma")]
226    unsafe fn inner(high: &[f64], low: &[f64], close: &[f64], out: &mut [f64]) {
227        let n = close.len();
228        if n == 0 {
229            return;
230        }
231        *out.get_unchecked_mut(0) = 0.0;
232
233        let hp = high.as_ptr();
234        let lp = low.as_ptr();
235        let cp = close.as_ptr();
236        let op = out.as_mut_ptr();
237
238        let mut acc = 0.0f64;
239        let mut pc = *cp;
240        let mut i = 1usize;
241
242        while i + 7 < n {
243            use core::arch::x86_64::{_mm_prefetch, _MM_HINT_T0};
244            if i + 40 < n {
245                _mm_prefetch(cp.add(i + 32) as *const i8, _MM_HINT_T0);
246                _mm_prefetch(hp.add(i + 32) as *const i8, _MM_HINT_T0);
247                _mm_prefetch(lp.add(i + 32) as *const i8, _MM_HINT_T0);
248            }
249
250            let c0 = *cp.add(i);
251            let h0 = *hp.add(i);
252            let l0 = *lp.add(i);
253            let trh0 = if pc > h0 { pc } else { h0 };
254            let trl0 = if pc < l0 { pc } else { l0 };
255            let gt0 = (c0 > pc) as i32 as f64;
256            let lt0 = (c0 < pc) as i32 as f64;
257            let ad0 = gt0.mul_add(c0 - trl0, lt0 * (c0 - trh0));
258            acc += ad0;
259            *op.add(i) = acc;
260
261            let c1 = *cp.add(i + 1);
262            let h1 = *hp.add(i + 1);
263            let l1 = *lp.add(i + 1);
264            let trh1 = if c0 > h1 { c0 } else { h1 };
265            let trl1 = if c0 < l1 { c0 } else { l1 };
266            let gt1 = (c1 > c0) as i32 as f64;
267            let lt1 = (c1 < c0) as i32 as f64;
268            let ad1 = gt1.mul_add(c1 - trl1, lt1 * (c1 - trh1));
269            acc += ad1;
270            *op.add(i + 1) = acc;
271
272            let c2 = *cp.add(i + 2);
273            let h2 = *hp.add(i + 2);
274            let l2 = *lp.add(i + 2);
275            let trh2 = if c1 > h2 { c1 } else { h2 };
276            let trl2 = if c1 < l2 { c1 } else { l2 };
277            let gt2 = (c2 > c1) as i32 as f64;
278            let lt2 = (c2 < c1) as i32 as f64;
279            let ad2 = gt2.mul_add(c2 - trl2, lt2 * (c2 - trh2));
280            acc += ad2;
281            *op.add(i + 2) = acc;
282
283            let c3 = *cp.add(i + 3);
284            let h3 = *hp.add(i + 3);
285            let l3 = *lp.add(i + 3);
286            let trh3 = if c2 > h3 { c2 } else { h3 };
287            let trl3 = if c2 < l3 { c2 } else { l3 };
288            let gt3 = (c3 > c2) as i32 as f64;
289            let lt3 = (c3 < c2) as i32 as f64;
290            let ad3 = gt3.mul_add(c3 - trl3, lt3 * (c3 - trh3));
291            acc += ad3;
292            *op.add(i + 3) = acc;
293
294            let c4 = *cp.add(i + 4);
295            let h4 = *hp.add(i + 4);
296            let l4 = *lp.add(i + 4);
297            let trh4 = if c3 > h4 { c3 } else { h4 };
298            let trl4 = if c3 < l4 { c3 } else { l4 };
299            let gt4 = (c4 > c3) as i32 as f64;
300            let lt4 = (c4 < c3) as i32 as f64;
301            let ad4 = gt4.mul_add(c4 - trl4, lt4 * (c4 - trh4));
302            acc += ad4;
303            *op.add(i + 4) = acc;
304
305            let c5 = *cp.add(i + 5);
306            let h5 = *hp.add(i + 5);
307            let l5 = *lp.add(i + 5);
308            let trh5 = if c4 > h5 { c4 } else { h5 };
309            let trl5 = if c4 < l5 { c4 } else { l5 };
310            let gt5 = (c5 > c4) as i32 as f64;
311            let lt5 = (c5 < c4) as i32 as f64;
312            let ad5 = gt5.mul_add(c5 - trl5, lt5 * (c5 - trh5));
313            acc += ad5;
314            *op.add(i + 5) = acc;
315
316            let c6 = *cp.add(i + 6);
317            let h6 = *hp.add(i + 6);
318            let l6 = *lp.add(i + 6);
319            let trh6 = if c5 > h6 { c5 } else { h6 };
320            let trl6 = if c5 < l6 { c5 } else { l6 };
321            let gt6 = (c6 > c5) as i32 as f64;
322            let lt6 = (c6 < c5) as i32 as f64;
323            let ad6 = gt6.mul_add(c6 - trl6, lt6 * (c6 - trh6));
324            acc += ad6;
325            *op.add(i + 6) = acc;
326
327            let c7 = *cp.add(i + 7);
328            let h7 = *hp.add(i + 7);
329            let l7 = *lp.add(i + 7);
330            let trh7 = if c6 > h7 { c6 } else { h7 };
331            let trl7 = if c6 < l7 { c6 } else { l7 };
332            let gt7 = (c7 > c6) as i32 as f64;
333            let lt7 = (c7 < c6) as i32 as f64;
334            let ad7 = gt7.mul_add(c7 - trl7, lt7 * (c7 - trh7));
335            acc += ad7;
336            *op.add(i + 7) = acc;
337
338            pc = c7;
339            i += 8;
340        }
341
342        while i < n {
343            let c = *cp.add(i);
344            let h = *hp.add(i);
345            let l = *lp.add(i);
346            let trh = if pc > h { pc } else { h };
347            let trl = if pc < l { pc } else { l };
348            let gt = (c > pc) as i32 as f64;
349            let lt = (c < pc) as i32 as f64;
350            let ad = gt.mul_add(c - trl, lt * (c - trh));
351            acc += ad;
352            *op.add(i) = acc;
353            pc = c;
354            i += 1;
355        }
356    }
357
358    inner(high, low, close, out)
359}
360
361#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
362#[inline(always)]
363pub unsafe fn wad_avx512(high: &[f64], low: &[f64], close: &[f64], out: &mut [f64]) {
364    if high.len() <= 64 {
365        wad_avx512_short(high, low, close, out);
366    } else {
367        wad_avx512_long(high, low, close, out);
368    }
369}
370
371#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
372#[inline(always)]
373pub unsafe fn wad_avx512_short(high: &[f64], low: &[f64], close: &[f64], out: &mut [f64]) {
374    #[target_feature(enable = "avx512f,fma")]
375    unsafe fn inner(high: &[f64], low: &[f64], close: &[f64], out: &mut [f64]) {
376        let n = close.len();
377        if n == 0 {
378            return;
379        }
380        *out.get_unchecked_mut(0) = 0.0;
381
382        let hp = high.as_ptr();
383        let lp = low.as_ptr();
384        let cp = close.as_ptr();
385        let op = out.as_mut_ptr();
386
387        let mut acc = 0.0f64;
388        let mut pc = *cp;
389        let mut i = 1usize;
390
391        while i + 7 < n {
392            let c0 = *cp.add(i);
393            let h0 = *hp.add(i);
394            let l0 = *lp.add(i);
395            let trh0 = if pc > h0 { pc } else { h0 };
396            let trl0 = if pc < l0 { pc } else { l0 };
397            let gt0 = (c0 > pc) as i32 as f64;
398            let lt0 = (c0 < pc) as i32 as f64;
399            let ad0 = gt0.mul_add(c0 - trl0, lt0 * (c0 - trh0));
400            acc += ad0;
401            *op.add(i) = acc;
402
403            let c1 = *cp.add(i + 1);
404            let h1 = *hp.add(i + 1);
405            let l1 = *lp.add(i + 1);
406            let trh1 = if c0 > h1 { c0 } else { h1 };
407            let trl1 = if c0 < l1 { c0 } else { l1 };
408            let gt1 = (c1 > c0) as i32 as f64;
409            let lt1 = (c1 < c0) as i32 as f64;
410            let ad1 = gt1.mul_add(c1 - trl1, lt1 * (c1 - trh1));
411            acc += ad1;
412            *op.add(i + 1) = acc;
413
414            let c2 = *cp.add(i + 2);
415            let h2 = *hp.add(i + 2);
416            let l2 = *lp.add(i + 2);
417            let trh2 = if c1 > h2 { c1 } else { h2 };
418            let trl2 = if c1 < l2 { c1 } else { l2 };
419            let gt2 = (c2 > c1) as i32 as f64;
420            let lt2 = (c2 < c1) as i32 as f64;
421            let ad2 = gt2.mul_add(c2 - trl2, lt2 * (c2 - trh2));
422            acc += ad2;
423            *op.add(i + 2) = acc;
424
425            let c3 = *cp.add(i + 3);
426            let h3 = *hp.add(i + 3);
427            let l3 = *lp.add(i + 3);
428            let trh3 = if c2 > h3 { c2 } else { h3 };
429            let trl3 = if c2 < l3 { c2 } else { l3 };
430            let gt3 = (c3 > c2) as i32 as f64;
431            let lt3 = (c3 < c2) as i32 as f64;
432            let ad3 = gt3.mul_add(c3 - trl3, lt3 * (c3 - trh3));
433            acc += ad3;
434            *op.add(i + 3) = acc;
435
436            let c4 = *cp.add(i + 4);
437            let h4 = *hp.add(i + 4);
438            let l4 = *lp.add(i + 4);
439            let trh4 = if c3 > h4 { c3 } else { h4 };
440            let trl4 = if c3 < l4 { c3 } else { l4 };
441            let gt4 = (c4 > c3) as i32 as f64;
442            let lt4 = (c4 < c3) as i32 as f64;
443            let ad4 = gt4.mul_add(c4 - trl4, lt4 * (c4 - trh4));
444            acc += ad4;
445            *op.add(i + 4) = acc;
446
447            let c5 = *cp.add(i + 5);
448            let h5 = *hp.add(i + 5);
449            let l5 = *lp.add(i + 5);
450            let trh5 = if c4 > h5 { c4 } else { h5 };
451            let trl5 = if c4 < l5 { c4 } else { l5 };
452            let gt5 = (c5 > c4) as i32 as f64;
453            let lt5 = (c5 < c4) as i32 as f64;
454            let ad5 = gt5.mul_add(c5 - trl5, lt5 * (c5 - trh5));
455            acc += ad5;
456            *op.add(i + 5) = acc;
457
458            let c6 = *cp.add(i + 6);
459            let h6 = *hp.add(i + 6);
460            let l6 = *lp.add(i + 6);
461            let trh6 = if c5 > h6 { c5 } else { h6 };
462            let trl6 = if c5 < l6 { c5 } else { l6 };
463            let gt6 = (c6 > c5) as i32 as f64;
464            let lt6 = (c6 < c5) as i32 as f64;
465            let ad6 = gt6.mul_add(c6 - trl6, lt6 * (c6 - trh6));
466            acc += ad6;
467            *op.add(i + 6) = acc;
468
469            let c7 = *cp.add(i + 7);
470            let h7 = *hp.add(i + 7);
471            let l7 = *lp.add(i + 7);
472            let trh7 = if c6 > h7 { c6 } else { h7 };
473            let trl7 = if c6 < l7 { c6 } else { l7 };
474            let gt7 = (c7 > c6) as i32 as f64;
475            let lt7 = (c7 < c6) as i32 as f64;
476            let ad7 = gt7.mul_add(c7 - trl7, lt7 * (c7 - trh7));
477            acc += ad7;
478            *op.add(i + 7) = acc;
479
480            pc = c7;
481            i += 8;
482        }
483
484        while i < n {
485            let c = *cp.add(i);
486            let h = *hp.add(i);
487            let l = *lp.add(i);
488            let trh = if pc > h { pc } else { h };
489            let trl = if pc < l { pc } else { l };
490            let gt = (c > pc) as i32 as f64;
491            let lt = (c < pc) as i32 as f64;
492            let ad = gt.mul_add(c - trl, lt * (c - trh));
493            acc += ad;
494            *op.add(i) = acc;
495            pc = c;
496            i += 1;
497        }
498    }
499
500    inner(high, low, close, out)
501}
502#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
503#[inline(always)]
504pub unsafe fn wad_avx512_long(high: &[f64], low: &[f64], close: &[f64], out: &mut [f64]) {
505    #[target_feature(enable = "avx512f,fma")]
506    unsafe fn inner(high: &[f64], low: &[f64], close: &[f64], out: &mut [f64]) {
507        let n = close.len();
508        if n == 0 {
509            return;
510        }
511        *out.get_unchecked_mut(0) = 0.0;
512
513        let hp = high.as_ptr();
514        let lp = low.as_ptr();
515        let cp = close.as_ptr();
516        let op = out.as_mut_ptr();
517
518        let mut acc = 0.0f64;
519        let mut pc = *cp;
520        let mut i = 1usize;
521        while i + 15 < n {
522            use core::arch::x86_64::{_mm_prefetch, _MM_HINT_T0};
523            if i + 96 < n {
524                _mm_prefetch(cp.add(i + 64) as *const i8, _MM_HINT_T0);
525                _mm_prefetch(hp.add(i + 64) as *const i8, _MM_HINT_T0);
526                _mm_prefetch(lp.add(i + 64) as *const i8, _MM_HINT_T0);
527            }
528
529            macro_rules! step {
530                ($off:expr, $pc:expr) => {{
531                    let c = *cp.add(i + $off);
532                    let h = *hp.add(i + $off);
533                    let l = *lp.add(i + $off);
534                    let trh = if $pc > h { $pc } else { h };
535                    let trl = if $pc < l { $pc } else { l };
536                    let gt = (c > $pc) as i32 as f64;
537                    let lt = (c < $pc) as i32 as f64;
538                    let ad = gt.mul_add(c - trl, lt * (c - trh));
539                    acc += ad;
540                    *op.add(i + $off) = acc;
541                    c
542                }};
543            }
544
545            let c0 = step!(0, pc);
546            let c1 = step!(1, c0);
547            let c2 = step!(2, c1);
548            let c3 = step!(3, c2);
549            let c4 = step!(4, c3);
550            let c5 = step!(5, c4);
551            let c6 = step!(6, c5);
552            let c7 = step!(7, c6);
553            let c8 = step!(8, c7);
554            let c9 = step!(9, c8);
555            let c10 = step!(10, c9);
556            let c11 = step!(11, c10);
557            let c12 = step!(12, c11);
558            let c13 = step!(13, c12);
559            let c14 = step!(14, c13);
560            let c15 = step!(15, c14);
561
562            pc = c15;
563            i += 16;
564        }
565        while i < n {
566            let c = *cp.add(i);
567            let h = *hp.add(i);
568            let l = *lp.add(i);
569            let trh = if pc > h { pc } else { h };
570            let trl = if pc < l { pc } else { l };
571            let gt = (c > pc) as i32 as f64;
572            let lt = (c < pc) as i32 as f64;
573            let ad = gt.mul_add(c - trl, lt * (c - trh));
574            acc += ad;
575            *op.add(i) = acc;
576            pc = c;
577            i += 1;
578        }
579    }
580
581    inner(high, low, close, out)
582}
583
584#[inline(always)]
585pub unsafe fn wad_row_scalar(high: &[f64], low: &[f64], close: &[f64], out: &mut [f64]) {
586    wad_scalar(high, low, close, out)
587}
588#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
589#[inline(always)]
590pub unsafe fn wad_row_avx2(high: &[f64], low: &[f64], close: &[f64], out: &mut [f64]) {
591    wad_avx2(high, low, close, out)
592}
593#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
594#[inline(always)]
595pub unsafe fn wad_row_avx512(high: &[f64], low: &[f64], close: &[f64], out: &mut [f64]) {
596    wad_avx512(high, low, close, out)
597}
598#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
599#[inline(always)]
600pub unsafe fn wad_row_avx512_short(high: &[f64], low: &[f64], close: &[f64], out: &mut [f64]) {
601    wad_avx512_short(high, low, close, out)
602}
603#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
604#[inline(always)]
605pub unsafe fn wad_row_avx512_long(high: &[f64], low: &[f64], close: &[f64], out: &mut [f64]) {
606    wad_avx512_long(high, low, close, out)
607}
608
609#[derive(Debug, Clone)]
610pub struct WadStream {
611    sum: f64,
612    prev_close: Option<f64>,
613}
614impl WadStream {
615    pub fn try_new() -> Result<Self, WadError> {
616        Ok(Self {
617            sum: 0.0,
618            prev_close: None,
619        })
620    }
621    #[inline(always)]
622    pub fn update(&mut self, high: f64, low: f64, close: f64) -> f64 {
623        let pc = match self.prev_close {
624            Some(pc) => pc,
625            None => {
626                self.prev_close = Some(close);
627                return self.sum;
628            }
629        };
630
631        let trh = pc.max(high);
632        let trl = pc.min(low);
633
634        let gt = (close > pc) as i32 as f64;
635        let lt = (close < pc) as i32 as f64;
636        let ad = gt.mul_add(close - trl, lt * (close - trh));
637
638        self.sum += ad;
639        self.prev_close = Some(close);
640        self.sum
641    }
642}
643
644#[derive(Clone, Debug)]
645pub struct WadBatchRange {
646    pub dummy: (usize, usize, usize),
647}
648impl Default for WadBatchRange {
649    fn default() -> Self {
650        Self { dummy: (0, 0, 0) }
651    }
652}
653#[derive(Clone, Debug, Default)]
654pub struct WadBatchBuilder {
655    range: WadBatchRange,
656    kernel: Kernel,
657}
658impl WadBatchBuilder {
659    pub fn new() -> Self {
660        Self::default()
661    }
662    pub fn kernel(mut self, k: Kernel) -> Self {
663        self.kernel = k;
664        self
665    }
666    pub fn apply_slices(
667        self,
668        high: &[f64],
669        low: &[f64],
670        close: &[f64],
671    ) -> Result<WadBatchOutput, WadError> {
672        wad_batch_with_kernel(high, low, close, self.kernel)
673    }
674    pub fn with_default_slices(
675        high: &[f64],
676        low: &[f64],
677        close: &[f64],
678        k: Kernel,
679    ) -> Result<WadBatchOutput, WadError> {
680        WadBatchBuilder::new()
681            .kernel(k)
682            .apply_slices(high, low, close)
683    }
684    pub fn apply_candles(self, c: &Candles) -> Result<WadBatchOutput, WadError> {
685        let high = source_type(c, "high");
686        let low = source_type(c, "low");
687        let close = source_type(c, "close");
688        self.apply_slices(high, low, close)
689    }
690    pub fn with_default_candles(c: &Candles) -> Result<WadBatchOutput, WadError> {
691        WadBatchBuilder::new().kernel(Kernel::Auto).apply_candles(c)
692    }
693}
694
695pub fn wad_batch_with_kernel(
696    high: &[f64],
697    low: &[f64],
698    close: &[f64],
699    k: Kernel,
700) -> Result<WadBatchOutput, WadError> {
701    let kernel = match k {
702        Kernel::Auto => detect_best_batch_kernel(),
703        other if other.is_batch() => other,
704        other => return Err(WadError::InvalidKernelForBatch(other)),
705    };
706    wad_batch_par_slice(high, low, close, kernel)
707}
708
709#[derive(Clone, Debug)]
710pub struct WadBatchOutput {
711    pub values: Vec<f64>,
712    pub rows: usize,
713    pub cols: usize,
714}
715impl WadBatchOutput {
716    pub fn row_for_params(&self, _: &WadParams) -> Option<usize> {
717        Some(0)
718    }
719    pub fn values_for(&self, _: &WadParams) -> Option<&[f64]> {
720        Some(&self.values)
721    }
722}
723
724#[inline(always)]
725pub fn expand_grid(_r: &WadBatchRange) -> Vec<WadParams> {
726    let mut result = Vec::with_capacity(1);
727    result.push(WadParams);
728    result
729}
730
731#[inline(always)]
732pub fn wad_batch_slice(
733    high: &[f64],
734    low: &[f64],
735    close: &[f64],
736    kern: Kernel,
737) -> Result<WadBatchOutput, WadError> {
738    wad_batch_inner(high, low, close, kern, false)
739}
740#[inline(always)]
741pub fn wad_batch_par_slice(
742    high: &[f64],
743    low: &[f64],
744    close: &[f64],
745    kern: Kernel,
746) -> Result<WadBatchOutput, WadError> {
747    wad_batch_inner(high, low, close, kern, true)
748}
749
750#[inline(always)]
751fn wad_batch_inner(
752    high: &[f64],
753    low: &[f64],
754    close: &[f64],
755    kern: Kernel,
756    _parallel: bool,
757) -> Result<WadBatchOutput, WadError> {
758    if high.is_empty() || low.is_empty() || close.is_empty() {
759        return Err(WadError::EmptyInputData);
760    }
761    let len = high.len();
762    if len != low.len() || len != close.len() {
763        let got = if low.len() != len {
764            low.len()
765        } else {
766            close.len()
767        };
768        return Err(WadError::OutputLengthMismatch { expected: len, got });
769    }
770    if high.iter().all(|x| x.is_nan())
771        || low.iter().all(|x| x.is_nan())
772        || close.iter().all(|x| x.is_nan())
773    {
774        return Err(WadError::AllValuesNaN);
775    }
776
777    let mut buf_mu = make_uninit_matrix(1, len);
778    init_matrix_prefixes(&mut buf_mu, len, &[0]);
779
780    let mut guard = ManuallyDrop::new(buf_mu);
781    let out: &mut [f64] =
782        unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
783
784    wad_batch_inner_into(high, low, close, kern, false, out)?;
785
786    let values = unsafe {
787        Vec::from_raw_parts(
788            guard.as_mut_ptr() as *mut f64,
789            guard.len(),
790            guard.capacity(),
791        )
792    };
793
794    Ok(WadBatchOutput {
795        values,
796        rows: 1,
797        cols: len,
798    })
799}
800
801#[inline(always)]
802fn wad_batch_inner_into(
803    high: &[f64],
804    low: &[f64],
805    close: &[f64],
806    kern: Kernel,
807    _parallel: bool,
808    out: &mut [f64],
809) -> Result<(), WadError> {
810    if high.is_empty() || low.is_empty() || close.is_empty() {
811        return Err(WadError::EmptyInputData);
812    }
813    let len = high.len();
814    if len != low.len() || len != close.len() {
815        let got = if low.len() != len {
816            low.len()
817        } else {
818            close.len()
819        };
820        return Err(WadError::OutputLengthMismatch { expected: len, got });
821    }
822    if high.iter().all(|x| x.is_nan())
823        || low.iter().all(|x| x.is_nan())
824        || close.iter().all(|x| x.is_nan())
825    {
826        return Err(WadError::AllValuesNaN);
827    }
828    if out.len() != len {
829        return Err(WadError::OutputLengthMismatch {
830            expected: len,
831            got: out.len(),
832        });
833    }
834
835    let actual = match kern {
836        Kernel::Auto => detect_best_batch_kernel(),
837        k => k,
838    };
839    unsafe {
840        match actual {
841            Kernel::Scalar | Kernel::ScalarBatch => wad_row_scalar(high, low, close, out),
842            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
843            Kernel::Avx2 | Kernel::Avx2Batch => wad_row_avx2(high, low, close, out),
844            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
845            Kernel::Avx512 | Kernel::Avx512Batch => wad_row_avx512(high, low, close, out),
846            _ => unreachable!(),
847        }
848    }
849    Ok(())
850}
851
852#[cfg(test)]
853mod tests {
854    use super::*;
855    use crate::skip_if_unsupported;
856    use crate::utilities::data_loader::read_candles_from_csv;
857    use std::error::Error;
858
859    fn check_wad_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
860        skip_if_unsupported!(kernel, test_name);
861        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
862        let candles = read_candles_from_csv(file_path)?;
863        let input = WadInput::from_candles(&candles);
864        let output = wad_with_kernel(&input, kernel)?;
865        assert_eq!(output.values.len(), candles.close.len());
866        Ok(())
867    }
868
869    fn check_wad_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
870        skip_if_unsupported!(kernel, test_name);
871        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
872        let candles = read_candles_from_csv(file_path)?;
873        let input = WadInput::from_candles(&candles);
874        let output = wad_with_kernel(&input, kernel)?;
875        assert_eq!(output.values.len(), candles.close.len());
876        let expected_last_five_wad = [
877            158503.46790000016,
878            158279.46790000016,
879            158014.46790000016,
880            158186.46790000016,
881            157605.46790000016,
882        ];
883        let start = output.values.len().saturating_sub(5);
884        for (i, &val) in output.values[start..].iter().enumerate() {
885            let exp = expected_last_five_wad[i];
886            assert!(
887                (val - exp).abs() < 1e-4,
888                "[{}] WAD {:?} mismatch at idx {}: got {}, expected {}",
889                test_name,
890                kernel,
891                i,
892                val,
893                exp
894            );
895        }
896        Ok(())
897    }
898
899    fn check_wad_empty_data(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
900        skip_if_unsupported!(kernel, test_name);
901        let input = WadInput::from_slices(&[], &[], &[]);
902        let result = wad_with_kernel(&input, kernel);
903        assert!(result.is_err());
904        Ok(())
905    }
906
907    fn check_wad_all_values_nan(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
908        skip_if_unsupported!(kernel, test_name);
909        let nan_slice = [f64::NAN, f64::NAN, f64::NAN];
910        let input = WadInput::from_slices(&nan_slice, &nan_slice, &nan_slice);
911        let result = wad_with_kernel(&input, kernel);
912        assert!(result.is_err());
913        Ok(())
914    }
915
916    fn check_wad_basic_slice(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
917        skip_if_unsupported!(kernel, test_name);
918        let high = [10.0, 11.0, 11.0, 12.0];
919        let low = [9.0, 9.0, 10.0, 10.0];
920        let close = [9.5, 10.5, 10.5, 11.5];
921        let input = WadInput::from_slices(&high, &low, &close);
922        let output = wad_with_kernel(&input, kernel)?;
923        assert_eq!(output.values.len(), 4);
924        assert!((output.values[0] - 0.0).abs() < 1e-10);
925        assert!((output.values[1] - 1.5).abs() < 1e-10);
926        assert!((output.values[2] - 1.5).abs() < 1e-10);
927        assert!((output.values[3] - 3.0).abs() < 1e-10);
928        Ok(())
929    }
930
931    fn check_wad_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
932        skip_if_unsupported!(kernel, test_name);
933        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
934        let candles = read_candles_from_csv(file_path)?;
935        let high = source_type(&candles, "high");
936        let low = source_type(&candles, "low");
937        let close = source_type(&candles, "close");
938        let batch_output =
939            wad_with_kernel(&WadInput::from_slices(high, low, close), kernel)?.values;
940        let mut stream = WadStream::try_new()?;
941        let mut stream_values = Vec::with_capacity(close.len());
942        for ((&h, &l), &c) in high.iter().zip(low).zip(close) {
943            stream_values.push(stream.update(h, l, c));
944        }
945        assert_eq!(batch_output.len(), stream_values.len());
946        for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
947            let diff = (b - s).abs();
948            assert!(
949                diff < 1e-9,
950                "[{}] WAD streaming mismatch at idx {}: batch={}, stream={}, diff={}",
951                test_name,
952                i,
953                b,
954                s,
955                diff
956            );
957        }
958        Ok(())
959    }
960
961    fn check_wad_small_example(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
962        skip_if_unsupported!(kernel, test_name);
963
964        let high = [10.0, 11.0, 12.0, 11.5, 12.5];
965        let low = [9.0, 9.5, 11.0, 10.5, 11.0];
966        let close = [9.5, 10.5, 11.5, 11.0, 12.0];
967        let expected = [0.0, 1.0, 2.0, 1.5, 2.5];
968
969        let input = WadInput::from_slices(&high, &low, &close);
970        let output = wad_with_kernel(&input, kernel)?;
971
972        assert_eq!(output.values.len(), 5);
973
974        for i in 0..5 {
975            let got = output.values[i];
976            let exp = expected[i];
977            assert!(
978                (got - exp).abs() < 1e-10,
979                "[{}] WAD {:?} small example mismatch at idx {}: got {}, expected {}",
980                test_name,
981                kernel,
982                i,
983                got,
984                exp
985            );
986        }
987
988        Ok(())
989    }
990
991    #[cfg(debug_assertions)]
992    fn check_wad_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
993        skip_if_unsupported!(kernel, test_name);
994
995        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
996        let candles = read_candles_from_csv(file_path)?;
997
998        let test_configs = vec![WadParams::default()];
999
1000        for (param_idx, params) in test_configs.iter().enumerate() {
1001            let input = WadInput {
1002                data: WadData::Candles { candles: &candles },
1003                params: params.clone(),
1004            };
1005            let output = wad_with_kernel(&input, kernel)?;
1006
1007            for (i, &val) in output.values.iter().enumerate() {
1008                if val.is_nan() {
1009                    continue;
1010                }
1011
1012                let bits = val.to_bits();
1013
1014                if bits == 0x11111111_11111111 {
1015                    panic!(
1016                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1017						 with params: {:?} (param set {})",
1018                        test_name, val, bits, i, params, param_idx
1019                    );
1020                }
1021
1022                if bits == 0x22222222_22222222 {
1023                    panic!(
1024                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1025						 with params: {:?} (param set {})",
1026                        test_name, val, bits, i, params, param_idx
1027                    );
1028                }
1029
1030                if bits == 0x33333333_33333333 {
1031                    panic!(
1032                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1033						 with params: {:?} (param set {})",
1034                        test_name, val, bits, i, params, param_idx
1035                    );
1036                }
1037            }
1038        }
1039
1040        Ok(())
1041    }
1042
1043    #[cfg(not(debug_assertions))]
1044    fn check_wad_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1045        Ok(())
1046    }
1047
1048    #[cfg(debug_assertions)]
1049    fn check_batch_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1050        skip_if_unsupported!(kernel, test_name);
1051
1052        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1053        let candles = read_candles_from_csv(file_path)?;
1054
1055        let test_configs = vec!["high", "low", "close"];
1056
1057        for (cfg_idx, &source) in test_configs.iter().enumerate() {
1058            let output = wad_batch_with_kernel(
1059                source_type(&candles, "high"),
1060                source_type(&candles, "low"),
1061                source_type(&candles, "close"),
1062                kernel,
1063            )?;
1064
1065            for (idx, &val) in output.values.iter().enumerate() {
1066                if val.is_nan() {
1067                    continue;
1068                }
1069
1070                let bits = val.to_bits();
1071                let row = idx / output.cols;
1072                let col = idx % output.cols;
1073
1074                if bits == 0x11111111_11111111 {
1075                    panic!(
1076                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1077						 at row {} col {} (flat index {}) - source: {}",
1078                        test_name, cfg_idx, val, bits, row, col, idx, source
1079                    );
1080                }
1081
1082                if bits == 0x22222222_22222222 {
1083                    panic!(
1084                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1085						 at row {} col {} (flat index {}) - source: {}",
1086                        test_name, cfg_idx, val, bits, row, col, idx, source
1087                    );
1088                }
1089
1090                if bits == 0x33333333_33333333 {
1091                    panic!(
1092                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1093						 at row {} col {} (flat index {}) - source: {}",
1094                        test_name, cfg_idx, val, bits, row, col, idx, source
1095                    );
1096                }
1097            }
1098        }
1099
1100        Ok(())
1101    }
1102
1103    #[cfg(not(debug_assertions))]
1104    fn check_batch_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1105        Ok(())
1106    }
1107
1108    macro_rules! generate_all_wad_tests {
1109        ($($test_fn:ident),*) => {
1110            paste::paste! {
1111                $(
1112                    #[test]
1113                    fn [<$test_fn _scalar_f64>]() {
1114                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1115                    }
1116                )*
1117                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1118                $(
1119                    #[test]
1120                    fn [<$test_fn _avx2_f64>]() {
1121                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1122                    }
1123                    #[test]
1124                    fn [<$test_fn _avx512_f64>]() {
1125                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1126                    }
1127                )*
1128            }
1129        }
1130    }
1131
1132    generate_all_wad_tests!(
1133        check_wad_partial_params,
1134        check_wad_accuracy,
1135        check_wad_empty_data,
1136        check_wad_all_values_nan,
1137        check_wad_basic_slice,
1138        check_wad_streaming,
1139        check_wad_small_example,
1140        check_wad_no_poison
1141    );
1142
1143    macro_rules! gen_batch_tests {
1144        ($fn_name:ident) => {
1145            paste::paste! {
1146                #[test] fn [<$fn_name _scalar>]()      {
1147                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1148                }
1149                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1150                #[test] fn [<$fn_name _avx2>]()        {
1151                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1152                }
1153                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1154                #[test] fn [<$fn_name _avx512>]()      {
1155                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1156                }
1157            }
1158        };
1159    }
1160
1161    gen_batch_tests!(check_batch_no_poison);
1162
1163    #[cfg(feature = "proptest")]
1164    #[allow(clippy::float_cmp)]
1165    fn check_wad_property(
1166        test_name: &str,
1167        kernel: Kernel,
1168    ) -> Result<(), Box<dyn std::error::Error>> {
1169        use proptest::prelude::*;
1170        skip_if_unsupported!(kernel, test_name);
1171
1172        let strat = (1usize..=200).prop_flat_map(|len| {
1173            prop::collection::vec(
1174                (1.0f64..1000.0f64).prop_flat_map(|base_price| {
1175                    let range = base_price * 0.1;
1176                    let low = base_price - range;
1177                    let high = base_price + range;
1178
1179                    (low..=high).prop_map(move |close| {
1180                        let actual_low = low.min(close);
1181                        let actual_high = high.max(close);
1182                        (actual_high, actual_low, close)
1183                    })
1184                }),
1185                len,
1186            )
1187        });
1188
1189        proptest::test_runner::TestRunner::default().run(&strat, |ohlc_data| {
1190            let (highs, lows, closes): (Vec<f64>, Vec<f64>, Vec<f64>) =
1191                ohlc_data.into_iter().map(|(h, l, c)| (h, l, c)).unzip3();
1192
1193            let input = WadInput::from_slices(&highs, &lows, &closes);
1194
1195            let WadOutput { values: out } = wad_with_kernel(&input, kernel).unwrap();
1196            let WadOutput { values: ref_out } = wad_with_kernel(&input, Kernel::Scalar).unwrap();
1197
1198            prop_assert_eq!(out[0], 0.0, "First WAD value must be 0.0");
1199            prop_assert_eq!(ref_out[0], 0.0, "First reference WAD value must be 0.0");
1200
1201            let mut expected_sum = 0.0;
1202            let mut prev_close = closes[0];
1203
1204            for i in 1..closes.len() {
1205                let trh = if prev_close > highs[i] {
1206                    prev_close
1207                } else {
1208                    highs[i]
1209                };
1210                let trl = if prev_close < lows[i] {
1211                    prev_close
1212                } else {
1213                    lows[i]
1214                };
1215
1216                let ad = if closes[i] > prev_close {
1217                    closes[i] - trl
1218                } else if closes[i] < prev_close {
1219                    closes[i] - trh
1220                } else {
1221                    0.0
1222                };
1223
1224                expected_sum += ad;
1225
1226                prop_assert!(
1227                    (out[i] - expected_sum).abs() <= 1e-9,
1228                    "WAD mismatch at idx {}: got {}, expected {}",
1229                    i,
1230                    out[i],
1231                    expected_sum
1232                );
1233
1234                prev_close = closes[i];
1235            }
1236
1237            for i in 0..out.len() {
1238                let y = out[i];
1239                let r = ref_out[i];
1240
1241                if !y.is_finite() || !r.is_finite() {
1242                    prop_assert_eq!(
1243                        y.to_bits(),
1244                        r.to_bits(),
1245                        "NaN/Inf mismatch at idx {}: {} vs {}",
1246                        i,
1247                        y,
1248                        r
1249                    );
1250                    continue;
1251                }
1252
1253                let ulp_diff = y.to_bits().abs_diff(r.to_bits());
1254                prop_assert!(
1255                    (y - r).abs() <= 1e-9 || ulp_diff <= 4,
1256                    "Kernel mismatch at idx {}: {} vs {} (diff: {}, ulp: {})",
1257                    i,
1258                    y,
1259                    r,
1260                    (y - r).abs(),
1261                    ulp_diff
1262                );
1263            }
1264
1265            for i in 1..closes.len() {
1266                if (closes[i] - closes[i - 1]).abs() < f64::EPSILON {
1267                    let ad_change = if i == 1 {
1268                        out[i] - 0.0
1269                    } else {
1270                        out[i] - out[i - 1]
1271                    };
1272                    prop_assert!(
1273                        ad_change.abs() < 1e-9,
1274                        "WAD should not change when close[{}] == close[{}], but changed by {}",
1275                        i,
1276                        i - 1,
1277                        ad_change
1278                    );
1279                }
1280            }
1281
1282            if closes.len() == 1 {
1283                prop_assert_eq!(out.len(), 1);
1284                prop_assert_eq!(out[0], 0.0);
1285            }
1286
1287            if closes
1288                .windows(2)
1289                .all(|w| (w[0] - w[1]).abs() < f64::EPSILON)
1290            {
1291                for i in 0..out.len() {
1292                    prop_assert!(
1293                        out[i].abs() < 1e-9,
1294                        "WAD should be 0 for constant prices, but got {} at index {}",
1295                        out[i],
1296                        i
1297                    );
1298                }
1299            }
1300
1301            let strictly_increasing = closes.windows(2).all(|w| w[1] > w[0]);
1302            if strictly_increasing && closes.len() > 1 {
1303                for i in 1..out.len() {
1304                    prop_assert!(
1305							out[i] >= out[i-1] - 1e-9,
1306							"WAD should increase monotonically for strictly increasing prices, but {} < {} at index {}",
1307							out[i], out[i-1], i
1308						);
1309                }
1310            }
1311
1312            let strictly_decreasing = closes.windows(2).all(|w| w[1] < w[0]);
1313            if strictly_decreasing && closes.len() > 1 {
1314                for i in 1..out.len() {
1315                    prop_assert!(
1316							out[i] <= out[i-1] + 1e-9,
1317							"WAD should decrease monotonically for strictly decreasing prices, but {} > {} at index {}",
1318							out[i], out[i-1], i
1319						);
1320                }
1321            }
1322
1323            Ok(())
1324        })?;
1325
1326        Ok(())
1327    }
1328
1329    trait Unzip3<A, B, C> {
1330        fn unzip3(self) -> (Vec<A>, Vec<B>, Vec<C>);
1331    }
1332
1333    impl<A, B, C, I> Unzip3<A, B, C> for I
1334    where
1335        I: Iterator<Item = (A, B, C)>,
1336    {
1337        fn unzip3(self) -> (Vec<A>, Vec<B>, Vec<C>) {
1338            let (mut a_vec, mut b_vec, mut c_vec) = (Vec::new(), Vec::new(), Vec::new());
1339            for (a, b, c) in self {
1340                a_vec.push(a);
1341                b_vec.push(b);
1342                c_vec.push(c);
1343            }
1344            (a_vec, b_vec, c_vec)
1345        }
1346    }
1347
1348    #[cfg(feature = "proptest")]
1349    generate_all_wad_tests!(check_wad_property);
1350
1351    #[test]
1352    fn test_wad_into_matches_api() -> Result<(), Box<dyn Error>> {
1353        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1354        let candles = read_candles_from_csv(file_path)?;
1355        let input = WadInput::from_candles(&candles);
1356
1357        let baseline = wad(&input)?.values;
1358
1359        let mut out = vec![0.0; baseline.len()];
1360        #[allow(unused_variables)]
1361        {
1362            #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1363            {
1364                wad_into(&input, &mut out)?;
1365            }
1366            #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1367            {
1368                wad_into_slice(&mut out, &input, Kernel::Auto)?;
1369            }
1370        }
1371
1372        assert_eq!(baseline.len(), out.len());
1373
1374        fn eq_or_both_nan(a: f64, b: f64) -> bool {
1375            (a.is_nan() && b.is_nan()) || (a == b)
1376        }
1377
1378        for i in 0..baseline.len() {
1379            assert!(
1380                eq_or_both_nan(baseline[i], out[i]),
1381                "Mismatch at index {}: baseline={}, into={}",
1382                i,
1383                baseline[i],
1384                out[i]
1385            );
1386        }
1387
1388        Ok(())
1389    }
1390}
1391
1392#[inline(always)]
1393fn wad_prepare<'a>(
1394    input: &'a WadInput,
1395    _kernel: Kernel,
1396) -> Result<(&'a [f64], &'a [f64], &'a [f64], usize, Kernel), WadError> {
1397    let (high, low, close): (&[f64], &[f64], &[f64]) = match &input.data {
1398        WadData::Candles { candles } => (
1399            source_type(candles, "high"),
1400            source_type(candles, "low"),
1401            source_type(candles, "close"),
1402        ),
1403        WadData::Slices { high, low, close } => (*high, *low, *close),
1404    };
1405
1406    if high.is_empty() || low.is_empty() || close.is_empty() {
1407        return Err(WadError::EmptyInputData);
1408    }
1409    let len = high.len();
1410    if len != low.len() || len != close.len() {
1411        let got = if low.len() != len {
1412            low.len()
1413        } else {
1414            close.len()
1415        };
1416        return Err(WadError::OutputLengthMismatch { expected: len, got });
1417    }
1418    if high.iter().all(|x| x.is_nan())
1419        || low.iter().all(|x| x.is_nan())
1420        || close.iter().all(|x| x.is_nan())
1421    {
1422        return Err(WadError::AllValuesNaN);
1423    }
1424
1425    let chosen = match _kernel {
1426        Kernel::Auto => detect_best_kernel(),
1427        other => other,
1428    };
1429
1430    Ok((high, low, close, len, chosen))
1431}
1432
1433#[inline]
1434pub fn wad_into_slice(dst: &mut [f64], input: &WadInput, kern: Kernel) -> Result<(), WadError> {
1435    let (high, low, close, len, chosen) = wad_prepare(input, kern)?;
1436
1437    if dst.len() != len {
1438        return Err(WadError::OutputLengthMismatch {
1439            expected: len,
1440            got: dst.len(),
1441        });
1442    }
1443
1444    unsafe {
1445        match chosen {
1446            Kernel::Scalar | Kernel::ScalarBatch => wad_scalar(high, low, close, dst),
1447            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1448            Kernel::Avx2 | Kernel::Avx2Batch => wad_avx2(high, low, close, dst),
1449            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1450            Kernel::Avx512 | Kernel::Avx512Batch => wad_avx512(high, low, close, dst),
1451            _ => unreachable!(),
1452        }
1453    }
1454
1455    Ok(())
1456}
1457
1458#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1459
1460pub fn wad_into(input: &WadInput, out: &mut [f64]) -> Result<(), WadError> {
1461    wad_into_slice(out, input, Kernel::Auto)
1462}
1463
1464#[cfg(all(feature = "python", feature = "cuda"))]
1465#[pyfunction(name = "wad_cuda_dev")]
1466#[pyo3(signature = (high_f32, low_f32, close_f32, device_id=0))]
1467pub fn wad_cuda_dev_py(
1468    py: Python<'_>,
1469    high_f32: PyReadonlyArray1<'_, f32>,
1470    low_f32: PyReadonlyArray1<'_, f32>,
1471    close_f32: PyReadonlyArray1<'_, f32>,
1472    device_id: usize,
1473) -> PyResult<DeviceArrayF32Py> {
1474    if !cuda_available() {
1475        return Err(PyValueError::new_err("CUDA not available"));
1476    }
1477
1478    let high = high_f32.as_slice()?;
1479    let low = low_f32.as_slice()?;
1480    let close = close_f32.as_slice()?;
1481
1482    let inner = py.allow_threads(|| {
1483        let cuda = CudaWad::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1484        cuda.wad_series_dev(high, low, close)
1485            .map_err(|e| PyValueError::new_err(e.to_string()))
1486    })?;
1487
1488    let handle = make_device_array_py(device_id, inner)?;
1489    Ok(handle)
1490}
1491
1492#[cfg(all(feature = "python", feature = "cuda"))]
1493#[pyfunction(name = "wad_cuda_batch_dev")]
1494#[pyo3(signature = (high_f32, low_f32, close_f32, device_id=0))]
1495pub fn wad_cuda_batch_dev_py(
1496    py: Python<'_>,
1497    high_f32: PyReadonlyArray1<'_, f32>,
1498    low_f32: PyReadonlyArray1<'_, f32>,
1499    close_f32: PyReadonlyArray1<'_, f32>,
1500    device_id: usize,
1501) -> PyResult<DeviceArrayF32Py> {
1502    if !cuda_available() {
1503        return Err(PyValueError::new_err("CUDA not available"));
1504    }
1505    let high = high_f32.as_slice()?;
1506    let low = low_f32.as_slice()?;
1507    let close = close_f32.as_slice()?;
1508    let inner = py.allow_threads(|| {
1509        let cuda = CudaWad::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1510        cuda.wad_batch_dev(high, low, close)
1511            .map_err(|e| PyValueError::new_err(e.to_string()))
1512    })?;
1513    let handle = make_device_array_py(device_id, inner)?;
1514    Ok(handle)
1515}
1516
1517#[cfg(all(feature = "python", feature = "cuda"))]
1518#[pyfunction(name = "wad_cuda_many_series_one_param_dev")]
1519#[pyo3(signature = (high_tm_f32, low_tm_f32, close_tm_f32, device_id=0))]
1520pub fn wad_cuda_many_series_one_param_dev_py(
1521    py: Python<'_>,
1522    high_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1523    low_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1524    close_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1525    device_id: usize,
1526) -> PyResult<DeviceArrayF32Py> {
1527    use numpy::PyUntypedArrayMethods;
1528    if !cuda_available() {
1529        return Err(PyValueError::new_err("CUDA not available"));
1530    }
1531    let rows = high_tm_f32.shape()[0];
1532    let cols = high_tm_f32.shape()[1];
1533    if low_tm_f32.shape() != [rows, cols] || close_tm_f32.shape() != [rows, cols] {
1534        return Err(PyValueError::new_err("high/low/close shapes must match"));
1535    }
1536    let high = high_tm_f32.as_slice()?;
1537    let low = low_tm_f32.as_slice()?;
1538    let close = close_tm_f32.as_slice()?;
1539    let inner = py.allow_threads(|| {
1540        let cuda = CudaWad::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1541        cuda.wad_many_series_one_param_time_major_dev(high, low, close, cols, rows)
1542            .map_err(|e| PyValueError::new_err(e.to_string()))
1543    })?;
1544    let handle = make_device_array_py(device_id, inner)?;
1545    Ok(handle)
1546}
1547
1548#[cfg(feature = "python")]
1549#[pyfunction(name = "wad")]
1550#[pyo3(signature = (high, low, close, kernel=None))]
1551pub fn wad_py<'py>(
1552    py: Python<'py>,
1553    high: PyReadonlyArray1<'py, f64>,
1554    low: PyReadonlyArray1<'py, f64>,
1555    close: PyReadonlyArray1<'py, f64>,
1556    kernel: Option<&str>,
1557) -> PyResult<Bound<'py, PyArray1<f64>>> {
1558    let high_slice = high.as_slice()?;
1559    let low_slice = low.as_slice()?;
1560    let close_slice = close.as_slice()?;
1561    let kern = validate_kernel(kernel, false)?;
1562
1563    let input = WadInput::from_slices(high_slice, low_slice, close_slice);
1564
1565    let result_vec: Vec<f64> = py
1566        .allow_threads(|| wad_with_kernel(&input, kern).map(|o| o.values))
1567        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1568
1569    Ok(result_vec.into_pyarray(py))
1570}
1571
1572#[cfg(feature = "python")]
1573#[pyclass(name = "WadStream")]
1574pub struct WadStreamPy {
1575    stream: WadStream,
1576}
1577
1578#[cfg(feature = "python")]
1579#[pymethods]
1580impl WadStreamPy {
1581    #[new]
1582    fn new() -> PyResult<Self> {
1583        let stream = WadStream::try_new().map_err(|e| PyValueError::new_err(e.to_string()))?;
1584        Ok(WadStreamPy { stream })
1585    }
1586
1587    fn update(&mut self, high: f64, low: f64, close: f64) -> f64 {
1588        self.stream.update(high, low, close)
1589    }
1590}
1591
1592#[cfg(feature = "python")]
1593#[pyfunction(name = "wad_batch")]
1594#[pyo3(signature = (high, low, close, kernel=None))]
1595pub fn wad_batch_py<'py>(
1596    py: Python<'py>,
1597    high: PyReadonlyArray1<'py, f64>,
1598    low: PyReadonlyArray1<'py, f64>,
1599    close: PyReadonlyArray1<'py, f64>,
1600    kernel: Option<&str>,
1601) -> PyResult<Bound<'py, PyDict>> {
1602    use pyo3::types::PyDict;
1603
1604    let high_slice = high.as_slice()?;
1605    let low_slice = low.as_slice()?;
1606    let close_slice = close.as_slice()?;
1607
1608    let cols = high_slice.len();
1609    let rows = 1usize;
1610
1611    let total = rows
1612        .checked_mul(cols)
1613        .ok_or_else(|| PyValueError::new_err("wad_batch: size overflow in rows*cols"))?;
1614
1615    let out_arr = unsafe { numpy::PyArray1::<f64>::new(py, [total], false) };
1616    let out_slice = unsafe { out_arr.as_slice_mut()? };
1617
1618    let kern = validate_kernel(kernel, true)?;
1619    py.allow_threads(|| {
1620        wad_batch_inner_into(high_slice, low_slice, close_slice, kern, true, out_slice)
1621    })
1622    .map_err(|e| PyValueError::new_err(e.to_string()))?;
1623
1624    let dict = PyDict::new(py);
1625    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1626    Ok(dict)
1627}
1628
1629#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1630#[wasm_bindgen]
1631pub fn wad_js(high: &[f64], low: &[f64], close: &[f64]) -> Result<Vec<f64>, JsValue> {
1632    let input = WadInput::from_slices(high, low, close);
1633
1634    let mut output = vec![0.0; high.len()];
1635
1636    wad_into_slice(&mut output, &input, Kernel::Auto)
1637        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1638
1639    Ok(output)
1640}
1641
1642#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1643#[wasm_bindgen]
1644pub fn wad_alloc(len: usize) -> *mut f64 {
1645    let mut vec = Vec::<f64>::with_capacity(len);
1646    let ptr = vec.as_mut_ptr();
1647    std::mem::forget(vec);
1648    ptr
1649}
1650
1651#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1652#[wasm_bindgen]
1653pub fn wad_free(ptr: *mut f64, len: usize) {
1654    if !ptr.is_null() {
1655        unsafe {
1656            let _ = Vec::from_raw_parts(ptr, len, len);
1657        }
1658    }
1659}
1660
1661#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1662#[wasm_bindgen]
1663pub fn wad_into(
1664    high_ptr: *const f64,
1665    low_ptr: *const f64,
1666    close_ptr: *const f64,
1667    out_ptr: *mut f64,
1668    len: usize,
1669) -> Result<(), JsValue> {
1670    if high_ptr.is_null() || low_ptr.is_null() || close_ptr.is_null() || out_ptr.is_null() {
1671        return Err(JsValue::from_str("null pointer passed to wad_into"));
1672    }
1673
1674    unsafe {
1675        let high = std::slice::from_raw_parts(high_ptr, len);
1676        let low = std::slice::from_raw_parts(low_ptr, len);
1677        let close = std::slice::from_raw_parts(close_ptr, len);
1678
1679        let input = WadInput::from_slices(high, low, close);
1680
1681        if high_ptr as *const f64 == out_ptr as *const f64
1682            || low_ptr as *const f64 == out_ptr as *const f64
1683            || close_ptr as *const f64 == out_ptr as *const f64
1684        {
1685            let mut temp = vec![0.0; len];
1686            wad_into_slice(&mut temp, &input, Kernel::Auto)
1687                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1688            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1689            out.copy_from_slice(&temp);
1690        } else {
1691            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1692            wad_into_slice(out, &input, Kernel::Auto)
1693                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1694        }
1695
1696        Ok(())
1697    }
1698}
1699
1700#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1701#[wasm_bindgen]
1702pub fn wad_batch_into(
1703    high_ptr: *const f64,
1704    low_ptr: *const f64,
1705    close_ptr: *const f64,
1706    out_ptr: *mut f64,
1707    len: usize,
1708) -> Result<usize, JsValue> {
1709    if high_ptr.is_null() || low_ptr.is_null() || close_ptr.is_null() || out_ptr.is_null() {
1710        return Err(JsValue::from_str("null pointer passed to wad_batch_into"));
1711    }
1712    unsafe {
1713        let high = std::slice::from_raw_parts(high_ptr, len);
1714        let low = std::slice::from_raw_parts(low_ptr, len);
1715        let close = std::slice::from_raw_parts(close_ptr, len);
1716        let out = std::slice::from_raw_parts_mut(out_ptr, len);
1717        wad_batch_inner_into(high, low, close, detect_best_kernel(), false, out)
1718            .map_err(|e| JsValue::from_str(&e.to_string()))?;
1719        Ok(1)
1720    }
1721}
1722
1723#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1724#[derive(Serialize, Deserialize)]
1725pub struct WadBatchConfig {
1726    pub dummy: (usize, usize, usize),
1727}
1728
1729#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1730#[derive(Serialize, Deserialize)]
1731pub struct WadBatchJsOutput {
1732    pub values: Vec<f64>,
1733    pub rows: usize,
1734    pub cols: usize,
1735}
1736
1737#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1738#[wasm_bindgen(js_name = wad_batch)]
1739pub fn wad_batch_unified_js(
1740    high: &[f64],
1741    low: &[f64],
1742    close: &[f64],
1743    _config: JsValue,
1744) -> Result<JsValue, JsValue> {
1745    let out = wad_batch_inner(high, low, close, detect_best_kernel(), false)
1746        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1747    let js = WadBatchJsOutput {
1748        values: out.values,
1749        rows: out.rows,
1750        cols: out.cols,
1751    };
1752    serde_wasm_bindgen::to_value(&js)
1753        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1754}