1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5 make_uninit_matrix,
6};
7#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
8use core::arch::x86_64::*;
9#[cfg(not(target_arch = "wasm32"))]
10use rayon::prelude::*;
11use std::convert::AsRef;
12use std::error::Error;
13use std::mem::MaybeUninit;
14use thiserror::Error;
15
16#[cfg(all(feature = "python", feature = "cuda"))]
17use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
18#[cfg(feature = "python")]
19use crate::utilities::kernel_validation::validate_kernel;
20#[cfg(feature = "python")]
21use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
22#[cfg(feature = "python")]
23use pyo3::exceptions::PyValueError;
24#[cfg(feature = "python")]
25use pyo3::prelude::*;
26#[cfg(feature = "python")]
27use pyo3::types::PyDict;
28
29#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
30use serde::{Deserialize, Serialize};
31#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
32use wasm_bindgen::prelude::*;
33
34#[derive(Debug, Clone)]
35pub enum RocData<'a> {
36 Candles {
37 candles: &'a Candles,
38 source: &'a str,
39 },
40 Slice(&'a [f64]),
41}
42
43#[derive(Debug, Clone)]
44pub struct RocOutput {
45 pub values: Vec<f64>,
46}
47
48#[derive(Debug, Clone)]
49#[cfg_attr(
50 all(target_arch = "wasm32", feature = "wasm"),
51 derive(Serialize, Deserialize)
52)]
53pub struct RocParams {
54 pub period: Option<usize>,
55}
56
57impl Default for RocParams {
58 fn default() -> Self {
59 Self { period: Some(9) }
60 }
61}
62
63#[derive(Debug, Clone)]
64pub struct RocInput<'a> {
65 pub data: RocData<'a>,
66 pub params: RocParams,
67}
68
69impl<'a> AsRef<[f64]> for RocInput<'a> {
70 #[inline(always)]
71 fn as_ref(&self) -> &[f64] {
72 match &self.data {
73 RocData::Slice(slice) => slice,
74 RocData::Candles { candles, source } => source_type(candles, source),
75 }
76 }
77}
78
79impl<'a> RocInput<'a> {
80 #[inline]
81 pub fn from_candles(candles: &'a Candles, source: &'a str, params: RocParams) -> Self {
82 Self {
83 data: RocData::Candles { candles, source },
84 params,
85 }
86 }
87 #[inline]
88 pub fn from_slice(slice: &'a [f64], params: RocParams) -> Self {
89 Self {
90 data: RocData::Slice(slice),
91 params,
92 }
93 }
94 #[inline]
95 pub fn with_default_candles(candles: &'a Candles) -> Self {
96 Self::from_candles(candles, "close", RocParams::default())
97 }
98 #[inline]
99 pub fn get_period(&self) -> usize {
100 self.params.period.unwrap_or(9)
101 }
102}
103
104#[derive(Copy, Clone, Debug)]
105pub struct RocBuilder {
106 period: Option<usize>,
107 kernel: Kernel,
108}
109
110impl Default for RocBuilder {
111 fn default() -> Self {
112 Self {
113 period: None,
114 kernel: Kernel::Auto,
115 }
116 }
117}
118
119impl RocBuilder {
120 #[inline(always)]
121 pub fn new() -> Self {
122 Self::default()
123 }
124 #[inline(always)]
125 pub fn period(mut self, n: usize) -> Self {
126 self.period = Some(n);
127 self
128 }
129 #[inline(always)]
130 pub fn kernel(mut self, k: Kernel) -> Self {
131 self.kernel = k;
132 self
133 }
134 #[inline(always)]
135 pub fn apply(self, c: &Candles) -> Result<RocOutput, RocError> {
136 let p = RocParams {
137 period: self.period,
138 };
139 let i = RocInput::from_candles(c, "close", p);
140 roc_with_kernel(&i, self.kernel)
141 }
142 #[inline(always)]
143 pub fn apply_slice(self, d: &[f64]) -> Result<RocOutput, RocError> {
144 let p = RocParams {
145 period: self.period,
146 };
147 let i = RocInput::from_slice(d, p);
148 roc_with_kernel(&i, self.kernel)
149 }
150 #[inline(always)]
151 pub fn into_stream(self) -> Result<RocStream, RocError> {
152 let p = RocParams {
153 period: self.period,
154 };
155 RocStream::try_new(p)
156 }
157}
158
159#[derive(Debug, Error)]
160pub enum RocError {
161 #[error("roc: Input data slice is empty.")]
162 EmptyInputData,
163 #[error("roc: Input data slice is empty.")]
164 EmptyData,
165 #[error("roc: All values are NaN.")]
166 AllValuesNaN,
167 #[error("roc: Invalid period: period = {period}, data length = {data_len}")]
168 InvalidPeriod { period: usize, data_len: usize },
169 #[error("roc: Not enough valid data: needed = {needed}, valid = {valid}")]
170 NotEnoughValidData { needed: usize, valid: usize },
171 #[error("roc: Output length mismatch: expected = {expected}, got = {got}")]
172 OutputLengthMismatch { expected: usize, got: usize },
173 #[error("roc: Invalid range: start={start}, end={end}, step={step}")]
174 InvalidRange {
175 start: usize,
176 end: usize,
177 step: usize,
178 },
179 #[error("roc: Invalid kernel for batch: {0:?}")]
180 InvalidKernelForBatch(Kernel),
181}
182
183#[inline]
184pub fn roc(input: &RocInput) -> Result<RocOutput, RocError> {
185 roc_with_kernel(input, Kernel::Auto)
186}
187
188#[inline(always)]
189fn roc_prepare<'a>(
190 input: &'a RocInput,
191 kernel: Kernel,
192) -> Result<(&'a [f64], usize, usize, Kernel), RocError> {
193 let data: &[f64] = input.as_ref();
194 let len = data.len();
195 if len == 0 {
196 return Err(RocError::EmptyInputData);
197 }
198 let first = data
199 .iter()
200 .position(|x| !x.is_nan())
201 .ok_or(RocError::AllValuesNaN)?;
202 let period = input.get_period();
203 if period == 0 || period > len {
204 return Err(RocError::InvalidPeriod {
205 period,
206 data_len: len,
207 });
208 }
209 if len - first < period {
210 return Err(RocError::NotEnoughValidData {
211 needed: period,
212 valid: len - first,
213 });
214 }
215
216 let chosen = match kernel {
217 Kernel::Auto => Kernel::Scalar,
218 k => k,
219 };
220 Ok((data, period, first, chosen))
221}
222
223#[inline(always)]
224fn roc_compute_into(data: &[f64], period: usize, first: usize, kernel: Kernel, out: &mut [f64]) {
225 unsafe {
226 match kernel {
227 Kernel::Scalar => roc_scalar(data, period, first, out),
228 Kernel::ScalarBatch => roc_scalar(data, period, first, out),
229 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
230 Kernel::Avx2 => roc_avx2(data, period, first, out),
231 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
232 Kernel::Avx2Batch => roc_row_avx2(data, first, period, 0, std::ptr::null(), 0.0, out),
233 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
234 Kernel::Avx512 => roc_avx512(data, period, first, out),
235 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
236 Kernel::Avx512Batch => {
237 roc_row_avx512(data, first, period, 0, std::ptr::null(), 0.0, out)
238 }
239 _ => unreachable!(),
240 }
241 }
242}
243
244pub fn roc_with_kernel(input: &RocInput, kernel: Kernel) -> Result<RocOutput, RocError> {
245 let (data, period, first, chosen) = roc_prepare(input, kernel)?;
246
247 let mut out = alloc_with_nan_prefix(data.len(), first + period);
248 roc_compute_into(data, period, first, chosen, &mut out);
249 Ok(RocOutput { values: out })
250}
251
252#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
253#[inline]
254pub fn roc_into(input: &RocInput, out: &mut [f64]) -> Result<(), RocError> {
255 roc_into_slice(out, input, Kernel::Auto)
256}
257
258#[inline(always)]
259pub unsafe fn roc_indicator(input: &RocInput) -> Result<RocOutput, RocError> {
260 roc_with_kernel(input, Kernel::Auto)
261}
262#[inline(always)]
263pub unsafe fn roc_indicator_with_kernel(
264 input: &RocInput,
265 k: Kernel,
266) -> Result<RocOutput, RocError> {
267 roc_with_kernel(input, k)
268}
269#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
270#[inline(always)]
271pub unsafe fn roc_indicator_avx512(input: &RocInput) -> Result<RocOutput, RocError> {
272 roc_with_kernel(input, Kernel::Avx512)
273}
274#[inline(always)]
275pub unsafe fn roc_indicator_scalar(input: &RocInput) -> Result<RocOutput, RocError> {
276 roc_with_kernel(input, Kernel::Scalar)
277}
278#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
279#[inline(always)]
280pub unsafe fn roc_indicator_avx2(input: &RocInput) -> Result<RocOutput, RocError> {
281 roc_with_kernel(input, Kernel::Avx2)
282}
283#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
284#[inline(always)]
285pub unsafe fn roc_indicator_avx512_short(input: &RocInput) -> Result<RocOutput, RocError> {
286 roc_with_kernel(input, Kernel::Avx512)
287}
288#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
289#[inline(always)]
290pub unsafe fn roc_indicator_avx512_long(input: &RocInput) -> Result<RocOutput, RocError> {
291 roc_with_kernel(input, Kernel::Avx512)
292}
293
294#[inline(always)]
295pub fn roc_scalar(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
296 let len = data.len();
297 let start = first + period;
298
299 let dst = &mut out[start..];
300 let curr = &data[start..];
301 let prev = &data[first..(len - period)];
302 for ((d, &c), &p) in dst.iter_mut().zip(curr.iter()).zip(prev.iter()) {
303 if p == 0.0 || p.is_nan() {
304 *d = 0.0;
305 } else {
306 *d = (c / p).mul_add(100.0, -100.0);
307 }
308 }
309}
310
311#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
312#[target_feature(enable = "avx2,fma")]
313#[inline]
314pub unsafe fn roc_avx2(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
315 let len = data.len();
316 let start = first + period;
317 if start >= len {
318 return;
319 }
320
321 let n = len - start;
322 let base_curr = data.as_ptr().add(start);
323 let base_prev = data.as_ptr().add(first);
324 let base_out = out.as_mut_ptr().add(start);
325
326 let v_zero = _mm256_set1_pd(0.0);
327 let v_m100 = _mm256_set1_pd(-100.0);
328 let v_100 = _mm256_set1_pd(100.0);
329
330 let mut i = 0usize;
331 while i + 4 <= n {
332 let c = _mm256_loadu_pd(base_curr.add(i));
333 let p = _mm256_loadu_pd(base_prev.add(i));
334
335 let mask_zero = _mm256_cmp_pd(p, v_zero, _CMP_EQ_OQ);
336 let mask_nan = _mm256_cmp_pd(p, p, _CMP_UNORD_Q);
337 let mask_invalid = _mm256_or_pd(mask_zero, mask_nan);
338
339 let div = _mm256_div_pd(c, p);
340 let res = _mm256_fmadd_pd(div, v_100, v_m100);
341
342 let blended = _mm256_blendv_pd(res, v_zero, mask_invalid);
343 _mm256_storeu_pd(base_out.add(i), blended);
344 i += 4;
345 }
346
347 while i < n {
348 let p = *base_prev.add(i);
349 let c = *base_curr.add(i);
350 *base_out.add(i) = if p == 0.0 || p.is_nan() {
351 0.0
352 } else {
353 (c / p).mul_add(100.0, -100.0)
354 };
355 i += 1;
356 }
357}
358#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
359#[target_feature(enable = "avx512f,fma")]
360#[inline]
361pub unsafe fn roc_avx512(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
362 let len = data.len();
363 let start = first + period;
364 if start >= len {
365 return;
366 }
367
368 let n = len - start;
369 let base_curr = data.as_ptr().add(start);
370 let base_prev = data.as_ptr().add(first);
371 let base_out = out.as_mut_ptr().add(start);
372
373 let v_zero = _mm512_set1_pd(0.0);
374 let v_m100 = _mm512_set1_pd(-100.0);
375 let v_100 = _mm512_set1_pd(100.0);
376
377 let mut i = 0usize;
378 while i + 8 <= n {
379 let c = _mm512_loadu_pd(base_curr.add(i));
380 let p = _mm512_loadu_pd(base_prev.add(i));
381
382 let k_zero = _mm512_cmp_pd_mask(p, v_zero, _CMP_EQ_OQ);
383 let k_nan = _mm512_cmp_pd_mask(p, p, _CMP_UNORD_Q);
384 let k_invalid = k_zero | k_nan;
385
386 let div = _mm512_div_pd(c, p);
387 let res = _mm512_fmadd_pd(div, v_100, v_m100);
388
389 let blended = _mm512_mask_mov_pd(res, k_invalid, v_zero);
390 _mm512_storeu_pd(base_out.add(i), blended);
391 i += 8;
392 }
393
394 while i < n {
395 let p = *base_prev.add(i);
396 let c = *base_curr.add(i);
397 *base_out.add(i) = if p == 0.0 || p.is_nan() {
398 0.0
399 } else {
400 ((c / p) - 1.0) * 100.0
401 };
402 i += 1;
403 }
404}
405#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
406#[inline]
407pub unsafe fn roc_avx512_short(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
408 roc_scalar(data, period, first, out)
409}
410#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
411#[inline]
412pub unsafe fn roc_avx512_long(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
413 roc_scalar(data, period, first, out)
414}
415
416#[inline(always)]
417pub unsafe fn roc_row_scalar(
418 data: &[f64],
419 first: usize,
420 period: usize,
421 _stride: usize,
422 _weights: *const f64,
423 _inv_n: f64,
424 out: &mut [f64],
425) {
426 let len = data.len();
427 let start = first + period;
428 if start >= len {
429 return;
430 }
431
432 let base_ptr = data.as_ptr();
433 let prev_ptr = base_ptr.add(first);
434 let curr_ptr = base_ptr.add(start);
435 let dst_ptr = out.as_mut_ptr().add(start);
436
437 let n = len - start;
438 let mut i = 0usize;
439
440 while i + 4 <= n {
441 let p0 = *prev_ptr.add(i + 0);
442 let p1 = *prev_ptr.add(i + 1);
443 let p2 = *prev_ptr.add(i + 2);
444 let p3 = *prev_ptr.add(i + 3);
445
446 let c0 = *curr_ptr.add(i + 0);
447 let c1 = *curr_ptr.add(i + 1);
448 let c2 = *curr_ptr.add(i + 2);
449 let c3 = *curr_ptr.add(i + 3);
450
451 *dst_ptr.add(i + 0) = if p0 == 0.0 || p0.is_nan() {
452 0.0
453 } else {
454 (c0 / p0).mul_add(100.0, -100.0)
455 };
456 *dst_ptr.add(i + 1) = if p1 == 0.0 || p1.is_nan() {
457 0.0
458 } else {
459 (c1 / p1).mul_add(100.0, -100.0)
460 };
461 *dst_ptr.add(i + 2) = if p2 == 0.0 || p2.is_nan() {
462 0.0
463 } else {
464 (c2 / p2).mul_add(100.0, -100.0)
465 };
466 *dst_ptr.add(i + 3) = if p3 == 0.0 || p3.is_nan() {
467 0.0
468 } else {
469 (c3 / p3).mul_add(100.0, -100.0)
470 };
471
472 i += 4;
473 }
474
475 while i < n {
476 let p = *prev_ptr.add(i);
477 let c = *curr_ptr.add(i);
478 *dst_ptr.add(i) = if p == 0.0 || p.is_nan() {
479 0.0
480 } else {
481 (c / p).mul_add(100.0, -100.0)
482 };
483 i += 1;
484 }
485}
486#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
487#[inline(always)]
488pub unsafe fn roc_row_avx2(
489 data: &[f64],
490 first: usize,
491 period: usize,
492 _stride: usize,
493 _weights: *const f64,
494 _inv_n: f64,
495 out: &mut [f64],
496) {
497 roc_row_scalar(data, first, period, _stride, _weights, _inv_n, out)
498}
499#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
500#[inline(always)]
501pub unsafe fn roc_row_avx512(
502 data: &[f64],
503 first: usize,
504 period: usize,
505 _stride: usize,
506 _weights: *const f64,
507 _inv_n: f64,
508 out: &mut [f64],
509) {
510 roc_row_scalar(data, first, period, _stride, _weights, _inv_n, out)
511}
512#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
513#[inline(always)]
514pub unsafe fn roc_row_avx512_short(
515 data: &[f64],
516 first: usize,
517 period: usize,
518 _stride: usize,
519 _weights: *const f64,
520 _inv_n: f64,
521 out: &mut [f64],
522) {
523 roc_row_scalar(data, first, period, _stride, _weights, _inv_n, out)
524}
525#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
526#[inline(always)]
527pub unsafe fn roc_row_avx512_long(
528 data: &[f64],
529 first: usize,
530 period: usize,
531 _stride: usize,
532 _weights: *const f64,
533 _inv_n: f64,
534 out: &mut [f64],
535) {
536 roc_row_scalar(data, first, period, _stride, _weights, _inv_n, out)
537}
538
539#[derive(Debug, Clone)]
540pub struct RocStream {
541 period: usize,
542 buffer: Vec<f64>,
543 head: usize,
544 filled: bool,
545 count: usize,
546}
547
548impl RocStream {
549 pub fn try_new(params: RocParams) -> Result<Self, RocError> {
550 let period = params.period.unwrap_or(9);
551 if period == 0 {
552 return Err(RocError::InvalidPeriod {
553 period,
554 data_len: 0,
555 });
556 }
557 Ok(Self {
558 period,
559 buffer: vec![f64::NAN; period],
560 head: 0,
561 filled: false,
562 count: 0,
563 })
564 }
565
566 #[inline(always)]
567 pub fn update(&mut self, value: f64) -> Option<f64> {
568 if !self.filled {
569 if value.is_nan() {
570 return None;
571 }
572
573 self.buffer[self.head] = value;
574
575 self.head += 1;
576 if self.head == self.period {
577 self.head = 0;
578 }
579 self.filled = true;
580 self.count = 1;
581 return None;
582 }
583
584 if self.count < self.period {
585 self.buffer[self.head] = value;
586 self.head += 1;
587 if self.head == self.period {
588 self.head = 0;
589 }
590 self.count += 1;
591 return None;
592 }
593
594 let old_value = self.buffer[self.head];
595 self.buffer[self.head] = value;
596
597 self.head += 1;
598 if self.head == self.period {
599 self.head = 0;
600 }
601
602 if old_value == 0.0 || old_value.is_nan() {
603 Some(0.0)
604 } else {
605 Some((value / old_value).mul_add(100.0, -100.0))
606 }
607 }
608}
609
610#[derive(Clone, Debug)]
611pub struct RocBatchRange {
612 pub period: (usize, usize, usize),
613}
614
615impl Default for RocBatchRange {
616 fn default() -> Self {
617 Self {
618 period: (9, 258, 1),
619 }
620 }
621}
622
623#[derive(Clone, Debug, Default)]
624pub struct RocBatchBuilder {
625 range: RocBatchRange,
626 kernel: Kernel,
627}
628
629impl RocBatchBuilder {
630 pub fn new() -> Self {
631 Self::default()
632 }
633 pub fn kernel(mut self, k: Kernel) -> Self {
634 self.kernel = k;
635 self
636 }
637 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
638 self.range.period = (start, end, step);
639 self
640 }
641 pub fn period_static(mut self, p: usize) -> Self {
642 self.range.period = (p, p, 0);
643 self
644 }
645 pub fn apply_slice(self, data: &[f64]) -> Result<RocBatchOutput, RocError> {
646 roc_batch_with_kernel(data, &self.range, self.kernel)
647 }
648 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<RocBatchOutput, RocError> {
649 RocBatchBuilder::new().kernel(k).apply_slice(data)
650 }
651 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<RocBatchOutput, RocError> {
652 let slice = source_type(c, src);
653 self.apply_slice(slice)
654 }
655 pub fn with_default_candles(c: &Candles) -> Result<RocBatchOutput, RocError> {
656 RocBatchBuilder::new()
657 .kernel(Kernel::Auto)
658 .apply_candles(c, "close")
659 }
660}
661
662pub fn roc_batch_with_kernel(
663 data: &[f64],
664 sweep: &RocBatchRange,
665 k: Kernel,
666) -> Result<RocBatchOutput, RocError> {
667 let kernel = match k {
668 Kernel::Auto => detect_best_batch_kernel(),
669 other if other.is_batch() => other,
670 other => return Err(RocError::InvalidKernelForBatch(other)),
671 };
672
673 let simd = match kernel {
674 Kernel::Avx512Batch => Kernel::Avx512,
675 Kernel::Avx2Batch => Kernel::Avx2,
676 Kernel::ScalarBatch => Kernel::Scalar,
677 _ => unreachable!(),
678 };
679 roc_batch_par_slice(data, sweep, simd)
680}
681
682#[derive(Clone, Debug)]
683pub struct RocBatchOutput {
684 pub values: Vec<f64>,
685 pub combos: Vec<RocParams>,
686 pub rows: usize,
687 pub cols: usize,
688}
689
690impl RocBatchOutput {
691 pub fn row_for_params(&self, p: &RocParams) -> Option<usize> {
692 self.combos
693 .iter()
694 .position(|c| c.period.unwrap_or(9) == p.period.unwrap_or(9))
695 }
696 pub fn values_for(&self, p: &RocParams) -> Option<&[f64]> {
697 self.row_for_params(p).and_then(|row| {
698 let start = row.checked_mul(self.cols)?;
699 let end = start.checked_add(self.cols)?;
700 self.values.get(start..end)
701 })
702 }
703}
704
705#[inline(always)]
706pub(crate) fn expand_grid(r: &RocBatchRange) -> Result<Vec<RocParams>, RocError> {
707 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, RocError> {
708 if step == 0 || start == end {
709 return Ok(vec![start]);
710 }
711 let mut out = Vec::new();
712 if start < end {
713 let mut v = start;
714 while v <= end {
715 out.push(v);
716 match v.checked_add(step) {
717 Some(next) => {
718 if next == v {
719 break;
720 }
721 v = next;
722 }
723 None => break,
724 }
725 }
726 } else {
727 let mut v = start;
728 while v >= end {
729 out.push(v);
730 if v < end + step {
731 break;
732 }
733 v -= step;
734 if v == 0 {
735 break;
736 }
737 }
738 }
739 if out.is_empty() {
740 return Err(RocError::InvalidRange { start, end, step });
741 }
742 Ok(out)
743 }
744
745 let periods = axis_usize(r.period)?;
746 Ok(periods
747 .into_iter()
748 .map(|p| RocParams { period: Some(p) })
749 .collect())
750}
751
752#[inline(always)]
753pub fn roc_batch_slice(
754 data: &[f64],
755 sweep: &RocBatchRange,
756 kern: Kernel,
757) -> Result<RocBatchOutput, RocError> {
758 roc_batch_inner(data, sweep, kern, false)
759}
760#[inline(always)]
761pub fn roc_batch_par_slice(
762 data: &[f64],
763 sweep: &RocBatchRange,
764 kern: Kernel,
765) -> Result<RocBatchOutput, RocError> {
766 roc_batch_inner(data, sweep, kern, true)
767}
768
769#[inline(always)]
770fn roc_batch_inner(
771 data: &[f64],
772 sweep: &RocBatchRange,
773 kern: Kernel,
774 parallel: bool,
775) -> Result<RocBatchOutput, RocError> {
776 let combos = expand_grid(sweep)?;
777 let first = data
778 .iter()
779 .position(|x| !x.is_nan())
780 .ok_or(RocError::AllValuesNaN)?;
781 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
782 if data.len() - first < max_p {
783 return Err(RocError::NotEnoughValidData {
784 needed: max_p,
785 valid: data.len() - first,
786 });
787 }
788 let rows = combos.len();
789 let cols = data.len();
790 let _total = rows.checked_mul(cols).ok_or(RocError::InvalidRange {
791 start: sweep.period.0,
792 end: sweep.period.1,
793 step: sweep.period.2,
794 })?;
795
796 let mut buf_mu = make_uninit_matrix(rows, cols);
797
798 let warm: Vec<usize> = combos.iter().map(|c| first + c.period.unwrap()).collect();
799 init_matrix_prefixes(&mut buf_mu, cols, &warm);
800
801 let mut buf_guard = core::mem::ManuallyDrop::new(buf_mu);
802 let out: &mut [f64] = unsafe {
803 core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
804 };
805
806 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
807 let period = combos[row].period.unwrap();
808 match kern {
809 Kernel::Scalar => {
810 roc_row_scalar(data, first, period, 0, std::ptr::null(), 0.0, out_row)
811 }
812 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
813 Kernel::Avx2 => roc_row_avx2(data, first, period, 0, std::ptr::null(), 0.0, out_row),
814 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
815 Kernel::Avx512 => {
816 roc_row_avx512(data, first, period, 0, std::ptr::null(), 0.0, out_row)
817 }
818 _ => unreachable!(),
819 }
820 };
821
822 if parallel {
823 #[cfg(not(target_arch = "wasm32"))]
824 {
825 out.par_chunks_mut(cols)
826 .enumerate()
827 .for_each(|(row, slice)| do_row(row, slice));
828 }
829
830 #[cfg(target_arch = "wasm32")]
831 {
832 for (row, slice) in out.chunks_mut(cols).enumerate() {
833 do_row(row, slice);
834 }
835 }
836 } else {
837 for (row, slice) in out.chunks_mut(cols).enumerate() {
838 do_row(row, slice);
839 }
840 }
841
842 let values = unsafe {
843 Vec::from_raw_parts(
844 buf_guard.as_mut_ptr() as *mut f64,
845 buf_guard.len(),
846 buf_guard.capacity(),
847 )
848 };
849
850 Ok(RocBatchOutput {
851 values,
852 combos,
853 rows,
854 cols,
855 })
856}
857
858#[inline(always)]
859fn roc_batch_inner_into(
860 data: &[f64],
861 sweep: &RocBatchRange,
862 kern: Kernel,
863 parallel: bool,
864 out: &mut [f64],
865) -> Result<Vec<RocParams>, RocError> {
866 let combos = expand_grid(sweep)?;
867 let first = data
868 .iter()
869 .position(|x| !x.is_nan())
870 .ok_or(RocError::AllValuesNaN)?;
871 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
872 if data.len() - first < max_p {
873 return Err(RocError::NotEnoughValidData {
874 needed: max_p,
875 valid: data.len() - first,
876 });
877 }
878
879 let rows = combos.len();
880 let cols = data.len();
881 let expected = rows.checked_mul(cols).ok_or(RocError::InvalidRange {
882 start: sweep.period.0,
883 end: sweep.period.1,
884 step: sweep.period.2,
885 })?;
886 if out.len() != expected {
887 return Err(RocError::OutputLengthMismatch {
888 expected,
889 got: out.len(),
890 });
891 }
892
893 let out_mu: &mut [MaybeUninit<f64>] = unsafe {
894 core::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
895 };
896 let warm: Vec<usize> = combos.iter().map(|c| first + c.period.unwrap()).collect();
897 init_matrix_prefixes(out_mu, cols, &warm);
898
899 let do_row = |row: usize, row_mu: &mut [MaybeUninit<f64>]| unsafe {
900 let period = combos[row].period.unwrap();
901 let dst: &mut [f64] =
902 core::slice::from_raw_parts_mut(row_mu.as_mut_ptr() as *mut f64, row_mu.len());
903 match kern {
904 Kernel::Scalar => roc_row_scalar(data, first, period, 0, std::ptr::null(), 0.0, dst),
905 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
906 Kernel::Avx2 => roc_row_avx2(data, first, period, 0, std::ptr::null(), 0.0, dst),
907 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
908 Kernel::Avx512 => roc_row_avx512(data, first, period, 0, std::ptr::null(), 0.0, dst),
909 _ => unreachable!(),
910 }
911 };
912
913 if parallel {
914 #[cfg(not(target_arch = "wasm32"))]
915 {
916 out_mu
917 .par_chunks_mut(cols)
918 .enumerate()
919 .for_each(|(row, row_mu)| do_row(row, row_mu));
920 }
921 #[cfg(target_arch = "wasm32")]
922 {
923 for (row, row_mu) in out_mu.chunks_mut(cols).enumerate() {
924 do_row(row, row_mu);
925 }
926 }
927 } else {
928 for (row, row_mu) in out_mu.chunks_mut(cols).enumerate() {
929 do_row(row, row_mu);
930 }
931 }
932
933 Ok(combos)
934}
935
936#[cfg(feature = "python")]
937#[pyfunction(name = "roc")]
938#[pyo3(signature = (data, period, kernel=None))]
939pub fn roc_py<'py>(
940 py: Python<'py>,
941 data: PyReadonlyArray1<'py, f64>,
942 period: usize,
943 kernel: Option<&str>,
944) -> PyResult<Bound<'py, PyArray1<f64>>> {
945 use numpy::{IntoPyArray, PyArrayMethods};
946
947 let slice_in = data.as_slice()?;
948 let kern = validate_kernel(kernel, false)?;
949
950 let params = RocParams {
951 period: Some(period),
952 };
953 let input = RocInput::from_slice(slice_in, params);
954
955 let result_vec: Vec<f64> = py
956 .allow_threads(|| roc_with_kernel(&input, kern).map(|o| o.values))
957 .map_err(|e| PyValueError::new_err(e.to_string()))?;
958
959 Ok(result_vec.into_pyarray(py))
960}
961
962#[cfg(feature = "python")]
963#[pyfunction(name = "roc_batch")]
964#[pyo3(signature = (data, period_range, kernel=None))]
965pub fn roc_batch_py<'py>(
966 py: Python<'py>,
967 data: PyReadonlyArray1<'py, f64>,
968 period_range: (usize, usize, usize),
969 kernel: Option<&str>,
970) -> PyResult<Bound<'py, PyDict>> {
971 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
972 use pyo3::types::PyDict;
973
974 let slice_in = data.as_slice()?;
975 let kern = validate_kernel(kernel, true)?;
976
977 let sweep = RocBatchRange {
978 period: period_range,
979 };
980
981 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
982 let rows = combos.len();
983 let cols = slice_in.len();
984
985 let total = rows
986 .checked_mul(cols)
987 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
988
989 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
990 let slice_out = unsafe { out_arr.as_slice_mut()? };
991
992 let combos = py
993 .allow_threads(|| {
994 let kernel = match kern {
995 Kernel::Auto => detect_best_batch_kernel(),
996 k => k,
997 };
998 let simd = match kernel {
999 Kernel::Avx512Batch => Kernel::Avx512,
1000 Kernel::Avx2Batch => Kernel::Avx2,
1001 Kernel::ScalarBatch => Kernel::Scalar,
1002 _ => unreachable!(),
1003 };
1004
1005 roc_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1006 })
1007 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1008
1009 let dict = PyDict::new(py);
1010 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1011 dict.set_item(
1012 "periods",
1013 combos
1014 .iter()
1015 .map(|p| p.period.unwrap() as u64)
1016 .collect::<Vec<_>>()
1017 .into_pyarray(py),
1018 )?;
1019
1020 Ok(dict)
1021}
1022
1023#[cfg(feature = "python")]
1024#[pyclass(name = "RocStream")]
1025pub struct RocStreamPy {
1026 inner: RocStream,
1027}
1028
1029#[cfg(feature = "python")]
1030#[pymethods]
1031impl RocStreamPy {
1032 #[new]
1033 pub fn new(period: usize) -> PyResult<Self> {
1034 let params = RocParams {
1035 period: Some(period),
1036 };
1037 let inner = RocStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1038 Ok(RocStreamPy { inner })
1039 }
1040
1041 pub fn update(&mut self, value: f64) -> Option<f64> {
1042 self.inner.update(value)
1043 }
1044}
1045
1046#[cfg(all(feature = "python", feature = "cuda"))]
1047#[pyclass(module = "ta_indicators.cuda", unsendable)]
1048pub struct RocDeviceArrayF32Py {
1049 pub(crate) inner: Option<DeviceArrayF32Roc>,
1050}
1051
1052#[cfg(all(feature = "python", feature = "cuda"))]
1053#[pymethods]
1054impl RocDeviceArrayF32Py {
1055 #[getter]
1056 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1057 let inner = self
1058 .inner
1059 .as_ref()
1060 .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
1061 let d = PyDict::new(py);
1062 let itemsize = std::mem::size_of::<f32>();
1063 let row_stride = inner
1064 .cols
1065 .checked_mul(itemsize)
1066 .ok_or_else(|| PyValueError::new_err("byte stride overflow"))?;
1067 d.set_item("shape", (inner.rows, inner.cols))?;
1068 d.set_item("typestr", "<f4")?;
1069 d.set_item("strides", (row_stride, itemsize))?;
1070 d.set_item("data", (inner.device_ptr() as usize, false))?;
1071
1072 d.set_item("version", 3)?;
1073 Ok(d)
1074 }
1075
1076 fn __dlpack_device__(&self) -> PyResult<(i32, i32)> {
1077 let inner = self
1078 .inner
1079 .as_ref()
1080 .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
1081 Ok((2, inner.device_id as i32))
1082 }
1083
1084 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1085 fn __dlpack__<'py>(
1086 &mut self,
1087 py: Python<'py>,
1088 stream: Option<pyo3::PyObject>,
1089 max_version: Option<pyo3::PyObject>,
1090 dl_device: Option<pyo3::PyObject>,
1091 copy: Option<pyo3::PyObject>,
1092 ) -> PyResult<PyObject> {
1093 let (kdl, alloc_dev) = self.__dlpack_device__()?;
1094 if let Some(dev_obj) = dl_device.as_ref() {
1095 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1096 if dev_ty != kdl || dev_id != alloc_dev {
1097 let wants_copy = copy
1098 .as_ref()
1099 .and_then(|c| c.extract::<bool>(py).ok())
1100 .unwrap_or(false);
1101 if wants_copy {
1102 return Err(PyValueError::new_err(
1103 "device copy not implemented for __dlpack__",
1104 ));
1105 } else {
1106 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
1107 }
1108 }
1109 }
1110 }
1111 let _ = stream;
1112
1113 let inner = self
1114 .inner
1115 .take()
1116 .ok_or_else(|| PyValueError::new_err("__dlpack__ may only be called once"))?;
1117
1118 let rows = inner.rows;
1119 let cols = inner.cols;
1120 let buf = inner.buf;
1121
1122 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1123
1124 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1125 }
1126}
1127
1128#[inline]
1129pub fn roc_into_slice(dst: &mut [f64], input: &RocInput, kern: Kernel) -> Result<(), RocError> {
1130 let (data, period, first, chosen) = roc_prepare(input, kern)?;
1131 if dst.len() != data.len() {
1132 return Err(RocError::OutputLengthMismatch {
1133 expected: data.len(),
1134 got: dst.len(),
1135 });
1136 }
1137 let warmup_end = first + period;
1138 for v in &mut dst[..warmup_end] {
1139 *v = f64::NAN;
1140 }
1141 roc_compute_into(data, period, first, chosen, dst);
1142 Ok(())
1143}
1144
1145#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1146#[wasm_bindgen]
1147pub fn roc_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1148 let params = RocParams {
1149 period: Some(period),
1150 };
1151 let input = RocInput::from_slice(data, params);
1152
1153 let mut output = vec![0.0; data.len()];
1154
1155 roc_into_slice(&mut output, &input, Kernel::Auto)
1156 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1157
1158 Ok(output)
1159}
1160
1161#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1162#[wasm_bindgen]
1163pub fn roc_alloc(len: usize) -> *mut f64 {
1164 let mut vec = Vec::<f64>::with_capacity(len);
1165 let ptr = vec.as_mut_ptr();
1166 std::mem::forget(vec);
1167 ptr
1168}
1169
1170#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1171#[wasm_bindgen]
1172pub fn roc_free(ptr: *mut f64, len: usize) {
1173 if !ptr.is_null() {
1174 unsafe {
1175 let _ = Vec::from_raw_parts(ptr, len, len);
1176 }
1177 }
1178}
1179
1180#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1181#[wasm_bindgen]
1182pub fn roc_into(
1183 in_ptr: *const f64,
1184 out_ptr: *mut f64,
1185 len: usize,
1186 period: usize,
1187) -> Result<(), JsValue> {
1188 if in_ptr.is_null() || out_ptr.is_null() {
1189 return Err(JsValue::from_str("Null pointer provided"));
1190 }
1191
1192 unsafe {
1193 let data = std::slice::from_raw_parts(in_ptr, len);
1194 let params = RocParams {
1195 period: Some(period),
1196 };
1197 let input = RocInput::from_slice(data, params);
1198
1199 if in_ptr == out_ptr {
1200 let mut temp = vec![0.0; len];
1201 roc_into_slice(&mut temp, &input, Kernel::Auto)
1202 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1203 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1204 out.copy_from_slice(&temp);
1205 } else {
1206 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1207 roc_into_slice(out, &input, Kernel::Auto)
1208 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1209 }
1210 Ok(())
1211 }
1212}
1213
1214#[cfg(all(feature = "python", feature = "cuda"))]
1215use crate::cuda::cuda_available;
1216#[cfg(all(feature = "python", feature = "cuda"))]
1217use crate::cuda::oscillators::roc_wrapper::{CudaRoc, DeviceArrayF32Roc};
1218
1219#[cfg(all(feature = "python", feature = "cuda"))]
1220#[pyfunction(name = "roc_cuda_batch_dev")]
1221#[pyo3(signature = (data, period_range, device_id=0))]
1222pub fn roc_cuda_batch_dev_py(
1223 py: Python<'_>,
1224 data: numpy::PyReadonlyArray1<'_, f32>,
1225 period_range: (usize, usize, usize),
1226 device_id: usize,
1227) -> PyResult<RocDeviceArrayF32Py> {
1228 if !cuda_available() {
1229 return Err(PyValueError::new_err("CUDA not available"));
1230 }
1231 let prices = data.as_slice()?;
1232 let sweep = RocBatchRange {
1233 period: period_range,
1234 };
1235 let inner = py.allow_threads(|| {
1236 let cuda = CudaRoc::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1237 cuda.roc_batch_dev(prices, &sweep)
1238 .map_err(|e| PyValueError::new_err(e.to_string()))
1239 })?;
1240 Ok(RocDeviceArrayF32Py { inner: Some(inner) })
1241}
1242
1243#[cfg(all(feature = "python", feature = "cuda"))]
1244#[pyfunction(name = "roc_cuda_many_series_one_param_dev")]
1245#[pyo3(signature = (data_tm, cols, rows, period, device_id=0))]
1246pub fn roc_cuda_many_series_one_param_dev_py(
1247 py: Python<'_>,
1248 data_tm: numpy::PyReadonlyArray1<'_, f32>,
1249 cols: usize,
1250 rows: usize,
1251 period: usize,
1252 device_id: usize,
1253) -> PyResult<RocDeviceArrayF32Py> {
1254 if !cuda_available() {
1255 return Err(PyValueError::new_err("CUDA not available"));
1256 }
1257 let prices_tm = data_tm.as_slice()?;
1258 let expected = cols
1259 .checked_mul(rows)
1260 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
1261 if prices_tm.len() != expected {
1262 return Err(PyValueError::new_err("time-major input length mismatch"));
1263 }
1264 let inner = py.allow_threads(|| {
1265 let cuda = CudaRoc::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1266 cuda.roc_many_series_one_param_time_major_dev(prices_tm, cols, rows, period)
1267 .map_err(|e| PyValueError::new_err(e.to_string()))
1268 })?;
1269 Ok(RocDeviceArrayF32Py { inner: Some(inner) })
1270}
1271
1272#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1273#[derive(Serialize, Deserialize)]
1274pub struct RocBatchConfig {
1275 pub period_range: (usize, usize, usize),
1276}
1277
1278#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1279#[derive(Serialize, Deserialize)]
1280pub struct RocBatchJsOutput {
1281 pub values: Vec<f64>,
1282 pub combos: Vec<RocParams>,
1283 pub rows: usize,
1284 pub cols: usize,
1285}
1286
1287#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1288#[wasm_bindgen]
1289pub fn roc_batch(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1290 let config: RocBatchConfig = serde_wasm_bindgen::from_value(config)
1291 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1292
1293 let batch_range = RocBatchRange {
1294 period: config.period_range,
1295 };
1296
1297 let result = roc_batch_with_kernel(data, &batch_range, Kernel::Auto)
1298 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1299
1300 let output = RocBatchJsOutput {
1301 values: result.values,
1302 combos: result.combos,
1303 rows: result.rows,
1304 cols: result.cols,
1305 };
1306
1307 serde_wasm_bindgen::to_value(&output)
1308 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1309}
1310
1311#[cfg(test)]
1312mod tests {
1313 use super::*;
1314 use crate::skip_if_unsupported;
1315 use crate::utilities::data_loader::read_candles_from_csv;
1316 use paste::paste;
1317
1318 #[test]
1319 fn test_roc_into_matches_api() -> Result<(), Box<dyn Error>> {
1320 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1321 let candles = read_candles_from_csv(file_path)?;
1322
1323 let params = RocParams { period: Some(10) };
1324 let input = RocInput::from_candles(&candles, "close", params);
1325
1326 let baseline = roc(&input)?.values;
1327
1328 let mut out = vec![0.0; candles.close.len()];
1329 roc_into(&input, &mut out)?;
1330
1331 assert_eq!(baseline.len(), out.len());
1332
1333 for (i, (a, b)) in baseline.iter().zip(out.iter()).enumerate() {
1334 let equal = (a.is_nan() && b.is_nan()) || (a == b) || ((a - b).abs() <= 1e-12);
1335 assert!(
1336 equal,
1337 "roc_into parity mismatch at idx {}: api={} into={}",
1338 i, a, b
1339 );
1340 }
1341 Ok(())
1342 }
1343
1344 fn check_roc_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1345 skip_if_unsupported!(kernel, test_name);
1346 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1347 let candles = read_candles_from_csv(file_path)?;
1348
1349 let default_params = RocParams { period: None };
1350 let input_default = RocInput::from_candles(&candles, "close", default_params);
1351 let output_default = roc_with_kernel(&input_default, kernel)?;
1352 assert_eq!(output_default.values.len(), candles.close.len());
1353
1354 let params_period_14 = RocParams { period: Some(14) };
1355 let input_period_14 = RocInput::from_candles(&candles, "hl2", params_period_14);
1356 let output_period_14 = roc_with_kernel(&input_period_14, kernel)?;
1357 assert_eq!(output_period_14.values.len(), candles.close.len());
1358
1359 let params_custom = RocParams { period: Some(20) };
1360 let input_custom = RocInput::from_candles(&candles, "hlc3", params_custom);
1361 let output_custom = roc_with_kernel(&input_custom, kernel)?;
1362 assert_eq!(output_custom.values.len(), candles.close.len());
1363
1364 Ok(())
1365 }
1366
1367 fn check_roc_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1368 skip_if_unsupported!(kernel, test_name);
1369 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1370 let candles = read_candles_from_csv(file_path)?;
1371
1372 let close_prices = &candles.close;
1373 let params = RocParams { period: Some(10) };
1374 let input = RocInput::from_candles(&candles, "close", params);
1375 let roc_result = roc_with_kernel(&input, kernel)?;
1376
1377 assert_eq!(roc_result.values.len(), close_prices.len());
1378
1379 let expected_last_five_roc = [
1380 -0.22551709049294377,
1381 -0.5561903481650754,
1382 -0.32752013235864963,
1383 -0.49454153980722504,
1384 -1.5045927020536976,
1385 ];
1386 assert!(roc_result.values.len() >= 5);
1387 let start_index = roc_result.values.len() - 5;
1388 let result_last_five_roc = &roc_result.values[start_index..];
1389 for (i, &value) in result_last_five_roc.iter().enumerate() {
1390 let expected_value = expected_last_five_roc[i];
1391 assert!(
1392 (value - expected_value).abs() < 1e-7,
1393 "[{}] ROC mismatch at index {}: expected {}, got {}",
1394 test_name,
1395 i,
1396 expected_value,
1397 value
1398 );
1399 }
1400 let period = input.get_period();
1401 for i in 0..(period - 1) {
1402 assert!(roc_result.values[i].is_nan());
1403 }
1404 Ok(())
1405 }
1406
1407 fn check_roc_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1408 skip_if_unsupported!(kernel, test_name);
1409 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1410 let candles = read_candles_from_csv(file_path)?;
1411
1412 let input = RocInput::with_default_candles(&candles);
1413 match input.data {
1414 RocData::Candles { source, .. } => assert_eq!(source, "close"),
1415 _ => panic!("Expected RocData::Candles"),
1416 }
1417 let output = roc_with_kernel(&input, kernel)?;
1418 assert_eq!(output.values.len(), candles.close.len());
1419 Ok(())
1420 }
1421
1422 fn check_roc_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1423 skip_if_unsupported!(kernel, test_name);
1424 let input_data = [10.0, 20.0, 30.0];
1425 let params = RocParams { period: Some(0) };
1426 let input = RocInput::from_slice(&input_data, params);
1427 let res = roc_with_kernel(&input, kernel);
1428 assert!(res.is_err());
1429 Ok(())
1430 }
1431
1432 fn check_roc_period_exceeds_length(
1433 test_name: &str,
1434 kernel: Kernel,
1435 ) -> Result<(), Box<dyn Error>> {
1436 skip_if_unsupported!(kernel, test_name);
1437 let data_small = [10.0, 20.0, 30.0];
1438 let params = RocParams { period: Some(10) };
1439 let input = RocInput::from_slice(&data_small, params);
1440 let res = roc_with_kernel(&input, kernel);
1441 assert!(res.is_err());
1442 Ok(())
1443 }
1444
1445 fn check_roc_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1446 skip_if_unsupported!(kernel, test_name);
1447 let single_point = [42.0];
1448 let params = RocParams { period: Some(9) };
1449 let input = RocInput::from_slice(&single_point, params);
1450 let res = roc_with_kernel(&input, kernel);
1451 assert!(res.is_err());
1452 Ok(())
1453 }
1454
1455 fn check_roc_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1456 skip_if_unsupported!(kernel, test_name);
1457 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1458 let candles = read_candles_from_csv(file_path)?;
1459
1460 let first_params = RocParams { period: Some(14) };
1461 let first_input = RocInput::from_candles(&candles, "close", first_params);
1462 let first_result = roc_with_kernel(&first_input, kernel)?;
1463
1464 let second_params = RocParams { period: Some(14) };
1465 let second_input = RocInput::from_slice(&first_result.values, second_params);
1466 let second_result = roc_with_kernel(&second_input, kernel)?;
1467
1468 assert_eq!(second_result.values.len(), first_result.values.len());
1469 for i in 28..second_result.values.len() {
1470 assert!(
1471 !second_result.values[i].is_nan(),
1472 "[{}] Expected no NaN after index 28, found NaN at {}",
1473 test_name,
1474 i
1475 );
1476 }
1477 Ok(())
1478 }
1479
1480 fn check_roc_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1481 skip_if_unsupported!(kernel, test_name);
1482 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1483 let candles = read_candles_from_csv(file_path)?;
1484
1485 let input = RocInput::from_candles(&candles, "close", RocParams { period: Some(9) });
1486 let res = roc_with_kernel(&input, kernel)?;
1487 assert_eq!(res.values.len(), candles.close.len());
1488 if res.values.len() > 240 {
1489 for (i, &val) in res.values[240..].iter().enumerate() {
1490 assert!(
1491 !val.is_nan(),
1492 "[{}] Found unexpected NaN at out-index {}",
1493 test_name,
1494 240 + i
1495 );
1496 }
1497 }
1498 Ok(())
1499 }
1500
1501 fn check_roc_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1502 skip_if_unsupported!(kernel, test_name);
1503 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1504 let candles = read_candles_from_csv(file_path)?;
1505 let period = 9;
1506
1507 let input = RocInput::from_candles(
1508 &candles,
1509 "close",
1510 RocParams {
1511 period: Some(period),
1512 },
1513 );
1514 let batch_output = roc_with_kernel(&input, kernel)?.values;
1515
1516 let mut stream = RocStream::try_new(RocParams {
1517 period: Some(period),
1518 })?;
1519 let mut stream_values = Vec::with_capacity(candles.close.len());
1520 for &price in &candles.close {
1521 match stream.update(price) {
1522 Some(val) => stream_values.push(val),
1523 None => stream_values.push(f64::NAN),
1524 }
1525 }
1526 assert_eq!(batch_output.len(), stream_values.len());
1527 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1528 if b.is_nan() && s.is_nan() {
1529 continue;
1530 }
1531 let diff = (b - s).abs();
1532 assert!(
1533 diff < 1e-9,
1534 "[{}] ROC streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1535 test_name,
1536 i,
1537 b,
1538 s,
1539 diff
1540 );
1541 }
1542 Ok(())
1543 }
1544
1545 #[cfg(debug_assertions)]
1546 fn check_roc_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1547 skip_if_unsupported!(kernel, test_name);
1548
1549 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1550 let candles = read_candles_from_csv(file_path)?;
1551
1552 let test_params = vec![
1553 RocParams::default(),
1554 RocParams { period: Some(2) },
1555 RocParams { period: Some(5) },
1556 RocParams { period: Some(7) },
1557 RocParams { period: Some(9) },
1558 RocParams { period: Some(14) },
1559 RocParams { period: Some(20) },
1560 RocParams { period: Some(30) },
1561 RocParams { period: Some(50) },
1562 RocParams { period: Some(100) },
1563 RocParams { period: Some(200) },
1564 ];
1565
1566 for (param_idx, params) in test_params.iter().enumerate() {
1567 let input = RocInput::from_candles(&candles, "close", params.clone());
1568 let output = roc_with_kernel(&input, kernel)?;
1569
1570 for (i, &val) in output.values.iter().enumerate() {
1571 if val.is_nan() {
1572 continue;
1573 }
1574
1575 let bits = val.to_bits();
1576
1577 if bits == 0x11111111_11111111 {
1578 panic!(
1579 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1580 with params: period={} (param set {})",
1581 test_name,
1582 val,
1583 bits,
1584 i,
1585 params.period.unwrap_or(9),
1586 param_idx
1587 );
1588 }
1589
1590 if bits == 0x22222222_22222222 {
1591 panic!(
1592 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1593 with params: period={} (param set {})",
1594 test_name,
1595 val,
1596 bits,
1597 i,
1598 params.period.unwrap_or(9),
1599 param_idx
1600 );
1601 }
1602
1603 if bits == 0x33333333_33333333 {
1604 panic!(
1605 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1606 with params: period={} (param set {})",
1607 test_name,
1608 val,
1609 bits,
1610 i,
1611 params.period.unwrap_or(9),
1612 param_idx
1613 );
1614 }
1615 }
1616 }
1617
1618 Ok(())
1619 }
1620
1621 #[cfg(not(debug_assertions))]
1622 fn check_roc_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1623 Ok(())
1624 }
1625
1626 #[cfg(feature = "proptest")]
1627 #[allow(clippy::float_cmp)]
1628 fn check_roc_property(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1629 use proptest::prelude::*;
1630 skip_if_unsupported!(kernel, test_name);
1631
1632 let strat = (2usize..=100).prop_flat_map(|period| {
1633 prop_oneof![
1634 (
1635 prop::collection::vec(
1636 (1f64..1e6f64)
1637 .prop_filter("finite positive", |x| x.is_finite() && *x > 0.0),
1638 period..400,
1639 ),
1640 Just(period),
1641 ),
1642 (
1643 prop::collection::vec(
1644 prop_oneof![
1645 (1f64..1000f64).prop_filter("finite", |x| x.is_finite()),
1646 Just(100.0),
1647 ],
1648 period..400,
1649 ),
1650 Just(period),
1651 ),
1652 (
1653 (100f64..10000f64, 0.01f64..0.1f64).prop_map(move |(start, step)| {
1654 let len = period + (400 - period) / 2;
1655 (0..len)
1656 .map(|i| start + (i as f64) * step)
1657 .collect::<Vec<_>>()
1658 }),
1659 Just(period),
1660 ),
1661 (
1662 (10000f64..100000f64, 0.01f64..0.1f64).prop_map(move |(start, step)| {
1663 let len = period + (400 - period) / 2;
1664 (0..len)
1665 .map(|i| start - (i as f64) * step)
1666 .collect::<Vec<_>>()
1667 }),
1668 Just(period),
1669 ),
1670 ]
1671 });
1672
1673 proptest::test_runner::TestRunner::default().run(&strat, |(data, period)| {
1674 let params = RocParams {
1675 period: Some(period),
1676 };
1677 let input = RocInput::from_slice(&data, params);
1678
1679 let RocOutput { values: out } = roc_with_kernel(&input, kernel)?;
1680 let RocOutput { values: ref_out } = roc_with_kernel(&input, Kernel::Scalar)?;
1681
1682 prop_assert_eq!(out.len(), data.len(), "Output length mismatch");
1683
1684 for i in 0..period {
1685 prop_assert!(
1686 out[i].is_nan(),
1687 "Expected NaN during warmup at index {}, got {}",
1688 i,
1689 out[i]
1690 );
1691 }
1692
1693 for i in period..data.len() {
1694 let current = data[i];
1695 let previous = data[i - period];
1696 let roc_val = out[i];
1697
1698 let expected_roc = if previous == 0.0 || previous.is_nan() {
1699 0.0
1700 } else {
1701 ((current / previous) - 1.0) * 100.0
1702 };
1703
1704 if !roc_val.is_nan() {
1705 prop_assert!(
1706 (roc_val - expected_roc).abs() < 1e-9,
1707 "ROC calculation mismatch at {}: got {}, expected {} (current={}, previous={})",
1708 i, roc_val, expected_roc, current, previous
1709 );
1710
1711 if current > previous && previous > 0.0 {
1712 prop_assert!(
1713 roc_val > -1e-9,
1714 "ROC should be positive when current > previous at {}: roc={}, current={}, previous={}",
1715 i, roc_val, current, previous
1716 );
1717 }
1718 if current < previous && previous > 0.0 {
1719 prop_assert!(
1720 roc_val < 1e-9,
1721 "ROC should be negative when current < previous at {}: roc={}, current={}, previous={}",
1722 i, roc_val, current, previous
1723 );
1724 }
1725 if (current - previous).abs() < 1e-12 && previous > 0.0 {
1726 prop_assert!(
1727 roc_val.abs() < 1e-9,
1728 "ROC should be ~0 when current ≈ previous at {}: roc={}, current={}, previous={}",
1729 i, roc_val, current, previous
1730 );
1731 }
1732 }
1733 }
1734
1735 let is_constant = data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-12);
1736 if is_constant && data.len() > period {
1737 for i in period..data.len() {
1738 if !out[i].is_nan() {
1739 prop_assert!(
1740 out[i].abs() < 1e-9,
1741 "ROC of constant data should be 0 at {}: got {}",
1742 i,
1743 out[i]
1744 );
1745 }
1746 }
1747 }
1748
1749 let is_monotonic_increasing = data.windows(2).all(|w| w[1] >= w[0]);
1750 let is_monotonic_decreasing = data.windows(2).all(|w| w[1] <= w[0]);
1751
1752 if is_monotonic_increasing && !is_constant {
1753 for i in period..data.len() {
1754 if !out[i].is_nan() && data[i] > data[i - period] {
1755 prop_assert!(
1756 out[i] > -1e-9,
1757 "ROC should be positive for increasing data at {}: got {}",
1758 i,
1759 out[i]
1760 );
1761 }
1762 }
1763 }
1764
1765 if is_monotonic_decreasing && !is_constant {
1766 for i in period..data.len() {
1767 if !out[i].is_nan() && data[i] < data[i - period] {
1768 prop_assert!(
1769 out[i] < 1e-9,
1770 "ROC should be negative for decreasing data at {}: got {}",
1771 i,
1772 out[i]
1773 );
1774 }
1775 }
1776 }
1777
1778 prop_assert_eq!(out.len(), ref_out.len(), "Kernel output length mismatch");
1779
1780 for i in 0..out.len() {
1781 let y = out[i];
1782 let r = ref_out[i];
1783
1784 if !y.is_finite() || !r.is_finite() {
1785 prop_assert!(
1786 y.to_bits() == r.to_bits(),
1787 "NaN/Inf mismatch at {}: {} vs {}",
1788 i,
1789 y,
1790 r
1791 );
1792 } else {
1793 let tolerance = 1e-9;
1794 prop_assert!(
1795 (y - r).abs() <= tolerance,
1796 "Kernel mismatch at {}: {} vs {}, diff={}",
1797 i,
1798 y,
1799 r,
1800 (y - r).abs()
1801 );
1802 }
1803 }
1804
1805 #[cfg(debug_assertions)]
1806 {
1807 for (i, &val) in out.iter().enumerate() {
1808 if !val.is_nan() {
1809 let bits = val.to_bits();
1810 prop_assert_ne!(
1811 bits,
1812 0x11111111_11111111,
1813 "Found alloc_with_nan_prefix poison at {}",
1814 i
1815 );
1816 prop_assert_ne!(
1817 bits,
1818 0x22222222_22222222,
1819 "Found init_matrix_prefixes poison at {}",
1820 i
1821 );
1822 prop_assert_ne!(
1823 bits,
1824 0x33333333_33333333,
1825 "Found make_uninit_matrix poison at {}",
1826 i
1827 );
1828 }
1829 }
1830 }
1831
1832 Ok(())
1833 })?;
1834
1835 Ok(())
1836 }
1837
1838 macro_rules! generate_all_roc_tests {
1839 ($($test_fn:ident),*) => {
1840 paste! {
1841 $(
1842 #[test]
1843 fn [<$test_fn _scalar_f64>]() {
1844 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1845 }
1846 )*
1847 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1848 $(
1849 #[test]
1850 fn [<$test_fn _avx2_f64>]() {
1851 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1852 }
1853 #[test]
1854 fn [<$test_fn _avx512_f64>]() {
1855 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1856 }
1857 )*
1858 }
1859 }
1860 }
1861 generate_all_roc_tests!(
1862 check_roc_partial_params,
1863 check_roc_accuracy,
1864 check_roc_default_candles,
1865 check_roc_zero_period,
1866 check_roc_period_exceeds_length,
1867 check_roc_very_small_dataset,
1868 check_roc_reinput,
1869 check_roc_nan_handling,
1870 check_roc_streaming,
1871 check_roc_no_poison
1872 );
1873
1874 #[cfg(feature = "proptest")]
1875 generate_all_roc_tests!(check_roc_property);
1876
1877 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1878 skip_if_unsupported!(kernel, test);
1879
1880 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1881 let c = read_candles_from_csv(file)?;
1882
1883 let output = RocBatchBuilder::new()
1884 .period_static(10)
1885 .kernel(kernel)
1886 .apply_candles(&c, "close")?;
1887
1888 let test_params = RocParams { period: Some(10) };
1889 let row = output
1890 .values_for(&test_params)
1891 .expect("period=10 row missing");
1892
1893 assert_eq!(row.len(), c.close.len());
1894
1895 let expected = [
1896 -0.22551709049294377,
1897 -0.5561903481650754,
1898 -0.32752013235864963,
1899 -0.49454153980722504,
1900 -1.5045927020536976,
1901 ];
1902 let start = row.len() - 5;
1903 for (i, &v) in row[start..].iter().enumerate() {
1904 assert!(
1905 (v - expected[i]).abs() < 1e-7,
1906 "[{test}] period=10 row mismatch at idx {i}: {v} vs {expected:?}"
1907 );
1908 }
1909 Ok(())
1910 }
1911
1912 #[cfg(debug_assertions)]
1913 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1914 skip_if_unsupported!(kernel, test);
1915
1916 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1917 let c = read_candles_from_csv(file)?;
1918
1919 let test_configs = vec![
1920 (2, 10, 2),
1921 (5, 25, 5),
1922 (30, 60, 15),
1923 (2, 5, 1),
1924 (9, 9, 0),
1925 (14, 21, 7),
1926 (10, 50, 10),
1927 ];
1928
1929 for (cfg_idx, &(p_start, p_end, p_step)) in test_configs.iter().enumerate() {
1930 let output = RocBatchBuilder::new()
1931 .kernel(kernel)
1932 .period_range(p_start, p_end, p_step)
1933 .apply_candles(&c, "close")?;
1934
1935 for (idx, &val) in output.values.iter().enumerate() {
1936 if val.is_nan() {
1937 continue;
1938 }
1939
1940 let bits = val.to_bits();
1941 let row = idx / output.cols;
1942 let col = idx % output.cols;
1943 let combo = &output.combos[row];
1944
1945 if bits == 0x11111111_11111111 {
1946 panic!(
1947 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1948 at row {} col {} (flat index {}) with params: period={}",
1949 test,
1950 cfg_idx,
1951 val,
1952 bits,
1953 row,
1954 col,
1955 idx,
1956 combo.period.unwrap_or(9)
1957 );
1958 }
1959
1960 if bits == 0x22222222_22222222 {
1961 panic!(
1962 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1963 at row {} col {} (flat index {}) with params: period={}",
1964 test,
1965 cfg_idx,
1966 val,
1967 bits,
1968 row,
1969 col,
1970 idx,
1971 combo.period.unwrap_or(9)
1972 );
1973 }
1974
1975 if bits == 0x33333333_33333333 {
1976 panic!(
1977 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1978 at row {} col {} (flat index {}) with params: period={}",
1979 test,
1980 cfg_idx,
1981 val,
1982 bits,
1983 row,
1984 col,
1985 idx,
1986 combo.period.unwrap_or(9)
1987 );
1988 }
1989 }
1990 }
1991
1992 Ok(())
1993 }
1994
1995 #[cfg(not(debug_assertions))]
1996 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1997 Ok(())
1998 }
1999
2000 macro_rules! gen_batch_tests {
2001 ($fn_name:ident) => {
2002 paste! {
2003 #[test] fn [<$fn_name _scalar>]() {
2004 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2005 }
2006 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2007 #[test] fn [<$fn_name _avx2>]() {
2008 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2009 }
2010 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2011 #[test] fn [<$fn_name _avx512>]() {
2012 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2013 }
2014 #[test] fn [<$fn_name _auto_detect>]() {
2015 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2016 }
2017 }
2018 };
2019 }
2020 gen_batch_tests!(check_batch_default_row);
2021 gen_batch_tests!(check_batch_no_poison);
2022}