1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::{PyDict, PyList};
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::indicators::moving_averages::ma::{ma, MaData};
16use crate::utilities::data_loader::{source_type, Candles};
17use crate::utilities::enums::Kernel;
18use crate::utilities::helpers::{
19 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
20 make_uninit_matrix,
21};
22#[cfg(feature = "python")]
23use crate::utilities::kernel_validation::validate_kernel;
24#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
25use core::arch::x86_64::*;
26#[cfg(not(target_arch = "wasm32"))]
27use rayon::prelude::*;
28use std::convert::AsRef;
29use thiserror::Error;
30
31#[derive(Debug, Clone)]
32pub enum EriData<'a> {
33 Candles {
34 candles: &'a Candles,
35 source: &'a str,
36 },
37 Slices {
38 high: &'a [f64],
39 low: &'a [f64],
40 source: &'a [f64],
41 },
42}
43
44impl<'a> AsRef<[f64]> for EriInput<'a> {
45 #[inline(always)]
46 fn as_ref(&self) -> &[f64] {
47 match &self.data {
48 EriData::Candles { candles, source } => source_type(candles, source),
49 EriData::Slices { source, .. } => source,
50 }
51 }
52}
53
54#[derive(Debug, Clone)]
55pub struct EriOutput {
56 pub bull: Vec<f64>,
57 pub bear: Vec<f64>,
58}
59
60#[derive(Debug, Clone)]
61pub struct EriParams {
62 pub period: Option<usize>,
63 pub ma_type: Option<String>,
64}
65
66impl Default for EriParams {
67 fn default() -> Self {
68 Self {
69 period: Some(13),
70 ma_type: Some("ema".to_string()),
71 }
72 }
73}
74
75#[derive(Debug, Clone)]
76pub struct EriInput<'a> {
77 pub data: EriData<'a>,
78 pub params: EriParams,
79}
80
81impl<'a> EriInput<'a> {
82 #[inline]
83 pub fn from_candles(candles: &'a Candles, source: &'a str, params: EriParams) -> Self {
84 Self {
85 data: EriData::Candles { candles, source },
86 params,
87 }
88 }
89 #[inline]
90 pub fn from_slices(
91 high: &'a [f64],
92 low: &'a [f64],
93 source: &'a [f64],
94 params: EriParams,
95 ) -> Self {
96 Self {
97 data: EriData::Slices { high, low, source },
98 params,
99 }
100 }
101 #[inline]
102 pub fn with_default_candles(candles: &'a Candles) -> Self {
103 Self::from_candles(candles, "close", EriParams::default())
104 }
105 #[inline]
106 pub fn get_period(&self) -> usize {
107 self.params.period.unwrap_or(13)
108 }
109 #[inline]
110 pub fn get_ma_type(&self) -> &str {
111 self.params.ma_type.as_deref().unwrap_or("ema")
112 }
113}
114
115#[derive(Clone, Debug)]
116pub struct EriBuilder {
117 period: Option<usize>,
118 ma_type: Option<String>,
119 kernel: Kernel,
120}
121
122impl Default for EriBuilder {
123 fn default() -> Self {
124 Self {
125 period: None,
126 ma_type: None,
127 kernel: Kernel::Auto,
128 }
129 }
130}
131
132impl EriBuilder {
133 #[inline(always)]
134 pub fn new() -> Self {
135 Self::default()
136 }
137 #[inline(always)]
138 pub fn period(mut self, n: usize) -> Self {
139 self.period = Some(n);
140 self
141 }
142 #[inline(always)]
143 pub fn ma_type<S: Into<String>>(mut self, t: S) -> Self {
144 self.ma_type = Some(t.into());
145 self
146 }
147 #[inline(always)]
148 pub fn kernel(mut self, k: Kernel) -> Self {
149 self.kernel = k;
150 self
151 }
152 #[inline(always)]
153 pub fn apply(self, c: &Candles) -> Result<EriOutput, EriError> {
154 let p = EriParams {
155 period: self.period,
156 ma_type: self.ma_type,
157 };
158 let i = EriInput::from_candles(c, "close", p);
159 eri_with_kernel(&i, self.kernel)
160 }
161 #[inline(always)]
162 pub fn apply_slices(
163 self,
164 high: &[f64],
165 low: &[f64],
166 src: &[f64],
167 ) -> Result<EriOutput, EriError> {
168 let p = EriParams {
169 period: self.period,
170 ma_type: self.ma_type,
171 };
172 let i = EriInput::from_slices(high, low, src, p);
173 eri_with_kernel(&i, self.kernel)
174 }
175 #[inline(always)]
176 pub fn into_stream(self) -> Result<EriStream, EriError> {
177 let p = EriParams {
178 period: self.period,
179 ma_type: self.ma_type,
180 };
181 EriStream::try_new(p)
182 }
183}
184
185#[derive(Debug, Error)]
186pub enum EriError {
187 #[error("eri: All input values are NaN.")]
188 AllValuesNaN,
189 #[error("eri: Invalid period: period = {period}, data length = {data_len}")]
190 InvalidPeriod { period: usize, data_len: usize },
191 #[error("eri: Not enough valid data: needed = {needed}, valid = {valid}")]
192 NotEnoughValidData { needed: usize, valid: usize },
193 #[error("eri: MA calculation error: {0}")]
194 MaCalculationError(String),
195 #[error("eri: Empty data provided.")]
196 EmptyInputData,
197 #[error("eri: Output slice length mismatch: expected = {expected}, got = {got}")]
198 OutputLengthMismatch { expected: usize, got: usize },
199 #[error("eri: Invalid range expansion: start={start}, end={end}, step={step}")]
200 InvalidRange {
201 start: usize,
202 end: usize,
203 step: usize,
204 },
205 #[error("eri: Invalid kernel for batch operation. Got {0:?}")]
206 InvalidKernelForBatch(Kernel),
207}
208
209#[inline]
210pub fn eri(input: &EriInput) -> Result<EriOutput, EriError> {
211 eri_with_kernel(input, Kernel::Auto)
212}
213
214pub fn eri_with_kernel(input: &EriInput, kernel: Kernel) -> Result<EriOutput, EriError> {
215 let (high, low, source_data) = match &input.data {
216 EriData::Candles { candles, source } => {
217 let high = candles
218 .select_candle_field("high")
219 .map_err(|_| EriError::EmptyInputData)?;
220 let low = candles
221 .select_candle_field("low")
222 .map_err(|_| EriError::EmptyInputData)?;
223 let src = source_type(candles, source);
224 (high, low, src)
225 }
226 EriData::Slices { high, low, source } => (*high, *low, *source),
227 };
228
229 if source_data.is_empty() || high.is_empty() || low.is_empty() {
230 return Err(EriError::EmptyInputData);
231 }
232
233 let period = input.get_period();
234 if period == 0 || period > source_data.len() {
235 return Err(EriError::InvalidPeriod {
236 period,
237 data_len: source_data.len(),
238 });
239 }
240
241 let mut first_valid_idx = None;
242 for i in 0..source_data.len() {
243 if !(source_data[i].is_nan() || high[i].is_nan() || low[i].is_nan()) {
244 first_valid_idx = Some(i);
245 break;
246 }
247 }
248 let first_valid_idx = match first_valid_idx {
249 Some(idx) => idx,
250 None => return Err(EriError::AllValuesNaN),
251 };
252
253 if (source_data.len() - first_valid_idx) < period {
254 return Err(EriError::NotEnoughValidData {
255 needed: period,
256 valid: source_data.len() - first_valid_idx,
257 });
258 }
259
260 let ma_type = input.get_ma_type();
261 let warmup_period = first_valid_idx + period - 1;
262 let mut bull = alloc_with_nan_prefix(source_data.len(), warmup_period);
263 let mut bear = alloc_with_nan_prefix(source_data.len(), warmup_period);
264
265 if ma_type == "sma" || ma_type == "SMA" {
266 unsafe {
267 eri_scalar_classic_sma(
268 high,
269 low,
270 &source_data,
271 period,
272 first_valid_idx,
273 &mut bull,
274 &mut bear,
275 )?;
276 }
277 return Ok(EriOutput { bull, bear });
278 } else if ma_type == "ema" || ma_type == "EMA" {
279 unsafe {
280 eri_scalar_classic_ema(
281 high,
282 low,
283 &source_data,
284 period,
285 first_valid_idx,
286 &mut bull,
287 &mut bear,
288 )?;
289 }
290 return Ok(EriOutput { bull, bear });
291 }
292
293 let full_ma = ma(&ma_type, MaData::Slice(&source_data), period)
294 .map_err(|e| EriError::MaCalculationError(e.to_string()))?;
295
296 let chosen = match kernel {
297 Kernel::Auto => Kernel::Scalar,
298 other => other,
299 };
300
301 unsafe {
302 match chosen {
303 Kernel::Scalar | Kernel::ScalarBatch => eri_scalar(
304 high,
305 low,
306 &full_ma,
307 period,
308 first_valid_idx,
309 &mut bull,
310 &mut bear,
311 ),
312 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
313 Kernel::Avx2 | Kernel::Avx2Batch => eri_avx2(
314 high,
315 low,
316 &full_ma,
317 period,
318 first_valid_idx,
319 &mut bull,
320 &mut bear,
321 ),
322 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
323 Kernel::Avx512 | Kernel::Avx512Batch => eri_avx512(
324 high,
325 low,
326 &full_ma,
327 period,
328 first_valid_idx,
329 &mut bull,
330 &mut bear,
331 ),
332 _ => unreachable!(),
333 }
334 }
335
336 Ok(EriOutput { bull, bear })
337}
338
339#[inline]
340pub fn eri_scalar(
341 high: &[f64],
342 low: &[f64],
343 ma: &[f64],
344 period: usize,
345 first_valid: usize,
346 bull: &mut [f64],
347 bear: &mut [f64],
348) {
349 let mut i = first_valid + period - 1;
350 let n = high.len();
351 if i >= n {
352 return;
353 }
354
355 while i + 4 <= n {
356 let m0 = ma[i + 0];
357 bull[i + 0] = high[i + 0] - m0;
358 bear[i + 0] = low[i + 0] - m0;
359
360 let m1 = ma[i + 1];
361 bull[i + 1] = high[i + 1] - m1;
362 bear[i + 1] = low[i + 1] - m1;
363
364 let m2 = ma[i + 2];
365 bull[i + 2] = high[i + 2] - m2;
366 bear[i + 2] = low[i + 2] - m2;
367
368 let m3 = ma[i + 3];
369 bull[i + 3] = high[i + 3] - m3;
370 bear[i + 3] = low[i + 3] - m3;
371
372 i += 4;
373 }
374
375 if i + 2 <= n {
376 let m0 = ma[i + 0];
377 bull[i + 0] = high[i + 0] - m0;
378 bear[i + 0] = low[i + 0] - m0;
379
380 let m1 = ma[i + 1];
381 bull[i + 1] = high[i + 1] - m1;
382 bear[i + 1] = low[i + 1] - m1;
383 i += 2;
384 }
385
386 if i < n {
387 let m0 = ma[i];
388 bull[i] = high[i] - m0;
389 bear[i] = low[i] - m0;
390 }
391}
392
393#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
394#[inline]
395pub fn eri_avx512(
396 high: &[f64],
397 low: &[f64],
398 ma: &[f64],
399 period: usize,
400 first_valid: usize,
401 bull: &mut [f64],
402 bear: &mut [f64],
403) {
404 unsafe { eri_avx512_long(high, low, ma, period, first_valid, bull, bear) }
405}
406
407#[inline]
408pub fn eri_avx2(
409 high: &[f64],
410 low: &[f64],
411 ma: &[f64],
412 period: usize,
413 first_valid: usize,
414 bull: &mut [f64],
415 bear: &mut [f64],
416) {
417 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
418 unsafe {
419 return eri_avx2_core(high, low, ma, period, first_valid, bull, bear);
420 }
421 eri_scalar(high, low, ma, period, first_valid, bull, bear)
422}
423
424#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
425#[inline]
426pub fn eri_avx512_short(
427 high: &[f64],
428 low: &[f64],
429 ma: &[f64],
430 period: usize,
431 first_valid: usize,
432 bull: &mut [f64],
433 bear: &mut [f64],
434) {
435 unsafe { eri_avx512_core(high, low, ma, period, first_valid, bull, bear) }
436}
437
438#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
439#[inline]
440pub fn eri_avx512_long(
441 high: &[f64],
442 low: &[f64],
443 ma: &[f64],
444 period: usize,
445 first_valid: usize,
446 bull: &mut [f64],
447 bear: &mut [f64],
448) {
449 unsafe { eri_avx512_core(high, low, ma, period, first_valid, bull, bear) }
450}
451
452#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
453#[inline]
454#[target_feature(enable = "avx2")]
455unsafe fn eri_avx2_core(
456 high: &[f64],
457 low: &[f64],
458 ma: &[f64],
459 period: usize,
460 first_valid: usize,
461 bull: &mut [f64],
462 bear: &mut [f64],
463) {
464 use core::arch::x86_64::*;
465
466 let mut i = first_valid + period - 1;
467 let n = high.len();
468 if i >= n {
469 return;
470 }
471 let len = n - i;
472
473 let mut h_ptr = high.as_ptr().add(i);
474 let mut l_ptr = low.as_ptr().add(i);
475 let mut m_ptr = ma.as_ptr().add(i);
476 let mut b_ptr = bull.as_mut_ptr().add(i);
477 let mut r_ptr = bear.as_mut_ptr().add(i);
478
479 let mut k = 0usize;
480 while k + 4 <= len {
481 let h = _mm256_loadu_pd(h_ptr);
482 let l = _mm256_loadu_pd(l_ptr);
483 let m = _mm256_loadu_pd(m_ptr);
484
485 let b = _mm256_sub_pd(h, m);
486 let r = _mm256_sub_pd(l, m);
487
488 _mm256_storeu_pd(b_ptr, b);
489 _mm256_storeu_pd(r_ptr, r);
490
491 h_ptr = h_ptr.add(4);
492 l_ptr = l_ptr.add(4);
493 m_ptr = m_ptr.add(4);
494 b_ptr = b_ptr.add(4);
495 r_ptr = r_ptr.add(4);
496 k += 4;
497 }
498
499 if k + 2 <= len {
500 let h = _mm_loadu_pd(h_ptr);
501 let l = _mm_loadu_pd(l_ptr);
502 let m = _mm_loadu_pd(m_ptr);
503
504 let b = _mm_sub_pd(h, m);
505 let r = _mm_sub_pd(l, m);
506
507 _mm_storeu_pd(b_ptr, b);
508 _mm_storeu_pd(r_ptr, r);
509
510 h_ptr = h_ptr.add(2);
511 l_ptr = l_ptr.add(2);
512 m_ptr = m_ptr.add(2);
513 b_ptr = b_ptr.add(2);
514 r_ptr = r_ptr.add(2);
515 k += 2;
516 }
517
518 if k < len {
519 let m0 = *m_ptr;
520 *b_ptr = *h_ptr - m0;
521 *r_ptr = *l_ptr - m0;
522 }
523}
524
525#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
526#[inline]
527#[target_feature(enable = "avx512f")]
528unsafe fn eri_avx512_core(
529 high: &[f64],
530 low: &[f64],
531 ma: &[f64],
532 period: usize,
533 first_valid: usize,
534 bull: &mut [f64],
535 bear: &mut [f64],
536) {
537 use core::arch::x86_64::*;
538
539 let mut i = first_valid + period - 1;
540 let n = high.len();
541 if i >= n {
542 return;
543 }
544 let len = n - i;
545
546 let mut h_ptr = high.as_ptr().add(i);
547 let mut l_ptr = low.as_ptr().add(i);
548 let mut m_ptr = ma.as_ptr().add(i);
549 let mut b_ptr = bull.as_mut_ptr().add(i);
550 let mut r_ptr = bear.as_mut_ptr().add(i);
551
552 let mut k = 0usize;
553 while k + 8 <= len {
554 let h = _mm512_loadu_pd(h_ptr);
555 let l = _mm512_loadu_pd(l_ptr);
556 let m = _mm512_loadu_pd(m_ptr);
557
558 let b = _mm512_sub_pd(h, m);
559 let r = _mm512_sub_pd(l, m);
560
561 _mm512_storeu_pd(b_ptr, b);
562 _mm512_storeu_pd(r_ptr, r);
563
564 h_ptr = h_ptr.add(8);
565 l_ptr = l_ptr.add(8);
566 m_ptr = m_ptr.add(8);
567 b_ptr = b_ptr.add(8);
568 r_ptr = r_ptr.add(8);
569 k += 8;
570 }
571
572 if k + 4 <= len {
573 #[cfg(target_feature = "avx2")]
574 {
575 let h = _mm256_loadu_pd(h_ptr);
576 let l = _mm256_loadu_pd(l_ptr);
577 let m = _mm256_loadu_pd(m_ptr);
578
579 let b = _mm256_sub_pd(h, m);
580 let r = _mm256_sub_pd(l, m);
581
582 _mm256_storeu_pd(b_ptr, b);
583 _mm256_storeu_pd(r_ptr, r);
584
585 h_ptr = h_ptr.add(4);
586 l_ptr = l_ptr.add(4);
587 m_ptr = m_ptr.add(4);
588 b_ptr = b_ptr.add(4);
589 r_ptr = r_ptr.add(4);
590 k += 4;
591 }
592 #[cfg(not(target_feature = "avx2"))]
593 {
594 let m0 = *m_ptr.add(0);
595 *b_ptr.add(0) = *h_ptr.add(0) - m0;
596 *r_ptr.add(0) = *l_ptr.add(0) - m0;
597 let m1 = *m_ptr.add(1);
598 *b_ptr.add(1) = *h_ptr.add(1) - m1;
599 *r_ptr.add(1) = *l_ptr.add(1) - m1;
600 let m2 = *m_ptr.add(2);
601 *b_ptr.add(2) = *h_ptr.add(2) - m2;
602 *r_ptr.add(2) = *l_ptr.add(2) - m2;
603 let m3 = *m_ptr.add(3);
604 *b_ptr.add(3) = *h_ptr.add(3) - m3;
605 *r_ptr.add(3) = *l_ptr.add(3) - m3;
606
607 h_ptr = h_ptr.add(4);
608 l_ptr = l_ptr.add(4);
609 m_ptr = m_ptr.add(4);
610 b_ptr = b_ptr.add(4);
611 r_ptr = r_ptr.add(4);
612 k += 4;
613 }
614 }
615
616 if k + 2 <= len {
617 let h = _mm_loadu_pd(h_ptr);
618 let l = _mm_loadu_pd(l_ptr);
619 let m = _mm_loadu_pd(m_ptr);
620
621 let b = _mm_sub_pd(h, m);
622 let r = _mm_sub_pd(l, m);
623
624 _mm_storeu_pd(b_ptr, b);
625 _mm_storeu_pd(r_ptr, r);
626
627 h_ptr = h_ptr.add(2);
628 l_ptr = l_ptr.add(2);
629 m_ptr = m_ptr.add(2);
630 b_ptr = b_ptr.add(2);
631 r_ptr = r_ptr.add(2);
632 k += 2;
633 }
634
635 if k < len {
636 let m0 = *m_ptr;
637 *b_ptr = *h_ptr - m0;
638 *r_ptr = *l_ptr - m0;
639 }
640}
641
642#[inline]
643pub fn eri_batch_with_kernel(
644 high: &[f64],
645 low: &[f64],
646 source: &[f64],
647 sweep: &EriBatchRange,
648 k: Kernel,
649) -> Result<EriBatchOutput, EriError> {
650 let kernel = match k {
651 Kernel::Auto => detect_best_batch_kernel(),
652 other if other.is_batch() => other,
653 other => return Err(EriError::InvalidKernelForBatch(other)),
654 };
655 let simd = match kernel {
656 Kernel::Avx512Batch => Kernel::Avx512,
657 Kernel::Avx2Batch => Kernel::Avx2,
658 Kernel::ScalarBatch => Kernel::Scalar,
659 _ => unreachable!(),
660 };
661 eri_batch_par_slice(high, low, source, sweep, simd)
662}
663
664#[derive(Clone, Debug)]
665pub struct EriBatchRange {
666 pub period: (usize, usize, usize),
667 pub ma_type: String,
668}
669
670impl Default for EriBatchRange {
671 fn default() -> Self {
672 Self {
673 period: (13, 262, 1),
674 ma_type: "ema".into(),
675 }
676 }
677}
678
679#[derive(Clone, Debug, Default)]
680pub struct EriBatchBuilder {
681 range: EriBatchRange,
682 kernel: Kernel,
683}
684
685impl EriBatchBuilder {
686 pub fn new() -> Self {
687 Self::default()
688 }
689 pub fn kernel(mut self, k: Kernel) -> Self {
690 self.kernel = k;
691 self
692 }
693 #[inline]
694 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
695 self.range.period = (start, end, step);
696 self
697 }
698 #[inline]
699 pub fn period_static(mut self, p: usize) -> Self {
700 self.range.period = (p, p, 0);
701 self
702 }
703 pub fn apply_slices(
704 self,
705 high: &[f64],
706 low: &[f64],
707 source: &[f64],
708 ) -> Result<EriBatchOutput, EriError> {
709 eri_batch_with_kernel(high, low, source, &self.range, self.kernel)
710 }
711}
712
713#[derive(Clone, Debug)]
714pub struct EriBatchOutput {
715 pub bull: Vec<f64>,
716 pub bear: Vec<f64>,
717 pub params: Vec<EriParams>,
718 pub rows: usize,
719 pub cols: usize,
720}
721
722impl EriBatchOutput {
723 pub fn row_for_params(&self, p: &EriParams) -> Option<usize> {
724 self.params
725 .iter()
726 .position(|c| c.period == p.period && c.ma_type == p.ma_type)
727 }
728 pub fn values_for_bull(&self, p: &EriParams) -> Option<&[f64]> {
729 self.row_for_params(p).map(|row| {
730 let start = row * self.cols;
731 &self.bull[start..start + self.cols]
732 })
733 }
734 pub fn values_for_bear(&self, p: &EriParams) -> Option<&[f64]> {
735 self.row_for_params(p).map(|row| {
736 let start = row * self.cols;
737 &self.bear[start..start + self.cols]
738 })
739 }
740}
741
742#[inline(always)]
743fn expand_grid(r: &EriBatchRange) -> Result<Vec<EriParams>, EriError> {
744 let (start, end, step) = r.period;
745
746 if step == 0 {
747 return Ok(vec![EriParams {
748 period: Some(start),
749 ma_type: Some(r.ma_type.clone()),
750 }]);
751 }
752
753 let mut out: Vec<EriParams> = Vec::new();
754 if start == end {
755 out.push(EriParams {
756 period: Some(start),
757 ma_type: Some(r.ma_type.clone()),
758 });
759 return Ok(out);
760 }
761
762 if start < end {
763 let mut p = start;
764 while p <= end {
765 out.push(EriParams {
766 period: Some(p),
767 ma_type: Some(r.ma_type.clone()),
768 });
769 match p.checked_add(step) {
770 Some(next) => {
771 if next == p {
772 break;
773 }
774 p = next;
775 }
776 None => return Err(EriError::InvalidRange { start, end, step }),
777 }
778 }
779 } else {
780 let mut p = start;
781 while p >= end {
782 out.push(EriParams {
783 period: Some(p),
784 ma_type: Some(r.ma_type.clone()),
785 });
786
787 if p < step {
788 break;
789 }
790 p -= step;
791 if p == usize::MAX {
792 break;
793 }
794 }
795
796 if out.is_empty() {
797 return Err(EriError::InvalidRange { start, end, step });
798 }
799 }
800
801 if out.is_empty() {
802 return Err(EriError::InvalidRange { start, end, step });
803 }
804 Ok(out)
805}
806
807#[inline(always)]
808pub fn eri_batch_slice(
809 high: &[f64],
810 low: &[f64],
811 source: &[f64],
812 sweep: &EriBatchRange,
813 kern: Kernel,
814) -> Result<EriBatchOutput, EriError> {
815 eri_batch_inner(high, low, source, sweep, kern, false)
816}
817
818#[inline(always)]
819pub fn eri_batch_par_slice(
820 high: &[f64],
821 low: &[f64],
822 source: &[f64],
823 sweep: &EriBatchRange,
824 kern: Kernel,
825) -> Result<EriBatchOutput, EriError> {
826 eri_batch_inner(high, low, source, sweep, kern, true)
827}
828
829#[inline(always)]
830fn eri_batch_inner(
831 high: &[f64],
832 low: &[f64],
833 source: &[f64],
834 sweep: &EriBatchRange,
835 kern: Kernel,
836 parallel: bool,
837) -> Result<EriBatchOutput, EriError> {
838 let combos = expand_grid(sweep)?;
839
840 let first = high
841 .iter()
842 .zip(low.iter())
843 .zip(source.iter())
844 .position(|((h, l), s)| !h.is_nan() && !l.is_nan() && !s.is_nan())
845 .ok_or(EriError::AllValuesNaN)?;
846 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
847 if source.len() - first < max_p {
848 return Err(EriError::NotEnoughValidData {
849 needed: max_p,
850 valid: source.len() - first,
851 });
852 }
853 let rows = combos.len();
854 let cols = source.len();
855 let total = rows.checked_mul(cols).ok_or(EriError::InvalidRange {
856 start: sweep.period.0,
857 end: sweep.period.1,
858 step: sweep.period.2,
859 })?;
860
861 let mut buf_bull = make_uninit_matrix(rows, cols);
862 let mut buf_bear = make_uninit_matrix(rows, cols);
863
864 let warmup_periods: Vec<usize> = combos
865 .iter()
866 .map(|c| first + c.period.unwrap() - 1)
867 .collect();
868 init_matrix_prefixes(&mut buf_bull, cols, &warmup_periods);
869 init_matrix_prefixes(&mut buf_bear, cols, &warmup_periods);
870
871 let mut buf_bull_guard = std::mem::ManuallyDrop::new(buf_bull);
872 let mut buf_bear_guard = std::mem::ManuallyDrop::new(buf_bear);
873
874 let mut bull =
875 unsafe { std::slice::from_raw_parts_mut(buf_bull_guard.as_mut_ptr() as *mut f64, total) };
876 let mut bear =
877 unsafe { std::slice::from_raw_parts_mut(buf_bear_guard.as_mut_ptr() as *mut f64, total) };
878
879 let do_row = |row: usize, bull_row: &mut [f64], bear_row: &mut [f64]| -> Result<(), EriError> {
880 let period = combos[row].period.unwrap();
881 let ma_type = combos[row].ma_type.as_deref().unwrap();
882 let ma_vec = ma(ma_type, MaData::Slice(source), period)
883 .map_err(|e| EriError::MaCalculationError(e.to_string()))?;
884 match kern {
885 Kernel::Scalar => unsafe {
886 eri_row_scalar(high, low, &ma_vec, first, period, bull_row, bear_row)
887 },
888 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
889 Kernel::Avx2 => unsafe {
890 eri_row_avx2(high, low, &ma_vec, first, period, bull_row, bear_row)
891 },
892 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
893 Kernel::Avx512 => unsafe {
894 eri_row_avx512(high, low, &ma_vec, first, period, bull_row, bear_row)
895 },
896 _ => unreachable!(),
897 }
898 Ok(())
899 };
900
901 if parallel {
902 #[cfg(not(target_arch = "wasm32"))]
903 {
904 bull.par_chunks_mut(cols)
905 .zip(bear.par_chunks_mut(cols))
906 .enumerate()
907 .for_each(|(row, (bull_row, bear_row))| {
908 let _ = do_row(row, bull_row, bear_row);
909 });
910 }
911 #[cfg(target_arch = "wasm32")]
912 {
913 for (row, (bull_row, bear_row)) in
914 bull.chunks_mut(cols).zip(bear.chunks_mut(cols)).enumerate()
915 {
916 let _ = do_row(row, bull_row, bear_row);
917 }
918 }
919 } else {
920 for (row, (bull_row, bear_row)) in
921 bull.chunks_mut(cols).zip(bear.chunks_mut(cols)).enumerate()
922 {
923 let _ = do_row(row, bull_row, bear_row);
924 }
925 }
926
927 let bull_vec =
928 unsafe { Vec::from_raw_parts(buf_bull_guard.as_mut_ptr() as *mut f64, total, total) };
929 let bear_vec =
930 unsafe { Vec::from_raw_parts(buf_bear_guard.as_mut_ptr() as *mut f64, total, total) };
931
932 Ok(EriBatchOutput {
933 bull: bull_vec,
934 bear: bear_vec,
935 params: combos,
936 rows,
937 cols,
938 })
939}
940
941#[inline(always)]
942pub fn eri_batch_inner_into(
943 high: &[f64],
944 low: &[f64],
945 source: &[f64],
946 sweep: &EriBatchRange,
947 kern: Kernel,
948 parallel: bool,
949 bull_out: &mut [f64],
950 bear_out: &mut [f64],
951) -> Result<Vec<EriParams>, EriError> {
952 let combos = expand_grid(sweep)?;
953
954 let first = high
955 .iter()
956 .zip(low.iter())
957 .zip(source.iter())
958 .position(|((h, l), s)| !h.is_nan() && !l.is_nan() && !s.is_nan())
959 .ok_or(EriError::AllValuesNaN)?;
960 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
961 if source.len() - first < max_p {
962 return Err(EriError::NotEnoughValidData {
963 needed: max_p,
964 valid: source.len() - first,
965 });
966 }
967
968 let rows = combos.len();
969 let cols = source.len();
970
971 let expected = rows.checked_mul(cols).ok_or(EriError::InvalidRange {
972 start: sweep.period.0,
973 end: sweep.period.1,
974 step: sweep.period.2,
975 })?;
976 if bull_out.len() != expected || bear_out.len() != expected {
977 return Err(EriError::OutputLengthMismatch {
978 expected,
979 got: bull_out.len().max(bear_out.len()),
980 });
981 }
982
983 let do_row = |row: usize, bull_row: &mut [f64], bear_row: &mut [f64]| -> Result<(), EriError> {
984 let period = combos[row].period.unwrap();
985 let ma_type = combos[row].ma_type.as_deref().unwrap();
986 let ma_vec = ma(ma_type, MaData::Slice(source), period)
987 .map_err(|e| EriError::MaCalculationError(e.to_string()))?;
988 match kern {
989 Kernel::Scalar => unsafe {
990 eri_row_scalar(high, low, &ma_vec, first, period, bull_row, bear_row)
991 },
992 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
993 Kernel::Avx2 => unsafe {
994 eri_row_avx2(high, low, &ma_vec, first, period, bull_row, bear_row)
995 },
996 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
997 Kernel::Avx512 => unsafe {
998 eri_row_avx512(high, low, &ma_vec, first, period, bull_row, bear_row)
999 },
1000 _ => unreachable!(),
1001 }
1002 Ok(())
1003 };
1004
1005 if parallel {
1006 #[cfg(not(target_arch = "wasm32"))]
1007 {
1008 use rayon::prelude::*;
1009 bull_out
1010 .par_chunks_mut(cols)
1011 .zip(bear_out.par_chunks_mut(cols))
1012 .enumerate()
1013 .for_each(|(row, (bull_row, bear_row))| {
1014 let _ = do_row(row, bull_row, bear_row);
1015 });
1016 }
1017 #[cfg(target_arch = "wasm32")]
1018 for row in 0..rows {
1019 let bull_row = &mut bull_out[row * cols..(row + 1) * cols];
1020 let bear_row = &mut bear_out[row * cols..(row + 1) * cols];
1021 let _ = do_row(row, bull_row, bear_row);
1022 }
1023 } else {
1024 for row in 0..rows {
1025 let bull_row = &mut bull_out[row * cols..(row + 1) * cols];
1026 let bear_row = &mut bear_out[row * cols..(row + 1) * cols];
1027 let _ = do_row(row, bull_row, bear_row);
1028 }
1029 }
1030
1031 Ok(combos)
1032}
1033
1034#[inline(always)]
1035unsafe fn eri_row_scalar(
1036 high: &[f64],
1037 low: &[f64],
1038 ma: &[f64],
1039 first: usize,
1040 period: usize,
1041 bull: &mut [f64],
1042 bear: &mut [f64],
1043) {
1044 let mut i = first + period - 1;
1045 let n = high.len();
1046 if i >= n {
1047 return;
1048 }
1049
1050 let len = n - i;
1051 let mut h_ptr = high.as_ptr().add(i);
1052 let mut l_ptr = low.as_ptr().add(i);
1053 let mut m_ptr = ma.as_ptr().add(i);
1054 let mut b_ptr = bull.as_mut_ptr().add(i);
1055 let mut r_ptr = bear.as_mut_ptr().add(i);
1056
1057 let mut k = 0usize;
1058 while k + 4 <= len {
1059 let m0 = *m_ptr.add(0);
1060 *b_ptr.add(0) = *h_ptr.add(0) - m0;
1061 *r_ptr.add(0) = *l_ptr.add(0) - m0;
1062
1063 let m1 = *m_ptr.add(1);
1064 *b_ptr.add(1) = *h_ptr.add(1) - m1;
1065 *r_ptr.add(1) = *l_ptr.add(1) - m1;
1066
1067 let m2 = *m_ptr.add(2);
1068 *b_ptr.add(2) = *h_ptr.add(2) - m2;
1069 *r_ptr.add(2) = *l_ptr.add(2) - m2;
1070
1071 let m3 = *m_ptr.add(3);
1072 *b_ptr.add(3) = *h_ptr.add(3) - m3;
1073 *r_ptr.add(3) = *l_ptr.add(3) - m3;
1074
1075 h_ptr = h_ptr.add(4);
1076 l_ptr = l_ptr.add(4);
1077 m_ptr = m_ptr.add(4);
1078 b_ptr = b_ptr.add(4);
1079 r_ptr = r_ptr.add(4);
1080 k += 4;
1081 }
1082 if k + 2 <= len {
1083 let m0 = *m_ptr.add(0);
1084 *b_ptr.add(0) = *h_ptr.add(0) - m0;
1085 *r_ptr.add(0) = *l_ptr.add(0) - m0;
1086
1087 let m1 = *m_ptr.add(1);
1088 *b_ptr.add(1) = *h_ptr.add(1) - m1;
1089 *r_ptr.add(1) = *l_ptr.add(1) - m1;
1090
1091 h_ptr = h_ptr.add(2);
1092 l_ptr = l_ptr.add(2);
1093 m_ptr = m_ptr.add(2);
1094 b_ptr = b_ptr.add(2);
1095 r_ptr = r_ptr.add(2);
1096 k += 2;
1097 }
1098 if k < len {
1099 let m0 = *m_ptr;
1100 *b_ptr = *h_ptr - m0;
1101 *r_ptr = *l_ptr - m0;
1102 }
1103}
1104
1105#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1106#[inline(always)]
1107unsafe fn eri_row_avx2(
1108 high: &[f64],
1109 low: &[f64],
1110 ma: &[f64],
1111 first: usize,
1112 period: usize,
1113 bull: &mut [f64],
1114 bear: &mut [f64],
1115) {
1116 eri_avx2_core(high, low, ma, period, first, bull, bear)
1117}
1118
1119#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1120#[inline(always)]
1121unsafe fn eri_row_avx512(
1122 high: &[f64],
1123 low: &[f64],
1124 ma: &[f64],
1125 first: usize,
1126 period: usize,
1127 bull: &mut [f64],
1128 bear: &mut [f64],
1129) {
1130 if period <= 32 {
1131 eri_row_avx512_short(high, low, ma, first, period, bull, bear);
1132 } else {
1133 eri_row_avx512_long(high, low, ma, first, period, bull, bear);
1134 }
1135}
1136
1137#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1138#[inline(always)]
1139unsafe fn eri_row_avx512_short(
1140 high: &[f64],
1141 low: &[f64],
1142 ma: &[f64],
1143 first: usize,
1144 period: usize,
1145 bull: &mut [f64],
1146 bear: &mut [f64],
1147) {
1148 eri_avx512_core(high, low, ma, period, first, bull, bear)
1149}
1150
1151#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1152#[inline(always)]
1153unsafe fn eri_row_avx512_long(
1154 high: &[f64],
1155 low: &[f64],
1156 ma: &[f64],
1157 first: usize,
1158 period: usize,
1159 bull: &mut [f64],
1160 bear: &mut [f64],
1161) {
1162 eri_avx512_core(high, low, ma, period, first, bull, bear)
1163}
1164
1165#[derive(Debug, Clone)]
1166pub struct EriStream {
1167 period: usize,
1168 ma_type: String,
1169 engine: StreamMa,
1170 ready: bool,
1171}
1172
1173#[derive(Debug, Clone)]
1174enum StreamMa {
1175 Sma(SmaState),
1176 Ema(EmaState),
1177 Rma(EmaState),
1178 Dema(DemaState),
1179 Tema(TemaState),
1180 Wma(WmaState),
1181 Generic(GenericState),
1182}
1183
1184#[derive(Debug, Clone)]
1185struct SmaState {
1186 buf: Vec<f64>,
1187 pos: usize,
1188 count: usize,
1189 sum: f64,
1190 inv_n: f64,
1191}
1192
1193#[derive(Debug, Clone)]
1194struct EmaState {
1195 n: usize,
1196 alpha: f64,
1197 beta: f64,
1198
1199 init_sum: f64,
1200 init_count: usize,
1201 ema: f64,
1202}
1203
1204#[derive(Debug, Clone)]
1205struct DemaState {
1206 n: usize,
1207 alpha: f64,
1208 beta: f64,
1209 init_sum: f64,
1210 init_count: usize,
1211 e1: f64,
1212 e2: f64,
1213}
1214
1215#[derive(Debug, Clone)]
1216struct TemaState {
1217 n: usize,
1218 alpha: f64,
1219 beta: f64,
1220 init_sum: f64,
1221 init_count: usize,
1222 e1: f64,
1223 e2: f64,
1224 e3: f64,
1225}
1226
1227#[derive(Debug, Clone)]
1228struct WmaState {
1229 n: usize,
1230 den_inv: f64,
1231 buf: Vec<f64>,
1232 pos: usize,
1233 count: usize,
1234 s: f64,
1235 ws: f64,
1236}
1237
1238#[derive(Debug, Clone)]
1239struct GenericState {
1240 n: usize,
1241 buf: Vec<f64>,
1242 pos: usize,
1243 count: usize,
1244 scratch: Vec<f64>,
1245}
1246
1247impl EriStream {
1248 pub fn try_new(params: EriParams) -> Result<Self, EriError> {
1249 let period = params.period.unwrap_or(13);
1250 if period == 0 {
1251 return Err(EriError::InvalidPeriod {
1252 period,
1253 data_len: 0,
1254 });
1255 }
1256 let ma_type = params.ma_type.unwrap_or_else(|| "ema".to_string());
1257
1258 let engine = make_engine(period, &ma_type);
1259 Ok(Self {
1260 period,
1261 ma_type,
1262 engine,
1263 ready: false,
1264 })
1265 }
1266
1267 #[inline(always)]
1268 pub fn update(&mut self, high: f64, low: f64, source: f64) -> Option<(f64, f64)> {
1269 if high.is_nan() || low.is_nan() || source.is_nan() {
1270 self.reset();
1271 return None;
1272 }
1273
1274 let ma_val = match &mut self.engine {
1275 StreamMa::Sma(st) => sma_update(st, source),
1276 StreamMa::Ema(st) => ema_like_update(st, source),
1277 StreamMa::Rma(st) => ema_like_update(st, source),
1278 StreamMa::Dema(st) => dema_update(st, source),
1279 StreamMa::Tema(st) => tema_update(st, source),
1280 StreamMa::Wma(st) => wma_update(st, source),
1281 StreamMa::Generic(st) => generic_update(st, &self.ma_type, source),
1282 }?;
1283
1284 self.ready = true;
1285 Some((high - ma_val, low - ma_val))
1286 }
1287
1288 #[inline(always)]
1289 fn reset(&mut self) {
1290 self.ready = false;
1291 self.engine = make_engine(self.period, &self.ma_type);
1292 }
1293}
1294
1295#[inline(always)]
1296fn make_engine(period: usize, ma_type: &str) -> StreamMa {
1297 let t = ma_type.to_ascii_lowercase();
1298 match t.as_str() {
1299 "sma" => StreamMa::Sma(SmaState {
1300 buf: vec![0.0; period],
1301 pos: 0,
1302 count: 0,
1303 sum: 0.0,
1304 inv_n: 1.0 / period as f64,
1305 }),
1306 "ema" | "ewma" => StreamMa::Ema(EmaState {
1307 n: period,
1308 alpha: 2.0 / (period as f64 + 1.0),
1309 beta: 1.0 - (2.0 / (period as f64 + 1.0)),
1310 init_sum: 0.0,
1311 init_count: 0,
1312 ema: f64::NAN,
1313 }),
1314 "rma" | "wilder" | "smma" => StreamMa::Rma(EmaState {
1315 n: period,
1316 alpha: 1.0 / period as f64,
1317 beta: 1.0 - (1.0 / period as f64),
1318 init_sum: 0.0,
1319 init_count: 0,
1320 ema: f64::NAN,
1321 }),
1322 "dema" => StreamMa::Dema(DemaState {
1323 n: period,
1324 alpha: 2.0 / (period as f64 + 1.0),
1325 beta: 1.0 - (2.0 / (period as f64 + 1.0)),
1326 init_sum: 0.0,
1327 init_count: 0,
1328 e1: f64::NAN,
1329 e2: f64::NAN,
1330 }),
1331 "tema" => StreamMa::Tema(TemaState {
1332 n: period,
1333 alpha: 2.0 / (period as f64 + 1.0),
1334 beta: 1.0 - (2.0 / (period as f64 + 1.0)),
1335 init_sum: 0.0,
1336 init_count: 0,
1337 e1: f64::NAN,
1338 e2: f64::NAN,
1339 e3: f64::NAN,
1340 }),
1341 "wma" | "lwma" | "linear" | "linear_wma" => {
1342 let n = period as f64;
1343 let den_inv = 2.0 / (n * (n + 1.0));
1344 StreamMa::Wma(WmaState {
1345 n: period,
1346 den_inv,
1347 buf: vec![0.0; period],
1348 pos: 0,
1349 count: 0,
1350 s: 0.0,
1351 ws: 0.0,
1352 })
1353 }
1354 _ => StreamMa::Generic(GenericState {
1355 n: period,
1356 buf: vec![0.0; period],
1357 pos: 0,
1358 count: 0,
1359 scratch: vec![0.0; period],
1360 }),
1361 }
1362}
1363
1364#[inline(always)]
1365fn sma_update(st: &mut SmaState, x: f64) -> Option<f64> {
1366 let n = st.buf.len();
1367 if st.count < n {
1368 st.buf[st.pos] = x;
1369 st.sum += x;
1370 st.pos = (st.pos + 1) % n;
1371 st.count += 1;
1372 return (st.count == n).then(|| st.sum * st.inv_n);
1373 }
1374 let old = st.buf[st.pos];
1375 st.buf[st.pos] = x;
1376 st.sum += x - old;
1377 st.pos = (st.pos + 1) % n;
1378 Some(st.sum * st.inv_n)
1379}
1380
1381#[inline(always)]
1382fn ema_like_update(st: &mut EmaState, x: f64) -> Option<f64> {
1383 if st.init_count < st.n {
1384 st.init_sum += x;
1385 st.init_count += 1;
1386 if st.init_count == st.n {
1387 st.ema = st.init_sum / st.n as f64;
1388 return Some(st.ema);
1389 }
1390 return None;
1391 }
1392 st.ema = x.mul_add(st.alpha, st.beta * st.ema);
1393 Some(st.ema)
1394}
1395
1396#[inline(always)]
1397fn dema_update(st: &mut DemaState, x: f64) -> Option<f64> {
1398 if st.init_count < st.n {
1399 st.init_sum += x;
1400 st.init_count += 1;
1401 if st.init_count == st.n {
1402 st.e1 = st.init_sum / st.n as f64;
1403 st.e2 = st.e1;
1404 return Some(st.e1);
1405 }
1406 return None;
1407 }
1408 st.e1 = x.mul_add(st.alpha, st.beta * st.e1);
1409 st.e2 = st.e1.mul_add(st.alpha, st.beta * st.e2);
1410 Some(2.0f64.mul_add(st.e1, -st.e2))
1411}
1412
1413#[inline(always)]
1414fn tema_update(st: &mut TemaState, x: f64) -> Option<f64> {
1415 if st.init_count < st.n {
1416 st.init_sum += x;
1417 st.init_count += 1;
1418 if st.init_count == st.n {
1419 st.e1 = st.init_sum / st.n as f64;
1420 st.e2 = st.e1;
1421 st.e3 = st.e2;
1422 return Some(st.e1);
1423 }
1424 return None;
1425 }
1426 st.e1 = x.mul_add(st.alpha, st.beta * st.e1);
1427 st.e2 = st.e1.mul_add(st.alpha, st.beta * st.e2);
1428 st.e3 = st.e2.mul_add(st.alpha, st.beta * st.e3);
1429 Some((3.0 * st.e1) - (3.0 * st.e2) + st.e3)
1430}
1431
1432#[inline(always)]
1433fn wma_update(st: &mut WmaState, x: f64) -> Option<f64> {
1434 let n = st.n;
1435 if st.count < n {
1436 st.buf[st.pos] = x;
1437 st.s += x;
1438 st.ws += (st.count as f64 + 1.0) * x;
1439 st.pos = (st.pos + 1) % n;
1440 st.count += 1;
1441 return (st.count == n).then(|| st.ws * st.den_inv);
1442 }
1443 let old = st.buf[st.pos];
1444 st.buf[st.pos] = x;
1445 let s_prev = st.s;
1446 st.s = s_prev - old + x;
1447 st.ws = st.ws - s_prev + (n as f64) * x;
1448 st.pos = (st.pos + 1) % n;
1449 Some(st.ws * st.den_inv)
1450}
1451
1452#[inline(always)]
1453fn generic_update(st: &mut GenericState, ma_type: &str, x: f64) -> Option<f64> {
1454 let n = st.n;
1455 if st.count < n {
1456 st.buf[st.pos] = x;
1457 st.pos = (st.pos + 1) % n;
1458 st.count += 1;
1459 return None;
1460 }
1461 st.buf[st.pos] = x;
1462 st.pos = (st.pos + 1) % n;
1463
1464 for i in 0..n {
1465 let src_idx = (st.pos + i) % n;
1466 st.scratch[i] = st.buf[src_idx];
1467 }
1468 let m = ma(ma_type, MaData::Slice(&st.scratch), n).ok()?;
1469 m.last().copied()
1470}
1471
1472#[cfg(feature = "python")]
1473#[pyfunction(name = "eri")]
1474#[pyo3(signature = (high, low, source, period=13, ma_type="ema", kernel=None))]
1475pub fn eri_py<'py>(
1476 py: Python<'py>,
1477 high: numpy::PyReadonlyArray1<'py, f64>,
1478 low: numpy::PyReadonlyArray1<'py, f64>,
1479 source: numpy::PyReadonlyArray1<'py, f64>,
1480 period: usize,
1481 ma_type: &str,
1482 kernel: Option<&str>,
1483) -> PyResult<(Bound<'py, PyArray1<f64>>, Bound<'py, PyArray1<f64>>)> {
1484 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1485
1486 let high_slice = high.as_slice()?;
1487 let low_slice = low.as_slice()?;
1488 let source_slice = source.as_slice()?;
1489
1490 if high_slice.len() != low_slice.len() || high_slice.len() != source_slice.len() {
1491 return Err(PyValueError::new_err(
1492 "high, low, and source arrays must have the same length",
1493 ));
1494 }
1495
1496 let kern = validate_kernel(kernel, false)?;
1497 let params = EriParams {
1498 period: Some(period),
1499 ma_type: Some(ma_type.to_string()),
1500 };
1501 let input = EriInput::from_slices(high_slice, low_slice, source_slice, params);
1502
1503 let result = py
1504 .allow_threads(|| eri_with_kernel(&input, kern))
1505 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1506
1507 Ok((result.bull.into_pyarray(py), result.bear.into_pyarray(py)))
1508}
1509
1510#[cfg(feature = "python")]
1511#[pyclass(name = "EriStream")]
1512pub struct EriStreamPy {
1513 stream: EriStream,
1514}
1515
1516#[cfg(feature = "python")]
1517#[pymethods]
1518impl EriStreamPy {
1519 #[new]
1520 fn new(period: usize, ma_type: Option<&str>) -> PyResult<Self> {
1521 let params = EriParams {
1522 period: Some(period),
1523 ma_type: ma_type
1524 .map(|s| s.to_string())
1525 .or_else(|| Some("ema".to_string())),
1526 };
1527 let stream =
1528 EriStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1529 Ok(EriStreamPy { stream })
1530 }
1531
1532 fn update(&mut self, high: f64, low: f64, source: f64) -> Option<(f64, f64)> {
1533 self.stream.update(high, low, source)
1534 }
1535}
1536
1537#[cfg(feature = "python")]
1538#[pyfunction(name = "eri_batch")]
1539#[pyo3(signature = (high, low, source, period_range=(13, 13, 0), ma_type="ema", kernel=None))]
1540pub fn eri_batch_py<'py>(
1541 py: Python<'py>,
1542 high: numpy::PyReadonlyArray1<'py, f64>,
1543 low: numpy::PyReadonlyArray1<'py, f64>,
1544 source: numpy::PyReadonlyArray1<'py, f64>,
1545 period_range: (usize, usize, usize),
1546 ma_type: &str,
1547 kernel: Option<&str>,
1548) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1549 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1550 use pyo3::types::PyDict;
1551
1552 let high_slice = high.as_slice()?;
1553 let low_slice = low.as_slice()?;
1554 let source_slice = source.as_slice()?;
1555
1556 if high_slice.len() != low_slice.len() || high_slice.len() != source_slice.len() {
1557 return Err(PyValueError::new_err(
1558 "high, low, and source arrays must have the same length",
1559 ));
1560 }
1561
1562 let sweep = EriBatchRange {
1563 period: period_range,
1564 ma_type: ma_type.to_string(),
1565 };
1566
1567 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1568 let rows = combos.len();
1569 let cols = high_slice.len();
1570 let total = rows
1571 .checked_mul(cols)
1572 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
1573
1574 let bull_array: Bound<'py, PyArray1<f64>> = unsafe { PyArray1::<f64>::new(py, [total], false) };
1575 let bear_array: Bound<'py, PyArray1<f64>> = unsafe { PyArray1::<f64>::new(py, [total], false) };
1576
1577 let bull_slice = unsafe { bull_array.as_slice_mut()? };
1578 let bear_slice = unsafe { bear_array.as_slice_mut()? };
1579
1580 let first_valid = high_slice
1581 .iter()
1582 .zip(low_slice.iter())
1583 .zip(source_slice.iter())
1584 .position(|((h, l), s)| !h.is_nan() && !l.is_nan() && !s.is_nan())
1585 .unwrap_or(0);
1586
1587 for (row, combo) in combos.iter().enumerate() {
1588 let period = combo.period.unwrap();
1589 let warmup = first_valid + period - 1;
1590 let row_start = row * cols;
1591 for i in 0..warmup.min(cols) {
1592 bull_slice[row_start + i] = f64::NAN;
1593 bear_slice[row_start + i] = f64::NAN;
1594 }
1595 }
1596
1597 let kern = validate_kernel(kernel, true)?;
1598 let kernel_to_use = match kern {
1599 Kernel::Auto => detect_best_batch_kernel(),
1600 k => k,
1601 };
1602 let simd = match kernel_to_use {
1603 Kernel::Avx512Batch => Kernel::Avx512,
1604 Kernel::Avx2Batch => Kernel::Avx2,
1605 Kernel::ScalarBatch => Kernel::Scalar,
1606 _ => Kernel::Scalar,
1607 };
1608
1609 let combos = py
1610 .allow_threads(|| {
1611 eri_batch_inner_into(
1612 high_slice,
1613 low_slice,
1614 source_slice,
1615 &sweep,
1616 simd,
1617 true,
1618 bull_slice,
1619 bear_slice,
1620 )
1621 })
1622 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1623
1624 let bull_reshaped = bull_array.reshape([rows, cols])?;
1625 let bear_reshaped = bear_array.reshape([rows, cols])?;
1626
1627 let periods: Vec<usize> = combos.iter().map(|c| c.period.unwrap()).collect();
1628 let ma_types: Vec<&str> = vec![ma_type; combos.len()];
1629
1630 let dict = PyDict::new(py);
1631 dict.set_item("bull_values", bull_reshaped)?;
1632 dict.set_item("bear_values", bear_reshaped)?;
1633 dict.set_item("periods", periods.into_pyarray(py))?;
1634 dict.set_item("ma_types", ma_types)?;
1635
1636 Ok(dict.into())
1637}
1638
1639pub fn eri_into_slice(
1640 dst_bull: &mut [f64],
1641 dst_bear: &mut [f64],
1642 input: &EriInput,
1643 kern: Kernel,
1644) -> Result<(), EriError> {
1645 let (high, low, source_data) = match &input.data {
1646 EriData::Candles { candles, source } => {
1647 let high = candles
1648 .select_candle_field("high")
1649 .map_err(|_| EriError::EmptyInputData)?;
1650 let low = candles
1651 .select_candle_field("low")
1652 .map_err(|_| EriError::EmptyInputData)?;
1653 let src = source_type(candles, source);
1654 (high, low, src)
1655 }
1656 EriData::Slices { high, low, source } => (*high, *low, *source),
1657 };
1658
1659 if source_data.is_empty() || high.is_empty() || low.is_empty() {
1660 return Err(EriError::EmptyInputData);
1661 }
1662
1663 if dst_bull.len() != source_data.len() || dst_bear.len() != source_data.len() {
1664 return Err(EriError::OutputLengthMismatch {
1665 expected: source_data.len(),
1666 got: dst_bull.len().max(dst_bear.len()),
1667 });
1668 }
1669
1670 let period = input.get_period();
1671 if period == 0 || period > source_data.len() {
1672 return Err(EriError::InvalidPeriod {
1673 period,
1674 data_len: source_data.len(),
1675 });
1676 }
1677
1678 let mut first_valid_idx = None;
1679 for i in 0..source_data.len() {
1680 if !(source_data[i].is_nan() || high[i].is_nan() || low[i].is_nan()) {
1681 first_valid_idx = Some(i);
1682 break;
1683 }
1684 }
1685 let first_valid_idx = match first_valid_idx {
1686 Some(idx) => idx,
1687 None => return Err(EriError::AllValuesNaN),
1688 };
1689
1690 if (source_data.len() - first_valid_idx) < period {
1691 return Err(EriError::NotEnoughValidData {
1692 needed: period,
1693 valid: source_data.len() - first_valid_idx,
1694 });
1695 }
1696
1697 let ma_type = input.get_ma_type();
1698 let full_ma = ma(&ma_type, MaData::Slice(&source_data), period)
1699 .map_err(|e| EriError::MaCalculationError(e.to_string()))?;
1700
1701 let warmup_period = first_valid_idx + period - 1;
1702
1703 for v in &mut dst_bull[..warmup_period] {
1704 *v = f64::NAN;
1705 }
1706 for v in &mut dst_bear[..warmup_period] {
1707 *v = f64::NAN;
1708 }
1709
1710 let chosen = match kern {
1711 Kernel::Auto => Kernel::Scalar,
1712 other => other,
1713 };
1714
1715 unsafe {
1716 match chosen {
1717 Kernel::Scalar | Kernel::ScalarBatch => eri_scalar(
1718 high,
1719 low,
1720 &full_ma,
1721 period,
1722 first_valid_idx,
1723 dst_bull,
1724 dst_bear,
1725 ),
1726 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1727 Kernel::Avx2 | Kernel::Avx2Batch => eri_avx2(
1728 high,
1729 low,
1730 &full_ma,
1731 period,
1732 first_valid_idx,
1733 dst_bull,
1734 dst_bear,
1735 ),
1736 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1737 Kernel::Avx512 | Kernel::Avx512Batch => eri_avx512(
1738 high,
1739 low,
1740 &full_ma,
1741 period,
1742 first_valid_idx,
1743 dst_bull,
1744 dst_bear,
1745 ),
1746 _ => unreachable!(),
1747 }
1748 }
1749
1750 Ok(())
1751}
1752
1753#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1754#[inline]
1755pub fn eri_into(
1756 input: &EriInput,
1757 bull_out: &mut [f64],
1758 bear_out: &mut [f64],
1759) -> Result<(), EriError> {
1760 eri_into_slice(bull_out, bear_out, input, Kernel::Auto)
1761}
1762
1763#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1764#[derive(Serialize, Deserialize)]
1765pub struct EriResult {
1766 pub values: Vec<f64>,
1767 pub rows: usize,
1768 pub cols: usize,
1769}
1770
1771#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1772#[wasm_bindgen]
1773pub fn eri_js_flat(
1774 high: &[f64],
1775 low: &[f64],
1776 source: &[f64],
1777 period: usize,
1778 ma_type: &str,
1779) -> Result<JsValue, JsValue> {
1780 if high.len() != low.len() || high.len() != source.len() {
1781 return Err(JsValue::from_str("length mismatch"));
1782 }
1783 let params = EriParams {
1784 period: Some(period),
1785 ma_type: Some(ma_type.to_string()),
1786 };
1787 let input = EriInput::from_slices(high, low, source, params);
1788
1789 let mut bull = vec![0.0; source.len()];
1790 let mut bear = vec![0.0; source.len()];
1791 eri_into_slice(&mut bull, &mut bear, &input, Kernel::Auto)
1792 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1793
1794 let mut values = bull;
1795 values.extend_from_slice(&bear);
1796
1797 let out = EriResult {
1798 values,
1799 rows: 2,
1800 cols: source.len(),
1801 };
1802 serde_wasm_bindgen::to_value(&out)
1803 .map_err(|e| JsValue::from_str(&format!("Serialization error: {e}")))
1804}
1805
1806#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1807#[wasm_bindgen]
1808pub fn eri_js(
1809 high: &[f64],
1810 low: &[f64],
1811 source: &[f64],
1812 period: usize,
1813 ma_type: &str,
1814) -> Result<Vec<f64>, JsValue> {
1815 if high.len() != low.len() || high.len() != source.len() {
1816 return Err(JsValue::from_str(
1817 "high, low, and source arrays must have the same length",
1818 ));
1819 }
1820
1821 let params = EriParams {
1822 period: Some(period),
1823 ma_type: Some(ma_type.to_string()),
1824 };
1825 let input = EriInput::from_slices(high, low, source, params);
1826
1827 let total = source
1828 .len()
1829 .checked_mul(2)
1830 .ok_or_else(|| JsValue::from_str("length overflow"))?;
1831 let mut output = vec![0.0; total];
1832 let (bull_part, bear_part) = output.split_at_mut(source.len());
1833
1834 eri_into_slice(bull_part, bear_part, &input, Kernel::Auto)
1835 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1836
1837 Ok(output)
1838}
1839
1840#[cfg(all(feature = "python", feature = "cuda"))]
1841use crate::cuda::eri_wrapper::CudaEri;
1842#[cfg(all(feature = "python", feature = "cuda"))]
1843use crate::utilities::dlpack_cuda::{make_device_array_py, DeviceArrayF32Py};
1844#[cfg(all(feature = "python", feature = "cuda"))]
1845#[pyfunction(name = "eri_cuda_batch_dev")]
1846#[pyo3(signature = (high_f32, low_f32, source_f32, period_range, ma_type, device_id=0))]
1847pub fn eri_cuda_batch_dev_py(
1848 py: Python<'_>,
1849 high_f32: numpy::PyReadonlyArray1<'_, f32>,
1850 low_f32: numpy::PyReadonlyArray1<'_, f32>,
1851 source_f32: numpy::PyReadonlyArray1<'_, f32>,
1852 period_range: (usize, usize, usize),
1853 ma_type: &str,
1854 device_id: usize,
1855) -> PyResult<(DeviceArrayF32Py, DeviceArrayF32Py)> {
1856 use crate::cuda::cuda_available;
1857 if !cuda_available() {
1858 return Err(PyValueError::new_err("CUDA not available"));
1859 }
1860 let h = high_f32.as_slice()?;
1861 let l = low_f32.as_slice()?;
1862 let s = source_f32.as_slice()?;
1863 let sweep = EriBatchRange {
1864 period: period_range,
1865 ma_type: ma_type.to_string(),
1866 };
1867 let (bull, bear) = py.allow_threads(|| {
1868 let cuda = CudaEri::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1869 cuda.eri_batch_dev(h, l, s, &sweep)
1870 .map_err(|e| PyValueError::new_err(e.to_string()))
1871 .map(|((bull, bear), _combos)| (bull, bear))
1872 })?;
1873 let bull_dev = make_device_array_py(device_id, bull)?;
1874 let bear_dev = make_device_array_py(device_id, bear)?;
1875 Ok((bull_dev, bear_dev))
1876}
1877
1878#[cfg(all(feature = "python", feature = "cuda"))]
1879#[pyfunction(name = "eri_cuda_many_series_one_param_dev")]
1880#[pyo3(signature = (high_tm_f32, low_tm_f32, source_tm_f32, cols, rows, period, ma_type, device_id=0))]
1881pub fn eri_cuda_many_series_one_param_dev_py(
1882 py: Python<'_>,
1883 high_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
1884 low_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
1885 source_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
1886 cols: usize,
1887 rows: usize,
1888 period: usize,
1889 ma_type: &str,
1890 device_id: usize,
1891) -> PyResult<(DeviceArrayF32Py, DeviceArrayF32Py)> {
1892 use crate::cuda::cuda_available;
1893 if !cuda_available() {
1894 return Err(PyValueError::new_err("CUDA not available"));
1895 }
1896 let h = high_tm_f32.as_slice()?;
1897 let l = low_tm_f32.as_slice()?;
1898 let s = source_tm_f32.as_slice()?;
1899 let (bull, bear) = py.allow_threads(|| {
1900 let cuda = CudaEri::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1901 cuda.eri_many_series_one_param_time_major_dev(h, l, s, cols, rows, period, ma_type)
1902 .map_err(|e| PyValueError::new_err(e.to_string()))
1903 .map(|(bull, bear)| (bull, bear))
1904 })?;
1905 let bull_dev = make_device_array_py(device_id, bull)?;
1906 let bear_dev = make_device_array_py(device_id, bear)?;
1907 Ok((bull_dev, bear_dev))
1908}
1909
1910#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1911#[wasm_bindgen]
1912pub fn eri_into(
1913 high_ptr: *const f64,
1914 low_ptr: *const f64,
1915 source_ptr: *const f64,
1916 bull_ptr: *mut f64,
1917 bear_ptr: *mut f64,
1918 len: usize,
1919 period: usize,
1920 ma_type: &str,
1921) -> Result<(), JsValue> {
1922 if high_ptr.is_null()
1923 || low_ptr.is_null()
1924 || source_ptr.is_null()
1925 || bull_ptr.is_null()
1926 || bear_ptr.is_null()
1927 {
1928 return Err(JsValue::from_str("Null pointer provided"));
1929 }
1930
1931 unsafe {
1932 let high = std::slice::from_raw_parts(high_ptr, len);
1933 let low = std::slice::from_raw_parts(low_ptr, len);
1934 let source = std::slice::from_raw_parts(source_ptr, len);
1935
1936 let params = EriParams {
1937 period: Some(period),
1938 ma_type: Some(ma_type.to_string()),
1939 };
1940 let input = EriInput::from_slices(high, low, source, params);
1941
1942 let needs_temp = bull_ptr as *const f64 == high_ptr
1943 || bull_ptr as *const f64 == low_ptr
1944 || bull_ptr as *const f64 == source_ptr
1945 || bear_ptr as *const f64 == high_ptr
1946 || bear_ptr as *const f64 == low_ptr
1947 || bear_ptr as *const f64 == source_ptr
1948 || bull_ptr == bear_ptr;
1949
1950 if needs_temp {
1951 let mut temp_bull = vec![0.0; len];
1952 let mut temp_bear = vec![0.0; len];
1953 eri_into_slice(&mut temp_bull, &mut temp_bear, &input, Kernel::Auto)
1954 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1955
1956 let bull_out = std::slice::from_raw_parts_mut(bull_ptr, len);
1957 let bear_out = std::slice::from_raw_parts_mut(bear_ptr, len);
1958 bull_out.copy_from_slice(&temp_bull);
1959 bear_out.copy_from_slice(&temp_bear);
1960 } else {
1961 let bull_out = std::slice::from_raw_parts_mut(bull_ptr, len);
1962 let bear_out = std::slice::from_raw_parts_mut(bear_ptr, len);
1963 eri_into_slice(bull_out, bear_out, &input, Kernel::Auto)
1964 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1965 }
1966
1967 Ok(())
1968 }
1969}
1970
1971#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1972#[wasm_bindgen]
1973pub fn eri_alloc(len: usize) -> *mut f64 {
1974 let mut vec = Vec::<f64>::with_capacity(len);
1975 let ptr = vec.as_mut_ptr();
1976 std::mem::forget(vec);
1977 ptr
1978}
1979
1980#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1981#[wasm_bindgen]
1982pub fn eri_free(ptr: *mut f64, len: usize) {
1983 if !ptr.is_null() {
1984 unsafe {
1985 let _ = Vec::from_raw_parts(ptr, len, len);
1986 }
1987 }
1988}
1989
1990#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1991#[derive(Serialize, Deserialize)]
1992pub struct EriBatchConfig {
1993 pub period_range: (usize, usize, usize),
1994 pub ma_type: String,
1995}
1996
1997#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1998#[derive(Serialize, Deserialize)]
1999pub struct EriBatchJsOutput {
2000 pub values: Vec<f64>,
2001 pub rows: usize,
2002 pub cols: usize,
2003 pub periods: Vec<usize>,
2004}
2005
2006#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2007#[wasm_bindgen(js_name = eri_batch)]
2008pub fn eri_batch_js(
2009 high: &[f64],
2010 low: &[f64],
2011 source: &[f64],
2012 config: JsValue,
2013) -> Result<JsValue, JsValue> {
2014 if high.len() != low.len() || high.len() != source.len() {
2015 return Err(JsValue::from_str(
2016 "high, low, and source arrays must have the same length",
2017 ));
2018 }
2019
2020 let config: EriBatchConfig = serde_wasm_bindgen::from_value(config)
2021 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2022
2023 let sweep = EriBatchRange {
2024 period: config.period_range,
2025 ma_type: config.ma_type,
2026 };
2027
2028 let output = eri_batch_with_kernel(high, low, source, &sweep, Kernel::Auto)
2029 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2030
2031 let rows = output
2032 .rows
2033 .checked_mul(2)
2034 .ok_or_else(|| JsValue::from_str("rows overflow"))?;
2035 let cols = output.cols;
2036
2037 let cap = rows
2038 .checked_mul(cols)
2039 .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
2040 let mut values = Vec::with_capacity(cap);
2041
2042 for r in 0..output.rows {
2043 let start = r * cols;
2044 values.extend_from_slice(&output.bull[start..start + cols]);
2045 }
2046
2047 for r in 0..output.rows {
2048 let start = r * cols;
2049 values.extend_from_slice(&output.bear[start..start + cols]);
2050 }
2051
2052 let periods: Vec<usize> = output.params.iter().map(|p| p.period.unwrap()).collect();
2053
2054 let js_output = EriBatchJsOutput {
2055 values,
2056 rows,
2057 cols,
2058 periods,
2059 };
2060 serde_wasm_bindgen::to_value(&js_output)
2061 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2062}
2063
2064#[inline]
2065pub unsafe fn eri_scalar_classic_sma(
2066 high: &[f64],
2067 low: &[f64],
2068 source: &[f64],
2069 period: usize,
2070 first_valid_idx: usize,
2071 bull: &mut [f64],
2072 bear: &mut [f64],
2073) -> Result<(), EriError> {
2074 let start_idx = first_valid_idx + period - 1;
2075
2076 let mut sum = 0.0;
2077 for i in 0..period {
2078 sum += source[first_valid_idx + i];
2079 }
2080 let mut sma = sum / period as f64;
2081
2082 bull[start_idx] = high[start_idx] - sma;
2083 bear[start_idx] = low[start_idx] - sma;
2084
2085 for i in (start_idx + 1)..source.len() {
2086 let old_val = source[i - period];
2087 let new_val = source[i];
2088 sum = sum - old_val + new_val;
2089 sma = sum / period as f64;
2090
2091 bull[i] = high[i] - sma;
2092 bear[i] = low[i] - sma;
2093 }
2094
2095 Ok(())
2096}
2097
2098#[inline]
2099pub unsafe fn eri_scalar_classic_ema(
2100 high: &[f64],
2101 low: &[f64],
2102 source: &[f64],
2103 period: usize,
2104 first_valid_idx: usize,
2105 bull: &mut [f64],
2106 bear: &mut [f64],
2107) -> Result<(), EriError> {
2108 let start_idx = first_valid_idx + period - 1;
2109 let alpha = 2.0 / (period as f64 + 1.0);
2110 let beta = 1.0 - alpha;
2111
2112 let mut sum = 0.0;
2113 for i in 0..period {
2114 sum += source[first_valid_idx + i];
2115 }
2116 let mut ema = sum / period as f64;
2117
2118 bull[start_idx] = high[start_idx] - ema;
2119 bear[start_idx] = low[start_idx] - ema;
2120
2121 for i in (start_idx + 1)..source.len() {
2122 ema = alpha * source[i] + beta * ema;
2123
2124 bull[i] = high[i] - ema;
2125 bear[i] = low[i] - ema;
2126 }
2127
2128 Ok(())
2129}
2130
2131#[cfg(test)]
2132mod tests {
2133 use super::*;
2134 use crate::skip_if_unsupported;
2135 use crate::utilities::data_loader::read_candles_from_csv;
2136 use crate::utilities::enums::Kernel;
2137
2138 fn check_eri_partial_params(
2139 test_name: &str,
2140 kernel: Kernel,
2141 ) -> Result<(), Box<dyn std::error::Error>> {
2142 skip_if_unsupported!(kernel, test_name);
2143 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2144 let candles = read_candles_from_csv(file_path)?;
2145
2146 let default_params = EriParams {
2147 period: None,
2148 ma_type: None,
2149 };
2150 let input_default = EriInput::from_candles(&candles, "close", default_params);
2151 let output_default = eri_with_kernel(&input_default, kernel)?;
2152 assert_eq!(output_default.bull.len(), candles.close.len());
2153 assert_eq!(output_default.bear.len(), candles.close.len());
2154
2155 let params_period_14 = EriParams {
2156 period: Some(14),
2157 ma_type: Some("ema".to_string()),
2158 };
2159 let input_period_14 = EriInput::from_candles(&candles, "hl2", params_period_14);
2160 let output_period_14 = eri_with_kernel(&input_period_14, kernel)?;
2161 assert_eq!(output_period_14.bull.len(), candles.close.len());
2162 assert_eq!(output_period_14.bear.len(), candles.close.len());
2163
2164 let params_custom = EriParams {
2165 period: Some(20),
2166 ma_type: Some("sma".to_string()),
2167 };
2168 let input_custom = EriInput::from_candles(&candles, "hlc3", params_custom);
2169 let output_custom = eri_with_kernel(&input_custom, kernel)?;
2170 assert_eq!(output_custom.bull.len(), candles.close.len());
2171 assert_eq!(output_custom.bear.len(), candles.close.len());
2172
2173 Ok(())
2174 }
2175
2176 fn check_eri_accuracy(
2177 test_name: &str,
2178 kernel: Kernel,
2179 ) -> Result<(), Box<dyn std::error::Error>> {
2180 skip_if_unsupported!(kernel, test_name);
2181 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2182 let candles = read_candles_from_csv(file_path)?;
2183 let close_prices = candles
2184 .select_candle_field("close")
2185 .expect("Failed to extract close prices");
2186
2187 let params = EriParams {
2188 period: Some(13),
2189 ma_type: Some("ema".to_string()),
2190 };
2191 let input = EriInput::from_candles(&candles, "close", params);
2192 let eri_result = eri_with_kernel(&input, kernel)?;
2193
2194 assert_eq!(eri_result.bull.len(), close_prices.len());
2195 assert_eq!(eri_result.bear.len(), close_prices.len());
2196
2197 let expected_bull_last_five = [
2198 -103.35343557205488,
2199 6.839912366813223,
2200 -42.851503685589705,
2201 -9.444146016219747,
2202 11.476446271808527,
2203 ];
2204 let expected_bear_last_five = [
2205 -433.3534355720549,
2206 -314.1600876331868,
2207 -414.8515036855897,
2208 -336.44414601621975,
2209 -925.5235537281915,
2210 ];
2211
2212 let start_index = eri_result.bull.len() - 5;
2213 for i in 0..5 {
2214 let actual_bull = eri_result.bull[start_index + i];
2215 let actual_bear = eri_result.bear[start_index + i];
2216 let expected_bull = expected_bull_last_five[i];
2217 let expected_bear = expected_bear_last_five[i];
2218 assert!(
2219 (actual_bull - expected_bull).abs() < 1e-2,
2220 "ERI bull mismatch at index {}: expected {}, got {}",
2221 i,
2222 expected_bull,
2223 actual_bull
2224 );
2225 assert!(
2226 (actual_bear - expected_bear).abs() < 1e-2,
2227 "ERI bear mismatch at index {}: expected {}, got {}",
2228 i,
2229 expected_bear,
2230 actual_bear
2231 );
2232 }
2233 Ok(())
2234 }
2235
2236 fn check_eri_default_candles(
2237 test_name: &str,
2238 kernel: Kernel,
2239 ) -> Result<(), Box<dyn std::error::Error>> {
2240 skip_if_unsupported!(kernel, test_name);
2241 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2242 let candles = read_candles_from_csv(file_path)?;
2243
2244 let input = EriInput::with_default_candles(&candles);
2245 match input.data {
2246 EriData::Candles { source, .. } => assert_eq!(source, "close"),
2247 _ => panic!("Expected EriData::Candles"),
2248 }
2249 let output = eri_with_kernel(&input, kernel)?;
2250 assert_eq!(output.bull.len(), candles.close.len());
2251 assert_eq!(output.bear.len(), candles.close.len());
2252
2253 Ok(())
2254 }
2255
2256 fn check_eri_zero_period(
2257 test_name: &str,
2258 kernel: Kernel,
2259 ) -> Result<(), Box<dyn std::error::Error>> {
2260 skip_if_unsupported!(kernel, test_name);
2261 let high = [10.0, 20.0, 30.0];
2262 let low = [8.0, 18.0, 28.0];
2263 let src = [9.0, 19.0, 29.0];
2264 let params = EriParams {
2265 period: Some(0),
2266 ma_type: Some("ema".to_string()),
2267 };
2268 let input = EriInput::from_slices(&high, &low, &src, params);
2269 let res = eri_with_kernel(&input, kernel);
2270 assert!(
2271 res.is_err(),
2272 "[{}] ERI should fail with zero period",
2273 test_name
2274 );
2275 Ok(())
2276 }
2277
2278 fn check_eri_period_exceeds_length(
2279 test_name: &str,
2280 kernel: Kernel,
2281 ) -> Result<(), Box<dyn std::error::Error>> {
2282 skip_if_unsupported!(kernel, test_name);
2283 let high = [10.0, 20.0, 30.0];
2284 let low = [8.0, 18.0, 28.0];
2285 let src = [9.0, 19.0, 29.0];
2286 let params = EriParams {
2287 period: Some(10),
2288 ma_type: Some("ema".to_string()),
2289 };
2290 let input = EriInput::from_slices(&high, &low, &src, params);
2291 let res = eri_with_kernel(&input, kernel);
2292 assert!(
2293 res.is_err(),
2294 "[{}] ERI should fail with period exceeding length",
2295 test_name
2296 );
2297 Ok(())
2298 }
2299
2300 fn check_eri_very_small_dataset(
2301 test_name: &str,
2302 kernel: Kernel,
2303 ) -> Result<(), Box<dyn std::error::Error>> {
2304 skip_if_unsupported!(kernel, test_name);
2305 let high = [42.0];
2306 let low = [40.0];
2307 let src = [41.0];
2308 let params = EriParams {
2309 period: Some(9),
2310 ma_type: Some("ema".to_string()),
2311 };
2312 let input = EriInput::from_slices(&high, &low, &src, params);
2313 let res = eri_with_kernel(&input, kernel);
2314 assert!(
2315 res.is_err(),
2316 "[{}] ERI should fail with insufficient data",
2317 test_name
2318 );
2319 Ok(())
2320 }
2321
2322 fn check_eri_reinput(
2323 test_name: &str,
2324 kernel: Kernel,
2325 ) -> Result<(), Box<dyn std::error::Error>> {
2326 skip_if_unsupported!(kernel, test_name);
2327 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2328 let candles = read_candles_from_csv(file_path)?;
2329
2330 let first_params = EriParams {
2331 period: Some(14),
2332 ma_type: Some("ema".to_string()),
2333 };
2334 let first_input = EriInput::from_candles(&candles, "close", first_params);
2335 let first_result = eri_with_kernel(&first_input, kernel)?;
2336
2337 assert_eq!(first_result.bull.len(), candles.close.len());
2338 assert_eq!(first_result.bear.len(), candles.close.len());
2339
2340 let second_params = EriParams {
2341 period: Some(14),
2342 ma_type: Some("ema".to_string()),
2343 };
2344 let second_input = EriInput::from_slices(
2345 &first_result.bull,
2346 &first_result.bear,
2347 &first_result.bull,
2348 second_params,
2349 );
2350 let second_result = eri_with_kernel(&second_input, kernel)?;
2351
2352 assert_eq!(second_result.bull.len(), first_result.bull.len());
2353 assert_eq!(second_result.bear.len(), first_result.bear.len());
2354
2355 for i in 28..second_result.bull.len() {
2356 assert!(
2357 !second_result.bull[i].is_nan(),
2358 "Expected no NaN in bull after index 28, but found NaN at index {}",
2359 i
2360 );
2361 assert!(
2362 !second_result.bear[i].is_nan(),
2363 "Expected no NaN in bear after index 28, but found NaN at index {}",
2364 i
2365 );
2366 }
2367 Ok(())
2368 }
2369
2370 fn check_eri_nan_handling(
2371 test_name: &str,
2372 kernel: Kernel,
2373 ) -> Result<(), Box<dyn std::error::Error>> {
2374 skip_if_unsupported!(kernel, test_name);
2375 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2376 let candles = read_candles_from_csv(file_path)?;
2377
2378 let input = EriInput::from_candles(
2379 &candles,
2380 "close",
2381 EriParams {
2382 period: Some(13),
2383 ma_type: Some("ema".to_string()),
2384 },
2385 );
2386 let res = eri_with_kernel(&input, kernel)?;
2387 assert_eq!(res.bull.len(), candles.close.len());
2388 if res.bull.len() > 240 {
2389 for (i, &val) in res.bull[240..].iter().enumerate() {
2390 assert!(
2391 !val.is_nan(),
2392 "[{}] Found unexpected NaN at bull-index {}",
2393 test_name,
2394 240 + i
2395 );
2396 }
2397 }
2398 if res.bear.len() > 240 {
2399 for (i, &val) in res.bear[240..].iter().enumerate() {
2400 assert!(
2401 !val.is_nan(),
2402 "[{}] Found unexpected NaN at bear-index {}",
2403 test_name,
2404 240 + i
2405 );
2406 }
2407 }
2408 Ok(())
2409 }
2410
2411 #[cfg(debug_assertions)]
2412 fn check_eri_no_poison(
2413 test_name: &str,
2414 kernel: Kernel,
2415 ) -> Result<(), Box<dyn std::error::Error>> {
2416 skip_if_unsupported!(kernel, test_name);
2417
2418 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2419 let candles = read_candles_from_csv(file_path)?;
2420
2421 let test_params = vec![
2422 EriParams::default(),
2423 EriParams {
2424 period: Some(2),
2425 ma_type: Some("ema".to_string()),
2426 },
2427 EriParams {
2428 period: Some(5),
2429 ma_type: Some("sma".to_string()),
2430 },
2431 EriParams {
2432 period: Some(7),
2433 ma_type: Some("ema".to_string()),
2434 },
2435 EriParams {
2436 period: Some(10),
2437 ma_type: Some("wma".to_string()),
2438 },
2439 EriParams {
2440 period: Some(13),
2441 ma_type: Some("ema".to_string()),
2442 },
2443 EriParams {
2444 period: Some(20),
2445 ma_type: Some("sma".to_string()),
2446 },
2447 EriParams {
2448 period: Some(30),
2449 ma_type: Some("ema".to_string()),
2450 },
2451 EriParams {
2452 period: Some(50),
2453 ma_type: Some("sma".to_string()),
2454 },
2455 EriParams {
2456 period: Some(100),
2457 ma_type: Some("ema".to_string()),
2458 },
2459 EriParams {
2460 period: Some(3),
2461 ma_type: Some("hma".to_string()),
2462 },
2463 EriParams {
2464 period: Some(21),
2465 ma_type: Some("dema".to_string()),
2466 },
2467 EriParams {
2468 period: Some(14),
2469 ma_type: Some("tema".to_string()),
2470 },
2471 ];
2472
2473 for (param_idx, params) in test_params.iter().enumerate() {
2474 let input = EriInput::from_candles(&candles, "close", params.clone());
2475 let output = eri_with_kernel(&input, kernel)?;
2476
2477 for (i, &val) in output.bull.iter().enumerate() {
2478 if val.is_nan() {
2479 continue;
2480 }
2481
2482 let bits = val.to_bits();
2483
2484 if bits == 0x11111111_11111111 {
2485 panic!(
2486 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at bull index {} \
2487 with params: period={}, ma_type={} (param set {})",
2488 test_name,
2489 val,
2490 bits,
2491 i,
2492 params.period.unwrap_or(13),
2493 params.ma_type.as_ref().unwrap_or(&"ema".to_string()),
2494 param_idx
2495 );
2496 }
2497
2498 if bits == 0x22222222_22222222 {
2499 panic!(
2500 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at bull index {} \
2501 with params: period={}, ma_type={} (param set {})",
2502 test_name,
2503 val,
2504 bits,
2505 i,
2506 params.period.unwrap_or(13),
2507 params.ma_type.as_ref().unwrap_or(&"ema".to_string()),
2508 param_idx
2509 );
2510 }
2511
2512 if bits == 0x33333333_33333333 {
2513 panic!(
2514 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at bull index {} \
2515 with params: period={}, ma_type={} (param set {})",
2516 test_name,
2517 val,
2518 bits,
2519 i,
2520 params.period.unwrap_or(13),
2521 params.ma_type.as_ref().unwrap_or(&"ema".to_string()),
2522 param_idx
2523 );
2524 }
2525 }
2526
2527 for (i, &val) in output.bear.iter().enumerate() {
2528 if val.is_nan() {
2529 continue;
2530 }
2531
2532 let bits = val.to_bits();
2533
2534 if bits == 0x11111111_11111111 {
2535 panic!(
2536 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at bear index {} \
2537 with params: period={}, ma_type={} (param set {})",
2538 test_name,
2539 val,
2540 bits,
2541 i,
2542 params.period.unwrap_or(13),
2543 params.ma_type.as_ref().unwrap_or(&"ema".to_string()),
2544 param_idx
2545 );
2546 }
2547
2548 if bits == 0x22222222_22222222 {
2549 panic!(
2550 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at bear index {} \
2551 with params: period={}, ma_type={} (param set {})",
2552 test_name,
2553 val,
2554 bits,
2555 i,
2556 params.period.unwrap_or(13),
2557 params.ma_type.as_ref().unwrap_or(&"ema".to_string()),
2558 param_idx
2559 );
2560 }
2561
2562 if bits == 0x33333333_33333333 {
2563 panic!(
2564 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at bear index {} \
2565 with params: period={}, ma_type={} (param set {})",
2566 test_name,
2567 val,
2568 bits,
2569 i,
2570 params.period.unwrap_or(13),
2571 params.ma_type.as_ref().unwrap_or(&"ema".to_string()),
2572 param_idx
2573 );
2574 }
2575 }
2576 }
2577
2578 Ok(())
2579 }
2580
2581 #[cfg(not(debug_assertions))]
2582 fn check_eri_no_poison(
2583 _test_name: &str,
2584 _kernel: Kernel,
2585 ) -> Result<(), Box<dyn std::error::Error>> {
2586 Ok(())
2587 }
2588
2589 macro_rules! generate_all_eri_tests {
2590 ($($test_fn:ident),*) => {
2591 paste::paste! {
2592 $(
2593 #[test]
2594 fn [<$test_fn _scalar_f64>]() {
2595 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2596 }
2597 )*
2598 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2599 $(
2600 #[test]
2601 fn [<$test_fn _avx2_f64>]() {
2602 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2603 }
2604 #[test]
2605 fn [<$test_fn _avx512_f64>]() {
2606 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2607 }
2608 )*
2609 }
2610 }
2611 }
2612
2613 generate_all_eri_tests!(
2614 check_eri_partial_params,
2615 check_eri_accuracy,
2616 check_eri_default_candles,
2617 check_eri_zero_period,
2618 check_eri_period_exceeds_length,
2619 check_eri_very_small_dataset,
2620 check_eri_reinput,
2621 check_eri_nan_handling,
2622 check_eri_no_poison
2623 );
2624
2625 #[cfg(test)]
2626 generate_all_eri_tests!(check_eri_property);
2627
2628 fn check_batch_default_row(
2629 test: &str,
2630 kernel: Kernel,
2631 ) -> Result<(), Box<dyn std::error::Error>> {
2632 skip_if_unsupported!(kernel, test);
2633
2634 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2635 let c = read_candles_from_csv(file)?;
2636
2637 let high = c.select_candle_field("high").unwrap();
2638 let low = c.select_candle_field("low").unwrap();
2639 let src = c.select_candle_field("close").unwrap();
2640
2641 let output = EriBatchBuilder::new()
2642 .kernel(kernel)
2643 .period_static(13)
2644 .apply_slices(high, low, src)?;
2645
2646 let def = EriParams::default();
2647 let row = output.values_for_bull(&def).expect("default row missing");
2648
2649 assert_eq!(row.len(), c.close.len());
2650
2651 let expected = [
2652 -103.35343557205488,
2653 6.839912366813223,
2654 -42.851503685589705,
2655 -9.444146016219747,
2656 11.476446271808527,
2657 ];
2658 let start = row.len() - 5;
2659 for (i, &v) in row[start..].iter().enumerate() {
2660 assert!(
2661 (v - expected[i]).abs() < 1e-2,
2662 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
2663 );
2664 }
2665 Ok(())
2666 }
2667
2668 #[test]
2669 fn test_eri_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
2670 let len = 256usize;
2671 let mut high = vec![0.0; len];
2672 let mut low = vec![0.0; len];
2673 let mut src = vec![0.0; len];
2674 for i in 0..len {
2675 let base = 100.0 + (i as f64) * 0.1;
2676 src[i] = base;
2677 high[i] = base + 1.0;
2678 low[i] = base - 1.0;
2679 }
2680
2681 let params = EriParams {
2682 period: Some(13),
2683 ma_type: Some("ema".to_string()),
2684 };
2685 let input = EriInput::from_slices(&high, &low, &src, params);
2686
2687 let baseline = eri(&input)?;
2688
2689 let mut bull = vec![0.0; len];
2690 let mut bear = vec![0.0; len];
2691 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2692 {
2693 eri_into(&input, &mut bull, &mut bear)?;
2694 }
2695 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2696 {
2697 eri_into_slice(&mut bull, &mut bear, &input, Kernel::Auto)?;
2698 }
2699
2700 assert_eq!(baseline.bull.len(), bull.len());
2701 assert_eq!(baseline.bear.len(), bear.len());
2702
2703 fn eq_or_both_nan(a: f64, b: f64) -> bool {
2704 (a.is_nan() && b.is_nan()) || (a - b).abs() <= 1e-12
2705 }
2706
2707 for i in 0..len {
2708 assert!(
2709 eq_or_both_nan(baseline.bull[i], bull[i]),
2710 "bull mismatch at index {i}: baseline={} into={}",
2711 baseline.bull[i],
2712 bull[i]
2713 );
2714 assert!(
2715 eq_or_both_nan(baseline.bear[i], bear[i]),
2716 "bear mismatch at index {i}: baseline={} into={}",
2717 baseline.bear[i],
2718 bear[i]
2719 );
2720 }
2721
2722 Ok(())
2723 }
2724
2725 #[cfg(debug_assertions)]
2726 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
2727 skip_if_unsupported!(kernel, test);
2728
2729 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2730 let c = read_candles_from_csv(file)?;
2731
2732 let high = c.select_candle_field("high").unwrap();
2733 let low = c.select_candle_field("low").unwrap();
2734 let src = c.select_candle_field("close").unwrap();
2735
2736 let test_configs = vec![
2737 (2, 10, 2),
2738 (5, 25, 5),
2739 (30, 60, 15),
2740 (2, 5, 1),
2741 (10, 20, 2),
2742 (20, 50, 10),
2743 (13, 13, 0),
2744 ];
2745
2746 for (cfg_idx, &(p_start, p_end, p_step)) in test_configs.iter().enumerate() {
2747 let output = EriBatchBuilder::new()
2748 .kernel(kernel)
2749 .period_range(p_start, p_end, p_step)
2750 .apply_slices(high, low, src)?;
2751
2752 for (idx, &val) in output.bull.iter().enumerate() {
2753 if val.is_nan() {
2754 continue;
2755 }
2756
2757 let bits = val.to_bits();
2758 let row = idx / output.cols;
2759 let col = idx % output.cols;
2760 let combo = &output.params[row];
2761
2762 if bits == 0x11111111_11111111 {
2763 panic!(
2764 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2765 at bull row {} col {} (flat index {}) with params: period={}, ma_type={}",
2766 test,
2767 cfg_idx,
2768 val,
2769 bits,
2770 row,
2771 col,
2772 idx,
2773 combo.period.unwrap_or(13),
2774 combo.ma_type.as_ref().unwrap_or(&"ema".to_string())
2775 );
2776 }
2777
2778 if bits == 0x22222222_22222222 {
2779 panic!(
2780 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2781 at bull row {} col {} (flat index {}) with params: period={}, ma_type={}",
2782 test,
2783 cfg_idx,
2784 val,
2785 bits,
2786 row,
2787 col,
2788 idx,
2789 combo.period.unwrap_or(13),
2790 combo.ma_type.as_ref().unwrap_or(&"ema".to_string())
2791 );
2792 }
2793
2794 if bits == 0x33333333_33333333 {
2795 panic!(
2796 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2797 at bull row {} col {} (flat index {}) with params: period={}, ma_type={}",
2798 test,
2799 cfg_idx,
2800 val,
2801 bits,
2802 row,
2803 col,
2804 idx,
2805 combo.period.unwrap_or(13),
2806 combo.ma_type.as_ref().unwrap_or(&"ema".to_string())
2807 );
2808 }
2809 }
2810
2811 for (idx, &val) in output.bear.iter().enumerate() {
2812 if val.is_nan() {
2813 continue;
2814 }
2815
2816 let bits = val.to_bits();
2817 let row = idx / output.cols;
2818 let col = idx % output.cols;
2819 let combo = &output.params[row];
2820
2821 if bits == 0x11111111_11111111 {
2822 panic!(
2823 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2824 at bear row {} col {} (flat index {}) with params: period={}, ma_type={}",
2825 test,
2826 cfg_idx,
2827 val,
2828 bits,
2829 row,
2830 col,
2831 idx,
2832 combo.period.unwrap_or(13),
2833 combo.ma_type.as_ref().unwrap_or(&"ema".to_string())
2834 );
2835 }
2836
2837 if bits == 0x22222222_22222222 {
2838 panic!(
2839 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2840 at bear row {} col {} (flat index {}) with params: period={}, ma_type={}",
2841 test,
2842 cfg_idx,
2843 val,
2844 bits,
2845 row,
2846 col,
2847 idx,
2848 combo.period.unwrap_or(13),
2849 combo.ma_type.as_ref().unwrap_or(&"ema".to_string())
2850 );
2851 }
2852
2853 if bits == 0x33333333_33333333 {
2854 panic!(
2855 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2856 at bear row {} col {} (flat index {}) with params: period={}, ma_type={}",
2857 test,
2858 cfg_idx,
2859 val,
2860 bits,
2861 row,
2862 col,
2863 idx,
2864 combo.period.unwrap_or(13),
2865 combo.ma_type.as_ref().unwrap_or(&"ema".to_string())
2866 );
2867 }
2868 }
2869 }
2870
2871 Ok(())
2872 }
2873
2874 #[cfg(not(debug_assertions))]
2875 fn check_batch_no_poison(
2876 _test: &str,
2877 _kernel: Kernel,
2878 ) -> Result<(), Box<dyn std::error::Error>> {
2879 Ok(())
2880 }
2881
2882 #[cfg(test)]
2883 #[allow(clippy::float_cmp)]
2884 fn check_eri_property(
2885 test_name: &str,
2886 kernel: Kernel,
2887 ) -> Result<(), Box<dyn std::error::Error>> {
2888 use proptest::prelude::*;
2889 skip_if_unsupported!(kernel, test_name);
2890
2891 let strat = (2usize..=50)
2892 .prop_flat_map(|period| {
2893 (
2894 100.0f64..5000.0f64,
2895 (period + 20)..400,
2896 0.001f64..0.05f64,
2897 -0.01f64..0.01f64,
2898 Just(period),
2899 prop::sample::select(vec!["ema", "sma", "wma"]),
2900 )
2901 })
2902 .prop_map(
2903 |(base_price, data_len, volatility, trend, period, ma_type)| {
2904 let mut high = Vec::with_capacity(data_len);
2905 let mut low = Vec::with_capacity(data_len);
2906 let mut close = Vec::with_capacity(data_len);
2907
2908 let mut price = base_price;
2909 for i in 0..data_len {
2910 price *= 1.0 + trend + (i as f64 * 0.0001 * trend);
2911 let daily_vol = volatility * price;
2912
2913 let c = price + daily_vol * ((i as f64).sin() * 0.3);
2914 let h = c + daily_vol * (0.5 + (i as f64 * 0.7).cos().abs() * 0.5);
2915 let l = c - daily_vol * (0.5 + (i as f64 * 0.7).sin().abs() * 0.5);
2916
2917 high.push(h);
2918 low.push(l.min(c));
2919 close.push(c);
2920 }
2921
2922 (high, low, close, period, ma_type.to_string())
2923 },
2924 );
2925
2926 proptest::test_runner::TestRunner::default()
2927 .run(&strat, |(high, low, close, period, ma_type)| {
2928 let params = EriParams {
2929 period: Some(period),
2930 ma_type: Some(ma_type.clone()),
2931 };
2932 let input = EriInput::from_slices(&high, &low, &close, params.clone());
2933
2934 let result = match eri_with_kernel(&input, kernel) {
2935 Ok(r) => r,
2936 Err(e) => match e {
2937 EriError::MaCalculationError(msg) if msg.contains("Not enough data") => {
2938 return Ok(())
2939 }
2940 EriError::NotEnoughValidData { .. } => return Ok(()),
2941 _ => panic!("Unexpected error type: {:?}", e),
2942 },
2943 };
2944
2945 let reference = match eri_with_kernel(&input, Kernel::Scalar) {
2946 Ok(r) => r,
2947 Err(_) => return Ok(()),
2948 };
2949
2950 let first_valid_idx = high
2951 .iter()
2952 .zip(low.iter())
2953 .zip(close.iter())
2954 .position(|((h, l), c)| !h.is_nan() && !l.is_nan() && !c.is_nan())
2955 .unwrap_or(0);
2956 let warmup_period = first_valid_idx + period - 1;
2957
2958 for i in 0..warmup_period.min(high.len()) {
2959 prop_assert!(
2960 result.bull[i].is_nan(),
2961 "[{}] Expected NaN in bull warmup at index {}, got {}",
2962 test_name,
2963 i,
2964 result.bull[i]
2965 );
2966 prop_assert!(
2967 result.bear[i].is_nan(),
2968 "[{}] Expected NaN in bear warmup at index {}, got {}",
2969 test_name,
2970 i,
2971 result.bear[i]
2972 );
2973 }
2974
2975 for i in warmup_period..high.len() {
2976 if !high[i].is_nan() && !low[i].is_nan() && !close[i].is_nan() {
2977 prop_assert!(
2978 !result.bull[i].is_nan(),
2979 "[{}] Unexpected NaN in bull at index {} after warmup",
2980 test_name,
2981 i
2982 );
2983 prop_assert!(
2984 !result.bear[i].is_nan(),
2985 "[{}] Unexpected NaN in bear at index {} after warmup",
2986 test_name,
2987 i
2988 );
2989 }
2990 }
2991
2992 for i in warmup_period..high.len() {
2993 if !result.bull[i].is_nan() && !result.bear[i].is_nan() {
2994 prop_assert!(
2995 result.bear[i] <= result.bull[i] + 1e-9,
2996 "[{}] Bear {} > Bull {} at index {} (low={}, high={})",
2997 test_name,
2998 result.bear[i],
2999 result.bull[i],
3000 i,
3001 low[i],
3002 high[i]
3003 );
3004 }
3005 }
3006
3007 for i in 0..high.len() {
3008 let bull_val = result.bull[i];
3009 let bear_val = result.bear[i];
3010 let ref_bull = reference.bull[i];
3011 let ref_bear = reference.bear[i];
3012
3013 if bull_val.is_finite() && ref_bull.is_finite() {
3014 let bull_diff = (bull_val - ref_bull).abs();
3015 let bull_ulp = bull_val.to_bits().abs_diff(ref_bull.to_bits());
3016 prop_assert!(
3017 bull_diff <= 1e-9 || bull_ulp <= 4,
3018 "[{}] Bull kernel mismatch at index {}: {} vs {} (diff={}, ULP={})",
3019 test_name,
3020 i,
3021 bull_val,
3022 ref_bull,
3023 bull_diff,
3024 bull_ulp
3025 );
3026 } else {
3027 prop_assert_eq!(
3028 bull_val.to_bits(),
3029 ref_bull.to_bits(),
3030 "[{}] Bull NaN/Inf mismatch at index {}: {} vs {}",
3031 test_name,
3032 i,
3033 bull_val,
3034 ref_bull
3035 );
3036 }
3037
3038 if bear_val.is_finite() && ref_bear.is_finite() {
3039 let bear_diff = (bear_val - ref_bear).abs();
3040 let bear_ulp = bear_val.to_bits().abs_diff(ref_bear.to_bits());
3041 prop_assert!(
3042 bear_diff <= 1e-9 || bear_ulp <= 4,
3043 "[{}] Bear kernel mismatch at index {}: {} vs {} (diff={}, ULP={})",
3044 test_name,
3045 i,
3046 bear_val,
3047 ref_bear,
3048 bear_diff,
3049 bear_ulp
3050 );
3051 } else {
3052 prop_assert_eq!(
3053 bear_val.to_bits(),
3054 ref_bear.to_bits(),
3055 "[{}] Bear NaN/Inf mismatch at index {}: {} vs {}",
3056 test_name,
3057 i,
3058 bear_val,
3059 ref_bear
3060 );
3061 }
3062 }
3063
3064 let all_high_same = high.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10);
3065 let all_low_same = low.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10);
3066 let all_close_same = close.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10);
3067
3068 if all_high_same
3069 && all_low_same
3070 && all_close_same
3071 && high.len() > warmup_period + 2 * period
3072 {
3073 let expected_bull = high[0] - close[0];
3074 let expected_bear = low[0] - close[0];
3075
3076 for i in (warmup_period + 2 * period)..high.len() {
3077 if !result.bull[i].is_nan() && !result.bear[i].is_nan() {
3078 prop_assert!(
3079 (result.bull[i] - expected_bull).abs() < 1e-6,
3080 "[{}] Constant price: Bull {} != expected {} at index {}",
3081 test_name,
3082 result.bull[i],
3083 expected_bull,
3084 i
3085 );
3086 prop_assert!(
3087 (result.bear[i] - expected_bear).abs() < 1e-6,
3088 "[{}] Constant price: Bear {} != expected {} at index {}",
3089 test_name,
3090 result.bear[i],
3091 expected_bear,
3092 i
3093 );
3094 }
3095 }
3096 }
3097
3098 for i in warmup_period..high.len() {
3099 if !result.bull[i].is_nan() && !result.bear[i].is_nan() {
3100 let expected_diff = high[i] - low[i];
3101 let actual_diff = result.bull[i] - result.bear[i];
3102 prop_assert!(
3103 (actual_diff - expected_diff).abs() < 1e-9,
3104 "[{}] Bull - Bear != High - Low at index {}: {} vs {}",
3105 test_name,
3106 i,
3107 actual_diff,
3108 expected_diff
3109 );
3110 }
3111 }
3112
3113 for i in warmup_period..high.len() {
3114 if !result.bull[i].is_nan() && !result.bear[i].is_nan() {
3115 if result.bull[i] < 0.0 && result.bear[i] > 0.0 {
3116 prop_assert!(
3117 false,
3118 "[{}] Impossible state: bull {} < 0 but bear {} > 0 at index {} (high={}, low={})",
3119 test_name, result.bull[i], result.bear[i], i, high[i], low[i]
3120 );
3121 }
3122 }
3123 }
3124
3125 if period == 1 {
3126 for i in warmup_period..high.len().min(warmup_period + 10) {
3127 if !result.bull[i].is_nan() && !result.bear[i].is_nan() {
3128 let expected_diff = high[i] - low[i];
3129 let actual_diff = result.bull[i] - result.bear[i];
3130 prop_assert!(
3131 (actual_diff - expected_diff).abs() < 1e-6,
3132 "[{}] Period=1: Bull-Bear mismatch at index {}: {} vs {}",
3133 test_name,
3134 i,
3135 actual_diff,
3136 expected_diff
3137 );
3138 }
3139 }
3140 }
3141
3142 let max_price = high.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
3143 let min_price = low.iter().cloned().fold(f64::INFINITY, f64::min);
3144 let price_range = max_price - min_price;
3145
3146 for i in warmup_period..high.len() {
3147 if !result.bull[i].is_nan() && !result.bear[i].is_nan() {
3148 prop_assert!(
3149 result.bull[i].abs() <= price_range * 2.0,
3150 "[{}] Bull {} exceeds reasonable bounds (price range: {}) at index {}",
3151 test_name,
3152 result.bull[i],
3153 price_range,
3154 i
3155 );
3156 prop_assert!(
3157 result.bear[i].abs() <= price_range * 2.0,
3158 "[{}] Bear {} exceeds reasonable bounds (price range: {}) at index {}",
3159 test_name,
3160 result.bear[i],
3161 price_range,
3162 i
3163 );
3164 }
3165 }
3166
3167 if period >= high.len() - 5 && period < high.len() {
3168 let valid_count = result
3169 .bull
3170 .iter()
3171 .zip(result.bear.iter())
3172 .filter(|(b, r)| !b.is_nan() && !r.is_nan())
3173 .count();
3174 prop_assert!(
3175 valid_count >= 1,
3176 "[{}] No valid values with period {} and data_len {}",
3177 test_name,
3178 period,
3179 high.len()
3180 );
3181 }
3182
3183 Ok(())
3184 })
3185 .unwrap();
3186
3187 Ok(())
3188 }
3189
3190 macro_rules! gen_batch_tests {
3191 ($fn_name:ident) => {
3192 paste::paste! {
3193 #[test] fn [<$fn_name _scalar>]() {
3194 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
3195 }
3196 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3197 #[test] fn [<$fn_name _avx2>]() {
3198 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
3199 }
3200 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3201 #[test] fn [<$fn_name _avx512>]() {
3202 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
3203 }
3204 #[test] fn [<$fn_name _auto_detect>]() {
3205 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
3206 }
3207 }
3208 };
3209 }
3210 gen_batch_tests!(check_batch_default_row);
3211 gen_batch_tests!(check_batch_no_poison);
3212}