Skip to main content

vector_ta/indicators/moving_averages/
reflex.rs

1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5    make_uninit_matrix,
6};
7use aligned_vec::{AVec, CACHELINE_ALIGN};
8#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
9use core::arch::x86_64::*;
10#[cfg(not(target_arch = "wasm32"))]
11use rayon::prelude::*;
12use std::convert::AsRef;
13use std::error::Error;
14use std::mem::MaybeUninit;
15use thiserror::Error;
16
17#[cfg(all(feature = "python", feature = "cuda"))]
18use crate::cuda::cuda_available;
19#[cfg(all(feature = "python", feature = "cuda"))]
20use crate::cuda::moving_averages::CudaReflex;
21#[cfg(all(feature = "python", feature = "cuda"))]
22use crate::cuda::moving_averages::DeviceArrayF32;
23#[cfg(all(feature = "python", feature = "cuda"))]
24use crate::utilities::dlpack_cuda::DeviceArrayF32Py;
25#[cfg(feature = "python")]
26use crate::utilities::kernel_validation::validate_kernel;
27#[cfg(feature = "python")]
28use numpy::{IntoPyArray, PyArray1, PyArray2, PyArrayMethods, PyReadonlyArray1, PyReadonlyArray2};
29#[cfg(feature = "python")]
30use pyo3::exceptions::PyValueError;
31#[cfg(feature = "python")]
32use pyo3::prelude::*;
33#[cfg(feature = "python")]
34use pyo3::types::{PyDict, PyList};
35
36#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
37use wasm_bindgen::prelude::*;
38
39impl<'a> AsRef<[f64]> for ReflexInput<'a> {
40    #[inline(always)]
41    fn as_ref(&self) -> &[f64] {
42        match &self.data {
43            ReflexData::Slice(slice) => slice,
44            ReflexData::Candles { candles, source } => source_type(candles, source),
45        }
46    }
47}
48
49#[derive(Debug, Clone)]
50pub enum ReflexData<'a> {
51    Candles {
52        candles: &'a Candles,
53        source: &'a str,
54    },
55    Slice(&'a [f64]),
56}
57
58#[derive(Debug, Clone)]
59pub struct ReflexOutput {
60    pub values: Vec<f64>,
61}
62
63#[derive(Debug, Clone, Copy)]
64pub struct ReflexParams {
65    pub period: Option<usize>,
66}
67
68impl Default for ReflexParams {
69    fn default() -> Self {
70        Self { period: Some(20) }
71    }
72}
73
74#[derive(Debug, Clone)]
75pub struct ReflexInput<'a> {
76    pub data: ReflexData<'a>,
77    pub params: ReflexParams,
78}
79
80impl<'a> ReflexInput<'a> {
81    #[inline]
82    pub fn from_candles(c: &'a Candles, s: &'a str, p: ReflexParams) -> Self {
83        Self {
84            data: ReflexData::Candles {
85                candles: c,
86                source: s,
87            },
88            params: p,
89        }
90    }
91    #[inline]
92    pub fn from_slice(sl: &'a [f64], p: ReflexParams) -> Self {
93        Self {
94            data: ReflexData::Slice(sl),
95            params: p,
96        }
97    }
98    #[inline]
99    pub fn with_default_candles(c: &'a Candles) -> Self {
100        Self::from_candles(c, "close", ReflexParams::default())
101    }
102    #[inline]
103    pub fn get_period(&self) -> usize {
104        self.params.period.unwrap_or(20)
105    }
106}
107
108#[derive(Copy, Clone, Debug)]
109pub struct ReflexBuilder {
110    period: Option<usize>,
111    kernel: Kernel,
112}
113
114impl Default for ReflexBuilder {
115    fn default() -> Self {
116        Self {
117            period: None,
118            kernel: Kernel::Auto,
119        }
120    }
121}
122
123impl ReflexBuilder {
124    #[inline(always)]
125    pub fn new() -> Self {
126        Self::default()
127    }
128    #[inline(always)]
129    pub fn period(mut self, n: usize) -> Self {
130        self.period = Some(n);
131        self
132    }
133    #[inline(always)]
134    pub fn kernel(mut self, k: Kernel) -> Self {
135        self.kernel = k;
136        self
137    }
138    #[inline(always)]
139    pub fn apply(self, c: &Candles) -> Result<ReflexOutput, ReflexError> {
140        let p = ReflexParams {
141            period: self.period,
142        };
143        let i = ReflexInput::from_candles(c, "close", p);
144        reflex_with_kernel(&i, self.kernel)
145    }
146    #[inline(always)]
147    pub fn apply_slice(self, d: &[f64]) -> Result<ReflexOutput, ReflexError> {
148        let p = ReflexParams {
149            period: self.period,
150        };
151        let i = ReflexInput::from_slice(d, p);
152        reflex_with_kernel(&i, self.kernel)
153    }
154    #[inline(always)]
155    pub fn into_stream(self) -> Result<ReflexStream, ReflexError> {
156        let p = ReflexParams {
157            period: self.period,
158        };
159        ReflexStream::try_new(p)
160    }
161}
162
163#[derive(Debug, Error)]
164pub enum ReflexError {
165    #[error("reflex: No data available (input data slice is empty).")]
166    EmptyInputData,
167    #[error("reflex: All values are NaN.")]
168    AllValuesNaN,
169    #[error("reflex: period must be >=2 (period = {period}, data length = {data_len})")]
170    InvalidPeriod { period: usize, data_len: usize },
171    #[error("reflex: Not enough data: needed = {needed}, valid = {valid}")]
172    NotEnoughValidData { needed: usize, valid: usize },
173    #[error("reflex: output length mismatch: expected = {expected}, got = {got}")]
174    OutputLengthMismatch { expected: usize, got: usize },
175    #[error("reflex: invalid kernel for batch: {0:?}")]
176    InvalidKernelForBatch(Kernel),
177    #[error("reflex: invalid range: start = {start}, end = {end}, step = {step}")]
178    InvalidRange {
179        start: usize,
180        end: usize,
181        step: usize,
182    },
183}
184
185#[inline]
186pub fn reflex(input: &ReflexInput) -> Result<ReflexOutput, ReflexError> {
187    reflex_with_kernel(input, Kernel::Auto)
188}
189
190pub fn reflex_with_kernel(
191    input: &ReflexInput,
192    kernel: Kernel,
193) -> Result<ReflexOutput, ReflexError> {
194    let (data, period, first, chosen) = reflex_prepare(input, kernel)?;
195    let len = data.len();
196
197    let mut out = alloc_with_nan_prefix(len, period);
198
199    reflex_compute_into(data, period, first, chosen, &mut out);
200
201    out[..period.min(len)].fill(0.0);
202
203    Ok(ReflexOutput { values: out })
204}
205
206#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
207#[inline]
208pub fn reflex_into(input: &ReflexInput, out: &mut [f64]) -> Result<(), ReflexError> {
209    reflex_into_slice(out, input, Kernel::Auto)
210}
211
212#[inline]
213pub fn reflex_into_slice(
214    dst: &mut [f64],
215    input: &ReflexInput,
216    kern: Kernel,
217) -> Result<(), ReflexError> {
218    let (data, period, first, chosen) = reflex_prepare(input, kern)?;
219
220    if dst.len() != data.len() {
221        return Err(ReflexError::OutputLengthMismatch {
222            expected: data.len(),
223            got: dst.len(),
224        });
225    }
226
227    reflex_compute_into(data, period, first, chosen, dst);
228
229    let end = period.min(dst.len());
230    for x in &mut dst[..end] {
231        *x = 0.0;
232    }
233
234    Ok(())
235}
236
237#[inline]
238pub fn reflex_scalar(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
239    let len = data.len();
240    if len < 2 || period < 2 {
241        return;
242    }
243
244    let half_p = (period / 2).max(1) as f64;
245    let a = (-1.414_f64 * std::f64::consts::PI / half_p).exp();
246    let a2 = a * a;
247    let b = 2.0 * a * (1.414_f64 * std::f64::consts::PI / half_p).cos();
248    let c = 0.5 * (1.0 + a2 - b);
249
250    let ring_len = period + 1;
251    let mut ssf = vec![0.0_f64; ring_len];
252
253    ssf[0] = data[0];
254    if len > 1 {
255        ssf[1] = data[1];
256    }
257
258    let mut ssf_sum = ssf[0] + ssf[1];
259
260    let inv_p = 1.0 / (period as f64);
261    let alpha = 0.5 * (1.0 + inv_p);
262    let beta = 1.0 - alpha;
263
264    let mut ms = 0.0_f64;
265
266    let d_ptr = data.as_ptr();
267    let o_ptr = out.as_mut_ptr();
268
269    let mut idx_im2 = 0usize;
270    let mut idx_im1 = 1usize;
271    let mut idx = 2usize;
272
273    unsafe {
274        let mut i = 2usize;
275        while i < len {
276            let di = *d_ptr.add(i);
277            let dim1 = *d_ptr.add(i - 1);
278            let ssf_im1 = *ssf.get_unchecked(idx_im1);
279            let ssf_im2 = *ssf.get_unchecked(idx_im2);
280
281            let t0 = c * (di + dim1);
282            let t1 = (-a2).mul_add(ssf_im2, t0);
283            let ssf_i = b.mul_add(ssf_im1, t1);
284
285            *ssf.get_unchecked_mut(idx) = ssf_i;
286
287            if i < period {
288                ssf_sum += ssf_i;
289            } else {
290                let mut idx_ip = idx + 1;
291                if idx_ip == ring_len {
292                    idx_ip = 0;
293                }
294                let ssf_ip = *ssf.get_unchecked(idx_ip);
295
296                let mean_lp = ssf_sum * inv_p;
297                let my_sum = ssf_i.mul_add(beta, ssf_ip * alpha) - mean_lp;
298
299                ms = (0.96_f64).mul_add(ms, 0.04_f64 * (my_sum * my_sum));
300                if ms > 0.0 {
301                    *o_ptr.add(i) = my_sum / ms.sqrt();
302                }
303
304                ssf_sum += ssf_i - ssf_ip;
305            }
306
307            idx_im2 = idx_im1;
308            idx_im1 = idx;
309            idx += 1;
310            if idx == ring_len {
311                idx = 0;
312            }
313
314            i += 1;
315        }
316    }
317}
318
319#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
320#[target_feature(enable = "avx2,fma")]
321pub unsafe fn reflex_avx2(data: &[f64], period: usize, _first: usize, out: &mut [f64]) {
322    reflex_scalar(data, period, _first, out)
323}
324
325#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
326#[target_feature(enable = "avx512f,avx512dq,fma")]
327pub unsafe fn reflex_avx512(data: &[f64], period: usize, _first: usize, out: &mut [f64]) {
328    reflex_scalar(data, period, _first, out)
329}
330
331#[inline(always)]
332fn reflex_prepare<'a>(
333    input: &'a ReflexInput,
334    kernel: Kernel,
335) -> Result<(&'a [f64], usize, usize, Kernel), ReflexError> {
336    let data: &[f64] = match &input.data {
337        ReflexData::Candles { candles, source } => source_type(candles, source),
338        ReflexData::Slice(sl) => sl,
339    };
340
341    let len = data.len();
342    if len == 0 {
343        return Err(ReflexError::EmptyInputData);
344    }
345
346    let first = data
347        .iter()
348        .position(|x| !x.is_nan())
349        .ok_or(ReflexError::AllValuesNaN)?;
350    let period = input.get_period();
351
352    if period < 2 {
353        return Err(ReflexError::InvalidPeriod {
354            period,
355            data_len: len,
356        });
357    }
358    if period > (len - first) {
359        return Err(ReflexError::NotEnoughValidData {
360            needed: period,
361            valid: len - first,
362        });
363    }
364
365    let chosen = match kernel {
366        Kernel::Auto => Kernel::Scalar,
367        other => other,
368    };
369
370    Ok((data, period, first, chosen))
371}
372
373#[inline(always)]
374fn reflex_compute_into(data: &[f64], period: usize, first: usize, kernel: Kernel, out: &mut [f64]) {
375    unsafe {
376        match kernel {
377            Kernel::Scalar | Kernel::ScalarBatch => reflex_scalar(data, period, first, out),
378            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
379            Kernel::Avx2 | Kernel::Avx2Batch => reflex_avx2(data, period, first, out),
380            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
381            Kernel::Avx2 | Kernel::Avx2Batch => reflex_scalar(data, period, first, out),
382            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
383            Kernel::Avx512 | Kernel::Avx512Batch => reflex_avx512(data, period, first, out),
384            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
385            Kernel::Avx512 | Kernel::Avx512Batch => reflex_scalar(data, period, first, out),
386            _ => unreachable!(),
387        }
388    }
389}
390
391#[derive(Debug, Clone)]
392pub struct ReflexStream {
393    period: usize,
394
395    a_sq: f64,
396    b: f64,
397    c: f64,
398
399    alpha: f64,
400    beta: f64,
401    inv_p: f64,
402
403    ssf_buf: Vec<f64>,
404    head: usize,
405    tail: usize,
406
407    ssf_sum: f64,
408    last_ms: f64,
409    prev_x: f64,
410    last_ssf1: f64,
411    last_ssf2: f64,
412    count: usize,
413}
414
415impl ReflexStream {
416    #[inline]
417    pub fn try_new(params: ReflexParams) -> Result<Self, ReflexError> {
418        let period = params.period.unwrap_or(20);
419        if period < 2 {
420            return Err(ReflexError::InvalidPeriod {
421                period,
422                data_len: 0,
423            });
424        }
425
426        let half_p = (period / 2).max(1) as f64;
427        let a = (-1.414_f64 * std::f64::consts::PI / half_p).exp();
428        let a_sq = a * a;
429        let b = 2.0 * a * (1.414_f64 * std::f64::consts::PI / half_p).cos();
430        let c = 0.5 * (1.0 + a_sq - b);
431
432        let inv_p = 1.0 / (period as f64);
433        let alpha = 0.5 * (1.0 + inv_p);
434        let beta = 1.0 - alpha;
435
436        Ok(Self {
437            period,
438            a_sq,
439            b,
440            c,
441            alpha,
442            beta,
443            inv_p,
444
445            ssf_buf: vec![0.0; period + 1],
446            head: 0,
447            tail: 0,
448
449            ssf_sum: 0.0,
450            last_ms: 0.0,
451            prev_x: 0.0,
452            last_ssf1: 0.0,
453            last_ssf2: 0.0,
454            count: 0,
455        })
456    }
457
458    #[inline(always)]
459    pub fn update(&mut self, x: f64) -> Option<f64> {
460        let p = self.period;
461        let ring_len = p + 1;
462        let t = self.count;
463
464        if t == 0 {
465            self.prev_x = x;
466            self.last_ssf1 = x;
467            self.ssf_buf[self.head] = x;
468            self.head += 1;
469            if self.head == ring_len {
470                self.head = 0;
471            }
472            self.ssf_sum += x;
473            self.count = 1;
474            return None;
475        }
476        if t == 1 {
477            self.prev_x = x;
478            self.last_ssf2 = self.last_ssf1;
479            self.last_ssf1 = x;
480            self.ssf_buf[self.head] = x;
481            self.head += 1;
482            if self.head == ring_len {
483                self.head = 0;
484            }
485            self.ssf_sum += x;
486            self.count = 2;
487            return None;
488        }
489
490        let t0 = self.c * (x + self.prev_x);
491        let t1 = (-self.a_sq).mul_add(self.last_ssf2, t0);
492        let ssf_t = self.b.mul_add(self.last_ssf1, t1);
493
494        let mut out = None;
495        if t >= p {
496            let ssf_tp = self.ssf_buf[self.tail];
497
498            let mean_lp = self.ssf_sum * self.inv_p;
499
500            let my_sum = self.beta.mul_add(ssf_t, self.alpha * ssf_tp) - mean_lp;
501
502            let ms = 0.96_f64.mul_add(self.last_ms, 0.04_f64 * (my_sum * my_sum));
503            self.last_ms = ms;
504            out = if ms > 0.0 {
505                Some(my_sum / ms.sqrt())
506            } else {
507                Some(0.0)
508            };
509
510            self.ssf_sum += ssf_t - ssf_tp;
511            self.tail += 1;
512            if self.tail == ring_len {
513                self.tail = 0;
514            }
515        } else {
516            self.ssf_sum += ssf_t;
517        }
518
519        self.ssf_buf[self.head] = ssf_t;
520        self.head += 1;
521        if self.head == ring_len {
522            self.head = 0;
523        }
524
525        self.prev_x = x;
526        self.last_ssf2 = self.last_ssf1;
527        self.last_ssf1 = ssf_t;
528        self.count = t + 1;
529
530        out
531    }
532}
533
534#[derive(Clone, Debug)]
535pub struct ReflexBatchRange {
536    pub period: (usize, usize, usize),
537}
538
539impl Default for ReflexBatchRange {
540    fn default() -> Self {
541        Self {
542            period: (20, 269, 1),
543        }
544    }
545}
546
547#[derive(Clone, Debug, Default)]
548pub struct ReflexBatchBuilder {
549    range: ReflexBatchRange,
550    kernel: Kernel,
551}
552
553impl ReflexBatchBuilder {
554    pub fn new() -> Self {
555        Self::default()
556    }
557    pub fn kernel(mut self, k: Kernel) -> Self {
558        self.kernel = k;
559        self
560    }
561    #[inline]
562    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
563        self.range.period = (start, end, step);
564        self
565    }
566    #[inline]
567    pub fn period_static(mut self, p: usize) -> Self {
568        self.range.period = (p, p, 0);
569        self
570    }
571    pub fn apply_slice(self, data: &[f64]) -> Result<ReflexBatchOutput, ReflexError> {
572        reflex_batch_with_kernel(data, &self.range, self.kernel)
573    }
574    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<ReflexBatchOutput, ReflexError> {
575        ReflexBatchBuilder::new().kernel(k).apply_slice(data)
576    }
577    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<ReflexBatchOutput, ReflexError> {
578        let slice = source_type(c, src);
579        self.apply_slice(slice)
580    }
581    pub fn with_default_candles(c: &Candles) -> Result<ReflexBatchOutput, ReflexError> {
582        ReflexBatchBuilder::new()
583            .kernel(Kernel::Auto)
584            .apply_candles(c, "close")
585    }
586}
587
588pub fn reflex_batch_with_kernel(
589    data: &[f64],
590    sweep: &ReflexBatchRange,
591    k: Kernel,
592) -> Result<ReflexBatchOutput, ReflexError> {
593    let kernel = match k {
594        Kernel::Auto => detect_best_batch_kernel(),
595        other if other.is_batch() => other,
596        other => return Err(ReflexError::InvalidKernelForBatch(other)),
597    };
598    let simd = match kernel {
599        Kernel::Avx512Batch => Kernel::Avx512,
600        Kernel::Avx2Batch => Kernel::Avx2,
601        Kernel::ScalarBatch => Kernel::Scalar,
602        _ => unreachable!(),
603    };
604    reflex_batch_par_slice(data, sweep, simd)
605}
606
607#[derive(Clone, Debug)]
608pub struct ReflexBatchOutput {
609    pub values: Vec<f64>,
610    pub combos: Vec<ReflexParams>,
611    pub rows: usize,
612    pub cols: usize,
613}
614
615impl ReflexBatchOutput {
616    pub fn row_for_params(&self, p: &ReflexParams) -> Option<usize> {
617        self.combos
618            .iter()
619            .position(|c| c.period.unwrap_or(20) == p.period.unwrap_or(20))
620    }
621    pub fn values_for(&self, p: &ReflexParams) -> Option<&[f64]> {
622        self.row_for_params(p).map(|row| {
623            let start = row * self.cols;
624            &self.values[start..start + self.cols]
625        })
626    }
627}
628
629#[inline(always)]
630fn expand_grid_checked(r: &ReflexBatchRange) -> Result<Vec<ReflexParams>, ReflexError> {
631    fn axis_usize(range: (usize, usize, usize)) -> Result<Vec<usize>, ReflexError> {
632        let (start, end, step) = range;
633        if step == 0 || start == end {
634            return Ok(vec![start]);
635        }
636        let mut out = Vec::new();
637        if start < end {
638            let mut cur = start;
639            while cur <= end {
640                out.push(cur);
641                cur = match cur.checked_add(step) {
642                    Some(v) => v,
643                    None => break,
644                };
645            }
646        } else {
647            let mut cur = start;
648            while cur >= end {
649                out.push(cur);
650                cur = match cur.checked_sub(step) {
651                    Some(v) => v,
652                    None => break,
653                };
654                if cur == 0 {
655                    break;
656                }
657            }
658        }
659        if out.is_empty() {
660            return Err(ReflexError::InvalidRange { start, end, step });
661        }
662        Ok(out)
663    }
664    let periods = axis_usize(r.period)?;
665    Ok(periods
666        .into_iter()
667        .map(|p| ReflexParams { period: Some(p) })
668        .collect())
669}
670
671#[inline(always)]
672pub fn reflex_batch_slice(
673    data: &[f64],
674    sweep: &ReflexBatchRange,
675    kern: Kernel,
676) -> Result<ReflexBatchOutput, ReflexError> {
677    reflex_batch_inner(data, sweep, kern, false)
678}
679
680#[inline(always)]
681pub fn reflex_batch_par_slice(
682    data: &[f64],
683    sweep: &ReflexBatchRange,
684    kern: Kernel,
685) -> Result<ReflexBatchOutput, ReflexError> {
686    reflex_batch_inner(data, sweep, kern, true)
687}
688
689#[inline(always)]
690fn reflex_batch_inner(
691    data: &[f64],
692    sweep: &ReflexBatchRange,
693    kern: Kernel,
694    parallel: bool,
695) -> Result<ReflexBatchOutput, ReflexError> {
696    let combos = expand_grid_checked(sweep)?;
697    let first = data
698        .iter()
699        .position(|x| !x.is_nan())
700        .ok_or(ReflexError::AllValuesNaN)?;
701    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
702    if data.len() - first < max_p {
703        return Err(ReflexError::NotEnoughValidData {
704            needed: max_p,
705            valid: data.len() - first,
706        });
707    }
708
709    let rows = combos.len();
710    let cols = data.len();
711
712    let _total = rows.checked_mul(cols).ok_or(ReflexError::InvalidRange {
713        start: rows,
714        end: cols,
715        step: 0,
716    })?;
717
718    let mut buf_mu = make_uninit_matrix(rows, cols);
719    let warm: Vec<usize> = combos.iter().map(|c| c.period.unwrap()).collect();
720    init_matrix_prefixes(&mut buf_mu, cols, &warm);
721
722    let mut guard = core::mem::ManuallyDrop::new(buf_mu);
723    let out: &mut [f64] =
724        unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
725
726    let kernel = match kern {
727        Kernel::Auto => detect_best_batch_kernel(),
728        other => other,
729    };
730
731    let simd = match kernel {
732        Kernel::Avx512Batch => Kernel::Avx512,
733        Kernel::Avx2Batch => Kernel::Avx2,
734        Kernel::ScalarBatch => Kernel::Scalar,
735        other => other,
736    };
737
738    let meta = reflex_batch_inner_into(data, sweep, simd, parallel, out)?;
739
740    let values = unsafe {
741        Vec::from_raw_parts(
742            guard.as_mut_ptr() as *mut f64,
743            guard.len(),
744            guard.capacity(),
745        )
746    };
747
748    Ok(ReflexBatchOutput {
749        values,
750        combos: meta.combos,
751        rows: meta.rows,
752        cols: meta.cols,
753    })
754}
755
756#[inline(always)]
757unsafe fn reflex_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
758    reflex_scalar(data, period, first, out)
759}
760
761#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
762#[inline(always)]
763unsafe fn reflex_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
764    reflex_avx2(data, period, first, out)
765}
766
767#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
768#[inline(always)]
769unsafe fn reflex_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
770    reflex_avx512(data, period, first, out)
771}
772
773#[inline(always)]
774fn reflex_batch_inner_into(
775    data: &[f64],
776    sweep: &ReflexBatchRange,
777    kern: Kernel,
778    parallel: bool,
779    out: &mut [f64],
780) -> Result<ReflexBatchMetadata, ReflexError> {
781    let combos = expand_grid_checked(sweep)?;
782
783    let first = data
784        .iter()
785        .position(|x| !x.is_nan())
786        .ok_or(ReflexError::AllValuesNaN)?;
787    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
788    if data.len() - first < max_p {
789        return Err(ReflexError::NotEnoughValidData {
790            needed: max_p,
791            valid: data.len() - first,
792        });
793    }
794
795    let rows = combos.len();
796    let cols = data.len();
797
798    let expected = rows.checked_mul(cols).ok_or(ReflexError::InvalidRange {
799        start: rows,
800        end: cols,
801        step: 0,
802    })?;
803    if out.len() != expected {
804        return Err(ReflexError::OutputLengthMismatch {
805            expected,
806            got: out.len(),
807        });
808    }
809
810    let do_row = |row: usize, dst: &mut [f64]| unsafe {
811        let period = combos[row].period.unwrap();
812
813        for x in &mut dst[..period.min(cols)] {
814            *x = 0.0;
815        }
816
817        match kern {
818            Kernel::Scalar => reflex_row_scalar(data, first, period, dst),
819            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
820            Kernel::Avx2 => reflex_row_avx2(data, first, period, dst),
821            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
822            Kernel::Avx2 => reflex_row_scalar(data, first, period, dst),
823            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
824            Kernel::Avx512 => reflex_row_avx512(data, first, period, dst),
825            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
826            Kernel::Avx512 => reflex_row_scalar(data, first, period, dst),
827            _ => unreachable!(),
828        }
829    };
830
831    if parallel {
832        #[cfg(not(target_arch = "wasm32"))]
833        {
834            out.par_chunks_mut(cols)
835                .enumerate()
836                .for_each(|(row, slice)| do_row(row, slice));
837        }
838        #[cfg(target_arch = "wasm32")]
839        {
840            for (row, slice) in out.chunks_mut(cols).enumerate() {
841                do_row(row, slice);
842            }
843        }
844    } else {
845        for (row, slice) in out.chunks_mut(cols).enumerate() {
846            do_row(row, slice);
847        }
848    }
849
850    Ok(ReflexBatchMetadata { combos, rows, cols })
851}
852
853#[derive(Clone, Debug)]
854pub struct ReflexBatchMetadata {
855    pub combos: Vec<ReflexParams>,
856    pub rows: usize,
857    pub cols: usize,
858}
859
860#[cfg(feature = "python")]
861#[pyfunction(name = "reflex")]
862#[pyo3(signature = (data, period = 20, kernel = None), text_signature = "(data, period=20, kernel=None)")]
863pub fn reflex_py<'py>(
864    py: Python<'py>,
865    data: PyReadonlyArray1<'py, f64>,
866    period: usize,
867    kernel: Option<&str>,
868) -> PyResult<Bound<'py, PyArray1<f64>>> {
869    r#"Compute Reflex indicator.
870
871    Parameters
872    ----------
873    data : numpy.ndarray
874        Input data array
875    period : int, default=20
876        Period for the indicator (must be >= 2)
877    kernel : str, optional
878        Kernel to use:
879        - 'auto' or None: Auto-detect best kernel (default)
880        - 'scalar': Use scalar implementation
881        - 'avx2': Use AVX2 implementation (if available)
882        - 'avx512': Use AVX512 implementation (if available)
883
884    Returns
885    -------
886    numpy.ndarray
887        Reflex values
888    "#;
889
890    use numpy::{IntoPyArray, PyArrayMethods};
891
892    let data_slice = data.as_slice()?;
893    let kern = validate_kernel(kernel, false)?;
894
895    let params = ReflexParams {
896        period: Some(period),
897    };
898    let input = ReflexInput::from_slice(data_slice, params);
899
900    let result_vec: Vec<f64> = py
901        .allow_threads(|| reflex_with_kernel(&input, kern).map(|o| o.values))
902        .map_err(|e| PyValueError::new_err(e.to_string()))?;
903
904    Ok(result_vec.into_pyarray(py))
905}
906
907#[cfg(feature = "python")]
908#[pyfunction(name = "reflex_batch")]
909#[pyo3(signature = (data, periods, kernel = None), text_signature = "(data, periods, kernel=None)")]
910pub fn reflex_batch_py<'py>(
911    py: Python<'py>,
912    data: PyReadonlyArray1<'py, f64>,
913    periods: (usize, usize, usize),
914    kernel: Option<&str>,
915) -> PyResult<Py<PyDict>> {
916    r#"Compute Reflex indicator for multiple periods.
917
918    Parameters
919    ----------
920    data : numpy.ndarray
921        Input data array
922    periods : tuple of int
923        (start, end, step) for period range
924    kernel : str, optional
925        Kernel to use (see reflex() for options)
926
927    Returns
928    -------
929    dict
930        Dictionary with 'values' (2D array) and 'periods' (list)
931    "#;
932
933    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
934    use pyo3::types::PyDict;
935
936    let data_slice = data.as_slice()?;
937    let kern = validate_kernel(kernel, true)?;
938
939    let range = ReflexBatchRange { period: periods };
940
941    let combos = expand_grid_checked(&range)
942        .map_err(|e| PyValueError::new_err(format!("reflex batch error: {}", e)))?;
943    let rows = combos.len();
944    let cols = data_slice.len();
945
946    let total = rows
947        .checked_mul(cols)
948        .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
949
950    let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
951    let slice_out = unsafe { out_arr.as_slice_mut()? };
952
953    let metadata = py
954        .allow_threads(|| {
955            let kernel = match kern {
956                Kernel::Auto => detect_best_batch_kernel(),
957                k => k,
958            };
959
960            let simd = match kernel {
961                Kernel::Avx512Batch => Kernel::Avx512,
962                Kernel::Avx2Batch => Kernel::Avx2,
963                Kernel::ScalarBatch => Kernel::Scalar,
964                other => other,
965            };
966
967            reflex_batch_inner_into(data_slice, &range, simd, true, slice_out)
968        })
969        .map_err(|e| PyValueError::new_err(format!("reflex batch error: {}", e)))?;
970
971    let dict = PyDict::new(py);
972
973    let reshaped = out_arr.reshape([rows, cols])?;
974    dict.set_item("values", reshaped)?;
975
976    dict.set_item(
977        "periods",
978        metadata
979            .combos
980            .iter()
981            .map(|c| c.period.unwrap_or(20) as u64)
982            .collect::<Vec<_>>()
983            .into_pyarray(py),
984    )?;
985
986    Ok(dict.into())
987}
988
989#[cfg(all(feature = "python", feature = "cuda"))]
990#[pyfunction(name = "reflex_cuda_batch_dev")]
991#[pyo3(signature = (data_f32, period_range, device_id=0))]
992pub fn reflex_cuda_batch_dev_py(
993    py: Python<'_>,
994    data_f32: numpy::PyReadonlyArray1<'_, f32>,
995    period_range: (usize, usize, usize),
996    device_id: usize,
997) -> PyResult<DeviceArrayF32Py> {
998    if !cuda_available() {
999        return Err(PyValueError::new_err("CUDA not available"));
1000    }
1001
1002    let slice_in = data_f32.as_slice()?;
1003    let sweep = ReflexBatchRange {
1004        period: period_range,
1005    };
1006
1007    let (inner, ctx, dev_id) = py.allow_threads(|| {
1008        let cuda = CudaReflex::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1009        let ctx = cuda.context_arc();
1010        let dev_id = device_id as u32;
1011        cuda.reflex_batch_dev(slice_in, &sweep)
1012            .map(|inner| (inner, ctx, dev_id))
1013            .map_err(|e| PyValueError::new_err(e.to_string()))
1014    })?;
1015    Ok(DeviceArrayF32Py {
1016        inner,
1017        _ctx: Some(ctx),
1018        device_id: Some(dev_id),
1019    })
1020}
1021
1022#[cfg(all(feature = "python", feature = "cuda"))]
1023#[pyfunction(name = "reflex_cuda_many_series_one_param_dev")]
1024#[pyo3(signature = (data_tm_f32, period, device_id=0))]
1025pub fn reflex_cuda_many_series_one_param_dev_py(
1026    py: Python<'_>,
1027    data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1028    period: usize,
1029    device_id: usize,
1030) -> PyResult<DeviceArrayF32Py> {
1031    use numpy::PyUntypedArrayMethods;
1032
1033    if !cuda_available() {
1034        return Err(PyValueError::new_err("CUDA not available"));
1035    }
1036
1037    let shape = data_tm_f32.shape();
1038    if shape.len() != 2 {
1039        return Err(PyValueError::new_err("expected 2D array"));
1040    }
1041    let rows = shape[0];
1042    let cols = shape[1];
1043    let flat = data_tm_f32.as_slice()?;
1044
1045    let (inner, ctx, dev_id) = py.allow_threads(|| {
1046        let cuda = CudaReflex::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1047        let ctx = cuda.context_arc();
1048        let dev_id = device_id as u32;
1049        cuda.reflex_many_series_one_param_time_major_dev(flat, cols, rows, period)
1050            .map(|inner| (inner, ctx, dev_id))
1051            .map_err(|e| PyValueError::new_err(e.to_string()))
1052    })?;
1053    Ok(DeviceArrayF32Py {
1054        inner,
1055        _ctx: Some(ctx),
1056        device_id: Some(dev_id),
1057    })
1058}
1059
1060#[cfg(feature = "python")]
1061#[pyclass(name = "ReflexStream")]
1062pub struct ReflexStreamPy {
1063    inner: ReflexStream,
1064}
1065
1066#[cfg(feature = "python")]
1067#[pymethods]
1068impl ReflexStreamPy {
1069    #[new]
1070    #[pyo3(signature = (period = 20))]
1071    pub fn new(period: usize) -> PyResult<Self> {
1072        let params = ReflexParams {
1073            period: Some(period),
1074        };
1075        let inner = ReflexStream::try_new(params)
1076            .map_err(|e| PyValueError::new_err(format!("reflex stream error: {}", e)))?;
1077        Ok(Self { inner })
1078    }
1079
1080    pub fn update(&mut self, value: f64) -> Option<f64> {
1081        self.inner.update(value)
1082    }
1083}
1084
1085#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1086#[wasm_bindgen]
1087pub fn reflex_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1088    let params = ReflexParams {
1089        period: Some(period),
1090    };
1091    let input = ReflexInput::from_slice(data, params);
1092
1093    let mut output = vec![0.0; data.len()];
1094
1095    reflex_into_slice(&mut output, &input, Kernel::Auto)
1096        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1097
1098    Ok(output)
1099}
1100
1101#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1102#[wasm_bindgen]
1103pub fn reflex_batch_js(
1104    data: &[f64],
1105    period_start: usize,
1106    period_end: usize,
1107    period_step: usize,
1108) -> Result<Vec<f64>, JsValue> {
1109    let range = ReflexBatchRange {
1110        period: (period_start, period_end, period_step),
1111    };
1112
1113    let output = reflex_batch_with_kernel(data, &range, Kernel::Auto)
1114        .map_err(|e| JsValue::from_str(&format!("reflex batch error: {}", e)))?;
1115
1116    Ok(output.values)
1117}
1118
1119#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1120#[wasm_bindgen]
1121pub fn reflex_batch_metadata_js(
1122    period_start: usize,
1123    period_end: usize,
1124    period_step: usize,
1125) -> Vec<usize> {
1126    let range = ReflexBatchRange {
1127        period: (period_start, period_end, period_step),
1128    };
1129    match expand_grid_checked(&range) {
1130        Ok(combos) => combos.iter().map(|c| c.period.unwrap_or(20)).collect(),
1131        Err(_) => Vec::new(),
1132    }
1133}
1134
1135#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1136#[wasm_bindgen]
1137pub fn reflex_batch_rows_cols_js(
1138    period_start: usize,
1139    period_end: usize,
1140    period_step: usize,
1141    data_len: usize,
1142) -> Vec<usize> {
1143    let range = ReflexBatchRange {
1144        period: (period_start, period_end, period_step),
1145    };
1146    let rows = expand_grid_checked(&range).map(|c| c.len()).unwrap_or(0);
1147    vec![rows, data_len]
1148}
1149
1150#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1151#[wasm_bindgen]
1152pub fn reflex_alloc(len: usize) -> *mut f64 {
1153    let mut vec = Vec::<f64>::with_capacity(len);
1154    let ptr = vec.as_mut_ptr();
1155    std::mem::forget(vec);
1156    ptr
1157}
1158
1159#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1160#[wasm_bindgen]
1161pub fn reflex_free(ptr: *mut f64, len: usize) {
1162    if !ptr.is_null() {
1163        unsafe {
1164            let _ = Vec::from_raw_parts(ptr, len, len);
1165        }
1166    }
1167}
1168
1169#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1170#[wasm_bindgen]
1171pub fn reflex_into(
1172    in_ptr: *const f64,
1173    out_ptr: *mut f64,
1174    len: usize,
1175    period: usize,
1176) -> Result<(), JsValue> {
1177    if in_ptr.is_null() || out_ptr.is_null() {
1178        return Err(JsValue::from_str("Null pointer provided"));
1179    }
1180
1181    unsafe {
1182        let data = std::slice::from_raw_parts(in_ptr, len);
1183
1184        if period == 0 || period > len {
1185            return Err(JsValue::from_str("Invalid period"));
1186        }
1187
1188        let params = ReflexParams {
1189            period: Some(period),
1190        };
1191        let input = ReflexInput::from_slice(data, params);
1192
1193        if in_ptr == out_ptr {
1194            let mut temp = vec![0.0; len];
1195            reflex_into_slice(&mut temp, &input, Kernel::Auto)
1196                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1197            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1198            out.copy_from_slice(&temp);
1199        } else {
1200            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1201            reflex_into_slice(out, &input, Kernel::Auto)
1202                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1203        }
1204        Ok(())
1205    }
1206}
1207
1208#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1209#[wasm_bindgen]
1210pub fn reflex_batch_into(
1211    in_ptr: *const f64,
1212    out_ptr: *mut f64,
1213    len: usize,
1214    period_start: usize,
1215    period_end: usize,
1216    period_step: usize,
1217) -> Result<usize, JsValue> {
1218    if in_ptr.is_null() || out_ptr.is_null() {
1219        return Err(JsValue::from_str("null pointer"));
1220    }
1221    let data = unsafe { std::slice::from_raw_parts(in_ptr, len) };
1222    let sweep = ReflexBatchRange {
1223        period: (period_start, period_end, period_step),
1224    };
1225
1226    let combos = expand_grid_checked(&sweep)
1227        .map_err(|e| JsValue::from_str(&format!("reflex batch error: {}", e)))?;
1228    let rows = combos.len();
1229    let cols = len;
1230    let total = rows
1231        .checked_mul(cols)
1232        .ok_or_else(|| JsValue::from_str("size overflow"))?;
1233    let out = unsafe { std::slice::from_raw_parts_mut(out_ptr, total) };
1234
1235    reflex_batch_inner_into(data, &sweep, detect_best_kernel(), false, out)
1236        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1237    Ok(rows)
1238}
1239
1240#[cfg(test)]
1241mod tests {
1242    use super::*;
1243    use crate::skip_if_unsupported;
1244    use crate::utilities::data_loader::read_candles_from_csv;
1245
1246    fn check_reflex_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1247        skip_if_unsupported!(kernel, test_name);
1248        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1249        let candles = read_candles_from_csv(file_path)?;
1250        let default_params = ReflexParams { period: None };
1251        let input = ReflexInput::from_candles(&candles, "close", default_params);
1252        let output = reflex_with_kernel(&input, kernel)?;
1253        assert_eq!(output.values.len(), candles.close.len());
1254        let params_period_14 = ReflexParams { period: Some(14) };
1255        let input2 = ReflexInput::from_candles(&candles, "hl2", params_period_14);
1256        let output2 = reflex_with_kernel(&input2, kernel)?;
1257        assert_eq!(output2.values.len(), candles.close.len());
1258        let params_custom = ReflexParams { period: Some(30) };
1259        let input3 = ReflexInput::from_candles(&candles, "hlc3", params_custom);
1260        let output3 = reflex_with_kernel(&input3, kernel)?;
1261        assert_eq!(output3.values.len(), candles.close.len());
1262        Ok(())
1263    }
1264
1265    fn check_reflex_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1266        skip_if_unsupported!(kernel, test_name);
1267        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1268        let candles = read_candles_from_csv(file_path)?;
1269        let default_params = ReflexParams::default();
1270        let input = ReflexInput::from_candles(&candles, "close", default_params);
1271        let result = reflex_with_kernel(&input, kernel)?;
1272        assert_eq!(result.values.len(), candles.close.len());
1273        let len = result.values.len();
1274        let expected_last_five = [
1275            0.8085220962465361,
1276            0.445264715886137,
1277            0.13861699036615063,
1278            -0.03598639652007061,
1279            -0.224906760543743,
1280        ];
1281        let start_idx = len - 5;
1282        let last_five = &result.values[start_idx..];
1283        for (i, &val) in last_five.iter().enumerate() {
1284            let exp = expected_last_five[i];
1285            assert!(
1286                (val - exp).abs() < 1e-7,
1287                "[{}] Reflex mismatch at idx {}: got {}, expected {}",
1288                test_name,
1289                i,
1290                val,
1291                exp
1292            );
1293        }
1294        Ok(())
1295    }
1296
1297    fn check_reflex_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1298        skip_if_unsupported!(kernel, test_name);
1299        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1300        let candles = read_candles_from_csv(file_path)?;
1301        let input = ReflexInput::with_default_candles(&candles);
1302        match input.data {
1303            ReflexData::Candles { source, .. } => assert_eq!(source, "close"),
1304            _ => panic!("Expected ReflexData::Candles"),
1305        }
1306        let output = reflex_with_kernel(&input, kernel)?;
1307        assert_eq!(output.values.len(), candles.close.len());
1308        Ok(())
1309    }
1310
1311    fn check_reflex_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1312        skip_if_unsupported!(kernel, test_name);
1313        let input_data = [10.0, 20.0, 30.0];
1314        let params = ReflexParams { period: Some(0) };
1315        let input = ReflexInput::from_slice(&input_data, params);
1316        let res = reflex_with_kernel(&input, kernel);
1317        assert!(
1318            res.is_err(),
1319            "[{}] Reflex should fail with zero period",
1320            test_name
1321        );
1322        Ok(())
1323    }
1324
1325    fn check_reflex_period_less_than_two(
1326        test_name: &str,
1327        kernel: Kernel,
1328    ) -> Result<(), Box<dyn Error>> {
1329        skip_if_unsupported!(kernel, test_name);
1330        let input_data = [10.0, 20.0, 30.0];
1331        let params = ReflexParams { period: Some(1) };
1332        let input = ReflexInput::from_slice(&input_data, params);
1333        let res = reflex_with_kernel(&input, kernel);
1334        assert!(
1335            res.is_err(),
1336            "[{}] Reflex should fail with period<2",
1337            test_name
1338        );
1339        Ok(())
1340    }
1341
1342    fn check_reflex_very_small_data_set(
1343        test_name: &str,
1344        kernel: Kernel,
1345    ) -> Result<(), Box<dyn Error>> {
1346        skip_if_unsupported!(kernel, test_name);
1347        let input_data = [42.0];
1348        let params = ReflexParams { period: Some(2) };
1349        let input = ReflexInput::from_slice(&input_data, params);
1350        let res = reflex_with_kernel(&input, kernel);
1351        assert!(
1352            res.is_err(),
1353            "[{}] Reflex should fail with insufficient data",
1354            test_name
1355        );
1356        Ok(())
1357    }
1358
1359    fn check_reflex_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1360        skip_if_unsupported!(kernel, test_name);
1361        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1362        let candles = read_candles_from_csv(file_path)?;
1363        let first_params = ReflexParams { period: Some(14) };
1364        let first_input = ReflexInput::from_candles(&candles, "close", first_params);
1365        let first_result = reflex_with_kernel(&first_input, kernel)?;
1366        assert_eq!(first_result.values.len(), candles.close.len());
1367        let second_params = ReflexParams { period: Some(10) };
1368        let second_input = ReflexInput::from_slice(&first_result.values, second_params);
1369        let second_result = reflex_with_kernel(&second_input, kernel)?;
1370        assert_eq!(second_result.values.len(), first_result.values.len());
1371        for i in 14..second_result.values.len() {
1372            assert!(second_result.values[i].is_finite());
1373        }
1374        Ok(())
1375    }
1376
1377    fn check_reflex_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1378        skip_if_unsupported!(kernel, test_name);
1379        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1380        let candles = read_candles_from_csv(file_path)?;
1381        let period = 14;
1382        let params = ReflexParams {
1383            period: Some(period),
1384        };
1385        let input = ReflexInput::from_candles(&candles, "close", params);
1386        let result = reflex_with_kernel(&input, kernel)?;
1387        assert_eq!(result.values.len(), candles.close.len());
1388        if result.values.len() > period {
1389            for i in period..result.values.len() {
1390                assert!(
1391                    result.values[i].is_finite(),
1392                    "[{}] Unexpected NaN at index {}",
1393                    test_name,
1394                    i
1395                );
1396            }
1397        }
1398        Ok(())
1399    }
1400
1401    fn check_reflex_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1402        skip_if_unsupported!(kernel, test_name);
1403        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1404        let candles = read_candles_from_csv(file_path)?;
1405        let period = 14;
1406        let params = ReflexParams {
1407            period: Some(period),
1408        };
1409        let input = ReflexInput::from_candles(&candles, "close", params.clone());
1410        let batch_output = reflex_with_kernel(&input, kernel)?.values;
1411        let mut stream = ReflexStream::try_new(params)?;
1412        let mut stream_values = Vec::with_capacity(candles.close.len());
1413        for &price in &candles.close {
1414            match stream.update(price) {
1415                Some(v) => stream_values.push(v),
1416                None => stream_values.push(0.0),
1417            }
1418        }
1419        assert_eq!(batch_output.len(), stream_values.len());
1420        for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1421            let diff = (b - s).abs();
1422            assert!(
1423                diff < 1e-9,
1424                "[{}] Reflex streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1425                test_name,
1426                i,
1427                b,
1428                s,
1429                diff
1430            );
1431        }
1432        Ok(())
1433    }
1434
1435    macro_rules! generate_all_reflex_tests {
1436        ($($test_fn:ident),*) => {
1437            paste::paste! {
1438                $(
1439                    #[test]
1440                    fn [<$test_fn _scalar_f64>]() {
1441                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1442                    }
1443                )*
1444                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1445                $(
1446                    #[test]
1447                    fn [<$test_fn _avx2_f64>]() {
1448                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1449                    }
1450                    #[test]
1451                    fn [<$test_fn _avx512_f64>]() {
1452                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1453                    }
1454                )*
1455            }
1456        }
1457    }
1458
1459    #[cfg(debug_assertions)]
1460    fn check_reflex_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1461        skip_if_unsupported!(kernel, test_name);
1462
1463        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1464        let candles = read_candles_from_csv(file_path)?;
1465
1466        let test_cases = vec![
1467            ReflexParams { period: Some(20) },
1468            ReflexParams { period: Some(2) },
1469            ReflexParams { period: Some(5) },
1470            ReflexParams { period: Some(10) },
1471            ReflexParams { period: Some(30) },
1472            ReflexParams { period: Some(50) },
1473            ReflexParams { period: Some(15) },
1474            ReflexParams { period: Some(40) },
1475            ReflexParams { period: None },
1476        ];
1477
1478        for params in test_cases {
1479            let input = ReflexInput::from_candles(&candles, "close", params);
1480            let output = reflex_with_kernel(&input, kernel)?;
1481
1482            for (i, &val) in output.values.iter().enumerate() {
1483                if val.is_nan() {
1484                    continue;
1485                }
1486
1487                let bits = val.to_bits();
1488
1489                if bits == 0x11111111_11111111 {
1490                    panic!(
1491                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1492                         with params period={:?}",
1493                        test_name, val, bits, i, params.period
1494                    );
1495                }
1496
1497                if bits == 0x22222222_22222222 {
1498                    panic!(
1499                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1500                         with params period={:?}",
1501                        test_name, val, bits, i, params.period
1502                    );
1503                }
1504
1505                if bits == 0x33333333_33333333 {
1506                    panic!(
1507                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1508                         with params period={:?}",
1509                        test_name, val, bits, i, params.period
1510                    );
1511                }
1512            }
1513        }
1514
1515        Ok(())
1516    }
1517
1518    #[cfg(not(debug_assertions))]
1519    fn check_reflex_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1520        Ok(())
1521    }
1522
1523    #[cfg(feature = "proptest")]
1524    #[allow(clippy::float_cmp)]
1525    fn check_reflex_property(
1526        test_name: &str,
1527        kernel: Kernel,
1528    ) -> Result<(), Box<dyn std::error::Error>> {
1529        use proptest::prelude::*;
1530        skip_if_unsupported!(kernel, test_name);
1531
1532        let strat = (2usize..=50).prop_flat_map(|period| {
1533            (
1534                prop::collection::vec(
1535                    (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1536                    period..400,
1537                ),
1538                Just(period),
1539            )
1540        });
1541
1542        proptest::test_runner::TestRunner::default()
1543            .run(&strat, |(data, period)| {
1544                let params = ReflexParams {
1545                    period: Some(period),
1546                };
1547                let input = ReflexInput::from_slice(&data, params);
1548
1549                let ReflexOutput { values: out } = reflex_with_kernel(&input, kernel).unwrap();
1550                let ReflexOutput { values: ref_out } =
1551                    reflex_with_kernel(&input, Kernel::Scalar).unwrap();
1552
1553                prop_assert_eq!(out.len(), data.len());
1554
1555                for i in 0..period.min(data.len()) {
1556                    prop_assert!(
1557                        out[i] == 0.0,
1558                        "[{}] idx {}: expected 0.0 during warmup, got {}",
1559                        test_name,
1560                        i,
1561                        out[i]
1562                    );
1563                }
1564
1565                for i in 0..data.len() {
1566                    let y = out[i];
1567                    let r = ref_out[i];
1568
1569                    if !y.is_finite() || !r.is_finite() {
1570                        prop_assert_eq!(
1571                            y.to_bits(),
1572                            r.to_bits(),
1573                            "[{}] finite/NaN mismatch idx {}: {} vs {}",
1574                            test_name,
1575                            i,
1576                            y,
1577                            r
1578                        );
1579                        continue;
1580                    }
1581
1582                    let ulp_diff = y.to_bits().abs_diff(r.to_bits());
1583                    prop_assert!(
1584                        (y - r).abs() <= 1e-9 || ulp_diff <= 4,
1585                        "[{}] mismatch idx {}: {} vs {} (ULP={})",
1586                        test_name,
1587                        i,
1588                        y,
1589                        r,
1590                        ulp_diff
1591                    );
1592                }
1593
1594                for i in period..data.len() {
1595                    if data[i].abs() < 1e10 {
1596                        prop_assert!(
1597                            out[i].is_finite(),
1598                            "[{}] idx {}: expected finite, got {}",
1599                            test_name,
1600                            i,
1601                            out[i]
1602                        );
1603                    }
1604                }
1605
1606                if data.windows(2).all(|w| (w[0] - w[1]).abs() < f64::EPSILON) {
1607                    for i in (period * 2)..data.len() {
1608                        prop_assert!(
1609                            out[i].abs() < 0.001,
1610                            "[{}] idx {}: constant data should yield near-zero, got {}",
1611                            test_name,
1612                            i,
1613                            out[i]
1614                        );
1615                    }
1616                }
1617
1618                Ok(())
1619            })
1620            .unwrap();
1621
1622        Ok(())
1623    }
1624
1625    #[cfg(not(feature = "proptest"))]
1626    fn check_reflex_property(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1627        skip_if_unsupported!(kernel, test_name);
1628        Ok(())
1629    }
1630
1631    generate_all_reflex_tests!(
1632        check_reflex_partial_params,
1633        check_reflex_accuracy,
1634        check_reflex_default_candles,
1635        check_reflex_zero_period,
1636        check_reflex_period_less_than_two,
1637        check_reflex_very_small_data_set,
1638        check_reflex_reinput,
1639        check_reflex_nan_handling,
1640        check_reflex_streaming,
1641        check_reflex_no_poison,
1642        check_reflex_property
1643    );
1644
1645    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1646        skip_if_unsupported!(kernel, test);
1647
1648        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1649        let c = read_candles_from_csv(file)?;
1650
1651        let output = ReflexBatchBuilder::new()
1652            .kernel(kernel)
1653            .apply_candles(&c, "close")?;
1654
1655        let def = ReflexParams::default();
1656        let row = output.values_for(&def).expect("default row missing");
1657        assert_eq!(row.len(), c.close.len());
1658
1659        let expected = [
1660            0.8085220962465361,
1661            0.445264715886137,
1662            0.13861699036615063,
1663            -0.03598639652007061,
1664            -0.224906760543743,
1665        ];
1666        let start = row.len() - 5;
1667        for (i, &v) in row[start..].iter().enumerate() {
1668            assert!(
1669                (v - expected[i]).abs() < 1e-7,
1670                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1671            );
1672        }
1673        Ok(())
1674    }
1675
1676    macro_rules! gen_batch_tests {
1677        ($fn_name:ident) => {
1678            paste::paste! {
1679                #[test] fn [<$fn_name _scalar>]()      {
1680                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1681                }
1682                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1683                #[test] fn [<$fn_name _avx2>]()        {
1684                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1685                }
1686                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1687                #[test] fn [<$fn_name _avx512>]()      {
1688                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1689                }
1690                #[test] fn [<$fn_name _auto_detect>]() {
1691                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1692                }
1693            }
1694        };
1695    }
1696
1697    #[cfg(debug_assertions)]
1698    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1699        skip_if_unsupported!(kernel, test);
1700
1701        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1702        let c = read_candles_from_csv(file)?;
1703
1704        let batch_configs = vec![
1705            (10, 30, 10),
1706            (20, 20, 0),
1707            (2, 10, 2),
1708            (25, 50, 25),
1709            (5, 20, 5),
1710            (15, 45, 15),
1711            (3, 15, 3),
1712            (30, 60, 10),
1713        ];
1714
1715        for (p_start, p_end, p_step) in batch_configs {
1716            let output = ReflexBatchBuilder::new()
1717                .kernel(kernel)
1718                .period_range(p_start, p_end, p_step)
1719                .apply_candles(&c, "close")?;
1720
1721            for (idx, &val) in output.values.iter().enumerate() {
1722                if val.is_nan() {
1723                    continue;
1724                }
1725
1726                let bits = val.to_bits();
1727                let row = idx / output.cols;
1728                let col = idx % output.cols;
1729                let combo = &output.combos[row];
1730
1731                if bits == 0x11111111_11111111 {
1732                    panic!(
1733						"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} \
1734                         (flat index {}) with params period={:?}",
1735						test, val, bits, row, col, idx, combo.period
1736					);
1737                }
1738
1739                if bits == 0x22222222_22222222 {
1740                    panic!(
1741						"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} \
1742                         (flat index {}) with params period={:?}",
1743						test, val, bits, row, col, idx, combo.period
1744					);
1745                }
1746
1747                if bits == 0x33333333_33333333 {
1748                    panic!(
1749						"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} \
1750                         (flat index {}) with params period={:?}",
1751						test, val, bits, row, col, idx, combo.period
1752					);
1753                }
1754            }
1755        }
1756
1757        Ok(())
1758    }
1759
1760    #[cfg(not(debug_assertions))]
1761    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1762        Ok(())
1763    }
1764
1765    #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1766    #[test]
1767    fn test_reflex_into_matches_api() -> Result<(), Box<dyn Error>> {
1768        let mut data = vec![f64::from_bits(0x7ff8_0000_0000_0000); 3];
1769        data.extend((0..256).map(|i| ((i as f64) * 0.1).sin() * 1.23 + (i as f64) * 0.01));
1770
1771        let input = ReflexInput::from_slice(&data, ReflexParams::default());
1772
1773        let baseline = reflex_with_kernel(&input, Kernel::Auto)?.values;
1774
1775        let mut out = vec![0.0; data.len()];
1776        super::reflex_into(&input, &mut out)?;
1777
1778        assert_eq!(baseline.len(), out.len());
1779
1780        fn eq_or_both_nan(a: f64, b: f64) -> bool {
1781            (a.is_nan() && b.is_nan()) || (a == b) || ((a - b).abs() <= 1e-12)
1782        }
1783
1784        for i in 0..out.len() {
1785            assert!(
1786                eq_or_both_nan(baseline[i], out[i]),
1787                "mismatch at {}: baseline={} out={}",
1788                i,
1789                baseline[i],
1790                out[i]
1791            );
1792        }
1793
1794        Ok(())
1795    }
1796
1797    gen_batch_tests!(check_batch_default_row);
1798    gen_batch_tests!(check_batch_no_poison);
1799}