Skip to main content

vector_ta/indicators/
roc.rs

1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5    make_uninit_matrix,
6};
7#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
8use core::arch::x86_64::*;
9#[cfg(not(target_arch = "wasm32"))]
10use rayon::prelude::*;
11use std::convert::AsRef;
12use std::error::Error;
13use std::mem::MaybeUninit;
14use thiserror::Error;
15
16#[cfg(all(feature = "python", feature = "cuda"))]
17use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
18#[cfg(feature = "python")]
19use crate::utilities::kernel_validation::validate_kernel;
20#[cfg(feature = "python")]
21use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
22#[cfg(feature = "python")]
23use pyo3::exceptions::PyValueError;
24#[cfg(feature = "python")]
25use pyo3::prelude::*;
26#[cfg(feature = "python")]
27use pyo3::types::PyDict;
28
29#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
30use serde::{Deserialize, Serialize};
31#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
32use wasm_bindgen::prelude::*;
33
34#[derive(Debug, Clone)]
35pub enum RocData<'a> {
36    Candles {
37        candles: &'a Candles,
38        source: &'a str,
39    },
40    Slice(&'a [f64]),
41}
42
43#[derive(Debug, Clone)]
44pub struct RocOutput {
45    pub values: Vec<f64>,
46}
47
48#[derive(Debug, Clone)]
49#[cfg_attr(
50    all(target_arch = "wasm32", feature = "wasm"),
51    derive(Serialize, Deserialize)
52)]
53pub struct RocParams {
54    pub period: Option<usize>,
55}
56
57impl Default for RocParams {
58    fn default() -> Self {
59        Self { period: Some(9) }
60    }
61}
62
63#[derive(Debug, Clone)]
64pub struct RocInput<'a> {
65    pub data: RocData<'a>,
66    pub params: RocParams,
67}
68
69impl<'a> AsRef<[f64]> for RocInput<'a> {
70    #[inline(always)]
71    fn as_ref(&self) -> &[f64] {
72        match &self.data {
73            RocData::Slice(slice) => slice,
74            RocData::Candles { candles, source } => source_type(candles, source),
75        }
76    }
77}
78
79impl<'a> RocInput<'a> {
80    #[inline]
81    pub fn from_candles(candles: &'a Candles, source: &'a str, params: RocParams) -> Self {
82        Self {
83            data: RocData::Candles { candles, source },
84            params,
85        }
86    }
87    #[inline]
88    pub fn from_slice(slice: &'a [f64], params: RocParams) -> Self {
89        Self {
90            data: RocData::Slice(slice),
91            params,
92        }
93    }
94    #[inline]
95    pub fn with_default_candles(candles: &'a Candles) -> Self {
96        Self::from_candles(candles, "close", RocParams::default())
97    }
98    #[inline]
99    pub fn get_period(&self) -> usize {
100        self.params.period.unwrap_or(9)
101    }
102}
103
104#[derive(Copy, Clone, Debug)]
105pub struct RocBuilder {
106    period: Option<usize>,
107    kernel: Kernel,
108}
109
110impl Default for RocBuilder {
111    fn default() -> Self {
112        Self {
113            period: None,
114            kernel: Kernel::Auto,
115        }
116    }
117}
118
119impl RocBuilder {
120    #[inline(always)]
121    pub fn new() -> Self {
122        Self::default()
123    }
124    #[inline(always)]
125    pub fn period(mut self, n: usize) -> Self {
126        self.period = Some(n);
127        self
128    }
129    #[inline(always)]
130    pub fn kernel(mut self, k: Kernel) -> Self {
131        self.kernel = k;
132        self
133    }
134    #[inline(always)]
135    pub fn apply(self, c: &Candles) -> Result<RocOutput, RocError> {
136        let p = RocParams {
137            period: self.period,
138        };
139        let i = RocInput::from_candles(c, "close", p);
140        roc_with_kernel(&i, self.kernel)
141    }
142    #[inline(always)]
143    pub fn apply_slice(self, d: &[f64]) -> Result<RocOutput, RocError> {
144        let p = RocParams {
145            period: self.period,
146        };
147        let i = RocInput::from_slice(d, p);
148        roc_with_kernel(&i, self.kernel)
149    }
150    #[inline(always)]
151    pub fn into_stream(self) -> Result<RocStream, RocError> {
152        let p = RocParams {
153            period: self.period,
154        };
155        RocStream::try_new(p)
156    }
157}
158
159#[derive(Debug, Error)]
160pub enum RocError {
161    #[error("roc: Input data slice is empty.")]
162    EmptyInputData,
163    #[error("roc: Input data slice is empty.")]
164    EmptyData,
165    #[error("roc: All values are NaN.")]
166    AllValuesNaN,
167    #[error("roc: Invalid period: period = {period}, data length = {data_len}")]
168    InvalidPeriod { period: usize, data_len: usize },
169    #[error("roc: Not enough valid data: needed = {needed}, valid = {valid}")]
170    NotEnoughValidData { needed: usize, valid: usize },
171    #[error("roc: Output length mismatch: expected = {expected}, got = {got}")]
172    OutputLengthMismatch { expected: usize, got: usize },
173    #[error("roc: Invalid range: start={start}, end={end}, step={step}")]
174    InvalidRange {
175        start: usize,
176        end: usize,
177        step: usize,
178    },
179    #[error("roc: Invalid kernel for batch: {0:?}")]
180    InvalidKernelForBatch(Kernel),
181}
182
183#[inline]
184pub fn roc(input: &RocInput) -> Result<RocOutput, RocError> {
185    roc_with_kernel(input, Kernel::Auto)
186}
187
188#[inline(always)]
189fn roc_prepare<'a>(
190    input: &'a RocInput,
191    kernel: Kernel,
192) -> Result<(&'a [f64], usize, usize, Kernel), RocError> {
193    let data: &[f64] = input.as_ref();
194    let len = data.len();
195    if len == 0 {
196        return Err(RocError::EmptyInputData);
197    }
198    let first = data
199        .iter()
200        .position(|x| !x.is_nan())
201        .ok_or(RocError::AllValuesNaN)?;
202    let period = input.get_period();
203    if period == 0 || period > len {
204        return Err(RocError::InvalidPeriod {
205            period,
206            data_len: len,
207        });
208    }
209    if len - first < period {
210        return Err(RocError::NotEnoughValidData {
211            needed: period,
212            valid: len - first,
213        });
214    }
215
216    let chosen = match kernel {
217        Kernel::Auto => Kernel::Scalar,
218        k => k,
219    };
220    Ok((data, period, first, chosen))
221}
222
223#[inline(always)]
224fn roc_compute_into(data: &[f64], period: usize, first: usize, kernel: Kernel, out: &mut [f64]) {
225    unsafe {
226        match kernel {
227            Kernel::Scalar => roc_scalar(data, period, first, out),
228            Kernel::ScalarBatch => roc_scalar(data, period, first, out),
229            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
230            Kernel::Avx2 => roc_avx2(data, period, first, out),
231            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
232            Kernel::Avx2Batch => roc_row_avx2(data, first, period, 0, std::ptr::null(), 0.0, out),
233            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
234            Kernel::Avx512 => roc_avx512(data, period, first, out),
235            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
236            Kernel::Avx512Batch => {
237                roc_row_avx512(data, first, period, 0, std::ptr::null(), 0.0, out)
238            }
239            _ => unreachable!(),
240        }
241    }
242}
243
244pub fn roc_with_kernel(input: &RocInput, kernel: Kernel) -> Result<RocOutput, RocError> {
245    let (data, period, first, chosen) = roc_prepare(input, kernel)?;
246
247    let mut out = alloc_with_nan_prefix(data.len(), first + period);
248    roc_compute_into(data, period, first, chosen, &mut out);
249    Ok(RocOutput { values: out })
250}
251
252#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
253#[inline]
254pub fn roc_into(input: &RocInput, out: &mut [f64]) -> Result<(), RocError> {
255    roc_into_slice(out, input, Kernel::Auto)
256}
257
258#[inline(always)]
259pub unsafe fn roc_indicator(input: &RocInput) -> Result<RocOutput, RocError> {
260    roc_with_kernel(input, Kernel::Auto)
261}
262#[inline(always)]
263pub unsafe fn roc_indicator_with_kernel(
264    input: &RocInput,
265    k: Kernel,
266) -> Result<RocOutput, RocError> {
267    roc_with_kernel(input, k)
268}
269#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
270#[inline(always)]
271pub unsafe fn roc_indicator_avx512(input: &RocInput) -> Result<RocOutput, RocError> {
272    roc_with_kernel(input, Kernel::Avx512)
273}
274#[inline(always)]
275pub unsafe fn roc_indicator_scalar(input: &RocInput) -> Result<RocOutput, RocError> {
276    roc_with_kernel(input, Kernel::Scalar)
277}
278#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
279#[inline(always)]
280pub unsafe fn roc_indicator_avx2(input: &RocInput) -> Result<RocOutput, RocError> {
281    roc_with_kernel(input, Kernel::Avx2)
282}
283#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
284#[inline(always)]
285pub unsafe fn roc_indicator_avx512_short(input: &RocInput) -> Result<RocOutput, RocError> {
286    roc_with_kernel(input, Kernel::Avx512)
287}
288#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
289#[inline(always)]
290pub unsafe fn roc_indicator_avx512_long(input: &RocInput) -> Result<RocOutput, RocError> {
291    roc_with_kernel(input, Kernel::Avx512)
292}
293
294#[inline(always)]
295pub fn roc_scalar(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
296    let len = data.len();
297    let start = first + period;
298
299    let dst = &mut out[start..];
300    let curr = &data[start..];
301    let prev = &data[first..(len - period)];
302    for ((d, &c), &p) in dst.iter_mut().zip(curr.iter()).zip(prev.iter()) {
303        if p == 0.0 || p.is_nan() {
304            *d = 0.0;
305        } else {
306            *d = (c / p).mul_add(100.0, -100.0);
307        }
308    }
309}
310
311#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
312#[target_feature(enable = "avx2,fma")]
313#[inline]
314pub unsafe fn roc_avx2(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
315    let len = data.len();
316    let start = first + period;
317    if start >= len {
318        return;
319    }
320
321    let n = len - start;
322    let base_curr = data.as_ptr().add(start);
323    let base_prev = data.as_ptr().add(first);
324    let base_out = out.as_mut_ptr().add(start);
325
326    let v_zero = _mm256_set1_pd(0.0);
327    let v_m100 = _mm256_set1_pd(-100.0);
328    let v_100 = _mm256_set1_pd(100.0);
329
330    let mut i = 0usize;
331    while i + 4 <= n {
332        let c = _mm256_loadu_pd(base_curr.add(i));
333        let p = _mm256_loadu_pd(base_prev.add(i));
334
335        let mask_zero = _mm256_cmp_pd(p, v_zero, _CMP_EQ_OQ);
336        let mask_nan = _mm256_cmp_pd(p, p, _CMP_UNORD_Q);
337        let mask_invalid = _mm256_or_pd(mask_zero, mask_nan);
338
339        let div = _mm256_div_pd(c, p);
340        let res = _mm256_fmadd_pd(div, v_100, v_m100);
341
342        let blended = _mm256_blendv_pd(res, v_zero, mask_invalid);
343        _mm256_storeu_pd(base_out.add(i), blended);
344        i += 4;
345    }
346
347    while i < n {
348        let p = *base_prev.add(i);
349        let c = *base_curr.add(i);
350        *base_out.add(i) = if p == 0.0 || p.is_nan() {
351            0.0
352        } else {
353            (c / p).mul_add(100.0, -100.0)
354        };
355        i += 1;
356    }
357}
358#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
359#[target_feature(enable = "avx512f,fma")]
360#[inline]
361pub unsafe fn roc_avx512(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
362    let len = data.len();
363    let start = first + period;
364    if start >= len {
365        return;
366    }
367
368    let n = len - start;
369    let base_curr = data.as_ptr().add(start);
370    let base_prev = data.as_ptr().add(first);
371    let base_out = out.as_mut_ptr().add(start);
372
373    let v_zero = _mm512_set1_pd(0.0);
374    let v_m100 = _mm512_set1_pd(-100.0);
375    let v_100 = _mm512_set1_pd(100.0);
376
377    let mut i = 0usize;
378    while i + 8 <= n {
379        let c = _mm512_loadu_pd(base_curr.add(i));
380        let p = _mm512_loadu_pd(base_prev.add(i));
381
382        let k_zero = _mm512_cmp_pd_mask(p, v_zero, _CMP_EQ_OQ);
383        let k_nan = _mm512_cmp_pd_mask(p, p, _CMP_UNORD_Q);
384        let k_invalid = k_zero | k_nan;
385
386        let div = _mm512_div_pd(c, p);
387        let res = _mm512_fmadd_pd(div, v_100, v_m100);
388
389        let blended = _mm512_mask_mov_pd(res, k_invalid, v_zero);
390        _mm512_storeu_pd(base_out.add(i), blended);
391        i += 8;
392    }
393
394    while i < n {
395        let p = *base_prev.add(i);
396        let c = *base_curr.add(i);
397        *base_out.add(i) = if p == 0.0 || p.is_nan() {
398            0.0
399        } else {
400            ((c / p) - 1.0) * 100.0
401        };
402        i += 1;
403    }
404}
405#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
406#[inline]
407pub unsafe fn roc_avx512_short(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
408    roc_scalar(data, period, first, out)
409}
410#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
411#[inline]
412pub unsafe fn roc_avx512_long(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
413    roc_scalar(data, period, first, out)
414}
415
416#[inline(always)]
417pub unsafe fn roc_row_scalar(
418    data: &[f64],
419    first: usize,
420    period: usize,
421    _stride: usize,
422    _weights: *const f64,
423    _inv_n: f64,
424    out: &mut [f64],
425) {
426    let len = data.len();
427    let start = first + period;
428    if start >= len {
429        return;
430    }
431
432    let base_ptr = data.as_ptr();
433    let prev_ptr = base_ptr.add(first);
434    let curr_ptr = base_ptr.add(start);
435    let dst_ptr = out.as_mut_ptr().add(start);
436
437    let n = len - start;
438    let mut i = 0usize;
439
440    while i + 4 <= n {
441        let p0 = *prev_ptr.add(i + 0);
442        let p1 = *prev_ptr.add(i + 1);
443        let p2 = *prev_ptr.add(i + 2);
444        let p3 = *prev_ptr.add(i + 3);
445
446        let c0 = *curr_ptr.add(i + 0);
447        let c1 = *curr_ptr.add(i + 1);
448        let c2 = *curr_ptr.add(i + 2);
449        let c3 = *curr_ptr.add(i + 3);
450
451        *dst_ptr.add(i + 0) = if p0 == 0.0 || p0.is_nan() {
452            0.0
453        } else {
454            (c0 / p0).mul_add(100.0, -100.0)
455        };
456        *dst_ptr.add(i + 1) = if p1 == 0.0 || p1.is_nan() {
457            0.0
458        } else {
459            (c1 / p1).mul_add(100.0, -100.0)
460        };
461        *dst_ptr.add(i + 2) = if p2 == 0.0 || p2.is_nan() {
462            0.0
463        } else {
464            (c2 / p2).mul_add(100.0, -100.0)
465        };
466        *dst_ptr.add(i + 3) = if p3 == 0.0 || p3.is_nan() {
467            0.0
468        } else {
469            (c3 / p3).mul_add(100.0, -100.0)
470        };
471
472        i += 4;
473    }
474
475    while i < n {
476        let p = *prev_ptr.add(i);
477        let c = *curr_ptr.add(i);
478        *dst_ptr.add(i) = if p == 0.0 || p.is_nan() {
479            0.0
480        } else {
481            (c / p).mul_add(100.0, -100.0)
482        };
483        i += 1;
484    }
485}
486#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
487#[inline(always)]
488pub unsafe fn roc_row_avx2(
489    data: &[f64],
490    first: usize,
491    period: usize,
492    _stride: usize,
493    _weights: *const f64,
494    _inv_n: f64,
495    out: &mut [f64],
496) {
497    roc_row_scalar(data, first, period, _stride, _weights, _inv_n, out)
498}
499#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
500#[inline(always)]
501pub unsafe fn roc_row_avx512(
502    data: &[f64],
503    first: usize,
504    period: usize,
505    _stride: usize,
506    _weights: *const f64,
507    _inv_n: f64,
508    out: &mut [f64],
509) {
510    roc_row_scalar(data, first, period, _stride, _weights, _inv_n, out)
511}
512#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
513#[inline(always)]
514pub unsafe fn roc_row_avx512_short(
515    data: &[f64],
516    first: usize,
517    period: usize,
518    _stride: usize,
519    _weights: *const f64,
520    _inv_n: f64,
521    out: &mut [f64],
522) {
523    roc_row_scalar(data, first, period, _stride, _weights, _inv_n, out)
524}
525#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
526#[inline(always)]
527pub unsafe fn roc_row_avx512_long(
528    data: &[f64],
529    first: usize,
530    period: usize,
531    _stride: usize,
532    _weights: *const f64,
533    _inv_n: f64,
534    out: &mut [f64],
535) {
536    roc_row_scalar(data, first, period, _stride, _weights, _inv_n, out)
537}
538
539#[derive(Debug, Clone)]
540pub struct RocStream {
541    period: usize,
542    buffer: Vec<f64>,
543    head: usize,
544    filled: bool,
545    count: usize,
546}
547
548impl RocStream {
549    pub fn try_new(params: RocParams) -> Result<Self, RocError> {
550        let period = params.period.unwrap_or(9);
551        if period == 0 {
552            return Err(RocError::InvalidPeriod {
553                period,
554                data_len: 0,
555            });
556        }
557        Ok(Self {
558            period,
559            buffer: vec![f64::NAN; period],
560            head: 0,
561            filled: false,
562            count: 0,
563        })
564    }
565
566    #[inline(always)]
567    pub fn update(&mut self, value: f64) -> Option<f64> {
568        if !self.filled {
569            if value.is_nan() {
570                return None;
571            }
572
573            self.buffer[self.head] = value;
574
575            self.head += 1;
576            if self.head == self.period {
577                self.head = 0;
578            }
579            self.filled = true;
580            self.count = 1;
581            return None;
582        }
583
584        if self.count < self.period {
585            self.buffer[self.head] = value;
586            self.head += 1;
587            if self.head == self.period {
588                self.head = 0;
589            }
590            self.count += 1;
591            return None;
592        }
593
594        let old_value = self.buffer[self.head];
595        self.buffer[self.head] = value;
596
597        self.head += 1;
598        if self.head == self.period {
599            self.head = 0;
600        }
601
602        if old_value == 0.0 || old_value.is_nan() {
603            Some(0.0)
604        } else {
605            Some((value / old_value).mul_add(100.0, -100.0))
606        }
607    }
608}
609
610#[derive(Clone, Debug)]
611pub struct RocBatchRange {
612    pub period: (usize, usize, usize),
613}
614
615impl Default for RocBatchRange {
616    fn default() -> Self {
617        Self {
618            period: (9, 258, 1),
619        }
620    }
621}
622
623#[derive(Clone, Debug, Default)]
624pub struct RocBatchBuilder {
625    range: RocBatchRange,
626    kernel: Kernel,
627}
628
629impl RocBatchBuilder {
630    pub fn new() -> Self {
631        Self::default()
632    }
633    pub fn kernel(mut self, k: Kernel) -> Self {
634        self.kernel = k;
635        self
636    }
637    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
638        self.range.period = (start, end, step);
639        self
640    }
641    pub fn period_static(mut self, p: usize) -> Self {
642        self.range.period = (p, p, 0);
643        self
644    }
645    pub fn apply_slice(self, data: &[f64]) -> Result<RocBatchOutput, RocError> {
646        roc_batch_with_kernel(data, &self.range, self.kernel)
647    }
648    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<RocBatchOutput, RocError> {
649        RocBatchBuilder::new().kernel(k).apply_slice(data)
650    }
651    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<RocBatchOutput, RocError> {
652        let slice = source_type(c, src);
653        self.apply_slice(slice)
654    }
655    pub fn with_default_candles(c: &Candles) -> Result<RocBatchOutput, RocError> {
656        RocBatchBuilder::new()
657            .kernel(Kernel::Auto)
658            .apply_candles(c, "close")
659    }
660}
661
662pub fn roc_batch_with_kernel(
663    data: &[f64],
664    sweep: &RocBatchRange,
665    k: Kernel,
666) -> Result<RocBatchOutput, RocError> {
667    let kernel = match k {
668        Kernel::Auto => detect_best_batch_kernel(),
669        other if other.is_batch() => other,
670        other => return Err(RocError::InvalidKernelForBatch(other)),
671    };
672
673    let simd = match kernel {
674        Kernel::Avx512Batch => Kernel::Avx512,
675        Kernel::Avx2Batch => Kernel::Avx2,
676        Kernel::ScalarBatch => Kernel::Scalar,
677        _ => unreachable!(),
678    };
679    roc_batch_par_slice(data, sweep, simd)
680}
681
682#[derive(Clone, Debug)]
683pub struct RocBatchOutput {
684    pub values: Vec<f64>,
685    pub combos: Vec<RocParams>,
686    pub rows: usize,
687    pub cols: usize,
688}
689
690impl RocBatchOutput {
691    pub fn row_for_params(&self, p: &RocParams) -> Option<usize> {
692        self.combos
693            .iter()
694            .position(|c| c.period.unwrap_or(9) == p.period.unwrap_or(9))
695    }
696    pub fn values_for(&self, p: &RocParams) -> Option<&[f64]> {
697        self.row_for_params(p).and_then(|row| {
698            let start = row.checked_mul(self.cols)?;
699            let end = start.checked_add(self.cols)?;
700            self.values.get(start..end)
701        })
702    }
703}
704
705#[inline(always)]
706pub(crate) fn expand_grid(r: &RocBatchRange) -> Result<Vec<RocParams>, RocError> {
707    fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, RocError> {
708        if step == 0 || start == end {
709            return Ok(vec![start]);
710        }
711        let mut out = Vec::new();
712        if start < end {
713            let mut v = start;
714            while v <= end {
715                out.push(v);
716                match v.checked_add(step) {
717                    Some(next) => {
718                        if next == v {
719                            break;
720                        }
721                        v = next;
722                    }
723                    None => break,
724                }
725            }
726        } else {
727            let mut v = start;
728            while v >= end {
729                out.push(v);
730                if v < end + step {
731                    break;
732                }
733                v -= step;
734                if v == 0 {
735                    break;
736                }
737            }
738        }
739        if out.is_empty() {
740            return Err(RocError::InvalidRange { start, end, step });
741        }
742        Ok(out)
743    }
744
745    let periods = axis_usize(r.period)?;
746    Ok(periods
747        .into_iter()
748        .map(|p| RocParams { period: Some(p) })
749        .collect())
750}
751
752#[inline(always)]
753pub fn roc_batch_slice(
754    data: &[f64],
755    sweep: &RocBatchRange,
756    kern: Kernel,
757) -> Result<RocBatchOutput, RocError> {
758    roc_batch_inner(data, sweep, kern, false)
759}
760#[inline(always)]
761pub fn roc_batch_par_slice(
762    data: &[f64],
763    sweep: &RocBatchRange,
764    kern: Kernel,
765) -> Result<RocBatchOutput, RocError> {
766    roc_batch_inner(data, sweep, kern, true)
767}
768
769#[inline(always)]
770fn roc_batch_inner(
771    data: &[f64],
772    sweep: &RocBatchRange,
773    kern: Kernel,
774    parallel: bool,
775) -> Result<RocBatchOutput, RocError> {
776    let combos = expand_grid(sweep)?;
777    let first = data
778        .iter()
779        .position(|x| !x.is_nan())
780        .ok_or(RocError::AllValuesNaN)?;
781    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
782    if data.len() - first < max_p {
783        return Err(RocError::NotEnoughValidData {
784            needed: max_p,
785            valid: data.len() - first,
786        });
787    }
788    let rows = combos.len();
789    let cols = data.len();
790    let _total = rows.checked_mul(cols).ok_or(RocError::InvalidRange {
791        start: sweep.period.0,
792        end: sweep.period.1,
793        step: sweep.period.2,
794    })?;
795
796    let mut buf_mu = make_uninit_matrix(rows, cols);
797
798    let warm: Vec<usize> = combos.iter().map(|c| first + c.period.unwrap()).collect();
799    init_matrix_prefixes(&mut buf_mu, cols, &warm);
800
801    let mut buf_guard = core::mem::ManuallyDrop::new(buf_mu);
802    let out: &mut [f64] = unsafe {
803        core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
804    };
805
806    let do_row = |row: usize, out_row: &mut [f64]| unsafe {
807        let period = combos[row].period.unwrap();
808        match kern {
809            Kernel::Scalar => {
810                roc_row_scalar(data, first, period, 0, std::ptr::null(), 0.0, out_row)
811            }
812            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
813            Kernel::Avx2 => roc_row_avx2(data, first, period, 0, std::ptr::null(), 0.0, out_row),
814            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
815            Kernel::Avx512 => {
816                roc_row_avx512(data, first, period, 0, std::ptr::null(), 0.0, out_row)
817            }
818            _ => unreachable!(),
819        }
820    };
821
822    if parallel {
823        #[cfg(not(target_arch = "wasm32"))]
824        {
825            out.par_chunks_mut(cols)
826                .enumerate()
827                .for_each(|(row, slice)| do_row(row, slice));
828        }
829
830        #[cfg(target_arch = "wasm32")]
831        {
832            for (row, slice) in out.chunks_mut(cols).enumerate() {
833                do_row(row, slice);
834            }
835        }
836    } else {
837        for (row, slice) in out.chunks_mut(cols).enumerate() {
838            do_row(row, slice);
839        }
840    }
841
842    let values = unsafe {
843        Vec::from_raw_parts(
844            buf_guard.as_mut_ptr() as *mut f64,
845            buf_guard.len(),
846            buf_guard.capacity(),
847        )
848    };
849
850    Ok(RocBatchOutput {
851        values,
852        combos,
853        rows,
854        cols,
855    })
856}
857
858#[inline(always)]
859fn roc_batch_inner_into(
860    data: &[f64],
861    sweep: &RocBatchRange,
862    kern: Kernel,
863    parallel: bool,
864    out: &mut [f64],
865) -> Result<Vec<RocParams>, RocError> {
866    let combos = expand_grid(sweep)?;
867    let first = data
868        .iter()
869        .position(|x| !x.is_nan())
870        .ok_or(RocError::AllValuesNaN)?;
871    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
872    if data.len() - first < max_p {
873        return Err(RocError::NotEnoughValidData {
874            needed: max_p,
875            valid: data.len() - first,
876        });
877    }
878
879    let rows = combos.len();
880    let cols = data.len();
881    let expected = rows.checked_mul(cols).ok_or(RocError::InvalidRange {
882        start: sweep.period.0,
883        end: sweep.period.1,
884        step: sweep.period.2,
885    })?;
886    if out.len() != expected {
887        return Err(RocError::OutputLengthMismatch {
888            expected,
889            got: out.len(),
890        });
891    }
892
893    let out_mu: &mut [MaybeUninit<f64>] = unsafe {
894        core::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
895    };
896    let warm: Vec<usize> = combos.iter().map(|c| first + c.period.unwrap()).collect();
897    init_matrix_prefixes(out_mu, cols, &warm);
898
899    let do_row = |row: usize, row_mu: &mut [MaybeUninit<f64>]| unsafe {
900        let period = combos[row].period.unwrap();
901        let dst: &mut [f64] =
902            core::slice::from_raw_parts_mut(row_mu.as_mut_ptr() as *mut f64, row_mu.len());
903        match kern {
904            Kernel::Scalar => roc_row_scalar(data, first, period, 0, std::ptr::null(), 0.0, dst),
905            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
906            Kernel::Avx2 => roc_row_avx2(data, first, period, 0, std::ptr::null(), 0.0, dst),
907            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
908            Kernel::Avx512 => roc_row_avx512(data, first, period, 0, std::ptr::null(), 0.0, dst),
909            _ => unreachable!(),
910        }
911    };
912
913    if parallel {
914        #[cfg(not(target_arch = "wasm32"))]
915        {
916            out_mu
917                .par_chunks_mut(cols)
918                .enumerate()
919                .for_each(|(row, row_mu)| do_row(row, row_mu));
920        }
921        #[cfg(target_arch = "wasm32")]
922        {
923            for (row, row_mu) in out_mu.chunks_mut(cols).enumerate() {
924                do_row(row, row_mu);
925            }
926        }
927    } else {
928        for (row, row_mu) in out_mu.chunks_mut(cols).enumerate() {
929            do_row(row, row_mu);
930        }
931    }
932
933    Ok(combos)
934}
935
936#[cfg(feature = "python")]
937#[pyfunction(name = "roc")]
938#[pyo3(signature = (data, period, kernel=None))]
939pub fn roc_py<'py>(
940    py: Python<'py>,
941    data: PyReadonlyArray1<'py, f64>,
942    period: usize,
943    kernel: Option<&str>,
944) -> PyResult<Bound<'py, PyArray1<f64>>> {
945    use numpy::{IntoPyArray, PyArrayMethods};
946
947    let slice_in = data.as_slice()?;
948    let kern = validate_kernel(kernel, false)?;
949
950    let params = RocParams {
951        period: Some(period),
952    };
953    let input = RocInput::from_slice(slice_in, params);
954
955    let result_vec: Vec<f64> = py
956        .allow_threads(|| roc_with_kernel(&input, kern).map(|o| o.values))
957        .map_err(|e| PyValueError::new_err(e.to_string()))?;
958
959    Ok(result_vec.into_pyarray(py))
960}
961
962#[cfg(feature = "python")]
963#[pyfunction(name = "roc_batch")]
964#[pyo3(signature = (data, period_range, kernel=None))]
965pub fn roc_batch_py<'py>(
966    py: Python<'py>,
967    data: PyReadonlyArray1<'py, f64>,
968    period_range: (usize, usize, usize),
969    kernel: Option<&str>,
970) -> PyResult<Bound<'py, PyDict>> {
971    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
972    use pyo3::types::PyDict;
973
974    let slice_in = data.as_slice()?;
975    let kern = validate_kernel(kernel, true)?;
976
977    let sweep = RocBatchRange {
978        period: period_range,
979    };
980
981    let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
982    let rows = combos.len();
983    let cols = slice_in.len();
984
985    let total = rows
986        .checked_mul(cols)
987        .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
988
989    let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
990    let slice_out = unsafe { out_arr.as_slice_mut()? };
991
992    let combos = py
993        .allow_threads(|| {
994            let kernel = match kern {
995                Kernel::Auto => detect_best_batch_kernel(),
996                k => k,
997            };
998            let simd = match kernel {
999                Kernel::Avx512Batch => Kernel::Avx512,
1000                Kernel::Avx2Batch => Kernel::Avx2,
1001                Kernel::ScalarBatch => Kernel::Scalar,
1002                _ => unreachable!(),
1003            };
1004
1005            roc_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1006        })
1007        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1008
1009    let dict = PyDict::new(py);
1010    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1011    dict.set_item(
1012        "periods",
1013        combos
1014            .iter()
1015            .map(|p| p.period.unwrap() as u64)
1016            .collect::<Vec<_>>()
1017            .into_pyarray(py),
1018    )?;
1019
1020    Ok(dict)
1021}
1022
1023#[cfg(feature = "python")]
1024#[pyclass(name = "RocStream")]
1025pub struct RocStreamPy {
1026    inner: RocStream,
1027}
1028
1029#[cfg(feature = "python")]
1030#[pymethods]
1031impl RocStreamPy {
1032    #[new]
1033    pub fn new(period: usize) -> PyResult<Self> {
1034        let params = RocParams {
1035            period: Some(period),
1036        };
1037        let inner = RocStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1038        Ok(RocStreamPy { inner })
1039    }
1040
1041    pub fn update(&mut self, value: f64) -> Option<f64> {
1042        self.inner.update(value)
1043    }
1044}
1045
1046#[cfg(all(feature = "python", feature = "cuda"))]
1047#[pyclass(module = "ta_indicators.cuda", unsendable)]
1048pub struct RocDeviceArrayF32Py {
1049    pub(crate) inner: Option<DeviceArrayF32Roc>,
1050}
1051
1052#[cfg(all(feature = "python", feature = "cuda"))]
1053#[pymethods]
1054impl RocDeviceArrayF32Py {
1055    #[getter]
1056    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1057        let inner = self
1058            .inner
1059            .as_ref()
1060            .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
1061        let d = PyDict::new(py);
1062        let itemsize = std::mem::size_of::<f32>();
1063        let row_stride = inner
1064            .cols
1065            .checked_mul(itemsize)
1066            .ok_or_else(|| PyValueError::new_err("byte stride overflow"))?;
1067        d.set_item("shape", (inner.rows, inner.cols))?;
1068        d.set_item("typestr", "<f4")?;
1069        d.set_item("strides", (row_stride, itemsize))?;
1070        d.set_item("data", (inner.device_ptr() as usize, false))?;
1071
1072        d.set_item("version", 3)?;
1073        Ok(d)
1074    }
1075
1076    fn __dlpack_device__(&self) -> PyResult<(i32, i32)> {
1077        let inner = self
1078            .inner
1079            .as_ref()
1080            .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
1081        Ok((2, inner.device_id as i32))
1082    }
1083
1084    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1085    fn __dlpack__<'py>(
1086        &mut self,
1087        py: Python<'py>,
1088        stream: Option<pyo3::PyObject>,
1089        max_version: Option<pyo3::PyObject>,
1090        dl_device: Option<pyo3::PyObject>,
1091        copy: Option<pyo3::PyObject>,
1092    ) -> PyResult<PyObject> {
1093        let (kdl, alloc_dev) = self.__dlpack_device__()?;
1094        if let Some(dev_obj) = dl_device.as_ref() {
1095            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1096                if dev_ty != kdl || dev_id != alloc_dev {
1097                    let wants_copy = copy
1098                        .as_ref()
1099                        .and_then(|c| c.extract::<bool>(py).ok())
1100                        .unwrap_or(false);
1101                    if wants_copy {
1102                        return Err(PyValueError::new_err(
1103                            "device copy not implemented for __dlpack__",
1104                        ));
1105                    } else {
1106                        return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
1107                    }
1108                }
1109            }
1110        }
1111        let _ = stream;
1112
1113        let inner = self
1114            .inner
1115            .take()
1116            .ok_or_else(|| PyValueError::new_err("__dlpack__ may only be called once"))?;
1117
1118        let rows = inner.rows;
1119        let cols = inner.cols;
1120        let buf = inner.buf;
1121
1122        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1123
1124        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1125    }
1126}
1127
1128#[inline]
1129pub fn roc_into_slice(dst: &mut [f64], input: &RocInput, kern: Kernel) -> Result<(), RocError> {
1130    let (data, period, first, chosen) = roc_prepare(input, kern)?;
1131    if dst.len() != data.len() {
1132        return Err(RocError::OutputLengthMismatch {
1133            expected: data.len(),
1134            got: dst.len(),
1135        });
1136    }
1137    let warmup_end = first + period;
1138    for v in &mut dst[..warmup_end] {
1139        *v = f64::NAN;
1140    }
1141    roc_compute_into(data, period, first, chosen, dst);
1142    Ok(())
1143}
1144
1145#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1146#[wasm_bindgen]
1147pub fn roc_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1148    let params = RocParams {
1149        period: Some(period),
1150    };
1151    let input = RocInput::from_slice(data, params);
1152
1153    let mut output = vec![0.0; data.len()];
1154
1155    roc_into_slice(&mut output, &input, Kernel::Auto)
1156        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1157
1158    Ok(output)
1159}
1160
1161#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1162#[wasm_bindgen]
1163pub fn roc_alloc(len: usize) -> *mut f64 {
1164    let mut vec = Vec::<f64>::with_capacity(len);
1165    let ptr = vec.as_mut_ptr();
1166    std::mem::forget(vec);
1167    ptr
1168}
1169
1170#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1171#[wasm_bindgen]
1172pub fn roc_free(ptr: *mut f64, len: usize) {
1173    if !ptr.is_null() {
1174        unsafe {
1175            let _ = Vec::from_raw_parts(ptr, len, len);
1176        }
1177    }
1178}
1179
1180#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1181#[wasm_bindgen]
1182pub fn roc_into(
1183    in_ptr: *const f64,
1184    out_ptr: *mut f64,
1185    len: usize,
1186    period: usize,
1187) -> Result<(), JsValue> {
1188    if in_ptr.is_null() || out_ptr.is_null() {
1189        return Err(JsValue::from_str("Null pointer provided"));
1190    }
1191
1192    unsafe {
1193        let data = std::slice::from_raw_parts(in_ptr, len);
1194        let params = RocParams {
1195            period: Some(period),
1196        };
1197        let input = RocInput::from_slice(data, params);
1198
1199        if in_ptr == out_ptr {
1200            let mut temp = vec![0.0; len];
1201            roc_into_slice(&mut temp, &input, Kernel::Auto)
1202                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1203            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1204            out.copy_from_slice(&temp);
1205        } else {
1206            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1207            roc_into_slice(out, &input, Kernel::Auto)
1208                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1209        }
1210        Ok(())
1211    }
1212}
1213
1214#[cfg(all(feature = "python", feature = "cuda"))]
1215use crate::cuda::cuda_available;
1216#[cfg(all(feature = "python", feature = "cuda"))]
1217use crate::cuda::oscillators::roc_wrapper::{CudaRoc, DeviceArrayF32Roc};
1218
1219#[cfg(all(feature = "python", feature = "cuda"))]
1220#[pyfunction(name = "roc_cuda_batch_dev")]
1221#[pyo3(signature = (data, period_range, device_id=0))]
1222pub fn roc_cuda_batch_dev_py(
1223    py: Python<'_>,
1224    data: numpy::PyReadonlyArray1<'_, f32>,
1225    period_range: (usize, usize, usize),
1226    device_id: usize,
1227) -> PyResult<RocDeviceArrayF32Py> {
1228    if !cuda_available() {
1229        return Err(PyValueError::new_err("CUDA not available"));
1230    }
1231    let prices = data.as_slice()?;
1232    let sweep = RocBatchRange {
1233        period: period_range,
1234    };
1235    let inner = py.allow_threads(|| {
1236        let cuda = CudaRoc::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1237        cuda.roc_batch_dev(prices, &sweep)
1238            .map_err(|e| PyValueError::new_err(e.to_string()))
1239    })?;
1240    Ok(RocDeviceArrayF32Py { inner: Some(inner) })
1241}
1242
1243#[cfg(all(feature = "python", feature = "cuda"))]
1244#[pyfunction(name = "roc_cuda_many_series_one_param_dev")]
1245#[pyo3(signature = (data_tm, cols, rows, period, device_id=0))]
1246pub fn roc_cuda_many_series_one_param_dev_py(
1247    py: Python<'_>,
1248    data_tm: numpy::PyReadonlyArray1<'_, f32>,
1249    cols: usize,
1250    rows: usize,
1251    period: usize,
1252    device_id: usize,
1253) -> PyResult<RocDeviceArrayF32Py> {
1254    if !cuda_available() {
1255        return Err(PyValueError::new_err("CUDA not available"));
1256    }
1257    let prices_tm = data_tm.as_slice()?;
1258    let expected = cols
1259        .checked_mul(rows)
1260        .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
1261    if prices_tm.len() != expected {
1262        return Err(PyValueError::new_err("time-major input length mismatch"));
1263    }
1264    let inner = py.allow_threads(|| {
1265        let cuda = CudaRoc::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1266        cuda.roc_many_series_one_param_time_major_dev(prices_tm, cols, rows, period)
1267            .map_err(|e| PyValueError::new_err(e.to_string()))
1268    })?;
1269    Ok(RocDeviceArrayF32Py { inner: Some(inner) })
1270}
1271
1272#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1273#[derive(Serialize, Deserialize)]
1274pub struct RocBatchConfig {
1275    pub period_range: (usize, usize, usize),
1276}
1277
1278#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1279#[derive(Serialize, Deserialize)]
1280pub struct RocBatchJsOutput {
1281    pub values: Vec<f64>,
1282    pub combos: Vec<RocParams>,
1283    pub rows: usize,
1284    pub cols: usize,
1285}
1286
1287#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1288#[wasm_bindgen]
1289pub fn roc_batch(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1290    let config: RocBatchConfig = serde_wasm_bindgen::from_value(config)
1291        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1292
1293    let batch_range = RocBatchRange {
1294        period: config.period_range,
1295    };
1296
1297    let result = roc_batch_with_kernel(data, &batch_range, Kernel::Auto)
1298        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1299
1300    let output = RocBatchJsOutput {
1301        values: result.values,
1302        combos: result.combos,
1303        rows: result.rows,
1304        cols: result.cols,
1305    };
1306
1307    serde_wasm_bindgen::to_value(&output)
1308        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1309}
1310
1311#[cfg(test)]
1312mod tests {
1313    use super::*;
1314    use crate::skip_if_unsupported;
1315    use crate::utilities::data_loader::read_candles_from_csv;
1316    use paste::paste;
1317
1318    #[test]
1319    fn test_roc_into_matches_api() -> Result<(), Box<dyn Error>> {
1320        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1321        let candles = read_candles_from_csv(file_path)?;
1322
1323        let params = RocParams { period: Some(10) };
1324        let input = RocInput::from_candles(&candles, "close", params);
1325
1326        let baseline = roc(&input)?.values;
1327
1328        let mut out = vec![0.0; candles.close.len()];
1329        roc_into(&input, &mut out)?;
1330
1331        assert_eq!(baseline.len(), out.len());
1332
1333        for (i, (a, b)) in baseline.iter().zip(out.iter()).enumerate() {
1334            let equal = (a.is_nan() && b.is_nan()) || (a == b) || ((a - b).abs() <= 1e-12);
1335            assert!(
1336                equal,
1337                "roc_into parity mismatch at idx {}: api={} into={}",
1338                i, a, b
1339            );
1340        }
1341        Ok(())
1342    }
1343
1344    fn check_roc_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1345        skip_if_unsupported!(kernel, test_name);
1346        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1347        let candles = read_candles_from_csv(file_path)?;
1348
1349        let default_params = RocParams { period: None };
1350        let input_default = RocInput::from_candles(&candles, "close", default_params);
1351        let output_default = roc_with_kernel(&input_default, kernel)?;
1352        assert_eq!(output_default.values.len(), candles.close.len());
1353
1354        let params_period_14 = RocParams { period: Some(14) };
1355        let input_period_14 = RocInput::from_candles(&candles, "hl2", params_period_14);
1356        let output_period_14 = roc_with_kernel(&input_period_14, kernel)?;
1357        assert_eq!(output_period_14.values.len(), candles.close.len());
1358
1359        let params_custom = RocParams { period: Some(20) };
1360        let input_custom = RocInput::from_candles(&candles, "hlc3", params_custom);
1361        let output_custom = roc_with_kernel(&input_custom, kernel)?;
1362        assert_eq!(output_custom.values.len(), candles.close.len());
1363
1364        Ok(())
1365    }
1366
1367    fn check_roc_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1368        skip_if_unsupported!(kernel, test_name);
1369        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1370        let candles = read_candles_from_csv(file_path)?;
1371
1372        let close_prices = &candles.close;
1373        let params = RocParams { period: Some(10) };
1374        let input = RocInput::from_candles(&candles, "close", params);
1375        let roc_result = roc_with_kernel(&input, kernel)?;
1376
1377        assert_eq!(roc_result.values.len(), close_prices.len());
1378
1379        let expected_last_five_roc = [
1380            -0.22551709049294377,
1381            -0.5561903481650754,
1382            -0.32752013235864963,
1383            -0.49454153980722504,
1384            -1.5045927020536976,
1385        ];
1386        assert!(roc_result.values.len() >= 5);
1387        let start_index = roc_result.values.len() - 5;
1388        let result_last_five_roc = &roc_result.values[start_index..];
1389        for (i, &value) in result_last_five_roc.iter().enumerate() {
1390            let expected_value = expected_last_five_roc[i];
1391            assert!(
1392                (value - expected_value).abs() < 1e-7,
1393                "[{}] ROC mismatch at index {}: expected {}, got {}",
1394                test_name,
1395                i,
1396                expected_value,
1397                value
1398            );
1399        }
1400        let period = input.get_period();
1401        for i in 0..(period - 1) {
1402            assert!(roc_result.values[i].is_nan());
1403        }
1404        Ok(())
1405    }
1406
1407    fn check_roc_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1408        skip_if_unsupported!(kernel, test_name);
1409        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1410        let candles = read_candles_from_csv(file_path)?;
1411
1412        let input = RocInput::with_default_candles(&candles);
1413        match input.data {
1414            RocData::Candles { source, .. } => assert_eq!(source, "close"),
1415            _ => panic!("Expected RocData::Candles"),
1416        }
1417        let output = roc_with_kernel(&input, kernel)?;
1418        assert_eq!(output.values.len(), candles.close.len());
1419        Ok(())
1420    }
1421
1422    fn check_roc_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1423        skip_if_unsupported!(kernel, test_name);
1424        let input_data = [10.0, 20.0, 30.0];
1425        let params = RocParams { period: Some(0) };
1426        let input = RocInput::from_slice(&input_data, params);
1427        let res = roc_with_kernel(&input, kernel);
1428        assert!(res.is_err());
1429        Ok(())
1430    }
1431
1432    fn check_roc_period_exceeds_length(
1433        test_name: &str,
1434        kernel: Kernel,
1435    ) -> Result<(), Box<dyn Error>> {
1436        skip_if_unsupported!(kernel, test_name);
1437        let data_small = [10.0, 20.0, 30.0];
1438        let params = RocParams { period: Some(10) };
1439        let input = RocInput::from_slice(&data_small, params);
1440        let res = roc_with_kernel(&input, kernel);
1441        assert!(res.is_err());
1442        Ok(())
1443    }
1444
1445    fn check_roc_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1446        skip_if_unsupported!(kernel, test_name);
1447        let single_point = [42.0];
1448        let params = RocParams { period: Some(9) };
1449        let input = RocInput::from_slice(&single_point, params);
1450        let res = roc_with_kernel(&input, kernel);
1451        assert!(res.is_err());
1452        Ok(())
1453    }
1454
1455    fn check_roc_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1456        skip_if_unsupported!(kernel, test_name);
1457        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1458        let candles = read_candles_from_csv(file_path)?;
1459
1460        let first_params = RocParams { period: Some(14) };
1461        let first_input = RocInput::from_candles(&candles, "close", first_params);
1462        let first_result = roc_with_kernel(&first_input, kernel)?;
1463
1464        let second_params = RocParams { period: Some(14) };
1465        let second_input = RocInput::from_slice(&first_result.values, second_params);
1466        let second_result = roc_with_kernel(&second_input, kernel)?;
1467
1468        assert_eq!(second_result.values.len(), first_result.values.len());
1469        for i in 28..second_result.values.len() {
1470            assert!(
1471                !second_result.values[i].is_nan(),
1472                "[{}] Expected no NaN after index 28, found NaN at {}",
1473                test_name,
1474                i
1475            );
1476        }
1477        Ok(())
1478    }
1479
1480    fn check_roc_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1481        skip_if_unsupported!(kernel, test_name);
1482        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1483        let candles = read_candles_from_csv(file_path)?;
1484
1485        let input = RocInput::from_candles(&candles, "close", RocParams { period: Some(9) });
1486        let res = roc_with_kernel(&input, kernel)?;
1487        assert_eq!(res.values.len(), candles.close.len());
1488        if res.values.len() > 240 {
1489            for (i, &val) in res.values[240..].iter().enumerate() {
1490                assert!(
1491                    !val.is_nan(),
1492                    "[{}] Found unexpected NaN at out-index {}",
1493                    test_name,
1494                    240 + i
1495                );
1496            }
1497        }
1498        Ok(())
1499    }
1500
1501    fn check_roc_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1502        skip_if_unsupported!(kernel, test_name);
1503        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1504        let candles = read_candles_from_csv(file_path)?;
1505        let period = 9;
1506
1507        let input = RocInput::from_candles(
1508            &candles,
1509            "close",
1510            RocParams {
1511                period: Some(period),
1512            },
1513        );
1514        let batch_output = roc_with_kernel(&input, kernel)?.values;
1515
1516        let mut stream = RocStream::try_new(RocParams {
1517            period: Some(period),
1518        })?;
1519        let mut stream_values = Vec::with_capacity(candles.close.len());
1520        for &price in &candles.close {
1521            match stream.update(price) {
1522                Some(val) => stream_values.push(val),
1523                None => stream_values.push(f64::NAN),
1524            }
1525        }
1526        assert_eq!(batch_output.len(), stream_values.len());
1527        for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1528            if b.is_nan() && s.is_nan() {
1529                continue;
1530            }
1531            let diff = (b - s).abs();
1532            assert!(
1533                diff < 1e-9,
1534                "[{}] ROC streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1535                test_name,
1536                i,
1537                b,
1538                s,
1539                diff
1540            );
1541        }
1542        Ok(())
1543    }
1544
1545    #[cfg(debug_assertions)]
1546    fn check_roc_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1547        skip_if_unsupported!(kernel, test_name);
1548
1549        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1550        let candles = read_candles_from_csv(file_path)?;
1551
1552        let test_params = vec![
1553            RocParams::default(),
1554            RocParams { period: Some(2) },
1555            RocParams { period: Some(5) },
1556            RocParams { period: Some(7) },
1557            RocParams { period: Some(9) },
1558            RocParams { period: Some(14) },
1559            RocParams { period: Some(20) },
1560            RocParams { period: Some(30) },
1561            RocParams { period: Some(50) },
1562            RocParams { period: Some(100) },
1563            RocParams { period: Some(200) },
1564        ];
1565
1566        for (param_idx, params) in test_params.iter().enumerate() {
1567            let input = RocInput::from_candles(&candles, "close", params.clone());
1568            let output = roc_with_kernel(&input, kernel)?;
1569
1570            for (i, &val) in output.values.iter().enumerate() {
1571                if val.is_nan() {
1572                    continue;
1573                }
1574
1575                let bits = val.to_bits();
1576
1577                if bits == 0x11111111_11111111 {
1578                    panic!(
1579                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1580						 with params: period={} (param set {})",
1581                        test_name,
1582                        val,
1583                        bits,
1584                        i,
1585                        params.period.unwrap_or(9),
1586                        param_idx
1587                    );
1588                }
1589
1590                if bits == 0x22222222_22222222 {
1591                    panic!(
1592                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1593						 with params: period={} (param set {})",
1594                        test_name,
1595                        val,
1596                        bits,
1597                        i,
1598                        params.period.unwrap_or(9),
1599                        param_idx
1600                    );
1601                }
1602
1603                if bits == 0x33333333_33333333 {
1604                    panic!(
1605                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1606						 with params: period={} (param set {})",
1607                        test_name,
1608                        val,
1609                        bits,
1610                        i,
1611                        params.period.unwrap_or(9),
1612                        param_idx
1613                    );
1614                }
1615            }
1616        }
1617
1618        Ok(())
1619    }
1620
1621    #[cfg(not(debug_assertions))]
1622    fn check_roc_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1623        Ok(())
1624    }
1625
1626    #[cfg(feature = "proptest")]
1627    #[allow(clippy::float_cmp)]
1628    fn check_roc_property(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1629        use proptest::prelude::*;
1630        skip_if_unsupported!(kernel, test_name);
1631
1632        let strat = (2usize..=100).prop_flat_map(|period| {
1633            prop_oneof![
1634                (
1635                    prop::collection::vec(
1636                        (1f64..1e6f64)
1637                            .prop_filter("finite positive", |x| x.is_finite() && *x > 0.0),
1638                        period..400,
1639                    ),
1640                    Just(period),
1641                ),
1642                (
1643                    prop::collection::vec(
1644                        prop_oneof![
1645                            (1f64..1000f64).prop_filter("finite", |x| x.is_finite()),
1646                            Just(100.0),
1647                        ],
1648                        period..400,
1649                    ),
1650                    Just(period),
1651                ),
1652                (
1653                    (100f64..10000f64, 0.01f64..0.1f64).prop_map(move |(start, step)| {
1654                        let len = period + (400 - period) / 2;
1655                        (0..len)
1656                            .map(|i| start + (i as f64) * step)
1657                            .collect::<Vec<_>>()
1658                    }),
1659                    Just(period),
1660                ),
1661                (
1662                    (10000f64..100000f64, 0.01f64..0.1f64).prop_map(move |(start, step)| {
1663                        let len = period + (400 - period) / 2;
1664                        (0..len)
1665                            .map(|i| start - (i as f64) * step)
1666                            .collect::<Vec<_>>()
1667                    }),
1668                    Just(period),
1669                ),
1670            ]
1671        });
1672
1673        proptest::test_runner::TestRunner::default().run(&strat, |(data, period)| {
1674            let params = RocParams {
1675                period: Some(period),
1676            };
1677            let input = RocInput::from_slice(&data, params);
1678
1679            let RocOutput { values: out } = roc_with_kernel(&input, kernel)?;
1680            let RocOutput { values: ref_out } = roc_with_kernel(&input, Kernel::Scalar)?;
1681
1682            prop_assert_eq!(out.len(), data.len(), "Output length mismatch");
1683
1684            for i in 0..period {
1685                prop_assert!(
1686                    out[i].is_nan(),
1687                    "Expected NaN during warmup at index {}, got {}",
1688                    i,
1689                    out[i]
1690                );
1691            }
1692
1693            for i in period..data.len() {
1694                let current = data[i];
1695                let previous = data[i - period];
1696                let roc_val = out[i];
1697
1698                let expected_roc = if previous == 0.0 || previous.is_nan() {
1699                    0.0
1700                } else {
1701                    ((current / previous) - 1.0) * 100.0
1702                };
1703
1704                if !roc_val.is_nan() {
1705                    prop_assert!(
1706							(roc_val - expected_roc).abs() < 1e-9,
1707							"ROC calculation mismatch at {}: got {}, expected {} (current={}, previous={})",
1708							i, roc_val, expected_roc, current, previous
1709						);
1710
1711                    if current > previous && previous > 0.0 {
1712                        prop_assert!(
1713								roc_val > -1e-9,
1714								"ROC should be positive when current > previous at {}: roc={}, current={}, previous={}",
1715								i, roc_val, current, previous
1716							);
1717                    }
1718                    if current < previous && previous > 0.0 {
1719                        prop_assert!(
1720								roc_val < 1e-9,
1721								"ROC should be negative when current < previous at {}: roc={}, current={}, previous={}",
1722								i, roc_val, current, previous
1723							);
1724                    }
1725                    if (current - previous).abs() < 1e-12 && previous > 0.0 {
1726                        prop_assert!(
1727								roc_val.abs() < 1e-9,
1728								"ROC should be ~0 when current ≈ previous at {}: roc={}, current={}, previous={}",
1729								i, roc_val, current, previous
1730							);
1731                    }
1732                }
1733            }
1734
1735            let is_constant = data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-12);
1736            if is_constant && data.len() > period {
1737                for i in period..data.len() {
1738                    if !out[i].is_nan() {
1739                        prop_assert!(
1740                            out[i].abs() < 1e-9,
1741                            "ROC of constant data should be 0 at {}: got {}",
1742                            i,
1743                            out[i]
1744                        );
1745                    }
1746                }
1747            }
1748
1749            let is_monotonic_increasing = data.windows(2).all(|w| w[1] >= w[0]);
1750            let is_monotonic_decreasing = data.windows(2).all(|w| w[1] <= w[0]);
1751
1752            if is_monotonic_increasing && !is_constant {
1753                for i in period..data.len() {
1754                    if !out[i].is_nan() && data[i] > data[i - period] {
1755                        prop_assert!(
1756                            out[i] > -1e-9,
1757                            "ROC should be positive for increasing data at {}: got {}",
1758                            i,
1759                            out[i]
1760                        );
1761                    }
1762                }
1763            }
1764
1765            if is_monotonic_decreasing && !is_constant {
1766                for i in period..data.len() {
1767                    if !out[i].is_nan() && data[i] < data[i - period] {
1768                        prop_assert!(
1769                            out[i] < 1e-9,
1770                            "ROC should be negative for decreasing data at {}: got {}",
1771                            i,
1772                            out[i]
1773                        );
1774                    }
1775                }
1776            }
1777
1778            prop_assert_eq!(out.len(), ref_out.len(), "Kernel output length mismatch");
1779
1780            for i in 0..out.len() {
1781                let y = out[i];
1782                let r = ref_out[i];
1783
1784                if !y.is_finite() || !r.is_finite() {
1785                    prop_assert!(
1786                        y.to_bits() == r.to_bits(),
1787                        "NaN/Inf mismatch at {}: {} vs {}",
1788                        i,
1789                        y,
1790                        r
1791                    );
1792                } else {
1793                    let tolerance = 1e-9;
1794                    prop_assert!(
1795                        (y - r).abs() <= tolerance,
1796                        "Kernel mismatch at {}: {} vs {}, diff={}",
1797                        i,
1798                        y,
1799                        r,
1800                        (y - r).abs()
1801                    );
1802                }
1803            }
1804
1805            #[cfg(debug_assertions)]
1806            {
1807                for (i, &val) in out.iter().enumerate() {
1808                    if !val.is_nan() {
1809                        let bits = val.to_bits();
1810                        prop_assert_ne!(
1811                            bits,
1812                            0x11111111_11111111,
1813                            "Found alloc_with_nan_prefix poison at {}",
1814                            i
1815                        );
1816                        prop_assert_ne!(
1817                            bits,
1818                            0x22222222_22222222,
1819                            "Found init_matrix_prefixes poison at {}",
1820                            i
1821                        );
1822                        prop_assert_ne!(
1823                            bits,
1824                            0x33333333_33333333,
1825                            "Found make_uninit_matrix poison at {}",
1826                            i
1827                        );
1828                    }
1829                }
1830            }
1831
1832            Ok(())
1833        })?;
1834
1835        Ok(())
1836    }
1837
1838    macro_rules! generate_all_roc_tests {
1839        ($($test_fn:ident),*) => {
1840            paste! {
1841                $(
1842                    #[test]
1843                    fn [<$test_fn _scalar_f64>]() {
1844                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1845                    }
1846                )*
1847                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1848                $(
1849                    #[test]
1850                    fn [<$test_fn _avx2_f64>]() {
1851                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1852                    }
1853                    #[test]
1854                    fn [<$test_fn _avx512_f64>]() {
1855                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1856                    }
1857                )*
1858            }
1859        }
1860    }
1861    generate_all_roc_tests!(
1862        check_roc_partial_params,
1863        check_roc_accuracy,
1864        check_roc_default_candles,
1865        check_roc_zero_period,
1866        check_roc_period_exceeds_length,
1867        check_roc_very_small_dataset,
1868        check_roc_reinput,
1869        check_roc_nan_handling,
1870        check_roc_streaming,
1871        check_roc_no_poison
1872    );
1873
1874    #[cfg(feature = "proptest")]
1875    generate_all_roc_tests!(check_roc_property);
1876
1877    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1878        skip_if_unsupported!(kernel, test);
1879
1880        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1881        let c = read_candles_from_csv(file)?;
1882
1883        let output = RocBatchBuilder::new()
1884            .period_static(10)
1885            .kernel(kernel)
1886            .apply_candles(&c, "close")?;
1887
1888        let test_params = RocParams { period: Some(10) };
1889        let row = output
1890            .values_for(&test_params)
1891            .expect("period=10 row missing");
1892
1893        assert_eq!(row.len(), c.close.len());
1894
1895        let expected = [
1896            -0.22551709049294377,
1897            -0.5561903481650754,
1898            -0.32752013235864963,
1899            -0.49454153980722504,
1900            -1.5045927020536976,
1901        ];
1902        let start = row.len() - 5;
1903        for (i, &v) in row[start..].iter().enumerate() {
1904            assert!(
1905                (v - expected[i]).abs() < 1e-7,
1906                "[{test}] period=10 row mismatch at idx {i}: {v} vs {expected:?}"
1907            );
1908        }
1909        Ok(())
1910    }
1911
1912    #[cfg(debug_assertions)]
1913    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1914        skip_if_unsupported!(kernel, test);
1915
1916        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1917        let c = read_candles_from_csv(file)?;
1918
1919        let test_configs = vec![
1920            (2, 10, 2),
1921            (5, 25, 5),
1922            (30, 60, 15),
1923            (2, 5, 1),
1924            (9, 9, 0),
1925            (14, 21, 7),
1926            (10, 50, 10),
1927        ];
1928
1929        for (cfg_idx, &(p_start, p_end, p_step)) in test_configs.iter().enumerate() {
1930            let output = RocBatchBuilder::new()
1931                .kernel(kernel)
1932                .period_range(p_start, p_end, p_step)
1933                .apply_candles(&c, "close")?;
1934
1935            for (idx, &val) in output.values.iter().enumerate() {
1936                if val.is_nan() {
1937                    continue;
1938                }
1939
1940                let bits = val.to_bits();
1941                let row = idx / output.cols;
1942                let col = idx % output.cols;
1943                let combo = &output.combos[row];
1944
1945                if bits == 0x11111111_11111111 {
1946                    panic!(
1947                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1948						 at row {} col {} (flat index {}) with params: period={}",
1949                        test,
1950                        cfg_idx,
1951                        val,
1952                        bits,
1953                        row,
1954                        col,
1955                        idx,
1956                        combo.period.unwrap_or(9)
1957                    );
1958                }
1959
1960                if bits == 0x22222222_22222222 {
1961                    panic!(
1962                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1963						 at row {} col {} (flat index {}) with params: period={}",
1964                        test,
1965                        cfg_idx,
1966                        val,
1967                        bits,
1968                        row,
1969                        col,
1970                        idx,
1971                        combo.period.unwrap_or(9)
1972                    );
1973                }
1974
1975                if bits == 0x33333333_33333333 {
1976                    panic!(
1977                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1978						 at row {} col {} (flat index {}) with params: period={}",
1979                        test,
1980                        cfg_idx,
1981                        val,
1982                        bits,
1983                        row,
1984                        col,
1985                        idx,
1986                        combo.period.unwrap_or(9)
1987                    );
1988                }
1989            }
1990        }
1991
1992        Ok(())
1993    }
1994
1995    #[cfg(not(debug_assertions))]
1996    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1997        Ok(())
1998    }
1999
2000    macro_rules! gen_batch_tests {
2001        ($fn_name:ident) => {
2002            paste! {
2003                #[test] fn [<$fn_name _scalar>]()      {
2004                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2005                }
2006                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2007                #[test] fn [<$fn_name _avx2>]()        {
2008                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2009                }
2010                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2011                #[test] fn [<$fn_name _avx512>]()      {
2012                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2013                }
2014                #[test] fn [<$fn_name _auto_detect>]() {
2015                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2016                }
2017            }
2018        };
2019    }
2020    gen_batch_tests!(check_batch_default_row);
2021    gen_batch_tests!(check_batch_no_poison);
2022}