1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::cuda_available;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::cuda::CudaWad;
5#[cfg(all(feature = "python", feature = "cuda"))]
6use crate::indicators::moving_averages::alma::{make_device_array_py, DeviceArrayF32Py};
7#[cfg(feature = "python")]
8use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
9#[cfg(feature = "python")]
10use pyo3::exceptions::PyValueError;
11#[cfg(feature = "python")]
12use pyo3::prelude::*;
13#[cfg(feature = "python")]
14use pyo3::types::PyDict;
15#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
16use serde::{Deserialize, Serialize};
17#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
18use wasm_bindgen::prelude::*;
19
20use crate::utilities::data_loader::{source_type, Candles};
21use crate::utilities::enums::Kernel;
22use crate::utilities::helpers::{
23 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
24 make_uninit_matrix,
25};
26#[cfg(feature = "python")]
27use crate::utilities::kernel_validation::validate_kernel;
28use aligned_vec::{AVec, CACHELINE_ALIGN};
29#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
30use core::arch::x86_64::*;
31#[cfg(not(target_arch = "wasm32"))]
32use rayon::prelude::*;
33use std::error::Error;
34use std::mem::ManuallyDrop;
35use thiserror::Error;
36
37#[derive(Debug, Clone)]
38pub enum WadData<'a> {
39 Candles {
40 candles: &'a Candles,
41 },
42 Slices {
43 high: &'a [f64],
44 low: &'a [f64],
45 close: &'a [f64],
46 },
47}
48
49#[derive(Debug, Clone)]
50pub struct WadOutput {
51 pub values: Vec<f64>,
52}
53
54#[derive(Debug, Clone, Default, PartialEq, Eq)]
55pub struct WadParams;
56
57#[derive(Debug, Clone)]
58pub struct WadInput<'a> {
59 pub data: WadData<'a>,
60 pub params: WadParams,
61}
62
63impl<'a> WadInput<'a> {
64 #[inline]
65 pub fn from_candles(candles: &'a Candles) -> Self {
66 Self {
67 data: WadData::Candles { candles },
68 params: WadParams::default(),
69 }
70 }
71 #[inline]
72 pub fn from_slices(high: &'a [f64], low: &'a [f64], close: &'a [f64]) -> Self {
73 Self {
74 data: WadData::Slices { high, low, close },
75 params: WadParams::default(),
76 }
77 }
78 #[inline]
79 pub fn with_default_candles(candles: &'a Candles) -> Self {
80 Self::from_candles(candles)
81 }
82}
83
84#[derive(Copy, Clone, Debug, Default)]
85pub struct WadBuilder {
86 kernel: Kernel,
87}
88impl WadBuilder {
89 #[inline(always)]
90 pub fn new() -> Self {
91 Self::default()
92 }
93 #[inline(always)]
94 pub fn kernel(mut self, k: Kernel) -> Self {
95 self.kernel = k;
96 self
97 }
98 #[inline(always)]
99 pub fn apply(self, candles: &Candles) -> Result<WadOutput, WadError> {
100 let i = WadInput::from_candles(candles);
101 wad_with_kernel(&i, self.kernel)
102 }
103 #[inline(always)]
104 pub fn apply_slices(
105 self,
106 high: &[f64],
107 low: &[f64],
108 close: &[f64],
109 ) -> Result<WadOutput, WadError> {
110 let i = WadInput::from_slices(high, low, close);
111 wad_with_kernel(&i, self.kernel)
112 }
113 #[inline(always)]
114 pub fn into_stream(self) -> Result<WadStream, WadError> {
115 WadStream::try_new()
116 }
117}
118
119#[derive(Debug, Error)]
120pub enum WadError {
121 #[error("wad: Empty input data.")]
122 EmptyInputData,
123 #[error("wad: All values are NaN.")]
124 AllValuesNaN,
125 #[error("wad: Invalid period: period = {period}, data length = {data_len}.")]
126 InvalidPeriod { period: usize, data_len: usize },
127 #[error("wad: Not enough valid data: needed = {needed}, valid = {valid}.")]
128 NotEnoughValidData { needed: usize, valid: usize },
129 #[error("wad: Empty or mismatched lengths: expected = {expected}, got = {got}.")]
130 OutputLengthMismatch { expected: usize, got: usize },
131 #[error("wad: Invalid range: start={start}, end={end}, step={step}.")]
132 InvalidRange {
133 start: usize,
134 end: usize,
135 step: usize,
136 },
137 #[error("wad: Invalid kernel for batch: {0:?}.")]
138 InvalidKernelForBatch(Kernel),
139 #[error("wad: Invalid input: {msg}.")]
140 InvalidInput { msg: String },
141}
142
143#[inline]
144pub fn wad(input: &WadInput) -> Result<WadOutput, WadError> {
145 wad_with_kernel(input, Kernel::Auto)
146}
147
148pub fn wad_with_kernel(input: &WadInput, kernel: Kernel) -> Result<WadOutput, WadError> {
149 let (high, low, close): (&[f64], &[f64], &[f64]) = match &input.data {
150 WadData::Candles { candles } => (
151 source_type(candles, "high"),
152 source_type(candles, "low"),
153 source_type(candles, "close"),
154 ),
155 WadData::Slices { high, low, close } => (*high, *low, *close),
156 };
157 if high.is_empty() || low.is_empty() || close.is_empty() {
158 return Err(WadError::EmptyInputData);
159 }
160 let len = high.len();
161 if len != low.len() || len != close.len() {
162 let got = if low.len() != len {
163 low.len()
164 } else {
165 close.len()
166 };
167 return Err(WadError::OutputLengthMismatch { expected: len, got });
168 }
169 if high.iter().all(|x| x.is_nan())
170 || low.iter().all(|x| x.is_nan())
171 || close.iter().all(|x| x.is_nan())
172 {
173 return Err(WadError::AllValuesNaN);
174 }
175 let chosen = match kernel {
176 Kernel::Auto => detect_best_kernel(),
177 other => other,
178 };
179 let mut out = alloc_with_nan_prefix(len, 0);
180 unsafe {
181 match chosen {
182 Kernel::Scalar | Kernel::ScalarBatch => wad_scalar(high, low, close, &mut out),
183 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
184 Kernel::Avx2 | Kernel::Avx2Batch => wad_avx2(high, low, close, &mut out),
185 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
186 Kernel::Avx512 | Kernel::Avx512Batch => wad_avx512(high, low, close, &mut out),
187 _ => unreachable!(),
188 }
189 }
190 Ok(WadOutput { values: out })
191}
192
193#[inline(always)]
194pub fn wad_scalar(high: &[f64], low: &[f64], close: &[f64], out: &mut [f64]) {
195 let n = close.len();
196 if n == 0 {
197 return;
198 }
199
200 out[0] = 0.0;
201 let mut acc = 0.0f64;
202 let mut pc = close[0];
203
204 for i in 1..n {
205 let h = high[i];
206 let l = low[i];
207 let c = close[i];
208 let trh = pc.max(h);
209 let trl = pc.min(l);
210
211 let gt = (c > pc) as i32 as f64;
212 let lt = (c < pc) as i32 as f64;
213
214 let ad = gt.mul_add(c - trl, lt * (c - trh));
215 acc += ad;
216 out[i] = acc;
217 pc = c;
218 }
219}
220
221#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
222#[inline(always)]
223pub unsafe fn wad_avx2(high: &[f64], low: &[f64], close: &[f64], out: &mut [f64]) {
224 #[cfg(target_arch = "x86_64")]
225 #[target_feature(enable = "avx2,fma")]
226 unsafe fn inner(high: &[f64], low: &[f64], close: &[f64], out: &mut [f64]) {
227 let n = close.len();
228 if n == 0 {
229 return;
230 }
231 *out.get_unchecked_mut(0) = 0.0;
232
233 let hp = high.as_ptr();
234 let lp = low.as_ptr();
235 let cp = close.as_ptr();
236 let op = out.as_mut_ptr();
237
238 let mut acc = 0.0f64;
239 let mut pc = *cp;
240 let mut i = 1usize;
241
242 while i + 7 < n {
243 use core::arch::x86_64::{_mm_prefetch, _MM_HINT_T0};
244 if i + 40 < n {
245 _mm_prefetch(cp.add(i + 32) as *const i8, _MM_HINT_T0);
246 _mm_prefetch(hp.add(i + 32) as *const i8, _MM_HINT_T0);
247 _mm_prefetch(lp.add(i + 32) as *const i8, _MM_HINT_T0);
248 }
249
250 let c0 = *cp.add(i);
251 let h0 = *hp.add(i);
252 let l0 = *lp.add(i);
253 let trh0 = if pc > h0 { pc } else { h0 };
254 let trl0 = if pc < l0 { pc } else { l0 };
255 let gt0 = (c0 > pc) as i32 as f64;
256 let lt0 = (c0 < pc) as i32 as f64;
257 let ad0 = gt0.mul_add(c0 - trl0, lt0 * (c0 - trh0));
258 acc += ad0;
259 *op.add(i) = acc;
260
261 let c1 = *cp.add(i + 1);
262 let h1 = *hp.add(i + 1);
263 let l1 = *lp.add(i + 1);
264 let trh1 = if c0 > h1 { c0 } else { h1 };
265 let trl1 = if c0 < l1 { c0 } else { l1 };
266 let gt1 = (c1 > c0) as i32 as f64;
267 let lt1 = (c1 < c0) as i32 as f64;
268 let ad1 = gt1.mul_add(c1 - trl1, lt1 * (c1 - trh1));
269 acc += ad1;
270 *op.add(i + 1) = acc;
271
272 let c2 = *cp.add(i + 2);
273 let h2 = *hp.add(i + 2);
274 let l2 = *lp.add(i + 2);
275 let trh2 = if c1 > h2 { c1 } else { h2 };
276 let trl2 = if c1 < l2 { c1 } else { l2 };
277 let gt2 = (c2 > c1) as i32 as f64;
278 let lt2 = (c2 < c1) as i32 as f64;
279 let ad2 = gt2.mul_add(c2 - trl2, lt2 * (c2 - trh2));
280 acc += ad2;
281 *op.add(i + 2) = acc;
282
283 let c3 = *cp.add(i + 3);
284 let h3 = *hp.add(i + 3);
285 let l3 = *lp.add(i + 3);
286 let trh3 = if c2 > h3 { c2 } else { h3 };
287 let trl3 = if c2 < l3 { c2 } else { l3 };
288 let gt3 = (c3 > c2) as i32 as f64;
289 let lt3 = (c3 < c2) as i32 as f64;
290 let ad3 = gt3.mul_add(c3 - trl3, lt3 * (c3 - trh3));
291 acc += ad3;
292 *op.add(i + 3) = acc;
293
294 let c4 = *cp.add(i + 4);
295 let h4 = *hp.add(i + 4);
296 let l4 = *lp.add(i + 4);
297 let trh4 = if c3 > h4 { c3 } else { h4 };
298 let trl4 = if c3 < l4 { c3 } else { l4 };
299 let gt4 = (c4 > c3) as i32 as f64;
300 let lt4 = (c4 < c3) as i32 as f64;
301 let ad4 = gt4.mul_add(c4 - trl4, lt4 * (c4 - trh4));
302 acc += ad4;
303 *op.add(i + 4) = acc;
304
305 let c5 = *cp.add(i + 5);
306 let h5 = *hp.add(i + 5);
307 let l5 = *lp.add(i + 5);
308 let trh5 = if c4 > h5 { c4 } else { h5 };
309 let trl5 = if c4 < l5 { c4 } else { l5 };
310 let gt5 = (c5 > c4) as i32 as f64;
311 let lt5 = (c5 < c4) as i32 as f64;
312 let ad5 = gt5.mul_add(c5 - trl5, lt5 * (c5 - trh5));
313 acc += ad5;
314 *op.add(i + 5) = acc;
315
316 let c6 = *cp.add(i + 6);
317 let h6 = *hp.add(i + 6);
318 let l6 = *lp.add(i + 6);
319 let trh6 = if c5 > h6 { c5 } else { h6 };
320 let trl6 = if c5 < l6 { c5 } else { l6 };
321 let gt6 = (c6 > c5) as i32 as f64;
322 let lt6 = (c6 < c5) as i32 as f64;
323 let ad6 = gt6.mul_add(c6 - trl6, lt6 * (c6 - trh6));
324 acc += ad6;
325 *op.add(i + 6) = acc;
326
327 let c7 = *cp.add(i + 7);
328 let h7 = *hp.add(i + 7);
329 let l7 = *lp.add(i + 7);
330 let trh7 = if c6 > h7 { c6 } else { h7 };
331 let trl7 = if c6 < l7 { c6 } else { l7 };
332 let gt7 = (c7 > c6) as i32 as f64;
333 let lt7 = (c7 < c6) as i32 as f64;
334 let ad7 = gt7.mul_add(c7 - trl7, lt7 * (c7 - trh7));
335 acc += ad7;
336 *op.add(i + 7) = acc;
337
338 pc = c7;
339 i += 8;
340 }
341
342 while i < n {
343 let c = *cp.add(i);
344 let h = *hp.add(i);
345 let l = *lp.add(i);
346 let trh = if pc > h { pc } else { h };
347 let trl = if pc < l { pc } else { l };
348 let gt = (c > pc) as i32 as f64;
349 let lt = (c < pc) as i32 as f64;
350 let ad = gt.mul_add(c - trl, lt * (c - trh));
351 acc += ad;
352 *op.add(i) = acc;
353 pc = c;
354 i += 1;
355 }
356 }
357
358 inner(high, low, close, out)
359}
360
361#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
362#[inline(always)]
363pub unsafe fn wad_avx512(high: &[f64], low: &[f64], close: &[f64], out: &mut [f64]) {
364 if high.len() <= 64 {
365 wad_avx512_short(high, low, close, out);
366 } else {
367 wad_avx512_long(high, low, close, out);
368 }
369}
370
371#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
372#[inline(always)]
373pub unsafe fn wad_avx512_short(high: &[f64], low: &[f64], close: &[f64], out: &mut [f64]) {
374 #[target_feature(enable = "avx512f,fma")]
375 unsafe fn inner(high: &[f64], low: &[f64], close: &[f64], out: &mut [f64]) {
376 let n = close.len();
377 if n == 0 {
378 return;
379 }
380 *out.get_unchecked_mut(0) = 0.0;
381
382 let hp = high.as_ptr();
383 let lp = low.as_ptr();
384 let cp = close.as_ptr();
385 let op = out.as_mut_ptr();
386
387 let mut acc = 0.0f64;
388 let mut pc = *cp;
389 let mut i = 1usize;
390
391 while i + 7 < n {
392 let c0 = *cp.add(i);
393 let h0 = *hp.add(i);
394 let l0 = *lp.add(i);
395 let trh0 = if pc > h0 { pc } else { h0 };
396 let trl0 = if pc < l0 { pc } else { l0 };
397 let gt0 = (c0 > pc) as i32 as f64;
398 let lt0 = (c0 < pc) as i32 as f64;
399 let ad0 = gt0.mul_add(c0 - trl0, lt0 * (c0 - trh0));
400 acc += ad0;
401 *op.add(i) = acc;
402
403 let c1 = *cp.add(i + 1);
404 let h1 = *hp.add(i + 1);
405 let l1 = *lp.add(i + 1);
406 let trh1 = if c0 > h1 { c0 } else { h1 };
407 let trl1 = if c0 < l1 { c0 } else { l1 };
408 let gt1 = (c1 > c0) as i32 as f64;
409 let lt1 = (c1 < c0) as i32 as f64;
410 let ad1 = gt1.mul_add(c1 - trl1, lt1 * (c1 - trh1));
411 acc += ad1;
412 *op.add(i + 1) = acc;
413
414 let c2 = *cp.add(i + 2);
415 let h2 = *hp.add(i + 2);
416 let l2 = *lp.add(i + 2);
417 let trh2 = if c1 > h2 { c1 } else { h2 };
418 let trl2 = if c1 < l2 { c1 } else { l2 };
419 let gt2 = (c2 > c1) as i32 as f64;
420 let lt2 = (c2 < c1) as i32 as f64;
421 let ad2 = gt2.mul_add(c2 - trl2, lt2 * (c2 - trh2));
422 acc += ad2;
423 *op.add(i + 2) = acc;
424
425 let c3 = *cp.add(i + 3);
426 let h3 = *hp.add(i + 3);
427 let l3 = *lp.add(i + 3);
428 let trh3 = if c2 > h3 { c2 } else { h3 };
429 let trl3 = if c2 < l3 { c2 } else { l3 };
430 let gt3 = (c3 > c2) as i32 as f64;
431 let lt3 = (c3 < c2) as i32 as f64;
432 let ad3 = gt3.mul_add(c3 - trl3, lt3 * (c3 - trh3));
433 acc += ad3;
434 *op.add(i + 3) = acc;
435
436 let c4 = *cp.add(i + 4);
437 let h4 = *hp.add(i + 4);
438 let l4 = *lp.add(i + 4);
439 let trh4 = if c3 > h4 { c3 } else { h4 };
440 let trl4 = if c3 < l4 { c3 } else { l4 };
441 let gt4 = (c4 > c3) as i32 as f64;
442 let lt4 = (c4 < c3) as i32 as f64;
443 let ad4 = gt4.mul_add(c4 - trl4, lt4 * (c4 - trh4));
444 acc += ad4;
445 *op.add(i + 4) = acc;
446
447 let c5 = *cp.add(i + 5);
448 let h5 = *hp.add(i + 5);
449 let l5 = *lp.add(i + 5);
450 let trh5 = if c4 > h5 { c4 } else { h5 };
451 let trl5 = if c4 < l5 { c4 } else { l5 };
452 let gt5 = (c5 > c4) as i32 as f64;
453 let lt5 = (c5 < c4) as i32 as f64;
454 let ad5 = gt5.mul_add(c5 - trl5, lt5 * (c5 - trh5));
455 acc += ad5;
456 *op.add(i + 5) = acc;
457
458 let c6 = *cp.add(i + 6);
459 let h6 = *hp.add(i + 6);
460 let l6 = *lp.add(i + 6);
461 let trh6 = if c5 > h6 { c5 } else { h6 };
462 let trl6 = if c5 < l6 { c5 } else { l6 };
463 let gt6 = (c6 > c5) as i32 as f64;
464 let lt6 = (c6 < c5) as i32 as f64;
465 let ad6 = gt6.mul_add(c6 - trl6, lt6 * (c6 - trh6));
466 acc += ad6;
467 *op.add(i + 6) = acc;
468
469 let c7 = *cp.add(i + 7);
470 let h7 = *hp.add(i + 7);
471 let l7 = *lp.add(i + 7);
472 let trh7 = if c6 > h7 { c6 } else { h7 };
473 let trl7 = if c6 < l7 { c6 } else { l7 };
474 let gt7 = (c7 > c6) as i32 as f64;
475 let lt7 = (c7 < c6) as i32 as f64;
476 let ad7 = gt7.mul_add(c7 - trl7, lt7 * (c7 - trh7));
477 acc += ad7;
478 *op.add(i + 7) = acc;
479
480 pc = c7;
481 i += 8;
482 }
483
484 while i < n {
485 let c = *cp.add(i);
486 let h = *hp.add(i);
487 let l = *lp.add(i);
488 let trh = if pc > h { pc } else { h };
489 let trl = if pc < l { pc } else { l };
490 let gt = (c > pc) as i32 as f64;
491 let lt = (c < pc) as i32 as f64;
492 let ad = gt.mul_add(c - trl, lt * (c - trh));
493 acc += ad;
494 *op.add(i) = acc;
495 pc = c;
496 i += 1;
497 }
498 }
499
500 inner(high, low, close, out)
501}
502#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
503#[inline(always)]
504pub unsafe fn wad_avx512_long(high: &[f64], low: &[f64], close: &[f64], out: &mut [f64]) {
505 #[target_feature(enable = "avx512f,fma")]
506 unsafe fn inner(high: &[f64], low: &[f64], close: &[f64], out: &mut [f64]) {
507 let n = close.len();
508 if n == 0 {
509 return;
510 }
511 *out.get_unchecked_mut(0) = 0.0;
512
513 let hp = high.as_ptr();
514 let lp = low.as_ptr();
515 let cp = close.as_ptr();
516 let op = out.as_mut_ptr();
517
518 let mut acc = 0.0f64;
519 let mut pc = *cp;
520 let mut i = 1usize;
521 while i + 15 < n {
522 use core::arch::x86_64::{_mm_prefetch, _MM_HINT_T0};
523 if i + 96 < n {
524 _mm_prefetch(cp.add(i + 64) as *const i8, _MM_HINT_T0);
525 _mm_prefetch(hp.add(i + 64) as *const i8, _MM_HINT_T0);
526 _mm_prefetch(lp.add(i + 64) as *const i8, _MM_HINT_T0);
527 }
528
529 macro_rules! step {
530 ($off:expr, $pc:expr) => {{
531 let c = *cp.add(i + $off);
532 let h = *hp.add(i + $off);
533 let l = *lp.add(i + $off);
534 let trh = if $pc > h { $pc } else { h };
535 let trl = if $pc < l { $pc } else { l };
536 let gt = (c > $pc) as i32 as f64;
537 let lt = (c < $pc) as i32 as f64;
538 let ad = gt.mul_add(c - trl, lt * (c - trh));
539 acc += ad;
540 *op.add(i + $off) = acc;
541 c
542 }};
543 }
544
545 let c0 = step!(0, pc);
546 let c1 = step!(1, c0);
547 let c2 = step!(2, c1);
548 let c3 = step!(3, c2);
549 let c4 = step!(4, c3);
550 let c5 = step!(5, c4);
551 let c6 = step!(6, c5);
552 let c7 = step!(7, c6);
553 let c8 = step!(8, c7);
554 let c9 = step!(9, c8);
555 let c10 = step!(10, c9);
556 let c11 = step!(11, c10);
557 let c12 = step!(12, c11);
558 let c13 = step!(13, c12);
559 let c14 = step!(14, c13);
560 let c15 = step!(15, c14);
561
562 pc = c15;
563 i += 16;
564 }
565 while i < n {
566 let c = *cp.add(i);
567 let h = *hp.add(i);
568 let l = *lp.add(i);
569 let trh = if pc > h { pc } else { h };
570 let trl = if pc < l { pc } else { l };
571 let gt = (c > pc) as i32 as f64;
572 let lt = (c < pc) as i32 as f64;
573 let ad = gt.mul_add(c - trl, lt * (c - trh));
574 acc += ad;
575 *op.add(i) = acc;
576 pc = c;
577 i += 1;
578 }
579 }
580
581 inner(high, low, close, out)
582}
583
584#[inline(always)]
585pub unsafe fn wad_row_scalar(high: &[f64], low: &[f64], close: &[f64], out: &mut [f64]) {
586 wad_scalar(high, low, close, out)
587}
588#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
589#[inline(always)]
590pub unsafe fn wad_row_avx2(high: &[f64], low: &[f64], close: &[f64], out: &mut [f64]) {
591 wad_avx2(high, low, close, out)
592}
593#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
594#[inline(always)]
595pub unsafe fn wad_row_avx512(high: &[f64], low: &[f64], close: &[f64], out: &mut [f64]) {
596 wad_avx512(high, low, close, out)
597}
598#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
599#[inline(always)]
600pub unsafe fn wad_row_avx512_short(high: &[f64], low: &[f64], close: &[f64], out: &mut [f64]) {
601 wad_avx512_short(high, low, close, out)
602}
603#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
604#[inline(always)]
605pub unsafe fn wad_row_avx512_long(high: &[f64], low: &[f64], close: &[f64], out: &mut [f64]) {
606 wad_avx512_long(high, low, close, out)
607}
608
609#[derive(Debug, Clone)]
610pub struct WadStream {
611 sum: f64,
612 prev_close: Option<f64>,
613}
614impl WadStream {
615 pub fn try_new() -> Result<Self, WadError> {
616 Ok(Self {
617 sum: 0.0,
618 prev_close: None,
619 })
620 }
621 #[inline(always)]
622 pub fn update(&mut self, high: f64, low: f64, close: f64) -> f64 {
623 let pc = match self.prev_close {
624 Some(pc) => pc,
625 None => {
626 self.prev_close = Some(close);
627 return self.sum;
628 }
629 };
630
631 let trh = pc.max(high);
632 let trl = pc.min(low);
633
634 let gt = (close > pc) as i32 as f64;
635 let lt = (close < pc) as i32 as f64;
636 let ad = gt.mul_add(close - trl, lt * (close - trh));
637
638 self.sum += ad;
639 self.prev_close = Some(close);
640 self.sum
641 }
642}
643
644#[derive(Clone, Debug)]
645pub struct WadBatchRange {
646 pub dummy: (usize, usize, usize),
647}
648impl Default for WadBatchRange {
649 fn default() -> Self {
650 Self { dummy: (0, 0, 0) }
651 }
652}
653#[derive(Clone, Debug, Default)]
654pub struct WadBatchBuilder {
655 range: WadBatchRange,
656 kernel: Kernel,
657}
658impl WadBatchBuilder {
659 pub fn new() -> Self {
660 Self::default()
661 }
662 pub fn kernel(mut self, k: Kernel) -> Self {
663 self.kernel = k;
664 self
665 }
666 pub fn apply_slices(
667 self,
668 high: &[f64],
669 low: &[f64],
670 close: &[f64],
671 ) -> Result<WadBatchOutput, WadError> {
672 wad_batch_with_kernel(high, low, close, self.kernel)
673 }
674 pub fn with_default_slices(
675 high: &[f64],
676 low: &[f64],
677 close: &[f64],
678 k: Kernel,
679 ) -> Result<WadBatchOutput, WadError> {
680 WadBatchBuilder::new()
681 .kernel(k)
682 .apply_slices(high, low, close)
683 }
684 pub fn apply_candles(self, c: &Candles) -> Result<WadBatchOutput, WadError> {
685 let high = source_type(c, "high");
686 let low = source_type(c, "low");
687 let close = source_type(c, "close");
688 self.apply_slices(high, low, close)
689 }
690 pub fn with_default_candles(c: &Candles) -> Result<WadBatchOutput, WadError> {
691 WadBatchBuilder::new().kernel(Kernel::Auto).apply_candles(c)
692 }
693}
694
695pub fn wad_batch_with_kernel(
696 high: &[f64],
697 low: &[f64],
698 close: &[f64],
699 k: Kernel,
700) -> Result<WadBatchOutput, WadError> {
701 let kernel = match k {
702 Kernel::Auto => detect_best_batch_kernel(),
703 other if other.is_batch() => other,
704 other => return Err(WadError::InvalidKernelForBatch(other)),
705 };
706 wad_batch_par_slice(high, low, close, kernel)
707}
708
709#[derive(Clone, Debug)]
710pub struct WadBatchOutput {
711 pub values: Vec<f64>,
712 pub rows: usize,
713 pub cols: usize,
714}
715impl WadBatchOutput {
716 pub fn row_for_params(&self, _: &WadParams) -> Option<usize> {
717 Some(0)
718 }
719 pub fn values_for(&self, _: &WadParams) -> Option<&[f64]> {
720 Some(&self.values)
721 }
722}
723
724#[inline(always)]
725pub fn expand_grid(_r: &WadBatchRange) -> Vec<WadParams> {
726 let mut result = Vec::with_capacity(1);
727 result.push(WadParams);
728 result
729}
730
731#[inline(always)]
732pub fn wad_batch_slice(
733 high: &[f64],
734 low: &[f64],
735 close: &[f64],
736 kern: Kernel,
737) -> Result<WadBatchOutput, WadError> {
738 wad_batch_inner(high, low, close, kern, false)
739}
740#[inline(always)]
741pub fn wad_batch_par_slice(
742 high: &[f64],
743 low: &[f64],
744 close: &[f64],
745 kern: Kernel,
746) -> Result<WadBatchOutput, WadError> {
747 wad_batch_inner(high, low, close, kern, true)
748}
749
750#[inline(always)]
751fn wad_batch_inner(
752 high: &[f64],
753 low: &[f64],
754 close: &[f64],
755 kern: Kernel,
756 _parallel: bool,
757) -> Result<WadBatchOutput, WadError> {
758 if high.is_empty() || low.is_empty() || close.is_empty() {
759 return Err(WadError::EmptyInputData);
760 }
761 let len = high.len();
762 if len != low.len() || len != close.len() {
763 let got = if low.len() != len {
764 low.len()
765 } else {
766 close.len()
767 };
768 return Err(WadError::OutputLengthMismatch { expected: len, got });
769 }
770 if high.iter().all(|x| x.is_nan())
771 || low.iter().all(|x| x.is_nan())
772 || close.iter().all(|x| x.is_nan())
773 {
774 return Err(WadError::AllValuesNaN);
775 }
776
777 let mut buf_mu = make_uninit_matrix(1, len);
778 init_matrix_prefixes(&mut buf_mu, len, &[0]);
779
780 let mut guard = ManuallyDrop::new(buf_mu);
781 let out: &mut [f64] =
782 unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
783
784 wad_batch_inner_into(high, low, close, kern, false, out)?;
785
786 let values = unsafe {
787 Vec::from_raw_parts(
788 guard.as_mut_ptr() as *mut f64,
789 guard.len(),
790 guard.capacity(),
791 )
792 };
793
794 Ok(WadBatchOutput {
795 values,
796 rows: 1,
797 cols: len,
798 })
799}
800
801#[inline(always)]
802fn wad_batch_inner_into(
803 high: &[f64],
804 low: &[f64],
805 close: &[f64],
806 kern: Kernel,
807 _parallel: bool,
808 out: &mut [f64],
809) -> Result<(), WadError> {
810 if high.is_empty() || low.is_empty() || close.is_empty() {
811 return Err(WadError::EmptyInputData);
812 }
813 let len = high.len();
814 if len != low.len() || len != close.len() {
815 let got = if low.len() != len {
816 low.len()
817 } else {
818 close.len()
819 };
820 return Err(WadError::OutputLengthMismatch { expected: len, got });
821 }
822 if high.iter().all(|x| x.is_nan())
823 || low.iter().all(|x| x.is_nan())
824 || close.iter().all(|x| x.is_nan())
825 {
826 return Err(WadError::AllValuesNaN);
827 }
828 if out.len() != len {
829 return Err(WadError::OutputLengthMismatch {
830 expected: len,
831 got: out.len(),
832 });
833 }
834
835 let actual = match kern {
836 Kernel::Auto => detect_best_batch_kernel(),
837 k => k,
838 };
839 unsafe {
840 match actual {
841 Kernel::Scalar | Kernel::ScalarBatch => wad_row_scalar(high, low, close, out),
842 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
843 Kernel::Avx2 | Kernel::Avx2Batch => wad_row_avx2(high, low, close, out),
844 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
845 Kernel::Avx512 | Kernel::Avx512Batch => wad_row_avx512(high, low, close, out),
846 _ => unreachable!(),
847 }
848 }
849 Ok(())
850}
851
852#[cfg(test)]
853mod tests {
854 use super::*;
855 use crate::skip_if_unsupported;
856 use crate::utilities::data_loader::read_candles_from_csv;
857 use std::error::Error;
858
859 fn check_wad_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
860 skip_if_unsupported!(kernel, test_name);
861 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
862 let candles = read_candles_from_csv(file_path)?;
863 let input = WadInput::from_candles(&candles);
864 let output = wad_with_kernel(&input, kernel)?;
865 assert_eq!(output.values.len(), candles.close.len());
866 Ok(())
867 }
868
869 fn check_wad_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
870 skip_if_unsupported!(kernel, test_name);
871 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
872 let candles = read_candles_from_csv(file_path)?;
873 let input = WadInput::from_candles(&candles);
874 let output = wad_with_kernel(&input, kernel)?;
875 assert_eq!(output.values.len(), candles.close.len());
876 let expected_last_five_wad = [
877 158503.46790000016,
878 158279.46790000016,
879 158014.46790000016,
880 158186.46790000016,
881 157605.46790000016,
882 ];
883 let start = output.values.len().saturating_sub(5);
884 for (i, &val) in output.values[start..].iter().enumerate() {
885 let exp = expected_last_five_wad[i];
886 assert!(
887 (val - exp).abs() < 1e-4,
888 "[{}] WAD {:?} mismatch at idx {}: got {}, expected {}",
889 test_name,
890 kernel,
891 i,
892 val,
893 exp
894 );
895 }
896 Ok(())
897 }
898
899 fn check_wad_empty_data(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
900 skip_if_unsupported!(kernel, test_name);
901 let input = WadInput::from_slices(&[], &[], &[]);
902 let result = wad_with_kernel(&input, kernel);
903 assert!(result.is_err());
904 Ok(())
905 }
906
907 fn check_wad_all_values_nan(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
908 skip_if_unsupported!(kernel, test_name);
909 let nan_slice = [f64::NAN, f64::NAN, f64::NAN];
910 let input = WadInput::from_slices(&nan_slice, &nan_slice, &nan_slice);
911 let result = wad_with_kernel(&input, kernel);
912 assert!(result.is_err());
913 Ok(())
914 }
915
916 fn check_wad_basic_slice(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
917 skip_if_unsupported!(kernel, test_name);
918 let high = [10.0, 11.0, 11.0, 12.0];
919 let low = [9.0, 9.0, 10.0, 10.0];
920 let close = [9.5, 10.5, 10.5, 11.5];
921 let input = WadInput::from_slices(&high, &low, &close);
922 let output = wad_with_kernel(&input, kernel)?;
923 assert_eq!(output.values.len(), 4);
924 assert!((output.values[0] - 0.0).abs() < 1e-10);
925 assert!((output.values[1] - 1.5).abs() < 1e-10);
926 assert!((output.values[2] - 1.5).abs() < 1e-10);
927 assert!((output.values[3] - 3.0).abs() < 1e-10);
928 Ok(())
929 }
930
931 fn check_wad_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
932 skip_if_unsupported!(kernel, test_name);
933 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
934 let candles = read_candles_from_csv(file_path)?;
935 let high = source_type(&candles, "high");
936 let low = source_type(&candles, "low");
937 let close = source_type(&candles, "close");
938 let batch_output =
939 wad_with_kernel(&WadInput::from_slices(high, low, close), kernel)?.values;
940 let mut stream = WadStream::try_new()?;
941 let mut stream_values = Vec::with_capacity(close.len());
942 for ((&h, &l), &c) in high.iter().zip(low).zip(close) {
943 stream_values.push(stream.update(h, l, c));
944 }
945 assert_eq!(batch_output.len(), stream_values.len());
946 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
947 let diff = (b - s).abs();
948 assert!(
949 diff < 1e-9,
950 "[{}] WAD streaming mismatch at idx {}: batch={}, stream={}, diff={}",
951 test_name,
952 i,
953 b,
954 s,
955 diff
956 );
957 }
958 Ok(())
959 }
960
961 fn check_wad_small_example(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
962 skip_if_unsupported!(kernel, test_name);
963
964 let high = [10.0, 11.0, 12.0, 11.5, 12.5];
965 let low = [9.0, 9.5, 11.0, 10.5, 11.0];
966 let close = [9.5, 10.5, 11.5, 11.0, 12.0];
967 let expected = [0.0, 1.0, 2.0, 1.5, 2.5];
968
969 let input = WadInput::from_slices(&high, &low, &close);
970 let output = wad_with_kernel(&input, kernel)?;
971
972 assert_eq!(output.values.len(), 5);
973
974 for i in 0..5 {
975 let got = output.values[i];
976 let exp = expected[i];
977 assert!(
978 (got - exp).abs() < 1e-10,
979 "[{}] WAD {:?} small example mismatch at idx {}: got {}, expected {}",
980 test_name,
981 kernel,
982 i,
983 got,
984 exp
985 );
986 }
987
988 Ok(())
989 }
990
991 #[cfg(debug_assertions)]
992 fn check_wad_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
993 skip_if_unsupported!(kernel, test_name);
994
995 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
996 let candles = read_candles_from_csv(file_path)?;
997
998 let test_configs = vec![WadParams::default()];
999
1000 for (param_idx, params) in test_configs.iter().enumerate() {
1001 let input = WadInput {
1002 data: WadData::Candles { candles: &candles },
1003 params: params.clone(),
1004 };
1005 let output = wad_with_kernel(&input, kernel)?;
1006
1007 for (i, &val) in output.values.iter().enumerate() {
1008 if val.is_nan() {
1009 continue;
1010 }
1011
1012 let bits = val.to_bits();
1013
1014 if bits == 0x11111111_11111111 {
1015 panic!(
1016 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1017 with params: {:?} (param set {})",
1018 test_name, val, bits, i, params, param_idx
1019 );
1020 }
1021
1022 if bits == 0x22222222_22222222 {
1023 panic!(
1024 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1025 with params: {:?} (param set {})",
1026 test_name, val, bits, i, params, param_idx
1027 );
1028 }
1029
1030 if bits == 0x33333333_33333333 {
1031 panic!(
1032 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1033 with params: {:?} (param set {})",
1034 test_name, val, bits, i, params, param_idx
1035 );
1036 }
1037 }
1038 }
1039
1040 Ok(())
1041 }
1042
1043 #[cfg(not(debug_assertions))]
1044 fn check_wad_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1045 Ok(())
1046 }
1047
1048 #[cfg(debug_assertions)]
1049 fn check_batch_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1050 skip_if_unsupported!(kernel, test_name);
1051
1052 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1053 let candles = read_candles_from_csv(file_path)?;
1054
1055 let test_configs = vec!["high", "low", "close"];
1056
1057 for (cfg_idx, &source) in test_configs.iter().enumerate() {
1058 let output = wad_batch_with_kernel(
1059 source_type(&candles, "high"),
1060 source_type(&candles, "low"),
1061 source_type(&candles, "close"),
1062 kernel,
1063 )?;
1064
1065 for (idx, &val) in output.values.iter().enumerate() {
1066 if val.is_nan() {
1067 continue;
1068 }
1069
1070 let bits = val.to_bits();
1071 let row = idx / output.cols;
1072 let col = idx % output.cols;
1073
1074 if bits == 0x11111111_11111111 {
1075 panic!(
1076 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1077 at row {} col {} (flat index {}) - source: {}",
1078 test_name, cfg_idx, val, bits, row, col, idx, source
1079 );
1080 }
1081
1082 if bits == 0x22222222_22222222 {
1083 panic!(
1084 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1085 at row {} col {} (flat index {}) - source: {}",
1086 test_name, cfg_idx, val, bits, row, col, idx, source
1087 );
1088 }
1089
1090 if bits == 0x33333333_33333333 {
1091 panic!(
1092 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1093 at row {} col {} (flat index {}) - source: {}",
1094 test_name, cfg_idx, val, bits, row, col, idx, source
1095 );
1096 }
1097 }
1098 }
1099
1100 Ok(())
1101 }
1102
1103 #[cfg(not(debug_assertions))]
1104 fn check_batch_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1105 Ok(())
1106 }
1107
1108 macro_rules! generate_all_wad_tests {
1109 ($($test_fn:ident),*) => {
1110 paste::paste! {
1111 $(
1112 #[test]
1113 fn [<$test_fn _scalar_f64>]() {
1114 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1115 }
1116 )*
1117 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1118 $(
1119 #[test]
1120 fn [<$test_fn _avx2_f64>]() {
1121 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1122 }
1123 #[test]
1124 fn [<$test_fn _avx512_f64>]() {
1125 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1126 }
1127 )*
1128 }
1129 }
1130 }
1131
1132 generate_all_wad_tests!(
1133 check_wad_partial_params,
1134 check_wad_accuracy,
1135 check_wad_empty_data,
1136 check_wad_all_values_nan,
1137 check_wad_basic_slice,
1138 check_wad_streaming,
1139 check_wad_small_example,
1140 check_wad_no_poison
1141 );
1142
1143 macro_rules! gen_batch_tests {
1144 ($fn_name:ident) => {
1145 paste::paste! {
1146 #[test] fn [<$fn_name _scalar>]() {
1147 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1148 }
1149 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1150 #[test] fn [<$fn_name _avx2>]() {
1151 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1152 }
1153 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1154 #[test] fn [<$fn_name _avx512>]() {
1155 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1156 }
1157 }
1158 };
1159 }
1160
1161 gen_batch_tests!(check_batch_no_poison);
1162
1163 #[cfg(feature = "proptest")]
1164 #[allow(clippy::float_cmp)]
1165 fn check_wad_property(
1166 test_name: &str,
1167 kernel: Kernel,
1168 ) -> Result<(), Box<dyn std::error::Error>> {
1169 use proptest::prelude::*;
1170 skip_if_unsupported!(kernel, test_name);
1171
1172 let strat = (1usize..=200).prop_flat_map(|len| {
1173 prop::collection::vec(
1174 (1.0f64..1000.0f64).prop_flat_map(|base_price| {
1175 let range = base_price * 0.1;
1176 let low = base_price - range;
1177 let high = base_price + range;
1178
1179 (low..=high).prop_map(move |close| {
1180 let actual_low = low.min(close);
1181 let actual_high = high.max(close);
1182 (actual_high, actual_low, close)
1183 })
1184 }),
1185 len,
1186 )
1187 });
1188
1189 proptest::test_runner::TestRunner::default().run(&strat, |ohlc_data| {
1190 let (highs, lows, closes): (Vec<f64>, Vec<f64>, Vec<f64>) =
1191 ohlc_data.into_iter().map(|(h, l, c)| (h, l, c)).unzip3();
1192
1193 let input = WadInput::from_slices(&highs, &lows, &closes);
1194
1195 let WadOutput { values: out } = wad_with_kernel(&input, kernel).unwrap();
1196 let WadOutput { values: ref_out } = wad_with_kernel(&input, Kernel::Scalar).unwrap();
1197
1198 prop_assert_eq!(out[0], 0.0, "First WAD value must be 0.0");
1199 prop_assert_eq!(ref_out[0], 0.0, "First reference WAD value must be 0.0");
1200
1201 let mut expected_sum = 0.0;
1202 let mut prev_close = closes[0];
1203
1204 for i in 1..closes.len() {
1205 let trh = if prev_close > highs[i] {
1206 prev_close
1207 } else {
1208 highs[i]
1209 };
1210 let trl = if prev_close < lows[i] {
1211 prev_close
1212 } else {
1213 lows[i]
1214 };
1215
1216 let ad = if closes[i] > prev_close {
1217 closes[i] - trl
1218 } else if closes[i] < prev_close {
1219 closes[i] - trh
1220 } else {
1221 0.0
1222 };
1223
1224 expected_sum += ad;
1225
1226 prop_assert!(
1227 (out[i] - expected_sum).abs() <= 1e-9,
1228 "WAD mismatch at idx {}: got {}, expected {}",
1229 i,
1230 out[i],
1231 expected_sum
1232 );
1233
1234 prev_close = closes[i];
1235 }
1236
1237 for i in 0..out.len() {
1238 let y = out[i];
1239 let r = ref_out[i];
1240
1241 if !y.is_finite() || !r.is_finite() {
1242 prop_assert_eq!(
1243 y.to_bits(),
1244 r.to_bits(),
1245 "NaN/Inf mismatch at idx {}: {} vs {}",
1246 i,
1247 y,
1248 r
1249 );
1250 continue;
1251 }
1252
1253 let ulp_diff = y.to_bits().abs_diff(r.to_bits());
1254 prop_assert!(
1255 (y - r).abs() <= 1e-9 || ulp_diff <= 4,
1256 "Kernel mismatch at idx {}: {} vs {} (diff: {}, ulp: {})",
1257 i,
1258 y,
1259 r,
1260 (y - r).abs(),
1261 ulp_diff
1262 );
1263 }
1264
1265 for i in 1..closes.len() {
1266 if (closes[i] - closes[i - 1]).abs() < f64::EPSILON {
1267 let ad_change = if i == 1 {
1268 out[i] - 0.0
1269 } else {
1270 out[i] - out[i - 1]
1271 };
1272 prop_assert!(
1273 ad_change.abs() < 1e-9,
1274 "WAD should not change when close[{}] == close[{}], but changed by {}",
1275 i,
1276 i - 1,
1277 ad_change
1278 );
1279 }
1280 }
1281
1282 if closes.len() == 1 {
1283 prop_assert_eq!(out.len(), 1);
1284 prop_assert_eq!(out[0], 0.0);
1285 }
1286
1287 if closes
1288 .windows(2)
1289 .all(|w| (w[0] - w[1]).abs() < f64::EPSILON)
1290 {
1291 for i in 0..out.len() {
1292 prop_assert!(
1293 out[i].abs() < 1e-9,
1294 "WAD should be 0 for constant prices, but got {} at index {}",
1295 out[i],
1296 i
1297 );
1298 }
1299 }
1300
1301 let strictly_increasing = closes.windows(2).all(|w| w[1] > w[0]);
1302 if strictly_increasing && closes.len() > 1 {
1303 for i in 1..out.len() {
1304 prop_assert!(
1305 out[i] >= out[i-1] - 1e-9,
1306 "WAD should increase monotonically for strictly increasing prices, but {} < {} at index {}",
1307 out[i], out[i-1], i
1308 );
1309 }
1310 }
1311
1312 let strictly_decreasing = closes.windows(2).all(|w| w[1] < w[0]);
1313 if strictly_decreasing && closes.len() > 1 {
1314 for i in 1..out.len() {
1315 prop_assert!(
1316 out[i] <= out[i-1] + 1e-9,
1317 "WAD should decrease monotonically for strictly decreasing prices, but {} > {} at index {}",
1318 out[i], out[i-1], i
1319 );
1320 }
1321 }
1322
1323 Ok(())
1324 })?;
1325
1326 Ok(())
1327 }
1328
1329 trait Unzip3<A, B, C> {
1330 fn unzip3(self) -> (Vec<A>, Vec<B>, Vec<C>);
1331 }
1332
1333 impl<A, B, C, I> Unzip3<A, B, C> for I
1334 where
1335 I: Iterator<Item = (A, B, C)>,
1336 {
1337 fn unzip3(self) -> (Vec<A>, Vec<B>, Vec<C>) {
1338 let (mut a_vec, mut b_vec, mut c_vec) = (Vec::new(), Vec::new(), Vec::new());
1339 for (a, b, c) in self {
1340 a_vec.push(a);
1341 b_vec.push(b);
1342 c_vec.push(c);
1343 }
1344 (a_vec, b_vec, c_vec)
1345 }
1346 }
1347
1348 #[cfg(feature = "proptest")]
1349 generate_all_wad_tests!(check_wad_property);
1350
1351 #[test]
1352 fn test_wad_into_matches_api() -> Result<(), Box<dyn Error>> {
1353 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1354 let candles = read_candles_from_csv(file_path)?;
1355 let input = WadInput::from_candles(&candles);
1356
1357 let baseline = wad(&input)?.values;
1358
1359 let mut out = vec![0.0; baseline.len()];
1360 #[allow(unused_variables)]
1361 {
1362 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1363 {
1364 wad_into(&input, &mut out)?;
1365 }
1366 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1367 {
1368 wad_into_slice(&mut out, &input, Kernel::Auto)?;
1369 }
1370 }
1371
1372 assert_eq!(baseline.len(), out.len());
1373
1374 fn eq_or_both_nan(a: f64, b: f64) -> bool {
1375 (a.is_nan() && b.is_nan()) || (a == b)
1376 }
1377
1378 for i in 0..baseline.len() {
1379 assert!(
1380 eq_or_both_nan(baseline[i], out[i]),
1381 "Mismatch at index {}: baseline={}, into={}",
1382 i,
1383 baseline[i],
1384 out[i]
1385 );
1386 }
1387
1388 Ok(())
1389 }
1390}
1391
1392#[inline(always)]
1393fn wad_prepare<'a>(
1394 input: &'a WadInput,
1395 _kernel: Kernel,
1396) -> Result<(&'a [f64], &'a [f64], &'a [f64], usize, Kernel), WadError> {
1397 let (high, low, close): (&[f64], &[f64], &[f64]) = match &input.data {
1398 WadData::Candles { candles } => (
1399 source_type(candles, "high"),
1400 source_type(candles, "low"),
1401 source_type(candles, "close"),
1402 ),
1403 WadData::Slices { high, low, close } => (*high, *low, *close),
1404 };
1405
1406 if high.is_empty() || low.is_empty() || close.is_empty() {
1407 return Err(WadError::EmptyInputData);
1408 }
1409 let len = high.len();
1410 if len != low.len() || len != close.len() {
1411 let got = if low.len() != len {
1412 low.len()
1413 } else {
1414 close.len()
1415 };
1416 return Err(WadError::OutputLengthMismatch { expected: len, got });
1417 }
1418 if high.iter().all(|x| x.is_nan())
1419 || low.iter().all(|x| x.is_nan())
1420 || close.iter().all(|x| x.is_nan())
1421 {
1422 return Err(WadError::AllValuesNaN);
1423 }
1424
1425 let chosen = match _kernel {
1426 Kernel::Auto => detect_best_kernel(),
1427 other => other,
1428 };
1429
1430 Ok((high, low, close, len, chosen))
1431}
1432
1433#[inline]
1434pub fn wad_into_slice(dst: &mut [f64], input: &WadInput, kern: Kernel) -> Result<(), WadError> {
1435 let (high, low, close, len, chosen) = wad_prepare(input, kern)?;
1436
1437 if dst.len() != len {
1438 return Err(WadError::OutputLengthMismatch {
1439 expected: len,
1440 got: dst.len(),
1441 });
1442 }
1443
1444 unsafe {
1445 match chosen {
1446 Kernel::Scalar | Kernel::ScalarBatch => wad_scalar(high, low, close, dst),
1447 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1448 Kernel::Avx2 | Kernel::Avx2Batch => wad_avx2(high, low, close, dst),
1449 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1450 Kernel::Avx512 | Kernel::Avx512Batch => wad_avx512(high, low, close, dst),
1451 _ => unreachable!(),
1452 }
1453 }
1454
1455 Ok(())
1456}
1457
1458#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1459
1460pub fn wad_into(input: &WadInput, out: &mut [f64]) -> Result<(), WadError> {
1461 wad_into_slice(out, input, Kernel::Auto)
1462}
1463
1464#[cfg(all(feature = "python", feature = "cuda"))]
1465#[pyfunction(name = "wad_cuda_dev")]
1466#[pyo3(signature = (high_f32, low_f32, close_f32, device_id=0))]
1467pub fn wad_cuda_dev_py(
1468 py: Python<'_>,
1469 high_f32: PyReadonlyArray1<'_, f32>,
1470 low_f32: PyReadonlyArray1<'_, f32>,
1471 close_f32: PyReadonlyArray1<'_, f32>,
1472 device_id: usize,
1473) -> PyResult<DeviceArrayF32Py> {
1474 if !cuda_available() {
1475 return Err(PyValueError::new_err("CUDA not available"));
1476 }
1477
1478 let high = high_f32.as_slice()?;
1479 let low = low_f32.as_slice()?;
1480 let close = close_f32.as_slice()?;
1481
1482 let inner = py.allow_threads(|| {
1483 let cuda = CudaWad::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1484 cuda.wad_series_dev(high, low, close)
1485 .map_err(|e| PyValueError::new_err(e.to_string()))
1486 })?;
1487
1488 let handle = make_device_array_py(device_id, inner)?;
1489 Ok(handle)
1490}
1491
1492#[cfg(all(feature = "python", feature = "cuda"))]
1493#[pyfunction(name = "wad_cuda_batch_dev")]
1494#[pyo3(signature = (high_f32, low_f32, close_f32, device_id=0))]
1495pub fn wad_cuda_batch_dev_py(
1496 py: Python<'_>,
1497 high_f32: PyReadonlyArray1<'_, f32>,
1498 low_f32: PyReadonlyArray1<'_, f32>,
1499 close_f32: PyReadonlyArray1<'_, f32>,
1500 device_id: usize,
1501) -> PyResult<DeviceArrayF32Py> {
1502 if !cuda_available() {
1503 return Err(PyValueError::new_err("CUDA not available"));
1504 }
1505 let high = high_f32.as_slice()?;
1506 let low = low_f32.as_slice()?;
1507 let close = close_f32.as_slice()?;
1508 let inner = py.allow_threads(|| {
1509 let cuda = CudaWad::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1510 cuda.wad_batch_dev(high, low, close)
1511 .map_err(|e| PyValueError::new_err(e.to_string()))
1512 })?;
1513 let handle = make_device_array_py(device_id, inner)?;
1514 Ok(handle)
1515}
1516
1517#[cfg(all(feature = "python", feature = "cuda"))]
1518#[pyfunction(name = "wad_cuda_many_series_one_param_dev")]
1519#[pyo3(signature = (high_tm_f32, low_tm_f32, close_tm_f32, device_id=0))]
1520pub fn wad_cuda_many_series_one_param_dev_py(
1521 py: Python<'_>,
1522 high_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1523 low_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1524 close_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1525 device_id: usize,
1526) -> PyResult<DeviceArrayF32Py> {
1527 use numpy::PyUntypedArrayMethods;
1528 if !cuda_available() {
1529 return Err(PyValueError::new_err("CUDA not available"));
1530 }
1531 let rows = high_tm_f32.shape()[0];
1532 let cols = high_tm_f32.shape()[1];
1533 if low_tm_f32.shape() != [rows, cols] || close_tm_f32.shape() != [rows, cols] {
1534 return Err(PyValueError::new_err("high/low/close shapes must match"));
1535 }
1536 let high = high_tm_f32.as_slice()?;
1537 let low = low_tm_f32.as_slice()?;
1538 let close = close_tm_f32.as_slice()?;
1539 let inner = py.allow_threads(|| {
1540 let cuda = CudaWad::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1541 cuda.wad_many_series_one_param_time_major_dev(high, low, close, cols, rows)
1542 .map_err(|e| PyValueError::new_err(e.to_string()))
1543 })?;
1544 let handle = make_device_array_py(device_id, inner)?;
1545 Ok(handle)
1546}
1547
1548#[cfg(feature = "python")]
1549#[pyfunction(name = "wad")]
1550#[pyo3(signature = (high, low, close, kernel=None))]
1551pub fn wad_py<'py>(
1552 py: Python<'py>,
1553 high: PyReadonlyArray1<'py, f64>,
1554 low: PyReadonlyArray1<'py, f64>,
1555 close: PyReadonlyArray1<'py, f64>,
1556 kernel: Option<&str>,
1557) -> PyResult<Bound<'py, PyArray1<f64>>> {
1558 let high_slice = high.as_slice()?;
1559 let low_slice = low.as_slice()?;
1560 let close_slice = close.as_slice()?;
1561 let kern = validate_kernel(kernel, false)?;
1562
1563 let input = WadInput::from_slices(high_slice, low_slice, close_slice);
1564
1565 let result_vec: Vec<f64> = py
1566 .allow_threads(|| wad_with_kernel(&input, kern).map(|o| o.values))
1567 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1568
1569 Ok(result_vec.into_pyarray(py))
1570}
1571
1572#[cfg(feature = "python")]
1573#[pyclass(name = "WadStream")]
1574pub struct WadStreamPy {
1575 stream: WadStream,
1576}
1577
1578#[cfg(feature = "python")]
1579#[pymethods]
1580impl WadStreamPy {
1581 #[new]
1582 fn new() -> PyResult<Self> {
1583 let stream = WadStream::try_new().map_err(|e| PyValueError::new_err(e.to_string()))?;
1584 Ok(WadStreamPy { stream })
1585 }
1586
1587 fn update(&mut self, high: f64, low: f64, close: f64) -> f64 {
1588 self.stream.update(high, low, close)
1589 }
1590}
1591
1592#[cfg(feature = "python")]
1593#[pyfunction(name = "wad_batch")]
1594#[pyo3(signature = (high, low, close, kernel=None))]
1595pub fn wad_batch_py<'py>(
1596 py: Python<'py>,
1597 high: PyReadonlyArray1<'py, f64>,
1598 low: PyReadonlyArray1<'py, f64>,
1599 close: PyReadonlyArray1<'py, f64>,
1600 kernel: Option<&str>,
1601) -> PyResult<Bound<'py, PyDict>> {
1602 use pyo3::types::PyDict;
1603
1604 let high_slice = high.as_slice()?;
1605 let low_slice = low.as_slice()?;
1606 let close_slice = close.as_slice()?;
1607
1608 let cols = high_slice.len();
1609 let rows = 1usize;
1610
1611 let total = rows
1612 .checked_mul(cols)
1613 .ok_or_else(|| PyValueError::new_err("wad_batch: size overflow in rows*cols"))?;
1614
1615 let out_arr = unsafe { numpy::PyArray1::<f64>::new(py, [total], false) };
1616 let out_slice = unsafe { out_arr.as_slice_mut()? };
1617
1618 let kern = validate_kernel(kernel, true)?;
1619 py.allow_threads(|| {
1620 wad_batch_inner_into(high_slice, low_slice, close_slice, kern, true, out_slice)
1621 })
1622 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1623
1624 let dict = PyDict::new(py);
1625 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1626 Ok(dict)
1627}
1628
1629#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1630#[wasm_bindgen]
1631pub fn wad_js(high: &[f64], low: &[f64], close: &[f64]) -> Result<Vec<f64>, JsValue> {
1632 let input = WadInput::from_slices(high, low, close);
1633
1634 let mut output = vec![0.0; high.len()];
1635
1636 wad_into_slice(&mut output, &input, Kernel::Auto)
1637 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1638
1639 Ok(output)
1640}
1641
1642#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1643#[wasm_bindgen]
1644pub fn wad_alloc(len: usize) -> *mut f64 {
1645 let mut vec = Vec::<f64>::with_capacity(len);
1646 let ptr = vec.as_mut_ptr();
1647 std::mem::forget(vec);
1648 ptr
1649}
1650
1651#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1652#[wasm_bindgen]
1653pub fn wad_free(ptr: *mut f64, len: usize) {
1654 if !ptr.is_null() {
1655 unsafe {
1656 let _ = Vec::from_raw_parts(ptr, len, len);
1657 }
1658 }
1659}
1660
1661#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1662#[wasm_bindgen]
1663pub fn wad_into(
1664 high_ptr: *const f64,
1665 low_ptr: *const f64,
1666 close_ptr: *const f64,
1667 out_ptr: *mut f64,
1668 len: usize,
1669) -> Result<(), JsValue> {
1670 if high_ptr.is_null() || low_ptr.is_null() || close_ptr.is_null() || out_ptr.is_null() {
1671 return Err(JsValue::from_str("null pointer passed to wad_into"));
1672 }
1673
1674 unsafe {
1675 let high = std::slice::from_raw_parts(high_ptr, len);
1676 let low = std::slice::from_raw_parts(low_ptr, len);
1677 let close = std::slice::from_raw_parts(close_ptr, len);
1678
1679 let input = WadInput::from_slices(high, low, close);
1680
1681 if high_ptr as *const f64 == out_ptr as *const f64
1682 || low_ptr as *const f64 == out_ptr as *const f64
1683 || close_ptr as *const f64 == out_ptr as *const f64
1684 {
1685 let mut temp = vec![0.0; len];
1686 wad_into_slice(&mut temp, &input, Kernel::Auto)
1687 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1688 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1689 out.copy_from_slice(&temp);
1690 } else {
1691 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1692 wad_into_slice(out, &input, Kernel::Auto)
1693 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1694 }
1695
1696 Ok(())
1697 }
1698}
1699
1700#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1701#[wasm_bindgen]
1702pub fn wad_batch_into(
1703 high_ptr: *const f64,
1704 low_ptr: *const f64,
1705 close_ptr: *const f64,
1706 out_ptr: *mut f64,
1707 len: usize,
1708) -> Result<usize, JsValue> {
1709 if high_ptr.is_null() || low_ptr.is_null() || close_ptr.is_null() || out_ptr.is_null() {
1710 return Err(JsValue::from_str("null pointer passed to wad_batch_into"));
1711 }
1712 unsafe {
1713 let high = std::slice::from_raw_parts(high_ptr, len);
1714 let low = std::slice::from_raw_parts(low_ptr, len);
1715 let close = std::slice::from_raw_parts(close_ptr, len);
1716 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1717 wad_batch_inner_into(high, low, close, detect_best_kernel(), false, out)
1718 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1719 Ok(1)
1720 }
1721}
1722
1723#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1724#[derive(Serialize, Deserialize)]
1725pub struct WadBatchConfig {
1726 pub dummy: (usize, usize, usize),
1727}
1728
1729#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1730#[derive(Serialize, Deserialize)]
1731pub struct WadBatchJsOutput {
1732 pub values: Vec<f64>,
1733 pub rows: usize,
1734 pub cols: usize,
1735}
1736
1737#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1738#[wasm_bindgen(js_name = wad_batch)]
1739pub fn wad_batch_unified_js(
1740 high: &[f64],
1741 low: &[f64],
1742 close: &[f64],
1743 _config: JsValue,
1744) -> Result<JsValue, JsValue> {
1745 let out = wad_batch_inner(high, low, close, detect_best_kernel(), false)
1746 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1747 let js = WadBatchJsOutput {
1748 values: out.values,
1749 rows: out.rows,
1750 cols: out.cols,
1751 };
1752 serde_wasm_bindgen::to_value(&js)
1753 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1754}