1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5 make_uninit_matrix,
6};
7use aligned_vec::{AVec, CACHELINE_ALIGN};
8#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
9use core::arch::x86_64::*;
10#[cfg(not(target_arch = "wasm32"))]
11use rayon::prelude::*;
12use std::convert::AsRef;
13use std::error::Error;
14use std::mem::MaybeUninit;
15use thiserror::Error;
16
17#[cfg(all(feature = "python", feature = "cuda"))]
18use crate::cuda::cuda_available;
19#[cfg(all(feature = "python", feature = "cuda"))]
20use crate::cuda::moving_averages::CudaReflex;
21#[cfg(all(feature = "python", feature = "cuda"))]
22use crate::cuda::moving_averages::DeviceArrayF32;
23#[cfg(all(feature = "python", feature = "cuda"))]
24use crate::utilities::dlpack_cuda::DeviceArrayF32Py;
25#[cfg(feature = "python")]
26use crate::utilities::kernel_validation::validate_kernel;
27#[cfg(feature = "python")]
28use numpy::{IntoPyArray, PyArray1, PyArray2, PyArrayMethods, PyReadonlyArray1, PyReadonlyArray2};
29#[cfg(feature = "python")]
30use pyo3::exceptions::PyValueError;
31#[cfg(feature = "python")]
32use pyo3::prelude::*;
33#[cfg(feature = "python")]
34use pyo3::types::{PyDict, PyList};
35
36#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
37use wasm_bindgen::prelude::*;
38
39impl<'a> AsRef<[f64]> for ReflexInput<'a> {
40 #[inline(always)]
41 fn as_ref(&self) -> &[f64] {
42 match &self.data {
43 ReflexData::Slice(slice) => slice,
44 ReflexData::Candles { candles, source } => source_type(candles, source),
45 }
46 }
47}
48
49#[derive(Debug, Clone)]
50pub enum ReflexData<'a> {
51 Candles {
52 candles: &'a Candles,
53 source: &'a str,
54 },
55 Slice(&'a [f64]),
56}
57
58#[derive(Debug, Clone)]
59pub struct ReflexOutput {
60 pub values: Vec<f64>,
61}
62
63#[derive(Debug, Clone, Copy)]
64pub struct ReflexParams {
65 pub period: Option<usize>,
66}
67
68impl Default for ReflexParams {
69 fn default() -> Self {
70 Self { period: Some(20) }
71 }
72}
73
74#[derive(Debug, Clone)]
75pub struct ReflexInput<'a> {
76 pub data: ReflexData<'a>,
77 pub params: ReflexParams,
78}
79
80impl<'a> ReflexInput<'a> {
81 #[inline]
82 pub fn from_candles(c: &'a Candles, s: &'a str, p: ReflexParams) -> Self {
83 Self {
84 data: ReflexData::Candles {
85 candles: c,
86 source: s,
87 },
88 params: p,
89 }
90 }
91 #[inline]
92 pub fn from_slice(sl: &'a [f64], p: ReflexParams) -> Self {
93 Self {
94 data: ReflexData::Slice(sl),
95 params: p,
96 }
97 }
98 #[inline]
99 pub fn with_default_candles(c: &'a Candles) -> Self {
100 Self::from_candles(c, "close", ReflexParams::default())
101 }
102 #[inline]
103 pub fn get_period(&self) -> usize {
104 self.params.period.unwrap_or(20)
105 }
106}
107
108#[derive(Copy, Clone, Debug)]
109pub struct ReflexBuilder {
110 period: Option<usize>,
111 kernel: Kernel,
112}
113
114impl Default for ReflexBuilder {
115 fn default() -> Self {
116 Self {
117 period: None,
118 kernel: Kernel::Auto,
119 }
120 }
121}
122
123impl ReflexBuilder {
124 #[inline(always)]
125 pub fn new() -> Self {
126 Self::default()
127 }
128 #[inline(always)]
129 pub fn period(mut self, n: usize) -> Self {
130 self.period = Some(n);
131 self
132 }
133 #[inline(always)]
134 pub fn kernel(mut self, k: Kernel) -> Self {
135 self.kernel = k;
136 self
137 }
138 #[inline(always)]
139 pub fn apply(self, c: &Candles) -> Result<ReflexOutput, ReflexError> {
140 let p = ReflexParams {
141 period: self.period,
142 };
143 let i = ReflexInput::from_candles(c, "close", p);
144 reflex_with_kernel(&i, self.kernel)
145 }
146 #[inline(always)]
147 pub fn apply_slice(self, d: &[f64]) -> Result<ReflexOutput, ReflexError> {
148 let p = ReflexParams {
149 period: self.period,
150 };
151 let i = ReflexInput::from_slice(d, p);
152 reflex_with_kernel(&i, self.kernel)
153 }
154 #[inline(always)]
155 pub fn into_stream(self) -> Result<ReflexStream, ReflexError> {
156 let p = ReflexParams {
157 period: self.period,
158 };
159 ReflexStream::try_new(p)
160 }
161}
162
163#[derive(Debug, Error)]
164pub enum ReflexError {
165 #[error("reflex: No data available (input data slice is empty).")]
166 EmptyInputData,
167 #[error("reflex: All values are NaN.")]
168 AllValuesNaN,
169 #[error("reflex: period must be >=2 (period = {period}, data length = {data_len})")]
170 InvalidPeriod { period: usize, data_len: usize },
171 #[error("reflex: Not enough data: needed = {needed}, valid = {valid}")]
172 NotEnoughValidData { needed: usize, valid: usize },
173 #[error("reflex: output length mismatch: expected = {expected}, got = {got}")]
174 OutputLengthMismatch { expected: usize, got: usize },
175 #[error("reflex: invalid kernel for batch: {0:?}")]
176 InvalidKernelForBatch(Kernel),
177 #[error("reflex: invalid range: start = {start}, end = {end}, step = {step}")]
178 InvalidRange {
179 start: usize,
180 end: usize,
181 step: usize,
182 },
183}
184
185#[inline]
186pub fn reflex(input: &ReflexInput) -> Result<ReflexOutput, ReflexError> {
187 reflex_with_kernel(input, Kernel::Auto)
188}
189
190pub fn reflex_with_kernel(
191 input: &ReflexInput,
192 kernel: Kernel,
193) -> Result<ReflexOutput, ReflexError> {
194 let (data, period, first, chosen) = reflex_prepare(input, kernel)?;
195 let len = data.len();
196
197 let mut out = alloc_with_nan_prefix(len, period);
198
199 reflex_compute_into(data, period, first, chosen, &mut out);
200
201 out[..period.min(len)].fill(0.0);
202
203 Ok(ReflexOutput { values: out })
204}
205
206#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
207#[inline]
208pub fn reflex_into(input: &ReflexInput, out: &mut [f64]) -> Result<(), ReflexError> {
209 reflex_into_slice(out, input, Kernel::Auto)
210}
211
212#[inline]
213pub fn reflex_into_slice(
214 dst: &mut [f64],
215 input: &ReflexInput,
216 kern: Kernel,
217) -> Result<(), ReflexError> {
218 let (data, period, first, chosen) = reflex_prepare(input, kern)?;
219
220 if dst.len() != data.len() {
221 return Err(ReflexError::OutputLengthMismatch {
222 expected: data.len(),
223 got: dst.len(),
224 });
225 }
226
227 reflex_compute_into(data, period, first, chosen, dst);
228
229 let end = period.min(dst.len());
230 for x in &mut dst[..end] {
231 *x = 0.0;
232 }
233
234 Ok(())
235}
236
237#[inline]
238pub fn reflex_scalar(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
239 let len = data.len();
240 if len < 2 || period < 2 {
241 return;
242 }
243
244 let half_p = (period / 2).max(1) as f64;
245 let a = (-1.414_f64 * std::f64::consts::PI / half_p).exp();
246 let a2 = a * a;
247 let b = 2.0 * a * (1.414_f64 * std::f64::consts::PI / half_p).cos();
248 let c = 0.5 * (1.0 + a2 - b);
249
250 let ring_len = period + 1;
251 let mut ssf = vec![0.0_f64; ring_len];
252
253 ssf[0] = data[0];
254 if len > 1 {
255 ssf[1] = data[1];
256 }
257
258 let mut ssf_sum = ssf[0] + ssf[1];
259
260 let inv_p = 1.0 / (period as f64);
261 let alpha = 0.5 * (1.0 + inv_p);
262 let beta = 1.0 - alpha;
263
264 let mut ms = 0.0_f64;
265
266 let d_ptr = data.as_ptr();
267 let o_ptr = out.as_mut_ptr();
268
269 let mut idx_im2 = 0usize;
270 let mut idx_im1 = 1usize;
271 let mut idx = 2usize;
272
273 unsafe {
274 let mut i = 2usize;
275 while i < len {
276 let di = *d_ptr.add(i);
277 let dim1 = *d_ptr.add(i - 1);
278 let ssf_im1 = *ssf.get_unchecked(idx_im1);
279 let ssf_im2 = *ssf.get_unchecked(idx_im2);
280
281 let t0 = c * (di + dim1);
282 let t1 = (-a2).mul_add(ssf_im2, t0);
283 let ssf_i = b.mul_add(ssf_im1, t1);
284
285 *ssf.get_unchecked_mut(idx) = ssf_i;
286
287 if i < period {
288 ssf_sum += ssf_i;
289 } else {
290 let mut idx_ip = idx + 1;
291 if idx_ip == ring_len {
292 idx_ip = 0;
293 }
294 let ssf_ip = *ssf.get_unchecked(idx_ip);
295
296 let mean_lp = ssf_sum * inv_p;
297 let my_sum = ssf_i.mul_add(beta, ssf_ip * alpha) - mean_lp;
298
299 ms = (0.96_f64).mul_add(ms, 0.04_f64 * (my_sum * my_sum));
300 if ms > 0.0 {
301 *o_ptr.add(i) = my_sum / ms.sqrt();
302 }
303
304 ssf_sum += ssf_i - ssf_ip;
305 }
306
307 idx_im2 = idx_im1;
308 idx_im1 = idx;
309 idx += 1;
310 if idx == ring_len {
311 idx = 0;
312 }
313
314 i += 1;
315 }
316 }
317}
318
319#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
320#[target_feature(enable = "avx2,fma")]
321pub unsafe fn reflex_avx2(data: &[f64], period: usize, _first: usize, out: &mut [f64]) {
322 reflex_scalar(data, period, _first, out)
323}
324
325#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
326#[target_feature(enable = "avx512f,avx512dq,fma")]
327pub unsafe fn reflex_avx512(data: &[f64], period: usize, _first: usize, out: &mut [f64]) {
328 reflex_scalar(data, period, _first, out)
329}
330
331#[inline(always)]
332fn reflex_prepare<'a>(
333 input: &'a ReflexInput,
334 kernel: Kernel,
335) -> Result<(&'a [f64], usize, usize, Kernel), ReflexError> {
336 let data: &[f64] = match &input.data {
337 ReflexData::Candles { candles, source } => source_type(candles, source),
338 ReflexData::Slice(sl) => sl,
339 };
340
341 let len = data.len();
342 if len == 0 {
343 return Err(ReflexError::EmptyInputData);
344 }
345
346 let first = data
347 .iter()
348 .position(|x| !x.is_nan())
349 .ok_or(ReflexError::AllValuesNaN)?;
350 let period = input.get_period();
351
352 if period < 2 {
353 return Err(ReflexError::InvalidPeriod {
354 period,
355 data_len: len,
356 });
357 }
358 if period > (len - first) {
359 return Err(ReflexError::NotEnoughValidData {
360 needed: period,
361 valid: len - first,
362 });
363 }
364
365 let chosen = match kernel {
366 Kernel::Auto => Kernel::Scalar,
367 other => other,
368 };
369
370 Ok((data, period, first, chosen))
371}
372
373#[inline(always)]
374fn reflex_compute_into(data: &[f64], period: usize, first: usize, kernel: Kernel, out: &mut [f64]) {
375 unsafe {
376 match kernel {
377 Kernel::Scalar | Kernel::ScalarBatch => reflex_scalar(data, period, first, out),
378 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
379 Kernel::Avx2 | Kernel::Avx2Batch => reflex_avx2(data, period, first, out),
380 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
381 Kernel::Avx2 | Kernel::Avx2Batch => reflex_scalar(data, period, first, out),
382 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
383 Kernel::Avx512 | Kernel::Avx512Batch => reflex_avx512(data, period, first, out),
384 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
385 Kernel::Avx512 | Kernel::Avx512Batch => reflex_scalar(data, period, first, out),
386 _ => unreachable!(),
387 }
388 }
389}
390
391#[derive(Debug, Clone)]
392pub struct ReflexStream {
393 period: usize,
394
395 a_sq: f64,
396 b: f64,
397 c: f64,
398
399 alpha: f64,
400 beta: f64,
401 inv_p: f64,
402
403 ssf_buf: Vec<f64>,
404 head: usize,
405 tail: usize,
406
407 ssf_sum: f64,
408 last_ms: f64,
409 prev_x: f64,
410 last_ssf1: f64,
411 last_ssf2: f64,
412 count: usize,
413}
414
415impl ReflexStream {
416 #[inline]
417 pub fn try_new(params: ReflexParams) -> Result<Self, ReflexError> {
418 let period = params.period.unwrap_or(20);
419 if period < 2 {
420 return Err(ReflexError::InvalidPeriod {
421 period,
422 data_len: 0,
423 });
424 }
425
426 let half_p = (period / 2).max(1) as f64;
427 let a = (-1.414_f64 * std::f64::consts::PI / half_p).exp();
428 let a_sq = a * a;
429 let b = 2.0 * a * (1.414_f64 * std::f64::consts::PI / half_p).cos();
430 let c = 0.5 * (1.0 + a_sq - b);
431
432 let inv_p = 1.0 / (period as f64);
433 let alpha = 0.5 * (1.0 + inv_p);
434 let beta = 1.0 - alpha;
435
436 Ok(Self {
437 period,
438 a_sq,
439 b,
440 c,
441 alpha,
442 beta,
443 inv_p,
444
445 ssf_buf: vec![0.0; period + 1],
446 head: 0,
447 tail: 0,
448
449 ssf_sum: 0.0,
450 last_ms: 0.0,
451 prev_x: 0.0,
452 last_ssf1: 0.0,
453 last_ssf2: 0.0,
454 count: 0,
455 })
456 }
457
458 #[inline(always)]
459 pub fn update(&mut self, x: f64) -> Option<f64> {
460 let p = self.period;
461 let ring_len = p + 1;
462 let t = self.count;
463
464 if t == 0 {
465 self.prev_x = x;
466 self.last_ssf1 = x;
467 self.ssf_buf[self.head] = x;
468 self.head += 1;
469 if self.head == ring_len {
470 self.head = 0;
471 }
472 self.ssf_sum += x;
473 self.count = 1;
474 return None;
475 }
476 if t == 1 {
477 self.prev_x = x;
478 self.last_ssf2 = self.last_ssf1;
479 self.last_ssf1 = x;
480 self.ssf_buf[self.head] = x;
481 self.head += 1;
482 if self.head == ring_len {
483 self.head = 0;
484 }
485 self.ssf_sum += x;
486 self.count = 2;
487 return None;
488 }
489
490 let t0 = self.c * (x + self.prev_x);
491 let t1 = (-self.a_sq).mul_add(self.last_ssf2, t0);
492 let ssf_t = self.b.mul_add(self.last_ssf1, t1);
493
494 let mut out = None;
495 if t >= p {
496 let ssf_tp = self.ssf_buf[self.tail];
497
498 let mean_lp = self.ssf_sum * self.inv_p;
499
500 let my_sum = self.beta.mul_add(ssf_t, self.alpha * ssf_tp) - mean_lp;
501
502 let ms = 0.96_f64.mul_add(self.last_ms, 0.04_f64 * (my_sum * my_sum));
503 self.last_ms = ms;
504 out = if ms > 0.0 {
505 Some(my_sum / ms.sqrt())
506 } else {
507 Some(0.0)
508 };
509
510 self.ssf_sum += ssf_t - ssf_tp;
511 self.tail += 1;
512 if self.tail == ring_len {
513 self.tail = 0;
514 }
515 } else {
516 self.ssf_sum += ssf_t;
517 }
518
519 self.ssf_buf[self.head] = ssf_t;
520 self.head += 1;
521 if self.head == ring_len {
522 self.head = 0;
523 }
524
525 self.prev_x = x;
526 self.last_ssf2 = self.last_ssf1;
527 self.last_ssf1 = ssf_t;
528 self.count = t + 1;
529
530 out
531 }
532}
533
534#[derive(Clone, Debug)]
535pub struct ReflexBatchRange {
536 pub period: (usize, usize, usize),
537}
538
539impl Default for ReflexBatchRange {
540 fn default() -> Self {
541 Self {
542 period: (20, 269, 1),
543 }
544 }
545}
546
547#[derive(Clone, Debug, Default)]
548pub struct ReflexBatchBuilder {
549 range: ReflexBatchRange,
550 kernel: Kernel,
551}
552
553impl ReflexBatchBuilder {
554 pub fn new() -> Self {
555 Self::default()
556 }
557 pub fn kernel(mut self, k: Kernel) -> Self {
558 self.kernel = k;
559 self
560 }
561 #[inline]
562 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
563 self.range.period = (start, end, step);
564 self
565 }
566 #[inline]
567 pub fn period_static(mut self, p: usize) -> Self {
568 self.range.period = (p, p, 0);
569 self
570 }
571 pub fn apply_slice(self, data: &[f64]) -> Result<ReflexBatchOutput, ReflexError> {
572 reflex_batch_with_kernel(data, &self.range, self.kernel)
573 }
574 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<ReflexBatchOutput, ReflexError> {
575 ReflexBatchBuilder::new().kernel(k).apply_slice(data)
576 }
577 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<ReflexBatchOutput, ReflexError> {
578 let slice = source_type(c, src);
579 self.apply_slice(slice)
580 }
581 pub fn with_default_candles(c: &Candles) -> Result<ReflexBatchOutput, ReflexError> {
582 ReflexBatchBuilder::new()
583 .kernel(Kernel::Auto)
584 .apply_candles(c, "close")
585 }
586}
587
588pub fn reflex_batch_with_kernel(
589 data: &[f64],
590 sweep: &ReflexBatchRange,
591 k: Kernel,
592) -> Result<ReflexBatchOutput, ReflexError> {
593 let kernel = match k {
594 Kernel::Auto => detect_best_batch_kernel(),
595 other if other.is_batch() => other,
596 other => return Err(ReflexError::InvalidKernelForBatch(other)),
597 };
598 let simd = match kernel {
599 Kernel::Avx512Batch => Kernel::Avx512,
600 Kernel::Avx2Batch => Kernel::Avx2,
601 Kernel::ScalarBatch => Kernel::Scalar,
602 _ => unreachable!(),
603 };
604 reflex_batch_par_slice(data, sweep, simd)
605}
606
607#[derive(Clone, Debug)]
608pub struct ReflexBatchOutput {
609 pub values: Vec<f64>,
610 pub combos: Vec<ReflexParams>,
611 pub rows: usize,
612 pub cols: usize,
613}
614
615impl ReflexBatchOutput {
616 pub fn row_for_params(&self, p: &ReflexParams) -> Option<usize> {
617 self.combos
618 .iter()
619 .position(|c| c.period.unwrap_or(20) == p.period.unwrap_or(20))
620 }
621 pub fn values_for(&self, p: &ReflexParams) -> Option<&[f64]> {
622 self.row_for_params(p).map(|row| {
623 let start = row * self.cols;
624 &self.values[start..start + self.cols]
625 })
626 }
627}
628
629#[inline(always)]
630fn expand_grid_checked(r: &ReflexBatchRange) -> Result<Vec<ReflexParams>, ReflexError> {
631 fn axis_usize(range: (usize, usize, usize)) -> Result<Vec<usize>, ReflexError> {
632 let (start, end, step) = range;
633 if step == 0 || start == end {
634 return Ok(vec![start]);
635 }
636 let mut out = Vec::new();
637 if start < end {
638 let mut cur = start;
639 while cur <= end {
640 out.push(cur);
641 cur = match cur.checked_add(step) {
642 Some(v) => v,
643 None => break,
644 };
645 }
646 } else {
647 let mut cur = start;
648 while cur >= end {
649 out.push(cur);
650 cur = match cur.checked_sub(step) {
651 Some(v) => v,
652 None => break,
653 };
654 if cur == 0 {
655 break;
656 }
657 }
658 }
659 if out.is_empty() {
660 return Err(ReflexError::InvalidRange { start, end, step });
661 }
662 Ok(out)
663 }
664 let periods = axis_usize(r.period)?;
665 Ok(periods
666 .into_iter()
667 .map(|p| ReflexParams { period: Some(p) })
668 .collect())
669}
670
671#[inline(always)]
672pub fn reflex_batch_slice(
673 data: &[f64],
674 sweep: &ReflexBatchRange,
675 kern: Kernel,
676) -> Result<ReflexBatchOutput, ReflexError> {
677 reflex_batch_inner(data, sweep, kern, false)
678}
679
680#[inline(always)]
681pub fn reflex_batch_par_slice(
682 data: &[f64],
683 sweep: &ReflexBatchRange,
684 kern: Kernel,
685) -> Result<ReflexBatchOutput, ReflexError> {
686 reflex_batch_inner(data, sweep, kern, true)
687}
688
689#[inline(always)]
690fn reflex_batch_inner(
691 data: &[f64],
692 sweep: &ReflexBatchRange,
693 kern: Kernel,
694 parallel: bool,
695) -> Result<ReflexBatchOutput, ReflexError> {
696 let combos = expand_grid_checked(sweep)?;
697 let first = data
698 .iter()
699 .position(|x| !x.is_nan())
700 .ok_or(ReflexError::AllValuesNaN)?;
701 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
702 if data.len() - first < max_p {
703 return Err(ReflexError::NotEnoughValidData {
704 needed: max_p,
705 valid: data.len() - first,
706 });
707 }
708
709 let rows = combos.len();
710 let cols = data.len();
711
712 let _total = rows.checked_mul(cols).ok_or(ReflexError::InvalidRange {
713 start: rows,
714 end: cols,
715 step: 0,
716 })?;
717
718 let mut buf_mu = make_uninit_matrix(rows, cols);
719 let warm: Vec<usize> = combos.iter().map(|c| c.period.unwrap()).collect();
720 init_matrix_prefixes(&mut buf_mu, cols, &warm);
721
722 let mut guard = core::mem::ManuallyDrop::new(buf_mu);
723 let out: &mut [f64] =
724 unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
725
726 let kernel = match kern {
727 Kernel::Auto => detect_best_batch_kernel(),
728 other => other,
729 };
730
731 let simd = match kernel {
732 Kernel::Avx512Batch => Kernel::Avx512,
733 Kernel::Avx2Batch => Kernel::Avx2,
734 Kernel::ScalarBatch => Kernel::Scalar,
735 other => other,
736 };
737
738 let meta = reflex_batch_inner_into(data, sweep, simd, parallel, out)?;
739
740 let values = unsafe {
741 Vec::from_raw_parts(
742 guard.as_mut_ptr() as *mut f64,
743 guard.len(),
744 guard.capacity(),
745 )
746 };
747
748 Ok(ReflexBatchOutput {
749 values,
750 combos: meta.combos,
751 rows: meta.rows,
752 cols: meta.cols,
753 })
754}
755
756#[inline(always)]
757unsafe fn reflex_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
758 reflex_scalar(data, period, first, out)
759}
760
761#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
762#[inline(always)]
763unsafe fn reflex_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
764 reflex_avx2(data, period, first, out)
765}
766
767#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
768#[inline(always)]
769unsafe fn reflex_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
770 reflex_avx512(data, period, first, out)
771}
772
773#[inline(always)]
774fn reflex_batch_inner_into(
775 data: &[f64],
776 sweep: &ReflexBatchRange,
777 kern: Kernel,
778 parallel: bool,
779 out: &mut [f64],
780) -> Result<ReflexBatchMetadata, ReflexError> {
781 let combos = expand_grid_checked(sweep)?;
782
783 let first = data
784 .iter()
785 .position(|x| !x.is_nan())
786 .ok_or(ReflexError::AllValuesNaN)?;
787 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
788 if data.len() - first < max_p {
789 return Err(ReflexError::NotEnoughValidData {
790 needed: max_p,
791 valid: data.len() - first,
792 });
793 }
794
795 let rows = combos.len();
796 let cols = data.len();
797
798 let expected = rows.checked_mul(cols).ok_or(ReflexError::InvalidRange {
799 start: rows,
800 end: cols,
801 step: 0,
802 })?;
803 if out.len() != expected {
804 return Err(ReflexError::OutputLengthMismatch {
805 expected,
806 got: out.len(),
807 });
808 }
809
810 let do_row = |row: usize, dst: &mut [f64]| unsafe {
811 let period = combos[row].period.unwrap();
812
813 for x in &mut dst[..period.min(cols)] {
814 *x = 0.0;
815 }
816
817 match kern {
818 Kernel::Scalar => reflex_row_scalar(data, first, period, dst),
819 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
820 Kernel::Avx2 => reflex_row_avx2(data, first, period, dst),
821 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
822 Kernel::Avx2 => reflex_row_scalar(data, first, period, dst),
823 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
824 Kernel::Avx512 => reflex_row_avx512(data, first, period, dst),
825 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
826 Kernel::Avx512 => reflex_row_scalar(data, first, period, dst),
827 _ => unreachable!(),
828 }
829 };
830
831 if parallel {
832 #[cfg(not(target_arch = "wasm32"))]
833 {
834 out.par_chunks_mut(cols)
835 .enumerate()
836 .for_each(|(row, slice)| do_row(row, slice));
837 }
838 #[cfg(target_arch = "wasm32")]
839 {
840 for (row, slice) in out.chunks_mut(cols).enumerate() {
841 do_row(row, slice);
842 }
843 }
844 } else {
845 for (row, slice) in out.chunks_mut(cols).enumerate() {
846 do_row(row, slice);
847 }
848 }
849
850 Ok(ReflexBatchMetadata { combos, rows, cols })
851}
852
853#[derive(Clone, Debug)]
854pub struct ReflexBatchMetadata {
855 pub combos: Vec<ReflexParams>,
856 pub rows: usize,
857 pub cols: usize,
858}
859
860#[cfg(feature = "python")]
861#[pyfunction(name = "reflex")]
862#[pyo3(signature = (data, period = 20, kernel = None), text_signature = "(data, period=20, kernel=None)")]
863pub fn reflex_py<'py>(
864 py: Python<'py>,
865 data: PyReadonlyArray1<'py, f64>,
866 period: usize,
867 kernel: Option<&str>,
868) -> PyResult<Bound<'py, PyArray1<f64>>> {
869 r#"Compute Reflex indicator.
870
871 Parameters
872 ----------
873 data : numpy.ndarray
874 Input data array
875 period : int, default=20
876 Period for the indicator (must be >= 2)
877 kernel : str, optional
878 Kernel to use:
879 - 'auto' or None: Auto-detect best kernel (default)
880 - 'scalar': Use scalar implementation
881 - 'avx2': Use AVX2 implementation (if available)
882 - 'avx512': Use AVX512 implementation (if available)
883
884 Returns
885 -------
886 numpy.ndarray
887 Reflex values
888 "#;
889
890 use numpy::{IntoPyArray, PyArrayMethods};
891
892 let data_slice = data.as_slice()?;
893 let kern = validate_kernel(kernel, false)?;
894
895 let params = ReflexParams {
896 period: Some(period),
897 };
898 let input = ReflexInput::from_slice(data_slice, params);
899
900 let result_vec: Vec<f64> = py
901 .allow_threads(|| reflex_with_kernel(&input, kern).map(|o| o.values))
902 .map_err(|e| PyValueError::new_err(e.to_string()))?;
903
904 Ok(result_vec.into_pyarray(py))
905}
906
907#[cfg(feature = "python")]
908#[pyfunction(name = "reflex_batch")]
909#[pyo3(signature = (data, periods, kernel = None), text_signature = "(data, periods, kernel=None)")]
910pub fn reflex_batch_py<'py>(
911 py: Python<'py>,
912 data: PyReadonlyArray1<'py, f64>,
913 periods: (usize, usize, usize),
914 kernel: Option<&str>,
915) -> PyResult<Py<PyDict>> {
916 r#"Compute Reflex indicator for multiple periods.
917
918 Parameters
919 ----------
920 data : numpy.ndarray
921 Input data array
922 periods : tuple of int
923 (start, end, step) for period range
924 kernel : str, optional
925 Kernel to use (see reflex() for options)
926
927 Returns
928 -------
929 dict
930 Dictionary with 'values' (2D array) and 'periods' (list)
931 "#;
932
933 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
934 use pyo3::types::PyDict;
935
936 let data_slice = data.as_slice()?;
937 let kern = validate_kernel(kernel, true)?;
938
939 let range = ReflexBatchRange { period: periods };
940
941 let combos = expand_grid_checked(&range)
942 .map_err(|e| PyValueError::new_err(format!("reflex batch error: {}", e)))?;
943 let rows = combos.len();
944 let cols = data_slice.len();
945
946 let total = rows
947 .checked_mul(cols)
948 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
949
950 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
951 let slice_out = unsafe { out_arr.as_slice_mut()? };
952
953 let metadata = py
954 .allow_threads(|| {
955 let kernel = match kern {
956 Kernel::Auto => detect_best_batch_kernel(),
957 k => k,
958 };
959
960 let simd = match kernel {
961 Kernel::Avx512Batch => Kernel::Avx512,
962 Kernel::Avx2Batch => Kernel::Avx2,
963 Kernel::ScalarBatch => Kernel::Scalar,
964 other => other,
965 };
966
967 reflex_batch_inner_into(data_slice, &range, simd, true, slice_out)
968 })
969 .map_err(|e| PyValueError::new_err(format!("reflex batch error: {}", e)))?;
970
971 let dict = PyDict::new(py);
972
973 let reshaped = out_arr.reshape([rows, cols])?;
974 dict.set_item("values", reshaped)?;
975
976 dict.set_item(
977 "periods",
978 metadata
979 .combos
980 .iter()
981 .map(|c| c.period.unwrap_or(20) as u64)
982 .collect::<Vec<_>>()
983 .into_pyarray(py),
984 )?;
985
986 Ok(dict.into())
987}
988
989#[cfg(all(feature = "python", feature = "cuda"))]
990#[pyfunction(name = "reflex_cuda_batch_dev")]
991#[pyo3(signature = (data_f32, period_range, device_id=0))]
992pub fn reflex_cuda_batch_dev_py(
993 py: Python<'_>,
994 data_f32: numpy::PyReadonlyArray1<'_, f32>,
995 period_range: (usize, usize, usize),
996 device_id: usize,
997) -> PyResult<DeviceArrayF32Py> {
998 if !cuda_available() {
999 return Err(PyValueError::new_err("CUDA not available"));
1000 }
1001
1002 let slice_in = data_f32.as_slice()?;
1003 let sweep = ReflexBatchRange {
1004 period: period_range,
1005 };
1006
1007 let (inner, ctx, dev_id) = py.allow_threads(|| {
1008 let cuda = CudaReflex::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1009 let ctx = cuda.context_arc();
1010 let dev_id = device_id as u32;
1011 cuda.reflex_batch_dev(slice_in, &sweep)
1012 .map(|inner| (inner, ctx, dev_id))
1013 .map_err(|e| PyValueError::new_err(e.to_string()))
1014 })?;
1015 Ok(DeviceArrayF32Py {
1016 inner,
1017 _ctx: Some(ctx),
1018 device_id: Some(dev_id),
1019 })
1020}
1021
1022#[cfg(all(feature = "python", feature = "cuda"))]
1023#[pyfunction(name = "reflex_cuda_many_series_one_param_dev")]
1024#[pyo3(signature = (data_tm_f32, period, device_id=0))]
1025pub fn reflex_cuda_many_series_one_param_dev_py(
1026 py: Python<'_>,
1027 data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1028 period: usize,
1029 device_id: usize,
1030) -> PyResult<DeviceArrayF32Py> {
1031 use numpy::PyUntypedArrayMethods;
1032
1033 if !cuda_available() {
1034 return Err(PyValueError::new_err("CUDA not available"));
1035 }
1036
1037 let shape = data_tm_f32.shape();
1038 if shape.len() != 2 {
1039 return Err(PyValueError::new_err("expected 2D array"));
1040 }
1041 let rows = shape[0];
1042 let cols = shape[1];
1043 let flat = data_tm_f32.as_slice()?;
1044
1045 let (inner, ctx, dev_id) = py.allow_threads(|| {
1046 let cuda = CudaReflex::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1047 let ctx = cuda.context_arc();
1048 let dev_id = device_id as u32;
1049 cuda.reflex_many_series_one_param_time_major_dev(flat, cols, rows, period)
1050 .map(|inner| (inner, ctx, dev_id))
1051 .map_err(|e| PyValueError::new_err(e.to_string()))
1052 })?;
1053 Ok(DeviceArrayF32Py {
1054 inner,
1055 _ctx: Some(ctx),
1056 device_id: Some(dev_id),
1057 })
1058}
1059
1060#[cfg(feature = "python")]
1061#[pyclass(name = "ReflexStream")]
1062pub struct ReflexStreamPy {
1063 inner: ReflexStream,
1064}
1065
1066#[cfg(feature = "python")]
1067#[pymethods]
1068impl ReflexStreamPy {
1069 #[new]
1070 #[pyo3(signature = (period = 20))]
1071 pub fn new(period: usize) -> PyResult<Self> {
1072 let params = ReflexParams {
1073 period: Some(period),
1074 };
1075 let inner = ReflexStream::try_new(params)
1076 .map_err(|e| PyValueError::new_err(format!("reflex stream error: {}", e)))?;
1077 Ok(Self { inner })
1078 }
1079
1080 pub fn update(&mut self, value: f64) -> Option<f64> {
1081 self.inner.update(value)
1082 }
1083}
1084
1085#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1086#[wasm_bindgen]
1087pub fn reflex_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1088 let params = ReflexParams {
1089 period: Some(period),
1090 };
1091 let input = ReflexInput::from_slice(data, params);
1092
1093 let mut output = vec![0.0; data.len()];
1094
1095 reflex_into_slice(&mut output, &input, Kernel::Auto)
1096 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1097
1098 Ok(output)
1099}
1100
1101#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1102#[wasm_bindgen]
1103pub fn reflex_batch_js(
1104 data: &[f64],
1105 period_start: usize,
1106 period_end: usize,
1107 period_step: usize,
1108) -> Result<Vec<f64>, JsValue> {
1109 let range = ReflexBatchRange {
1110 period: (period_start, period_end, period_step),
1111 };
1112
1113 let output = reflex_batch_with_kernel(data, &range, Kernel::Auto)
1114 .map_err(|e| JsValue::from_str(&format!("reflex batch error: {}", e)))?;
1115
1116 Ok(output.values)
1117}
1118
1119#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1120#[wasm_bindgen]
1121pub fn reflex_batch_metadata_js(
1122 period_start: usize,
1123 period_end: usize,
1124 period_step: usize,
1125) -> Vec<usize> {
1126 let range = ReflexBatchRange {
1127 period: (period_start, period_end, period_step),
1128 };
1129 match expand_grid_checked(&range) {
1130 Ok(combos) => combos.iter().map(|c| c.period.unwrap_or(20)).collect(),
1131 Err(_) => Vec::new(),
1132 }
1133}
1134
1135#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1136#[wasm_bindgen]
1137pub fn reflex_batch_rows_cols_js(
1138 period_start: usize,
1139 period_end: usize,
1140 period_step: usize,
1141 data_len: usize,
1142) -> Vec<usize> {
1143 let range = ReflexBatchRange {
1144 period: (period_start, period_end, period_step),
1145 };
1146 let rows = expand_grid_checked(&range).map(|c| c.len()).unwrap_or(0);
1147 vec![rows, data_len]
1148}
1149
1150#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1151#[wasm_bindgen]
1152pub fn reflex_alloc(len: usize) -> *mut f64 {
1153 let mut vec = Vec::<f64>::with_capacity(len);
1154 let ptr = vec.as_mut_ptr();
1155 std::mem::forget(vec);
1156 ptr
1157}
1158
1159#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1160#[wasm_bindgen]
1161pub fn reflex_free(ptr: *mut f64, len: usize) {
1162 if !ptr.is_null() {
1163 unsafe {
1164 let _ = Vec::from_raw_parts(ptr, len, len);
1165 }
1166 }
1167}
1168
1169#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1170#[wasm_bindgen]
1171pub fn reflex_into(
1172 in_ptr: *const f64,
1173 out_ptr: *mut f64,
1174 len: usize,
1175 period: usize,
1176) -> Result<(), JsValue> {
1177 if in_ptr.is_null() || out_ptr.is_null() {
1178 return Err(JsValue::from_str("Null pointer provided"));
1179 }
1180
1181 unsafe {
1182 let data = std::slice::from_raw_parts(in_ptr, len);
1183
1184 if period == 0 || period > len {
1185 return Err(JsValue::from_str("Invalid period"));
1186 }
1187
1188 let params = ReflexParams {
1189 period: Some(period),
1190 };
1191 let input = ReflexInput::from_slice(data, params);
1192
1193 if in_ptr == out_ptr {
1194 let mut temp = vec![0.0; len];
1195 reflex_into_slice(&mut temp, &input, Kernel::Auto)
1196 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1197 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1198 out.copy_from_slice(&temp);
1199 } else {
1200 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1201 reflex_into_slice(out, &input, Kernel::Auto)
1202 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1203 }
1204 Ok(())
1205 }
1206}
1207
1208#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1209#[wasm_bindgen]
1210pub fn reflex_batch_into(
1211 in_ptr: *const f64,
1212 out_ptr: *mut f64,
1213 len: usize,
1214 period_start: usize,
1215 period_end: usize,
1216 period_step: usize,
1217) -> Result<usize, JsValue> {
1218 if in_ptr.is_null() || out_ptr.is_null() {
1219 return Err(JsValue::from_str("null pointer"));
1220 }
1221 let data = unsafe { std::slice::from_raw_parts(in_ptr, len) };
1222 let sweep = ReflexBatchRange {
1223 period: (period_start, period_end, period_step),
1224 };
1225
1226 let combos = expand_grid_checked(&sweep)
1227 .map_err(|e| JsValue::from_str(&format!("reflex batch error: {}", e)))?;
1228 let rows = combos.len();
1229 let cols = len;
1230 let total = rows
1231 .checked_mul(cols)
1232 .ok_or_else(|| JsValue::from_str("size overflow"))?;
1233 let out = unsafe { std::slice::from_raw_parts_mut(out_ptr, total) };
1234
1235 reflex_batch_inner_into(data, &sweep, detect_best_kernel(), false, out)
1236 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1237 Ok(rows)
1238}
1239
1240#[cfg(test)]
1241mod tests {
1242 use super::*;
1243 use crate::skip_if_unsupported;
1244 use crate::utilities::data_loader::read_candles_from_csv;
1245
1246 fn check_reflex_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1247 skip_if_unsupported!(kernel, test_name);
1248 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1249 let candles = read_candles_from_csv(file_path)?;
1250 let default_params = ReflexParams { period: None };
1251 let input = ReflexInput::from_candles(&candles, "close", default_params);
1252 let output = reflex_with_kernel(&input, kernel)?;
1253 assert_eq!(output.values.len(), candles.close.len());
1254 let params_period_14 = ReflexParams { period: Some(14) };
1255 let input2 = ReflexInput::from_candles(&candles, "hl2", params_period_14);
1256 let output2 = reflex_with_kernel(&input2, kernel)?;
1257 assert_eq!(output2.values.len(), candles.close.len());
1258 let params_custom = ReflexParams { period: Some(30) };
1259 let input3 = ReflexInput::from_candles(&candles, "hlc3", params_custom);
1260 let output3 = reflex_with_kernel(&input3, kernel)?;
1261 assert_eq!(output3.values.len(), candles.close.len());
1262 Ok(())
1263 }
1264
1265 fn check_reflex_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1266 skip_if_unsupported!(kernel, test_name);
1267 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1268 let candles = read_candles_from_csv(file_path)?;
1269 let default_params = ReflexParams::default();
1270 let input = ReflexInput::from_candles(&candles, "close", default_params);
1271 let result = reflex_with_kernel(&input, kernel)?;
1272 assert_eq!(result.values.len(), candles.close.len());
1273 let len = result.values.len();
1274 let expected_last_five = [
1275 0.8085220962465361,
1276 0.445264715886137,
1277 0.13861699036615063,
1278 -0.03598639652007061,
1279 -0.224906760543743,
1280 ];
1281 let start_idx = len - 5;
1282 let last_five = &result.values[start_idx..];
1283 for (i, &val) in last_five.iter().enumerate() {
1284 let exp = expected_last_five[i];
1285 assert!(
1286 (val - exp).abs() < 1e-7,
1287 "[{}] Reflex mismatch at idx {}: got {}, expected {}",
1288 test_name,
1289 i,
1290 val,
1291 exp
1292 );
1293 }
1294 Ok(())
1295 }
1296
1297 fn check_reflex_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1298 skip_if_unsupported!(kernel, test_name);
1299 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1300 let candles = read_candles_from_csv(file_path)?;
1301 let input = ReflexInput::with_default_candles(&candles);
1302 match input.data {
1303 ReflexData::Candles { source, .. } => assert_eq!(source, "close"),
1304 _ => panic!("Expected ReflexData::Candles"),
1305 }
1306 let output = reflex_with_kernel(&input, kernel)?;
1307 assert_eq!(output.values.len(), candles.close.len());
1308 Ok(())
1309 }
1310
1311 fn check_reflex_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1312 skip_if_unsupported!(kernel, test_name);
1313 let input_data = [10.0, 20.0, 30.0];
1314 let params = ReflexParams { period: Some(0) };
1315 let input = ReflexInput::from_slice(&input_data, params);
1316 let res = reflex_with_kernel(&input, kernel);
1317 assert!(
1318 res.is_err(),
1319 "[{}] Reflex should fail with zero period",
1320 test_name
1321 );
1322 Ok(())
1323 }
1324
1325 fn check_reflex_period_less_than_two(
1326 test_name: &str,
1327 kernel: Kernel,
1328 ) -> Result<(), Box<dyn Error>> {
1329 skip_if_unsupported!(kernel, test_name);
1330 let input_data = [10.0, 20.0, 30.0];
1331 let params = ReflexParams { period: Some(1) };
1332 let input = ReflexInput::from_slice(&input_data, params);
1333 let res = reflex_with_kernel(&input, kernel);
1334 assert!(
1335 res.is_err(),
1336 "[{}] Reflex should fail with period<2",
1337 test_name
1338 );
1339 Ok(())
1340 }
1341
1342 fn check_reflex_very_small_data_set(
1343 test_name: &str,
1344 kernel: Kernel,
1345 ) -> Result<(), Box<dyn Error>> {
1346 skip_if_unsupported!(kernel, test_name);
1347 let input_data = [42.0];
1348 let params = ReflexParams { period: Some(2) };
1349 let input = ReflexInput::from_slice(&input_data, params);
1350 let res = reflex_with_kernel(&input, kernel);
1351 assert!(
1352 res.is_err(),
1353 "[{}] Reflex should fail with insufficient data",
1354 test_name
1355 );
1356 Ok(())
1357 }
1358
1359 fn check_reflex_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1360 skip_if_unsupported!(kernel, test_name);
1361 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1362 let candles = read_candles_from_csv(file_path)?;
1363 let first_params = ReflexParams { period: Some(14) };
1364 let first_input = ReflexInput::from_candles(&candles, "close", first_params);
1365 let first_result = reflex_with_kernel(&first_input, kernel)?;
1366 assert_eq!(first_result.values.len(), candles.close.len());
1367 let second_params = ReflexParams { period: Some(10) };
1368 let second_input = ReflexInput::from_slice(&first_result.values, second_params);
1369 let second_result = reflex_with_kernel(&second_input, kernel)?;
1370 assert_eq!(second_result.values.len(), first_result.values.len());
1371 for i in 14..second_result.values.len() {
1372 assert!(second_result.values[i].is_finite());
1373 }
1374 Ok(())
1375 }
1376
1377 fn check_reflex_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1378 skip_if_unsupported!(kernel, test_name);
1379 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1380 let candles = read_candles_from_csv(file_path)?;
1381 let period = 14;
1382 let params = ReflexParams {
1383 period: Some(period),
1384 };
1385 let input = ReflexInput::from_candles(&candles, "close", params);
1386 let result = reflex_with_kernel(&input, kernel)?;
1387 assert_eq!(result.values.len(), candles.close.len());
1388 if result.values.len() > period {
1389 for i in period..result.values.len() {
1390 assert!(
1391 result.values[i].is_finite(),
1392 "[{}] Unexpected NaN at index {}",
1393 test_name,
1394 i
1395 );
1396 }
1397 }
1398 Ok(())
1399 }
1400
1401 fn check_reflex_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1402 skip_if_unsupported!(kernel, test_name);
1403 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1404 let candles = read_candles_from_csv(file_path)?;
1405 let period = 14;
1406 let params = ReflexParams {
1407 period: Some(period),
1408 };
1409 let input = ReflexInput::from_candles(&candles, "close", params.clone());
1410 let batch_output = reflex_with_kernel(&input, kernel)?.values;
1411 let mut stream = ReflexStream::try_new(params)?;
1412 let mut stream_values = Vec::with_capacity(candles.close.len());
1413 for &price in &candles.close {
1414 match stream.update(price) {
1415 Some(v) => stream_values.push(v),
1416 None => stream_values.push(0.0),
1417 }
1418 }
1419 assert_eq!(batch_output.len(), stream_values.len());
1420 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1421 let diff = (b - s).abs();
1422 assert!(
1423 diff < 1e-9,
1424 "[{}] Reflex streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1425 test_name,
1426 i,
1427 b,
1428 s,
1429 diff
1430 );
1431 }
1432 Ok(())
1433 }
1434
1435 macro_rules! generate_all_reflex_tests {
1436 ($($test_fn:ident),*) => {
1437 paste::paste! {
1438 $(
1439 #[test]
1440 fn [<$test_fn _scalar_f64>]() {
1441 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1442 }
1443 )*
1444 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1445 $(
1446 #[test]
1447 fn [<$test_fn _avx2_f64>]() {
1448 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1449 }
1450 #[test]
1451 fn [<$test_fn _avx512_f64>]() {
1452 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1453 }
1454 )*
1455 }
1456 }
1457 }
1458
1459 #[cfg(debug_assertions)]
1460 fn check_reflex_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1461 skip_if_unsupported!(kernel, test_name);
1462
1463 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1464 let candles = read_candles_from_csv(file_path)?;
1465
1466 let test_cases = vec![
1467 ReflexParams { period: Some(20) },
1468 ReflexParams { period: Some(2) },
1469 ReflexParams { period: Some(5) },
1470 ReflexParams { period: Some(10) },
1471 ReflexParams { period: Some(30) },
1472 ReflexParams { period: Some(50) },
1473 ReflexParams { period: Some(15) },
1474 ReflexParams { period: Some(40) },
1475 ReflexParams { period: None },
1476 ];
1477
1478 for params in test_cases {
1479 let input = ReflexInput::from_candles(&candles, "close", params);
1480 let output = reflex_with_kernel(&input, kernel)?;
1481
1482 for (i, &val) in output.values.iter().enumerate() {
1483 if val.is_nan() {
1484 continue;
1485 }
1486
1487 let bits = val.to_bits();
1488
1489 if bits == 0x11111111_11111111 {
1490 panic!(
1491 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1492 with params period={:?}",
1493 test_name, val, bits, i, params.period
1494 );
1495 }
1496
1497 if bits == 0x22222222_22222222 {
1498 panic!(
1499 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1500 with params period={:?}",
1501 test_name, val, bits, i, params.period
1502 );
1503 }
1504
1505 if bits == 0x33333333_33333333 {
1506 panic!(
1507 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1508 with params period={:?}",
1509 test_name, val, bits, i, params.period
1510 );
1511 }
1512 }
1513 }
1514
1515 Ok(())
1516 }
1517
1518 #[cfg(not(debug_assertions))]
1519 fn check_reflex_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1520 Ok(())
1521 }
1522
1523 #[cfg(feature = "proptest")]
1524 #[allow(clippy::float_cmp)]
1525 fn check_reflex_property(
1526 test_name: &str,
1527 kernel: Kernel,
1528 ) -> Result<(), Box<dyn std::error::Error>> {
1529 use proptest::prelude::*;
1530 skip_if_unsupported!(kernel, test_name);
1531
1532 let strat = (2usize..=50).prop_flat_map(|period| {
1533 (
1534 prop::collection::vec(
1535 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1536 period..400,
1537 ),
1538 Just(period),
1539 )
1540 });
1541
1542 proptest::test_runner::TestRunner::default()
1543 .run(&strat, |(data, period)| {
1544 let params = ReflexParams {
1545 period: Some(period),
1546 };
1547 let input = ReflexInput::from_slice(&data, params);
1548
1549 let ReflexOutput { values: out } = reflex_with_kernel(&input, kernel).unwrap();
1550 let ReflexOutput { values: ref_out } =
1551 reflex_with_kernel(&input, Kernel::Scalar).unwrap();
1552
1553 prop_assert_eq!(out.len(), data.len());
1554
1555 for i in 0..period.min(data.len()) {
1556 prop_assert!(
1557 out[i] == 0.0,
1558 "[{}] idx {}: expected 0.0 during warmup, got {}",
1559 test_name,
1560 i,
1561 out[i]
1562 );
1563 }
1564
1565 for i in 0..data.len() {
1566 let y = out[i];
1567 let r = ref_out[i];
1568
1569 if !y.is_finite() || !r.is_finite() {
1570 prop_assert_eq!(
1571 y.to_bits(),
1572 r.to_bits(),
1573 "[{}] finite/NaN mismatch idx {}: {} vs {}",
1574 test_name,
1575 i,
1576 y,
1577 r
1578 );
1579 continue;
1580 }
1581
1582 let ulp_diff = y.to_bits().abs_diff(r.to_bits());
1583 prop_assert!(
1584 (y - r).abs() <= 1e-9 || ulp_diff <= 4,
1585 "[{}] mismatch idx {}: {} vs {} (ULP={})",
1586 test_name,
1587 i,
1588 y,
1589 r,
1590 ulp_diff
1591 );
1592 }
1593
1594 for i in period..data.len() {
1595 if data[i].abs() < 1e10 {
1596 prop_assert!(
1597 out[i].is_finite(),
1598 "[{}] idx {}: expected finite, got {}",
1599 test_name,
1600 i,
1601 out[i]
1602 );
1603 }
1604 }
1605
1606 if data.windows(2).all(|w| (w[0] - w[1]).abs() < f64::EPSILON) {
1607 for i in (period * 2)..data.len() {
1608 prop_assert!(
1609 out[i].abs() < 0.001,
1610 "[{}] idx {}: constant data should yield near-zero, got {}",
1611 test_name,
1612 i,
1613 out[i]
1614 );
1615 }
1616 }
1617
1618 Ok(())
1619 })
1620 .unwrap();
1621
1622 Ok(())
1623 }
1624
1625 #[cfg(not(feature = "proptest"))]
1626 fn check_reflex_property(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1627 skip_if_unsupported!(kernel, test_name);
1628 Ok(())
1629 }
1630
1631 generate_all_reflex_tests!(
1632 check_reflex_partial_params,
1633 check_reflex_accuracy,
1634 check_reflex_default_candles,
1635 check_reflex_zero_period,
1636 check_reflex_period_less_than_two,
1637 check_reflex_very_small_data_set,
1638 check_reflex_reinput,
1639 check_reflex_nan_handling,
1640 check_reflex_streaming,
1641 check_reflex_no_poison,
1642 check_reflex_property
1643 );
1644
1645 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1646 skip_if_unsupported!(kernel, test);
1647
1648 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1649 let c = read_candles_from_csv(file)?;
1650
1651 let output = ReflexBatchBuilder::new()
1652 .kernel(kernel)
1653 .apply_candles(&c, "close")?;
1654
1655 let def = ReflexParams::default();
1656 let row = output.values_for(&def).expect("default row missing");
1657 assert_eq!(row.len(), c.close.len());
1658
1659 let expected = [
1660 0.8085220962465361,
1661 0.445264715886137,
1662 0.13861699036615063,
1663 -0.03598639652007061,
1664 -0.224906760543743,
1665 ];
1666 let start = row.len() - 5;
1667 for (i, &v) in row[start..].iter().enumerate() {
1668 assert!(
1669 (v - expected[i]).abs() < 1e-7,
1670 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1671 );
1672 }
1673 Ok(())
1674 }
1675
1676 macro_rules! gen_batch_tests {
1677 ($fn_name:ident) => {
1678 paste::paste! {
1679 #[test] fn [<$fn_name _scalar>]() {
1680 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1681 }
1682 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1683 #[test] fn [<$fn_name _avx2>]() {
1684 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1685 }
1686 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1687 #[test] fn [<$fn_name _avx512>]() {
1688 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1689 }
1690 #[test] fn [<$fn_name _auto_detect>]() {
1691 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1692 }
1693 }
1694 };
1695 }
1696
1697 #[cfg(debug_assertions)]
1698 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1699 skip_if_unsupported!(kernel, test);
1700
1701 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1702 let c = read_candles_from_csv(file)?;
1703
1704 let batch_configs = vec![
1705 (10, 30, 10),
1706 (20, 20, 0),
1707 (2, 10, 2),
1708 (25, 50, 25),
1709 (5, 20, 5),
1710 (15, 45, 15),
1711 (3, 15, 3),
1712 (30, 60, 10),
1713 ];
1714
1715 for (p_start, p_end, p_step) in batch_configs {
1716 let output = ReflexBatchBuilder::new()
1717 .kernel(kernel)
1718 .period_range(p_start, p_end, p_step)
1719 .apply_candles(&c, "close")?;
1720
1721 for (idx, &val) in output.values.iter().enumerate() {
1722 if val.is_nan() {
1723 continue;
1724 }
1725
1726 let bits = val.to_bits();
1727 let row = idx / output.cols;
1728 let col = idx % output.cols;
1729 let combo = &output.combos[row];
1730
1731 if bits == 0x11111111_11111111 {
1732 panic!(
1733 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} \
1734 (flat index {}) with params period={:?}",
1735 test, val, bits, row, col, idx, combo.period
1736 );
1737 }
1738
1739 if bits == 0x22222222_22222222 {
1740 panic!(
1741 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} \
1742 (flat index {}) with params period={:?}",
1743 test, val, bits, row, col, idx, combo.period
1744 );
1745 }
1746
1747 if bits == 0x33333333_33333333 {
1748 panic!(
1749 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} \
1750 (flat index {}) with params period={:?}",
1751 test, val, bits, row, col, idx, combo.period
1752 );
1753 }
1754 }
1755 }
1756
1757 Ok(())
1758 }
1759
1760 #[cfg(not(debug_assertions))]
1761 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1762 Ok(())
1763 }
1764
1765 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1766 #[test]
1767 fn test_reflex_into_matches_api() -> Result<(), Box<dyn Error>> {
1768 let mut data = vec![f64::from_bits(0x7ff8_0000_0000_0000); 3];
1769 data.extend((0..256).map(|i| ((i as f64) * 0.1).sin() * 1.23 + (i as f64) * 0.01));
1770
1771 let input = ReflexInput::from_slice(&data, ReflexParams::default());
1772
1773 let baseline = reflex_with_kernel(&input, Kernel::Auto)?.values;
1774
1775 let mut out = vec![0.0; data.len()];
1776 super::reflex_into(&input, &mut out)?;
1777
1778 assert_eq!(baseline.len(), out.len());
1779
1780 fn eq_or_both_nan(a: f64, b: f64) -> bool {
1781 (a.is_nan() && b.is_nan()) || (a == b) || ((a - b).abs() <= 1e-12)
1782 }
1783
1784 for i in 0..out.len() {
1785 assert!(
1786 eq_or_both_nan(baseline[i], out[i]),
1787 "mismatch at {}: baseline={} out={}",
1788 i,
1789 baseline[i],
1790 out[i]
1791 );
1792 }
1793
1794 Ok(())
1795 }
1796
1797 gen_batch_tests!(check_batch_default_row);
1798 gen_batch_tests!(check_batch_no_poison);
1799}