1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5 make_uninit_matrix,
6};
7use aligned_vec::{AVec, CACHELINE_ALIGN};
8#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
9use core::arch::x86_64::*;
10#[cfg(not(target_arch = "wasm32"))]
11use rayon::prelude::*;
12use std::convert::AsRef;
13use std::error::Error;
14use std::mem::ManuallyDrop;
15use thiserror::Error;
16
17const DEFAULT_PERIOD: usize = 10;
18
19#[cfg(all(feature = "python", feature = "cuda"))]
20use crate::cuda::moving_averages::DeviceArrayF32;
21#[cfg(all(feature = "python", feature = "cuda"))]
22use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
23#[cfg(feature = "python")]
24use crate::utilities::kernel_validation::validate_kernel;
25#[cfg(all(feature = "python", feature = "cuda"))]
26use cust::context::Context;
27#[cfg(all(feature = "python", feature = "cuda"))]
28use cust::memory::DeviceBuffer;
29#[cfg(feature = "python")]
30use numpy::{IntoPyArray, PyArray1};
31#[cfg(feature = "python")]
32use pyo3::exceptions::PyValueError;
33#[cfg(feature = "python")]
34use pyo3::prelude::*;
35#[cfg(feature = "python")]
36use pyo3::types::PyDict;
37#[cfg(all(feature = "python", feature = "cuda"))]
38use std::sync::Arc;
39
40#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
41use serde::{Deserialize, Serialize};
42#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
43use wasm_bindgen::prelude::*;
44
45impl<'a> AsRef<[f64]> for RocrInput<'a> {
46 #[inline(always)]
47 fn as_ref(&self) -> &[f64] {
48 match &self.data {
49 RocrData::Slice(slice) => slice,
50 RocrData::Candles { candles, source } => source_type(candles, source),
51 }
52 }
53}
54
55#[derive(Debug, Clone)]
56pub enum RocrData<'a> {
57 Candles {
58 candles: &'a Candles,
59 source: &'a str,
60 },
61 Slice(&'a [f64]),
62}
63
64#[derive(Debug, Clone)]
65pub struct RocrOutput {
66 pub values: Vec<f64>,
67}
68
69#[derive(Debug, Clone)]
70#[cfg_attr(
71 all(target_arch = "wasm32", feature = "wasm"),
72 derive(Serialize, Deserialize)
73)]
74pub struct RocrParams {
75 pub period: Option<usize>,
76}
77
78impl Default for RocrParams {
79 fn default() -> Self {
80 Self {
81 period: Some(DEFAULT_PERIOD),
82 }
83 }
84}
85
86#[derive(Debug, Clone)]
87pub struct RocrInput<'a> {
88 pub data: RocrData<'a>,
89 pub params: RocrParams,
90}
91
92impl<'a> RocrInput<'a> {
93 #[inline]
94 pub fn from_candles(c: &'a Candles, s: &'a str, p: RocrParams) -> Self {
95 Self {
96 data: RocrData::Candles {
97 candles: c,
98 source: s,
99 },
100 params: p,
101 }
102 }
103 #[inline]
104 pub fn from_slice(sl: &'a [f64], p: RocrParams) -> Self {
105 Self {
106 data: RocrData::Slice(sl),
107 params: p,
108 }
109 }
110 #[inline]
111 pub fn with_default_candles(c: &'a Candles) -> Self {
112 Self::from_candles(c, "close", RocrParams::default())
113 }
114 #[inline]
115 pub fn get_period(&self) -> usize {
116 self.params.period.unwrap_or(DEFAULT_PERIOD)
117 }
118}
119
120#[derive(Copy, Clone, Debug)]
121pub struct RocrBuilder {
122 period: Option<usize>,
123 kernel: Kernel,
124}
125
126impl Default for RocrBuilder {
127 fn default() -> Self {
128 Self {
129 period: None,
130 kernel: Kernel::Auto,
131 }
132 }
133}
134
135impl RocrBuilder {
136 #[inline(always)]
137 pub fn new() -> Self {
138 Self::default()
139 }
140 #[inline(always)]
141 pub fn period(mut self, n: usize) -> Self {
142 self.period = Some(n);
143 self
144 }
145 #[inline(always)]
146 pub fn kernel(mut self, k: Kernel) -> Self {
147 self.kernel = k;
148 self
149 }
150 #[inline(always)]
151 pub fn apply(self, c: &Candles) -> Result<RocrOutput, RocrError> {
152 let p = RocrParams {
153 period: self.period,
154 };
155 let i = RocrInput::from_candles(c, "close", p);
156 rocr_with_kernel(&i, self.kernel)
157 }
158 #[inline(always)]
159 pub fn apply_slice(self, d: &[f64]) -> Result<RocrOutput, RocrError> {
160 let p = RocrParams {
161 period: self.period,
162 };
163 let i = RocrInput::from_slice(d, p);
164 rocr_with_kernel(&i, self.kernel)
165 }
166 #[inline(always)]
167 pub fn into_stream(self) -> Result<RocrStream, RocrError> {
168 let p = RocrParams {
169 period: self.period,
170 };
171 RocrStream::try_new(p)
172 }
173}
174
175#[derive(Debug, Error)]
176pub enum RocrError {
177 #[error("rocr: Empty data provided.")]
178 EmptyInputData,
179
180 #[error("rocr: All values are NaN.")]
181 AllValuesNaN,
182
183 #[error("rocr: Invalid period: period = {period}, data length = {data_len}")]
184 InvalidPeriod { period: usize, data_len: usize },
185
186 #[error("rocr: Not enough valid data: needed = {needed}, valid = {valid}")]
187 NotEnoughValidData { needed: usize, valid: usize },
188
189 #[error("rocr: Output length mismatch: expected {expected}, got {got}")]
190 OutputLengthMismatch { expected: usize, got: usize },
191
192 #[error("rocr: Invalid range: start={start}, end={end}, step={step}")]
193 InvalidRange {
194 start: usize,
195 end: usize,
196 step: usize,
197 },
198
199 #[error("rocr: Invalid kernel for batch: {0:?}")]
200 InvalidKernelForBatch(crate::utilities::enums::Kernel),
201}
202
203#[inline(always)]
204fn rocr_prepare<'a>(
205 input: &'a RocrInput,
206 kern: Kernel,
207) -> Result<(&'a [f64], usize, usize, Kernel), RocrError> {
208 let data: &[f64] = input.as_ref();
209 let len = data.len();
210 if len == 0 {
211 return Err(RocrError::EmptyInputData);
212 }
213
214 let first = data
215 .iter()
216 .position(|x| !x.is_nan())
217 .ok_or(RocrError::AllValuesNaN)?;
218 let period = input.get_period();
219
220 if period == 0 || period > len {
221 return Err(RocrError::InvalidPeriod {
222 period,
223 data_len: len,
224 });
225 }
226 if len - first < period {
227 return Err(RocrError::NotEnoughValidData {
228 needed: period,
229 valid: len - first,
230 });
231 }
232
233 let chosen = match kern {
234 Kernel::Auto => Kernel::Scalar,
235 k => k,
236 };
237 Ok((data, period, first, chosen))
238}
239
240#[inline]
241pub fn rocr(input: &RocrInput) -> Result<RocrOutput, RocrError> {
242 rocr_with_kernel(input, Kernel::Auto)
243}
244
245pub fn rocr_with_kernel(input: &RocrInput, kernel: Kernel) -> Result<RocrOutput, RocrError> {
246 let (data, period, first, chosen) = rocr_prepare(input, kernel)?;
247 let mut out = alloc_with_nan_prefix(data.len(), first + period);
248 unsafe {
249 match chosen {
250 Kernel::Scalar | Kernel::ScalarBatch => rocr_scalar(data, period, first, &mut out),
251 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
252 Kernel::Avx2 | Kernel::Avx2Batch => rocr_avx2(data, period, first, &mut out),
253 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
254 Kernel::Avx512 | Kernel::Avx512Batch => rocr_avx512(data, period, first, &mut out),
255 _ => unreachable!(),
256 }
257 }
258 Ok(RocrOutput { values: out })
259}
260
261#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
262pub fn rocr_into(input: &RocrInput, out: &mut [f64]) -> Result<(), RocrError> {
263 rocr_into_slice(out, input, Kernel::Auto)
264}
265
266pub fn rocr_into_slice(dst: &mut [f64], input: &RocrInput, kern: Kernel) -> Result<(), RocrError> {
267 let (data, period, first, chosen) = rocr_prepare(input, kern)?;
268 if dst.len() != data.len() {
269 return Err(RocrError::OutputLengthMismatch {
270 expected: data.len(),
271 got: dst.len(),
272 });
273 }
274
275 unsafe {
276 match chosen {
277 Kernel::Scalar | Kernel::ScalarBatch => rocr_scalar(data, period, first, dst),
278 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
279 Kernel::Avx2 | Kernel::Avx2Batch => rocr_avx2(data, period, first, dst),
280 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
281 Kernel::Avx512 | Kernel::Avx512Batch => rocr_avx512(data, period, first, dst),
282 _ => unreachable!(),
283 }
284 }
285
286 let warm = first + period;
287 for v in &mut dst[..warm] {
288 *v = f64::NAN;
289 }
290 Ok(())
291}
292
293#[inline]
294pub fn rocr_scalar(data: &[f64], period: usize, first_val: usize, out: &mut [f64]) {
295 for i in (first_val + period)..data.len() {
296 let past = data[i - period];
297 out[i] = if past == 0.0 || past.is_nan() {
298 0.0
299 } else {
300 data[i] / past
301 };
302 }
303}
304
305#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
306#[inline]
307pub fn rocr_avx512(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
308 unsafe {
309 let start = first_valid + period;
310 let len = data.len();
311 if start >= len {
312 return;
313 }
314
315 #[cfg(target_feature = "avx512f")]
316 {
317 let mut i = start;
318 let end = len;
319 let step = 8usize;
320 while i + step <= end {
321 let cur = _mm512_loadu_pd(data.as_ptr().add(i));
322 let pst = _mm512_loadu_pd(data.as_ptr().add(i - period));
323
324 let m0 = _mm512_cmp_pd_mask(pst, _mm512_set1_pd(0.0), _CMP_EQ_OQ);
325 let m1 = _mm512_cmp_pd_mask(pst, pst, _CMP_UNORD_Q);
326 let bad = m0 | m1;
327 let good = !bad;
328
329 let res = _mm512_maskz_div_pd(good, cur, pst);
330 _mm512_storeu_pd(out.as_mut_ptr().add(i), res);
331 i += step;
332 }
333 for j in i..end {
334 let p = *data.get_unchecked(j - period);
335 let c = *data.get_unchecked(j);
336 *out.get_unchecked_mut(j) = if p == 0.0 || p.is_nan() { 0.0 } else { c / p };
337 }
338 return;
339 }
340
341 rocr_scalar(data, period, first_valid, out)
342 }
343}
344
345#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
346#[inline]
347pub fn rocr_avx2(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
348 unsafe {
349 let start = first_valid + period;
350 let len = data.len();
351 if start >= len {
352 return;
353 }
354
355 #[cfg(target_feature = "avx2")]
356 {
357 let mut i = start;
358 let mut p = i - period;
359 let end = len;
360 let zero = _mm256_set1_pd(0.0);
361
362 while i + 8 <= end {
363 let cur0 = _mm256_loadu_pd(data.as_ptr().add(i));
364 let pst0 = _mm256_loadu_pd(data.as_ptr().add(p));
365 let div0 = _mm256_div_pd(cur0, pst0);
366 let m0z = _mm256_cmp_pd(pst0, zero, _CMP_EQ_OQ);
367 let m0n = _mm256_cmp_pd(pst0, pst0, _CMP_UNORD_Q);
368 let m0 = _mm256_or_pd(m0z, m0n);
369 let res0 = _mm256_andnot_pd(m0, div0);
370 _mm256_storeu_pd(out.as_mut_ptr().add(i), res0);
371
372 let cur1 = _mm256_loadu_pd(data.as_ptr().add(i + 4));
373 let pst1 = _mm256_loadu_pd(data.as_ptr().add(p + 4));
374 let div1 = _mm256_div_pd(cur1, pst1);
375 let m1z = _mm256_cmp_pd(pst1, zero, _CMP_EQ_OQ);
376 let m1n = _mm256_cmp_pd(pst1, pst1, _CMP_UNORD_Q);
377 let m1 = _mm256_or_pd(m1z, m1n);
378 let res1 = _mm256_andnot_pd(m1, div1);
379 _mm256_storeu_pd(out.as_mut_ptr().add(i + 4), res1);
380
381 i += 8;
382 p += 8;
383 }
384
385 while i + 4 <= end {
386 let cur = _mm256_loadu_pd(data.as_ptr().add(i));
387 let pst = _mm256_loadu_pd(data.as_ptr().add(p));
388 let div = _mm256_div_pd(cur, pst);
389 let mz = _mm256_cmp_pd(pst, zero, _CMP_EQ_OQ);
390 let mn = _mm256_cmp_pd(pst, pst, _CMP_UNORD_Q);
391 let m = _mm256_or_pd(mz, mn);
392 let res = _mm256_andnot_pd(m, div);
393 _mm256_storeu_pd(out.as_mut_ptr().add(i), res);
394 i += 4;
395 p += 4;
396 }
397
398 while i < end {
399 let past = *data.get_unchecked(p);
400 let cur = *data.get_unchecked(i);
401 let b = past.to_bits();
402 let is_zero = (b & 0x7fff_ffff_ffff_ffff) == 0;
403 let is_nan = (b & 0x7ff0_0000_0000_0000) == 0x7ff0_0000_0000_0000
404 && (b & 0x000f_ffff_ffff_ffff) != 0;
405 *out.get_unchecked_mut(i) = if is_zero | is_nan { 0.0 } else { cur / past };
406 i += 1;
407 p += 1;
408 }
409 return;
410 }
411
412 rocr_scalar(data, period, first_valid, out)
413 }
414}
415
416#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
417#[inline]
418pub unsafe fn rocr_avx512_short(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
419 rocr_scalar(data, period, first_valid, out)
420}
421
422#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
423#[inline]
424pub unsafe fn rocr_avx512_long(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
425 rocr_scalar(data, period, first_valid, out)
426}
427
428pub struct RocrStream {
429 period: usize,
430 buffer: Vec<f64>,
431 head: usize,
432 filled: bool,
433}
434
435impl RocrStream {
436 pub fn try_new(params: RocrParams) -> Result<Self, RocrError> {
437 let period = params.period.unwrap_or(DEFAULT_PERIOD);
438 if period == 0 {
439 return Err(RocrError::InvalidPeriod {
440 period,
441 data_len: 0,
442 });
443 }
444 Ok(Self {
445 period,
446 buffer: vec![f64::NAN; period],
447 head: 0,
448 filled: false,
449 })
450 }
451
452 #[inline(always)]
453 pub fn update(&mut self, value: f64) -> Option<f64> {
454 const ABS_MASK: u64 = 0x7fff_ffff_ffff_ffff;
455 const EXP_MASK: u64 = 0x7ff0_0000_0000_0000;
456 const MAN_MASK: u64 = 0x000f_ffff_ffff_ffff;
457
458 let past = self.buffer[self.head];
459
460 let out = if self.filled {
461 let y = value / past;
462 let pb = past.to_bits();
463 let is_zero = (pb & ABS_MASK) == 0;
464 let is_nan = (pb & EXP_MASK) == EXP_MASK && (pb & MAN_MASK) != 0;
465 let good_mask: u64 = (!(is_zero | is_nan) as u64).wrapping_neg();
466 Some(f64::from_bits(y.to_bits() & good_mask))
467 } else {
468 None
469 };
470
471 self.buffer[self.head] = value;
472 let next = self.head + 1;
473 if next == self.period {
474 self.head = 0;
475 self.filled = true;
476 } else {
477 self.head = next;
478 }
479
480 out
481 }
482}
483
484#[derive(Clone, Debug)]
485pub struct RocrBatchRange {
486 pub period: (usize, usize, usize),
487}
488
489impl Default for RocrBatchRange {
490 fn default() -> Self {
491 Self {
492 period: (9, 258, 1),
493 }
494 }
495}
496
497#[derive(Clone, Debug, Default)]
498pub struct RocrBatchBuilder {
499 range: RocrBatchRange,
500 kernel: Kernel,
501}
502
503impl RocrBatchBuilder {
504 pub fn new() -> Self {
505 Self::default()
506 }
507 pub fn kernel(mut self, k: Kernel) -> Self {
508 self.kernel = k;
509 self
510 }
511
512 #[inline]
513 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
514 self.range.period = (start, end, step);
515 self
516 }
517 #[inline]
518 pub fn period_static(mut self, p: usize) -> Self {
519 self.range.period = (p, p, 0);
520 self
521 }
522
523 pub fn apply_slice(self, data: &[f64]) -> Result<RocrBatchOutput, RocrError> {
524 rocr_batch_with_kernel(data, &self.range, self.kernel)
525 }
526
527 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<RocrBatchOutput, RocrError> {
528 RocrBatchBuilder::new().kernel(k).apply_slice(data)
529 }
530
531 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<RocrBatchOutput, RocrError> {
532 let slice = source_type(c, src);
533 self.apply_slice(slice)
534 }
535
536 pub fn with_default_candles(c: &Candles) -> Result<RocrBatchOutput, RocrError> {
537 RocrBatchBuilder::new()
538 .kernel(Kernel::Auto)
539 .apply_candles(c, "close")
540 }
541}
542
543pub fn rocr_batch_with_kernel(
544 data: &[f64],
545 sweep: &RocrBatchRange,
546 k: Kernel,
547) -> Result<RocrBatchOutput, RocrError> {
548 let kernel = match k {
549 Kernel::Auto => Kernel::ScalarBatch,
550 other if other.is_batch() => other,
551 _ => {
552 return Err(RocrError::InvalidKernelForBatch(k));
553 }
554 };
555 let simd = match kernel {
556 Kernel::Avx512Batch => Kernel::Avx512,
557 Kernel::Avx2Batch => Kernel::Avx2,
558 Kernel::ScalarBatch => Kernel::Scalar,
559 _ => unreachable!(),
560 };
561 rocr_batch_par_slice(data, sweep, simd)
562}
563
564#[derive(Clone, Debug)]
565pub struct RocrBatchOutput {
566 pub values: Vec<f64>,
567 pub combos: Vec<RocrParams>,
568 pub rows: usize,
569 pub cols: usize,
570}
571impl RocrBatchOutput {
572 pub fn row_for_params(&self, p: &RocrParams) -> Option<usize> {
573 self.combos
574 .iter()
575 .position(|c| c.period.unwrap_or(9) == p.period.unwrap_or(9))
576 }
577
578 pub fn values_for(&self, p: &RocrParams) -> Option<&[f64]> {
579 self.row_for_params(p).map(|row| {
580 let start = row * self.cols;
581 &self.values[start..start + self.cols]
582 })
583 }
584}
585
586#[inline(always)]
587fn expand_grid(r: &RocrBatchRange) -> Result<Vec<RocrParams>, RocrError> {
588 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, RocrError> {
589 if step == 0 || start == end {
590 return Ok(vec![start]);
591 }
592
593 if start < end {
594 let st = step.max(1);
595 let mut v = Vec::new();
596 let mut x = start;
597 while x <= end {
598 v.push(x);
599 x = match x.checked_add(st) {
600 Some(next) => next,
601 None => break,
602 };
603 }
604 if v.is_empty() {
605 return Err(RocrError::InvalidRange { start, end, step });
606 }
607 return Ok(v);
608 }
609
610 let st = step.max(1) as isize;
611 let mut v = Vec::new();
612 let mut x = start as isize;
613 let end_i = end as isize;
614 while x >= end_i {
615 v.push(x as usize);
616 x -= st;
617 }
618 if v.is_empty() {
619 return Err(RocrError::InvalidRange { start, end, step });
620 }
621 Ok(v)
622 }
623
624 let periods = axis_usize(r.period)?;
625 if periods.is_empty() {
626 return Err(RocrError::InvalidRange {
627 start: r.period.0,
628 end: r.period.1,
629 step: r.period.2,
630 });
631 }
632
633 let mut out = Vec::with_capacity(periods.len());
634 for &p in &periods {
635 out.push(RocrParams { period: Some(p) });
636 }
637 Ok(out)
638}
639
640#[inline(always)]
641fn rocr_batch_prepare(data: &[f64], combos: &[RocrParams]) -> Result<(usize, usize), RocrError> {
642 let cols = data.len();
643 if cols == 0 {
644 return Err(RocrError::EmptyInputData);
645 }
646
647 let first = data
648 .iter()
649 .position(|x| !x.is_nan())
650 .ok_or(RocrError::AllValuesNaN)?;
651
652 let mut max_p = 0usize;
653 for c in combos {
654 let p = c.period.unwrap();
655 if p == 0 || p > cols {
656 return Err(RocrError::InvalidPeriod {
657 period: p,
658 data_len: cols,
659 });
660 }
661 if p > max_p {
662 max_p = p;
663 }
664 }
665
666 if cols - first < max_p {
667 return Err(RocrError::NotEnoughValidData {
668 needed: max_p,
669 valid: cols - first,
670 });
671 }
672
673 Ok((first, max_p))
674}
675
676#[inline(always)]
677pub fn rocr_batch_slice(
678 data: &[f64],
679 sweep: &RocrBatchRange,
680 kern: Kernel,
681) -> Result<RocrBatchOutput, RocrError> {
682 rocr_batch_inner(data, sweep, kern, false)
683}
684
685#[inline(always)]
686pub fn rocr_batch_par_slice(
687 data: &[f64],
688 sweep: &RocrBatchRange,
689 kern: Kernel,
690) -> Result<RocrBatchOutput, RocrError> {
691 rocr_batch_inner(data, sweep, kern, true)
692}
693
694#[inline(always)]
695fn rocr_batch_inner(
696 data: &[f64],
697 sweep: &RocrBatchRange,
698 kern: Kernel,
699 parallel: bool,
700) -> Result<RocrBatchOutput, RocrError> {
701 let combos = expand_grid(sweep)?;
702 let (first, _max_p) = rocr_batch_prepare(data, &combos)?;
703 let rows = combos.len();
704 let cols = data.len();
705
706 rows.checked_mul(cols).ok_or(RocrError::InvalidRange {
707 start: rows,
708 end: cols,
709 step: 0,
710 })?;
711
712 let mut values_uninit = make_uninit_matrix(rows, cols);
713
714 let warmup_periods: Vec<usize> = combos
715 .iter()
716 .map(|c| {
717 let p = c.period.unwrap();
718 first.checked_add(p).ok_or(RocrError::InvalidRange {
719 start: first,
720 end: p,
721 step: 0,
722 })
723 })
724 .collect::<Result<_, _>>()?;
725 init_matrix_prefixes(&mut values_uninit, cols, &warmup_periods);
726
727 let mut buf_guard = core::mem::ManuallyDrop::new(values_uninit);
728 let values: &mut [f64] = unsafe {
729 core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
730 };
731
732 let inv: Option<AVec<f64>> = match kern {
733 Kernel::Scalar => {
734 let mut buf: AVec<f64> = AVec::with_capacity(CACHELINE_ALIGN, cols);
735 unsafe { buf.set_len(cols) };
736
737 const ABS_MASK: u64 = 0x7fff_ffff_ffff_ffff;
738 const EXP_MASK: u64 = 0x7ff0_0000_0000_0000;
739 const MAN_MASK: u64 = 0x000f_ffff_ffff_ffff;
740 for j in first..cols {
741 let v = data[j];
742 let b = v.to_bits();
743 let is_zero = (b & ABS_MASK) == 0;
744 let is_nan = (b & EXP_MASK) == EXP_MASK && (b & MAN_MASK) != 0;
745 buf[j] = if is_zero | is_nan { 0.0 } else { 1.0 / v };
746 }
747 Some(buf)
748 }
749 _ => None,
750 };
751
752 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
753 let period = combos[row].period.unwrap();
754
755 match kern {
756 Kernel::Scalar => match &inv {
757 Some(inv) => rocr_row_scalar_mul(data, inv.as_slice(), first, period, out_row),
758 None => rocr_row_scalar(data, first, period, out_row),
759 },
760 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
761 Kernel::Avx2 => rocr_row_avx2(data, first, period, out_row),
762 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
763 Kernel::Avx512 => rocr_row_avx512(data, first, period, out_row),
764 _ => unreachable!(),
765 }
766 };
767
768 if parallel {
769 #[cfg(not(target_arch = "wasm32"))]
770 {
771 values
772 .par_chunks_mut(cols)
773 .enumerate()
774 .for_each(|(row, slice)| do_row(row, slice));
775 }
776
777 #[cfg(target_arch = "wasm32")]
778 {
779 for (row, slice) in values.chunks_mut(cols).enumerate() {
780 do_row(row, slice);
781 }
782 }
783 } else {
784 for (row, slice) in values.chunks_mut(cols).enumerate() {
785 do_row(row, slice);
786 }
787 }
788
789 let values = unsafe {
790 Vec::from_raw_parts(
791 buf_guard.as_mut_ptr() as *mut f64,
792 buf_guard.len(),
793 buf_guard.capacity(),
794 )
795 };
796
797 Ok(RocrBatchOutput {
798 values,
799 combos,
800 rows,
801 cols,
802 })
803}
804
805#[inline(always)]
806fn rocr_batch_inner_into(
807 data: &[f64],
808 combos: &[RocrParams],
809 first: usize,
810 kern: Kernel,
811 parallel: bool,
812 out: &mut [f64],
813) -> Result<(), RocrError> {
814 if combos.is_empty() {
815 return Err(RocrError::InvalidRange {
816 start: 0,
817 end: 0,
818 step: 0,
819 });
820 }
821 let cols = data.len();
822 if cols == 0 {
823 return Err(RocrError::EmptyInputData);
824 }
825 if first >= cols {
826 return Err(RocrError::AllValuesNaN);
827 }
828
829 let expected = combos
830 .len()
831 .checked_mul(cols)
832 .ok_or(RocrError::InvalidRange {
833 start: combos.len(),
834 end: cols,
835 step: 0,
836 })?;
837 if out.len() != expected {
838 return Err(RocrError::OutputLengthMismatch {
839 expected,
840 got: out.len(),
841 });
842 }
843
844 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
845 let period = combos[row].period.unwrap();
846
847 match kern {
848 Kernel::Scalar => rocr_row_scalar(data, first, period, out_row),
849 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
850 Kernel::Avx2 => rocr_row_avx2(data, first, period, out_row),
851 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
852 Kernel::Avx512 => rocr_row_avx512(data, first, period, out_row),
853 _ => unreachable!(),
854 }
855 };
856
857 if parallel {
858 #[cfg(not(target_arch = "wasm32"))]
859 {
860 out.par_chunks_mut(cols)
861 .enumerate()
862 .for_each(|(row, slice)| do_row(row, slice));
863 }
864
865 #[cfg(target_arch = "wasm32")]
866 {
867 for (row, slice) in out.chunks_mut(cols).enumerate() {
868 do_row(row, slice);
869 }
870 }
871 } else {
872 for (row, slice) in out.chunks_mut(cols).enumerate() {
873 do_row(row, slice);
874 }
875 }
876
877 Ok(())
878}
879
880#[inline(always)]
881unsafe fn rocr_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
882 let mut i = first + period;
883 let len = data.len();
884 if i >= len {
885 return;
886 }
887
888 const ABS_MASK: u64 = 0x7fff_ffff_ffff_ffff;
889 const EXP_MASK: u64 = 0x7ff0_0000_0000_0000;
890 const MAN_MASK: u64 = 0x000f_ffff_ffff_ffff;
891
892 #[inline(always)]
893 fn zero_if_bad(den_bits: u64, q: f64) -> f64 {
894 let is_zero = (den_bits & ABS_MASK) == 0;
895 let is_nan = (den_bits & EXP_MASK) == EXP_MASK && (den_bits & MAN_MASK) != 0;
896 let good_mask: u64 = (!(is_zero | is_nan) as u64).wrapping_neg();
897 f64::from_bits(q.to_bits() & good_mask)
898 }
899
900 while i + 4 <= len {
901 let base_p = i - period;
902 let p0 = data[base_p + 0];
903 let p1 = data[base_p + 1];
904 let p2 = data[base_p + 2];
905 let p3 = data[base_p + 3];
906
907 let c0 = data[i + 0];
908 let c1 = data[i + 1];
909 let c2 = data[i + 2];
910 let c3 = data[i + 3];
911
912 out[i + 0] = zero_if_bad(p0.to_bits(), c0 / p0);
913 out[i + 1] = zero_if_bad(p1.to_bits(), c1 / p1);
914 out[i + 2] = zero_if_bad(p2.to_bits(), c2 / p2);
915 out[i + 3] = zero_if_bad(p3.to_bits(), c3 / p3);
916
917 i += 4;
918 }
919 while i < len {
920 let p = data[i - period];
921 let c = data[i];
922 out[i] = zero_if_bad(p.to_bits(), c / p);
923 i += 1;
924 }
925}
926
927#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
928#[inline(always)]
929unsafe fn rocr_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
930 #[cfg(target_feature = "avx2")]
931 {
932 let start = first + period;
933 let len = data.len();
934 if start >= len {
935 return;
936 }
937 let mut i = start;
938 let mut p = i - period;
939 let end = len;
940 let zero = _mm256_set1_pd(0.0);
941
942 while i + 8 <= end {
943 let cur0 = _mm256_loadu_pd(data.as_ptr().add(i));
944 let pst0 = _mm256_loadu_pd(data.as_ptr().add(p));
945 let div0 = _mm256_div_pd(cur0, pst0);
946 let m0 = _mm256_or_pd(
947 _mm256_cmp_pd(pst0, zero, _CMP_EQ_OQ),
948 _mm256_cmp_pd(pst0, pst0, _CMP_UNORD_Q),
949 );
950 let res0 = _mm256_andnot_pd(m0, div0);
951 _mm256_storeu_pd(out.as_mut_ptr().add(i), res0);
952
953 let cur1 = _mm256_loadu_pd(data.as_ptr().add(i + 4));
954 let pst1 = _mm256_loadu_pd(data.as_ptr().add(p + 4));
955 let div1 = _mm256_div_pd(cur1, pst1);
956 let m1 = _mm256_or_pd(
957 _mm256_cmp_pd(pst1, zero, _CMP_EQ_OQ),
958 _mm256_cmp_pd(pst1, pst1, _CMP_UNORD_Q),
959 );
960 let res1 = _mm256_andnot_pd(m1, div1);
961 _mm256_storeu_pd(out.as_mut_ptr().add(i + 4), res1);
962
963 i += 8;
964 p += 8;
965 }
966
967 while i + 4 <= end {
968 let cur = _mm256_loadu_pd(data.as_ptr().add(i));
969 let pst = _mm256_loadu_pd(data.as_ptr().add(p));
970 let div = _mm256_div_pd(cur, pst);
971 let m = _mm256_or_pd(
972 _mm256_cmp_pd(pst, zero, _CMP_EQ_OQ),
973 _mm256_cmp_pd(pst, pst, _CMP_UNORD_Q),
974 );
975 let res = _mm256_andnot_pd(m, div);
976 _mm256_storeu_pd(out.as_mut_ptr().add(i), res);
977 i += 4;
978 p += 4;
979 }
980
981 while i < end {
982 let past = *data.get_unchecked(p);
983 let cur = *data.get_unchecked(i);
984 let b = past.to_bits();
985 let is_zero = (b & 0x7fff_ffff_ffff_ffff) == 0;
986 let is_nan = (b & 0x7ff0_0000_0000_0000) == 0x7ff0_0000_0000_0000
987 && (b & 0x000f_ffff_ffff_ffff) != 0;
988 *out.get_unchecked_mut(i) = if is_zero | is_nan { 0.0 } else { cur / past };
989 i += 1;
990 p += 1;
991 }
992 return;
993 }
994 rocr_row_scalar(data, first, period, out)
995}
996
997#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
998#[inline(always)]
999pub unsafe fn rocr_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1000 #[cfg(target_feature = "avx512f")]
1001 {
1002 let start = first + period;
1003 let len = data.len();
1004 if start >= len {
1005 return;
1006 }
1007 let mut i = start;
1008 let end = len;
1009 let step = 8usize;
1010 while i + step <= end {
1011 let cur = _mm512_loadu_pd(data.as_ptr().add(i));
1012 let pst = _mm512_loadu_pd(data.as_ptr().add(i - period));
1013 let m0 = _mm512_cmp_pd_mask(pst, _mm512_set1_pd(0.0), _CMP_EQ_OQ);
1014 let m1 = _mm512_cmp_pd_mask(pst, pst, _CMP_UNORD_Q);
1015 let good = !(m0 | m1);
1016 let res = _mm512_maskz_div_pd(good, cur, pst);
1017 _mm512_storeu_pd(out.as_mut_ptr().add(i), res);
1018 i += step;
1019 }
1020 for j in i..end {
1021 let p = *data.get_unchecked(j - period);
1022 let c = *data.get_unchecked(j);
1023 *out.get_unchecked_mut(j) = if p == 0.0 || p.is_nan() { 0.0 } else { c / p };
1024 }
1025 return;
1026 }
1027 rocr_row_scalar(data, first, period, out)
1028}
1029
1030#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1031#[inline(always)]
1032pub unsafe fn rocr_row_avx512_short(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1033 rocr_row_avx512(data, first, period, out)
1034}
1035
1036#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1037#[inline(always)]
1038pub unsafe fn rocr_row_avx512_long(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1039 rocr_row_avx512(data, first, period, out)
1040}
1041
1042#[inline(always)]
1043pub fn expand_grid_rocr(r: &RocrBatchRange) -> Vec<RocrParams> {
1044 expand_grid(r).unwrap_or_else(|_| Vec::new())
1045}
1046
1047#[inline(always)]
1048unsafe fn rocr_row_scalar_mul(
1049 data: &[f64],
1050 inv: &[f64],
1051 first: usize,
1052 period: usize,
1053 out: &mut [f64],
1054) {
1055 let mut i = first + period;
1056 let len = data.len();
1057 while i < len {
1058 *out.get_unchecked_mut(i) = *data.get_unchecked(i) * *inv.get_unchecked(i - period);
1059 i += 1;
1060 }
1061}
1062
1063#[cfg(feature = "python")]
1064#[pyfunction(name = "rocr")]
1065#[pyo3(signature = (data, period=DEFAULT_PERIOD, kernel=None))]
1066pub fn rocr_py<'py>(
1067 py: Python<'py>,
1068 data: numpy::PyReadonlyArray1<'py, f64>,
1069 period: usize,
1070 kernel: Option<&str>,
1071) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
1072 use numpy::{IntoPyArray, PyArrayMethods};
1073
1074 let slice_in = data.as_slice()?;
1075 let kern = validate_kernel(kernel, false)?;
1076 let params = RocrParams {
1077 period: Some(period),
1078 };
1079 let input = RocrInput::from_slice(slice_in, params);
1080
1081 let result_vec: Vec<f64> = py
1082 .allow_threads(|| rocr_with_kernel(&input, kern).map(|o| o.values))
1083 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1084
1085 Ok(result_vec.into_pyarray(py))
1086}
1087
1088#[cfg(feature = "python")]
1089#[pyclass(name = "RocrStream")]
1090pub struct RocrStreamPy {
1091 stream: RocrStream,
1092}
1093
1094#[cfg(feature = "python")]
1095#[pymethods]
1096impl RocrStreamPy {
1097 #[new]
1098 fn new(period: usize) -> PyResult<Self> {
1099 let params = RocrParams {
1100 period: Some(period),
1101 };
1102 let stream =
1103 RocrStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1104 Ok(RocrStreamPy { stream })
1105 }
1106
1107 fn update(&mut self, value: f64) -> Option<f64> {
1108 self.stream.update(value)
1109 }
1110}
1111
1112#[cfg(all(feature = "python", feature = "cuda"))]
1113#[pyclass(module = "ta_indicators.cuda", name = "RocrDeviceArrayF32", unsendable)]
1114pub struct RocrDeviceArrayF32Py {
1115 pub inner: DeviceArrayF32,
1116 _ctx_guard: Arc<Context>,
1117 device_id: i32,
1118}
1119
1120#[cfg(all(feature = "python", feature = "cuda"))]
1121#[pymethods]
1122impl RocrDeviceArrayF32Py {
1123 #[getter]
1124 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1125 let d = PyDict::new(py);
1126 let inner = &self.inner;
1127 let itemsize = std::mem::size_of::<f32>();
1128 d.set_item("shape", (inner.rows, inner.cols))?;
1129 d.set_item("typestr", "<f4")?;
1130 d.set_item("strides", (inner.cols * itemsize, itemsize))?;
1131 let ptr_val = inner.buf.as_device_ptr().as_raw() as usize;
1132 d.set_item("data", (ptr_val, false))?;
1133
1134 d.set_item("version", 3)?;
1135 Ok(d)
1136 }
1137
1138 fn __dlpack_device__(&self) -> (i32, i32) {
1139 (2, self.device_id)
1140 }
1141
1142 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, _copy=None))]
1143 fn __dlpack__<'py>(
1144 &mut self,
1145 py: Python<'py>,
1146 stream: Option<PyObject>,
1147 max_version: Option<PyObject>,
1148 dl_device: Option<PyObject>,
1149 _copy: Option<PyObject>,
1150 ) -> PyResult<PyObject> {
1151 if let Some(dev) = dl_device {
1152 if let Ok((dtype, did)) = dev.extract::<(i32, i32)>(py) {
1153 if dtype != 2 || did != self.device_id {
1154 return Err(PyValueError::new_err("dl_device mismatch for ROCR buffer"));
1155 }
1156 }
1157 }
1158
1159 if let Some(s) = stream {
1160 if let Ok(v) = s.extract::<i64>(py) {
1161 if v == 0 {
1162 return Err(PyValueError::new_err(
1163 "__dlpack__(stream=0) is invalid for CUDA",
1164 ));
1165 }
1166 }
1167 }
1168
1169 let dummy =
1170 DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
1171 let inner = std::mem::replace(
1172 &mut self.inner,
1173 DeviceArrayF32 {
1174 buf: dummy,
1175 rows: 0,
1176 cols: 0,
1177 },
1178 );
1179
1180 let rows = inner.rows;
1181 let cols = inner.cols;
1182 let buf = inner.buf;
1183
1184 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1185
1186 export_f32_cuda_dlpack_2d(py, buf, rows, cols, self.device_id, max_version_bound)
1187 }
1188}
1189
1190#[cfg(all(feature = "python", feature = "cuda"))]
1191impl RocrDeviceArrayF32Py {
1192 fn new_from_cuda(inner: DeviceArrayF32, ctx_guard: Arc<Context>, device_id: u32) -> Self {
1193 Self {
1194 inner,
1195 _ctx_guard: ctx_guard,
1196 device_id: device_id as i32,
1197 }
1198 }
1199}
1200
1201#[cfg(feature = "python")]
1202#[pyfunction(name = "rocr_batch")]
1203#[pyo3(signature = (data, period_range, kernel=None))]
1204pub fn rocr_batch_py<'py>(
1205 py: Python<'py>,
1206 data: numpy::PyReadonlyArray1<'py, f64>,
1207 period_range: (usize, usize, usize),
1208 kernel: Option<&str>,
1209) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1210 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1211 use pyo3::types::PyDict;
1212 use std::mem::MaybeUninit;
1213
1214 let slice_in = data.as_slice()?;
1215 let sweep = RocrBatchRange {
1216 period: period_range,
1217 };
1218
1219 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1220 let rows = combos.len();
1221 let cols = slice_in.len();
1222 let (first, _max_p) =
1223 rocr_batch_prepare(slice_in, &combos).map_err(|e| PyValueError::new_err(e.to_string()))?;
1224
1225 let kern = validate_kernel(kernel, true)?;
1226 let simd = match match kern {
1227 Kernel::Auto => detect_best_batch_kernel(),
1228 k => k,
1229 } {
1230 Kernel::Avx512Batch => Kernel::Avx512,
1231 Kernel::Avx2Batch => Kernel::Avx2,
1232 Kernel::ScalarBatch => Kernel::Scalar,
1233 _ => unreachable!(),
1234 };
1235
1236 let total = rows
1237 .checked_mul(cols)
1238 .ok_or_else(|| PyValueError::new_err("rocr_batch_py: rows*cols overflow"))?;
1239 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1240 let slice_out = unsafe { out_arr.as_slice_mut()? };
1241 let warms: Vec<usize> = combos
1242 .iter()
1243 .map(|c| {
1244 let p = c.period.unwrap();
1245 first
1246 .checked_add(p)
1247 .ok_or_else(|| PyValueError::new_err("rocr_batch_py: warmup overflow"))
1248 })
1249 .collect::<Result<_, _>>()?;
1250
1251 unsafe {
1252 let mu: &mut [MaybeUninit<f64>] = std::slice::from_raw_parts_mut(
1253 slice_out.as_mut_ptr() as *mut MaybeUninit<f64>,
1254 slice_out.len(),
1255 );
1256 init_matrix_prefixes(mu, cols, &warms);
1257 }
1258
1259 py.allow_threads(|| rocr_batch_inner_into(slice_in, &combos, first, simd, true, slice_out))
1260 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1261
1262 let dict = PyDict::new(py);
1263 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1264 dict.set_item(
1265 "periods",
1266 combos
1267 .iter()
1268 .map(|p| p.period.unwrap() as u64)
1269 .collect::<Vec<_>>()
1270 .into_pyarray(py),
1271 )?;
1272
1273 Ok(dict)
1274}
1275
1276#[cfg(all(feature = "python", feature = "cuda"))]
1277#[pyfunction(name = "rocr_cuda_batch_dev")]
1278#[pyo3(signature = (data_f32, period_range, device_id=0))]
1279pub fn rocr_cuda_batch_dev_py(
1280 py: Python<'_>,
1281 data_f32: numpy::PyReadonlyArray1<'_, f32>,
1282 period_range: (usize, usize, usize),
1283 device_id: usize,
1284) -> PyResult<RocrDeviceArrayF32Py> {
1285 use crate::cuda::cuda_available;
1286 use crate::cuda::CudaRocr;
1287
1288 if !cuda_available() {
1289 return Err(PyValueError::new_err("CUDA not available"));
1290 }
1291
1292 let slice = data_f32.as_slice()?;
1293 let sweep = RocrBatchRange {
1294 period: period_range,
1295 };
1296 let result: PyResult<(DeviceArrayF32, Arc<Context>, u32)> = py.allow_threads(|| {
1297 let cuda = CudaRocr::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1298 let ctx = cuda.context_arc();
1299 let dev_id = cuda.device_id();
1300 let buf = cuda
1301 .rocr_batch_dev(slice, &sweep)
1302 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1303 Ok((buf, ctx, dev_id))
1304 });
1305 let (inner, ctx, dev_id) = result?;
1306
1307 Ok(RocrDeviceArrayF32Py::new_from_cuda(inner, ctx, dev_id))
1308}
1309
1310#[cfg(all(feature = "python", feature = "cuda"))]
1311#[pyfunction(name = "rocr_cuda_many_series_one_param_dev")]
1312#[pyo3(signature = (data_tm_f32, cols, rows, period, device_id=0))]
1313pub fn rocr_cuda_many_series_one_param_dev_py(
1314 py: Python<'_>,
1315 data_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
1316 cols: usize,
1317 rows: usize,
1318 period: usize,
1319 device_id: usize,
1320) -> PyResult<RocrDeviceArrayF32Py> {
1321 use crate::cuda::cuda_available;
1322 use crate::cuda::CudaRocr;
1323
1324 if !cuda_available() {
1325 return Err(PyValueError::new_err("CUDA not available"));
1326 }
1327 let slice = data_tm_f32.as_slice()?;
1328 let result: PyResult<(DeviceArrayF32, Arc<Context>, u32)> = py.allow_threads(|| {
1329 let cuda = CudaRocr::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1330 let ctx = cuda.context_arc();
1331 let dev_id = cuda.device_id();
1332 let buf = cuda
1333 .rocr_many_series_one_param_time_major_dev(slice, cols, rows, period)
1334 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1335 Ok((buf, ctx, dev_id))
1336 });
1337 let (inner, ctx, dev_id) = result?;
1338 Ok(RocrDeviceArrayF32Py::new_from_cuda(inner, ctx, dev_id))
1339}
1340
1341#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1342#[wasm_bindgen]
1343pub fn rocr_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1344 let params = RocrParams {
1345 period: Some(period),
1346 };
1347 let input = RocrInput::from_slice(data, params);
1348
1349 let mut output = vec![0.0; data.len()];
1350 rocr_into_slice(&mut output, &input, Kernel::Auto)
1351 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1352
1353 Ok(output)
1354}
1355
1356#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1357#[wasm_bindgen]
1358pub fn rocr_alloc(len: usize) -> *mut f64 {
1359 let mut vec = Vec::<f64>::with_capacity(len);
1360 let ptr = vec.as_mut_ptr();
1361 std::mem::forget(vec);
1362 ptr
1363}
1364
1365#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1366#[wasm_bindgen]
1367pub fn rocr_free(ptr: *mut f64, len: usize) {
1368 unsafe {
1369 let _ = Vec::from_raw_parts(ptr, len, len);
1370 }
1371}
1372
1373#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1374#[wasm_bindgen]
1375pub fn rocr_into(
1376 in_ptr: *const f64,
1377 out_ptr: *mut f64,
1378 len: usize,
1379 period: usize,
1380) -> Result<(), JsValue> {
1381 if in_ptr.is_null() || out_ptr.is_null() {
1382 return Err(JsValue::from_str("Null pointer provided"));
1383 }
1384
1385 unsafe {
1386 let data = std::slice::from_raw_parts(in_ptr, len);
1387 let params = RocrParams {
1388 period: Some(period),
1389 };
1390 let input = RocrInput::from_slice(data, params);
1391
1392 if in_ptr == out_ptr {
1393 let mut temp = vec![0.0; len];
1394 rocr_into_slice(&mut temp, &input, Kernel::Auto)
1395 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1396 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1397 out.copy_from_slice(&temp);
1398 } else {
1399 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1400 rocr_into_slice(out, &input, Kernel::Auto)
1401 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1402 }
1403
1404 Ok(())
1405 }
1406}
1407
1408#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1409#[derive(Serialize, Deserialize)]
1410pub struct RocrBatchConfig {
1411 pub period_range: (usize, usize, usize),
1412}
1413
1414#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1415#[derive(Serialize, Deserialize)]
1416pub struct RocrBatchJsOutput {
1417 pub values: Vec<f64>,
1418 pub combos: Vec<RocrParams>,
1419 pub rows: usize,
1420 pub cols: usize,
1421}
1422
1423#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1424#[wasm_bindgen(js_name = rocr_batch)]
1425pub fn rocr_batch_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1426 let config: RocrBatchConfig = serde_wasm_bindgen::from_value(config)
1427 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1428
1429 let sweep = RocrBatchRange {
1430 period: config.period_range,
1431 };
1432
1433 let output = rocr_batch_with_kernel(data, &sweep, Kernel::Auto)
1434 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1435
1436 let js_output = RocrBatchJsOutput {
1437 values: output.values,
1438 combos: output.combos,
1439 rows: output.rows,
1440 cols: output.cols,
1441 };
1442
1443 serde_wasm_bindgen::to_value(&js_output)
1444 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1445}
1446
1447#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1448#[wasm_bindgen]
1449pub fn rocr_batch_into(
1450 in_ptr: *const f64,
1451 out_ptr: *mut f64,
1452 len: usize,
1453 period_start: usize,
1454 period_end: usize,
1455 period_step: usize,
1456) -> Result<usize, JsValue> {
1457 if in_ptr.is_null() || out_ptr.is_null() {
1458 return Err(JsValue::from_str("null pointer passed to rocr_batch_into"));
1459 }
1460
1461 unsafe {
1462 let sweep = RocrBatchRange {
1463 period: (period_start, period_end, period_step),
1464 };
1465
1466 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1467 let period_count = combos.len();
1468
1469 if in_ptr == out_ptr {
1470 let total_elements = period_count
1471 .checked_mul(len)
1472 .ok_or_else(|| JsValue::from_str("rocr_batch_into: rows*cols overflow"))?;
1473 let mut temp = vec![0.0; total_elements];
1474 let (first, _max_p) = {
1475 let data = std::slice::from_raw_parts(in_ptr, len);
1476 rocr_batch_prepare(data, &combos).map_err(|e| JsValue::from_str(&e.to_string()))?
1477 };
1478
1479 use std::mem::MaybeUninit;
1480 let warms: Vec<usize> = combos
1481 .iter()
1482 .map(|c| {
1483 let p = c.period.unwrap();
1484 first
1485 .checked_add(p)
1486 .ok_or_else(|| JsValue::from_str("rocr_batch_into: warmup overflow"))
1487 })
1488 .collect::<Result<_, _>>()?;
1489
1490 {
1491 let mu: &mut [MaybeUninit<f64>] = std::slice::from_raw_parts_mut(
1492 temp.as_mut_ptr() as *mut MaybeUninit<f64>,
1493 total_elements,
1494 );
1495 init_matrix_prefixes(mu, len, &warms);
1496 }
1497
1498 let simd = match detect_best_batch_kernel() {
1499 Kernel::Avx512Batch => Kernel::Avx512,
1500 Kernel::Avx2Batch => Kernel::Avx2,
1501 _ => Kernel::Scalar,
1502 };
1503 {
1504 let data = std::slice::from_raw_parts(in_ptr, len);
1505 rocr_batch_inner_into(data, &combos, first, simd, false, &mut temp)
1506 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1507 }
1508
1509 let out = std::slice::from_raw_parts_mut(out_ptr, total_elements);
1510 out.copy_from_slice(&temp);
1511 return Ok(period_count);
1512 }
1513
1514 let data = std::slice::from_raw_parts(in_ptr, len);
1515 let (first, _max_p) =
1516 rocr_batch_prepare(data, &combos).map_err(|e| JsValue::from_str(&e.to_string()))?;
1517
1518 let total_elements = period_count
1519 .checked_mul(len)
1520 .ok_or_else(|| JsValue::from_str("rocr_batch_into: rows*cols overflow"))?;
1521
1522 use std::mem::MaybeUninit;
1523 let warms: Vec<usize> = combos
1524 .iter()
1525 .map(|c| {
1526 let p = c.period.unwrap();
1527 first
1528 .checked_add(p)
1529 .ok_or_else(|| JsValue::from_str("rocr_batch_into: warmup overflow"))
1530 })
1531 .collect::<Result<_, _>>()?;
1532 let mu: &mut [MaybeUninit<f64>] =
1533 std::slice::from_raw_parts_mut(out_ptr as *mut MaybeUninit<f64>, total_elements);
1534 init_matrix_prefixes(mu, len, &warms);
1535 let out = std::slice::from_raw_parts_mut(out_ptr, total_elements);
1536
1537 let simd = match detect_best_batch_kernel() {
1538 Kernel::Avx512Batch => Kernel::Avx512,
1539 Kernel::Avx2Batch => Kernel::Avx2,
1540 _ => Kernel::Scalar,
1541 };
1542 rocr_batch_inner_into(data, &combos, first, simd, false, out)
1543 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1544
1545 Ok(period_count)
1546 }
1547}
1548
1549#[cfg(test)]
1550mod tests {
1551 use super::*;
1552 use crate::skip_if_unsupported;
1553 use crate::utilities::data_loader::read_candles_from_csv;
1554
1555 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1556 #[test]
1557 fn test_rocr_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
1558 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1559 let candles = read_candles_from_csv(file_path)?;
1560
1561 let input = RocrInput::from_candles(&candles, "close", RocrParams::default());
1562 let baseline = rocr(&input)?.values;
1563
1564 let mut out = vec![0.0; candles.close.len()];
1565 rocr_into(&input, &mut out)?;
1566
1567 assert_eq!(baseline.len(), out.len());
1568 fn eq_or_both_nan(a: f64, b: f64) -> bool {
1569 (a.is_nan() && b.is_nan()) || (a == b) || (a - b).abs() <= 1e-12
1570 }
1571 for (i, (&a, &b)) in baseline.iter().zip(out.iter()).enumerate() {
1572 assert!(
1573 eq_or_both_nan(a, b),
1574 "mismatch at index {}: api={} into={}",
1575 i,
1576 a,
1577 b
1578 );
1579 }
1580 Ok(())
1581 }
1582
1583 fn check_rocr_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1584 skip_if_unsupported!(kernel, test_name);
1585 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1586 let candles = read_candles_from_csv(file_path)?;
1587
1588 let default_params = RocrParams { period: None };
1589 let input = RocrInput::from_candles(&candles, "close", default_params);
1590 let output = rocr_with_kernel(&input, kernel)?;
1591 assert_eq!(output.values.len(), candles.close.len());
1592
1593 Ok(())
1594 }
1595
1596 fn check_rocr_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1597 skip_if_unsupported!(kernel, test_name);
1598 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1599 let candles = read_candles_from_csv(file_path)?;
1600
1601 let input = RocrInput::from_candles(&candles, "close", RocrParams { period: Some(10) });
1602 let result = rocr_with_kernel(&input, kernel)?;
1603 let expected_last_five = [
1604 0.9977448290950706,
1605 0.9944380965183492,
1606 0.9967247986764135,
1607 0.9950545846019277,
1608 0.984954072979463,
1609 ];
1610 let start = result.values.len().saturating_sub(5);
1611 for (i, &val) in result.values[start..].iter().enumerate() {
1612 let diff = (val - expected_last_five[i]).abs();
1613 assert!(
1614 diff < 1e-8,
1615 "[{}] ROCR {:?} mismatch at idx {}: got {}, expected {}",
1616 test_name,
1617 kernel,
1618 i,
1619 val,
1620 expected_last_five[i]
1621 );
1622 }
1623 Ok(())
1624 }
1625
1626 fn check_rocr_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1627 skip_if_unsupported!(kernel, test_name);
1628 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1629 let candles = read_candles_from_csv(file_path)?;
1630
1631 let input = RocrInput::with_default_candles(&candles);
1632 match input.data {
1633 RocrData::Candles { source, .. } => assert_eq!(source, "close"),
1634 _ => panic!("Expected RocrData::Candles"),
1635 }
1636 let output = rocr_with_kernel(&input, kernel)?;
1637 assert_eq!(output.values.len(), candles.close.len());
1638
1639 Ok(())
1640 }
1641
1642 fn check_rocr_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1643 skip_if_unsupported!(kernel, test_name);
1644 let input_data = [10.0, 20.0, 30.0];
1645 let params = RocrParams { period: Some(0) };
1646 let input = RocrInput::from_slice(&input_data, params);
1647 let res = rocr_with_kernel(&input, kernel);
1648 assert!(
1649 res.is_err(),
1650 "[{}] ROCR should fail with zero period",
1651 test_name
1652 );
1653 Ok(())
1654 }
1655
1656 fn check_rocr_period_exceeds_length(
1657 test_name: &str,
1658 kernel: Kernel,
1659 ) -> Result<(), Box<dyn Error>> {
1660 skip_if_unsupported!(kernel, test_name);
1661 let data_small = [10.0, 20.0, 30.0];
1662 let params = RocrParams { period: Some(10) };
1663 let input = RocrInput::from_slice(&data_small, params);
1664 let res = rocr_with_kernel(&input, kernel);
1665 assert!(
1666 res.is_err(),
1667 "[{}] ROCR should fail with period exceeding length",
1668 test_name
1669 );
1670 Ok(())
1671 }
1672
1673 fn check_rocr_very_small_dataset(
1674 test_name: &str,
1675 kernel: Kernel,
1676 ) -> Result<(), Box<dyn Error>> {
1677 skip_if_unsupported!(kernel, test_name);
1678 let single_point = [42.0];
1679 let params = RocrParams { period: Some(9) };
1680 let input = RocrInput::from_slice(&single_point, params);
1681 let res = rocr_with_kernel(&input, kernel);
1682 assert!(
1683 res.is_err(),
1684 "[{}] ROCR should fail with insufficient data",
1685 test_name
1686 );
1687 Ok(())
1688 }
1689
1690 fn check_rocr_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1691 skip_if_unsupported!(kernel, test_name);
1692 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1693 let candles = read_candles_from_csv(file_path)?;
1694
1695 let first_params = RocrParams { period: Some(14) };
1696 let first_input = RocrInput::from_candles(&candles, "close", first_params);
1697 let first_result = rocr_with_kernel(&first_input, kernel)?;
1698
1699 let second_params = RocrParams { period: Some(14) };
1700 let second_input = RocrInput::from_slice(&first_result.values, second_params);
1701 let second_result = rocr_with_kernel(&second_input, kernel)?;
1702
1703 assert_eq!(second_result.values.len(), first_result.values.len());
1704 for i in 28..second_result.values.len() {
1705 assert!(
1706 !second_result.values[i].is_nan(),
1707 "[{}] ROCR Slice Reinput {:?} unexpected NaN at idx {}",
1708 test_name,
1709 kernel,
1710 i,
1711 );
1712 }
1713 Ok(())
1714 }
1715
1716 fn check_rocr_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1717 skip_if_unsupported!(kernel, test_name);
1718 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1719 let candles = read_candles_from_csv(file_path)?;
1720
1721 let input = RocrInput::from_candles(&candles, "close", RocrParams { period: Some(9) });
1722 let res = rocr_with_kernel(&input, kernel)?;
1723 assert_eq!(res.values.len(), candles.close.len());
1724 if res.values.len() > 240 {
1725 for (i, &val) in res.values[240..].iter().enumerate() {
1726 assert!(
1727 !val.is_nan(),
1728 "[{}] Found unexpected NaN at out-index {}",
1729 test_name,
1730 240 + i
1731 );
1732 }
1733 }
1734 Ok(())
1735 }
1736
1737 fn check_rocr_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1738 skip_if_unsupported!(kernel, test_name);
1739
1740 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1741 let candles = read_candles_from_csv(file_path)?;
1742
1743 let period = 9;
1744
1745 let input = RocrInput::from_candles(
1746 &candles,
1747 "close",
1748 RocrParams {
1749 period: Some(period),
1750 },
1751 );
1752 let batch_output = rocr_with_kernel(&input, kernel)?.values;
1753
1754 let mut stream = RocrStream::try_new(RocrParams {
1755 period: Some(period),
1756 })?;
1757
1758 let mut stream_values = Vec::with_capacity(candles.close.len());
1759 for &price in &candles.close {
1760 match stream.update(price) {
1761 Some(rocr_val) => stream_values.push(rocr_val),
1762 None => stream_values.push(f64::NAN),
1763 }
1764 }
1765
1766 assert_eq!(batch_output.len(), stream_values.len());
1767 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1768 if b.is_nan() && s.is_nan() {
1769 continue;
1770 }
1771 let diff = (b - s).abs();
1772 assert!(
1773 diff < 1e-9,
1774 "[{}] ROCR streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1775 test_name,
1776 i,
1777 b,
1778 s,
1779 diff
1780 );
1781 }
1782 Ok(())
1783 }
1784
1785 #[cfg(debug_assertions)]
1786 fn check_rocr_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1787 skip_if_unsupported!(kernel, test_name);
1788
1789 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1790 let candles = read_candles_from_csv(file_path)?;
1791
1792 let test_params = vec![
1793 RocrParams::default(),
1794 RocrParams { period: Some(1) },
1795 RocrParams { period: Some(5) },
1796 RocrParams { period: Some(20) },
1797 RocrParams { period: Some(50) },
1798 RocrParams { period: Some(100) },
1799 RocrParams { period: Some(2) },
1800 ];
1801
1802 for (param_idx, params) in test_params.iter().enumerate() {
1803 let input = RocrInput::from_candles(&candles, "close", params.clone());
1804 let output = rocr_with_kernel(&input, kernel)?;
1805
1806 for (i, &val) in output.values.iter().enumerate() {
1807 if val.is_nan() {
1808 continue;
1809 }
1810
1811 let bits = val.to_bits();
1812
1813 if bits == 0x11111111_11111111 {
1814 panic!(
1815 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1816 with params: period={} (param set {})",
1817 test_name,
1818 val,
1819 bits,
1820 i,
1821 params.period.unwrap_or(10),
1822 param_idx
1823 );
1824 }
1825
1826 if bits == 0x22222222_22222222 {
1827 panic!(
1828 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1829 with params: period={} (param set {})",
1830 test_name,
1831 val,
1832 bits,
1833 i,
1834 params.period.unwrap_or(10),
1835 param_idx
1836 );
1837 }
1838
1839 if bits == 0x33333333_33333333 {
1840 panic!(
1841 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1842 with params: period={} (param set {})",
1843 test_name,
1844 val,
1845 bits,
1846 i,
1847 params.period.unwrap_or(10),
1848 param_idx
1849 );
1850 }
1851 }
1852 }
1853
1854 Ok(())
1855 }
1856
1857 #[cfg(not(debug_assertions))]
1858 fn check_rocr_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1859 Ok(())
1860 }
1861
1862 #[cfg(feature = "proptest")]
1863 #[allow(clippy::float_cmp)]
1864 fn check_rocr_property(
1865 test_name: &str,
1866 kernel: Kernel,
1867 ) -> Result<(), Box<dyn std::error::Error>> {
1868 use proptest::prelude::*;
1869 skip_if_unsupported!(kernel, test_name);
1870
1871 let strat = (1usize..=64).prop_flat_map(|period| {
1872 (
1873 prop::collection::vec(
1874 prop::strategy::Union::new(vec![
1875 (0.1f64..10000f64).boxed(),
1876 prop::strategy::Just(0.0).boxed(),
1877 (1e-10f64..1e-5f64).boxed(),
1878 (1e5f64..1e8f64).boxed(),
1879 ])
1880 .prop_filter("finite values", |x| x.is_finite() && *x >= 0.0),
1881 period..400,
1882 ),
1883 Just(period),
1884 )
1885 });
1886
1887 proptest::test_runner::TestRunner::default()
1888 .run(&strat, |(data, period)| {
1889 let params = RocrParams {
1890 period: Some(period),
1891 };
1892 let input = RocrInput::from_slice(&data, params);
1893
1894 let RocrOutput { values: out } = rocr_with_kernel(&input, kernel).unwrap();
1895 let RocrOutput { values: ref_out } =
1896 rocr_with_kernel(&input, Kernel::Scalar).unwrap();
1897
1898 for i in 0..period {
1899 prop_assert!(
1900 out[i].is_nan(),
1901 "Expected NaN during warmup at index {}, got {}",
1902 i,
1903 out[i]
1904 );
1905 }
1906
1907 for i in period..data.len() {
1908 let current = data[i];
1909 let past = data[i - period];
1910
1911 let expected = if past == 0.0 || past.is_nan() {
1912 0.0
1913 } else {
1914 current / past
1915 };
1916
1917 let y = out[i];
1918 let r = ref_out[i];
1919
1920 if !y.is_nan() {
1921 let tolerance = if expected.abs() > 1.0 {
1922 expected.abs() * 1e-9
1923 } else {
1924 1e-9
1925 };
1926
1927 prop_assert!(
1928 (y - expected).abs() <= tolerance,
1929 "ROCR formula mismatch at idx {}: got {}, expected {} (current={}, past={})",
1930 i, y, expected, current, past
1931 );
1932
1933 prop_assert!(
1934 y >= 0.0,
1935 "ROCR should be non-negative at idx {}: got {}",
1936 i,
1937 y
1938 );
1939
1940 if current == 0.0 && past != 0.0 {
1941 prop_assert!(
1942 y == 0.0,
1943 "ROCR should be 0 when current=0 at idx {}: got {}",
1944 i,
1945 y
1946 );
1947 }
1948
1949 if past == 0.0 {
1950 prop_assert!(
1951 y == 0.0,
1952 "ROCR should be 0 when past=0 at idx {}: got {}",
1953 i,
1954 y
1955 );
1956 }
1957
1958 prop_assert!(
1959 y.is_finite(),
1960 "ROCR should be finite at idx {}: got {} (current={}, past={})",
1961 i,
1962 y,
1963 current,
1964 past
1965 );
1966 }
1967
1968 if period == 1 && i > 0 && data[i - 1] != 0.0 {
1969 let expected_simple = data[i] / data[i - 1];
1970 if !y.is_nan() {
1971 let tolerance = if expected_simple.abs() > 1.0 {
1972 expected_simple.abs() * 1e-9
1973 } else {
1974 1e-9
1975 };
1976 prop_assert!(
1977 (y - expected_simple).abs() <= tolerance,
1978 "Period=1 mismatch at idx {}: got {}, expected {}",
1979 i,
1980 y,
1981 expected_simple
1982 );
1983 }
1984 }
1985
1986 if i >= period && period > 1 {
1987 let window = &data[i - period + 1..=i];
1988
1989 let first_val = window[0];
1990 let is_constant = first_val != 0.0
1991 && window.iter().all(|&v| {
1992 (v - first_val).abs() <= 1e-10 * first_val.abs().max(1.0)
1993 });
1994
1995 if is_constant {
1996 prop_assert!(
1997 (y - 1.0).abs() <= 1e-9,
1998 "Constant data should yield ROCR=1.0 at idx {}: got {}",
1999 i,
2000 y
2001 );
2002 }
2003 }
2004
2005 let y_bits = y.to_bits();
2006 let r_bits = r.to_bits();
2007
2008 if !y.is_finite() || !r.is_finite() {
2009 prop_assert!(
2010 y.to_bits() == r.to_bits(),
2011 "finite/NaN mismatch idx {}: {} vs {}",
2012 i,
2013 y,
2014 r
2015 );
2016 continue;
2017 }
2018
2019 let ulp_diff: u64 = y_bits.abs_diff(r_bits);
2020
2021 let max_ulp = if y.abs() > 1e6 || y.abs() < 1e-6 {
2022 8
2023 } else {
2024 4
2025 };
2026
2027 prop_assert!(
2028 (y - r).abs() <= 1e-9 || ulp_diff <= max_ulp,
2029 "Kernel mismatch idx {}: {} vs {} (ULP={})",
2030 i,
2031 y,
2032 r,
2033 ulp_diff
2034 );
2035 }
2036 Ok(())
2037 })
2038 .unwrap();
2039
2040 Ok(())
2041 }
2042
2043 macro_rules! generate_all_rocr_tests {
2044 ($($test_fn:ident),*) => {
2045 paste::paste! {
2046 $(
2047 #[test]
2048 fn [<$test_fn _scalar_f64>]() {
2049 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2050 }
2051 )*
2052 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2053 $(
2054 #[test]
2055 fn [<$test_fn _avx2_f64>]() {
2056 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2057 }
2058 #[test]
2059 fn [<$test_fn _avx512_f64>]() {
2060 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2061 }
2062 )*
2063 }
2064 }
2065 }
2066
2067 generate_all_rocr_tests!(
2068 check_rocr_partial_params,
2069 check_rocr_accuracy,
2070 check_rocr_default_candles,
2071 check_rocr_zero_period,
2072 check_rocr_period_exceeds_length,
2073 check_rocr_very_small_dataset,
2074 check_rocr_reinput,
2075 check_rocr_nan_handling,
2076 check_rocr_streaming,
2077 check_rocr_no_poison
2078 );
2079
2080 #[cfg(feature = "proptest")]
2081 generate_all_rocr_tests!(check_rocr_property);
2082
2083 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2084 skip_if_unsupported!(kernel, test);
2085
2086 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2087 let c = read_candles_from_csv(file)?;
2088
2089 let output = RocrBatchBuilder::new()
2090 .period_static(10)
2091 .kernel(kernel)
2092 .apply_candles(&c, "close")?;
2093
2094 let def = RocrParams::default();
2095 let row = output.values_for(&def).expect("default row missing");
2096
2097 assert_eq!(row.len(), c.close.len());
2098
2099 let expected = [
2100 0.9977448290950706,
2101 0.9944380965183492,
2102 0.9967247986764135,
2103 0.9950545846019277,
2104 0.984954072979463,
2105 ];
2106 let start = row.len() - 5;
2107 for (i, &v) in row[start..].iter().enumerate() {
2108 assert!(
2109 (v - expected[i]).abs() < 1e-8,
2110 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
2111 );
2112 }
2113 Ok(())
2114 }
2115
2116 #[cfg(debug_assertions)]
2117 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2118 skip_if_unsupported!(kernel, test);
2119
2120 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2121 let c = read_candles_from_csv(file)?;
2122
2123 let test_configs = vec![
2124 (2, 10, 2),
2125 (5, 25, 5),
2126 (30, 60, 15),
2127 (2, 5, 1),
2128 (10, 50, 10),
2129 (1, 3, 1),
2130 (100, 200, 50),
2131 ];
2132
2133 for (cfg_idx, &(p_start, p_end, p_step)) in test_configs.iter().enumerate() {
2134 let output = RocrBatchBuilder::new()
2135 .kernel(kernel)
2136 .period_range(p_start, p_end, p_step)
2137 .apply_candles(&c, "close")?;
2138
2139 for (idx, &val) in output.values.iter().enumerate() {
2140 if val.is_nan() {
2141 continue;
2142 }
2143
2144 let bits = val.to_bits();
2145 let row = idx / output.cols;
2146 let col = idx % output.cols;
2147 let combo = &output.combos[row];
2148
2149 if bits == 0x11111111_11111111 {
2150 panic!(
2151 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2152 at row {} col {} (flat index {}) with params: period={}",
2153 test,
2154 cfg_idx,
2155 val,
2156 bits,
2157 row,
2158 col,
2159 idx,
2160 combo.period.unwrap_or(10)
2161 );
2162 }
2163
2164 if bits == 0x22222222_22222222 {
2165 panic!(
2166 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2167 at row {} col {} (flat index {}) with params: period={}",
2168 test,
2169 cfg_idx,
2170 val,
2171 bits,
2172 row,
2173 col,
2174 idx,
2175 combo.period.unwrap_or(10)
2176 );
2177 }
2178
2179 if bits == 0x33333333_33333333 {
2180 panic!(
2181 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2182 at row {} col {} (flat index {}) with params: period={}",
2183 test,
2184 cfg_idx,
2185 val,
2186 bits,
2187 row,
2188 col,
2189 idx,
2190 combo.period.unwrap_or(10)
2191 );
2192 }
2193 }
2194 }
2195
2196 Ok(())
2197 }
2198
2199 #[cfg(not(debug_assertions))]
2200 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2201 Ok(())
2202 }
2203
2204 macro_rules! gen_batch_tests {
2205 ($fn_name:ident) => {
2206 paste::paste! {
2207 #[test] fn [<$fn_name _scalar>]() {
2208 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2209 }
2210 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2211 #[test] fn [<$fn_name _avx2>]() {
2212 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2213 }
2214 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2215 #[test] fn [<$fn_name _avx512>]() {
2216 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2217 }
2218 #[test] fn [<$fn_name _auto_detect>]() {
2219 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]),
2220 Kernel::Auto);
2221 }
2222 }
2223 };
2224 }
2225 gen_batch_tests!(check_batch_default_row);
2226 gen_batch_tests!(check_batch_no_poison);
2227}