1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::CudaAd;
3use crate::utilities::data_loader::Candles;
4#[cfg(all(feature = "python", feature = "cuda"))]
5use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
6use crate::utilities::enums::Kernel;
7use crate::utilities::helpers::{
8 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, make_uninit_matrix,
9};
10#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
11use core::arch::x86_64::*;
12#[cfg(all(feature = "python", feature = "cuda"))]
13use numpy::PyReadonlyArray1;
14#[cfg(feature = "python")]
15use pyo3::exceptions::PyValueError;
16#[cfg(feature = "python")]
17use pyo3::types::{PyDict, PyList, PyListMethods};
18#[cfg(feature = "python")]
19use pyo3::{pyfunction, Bound, PyResult, Python};
20#[cfg(not(target_arch = "wasm32"))]
21use rayon::prelude::*;
22use thiserror::Error;
23
24#[derive(Debug, Clone)]
25pub enum AdData<'a> {
26 Candles {
27 candles: &'a Candles,
28 },
29 Slices {
30 high: &'a [f64],
31 low: &'a [f64],
32 close: &'a [f64],
33 volume: &'a [f64],
34 },
35}
36
37#[derive(Debug, Clone, Default)]
38pub struct AdParams {}
39
40#[derive(Debug, Clone)]
41pub struct AdInput<'a> {
42 pub data: AdData<'a>,
43 pub params: AdParams,
44}
45
46impl<'a> AdInput<'a> {
47 #[inline]
48 pub fn from_candles(candles: &'a Candles, params: AdParams) -> Self {
49 Self {
50 data: AdData::Candles { candles },
51 params,
52 }
53 }
54
55 #[inline]
56 pub fn from_slices(
57 high: &'a [f64],
58 low: &'a [f64],
59 close: &'a [f64],
60 volume: &'a [f64],
61 params: AdParams,
62 ) -> Self {
63 Self {
64 data: AdData::Slices {
65 high,
66 low,
67 close,
68 volume,
69 },
70 params,
71 }
72 }
73
74 #[inline]
75 pub fn with_default_candles(candles: &'a Candles) -> Self {
76 Self::from_candles(candles, AdParams::default())
77 }
78}
79
80#[derive(Debug, Clone)]
81pub struct AdOutput {
82 pub values: Vec<f64>,
83}
84
85#[derive(Copy, Clone, Debug, Default)]
86pub struct AdBuilder {
87 kernel: Kernel,
88}
89
90impl AdBuilder {
91 #[inline(always)]
92 pub fn new() -> Self {
93 Self {
94 kernel: Kernel::Auto,
95 }
96 }
97
98 #[inline(always)]
99 pub fn kernel(mut self, k: Kernel) -> Self {
100 self.kernel = k;
101 self
102 }
103
104 #[inline(always)]
105 pub fn apply(self, c: &Candles) -> Result<AdOutput, AdError> {
106 let input = AdInput::from_candles(c, AdParams::default());
107 ad_with_kernel(&input, self.kernel)
108 }
109
110 #[inline(always)]
111 pub fn apply_slices(
112 self,
113 high: &[f64],
114 low: &[f64],
115 close: &[f64],
116 volume: &[f64],
117 ) -> Result<AdOutput, AdError> {
118 let input = AdInput::from_slices(high, low, close, volume, AdParams::default());
119 ad_with_kernel(&input, self.kernel)
120 }
121
122 #[inline(always)]
123 pub fn into_stream(self) -> Result<AdStream, AdError> {
124 AdStream::try_new()
125 }
126}
127
128#[derive(Debug, Error)]
129pub enum AdError {
130 #[error("ad: candle field error: {0}")]
131 CandleFieldError(String),
132 #[error(
133 "ad: Data length mismatch: high={high_len}, low={low_len}, close={close_len}, volume={volume_len}"
134 )]
135 DataLengthMismatch {
136 high_len: usize,
137 low_len: usize,
138 close_len: usize,
139 volume_len: usize,
140 },
141 #[error("ad: invalid period: period={period}, data_len={data_len}")]
142 InvalidPeriod { period: usize, data_len: usize },
143 #[error("ad: output length mismatch: expected={expected}, got={got}")]
144 OutputLengthMismatch { expected: usize, got: usize },
145 #[error("ad: not enough valid data: needed={needed}, valid={valid}")]
146 NotEnoughValidData { needed: usize, valid: usize },
147 #[error("ad: empty input data")]
148 EmptyInputData,
149 #[error("ad: all values are NaN")]
150 AllValuesNaN,
151 #[error("ad: invalid range: start={start}, end={end}, step={step}")]
152 InvalidRange {
153 start: isize,
154 end: isize,
155 step: isize,
156 },
157 #[error("ad: invalid kernel for batch: {0:?}")]
158 InvalidKernelForBatch(Kernel),
159 #[error("ad: invalid input: {0}")]
160 InvalidInput(String),
161}
162
163#[inline]
164pub fn ad(input: &AdInput) -> Result<AdOutput, AdError> {
165 ad_with_kernel(input, Kernel::Auto)
166}
167
168pub fn ad_with_kernel(input: &AdInput, kernel: Kernel) -> Result<AdOutput, AdError> {
169 let (high, low, close, volume): (&[f64], &[f64], &[f64], &[f64]) = match &input.data {
170 AdData::Candles { candles } => {
171 let high = candles
172 .select_candle_field("high")
173 .map_err(|e| AdError::CandleFieldError(e.to_string()))?;
174 let low = candles
175 .select_candle_field("low")
176 .map_err(|e| AdError::CandleFieldError(e.to_string()))?;
177 let close = candles
178 .select_candle_field("close")
179 .map_err(|e| AdError::CandleFieldError(e.to_string()))?;
180 let volume = candles
181 .select_candle_field("volume")
182 .map_err(|e| AdError::CandleFieldError(e.to_string()))?;
183 (high, low, close, volume)
184 }
185 AdData::Slices {
186 high,
187 low,
188 close,
189 volume,
190 } => (*high, *low, *close, *volume),
191 };
192
193 if high.len() != low.len() || high.len() != close.len() || high.len() != volume.len() {
194 return Err(AdError::DataLengthMismatch {
195 high_len: high.len(),
196 low_len: low.len(),
197 close_len: close.len(),
198 volume_len: volume.len(),
199 });
200 }
201
202 let size = high.len();
203 if size == 0 {
204 return Err(AdError::EmptyInputData);
205 }
206
207 let mut chosen = match kernel {
208 Kernel::Auto => detect_best_kernel(),
209 k => k,
210 };
211
212 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
213 if matches!(kernel, Kernel::Auto) && matches!(chosen, Kernel::Avx512 | Kernel::Avx512Batch) {
214 chosen = Kernel::Avx2;
215 }
216
217 let mut out = alloc_with_nan_prefix(size, 0);
218
219 unsafe {
220 match chosen {
221 Kernel::Scalar | Kernel::ScalarBatch => ad_scalar(high, low, close, volume, &mut out),
222 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
223 Kernel::Avx2 | Kernel::Avx2Batch => ad_avx2(high, low, close, volume, &mut out),
224 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
225 Kernel::Avx512 | Kernel::Avx512Batch => ad_avx512(high, low, close, volume, &mut out),
226 _ => unreachable!(),
227 }
228 }
229 Ok(AdOutput { values: out })
230}
231
232#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
233#[inline]
234pub fn ad_into(input: &AdInput, out: &mut [f64]) -> Result<(), AdError> {
235 ad_into_slice(out, input, Kernel::Auto)
236}
237
238pub fn ad_into_slice(dst: &mut [f64], input: &AdInput, kern: Kernel) -> Result<(), AdError> {
239 let (high, low, close, volume) = match &input.data {
240 AdData::Candles { candles, .. } => (
241 &candles.high[..],
242 &candles.low[..],
243 &candles.close[..],
244 &candles.volume[..],
245 ),
246 AdData::Slices {
247 high,
248 low,
249 close,
250 volume,
251 } => (*high, *low, *close, *volume),
252 };
253
254 if high.is_empty() {
255 return Err(AdError::EmptyInputData);
256 }
257
258 if high.len() != low.len() || high.len() != close.len() || high.len() != volume.len() {
259 return Err(AdError::DataLengthMismatch {
260 high_len: high.len(),
261 low_len: low.len(),
262 close_len: close.len(),
263 volume_len: volume.len(),
264 });
265 }
266
267 if dst.len() != high.len() {
268 return Err(AdError::OutputLengthMismatch {
269 expected: high.len(),
270 got: dst.len(),
271 });
272 }
273
274 match kern {
275 Kernel::Auto => {
276 let mut k = detect_best_kernel();
277 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
278 if matches!(k, Kernel::Avx512) {
279 k = Kernel::Avx2;
280 }
281 match k {
282 Kernel::Scalar => ad_scalar(high, low, close, volume, dst),
283 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
284 Kernel::Avx2 => ad_avx2(high, low, close, volume, dst),
285 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
286 Kernel::Avx512 => ad_avx512(high, low, close, volume, dst),
287 _ => ad_scalar(high, low, close, volume, dst),
288 }
289 }
290 Kernel::Scalar => ad_scalar(high, low, close, volume, dst),
291 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
292 Kernel::Avx2 => ad_avx2(high, low, close, volume, dst),
293 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
294 Kernel::Avx512 => ad_avx512(high, low, close, volume, dst),
295 _ => ad_scalar(high, low, close, volume, dst),
296 }
297
298 Ok(())
299}
300
301#[inline]
302pub fn ad_scalar(high: &[f64], low: &[f64], close: &[f64], volume: &[f64], out: &mut [f64]) {
303 debug_assert_eq!(high.len(), low.len());
304 debug_assert_eq!(high.len(), close.len());
305 debug_assert_eq!(high.len(), volume.len());
306 debug_assert_eq!(high.len(), out.len());
307
308 let mut sum = 0.0f64;
309 for ((((&h, &l), &c), &v), o) in high
310 .iter()
311 .zip(low)
312 .zip(close)
313 .zip(volume)
314 .zip(out.iter_mut())
315 {
316 let hl = h - l;
317 if hl != 0.0 {
318 let num = (c - l) - (h - c);
319 sum += (num / hl) * v;
320 }
321 *o = sum;
322 }
323}
324
325#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
326#[inline]
327pub fn ad_avx2(high: &[f64], low: &[f64], close: &[f64], volume: &[f64], out: &mut [f64]) {
328 unsafe { ad_avx2_inner(high, low, close, volume, out) }
329}
330
331#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
332#[target_feature(enable = "avx2")]
333unsafe fn ad_avx2_inner(high: &[f64], low: &[f64], close: &[f64], volume: &[f64], out: &mut [f64]) {
334 use core::arch::x86_64::*;
335
336 let n = high.len();
337 let h = high.as_ptr();
338 let l = low.as_ptr();
339 let c = close.as_ptr();
340 let v = volume.as_ptr();
341 let o = out.as_mut_ptr();
342
343 let mut base = 0.0f64;
344 let mut i = 0usize;
345
346 while i + 4 <= n {
347 let hv = _mm256_loadu_pd(h.add(i));
348 let lv = _mm256_loadu_pd(l.add(i));
349 let cv = _mm256_loadu_pd(c.add(i));
350 let vv = _mm256_loadu_pd(v.add(i));
351
352 let hl = _mm256_sub_pd(hv, lv);
353 let num = _mm256_sub_pd(_mm256_sub_pd(cv, lv), _mm256_sub_pd(hv, cv));
354 let mfm = _mm256_div_pd(num, hl);
355 let mfv_unmasked = _mm256_mul_pd(mfm, vv);
356
357 let z = _mm256_set1_pd(0.0);
358 let mask = _mm256_cmp_pd(hl, z, _CMP_NEQ_OQ);
359 let mfv = _mm256_and_pd(mfv_unmasked, mask);
360
361 let mut tmp: [f64; 4] = core::mem::zeroed();
362 _mm256_storeu_pd(tmp.as_mut_ptr(), mfv);
363 *o.add(i + 0) = {
364 base += tmp[0];
365 base
366 };
367 *o.add(i + 1) = {
368 base += tmp[1];
369 base
370 };
371 *o.add(i + 2) = {
372 base += tmp[2];
373 base
374 };
375 *o.add(i + 3) = {
376 base += tmp[3];
377 base
378 };
379
380 i += 4;
381 }
382
383 while i < n {
384 let hi = *h.add(i);
385 let lo = *l.add(i);
386 let cl = *c.add(i);
387 let vo = *v.add(i);
388 let hl = hi - lo;
389 if hl != 0.0 {
390 let num = (cl - lo) - (hi - cl);
391 base += (num / hl) * vo;
392 }
393 *o.add(i) = base;
394 i += 1;
395 }
396}
397
398#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
399#[inline]
400pub fn ad_avx512(high: &[f64], low: &[f64], close: &[f64], volume: &[f64], out: &mut [f64]) {
401 unsafe { ad_avx512_inner(high, low, close, volume, out) }
402}
403
404#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
405#[target_feature(enable = "avx512f")]
406unsafe fn ad_avx512_inner(
407 high: &[f64],
408 low: &[f64],
409 close: &[f64],
410 volume: &[f64],
411 out: &mut [f64],
412) {
413 use core::arch::x86_64::*;
414
415 let n = high.len();
416 let h = high.as_ptr();
417 let l = low.as_ptr();
418 let c = close.as_ptr();
419 let v = volume.as_ptr();
420 let o = out.as_mut_ptr();
421
422 let mut base = 0.0f64;
423 let mut i = 0usize;
424
425 while i + 8 <= n {
426 let hv = _mm512_loadu_pd(h.add(i));
427 let lv = _mm512_loadu_pd(l.add(i));
428 let cv = _mm512_loadu_pd(c.add(i));
429 let vv = _mm512_loadu_pd(v.add(i));
430
431 let hl = _mm512_sub_pd(hv, lv);
432 let num = _mm512_sub_pd(_mm512_sub_pd(cv, lv), _mm512_sub_pd(hv, cv));
433 let mfm = _mm512_div_pd(num, hl);
434 let mfv_unmasked = _mm512_mul_pd(mfm, vv);
435
436 let mask = _mm512_cmpneq_pd_mask(hl, _mm512_set1_pd(0.0));
437 let mfv = _mm512_maskz_mov_pd(mask, mfv_unmasked);
438
439 let mut tmp = core::mem::MaybeUninit::<[f64; 8]>::uninit();
440 _mm512_storeu_pd(tmp.as_mut_ptr() as *mut f64, mfv);
441 let vals = tmp.assume_init();
442
443 *o.add(i + 0) = {
444 base += vals[0];
445 base
446 };
447 *o.add(i + 1) = {
448 base += vals[1];
449 base
450 };
451 *o.add(i + 2) = {
452 base += vals[2];
453 base
454 };
455 *o.add(i + 3) = {
456 base += vals[3];
457 base
458 };
459 *o.add(i + 4) = {
460 base += vals[4];
461 base
462 };
463 *o.add(i + 5) = {
464 base += vals[5];
465 base
466 };
467 *o.add(i + 6) = {
468 base += vals[6];
469 base
470 };
471 *o.add(i + 7) = {
472 base += vals[7];
473 base
474 };
475
476 i += 8;
477 }
478
479 while i < n {
480 let hi = *h.add(i);
481 let lo = *l.add(i);
482 let cl = *c.add(i);
483 let vo = *v.add(i);
484 let hl = hi - lo;
485 if hl != 0.0 {
486 let num = (cl - lo) - (hi - cl);
487 base += (num / hl) * vo;
488 }
489 *o.add(i) = base;
490 i += 1;
491 }
492}
493
494#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
495#[inline]
496pub fn ad_avx512_short(high: &[f64], low: &[f64], close: &[f64], volume: &[f64], out: &mut [f64]) {
497 ad_avx512(high, low, close, volume, out)
498}
499
500#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
501#[inline]
502pub fn ad_avx512_long(high: &[f64], low: &[f64], close: &[f64], volume: &[f64], out: &mut [f64]) {
503 ad_avx512(high, low, close, volume, out)
504}
505
506#[inline]
507pub fn ad_batch_with_kernel(data: &AdBatchInput, k: Kernel) -> Result<AdBatchOutput, AdError> {
508 let mut kernel = match k {
509 Kernel::Auto => detect_best_batch_kernel(),
510 other if other.is_batch() => other,
511 other => return Err(AdError::InvalidKernelForBatch(other)),
512 };
513 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
514 if matches!(k, Kernel::Auto) && matches!(kernel, Kernel::Avx512Batch) {
515 kernel = Kernel::Avx2Batch;
516 }
517
518 let simd = match kernel {
519 Kernel::Avx512Batch => Kernel::Avx512,
520 Kernel::Avx2Batch => Kernel::Avx2,
521 Kernel::ScalarBatch => Kernel::Scalar,
522 _ => unreachable!(),
523 };
524 ad_batch_par_slice(data, simd)
525}
526
527#[derive(Clone, Debug)]
528pub struct AdBatchInput<'a> {
529 pub highs: &'a [&'a [f64]],
530 pub lows: &'a [&'a [f64]],
531 pub closes: &'a [&'a [f64]],
532 pub volumes: &'a [&'a [f64]],
533}
534
535#[derive(Clone, Debug)]
536pub struct AdBatchOutput {
537 pub values: Vec<f64>,
538 pub rows: usize,
539 pub cols: usize,
540}
541
542#[inline(always)]
543pub fn ad_batch_slice(data: &AdBatchInput, kern: Kernel) -> Result<AdBatchOutput, AdError> {
544 ad_batch_inner(data, kern, false)
545}
546
547#[inline(always)]
548pub fn ad_batch_par_slice(data: &AdBatchInput, kern: Kernel) -> Result<AdBatchOutput, AdError> {
549 ad_batch_inner(data, kern, true)
550}
551
552fn ad_batch_inner(
553 data: &AdBatchInput,
554 kern: Kernel,
555 parallel: bool,
556) -> Result<AdBatchOutput, AdError> {
557 let rows = data.highs.len();
558 let cols = if rows > 0 { data.highs[0].len() } else { 0 };
559 let len = rows
560 .checked_mul(cols)
561 .ok_or_else(|| AdError::InvalidInput("rows*cols overflow".into()))?;
562
563 let mut buf_mu = make_uninit_matrix(rows, cols);
564 let values = unsafe {
565 let ptr = buf_mu.as_mut_ptr() as *mut f64;
566 let slice = std::slice::from_raw_parts_mut(ptr, len);
567
568 ad_batch_inner_into(data, kern, parallel, slice)?;
569
570 Vec::from_raw_parts(ptr, len, len)
571 };
572 std::mem::forget(buf_mu);
573
574 Ok(AdBatchOutput { values, rows, cols })
575}
576
577fn ad_batch_inner_into(
578 data: &AdBatchInput,
579 kern: Kernel,
580 parallel: bool,
581 out: &mut [f64],
582) -> Result<(), AdError> {
583 let rows = data.highs.len();
584 let cols = if rows > 0 { data.highs[0].len() } else { 0 };
585
586 if data.lows.len() != rows || data.closes.len() != rows || data.volumes.len() != rows {
587 return Err(AdError::DataLengthMismatch {
588 high_len: data.highs.len(),
589 low_len: data.lows.len(),
590 close_len: data.closes.len(),
591 volume_len: data.volumes.len(),
592 });
593 }
594
595 for row in 0..rows {
596 let h_len = data.highs[row].len();
597 let l_len = data.lows[row].len();
598 let c_len = data.closes[row].len();
599 let v_len = data.volumes[row].len();
600
601 if h_len != cols || l_len != cols || c_len != cols || v_len != cols {
602 return Err(AdError::DataLengthMismatch {
603 high_len: h_len,
604 low_len: l_len,
605 close_len: c_len,
606 volume_len: v_len,
607 });
608 }
609 }
610
611 let expected = rows
612 .checked_mul(cols)
613 .ok_or_else(|| AdError::InvalidInput("rows*cols overflow".into()))?;
614 if out.len() != expected {
615 return Err(AdError::OutputLengthMismatch {
616 expected,
617 got: out.len(),
618 });
619 }
620
621 let mut actual = match kern {
622 Kernel::Auto => detect_best_batch_kernel(),
623 k => k,
624 };
625 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
626 if matches!(kern, Kernel::Auto) && matches!(actual, Kernel::Avx512Batch) {
627 actual = Kernel::Avx2Batch;
628 }
629
630 let do_row = |row: usize, dst: &mut [f64]| unsafe {
631 match actual {
632 Kernel::Scalar | Kernel::ScalarBatch => ad_row_scalar(
633 data.highs[row],
634 data.lows[row],
635 data.closes[row],
636 data.volumes[row],
637 dst,
638 ),
639 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
640 Kernel::Avx2 | Kernel::Avx2Batch => ad_row_avx2(
641 data.highs[row],
642 data.lows[row],
643 data.closes[row],
644 data.volumes[row],
645 dst,
646 ),
647 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
648 Kernel::Avx512 | Kernel::Avx512Batch => ad_row_avx512(
649 data.highs[row],
650 data.lows[row],
651 data.closes[row],
652 data.volumes[row],
653 dst,
654 ),
655 _ => ad_row_scalar(
656 data.highs[row],
657 data.lows[row],
658 data.closes[row],
659 data.volumes[row],
660 dst,
661 ),
662 }
663 };
664
665 if parallel {
666 #[cfg(not(target_arch = "wasm32"))]
667 {
668 use rayon::prelude::*;
669 out.par_chunks_mut(cols)
670 .enumerate()
671 .for_each(|(r, s)| do_row(r, s));
672 }
673 #[cfg(target_arch = "wasm32")]
674 {
675 for (r, s) in out.chunks_mut(cols).enumerate() {
676 do_row(r, s);
677 }
678 }
679 } else {
680 for (r, s) in out.chunks_mut(cols).enumerate() {
681 do_row(r, s);
682 }
683 }
684
685 Ok(())
686}
687
688#[inline(always)]
689pub unsafe fn ad_row_scalar(
690 high: &[f64],
691 low: &[f64],
692 close: &[f64],
693 volume: &[f64],
694 out: &mut [f64],
695) {
696 ad_scalar(high, low, close, volume, out)
697}
698
699#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
700#[inline(always)]
701pub unsafe fn ad_row_avx2(
702 high: &[f64],
703 low: &[f64],
704 close: &[f64],
705 volume: &[f64],
706 out: &mut [f64],
707) {
708 ad_avx2(high, low, close, volume, out)
709}
710
711#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
712#[inline(always)]
713pub unsafe fn ad_row_avx512(
714 high: &[f64],
715 low: &[f64],
716 close: &[f64],
717 volume: &[f64],
718 out: &mut [f64],
719) {
720 ad_avx512(high, low, close, volume, out)
721}
722
723#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
724#[inline(always)]
725pub unsafe fn ad_row_avx512_short(
726 high: &[f64],
727 low: &[f64],
728 close: &[f64],
729 volume: &[f64],
730 out: &mut [f64],
731) {
732 ad_avx512(high, low, close, volume, out)
733}
734
735#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
736#[inline(always)]
737pub unsafe fn ad_row_avx512_long(
738 high: &[f64],
739 low: &[f64],
740 close: &[f64],
741 volume: &[f64],
742 out: &mut [f64],
743) {
744 ad_avx512(high, low, close, volume, out)
745}
746
747#[derive(Debug, Clone)]
748pub struct AdStream {
749 sum: f64,
750}
751
752impl AdStream {
753 #[inline(always)]
754 pub fn try_new() -> Result<Self, AdError> {
755 Ok(Self { sum: 0.0 })
756 }
757
758 #[inline(always)]
759 pub fn update(&mut self, high: f64, low: f64, close: f64, volume: f64) -> f64 {
760 if volume == 0.0 {
761 return self.sum;
762 }
763
764 let hl = high - low;
765 if hl != 0.0 {
766 let num = (close - low) - (high - close);
767
768 self.sum += (num / hl) * volume;
769 }
770 self.sum
771 }
772}
773
774#[cfg(all(feature = "python", feature = "cuda"))]
775use cust::context::Context;
776#[cfg(all(feature = "python", feature = "cuda"))]
777use cust::memory::DeviceBuffer;
778#[cfg(all(feature = "python", feature = "cuda"))]
779use std::sync::Arc;
780#[cfg(all(feature = "python", feature = "cuda"))]
781#[pyclass(module = "ta_indicators.cuda", unsendable)]
782pub struct AdDeviceArrayF32Py {
783 pub(crate) buf: Option<DeviceBuffer<f32>>,
784 pub(crate) rows: usize,
785 pub(crate) cols: usize,
786 pub(crate) _ctx: Arc<Context>,
787 pub(crate) device_id: u32,
788}
789#[cfg(all(feature = "python", feature = "cuda"))]
790#[pymethods]
791impl AdDeviceArrayF32Py {
792 #[getter]
793 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
794 let d = PyDict::new(py);
795 d.set_item("shape", (self.rows, self.cols))?;
796 d.set_item("typestr", "<f4")?;
797 d.set_item(
798 "strides",
799 (
800 self.cols * std::mem::size_of::<f32>(),
801 std::mem::size_of::<f32>(),
802 ),
803 )?;
804 let ptr = self
805 .buf
806 .as_ref()
807 .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?
808 .as_device_ptr()
809 .as_raw() as usize;
810 d.set_item("data", (ptr, false))?;
811
812 d.set_item("version", 3)?;
813 Ok(d)
814 }
815
816 fn __dlpack_device__(&self) -> (i32, i32) {
817 (2, self.device_id as i32)
818 }
819
820 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
821 fn __dlpack__<'py>(
822 &mut self,
823 py: Python<'py>,
824 stream: Option<pyo3::PyObject>,
825 max_version: Option<pyo3::PyObject>,
826 dl_device: Option<pyo3::PyObject>,
827 copy: Option<pyo3::PyObject>,
828 ) -> PyResult<PyObject> {
829 let (kdl, alloc_dev) = self.__dlpack_device__();
830 if let Some(dev_obj) = dl_device.as_ref() {
831 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
832 if dev_ty != kdl || dev_id != alloc_dev {
833 let wants_copy = copy
834 .as_ref()
835 .and_then(|c| c.extract::<bool>(py).ok())
836 .unwrap_or(false);
837 if wants_copy {
838 return Err(PyValueError::new_err(
839 "device copy not implemented for __dlpack__",
840 ));
841 } else {
842 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
843 }
844 }
845 }
846 }
847 let _ = stream;
848
849 let buf = self
850 .buf
851 .take()
852 .ok_or_else(|| PyValueError::new_err("__dlpack__ may only be called once"))?;
853
854 let rows = self.rows;
855 let cols = self.cols;
856
857 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
858
859 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
860 }
861}
862#[cfg(all(feature = "python", feature = "cuda"))]
863#[pyfunction(name = "ad_cuda_dev")]
864#[pyo3(signature = (high_f32, low_f32, close_f32, volume_f32, device_id=0))]
865pub fn ad_cuda_dev_py(
866 py: Python<'_>,
867 high_f32: PyReadonlyArray1<'_, f32>,
868 low_f32: PyReadonlyArray1<'_, f32>,
869 close_f32: PyReadonlyArray1<'_, f32>,
870 volume_f32: PyReadonlyArray1<'_, f32>,
871 device_id: usize,
872) -> PyResult<AdDeviceArrayF32Py> {
873 use crate::cuda::cuda_available;
874 if !cuda_available() {
875 return Err(PyValueError::new_err("CUDA not available"));
876 }
877
878 let high = high_f32.as_slice()?;
879 let low = low_f32.as_slice()?;
880 let close = close_f32.as_slice()?;
881 let volume = volume_f32.as_slice()?;
882
883 let (buf, rows, cols, ctx, dev_id) = py.allow_threads(|| {
884 let cuda = CudaAd::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
885 let out = cuda
886 .ad_series_dev(high, low, close, volume)
887 .map_err(|e| PyValueError::new_err(e.to_string()))?;
888 let ctx = cuda.context_arc();
889 Ok::<_, pyo3::PyErr>((out.buf, out.rows, out.cols, ctx, cuda.device_id()))
890 })?;
891
892 Ok(AdDeviceArrayF32Py {
893 buf: Some(buf),
894 rows,
895 cols,
896 _ctx: ctx,
897 device_id: dev_id,
898 })
899}
900
901#[cfg(all(feature = "python", feature = "cuda"))]
902#[pyfunction(name = "ad_cuda_many_series_one_param_dev")]
903#[pyo3(signature = (high_tm_f32, low_tm_f32, close_tm_f32, volume_tm_f32, cols, rows, device_id=0))]
904pub fn ad_cuda_many_series_one_param_dev_py(
905 py: Python<'_>,
906 high_tm_f32: PyReadonlyArray1<'_, f32>,
907 low_tm_f32: PyReadonlyArray1<'_, f32>,
908 close_tm_f32: PyReadonlyArray1<'_, f32>,
909 volume_tm_f32: PyReadonlyArray1<'_, f32>,
910 cols: usize,
911 rows: usize,
912 device_id: usize,
913) -> PyResult<AdDeviceArrayF32Py> {
914 use crate::cuda::cuda_available;
915 if !cuda_available() {
916 return Err(PyValueError::new_err("CUDA not available"));
917 }
918 let high_tm = high_tm_f32.as_slice()?;
919 let low_tm = low_tm_f32.as_slice()?;
920 let close_tm = close_tm_f32.as_slice()?;
921 let volume_tm = volume_tm_f32.as_slice()?;
922
923 let (buf, r_out, c_out, ctx, dev_id) = py.allow_threads(|| {
924 let cuda = CudaAd::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
925 let out = cuda
926 .ad_many_series_one_param_time_major_dev(
927 high_tm, low_tm, close_tm, volume_tm, cols, rows,
928 )
929 .map_err(|e| PyValueError::new_err(e.to_string()))?;
930 let ctx = cuda.context_arc();
931 Ok::<_, pyo3::PyErr>((out.buf, out.rows, out.cols, ctx, cuda.device_id()))
932 })?;
933
934 Ok(AdDeviceArrayF32Py {
935 buf: Some(buf),
936 rows: r_out,
937 cols: c_out,
938 _ctx: ctx,
939 device_id: dev_id,
940 })
941}
942
943#[derive(Clone, Debug, Default)]
944pub struct AdBatchBuilder {
945 pub kernel: Kernel,
946}
947
948impl AdBatchBuilder {
949 pub fn new() -> Self {
950 Self {
951 kernel: Kernel::Auto,
952 }
953 }
954 pub fn kernel(mut self, k: Kernel) -> Self {
955 self.kernel = k;
956 self
957 }
958
959 pub fn apply_slices(
960 self,
961 highs: &[&[f64]],
962 lows: &[&[f64]],
963 closes: &[&[f64]],
964 volumes: &[&[f64]],
965 ) -> Result<AdBatchOutput, AdError> {
966 let batch = AdBatchInput {
967 highs,
968 lows,
969 closes,
970 volumes,
971 };
972 ad_batch_with_kernel(&batch, self.kernel)
973 }
974}
975
976#[cfg(feature = "python")]
977use numpy::{IntoPyArray, PyArray1};
978#[cfg(feature = "python")]
979use pyo3::prelude::*;
980#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
981use wasm_bindgen::prelude::*;
982
983#[cfg(feature = "python")]
984#[pyfunction(name = "ad")]
985#[pyo3(signature = (high, low, close, volume, kernel=None))]
986
987pub fn ad_py<'py>(
988 py: Python<'py>,
989 high: numpy::PyReadonlyArray1<'py, f64>,
990 low: numpy::PyReadonlyArray1<'py, f64>,
991 close: numpy::PyReadonlyArray1<'py, f64>,
992 volume: numpy::PyReadonlyArray1<'py, f64>,
993 kernel: Option<&str>,
994) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
995 use numpy::{IntoPyArray, PyArrayMethods};
996
997 let high_slice = high.as_slice()?;
998 let low_slice = low.as_slice()?;
999 let close_slice = close.as_slice()?;
1000 let volume_slice = volume.as_slice()?;
1001
1002 if high_slice.is_empty()
1003 || low_slice.is_empty()
1004 || close_slice.is_empty()
1005 || volume_slice.is_empty()
1006 {
1007 return Err(PyValueError::new_err("Not enough data"));
1008 }
1009
1010 let kern = crate::utilities::kernel_validation::validate_kernel(kernel, false)?;
1011
1012 let input = AdInput::from_slices(
1013 high_slice,
1014 low_slice,
1015 close_slice,
1016 volume_slice,
1017 AdParams::default(),
1018 );
1019
1020 let result_vec: Vec<f64> = py
1021 .allow_threads(|| ad_with_kernel(&input, kern).map(|o| o.values))
1022 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1023
1024 Ok(result_vec.into_pyarray(py))
1025}
1026
1027#[cfg(feature = "python")]
1028#[pyclass(name = "AdStream")]
1029pub struct AdStreamPy {
1030 stream: AdStream,
1031}
1032
1033#[cfg(feature = "python")]
1034#[pymethods]
1035impl AdStreamPy {
1036 #[new]
1037 fn new() -> PyResult<Self> {
1038 let stream = AdStream::try_new().map_err(|e| PyValueError::new_err(e.to_string()))?;
1039 Ok(AdStreamPy { stream })
1040 }
1041
1042 fn update(&mut self, high: f64, low: f64, close: f64, volume: f64) -> f64 {
1043 self.stream.update(high, low, close, volume)
1044 }
1045}
1046
1047#[cfg(feature = "python")]
1048#[pyfunction(name = "ad_batch")]
1049#[pyo3(signature = (highs, lows, closes, volumes, kernel=None))]
1050
1051pub fn ad_batch_py<'py>(
1052 py: Python<'py>,
1053 highs: &Bound<'py, PyList>,
1054 lows: &Bound<'py, PyList>,
1055 closes: &Bound<'py, PyList>,
1056 volumes: &Bound<'py, PyList>,
1057 kernel: Option<&str>,
1058) -> PyResult<Bound<'py, PyDict>> {
1059 use numpy::{PyArray1, PyArrayMethods, PyReadonlyArray1};
1060 use pyo3::types::PyDict;
1061
1062 let rows = highs.len();
1063 if lows.len() != rows || closes.len() != rows || volumes.len() != rows {
1064 return Err(PyValueError::new_err(
1065 "All input lists must have the same length",
1066 ));
1067 }
1068
1069 let mut high_arrays: Vec<PyReadonlyArray1<f64>> = Vec::with_capacity(rows);
1070 let mut low_arrays: Vec<PyReadonlyArray1<f64>> = Vec::with_capacity(rows);
1071 let mut close_arrays: Vec<PyReadonlyArray1<f64>> = Vec::with_capacity(rows);
1072 let mut volume_arrays: Vec<PyReadonlyArray1<f64>> = Vec::with_capacity(rows);
1073
1074 for i in 0..rows {
1075 let h = highs.get_item(i)?.extract::<PyReadonlyArray1<f64>>()?;
1076 let l = lows.get_item(i)?.extract::<PyReadonlyArray1<f64>>()?;
1077 let c = closes.get_item(i)?.extract::<PyReadonlyArray1<f64>>()?;
1078 let v = volumes.get_item(i)?.extract::<PyReadonlyArray1<f64>>()?;
1079
1080 let n = h.len()?;
1081 if l.len()? != n || c.len()? != n || v.len()? != n {
1082 return Err(PyValueError::new_err(
1083 "Rows must have equal lengths across OHLCV arrays",
1084 ));
1085 }
1086 high_arrays.push(h);
1087 low_arrays.push(l);
1088 close_arrays.push(c);
1089 volume_arrays.push(v);
1090 }
1091
1092 let high_slices: Vec<&[f64]> = high_arrays.iter().map(|a| a.as_slice().unwrap()).collect();
1093 let low_slices: Vec<&[f64]> = low_arrays.iter().map(|a| a.as_slice().unwrap()).collect();
1094 let close_slices: Vec<&[f64]> = close_arrays.iter().map(|a| a.as_slice().unwrap()).collect();
1095 let volume_slices: Vec<&[f64]> = volume_arrays
1096 .iter()
1097 .map(|a| a.as_slice().unwrap())
1098 .collect();
1099
1100 let cols = if rows > 0 { high_slices[0].len() } else { 0 };
1101 let total = rows
1102 .checked_mul(cols)
1103 .ok_or_else(|| PyValueError::new_err("rows*cols overflow in ad_batch"))?;
1104 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1105 let out_slice = unsafe { out_arr.as_slice_mut()? };
1106
1107 let kern = crate::utilities::kernel_validation::validate_kernel(kernel, true)?;
1108
1109 py.allow_threads(|| -> Result<(), AdError> {
1110 let batch_input = AdBatchInput {
1111 highs: &high_slices,
1112 lows: &low_slices,
1113 closes: &close_slices,
1114 volumes: &volume_slices,
1115 };
1116
1117 let actual = match kern {
1118 Kernel::Auto => detect_best_batch_kernel(),
1119 k => k,
1120 };
1121 ad_batch_inner_into(&batch_input, actual, true, out_slice)
1122 })
1123 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1124
1125 let dict = PyDict::new(py);
1126 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1127 dict.set_item("rows", rows)?;
1128 dict.set_item("cols", cols)?;
1129 Ok(dict)
1130}
1131
1132#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1133#[wasm_bindgen]
1134pub fn ad_js(
1135 high: &[f64],
1136 low: &[f64],
1137 close: &[f64],
1138 volume: &[f64],
1139) -> Result<Vec<f64>, JsValue> {
1140 if high.is_empty() || low.is_empty() || close.is_empty() || volume.is_empty() {
1141 return Err(JsValue::from_str("Not enough data"));
1142 }
1143
1144 let input = AdInput::from_slices(high, low, close, volume, AdParams::default());
1145
1146 let mut output = vec![0.0; high.len()];
1147 ad_into_slice(&mut output, &input, Kernel::Auto)
1148 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1149
1150 Ok(output)
1151}
1152
1153#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1154#[wasm_bindgen]
1155pub fn ad_batch_js(
1156 highs_flat: &[f64],
1157 lows_flat: &[f64],
1158 closes_flat: &[f64],
1159 volumes_flat: &[f64],
1160 rows: usize,
1161) -> Result<Vec<f64>, JsValue> {
1162 if highs_flat.is_empty() || rows == 0 {
1163 return Err(JsValue::from_str("Empty input data"));
1164 }
1165
1166 let cols = highs_flat.len() / rows;
1167 let check = rows
1168 .checked_mul(cols)
1169 .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
1170 if highs_flat.len() != check
1171 || lows_flat.len() != check
1172 || closes_flat.len() != check
1173 || volumes_flat.len() != check
1174 {
1175 return Err(JsValue::from_str(
1176 "Input arrays must have rows*cols elements",
1177 ));
1178 }
1179
1180 let mut high_slices = Vec::with_capacity(rows);
1181 let mut low_slices = Vec::with_capacity(rows);
1182 let mut close_slices = Vec::with_capacity(rows);
1183 let mut volume_slices = Vec::with_capacity(rows);
1184
1185 for i in 0..rows {
1186 let start = i * cols;
1187 let end = start + cols;
1188 high_slices.push(&highs_flat[start..end]);
1189 low_slices.push(&lows_flat[start..end]);
1190 close_slices.push(&closes_flat[start..end]);
1191 volume_slices.push(&volumes_flat[start..end]);
1192 }
1193
1194 let batch_input = AdBatchInput {
1195 highs: &high_slices,
1196 lows: &low_slices,
1197 closes: &close_slices,
1198 volumes: &volume_slices,
1199 };
1200
1201 ad_batch_with_kernel(&batch_input, Kernel::ScalarBatch)
1202 .map(|o| o.values)
1203 .map_err(|e| JsValue::from_str(&e.to_string()))
1204}
1205
1206#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1207#[wasm_bindgen]
1208pub fn ad_batch_metadata_js(rows: usize, cols: usize) -> Vec<f64> {
1209 vec![rows as f64, cols as f64]
1210}
1211
1212#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1213#[wasm_bindgen]
1214pub fn ad_alloc(len: usize) -> *mut f64 {
1215 let mut vec = Vec::<f64>::with_capacity(len);
1216 let ptr = vec.as_mut_ptr();
1217 std::mem::forget(vec);
1218 ptr
1219}
1220
1221#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1222#[wasm_bindgen]
1223pub fn ad_free(ptr: *mut f64, len: usize) {
1224 if !ptr.is_null() {
1225 unsafe {
1226 let _ = Vec::from_raw_parts(ptr, len, len);
1227 }
1228 }
1229}
1230
1231#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1232#[wasm_bindgen]
1233pub fn ad_into(
1234 high_ptr: *const f64,
1235 low_ptr: *const f64,
1236 close_ptr: *const f64,
1237 volume_ptr: *const f64,
1238 out_ptr: *mut f64,
1239 len: usize,
1240) -> Result<(), JsValue> {
1241 if high_ptr.is_null()
1242 || low_ptr.is_null()
1243 || close_ptr.is_null()
1244 || volume_ptr.is_null()
1245 || out_ptr.is_null()
1246 {
1247 return Err(JsValue::from_str("Null pointer provided"));
1248 }
1249
1250 unsafe {
1251 let high = std::slice::from_raw_parts(high_ptr, len);
1252 let low = std::slice::from_raw_parts(low_ptr, len);
1253 let close = std::slice::from_raw_parts(close_ptr, len);
1254 let volume = std::slice::from_raw_parts(volume_ptr, len);
1255
1256 let input = AdInput::from_slices(high, low, close, volume, AdParams::default());
1257
1258 if high_ptr as *const f64 == out_ptr
1259 || low_ptr as *const f64 == out_ptr
1260 || close_ptr as *const f64 == out_ptr
1261 || volume_ptr as *const f64 == out_ptr
1262 {
1263 let mut temp = vec![0.0; len];
1264 ad_into_slice(&mut temp, &input, Kernel::Auto)
1265 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1266 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1267 out.copy_from_slice(&temp);
1268 } else {
1269 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1270 ad_into_slice(out, &input, Kernel::Auto)
1271 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1272 }
1273
1274 Ok(())
1275 }
1276}
1277
1278#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1279use serde::{Deserialize, Serialize};
1280
1281#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1282#[derive(Serialize, Deserialize)]
1283pub struct AdBatchJsOutput {
1284 pub values: Vec<f64>,
1285 pub rows: usize,
1286 pub cols: usize,
1287}
1288
1289#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1290#[wasm_bindgen(js_name = "ad_batch")]
1291pub fn ad_batch_unified_js(
1292 highs_flat: &[f64],
1293 lows_flat: &[f64],
1294 closes_flat: &[f64],
1295 volumes_flat: &[f64],
1296 rows: usize,
1297) -> Result<JsValue, JsValue> {
1298 if rows == 0 {
1299 return Err(JsValue::from_str("rows must be > 0"));
1300 }
1301 if highs_flat.is_empty() {
1302 return Err(JsValue::from_str("empty inputs"));
1303 }
1304 let cols = highs_flat.len() / rows;
1305 let check = rows
1306 .checked_mul(cols)
1307 .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
1308 if lows_flat.len() != check || closes_flat.len() != check || volumes_flat.len() != check {
1309 return Err(JsValue::from_str(
1310 "Input arrays must have rows*cols elements",
1311 ));
1312 }
1313
1314 let mut highs = Vec::with_capacity(rows);
1315 let mut lows = Vec::with_capacity(rows);
1316 let mut closes = Vec::with_capacity(rows);
1317 let mut volumes = Vec::with_capacity(rows);
1318 for r in 0..rows {
1319 let s = r * cols;
1320 let e = s + cols;
1321 highs.push(&highs_flat[s..e]);
1322 lows.push(&lows_flat[s..e]);
1323 closes.push(&closes_flat[s..e]);
1324 volumes.push(&volumes_flat[s..e]);
1325 }
1326
1327 let batch = AdBatchInput {
1328 highs: &highs,
1329 lows: &lows,
1330 closes: &closes,
1331 volumes: &volumes,
1332 };
1333 let out = ad_batch_with_kernel(&batch, Kernel::Auto)
1334 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1335
1336 let packed = AdBatchJsOutput {
1337 values: out.values,
1338 rows: out.rows,
1339 cols: out.cols,
1340 };
1341 serde_wasm_bindgen::to_value(&packed)
1342 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1343}
1344
1345#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1346#[wasm_bindgen]
1347pub fn ad_batch_into(
1348 highs_ptr: *const f64,
1349 lows_ptr: *const f64,
1350 closes_ptr: *const f64,
1351 volumes_ptr: *const f64,
1352 out_ptr: *mut f64,
1353 rows: usize,
1354 cols: usize,
1355) -> Result<(), JsValue> {
1356 if highs_ptr.is_null()
1357 || lows_ptr.is_null()
1358 || closes_ptr.is_null()
1359 || volumes_ptr.is_null()
1360 || out_ptr.is_null()
1361 {
1362 return Err(JsValue::from_str("null pointer"));
1363 }
1364 unsafe {
1365 let check = rows
1366 .checked_mul(cols)
1367 .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
1368 let highs_flat = std::slice::from_raw_parts(highs_ptr, check);
1369 let lows_flat = std::slice::from_raw_parts(lows_ptr, check);
1370 let closes_flat = std::slice::from_raw_parts(closes_ptr, check);
1371 let volumes_flat = std::slice::from_raw_parts(volumes_ptr, check);
1372 let out = std::slice::from_raw_parts_mut(out_ptr, check);
1373
1374 let mut highs = Vec::with_capacity(rows);
1375 let mut lows = Vec::with_capacity(rows);
1376 let mut closes = Vec::with_capacity(rows);
1377 let mut volumes = Vec::with_capacity(rows);
1378 for r in 0..rows {
1379 let s = r * cols;
1380 let e = s + cols;
1381 highs.push(&highs_flat[s..e]);
1382 lows.push(&lows_flat[s..e]);
1383 closes.push(&closes_flat[s..e]);
1384 volumes.push(&volumes_flat[s..e]);
1385 }
1386 let batch = AdBatchInput {
1387 highs: &highs,
1388 lows: &lows,
1389 closes: &closes,
1390 volumes: &volumes,
1391 };
1392
1393 ad_batch_inner_into(&batch, detect_best_batch_kernel(), false, out)
1394 .map_err(|e| JsValue::from_str(&e.to_string()))
1395 }
1396}
1397
1398#[cfg(test)]
1399mod tests {
1400 use super::*;
1401 use crate::skip_if_unsupported;
1402 use crate::utilities::data_loader::{read_candles_from_csv, Candles};
1403 use crate::utilities::enums::Kernel;
1404
1405 fn check_ad_partial_params(
1406 test_name: &str,
1407 kernel: Kernel,
1408 ) -> Result<(), Box<dyn std::error::Error>> {
1409 skip_if_unsupported!(kernel, test_name);
1410 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1411 let candles = read_candles_from_csv(file_path)?;
1412 let default_params = AdParams::default();
1413 let input = AdInput::from_candles(&candles, default_params);
1414 let output = ad_with_kernel(&input, kernel)?;
1415 assert_eq!(output.values.len(), candles.close.len());
1416 Ok(())
1417 }
1418
1419 fn check_ad_accuracy(
1420 test_name: &str,
1421 kernel: Kernel,
1422 ) -> Result<(), Box<dyn std::error::Error>> {
1423 skip_if_unsupported!(kernel, test_name);
1424 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1425 let candles = read_candles_from_csv(file_path)?;
1426 let input = AdInput::with_default_candles(&candles);
1427 let ad_result = ad_with_kernel(&input, kernel)?;
1428 assert_eq!(ad_result.values.len(), candles.close.len());
1429 let expected_last_five = [1645918.16, 1645876.11, 1645824.27, 1645828.87, 1645728.78];
1430 let start = ad_result.values.len() - 5;
1431 let actual = &ad_result.values[start..];
1432 for (i, &val) in actual.iter().enumerate() {
1433 assert!(
1434 (val - expected_last_five[i]).abs() < 1e-1,
1435 "[{}] AD mismatch at idx {}: got {}, expected {}",
1436 test_name,
1437 i,
1438 val,
1439 expected_last_five[i]
1440 );
1441 }
1442 Ok(())
1443 }
1444
1445 fn check_ad_with_slice_data_reinput(
1446 test_name: &str,
1447 kernel: Kernel,
1448 ) -> Result<(), Box<dyn std::error::Error>> {
1449 skip_if_unsupported!(kernel, test_name);
1450 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1451 let candles = read_candles_from_csv(file_path)?;
1452 let first_input = AdInput::with_default_candles(&candles);
1453 let first_result = ad_with_kernel(&first_input, kernel)?;
1454 let second_input = AdInput::from_slices(
1455 &first_result.values,
1456 &first_result.values,
1457 &first_result.values,
1458 &first_result.values,
1459 AdParams::default(),
1460 );
1461 let second_result = ad_with_kernel(&second_input, kernel)?;
1462 assert_eq!(second_result.values.len(), first_result.values.len());
1463 for i in 50..second_result.values.len() {
1464 assert!(!second_result.values[i].is_nan());
1465 }
1466 Ok(())
1467 }
1468
1469 fn check_ad_input_with_default_candles(
1470 test_name: &str,
1471 kernel: Kernel,
1472 ) -> Result<(), Box<dyn std::error::Error>> {
1473 skip_if_unsupported!(kernel, test_name);
1474 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1475 let candles = read_candles_from_csv(file_path)?;
1476 let input = AdInput::with_default_candles(&candles);
1477 match input.data {
1478 AdData::Candles { .. } => {}
1479 _ => panic!("Expected AdData::Candles variant"),
1480 }
1481 Ok(())
1482 }
1483
1484 fn check_ad_accuracy_nan_check(
1485 test_name: &str,
1486 kernel: Kernel,
1487 ) -> Result<(), Box<dyn std::error::Error>> {
1488 skip_if_unsupported!(kernel, test_name);
1489 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1490 let candles = read_candles_from_csv(file_path)?;
1491 let input = AdInput::with_default_candles(&candles);
1492 let ad_result = ad_with_kernel(&input, kernel)?;
1493 assert_eq!(ad_result.values.len(), candles.close.len());
1494 if ad_result.values.len() > 50 {
1495 for i in 50..ad_result.values.len() {
1496 assert!(
1497 !ad_result.values[i].is_nan(),
1498 "[{}] Expected no NaN after index 50, but found NaN at index {}",
1499 test_name,
1500 i
1501 );
1502 }
1503 }
1504 Ok(())
1505 }
1506
1507 fn check_ad_streaming(
1508 test_name: &str,
1509 kernel: Kernel,
1510 ) -> Result<(), Box<dyn std::error::Error>> {
1511 skip_if_unsupported!(kernel, test_name);
1512 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1513 let candles = read_candles_from_csv(file_path)?;
1514 let input = AdInput::with_default_candles(&candles);
1515 let batch = ad_with_kernel(&input, kernel)?.values;
1516 let mut stream = AdStream::try_new()?;
1517 let mut stream_values = Vec::with_capacity(candles.close.len());
1518 for i in 0..candles.close.len() {
1519 let val = stream.update(
1520 candles.high[i],
1521 candles.low[i],
1522 candles.close[i],
1523 candles.volume[i],
1524 );
1525 stream_values.push(val);
1526 }
1527 assert_eq!(batch.len(), stream_values.len());
1528 for (b, s) in batch.iter().zip(stream_values.iter()) {
1529 if b.is_nan() && s.is_nan() {
1530 continue;
1531 }
1532 assert!(
1533 (b - s).abs() < 1e-9,
1534 "[{}] AD streaming mismatch",
1535 test_name
1536 );
1537 }
1538 Ok(())
1539 }
1540
1541 #[cfg(debug_assertions)]
1542 fn check_ad_no_poison(
1543 test_name: &str,
1544 kernel: Kernel,
1545 ) -> Result<(), Box<dyn std::error::Error>> {
1546 skip_if_unsupported!(kernel, test_name);
1547
1548 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1549 let candles = read_candles_from_csv(file_path)?;
1550
1551 let input = AdInput::with_default_candles(&candles);
1552 let output = ad_with_kernel(&input, kernel)?;
1553
1554 for (i, &val) in output.values.iter().enumerate() {
1555 if val.is_nan() {
1556 continue;
1557 }
1558
1559 let bits = val.to_bits();
1560
1561 if bits == 0x11111111_11111111 {
1562 panic!(
1563 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {}",
1564 test_name, val, bits, i
1565 );
1566 }
1567
1568 if bits == 0x22222222_22222222 {
1569 panic!(
1570 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {}",
1571 test_name, val, bits, i
1572 );
1573 }
1574
1575 if bits == 0x33333333_33333333 {
1576 panic!(
1577 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {}",
1578 test_name, val, bits, i
1579 );
1580 }
1581 }
1582
1583 let slice_input = AdInput::from_slices(
1584 &candles.high,
1585 &candles.low,
1586 &candles.close,
1587 &candles.volume,
1588 AdParams::default(),
1589 );
1590 let slice_output = ad_with_kernel(&slice_input, kernel)?;
1591
1592 for (i, &val) in slice_output.values.iter().enumerate() {
1593 if val.is_nan() {
1594 continue;
1595 }
1596
1597 let bits = val.to_bits();
1598
1599 if bits == 0x11111111_11111111 {
1600 panic!(
1601 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} (slice test)",
1602 test_name, val, bits, i
1603 );
1604 }
1605
1606 if bits == 0x22222222_22222222 {
1607 panic!(
1608 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} (slice test)",
1609 test_name, val, bits, i
1610 );
1611 }
1612
1613 if bits == 0x33333333_33333333 {
1614 panic!(
1615 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} (slice test)",
1616 test_name, val, bits, i
1617 );
1618 }
1619 }
1620
1621 Ok(())
1622 }
1623
1624 #[cfg(not(debug_assertions))]
1625 fn check_ad_no_poison(
1626 _test_name: &str,
1627 _kernel: Kernel,
1628 ) -> Result<(), Box<dyn std::error::Error>> {
1629 Ok(())
1630 }
1631
1632 macro_rules! generate_all_ad_tests {
1633 ($($test_fn:ident),*) => {
1634 paste::paste! {
1635 $(#[test] fn [<$test_fn _scalar_f64>]() {
1636 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1637 })*
1638 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1639 $(#[test] fn [<$test_fn _avx2_f64>]() {
1640 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1641 })*
1642 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1643 $(#[test] fn [<$test_fn _avx512_f64>]() {
1644 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1645 })*
1646 }
1647 }
1648 }
1649
1650 #[cfg(feature = "proptest")]
1651 #[allow(clippy::float_cmp)]
1652 fn check_ad_property(
1653 test_name: &str,
1654 kernel: Kernel,
1655 ) -> Result<(), Box<dyn std::error::Error>> {
1656 use proptest::prelude::*;
1657 skip_if_unsupported!(kernel, test_name);
1658
1659 let strat = (10usize..400).prop_flat_map(|len| {
1660 prop::collection::vec(
1661 (
1662 1.0f64..1000.0f64,
1663 0.0f64..500.0f64,
1664 0.0f64..1.0f64,
1665 0.0f64..1e6f64,
1666 )
1667 .prop_filter("finite values", |(l, hd, cr, v)| {
1668 l.is_finite()
1669 && hd.is_finite()
1670 && cr.is_finite()
1671 && v.is_finite()
1672 && *v >= 0.0
1673 })
1674 .prop_map(|(low, high_delta, close_ratio, volume)| {
1675 let high = low + high_delta;
1676 let close = if high_delta == 0.0 {
1677 low
1678 } else {
1679 low + high_delta * close_ratio
1680 };
1681 (high, low, close, volume)
1682 }),
1683 len,
1684 )
1685 .prop_map(|data| {
1686 let (highs, lows, closes, volumes): (Vec<_>, Vec<_>, Vec<_>, Vec<_>) =
1687 data.into_iter().map(|(h, l, c, v)| (h, l, c, v)).unzip4();
1688 (highs, lows, closes, volumes)
1689 })
1690 });
1691
1692 trait Unzip4<A, B, C, D> {
1693 fn unzip4(self) -> (Vec<A>, Vec<B>, Vec<C>, Vec<D>);
1694 }
1695
1696 impl<I, A, B, C, D> Unzip4<A, B, C, D> for I
1697 where
1698 I: Iterator<Item = (A, B, C, D)>,
1699 {
1700 fn unzip4(self) -> (Vec<A>, Vec<B>, Vec<C>, Vec<D>) {
1701 let (mut a, mut b, mut c, mut d) = (Vec::new(), Vec::new(), Vec::new(), Vec::new());
1702 for (av, bv, cv, dv) in self {
1703 a.push(av);
1704 b.push(bv);
1705 c.push(cv);
1706 d.push(dv);
1707 }
1708 (a, b, c, d)
1709 }
1710 }
1711
1712 proptest::test_runner::TestRunner::default()
1713 .run(&strat, |(highs, lows, closes, volumes)| {
1714 let input =
1715 AdInput::from_slices(&highs, &lows, &closes, &volumes, AdParams::default());
1716
1717 let AdOutput { values: out } = ad_with_kernel(&input, kernel).unwrap();
1718
1719 let AdOutput { values: ref_out } = ad_with_kernel(&input, Kernel::Scalar).unwrap();
1720
1721 prop_assert_eq!(out.len(), highs.len(), "Output length mismatch");
1722
1723 for (i, &val) in out.iter().enumerate() {
1724 prop_assert!(
1725 !val.is_nan(),
1726 "Unexpected NaN at index {}: AD should not have NaN values",
1727 i
1728 );
1729 }
1730
1731 for i in 0..out.len() {
1732 let y = out[i];
1733 let r = ref_out[i];
1734
1735 let y_bits = y.to_bits();
1736 let r_bits = r.to_bits();
1737
1738 if !y.is_finite() || !r.is_finite() {
1739 prop_assert_eq!(
1740 y_bits,
1741 r_bits,
1742 "Special value mismatch at idx {}: {} vs {}",
1743 i,
1744 y,
1745 r
1746 );
1747 } else {
1748 let ulp_diff: u64 = y_bits.abs_diff(r_bits);
1749 prop_assert!(
1750 (y - r).abs() <= 1e-9 || ulp_diff <= 4,
1751 "Value mismatch at idx {}: {} vs {} (ULP={})",
1752 i,
1753 y,
1754 r,
1755 ulp_diff
1756 );
1757 }
1758 }
1759
1760 for i in 1..volumes.len() {
1761 if volumes[i] == 0.0 {
1762 prop_assert!(
1763 (out[i] - out[i - 1]).abs() < 1e-10,
1764 "AD should not change when volume is 0 at index {}",
1765 i
1766 );
1767 }
1768 }
1769
1770 for i in 0..highs.len() {
1771 if (highs[i] - lows[i]).abs() < 1e-10 {
1772 if i == 0 {
1773 prop_assert!(
1774 out[i].abs() < 1e-10,
1775 "When high=low, first AD value should be 0, got {}",
1776 out[i]
1777 );
1778 } else {
1779 prop_assert!(
1780 (out[i] - out[i - 1]).abs() < 1e-10,
1781 "When high=low at index {}, AD should remain unchanged",
1782 i
1783 );
1784 }
1785 }
1786 }
1787
1788 let mut expected_ad = 0.0;
1789 for i in 0..highs.len() {
1790 let hl = highs[i] - lows[i];
1791 if hl != 0.0 {
1792 let mfm = ((closes[i] - lows[i]) - (highs[i] - closes[i])) / hl;
1793 let mfv = mfm * volumes[i];
1794 expected_ad += mfv;
1795 }
1796 prop_assert!(
1797 (out[i] - expected_ad).abs() < 1e-9,
1798 "Cumulative property violation at index {}: expected {}, got {}",
1799 i,
1800 expected_ad,
1801 out[i]
1802 );
1803 }
1804
1805 if !highs.is_empty() {
1806 let hl = highs[0] - lows[0];
1807 let expected_first = if hl != 0.0 {
1808 ((closes[0] - lows[0]) - (highs[0] - closes[0])) / hl * volumes[0]
1809 } else {
1810 0.0
1811 };
1812 prop_assert!(
1813 (out[0] - expected_first).abs() < 1e-10,
1814 "First value mismatch: expected {}, got {}",
1815 expected_first,
1816 out[0]
1817 );
1818 }
1819
1820 for i in 0..highs.len() {
1821 prop_assert!(
1822 lows[i] <= closes[i] + 1e-10 && closes[i] <= highs[i] + 1e-10,
1823 "Price constraint violation at index {}: low={}, close={}, high={}",
1824 i,
1825 lows[i],
1826 closes[i],
1827 highs[i]
1828 );
1829 }
1830
1831 let all_equal = highs
1832 .iter()
1833 .zip(lows.iter())
1834 .zip(closes.iter())
1835 .all(|((&h, &l), &c)| (h - l).abs() < 1e-10 && (l - c).abs() < 1e-10);
1836
1837 if all_equal {
1838 for (i, &val) in out.iter().enumerate() {
1839 prop_assert!(
1840 val.abs() < 1e-10,
1841 "When all prices are equal, AD should be 0 at index {}, got {}",
1842 i,
1843 val
1844 );
1845 }
1846 }
1847
1848 Ok(())
1849 })
1850 .unwrap();
1851
1852 Ok(())
1853 }
1854
1855 generate_all_ad_tests!(
1856 check_ad_partial_params,
1857 check_ad_accuracy,
1858 check_ad_input_with_default_candles,
1859 check_ad_with_slice_data_reinput,
1860 check_ad_accuracy_nan_check,
1861 check_ad_streaming,
1862 check_ad_no_poison
1863 );
1864
1865 #[cfg(feature = "proptest")]
1866 generate_all_ad_tests!(check_ad_property);
1867
1868 fn check_batch_single_row(
1869 test: &str,
1870 kernel: Kernel,
1871 ) -> Result<(), Box<dyn std::error::Error>> {
1872 skip_if_unsupported!(kernel, test);
1873 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1874 let candles = read_candles_from_csv(file_path)?;
1875
1876 let highs: Vec<&[f64]> = vec![&candles.high];
1877 let lows: Vec<&[f64]> = vec![&candles.low];
1878 let closes: Vec<&[f64]> = vec![&candles.close];
1879 let volumes: Vec<&[f64]> = vec![&candles.volume];
1880
1881 let single = ad_with_kernel(
1882 &AdInput::from_candles(&candles, AdParams::default()),
1883 kernel,
1884 )?
1885 .values;
1886
1887 let batch = AdBatchBuilder::new()
1888 .kernel(kernel)
1889 .apply_slices(&highs, &lows, &closes, &volumes)?;
1890
1891 assert_eq!(batch.rows, 1);
1892 assert_eq!(batch.cols, candles.close.len());
1893 assert_eq!(batch.values.len(), candles.close.len());
1894
1895 for (i, (a, b)) in single.iter().zip(&batch.values).enumerate() {
1896 assert!(
1897 (a - b).abs() < 1e-8,
1898 "[{}] AD batch single row mismatch at {}: {} vs {}",
1899 test,
1900 i,
1901 a,
1902 b
1903 );
1904 }
1905 Ok(())
1906 }
1907
1908 fn check_batch_multi_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
1909 skip_if_unsupported!(kernel, test);
1910 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1911 let candles = read_candles_from_csv(file_path)?;
1912
1913 let highs: Vec<&[f64]> = vec![&candles.high, &candles.high, &candles.high];
1914 let lows: Vec<&[f64]> = vec![&candles.low, &candles.low, &candles.low];
1915 let closes: Vec<&[f64]> = vec![&candles.close, &candles.close, &candles.close];
1916 let volumes: Vec<&[f64]> = vec![&candles.volume, &candles.volume, &candles.volume];
1917
1918 let single = ad_with_kernel(
1919 &AdInput::from_candles(&candles, AdParams::default()),
1920 kernel,
1921 )?
1922 .values;
1923
1924 let batch = AdBatchBuilder::new()
1925 .kernel(kernel)
1926 .apply_slices(&highs, &lows, &closes, &volumes)?;
1927
1928 assert_eq!(batch.rows, 3);
1929 assert_eq!(batch.cols, candles.close.len());
1930 assert_eq!(batch.values.len(), 3 * candles.close.len());
1931
1932 for row in 0..3 {
1933 let row_slice = &batch.values[row * batch.cols..(row + 1) * batch.cols];
1934 for (i, (a, b)) in single.iter().zip(row_slice.iter()).enumerate() {
1935 assert!(
1936 (a - b).abs() < 1e-8,
1937 "[{}] AD batch multi row mismatch row {} idx {}: {} vs {}",
1938 test,
1939 row,
1940 i,
1941 a,
1942 b
1943 );
1944 }
1945 }
1946 Ok(())
1947 }
1948
1949 macro_rules! gen_batch_tests {
1950 ($fn_name:ident) => {
1951 paste::paste! {
1952 #[test] fn [<$fn_name _scalar>]() {
1953 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1954 }
1955 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1956 #[test] fn [<$fn_name _avx2>]() {
1957 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1958 }
1959 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1960 #[test] fn [<$fn_name _avx512>]() {
1961 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1962 }
1963 #[test] fn [<$fn_name _auto_detect>]() {
1964 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1965 }
1966 }
1967 };
1968 }
1969
1970 #[cfg(debug_assertions)]
1971 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
1972 skip_if_unsupported!(kernel, test);
1973
1974 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1975 let c = read_candles_from_csv(file)?;
1976
1977 let mut highs: Vec<&[f64]> = vec![];
1978 let mut lows: Vec<&[f64]> = vec![];
1979 let mut closes: Vec<&[f64]> = vec![];
1980 let mut volumes: Vec<&[f64]> = vec![];
1981
1982 highs.push(&c.high);
1983 lows.push(&c.low);
1984 closes.push(&c.close);
1985 volumes.push(&c.volume);
1986
1987 let high_rev: Vec<f64> = c.high.iter().rev().copied().collect();
1988 let low_rev: Vec<f64> = c.low.iter().rev().copied().collect();
1989 let close_rev: Vec<f64> = c.close.iter().rev().copied().collect();
1990 let volume_rev: Vec<f64> = c.volume.iter().rev().copied().collect();
1991
1992 highs.push(&high_rev);
1993 lows.push(&low_rev);
1994 closes.push(&close_rev);
1995 volumes.push(&volume_rev);
1996
1997 if c.high.len() > 100 {
1998 highs.push(&c.high[50..]);
1999 lows.push(&c.low[50..]);
2000 closes.push(&c.close[50..]);
2001 volumes.push(&c.volume[50..]);
2002 }
2003
2004 let batch = AdBatchBuilder::new()
2005 .kernel(kernel)
2006 .apply_slices(&highs, &lows, &closes, &volumes)?;
2007
2008 for (idx, &val) in batch.values.iter().enumerate() {
2009 if val.is_nan() {
2010 continue;
2011 }
2012
2013 let bits = val.to_bits();
2014 let row = idx / batch.cols;
2015 let col = idx % batch.cols;
2016
2017 if bits == 0x11111111_11111111 {
2018 panic!(
2019 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (flat index {})",
2020 test, val, bits, row, col, idx
2021 );
2022 }
2023
2024 if bits == 0x22222222_22222222 {
2025 panic!(
2026 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (flat index {})",
2027 test, val, bits, row, col, idx
2028 );
2029 }
2030
2031 if bits == 0x33333333_33333333 {
2032 panic!(
2033 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (flat index {})",
2034 test, val, bits, row, col, idx
2035 );
2036 }
2037 }
2038
2039 Ok(())
2040 }
2041
2042 #[cfg(not(debug_assertions))]
2043 fn check_batch_no_poison(
2044 _test: &str,
2045 _kernel: Kernel,
2046 ) -> Result<(), Box<dyn std::error::Error>> {
2047 Ok(())
2048 }
2049
2050 gen_batch_tests!(check_batch_single_row);
2051 gen_batch_tests!(check_batch_multi_row);
2052 gen_batch_tests!(check_batch_no_poison);
2053
2054 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2055 #[test]
2056 fn test_ad_into_matches_api() {
2057 let n = 256usize;
2058 let mut ts = Vec::with_capacity(n);
2059 let mut open = Vec::with_capacity(n);
2060 let mut high = Vec::with_capacity(n);
2061 let mut low = Vec::with_capacity(n);
2062 let mut close = Vec::with_capacity(n);
2063 let mut volume = Vec::with_capacity(n);
2064
2065 for i in 0..n {
2066 let i_f = i as f64;
2067 ts.push(i as i64);
2068 let o = 100.0 + (i % 13) as f64 * 0.75;
2069 let l = o - 2.0;
2070 let h = o + 2.0 + ((i % 3) as f64) * 0.1;
2071 let c = l + ((i % 5) as f64) * 0.5;
2072 let v = 1000.0 + 10.0 * i_f;
2073 open.push(o);
2074 low.push(l);
2075 high.push(h);
2076 close.push(c);
2077 volume.push(v);
2078 }
2079
2080 let candles = Candles::new(
2081 ts,
2082 open,
2083 high.clone(),
2084 low.clone(),
2085 close.clone(),
2086 volume.clone(),
2087 );
2088 let input = AdInput::with_default_candles(&candles);
2089
2090 let baseline = ad(&input).expect("ad() should succeed").values;
2091
2092 let mut out = vec![0.0; baseline.len()];
2093 ad_into(&input, &mut out).expect("ad_into() should succeed");
2094
2095 assert_eq!(out.len(), baseline.len());
2096
2097 fn eq_or_both_nan(a: f64, b: f64) -> bool {
2098 (a.is_nan() && b.is_nan()) || (a == b) || ((a - b).abs() <= 1e-12)
2099 }
2100
2101 for (i, (a, b)) in out
2102 .iter()
2103 .copied()
2104 .zip(baseline.iter().copied())
2105 .enumerate()
2106 {
2107 assert!(
2108 eq_or_both_nan(a, b),
2109 "ad_into parity failed at index {}: {} vs {}",
2110 i,
2111 a,
2112 b
2113 );
2114 }
2115 }
2116}