Skip to main content

vector_ta/indicators/
eri.rs

1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::{PyDict, PyList};
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::indicators::moving_averages::ma::{ma, MaData};
16use crate::utilities::data_loader::{source_type, Candles};
17use crate::utilities::enums::Kernel;
18use crate::utilities::helpers::{
19    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
20    make_uninit_matrix,
21};
22#[cfg(feature = "python")]
23use crate::utilities::kernel_validation::validate_kernel;
24#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
25use core::arch::x86_64::*;
26#[cfg(not(target_arch = "wasm32"))]
27use rayon::prelude::*;
28use std::convert::AsRef;
29use thiserror::Error;
30
31#[derive(Debug, Clone)]
32pub enum EriData<'a> {
33    Candles {
34        candles: &'a Candles,
35        source: &'a str,
36    },
37    Slices {
38        high: &'a [f64],
39        low: &'a [f64],
40        source: &'a [f64],
41    },
42}
43
44impl<'a> AsRef<[f64]> for EriInput<'a> {
45    #[inline(always)]
46    fn as_ref(&self) -> &[f64] {
47        match &self.data {
48            EriData::Candles { candles, source } => source_type(candles, source),
49            EriData::Slices { source, .. } => source,
50        }
51    }
52}
53
54#[derive(Debug, Clone)]
55pub struct EriOutput {
56    pub bull: Vec<f64>,
57    pub bear: Vec<f64>,
58}
59
60#[derive(Debug, Clone)]
61pub struct EriParams {
62    pub period: Option<usize>,
63    pub ma_type: Option<String>,
64}
65
66impl Default for EriParams {
67    fn default() -> Self {
68        Self {
69            period: Some(13),
70            ma_type: Some("ema".to_string()),
71        }
72    }
73}
74
75#[derive(Debug, Clone)]
76pub struct EriInput<'a> {
77    pub data: EriData<'a>,
78    pub params: EriParams,
79}
80
81impl<'a> EriInput<'a> {
82    #[inline]
83    pub fn from_candles(candles: &'a Candles, source: &'a str, params: EriParams) -> Self {
84        Self {
85            data: EriData::Candles { candles, source },
86            params,
87        }
88    }
89    #[inline]
90    pub fn from_slices(
91        high: &'a [f64],
92        low: &'a [f64],
93        source: &'a [f64],
94        params: EriParams,
95    ) -> Self {
96        Self {
97            data: EriData::Slices { high, low, source },
98            params,
99        }
100    }
101    #[inline]
102    pub fn with_default_candles(candles: &'a Candles) -> Self {
103        Self::from_candles(candles, "close", EriParams::default())
104    }
105    #[inline]
106    pub fn get_period(&self) -> usize {
107        self.params.period.unwrap_or(13)
108    }
109    #[inline]
110    pub fn get_ma_type(&self) -> &str {
111        self.params.ma_type.as_deref().unwrap_or("ema")
112    }
113}
114
115#[derive(Clone, Debug)]
116pub struct EriBuilder {
117    period: Option<usize>,
118    ma_type: Option<String>,
119    kernel: Kernel,
120}
121
122impl Default for EriBuilder {
123    fn default() -> Self {
124        Self {
125            period: None,
126            ma_type: None,
127            kernel: Kernel::Auto,
128        }
129    }
130}
131
132impl EriBuilder {
133    #[inline(always)]
134    pub fn new() -> Self {
135        Self::default()
136    }
137    #[inline(always)]
138    pub fn period(mut self, n: usize) -> Self {
139        self.period = Some(n);
140        self
141    }
142    #[inline(always)]
143    pub fn ma_type<S: Into<String>>(mut self, t: S) -> Self {
144        self.ma_type = Some(t.into());
145        self
146    }
147    #[inline(always)]
148    pub fn kernel(mut self, k: Kernel) -> Self {
149        self.kernel = k;
150        self
151    }
152    #[inline(always)]
153    pub fn apply(self, c: &Candles) -> Result<EriOutput, EriError> {
154        let p = EriParams {
155            period: self.period,
156            ma_type: self.ma_type,
157        };
158        let i = EriInput::from_candles(c, "close", p);
159        eri_with_kernel(&i, self.kernel)
160    }
161    #[inline(always)]
162    pub fn apply_slices(
163        self,
164        high: &[f64],
165        low: &[f64],
166        src: &[f64],
167    ) -> Result<EriOutput, EriError> {
168        let p = EriParams {
169            period: self.period,
170            ma_type: self.ma_type,
171        };
172        let i = EriInput::from_slices(high, low, src, p);
173        eri_with_kernel(&i, self.kernel)
174    }
175    #[inline(always)]
176    pub fn into_stream(self) -> Result<EriStream, EriError> {
177        let p = EriParams {
178            period: self.period,
179            ma_type: self.ma_type,
180        };
181        EriStream::try_new(p)
182    }
183}
184
185#[derive(Debug, Error)]
186pub enum EriError {
187    #[error("eri: All input values are NaN.")]
188    AllValuesNaN,
189    #[error("eri: Invalid period: period = {period}, data length = {data_len}")]
190    InvalidPeriod { period: usize, data_len: usize },
191    #[error("eri: Not enough valid data: needed = {needed}, valid = {valid}")]
192    NotEnoughValidData { needed: usize, valid: usize },
193    #[error("eri: MA calculation error: {0}")]
194    MaCalculationError(String),
195    #[error("eri: Empty data provided.")]
196    EmptyInputData,
197    #[error("eri: Output slice length mismatch: expected = {expected}, got = {got}")]
198    OutputLengthMismatch { expected: usize, got: usize },
199    #[error("eri: Invalid range expansion: start={start}, end={end}, step={step}")]
200    InvalidRange {
201        start: usize,
202        end: usize,
203        step: usize,
204    },
205    #[error("eri: Invalid kernel for batch operation. Got {0:?}")]
206    InvalidKernelForBatch(Kernel),
207}
208
209#[inline]
210pub fn eri(input: &EriInput) -> Result<EriOutput, EriError> {
211    eri_with_kernel(input, Kernel::Auto)
212}
213
214pub fn eri_with_kernel(input: &EriInput, kernel: Kernel) -> Result<EriOutput, EriError> {
215    let (high, low, source_data) = match &input.data {
216        EriData::Candles { candles, source } => {
217            let high = candles
218                .select_candle_field("high")
219                .map_err(|_| EriError::EmptyInputData)?;
220            let low = candles
221                .select_candle_field("low")
222                .map_err(|_| EriError::EmptyInputData)?;
223            let src = source_type(candles, source);
224            (high, low, src)
225        }
226        EriData::Slices { high, low, source } => (*high, *low, *source),
227    };
228
229    if source_data.is_empty() || high.is_empty() || low.is_empty() {
230        return Err(EriError::EmptyInputData);
231    }
232
233    let period = input.get_period();
234    if period == 0 || period > source_data.len() {
235        return Err(EriError::InvalidPeriod {
236            period,
237            data_len: source_data.len(),
238        });
239    }
240
241    let mut first_valid_idx = None;
242    for i in 0..source_data.len() {
243        if !(source_data[i].is_nan() || high[i].is_nan() || low[i].is_nan()) {
244            first_valid_idx = Some(i);
245            break;
246        }
247    }
248    let first_valid_idx = match first_valid_idx {
249        Some(idx) => idx,
250        None => return Err(EriError::AllValuesNaN),
251    };
252
253    if (source_data.len() - first_valid_idx) < period {
254        return Err(EriError::NotEnoughValidData {
255            needed: period,
256            valid: source_data.len() - first_valid_idx,
257        });
258    }
259
260    let ma_type = input.get_ma_type();
261    let warmup_period = first_valid_idx + period - 1;
262    let mut bull = alloc_with_nan_prefix(source_data.len(), warmup_period);
263    let mut bear = alloc_with_nan_prefix(source_data.len(), warmup_period);
264
265    if ma_type == "sma" || ma_type == "SMA" {
266        unsafe {
267            eri_scalar_classic_sma(
268                high,
269                low,
270                &source_data,
271                period,
272                first_valid_idx,
273                &mut bull,
274                &mut bear,
275            )?;
276        }
277        return Ok(EriOutput { bull, bear });
278    } else if ma_type == "ema" || ma_type == "EMA" {
279        unsafe {
280            eri_scalar_classic_ema(
281                high,
282                low,
283                &source_data,
284                period,
285                first_valid_idx,
286                &mut bull,
287                &mut bear,
288            )?;
289        }
290        return Ok(EriOutput { bull, bear });
291    }
292
293    let full_ma = ma(&ma_type, MaData::Slice(&source_data), period)
294        .map_err(|e| EriError::MaCalculationError(e.to_string()))?;
295
296    let chosen = match kernel {
297        Kernel::Auto => Kernel::Scalar,
298        other => other,
299    };
300
301    unsafe {
302        match chosen {
303            Kernel::Scalar | Kernel::ScalarBatch => eri_scalar(
304                high,
305                low,
306                &full_ma,
307                period,
308                first_valid_idx,
309                &mut bull,
310                &mut bear,
311            ),
312            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
313            Kernel::Avx2 | Kernel::Avx2Batch => eri_avx2(
314                high,
315                low,
316                &full_ma,
317                period,
318                first_valid_idx,
319                &mut bull,
320                &mut bear,
321            ),
322            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
323            Kernel::Avx512 | Kernel::Avx512Batch => eri_avx512(
324                high,
325                low,
326                &full_ma,
327                period,
328                first_valid_idx,
329                &mut bull,
330                &mut bear,
331            ),
332            _ => unreachable!(),
333        }
334    }
335
336    Ok(EriOutput { bull, bear })
337}
338
339#[inline]
340pub fn eri_scalar(
341    high: &[f64],
342    low: &[f64],
343    ma: &[f64],
344    period: usize,
345    first_valid: usize,
346    bull: &mut [f64],
347    bear: &mut [f64],
348) {
349    let mut i = first_valid + period - 1;
350    let n = high.len();
351    if i >= n {
352        return;
353    }
354
355    while i + 4 <= n {
356        let m0 = ma[i + 0];
357        bull[i + 0] = high[i + 0] - m0;
358        bear[i + 0] = low[i + 0] - m0;
359
360        let m1 = ma[i + 1];
361        bull[i + 1] = high[i + 1] - m1;
362        bear[i + 1] = low[i + 1] - m1;
363
364        let m2 = ma[i + 2];
365        bull[i + 2] = high[i + 2] - m2;
366        bear[i + 2] = low[i + 2] - m2;
367
368        let m3 = ma[i + 3];
369        bull[i + 3] = high[i + 3] - m3;
370        bear[i + 3] = low[i + 3] - m3;
371
372        i += 4;
373    }
374
375    if i + 2 <= n {
376        let m0 = ma[i + 0];
377        bull[i + 0] = high[i + 0] - m0;
378        bear[i + 0] = low[i + 0] - m0;
379
380        let m1 = ma[i + 1];
381        bull[i + 1] = high[i + 1] - m1;
382        bear[i + 1] = low[i + 1] - m1;
383        i += 2;
384    }
385
386    if i < n {
387        let m0 = ma[i];
388        bull[i] = high[i] - m0;
389        bear[i] = low[i] - m0;
390    }
391}
392
393#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
394#[inline]
395pub fn eri_avx512(
396    high: &[f64],
397    low: &[f64],
398    ma: &[f64],
399    period: usize,
400    first_valid: usize,
401    bull: &mut [f64],
402    bear: &mut [f64],
403) {
404    unsafe { eri_avx512_long(high, low, ma, period, first_valid, bull, bear) }
405}
406
407#[inline]
408pub fn eri_avx2(
409    high: &[f64],
410    low: &[f64],
411    ma: &[f64],
412    period: usize,
413    first_valid: usize,
414    bull: &mut [f64],
415    bear: &mut [f64],
416) {
417    #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
418    unsafe {
419        return eri_avx2_core(high, low, ma, period, first_valid, bull, bear);
420    }
421    eri_scalar(high, low, ma, period, first_valid, bull, bear)
422}
423
424#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
425#[inline]
426pub fn eri_avx512_short(
427    high: &[f64],
428    low: &[f64],
429    ma: &[f64],
430    period: usize,
431    first_valid: usize,
432    bull: &mut [f64],
433    bear: &mut [f64],
434) {
435    unsafe { eri_avx512_core(high, low, ma, period, first_valid, bull, bear) }
436}
437
438#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
439#[inline]
440pub fn eri_avx512_long(
441    high: &[f64],
442    low: &[f64],
443    ma: &[f64],
444    period: usize,
445    first_valid: usize,
446    bull: &mut [f64],
447    bear: &mut [f64],
448) {
449    unsafe { eri_avx512_core(high, low, ma, period, first_valid, bull, bear) }
450}
451
452#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
453#[inline]
454#[target_feature(enable = "avx2")]
455unsafe fn eri_avx2_core(
456    high: &[f64],
457    low: &[f64],
458    ma: &[f64],
459    period: usize,
460    first_valid: usize,
461    bull: &mut [f64],
462    bear: &mut [f64],
463) {
464    use core::arch::x86_64::*;
465
466    let mut i = first_valid + period - 1;
467    let n = high.len();
468    if i >= n {
469        return;
470    }
471    let len = n - i;
472
473    let mut h_ptr = high.as_ptr().add(i);
474    let mut l_ptr = low.as_ptr().add(i);
475    let mut m_ptr = ma.as_ptr().add(i);
476    let mut b_ptr = bull.as_mut_ptr().add(i);
477    let mut r_ptr = bear.as_mut_ptr().add(i);
478
479    let mut k = 0usize;
480    while k + 4 <= len {
481        let h = _mm256_loadu_pd(h_ptr);
482        let l = _mm256_loadu_pd(l_ptr);
483        let m = _mm256_loadu_pd(m_ptr);
484
485        let b = _mm256_sub_pd(h, m);
486        let r = _mm256_sub_pd(l, m);
487
488        _mm256_storeu_pd(b_ptr, b);
489        _mm256_storeu_pd(r_ptr, r);
490
491        h_ptr = h_ptr.add(4);
492        l_ptr = l_ptr.add(4);
493        m_ptr = m_ptr.add(4);
494        b_ptr = b_ptr.add(4);
495        r_ptr = r_ptr.add(4);
496        k += 4;
497    }
498
499    if k + 2 <= len {
500        let h = _mm_loadu_pd(h_ptr);
501        let l = _mm_loadu_pd(l_ptr);
502        let m = _mm_loadu_pd(m_ptr);
503
504        let b = _mm_sub_pd(h, m);
505        let r = _mm_sub_pd(l, m);
506
507        _mm_storeu_pd(b_ptr, b);
508        _mm_storeu_pd(r_ptr, r);
509
510        h_ptr = h_ptr.add(2);
511        l_ptr = l_ptr.add(2);
512        m_ptr = m_ptr.add(2);
513        b_ptr = b_ptr.add(2);
514        r_ptr = r_ptr.add(2);
515        k += 2;
516    }
517
518    if k < len {
519        let m0 = *m_ptr;
520        *b_ptr = *h_ptr - m0;
521        *r_ptr = *l_ptr - m0;
522    }
523}
524
525#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
526#[inline]
527#[target_feature(enable = "avx512f")]
528unsafe fn eri_avx512_core(
529    high: &[f64],
530    low: &[f64],
531    ma: &[f64],
532    period: usize,
533    first_valid: usize,
534    bull: &mut [f64],
535    bear: &mut [f64],
536) {
537    use core::arch::x86_64::*;
538
539    let mut i = first_valid + period - 1;
540    let n = high.len();
541    if i >= n {
542        return;
543    }
544    let len = n - i;
545
546    let mut h_ptr = high.as_ptr().add(i);
547    let mut l_ptr = low.as_ptr().add(i);
548    let mut m_ptr = ma.as_ptr().add(i);
549    let mut b_ptr = bull.as_mut_ptr().add(i);
550    let mut r_ptr = bear.as_mut_ptr().add(i);
551
552    let mut k = 0usize;
553    while k + 8 <= len {
554        let h = _mm512_loadu_pd(h_ptr);
555        let l = _mm512_loadu_pd(l_ptr);
556        let m = _mm512_loadu_pd(m_ptr);
557
558        let b = _mm512_sub_pd(h, m);
559        let r = _mm512_sub_pd(l, m);
560
561        _mm512_storeu_pd(b_ptr, b);
562        _mm512_storeu_pd(r_ptr, r);
563
564        h_ptr = h_ptr.add(8);
565        l_ptr = l_ptr.add(8);
566        m_ptr = m_ptr.add(8);
567        b_ptr = b_ptr.add(8);
568        r_ptr = r_ptr.add(8);
569        k += 8;
570    }
571
572    if k + 4 <= len {
573        #[cfg(target_feature = "avx2")]
574        {
575            let h = _mm256_loadu_pd(h_ptr);
576            let l = _mm256_loadu_pd(l_ptr);
577            let m = _mm256_loadu_pd(m_ptr);
578
579            let b = _mm256_sub_pd(h, m);
580            let r = _mm256_sub_pd(l, m);
581
582            _mm256_storeu_pd(b_ptr, b);
583            _mm256_storeu_pd(r_ptr, r);
584
585            h_ptr = h_ptr.add(4);
586            l_ptr = l_ptr.add(4);
587            m_ptr = m_ptr.add(4);
588            b_ptr = b_ptr.add(4);
589            r_ptr = r_ptr.add(4);
590            k += 4;
591        }
592        #[cfg(not(target_feature = "avx2"))]
593        {
594            let m0 = *m_ptr.add(0);
595            *b_ptr.add(0) = *h_ptr.add(0) - m0;
596            *r_ptr.add(0) = *l_ptr.add(0) - m0;
597            let m1 = *m_ptr.add(1);
598            *b_ptr.add(1) = *h_ptr.add(1) - m1;
599            *r_ptr.add(1) = *l_ptr.add(1) - m1;
600            let m2 = *m_ptr.add(2);
601            *b_ptr.add(2) = *h_ptr.add(2) - m2;
602            *r_ptr.add(2) = *l_ptr.add(2) - m2;
603            let m3 = *m_ptr.add(3);
604            *b_ptr.add(3) = *h_ptr.add(3) - m3;
605            *r_ptr.add(3) = *l_ptr.add(3) - m3;
606
607            h_ptr = h_ptr.add(4);
608            l_ptr = l_ptr.add(4);
609            m_ptr = m_ptr.add(4);
610            b_ptr = b_ptr.add(4);
611            r_ptr = r_ptr.add(4);
612            k += 4;
613        }
614    }
615
616    if k + 2 <= len {
617        let h = _mm_loadu_pd(h_ptr);
618        let l = _mm_loadu_pd(l_ptr);
619        let m = _mm_loadu_pd(m_ptr);
620
621        let b = _mm_sub_pd(h, m);
622        let r = _mm_sub_pd(l, m);
623
624        _mm_storeu_pd(b_ptr, b);
625        _mm_storeu_pd(r_ptr, r);
626
627        h_ptr = h_ptr.add(2);
628        l_ptr = l_ptr.add(2);
629        m_ptr = m_ptr.add(2);
630        b_ptr = b_ptr.add(2);
631        r_ptr = r_ptr.add(2);
632        k += 2;
633    }
634
635    if k < len {
636        let m0 = *m_ptr;
637        *b_ptr = *h_ptr - m0;
638        *r_ptr = *l_ptr - m0;
639    }
640}
641
642#[inline]
643pub fn eri_batch_with_kernel(
644    high: &[f64],
645    low: &[f64],
646    source: &[f64],
647    sweep: &EriBatchRange,
648    k: Kernel,
649) -> Result<EriBatchOutput, EriError> {
650    let kernel = match k {
651        Kernel::Auto => detect_best_batch_kernel(),
652        other if other.is_batch() => other,
653        other => return Err(EriError::InvalidKernelForBatch(other)),
654    };
655    let simd = match kernel {
656        Kernel::Avx512Batch => Kernel::Avx512,
657        Kernel::Avx2Batch => Kernel::Avx2,
658        Kernel::ScalarBatch => Kernel::Scalar,
659        _ => unreachable!(),
660    };
661    eri_batch_par_slice(high, low, source, sweep, simd)
662}
663
664#[derive(Clone, Debug)]
665pub struct EriBatchRange {
666    pub period: (usize, usize, usize),
667    pub ma_type: String,
668}
669
670impl Default for EriBatchRange {
671    fn default() -> Self {
672        Self {
673            period: (13, 262, 1),
674            ma_type: "ema".into(),
675        }
676    }
677}
678
679#[derive(Clone, Debug, Default)]
680pub struct EriBatchBuilder {
681    range: EriBatchRange,
682    kernel: Kernel,
683}
684
685impl EriBatchBuilder {
686    pub fn new() -> Self {
687        Self::default()
688    }
689    pub fn kernel(mut self, k: Kernel) -> Self {
690        self.kernel = k;
691        self
692    }
693    #[inline]
694    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
695        self.range.period = (start, end, step);
696        self
697    }
698    #[inline]
699    pub fn period_static(mut self, p: usize) -> Self {
700        self.range.period = (p, p, 0);
701        self
702    }
703    pub fn apply_slices(
704        self,
705        high: &[f64],
706        low: &[f64],
707        source: &[f64],
708    ) -> Result<EriBatchOutput, EriError> {
709        eri_batch_with_kernel(high, low, source, &self.range, self.kernel)
710    }
711}
712
713#[derive(Clone, Debug)]
714pub struct EriBatchOutput {
715    pub bull: Vec<f64>,
716    pub bear: Vec<f64>,
717    pub params: Vec<EriParams>,
718    pub rows: usize,
719    pub cols: usize,
720}
721
722impl EriBatchOutput {
723    pub fn row_for_params(&self, p: &EriParams) -> Option<usize> {
724        self.params
725            .iter()
726            .position(|c| c.period == p.period && c.ma_type == p.ma_type)
727    }
728    pub fn values_for_bull(&self, p: &EriParams) -> Option<&[f64]> {
729        self.row_for_params(p).map(|row| {
730            let start = row * self.cols;
731            &self.bull[start..start + self.cols]
732        })
733    }
734    pub fn values_for_bear(&self, p: &EriParams) -> Option<&[f64]> {
735        self.row_for_params(p).map(|row| {
736            let start = row * self.cols;
737            &self.bear[start..start + self.cols]
738        })
739    }
740}
741
742#[inline(always)]
743fn expand_grid(r: &EriBatchRange) -> Result<Vec<EriParams>, EriError> {
744    let (start, end, step) = r.period;
745
746    if step == 0 {
747        return Ok(vec![EriParams {
748            period: Some(start),
749            ma_type: Some(r.ma_type.clone()),
750        }]);
751    }
752
753    let mut out: Vec<EriParams> = Vec::new();
754    if start == end {
755        out.push(EriParams {
756            period: Some(start),
757            ma_type: Some(r.ma_type.clone()),
758        });
759        return Ok(out);
760    }
761
762    if start < end {
763        let mut p = start;
764        while p <= end {
765            out.push(EriParams {
766                period: Some(p),
767                ma_type: Some(r.ma_type.clone()),
768            });
769            match p.checked_add(step) {
770                Some(next) => {
771                    if next == p {
772                        break;
773                    }
774                    p = next;
775                }
776                None => return Err(EriError::InvalidRange { start, end, step }),
777            }
778        }
779    } else {
780        let mut p = start;
781        while p >= end {
782            out.push(EriParams {
783                period: Some(p),
784                ma_type: Some(r.ma_type.clone()),
785            });
786
787            if p < step {
788                break;
789            }
790            p -= step;
791            if p == usize::MAX {
792                break;
793            }
794        }
795
796        if out.is_empty() {
797            return Err(EriError::InvalidRange { start, end, step });
798        }
799    }
800
801    if out.is_empty() {
802        return Err(EriError::InvalidRange { start, end, step });
803    }
804    Ok(out)
805}
806
807#[inline(always)]
808pub fn eri_batch_slice(
809    high: &[f64],
810    low: &[f64],
811    source: &[f64],
812    sweep: &EriBatchRange,
813    kern: Kernel,
814) -> Result<EriBatchOutput, EriError> {
815    eri_batch_inner(high, low, source, sweep, kern, false)
816}
817
818#[inline(always)]
819pub fn eri_batch_par_slice(
820    high: &[f64],
821    low: &[f64],
822    source: &[f64],
823    sweep: &EriBatchRange,
824    kern: Kernel,
825) -> Result<EriBatchOutput, EriError> {
826    eri_batch_inner(high, low, source, sweep, kern, true)
827}
828
829#[inline(always)]
830fn eri_batch_inner(
831    high: &[f64],
832    low: &[f64],
833    source: &[f64],
834    sweep: &EriBatchRange,
835    kern: Kernel,
836    parallel: bool,
837) -> Result<EriBatchOutput, EriError> {
838    let combos = expand_grid(sweep)?;
839
840    let first = high
841        .iter()
842        .zip(low.iter())
843        .zip(source.iter())
844        .position(|((h, l), s)| !h.is_nan() && !l.is_nan() && !s.is_nan())
845        .ok_or(EriError::AllValuesNaN)?;
846    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
847    if source.len() - first < max_p {
848        return Err(EriError::NotEnoughValidData {
849            needed: max_p,
850            valid: source.len() - first,
851        });
852    }
853    let rows = combos.len();
854    let cols = source.len();
855    let total = rows.checked_mul(cols).ok_or(EriError::InvalidRange {
856        start: sweep.period.0,
857        end: sweep.period.1,
858        step: sweep.period.2,
859    })?;
860
861    let mut buf_bull = make_uninit_matrix(rows, cols);
862    let mut buf_bear = make_uninit_matrix(rows, cols);
863
864    let warmup_periods: Vec<usize> = combos
865        .iter()
866        .map(|c| first + c.period.unwrap() - 1)
867        .collect();
868    init_matrix_prefixes(&mut buf_bull, cols, &warmup_periods);
869    init_matrix_prefixes(&mut buf_bear, cols, &warmup_periods);
870
871    let mut buf_bull_guard = std::mem::ManuallyDrop::new(buf_bull);
872    let mut buf_bear_guard = std::mem::ManuallyDrop::new(buf_bear);
873
874    let mut bull =
875        unsafe { std::slice::from_raw_parts_mut(buf_bull_guard.as_mut_ptr() as *mut f64, total) };
876    let mut bear =
877        unsafe { std::slice::from_raw_parts_mut(buf_bear_guard.as_mut_ptr() as *mut f64, total) };
878
879    let do_row = |row: usize, bull_row: &mut [f64], bear_row: &mut [f64]| -> Result<(), EriError> {
880        let period = combos[row].period.unwrap();
881        let ma_type = combos[row].ma_type.as_deref().unwrap();
882        let ma_vec = ma(ma_type, MaData::Slice(source), period)
883            .map_err(|e| EriError::MaCalculationError(e.to_string()))?;
884        match kern {
885            Kernel::Scalar => unsafe {
886                eri_row_scalar(high, low, &ma_vec, first, period, bull_row, bear_row)
887            },
888            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
889            Kernel::Avx2 => unsafe {
890                eri_row_avx2(high, low, &ma_vec, first, period, bull_row, bear_row)
891            },
892            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
893            Kernel::Avx512 => unsafe {
894                eri_row_avx512(high, low, &ma_vec, first, period, bull_row, bear_row)
895            },
896            _ => unreachable!(),
897        }
898        Ok(())
899    };
900
901    if parallel {
902        #[cfg(not(target_arch = "wasm32"))]
903        {
904            bull.par_chunks_mut(cols)
905                .zip(bear.par_chunks_mut(cols))
906                .enumerate()
907                .for_each(|(row, (bull_row, bear_row))| {
908                    let _ = do_row(row, bull_row, bear_row);
909                });
910        }
911        #[cfg(target_arch = "wasm32")]
912        {
913            for (row, (bull_row, bear_row)) in
914                bull.chunks_mut(cols).zip(bear.chunks_mut(cols)).enumerate()
915            {
916                let _ = do_row(row, bull_row, bear_row);
917            }
918        }
919    } else {
920        for (row, (bull_row, bear_row)) in
921            bull.chunks_mut(cols).zip(bear.chunks_mut(cols)).enumerate()
922        {
923            let _ = do_row(row, bull_row, bear_row);
924        }
925    }
926
927    let bull_vec =
928        unsafe { Vec::from_raw_parts(buf_bull_guard.as_mut_ptr() as *mut f64, total, total) };
929    let bear_vec =
930        unsafe { Vec::from_raw_parts(buf_bear_guard.as_mut_ptr() as *mut f64, total, total) };
931
932    Ok(EriBatchOutput {
933        bull: bull_vec,
934        bear: bear_vec,
935        params: combos,
936        rows,
937        cols,
938    })
939}
940
941#[inline(always)]
942pub fn eri_batch_inner_into(
943    high: &[f64],
944    low: &[f64],
945    source: &[f64],
946    sweep: &EriBatchRange,
947    kern: Kernel,
948    parallel: bool,
949    bull_out: &mut [f64],
950    bear_out: &mut [f64],
951) -> Result<Vec<EriParams>, EriError> {
952    let combos = expand_grid(sweep)?;
953
954    let first = high
955        .iter()
956        .zip(low.iter())
957        .zip(source.iter())
958        .position(|((h, l), s)| !h.is_nan() && !l.is_nan() && !s.is_nan())
959        .ok_or(EriError::AllValuesNaN)?;
960    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
961    if source.len() - first < max_p {
962        return Err(EriError::NotEnoughValidData {
963            needed: max_p,
964            valid: source.len() - first,
965        });
966    }
967
968    let rows = combos.len();
969    let cols = source.len();
970
971    let expected = rows.checked_mul(cols).ok_or(EriError::InvalidRange {
972        start: sweep.period.0,
973        end: sweep.period.1,
974        step: sweep.period.2,
975    })?;
976    if bull_out.len() != expected || bear_out.len() != expected {
977        return Err(EriError::OutputLengthMismatch {
978            expected,
979            got: bull_out.len().max(bear_out.len()),
980        });
981    }
982
983    let do_row = |row: usize, bull_row: &mut [f64], bear_row: &mut [f64]| -> Result<(), EriError> {
984        let period = combos[row].period.unwrap();
985        let ma_type = combos[row].ma_type.as_deref().unwrap();
986        let ma_vec = ma(ma_type, MaData::Slice(source), period)
987            .map_err(|e| EriError::MaCalculationError(e.to_string()))?;
988        match kern {
989            Kernel::Scalar => unsafe {
990                eri_row_scalar(high, low, &ma_vec, first, period, bull_row, bear_row)
991            },
992            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
993            Kernel::Avx2 => unsafe {
994                eri_row_avx2(high, low, &ma_vec, first, period, bull_row, bear_row)
995            },
996            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
997            Kernel::Avx512 => unsafe {
998                eri_row_avx512(high, low, &ma_vec, first, period, bull_row, bear_row)
999            },
1000            _ => unreachable!(),
1001        }
1002        Ok(())
1003    };
1004
1005    if parallel {
1006        #[cfg(not(target_arch = "wasm32"))]
1007        {
1008            use rayon::prelude::*;
1009            bull_out
1010                .par_chunks_mut(cols)
1011                .zip(bear_out.par_chunks_mut(cols))
1012                .enumerate()
1013                .for_each(|(row, (bull_row, bear_row))| {
1014                    let _ = do_row(row, bull_row, bear_row);
1015                });
1016        }
1017        #[cfg(target_arch = "wasm32")]
1018        for row in 0..rows {
1019            let bull_row = &mut bull_out[row * cols..(row + 1) * cols];
1020            let bear_row = &mut bear_out[row * cols..(row + 1) * cols];
1021            let _ = do_row(row, bull_row, bear_row);
1022        }
1023    } else {
1024        for row in 0..rows {
1025            let bull_row = &mut bull_out[row * cols..(row + 1) * cols];
1026            let bear_row = &mut bear_out[row * cols..(row + 1) * cols];
1027            let _ = do_row(row, bull_row, bear_row);
1028        }
1029    }
1030
1031    Ok(combos)
1032}
1033
1034#[inline(always)]
1035unsafe fn eri_row_scalar(
1036    high: &[f64],
1037    low: &[f64],
1038    ma: &[f64],
1039    first: usize,
1040    period: usize,
1041    bull: &mut [f64],
1042    bear: &mut [f64],
1043) {
1044    let mut i = first + period - 1;
1045    let n = high.len();
1046    if i >= n {
1047        return;
1048    }
1049
1050    let len = n - i;
1051    let mut h_ptr = high.as_ptr().add(i);
1052    let mut l_ptr = low.as_ptr().add(i);
1053    let mut m_ptr = ma.as_ptr().add(i);
1054    let mut b_ptr = bull.as_mut_ptr().add(i);
1055    let mut r_ptr = bear.as_mut_ptr().add(i);
1056
1057    let mut k = 0usize;
1058    while k + 4 <= len {
1059        let m0 = *m_ptr.add(0);
1060        *b_ptr.add(0) = *h_ptr.add(0) - m0;
1061        *r_ptr.add(0) = *l_ptr.add(0) - m0;
1062
1063        let m1 = *m_ptr.add(1);
1064        *b_ptr.add(1) = *h_ptr.add(1) - m1;
1065        *r_ptr.add(1) = *l_ptr.add(1) - m1;
1066
1067        let m2 = *m_ptr.add(2);
1068        *b_ptr.add(2) = *h_ptr.add(2) - m2;
1069        *r_ptr.add(2) = *l_ptr.add(2) - m2;
1070
1071        let m3 = *m_ptr.add(3);
1072        *b_ptr.add(3) = *h_ptr.add(3) - m3;
1073        *r_ptr.add(3) = *l_ptr.add(3) - m3;
1074
1075        h_ptr = h_ptr.add(4);
1076        l_ptr = l_ptr.add(4);
1077        m_ptr = m_ptr.add(4);
1078        b_ptr = b_ptr.add(4);
1079        r_ptr = r_ptr.add(4);
1080        k += 4;
1081    }
1082    if k + 2 <= len {
1083        let m0 = *m_ptr.add(0);
1084        *b_ptr.add(0) = *h_ptr.add(0) - m0;
1085        *r_ptr.add(0) = *l_ptr.add(0) - m0;
1086
1087        let m1 = *m_ptr.add(1);
1088        *b_ptr.add(1) = *h_ptr.add(1) - m1;
1089        *r_ptr.add(1) = *l_ptr.add(1) - m1;
1090
1091        h_ptr = h_ptr.add(2);
1092        l_ptr = l_ptr.add(2);
1093        m_ptr = m_ptr.add(2);
1094        b_ptr = b_ptr.add(2);
1095        r_ptr = r_ptr.add(2);
1096        k += 2;
1097    }
1098    if k < len {
1099        let m0 = *m_ptr;
1100        *b_ptr = *h_ptr - m0;
1101        *r_ptr = *l_ptr - m0;
1102    }
1103}
1104
1105#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1106#[inline(always)]
1107unsafe fn eri_row_avx2(
1108    high: &[f64],
1109    low: &[f64],
1110    ma: &[f64],
1111    first: usize,
1112    period: usize,
1113    bull: &mut [f64],
1114    bear: &mut [f64],
1115) {
1116    eri_avx2_core(high, low, ma, period, first, bull, bear)
1117}
1118
1119#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1120#[inline(always)]
1121unsafe fn eri_row_avx512(
1122    high: &[f64],
1123    low: &[f64],
1124    ma: &[f64],
1125    first: usize,
1126    period: usize,
1127    bull: &mut [f64],
1128    bear: &mut [f64],
1129) {
1130    if period <= 32 {
1131        eri_row_avx512_short(high, low, ma, first, period, bull, bear);
1132    } else {
1133        eri_row_avx512_long(high, low, ma, first, period, bull, bear);
1134    }
1135}
1136
1137#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1138#[inline(always)]
1139unsafe fn eri_row_avx512_short(
1140    high: &[f64],
1141    low: &[f64],
1142    ma: &[f64],
1143    first: usize,
1144    period: usize,
1145    bull: &mut [f64],
1146    bear: &mut [f64],
1147) {
1148    eri_avx512_core(high, low, ma, period, first, bull, bear)
1149}
1150
1151#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1152#[inline(always)]
1153unsafe fn eri_row_avx512_long(
1154    high: &[f64],
1155    low: &[f64],
1156    ma: &[f64],
1157    first: usize,
1158    period: usize,
1159    bull: &mut [f64],
1160    bear: &mut [f64],
1161) {
1162    eri_avx512_core(high, low, ma, period, first, bull, bear)
1163}
1164
1165#[derive(Debug, Clone)]
1166pub struct EriStream {
1167    period: usize,
1168    ma_type: String,
1169    engine: StreamMa,
1170    ready: bool,
1171}
1172
1173#[derive(Debug, Clone)]
1174enum StreamMa {
1175    Sma(SmaState),
1176    Ema(EmaState),
1177    Rma(EmaState),
1178    Dema(DemaState),
1179    Tema(TemaState),
1180    Wma(WmaState),
1181    Generic(GenericState),
1182}
1183
1184#[derive(Debug, Clone)]
1185struct SmaState {
1186    buf: Vec<f64>,
1187    pos: usize,
1188    count: usize,
1189    sum: f64,
1190    inv_n: f64,
1191}
1192
1193#[derive(Debug, Clone)]
1194struct EmaState {
1195    n: usize,
1196    alpha: f64,
1197    beta: f64,
1198
1199    init_sum: f64,
1200    init_count: usize,
1201    ema: f64,
1202}
1203
1204#[derive(Debug, Clone)]
1205struct DemaState {
1206    n: usize,
1207    alpha: f64,
1208    beta: f64,
1209    init_sum: f64,
1210    init_count: usize,
1211    e1: f64,
1212    e2: f64,
1213}
1214
1215#[derive(Debug, Clone)]
1216struct TemaState {
1217    n: usize,
1218    alpha: f64,
1219    beta: f64,
1220    init_sum: f64,
1221    init_count: usize,
1222    e1: f64,
1223    e2: f64,
1224    e3: f64,
1225}
1226
1227#[derive(Debug, Clone)]
1228struct WmaState {
1229    n: usize,
1230    den_inv: f64,
1231    buf: Vec<f64>,
1232    pos: usize,
1233    count: usize,
1234    s: f64,
1235    ws: f64,
1236}
1237
1238#[derive(Debug, Clone)]
1239struct GenericState {
1240    n: usize,
1241    buf: Vec<f64>,
1242    pos: usize,
1243    count: usize,
1244    scratch: Vec<f64>,
1245}
1246
1247impl EriStream {
1248    pub fn try_new(params: EriParams) -> Result<Self, EriError> {
1249        let period = params.period.unwrap_or(13);
1250        if period == 0 {
1251            return Err(EriError::InvalidPeriod {
1252                period,
1253                data_len: 0,
1254            });
1255        }
1256        let ma_type = params.ma_type.unwrap_or_else(|| "ema".to_string());
1257
1258        let engine = make_engine(period, &ma_type);
1259        Ok(Self {
1260            period,
1261            ma_type,
1262            engine,
1263            ready: false,
1264        })
1265    }
1266
1267    #[inline(always)]
1268    pub fn update(&mut self, high: f64, low: f64, source: f64) -> Option<(f64, f64)> {
1269        if high.is_nan() || low.is_nan() || source.is_nan() {
1270            self.reset();
1271            return None;
1272        }
1273
1274        let ma_val = match &mut self.engine {
1275            StreamMa::Sma(st) => sma_update(st, source),
1276            StreamMa::Ema(st) => ema_like_update(st, source),
1277            StreamMa::Rma(st) => ema_like_update(st, source),
1278            StreamMa::Dema(st) => dema_update(st, source),
1279            StreamMa::Tema(st) => tema_update(st, source),
1280            StreamMa::Wma(st) => wma_update(st, source),
1281            StreamMa::Generic(st) => generic_update(st, &self.ma_type, source),
1282        }?;
1283
1284        self.ready = true;
1285        Some((high - ma_val, low - ma_val))
1286    }
1287
1288    #[inline(always)]
1289    fn reset(&mut self) {
1290        self.ready = false;
1291        self.engine = make_engine(self.period, &self.ma_type);
1292    }
1293}
1294
1295#[inline(always)]
1296fn make_engine(period: usize, ma_type: &str) -> StreamMa {
1297    let t = ma_type.to_ascii_lowercase();
1298    match t.as_str() {
1299        "sma" => StreamMa::Sma(SmaState {
1300            buf: vec![0.0; period],
1301            pos: 0,
1302            count: 0,
1303            sum: 0.0,
1304            inv_n: 1.0 / period as f64,
1305        }),
1306        "ema" | "ewma" => StreamMa::Ema(EmaState {
1307            n: period,
1308            alpha: 2.0 / (period as f64 + 1.0),
1309            beta: 1.0 - (2.0 / (period as f64 + 1.0)),
1310            init_sum: 0.0,
1311            init_count: 0,
1312            ema: f64::NAN,
1313        }),
1314        "rma" | "wilder" | "smma" => StreamMa::Rma(EmaState {
1315            n: period,
1316            alpha: 1.0 / period as f64,
1317            beta: 1.0 - (1.0 / period as f64),
1318            init_sum: 0.0,
1319            init_count: 0,
1320            ema: f64::NAN,
1321        }),
1322        "dema" => StreamMa::Dema(DemaState {
1323            n: period,
1324            alpha: 2.0 / (period as f64 + 1.0),
1325            beta: 1.0 - (2.0 / (period as f64 + 1.0)),
1326            init_sum: 0.0,
1327            init_count: 0,
1328            e1: f64::NAN,
1329            e2: f64::NAN,
1330        }),
1331        "tema" => StreamMa::Tema(TemaState {
1332            n: period,
1333            alpha: 2.0 / (period as f64 + 1.0),
1334            beta: 1.0 - (2.0 / (period as f64 + 1.0)),
1335            init_sum: 0.0,
1336            init_count: 0,
1337            e1: f64::NAN,
1338            e2: f64::NAN,
1339            e3: f64::NAN,
1340        }),
1341        "wma" | "lwma" | "linear" | "linear_wma" => {
1342            let n = period as f64;
1343            let den_inv = 2.0 / (n * (n + 1.0));
1344            StreamMa::Wma(WmaState {
1345                n: period,
1346                den_inv,
1347                buf: vec![0.0; period],
1348                pos: 0,
1349                count: 0,
1350                s: 0.0,
1351                ws: 0.0,
1352            })
1353        }
1354        _ => StreamMa::Generic(GenericState {
1355            n: period,
1356            buf: vec![0.0; period],
1357            pos: 0,
1358            count: 0,
1359            scratch: vec![0.0; period],
1360        }),
1361    }
1362}
1363
1364#[inline(always)]
1365fn sma_update(st: &mut SmaState, x: f64) -> Option<f64> {
1366    let n = st.buf.len();
1367    if st.count < n {
1368        st.buf[st.pos] = x;
1369        st.sum += x;
1370        st.pos = (st.pos + 1) % n;
1371        st.count += 1;
1372        return (st.count == n).then(|| st.sum * st.inv_n);
1373    }
1374    let old = st.buf[st.pos];
1375    st.buf[st.pos] = x;
1376    st.sum += x - old;
1377    st.pos = (st.pos + 1) % n;
1378    Some(st.sum * st.inv_n)
1379}
1380
1381#[inline(always)]
1382fn ema_like_update(st: &mut EmaState, x: f64) -> Option<f64> {
1383    if st.init_count < st.n {
1384        st.init_sum += x;
1385        st.init_count += 1;
1386        if st.init_count == st.n {
1387            st.ema = st.init_sum / st.n as f64;
1388            return Some(st.ema);
1389        }
1390        return None;
1391    }
1392    st.ema = x.mul_add(st.alpha, st.beta * st.ema);
1393    Some(st.ema)
1394}
1395
1396#[inline(always)]
1397fn dema_update(st: &mut DemaState, x: f64) -> Option<f64> {
1398    if st.init_count < st.n {
1399        st.init_sum += x;
1400        st.init_count += 1;
1401        if st.init_count == st.n {
1402            st.e1 = st.init_sum / st.n as f64;
1403            st.e2 = st.e1;
1404            return Some(st.e1);
1405        }
1406        return None;
1407    }
1408    st.e1 = x.mul_add(st.alpha, st.beta * st.e1);
1409    st.e2 = st.e1.mul_add(st.alpha, st.beta * st.e2);
1410    Some(2.0f64.mul_add(st.e1, -st.e2))
1411}
1412
1413#[inline(always)]
1414fn tema_update(st: &mut TemaState, x: f64) -> Option<f64> {
1415    if st.init_count < st.n {
1416        st.init_sum += x;
1417        st.init_count += 1;
1418        if st.init_count == st.n {
1419            st.e1 = st.init_sum / st.n as f64;
1420            st.e2 = st.e1;
1421            st.e3 = st.e2;
1422            return Some(st.e1);
1423        }
1424        return None;
1425    }
1426    st.e1 = x.mul_add(st.alpha, st.beta * st.e1);
1427    st.e2 = st.e1.mul_add(st.alpha, st.beta * st.e2);
1428    st.e3 = st.e2.mul_add(st.alpha, st.beta * st.e3);
1429    Some((3.0 * st.e1) - (3.0 * st.e2) + st.e3)
1430}
1431
1432#[inline(always)]
1433fn wma_update(st: &mut WmaState, x: f64) -> Option<f64> {
1434    let n = st.n;
1435    if st.count < n {
1436        st.buf[st.pos] = x;
1437        st.s += x;
1438        st.ws += (st.count as f64 + 1.0) * x;
1439        st.pos = (st.pos + 1) % n;
1440        st.count += 1;
1441        return (st.count == n).then(|| st.ws * st.den_inv);
1442    }
1443    let old = st.buf[st.pos];
1444    st.buf[st.pos] = x;
1445    let s_prev = st.s;
1446    st.s = s_prev - old + x;
1447    st.ws = st.ws - s_prev + (n as f64) * x;
1448    st.pos = (st.pos + 1) % n;
1449    Some(st.ws * st.den_inv)
1450}
1451
1452#[inline(always)]
1453fn generic_update(st: &mut GenericState, ma_type: &str, x: f64) -> Option<f64> {
1454    let n = st.n;
1455    if st.count < n {
1456        st.buf[st.pos] = x;
1457        st.pos = (st.pos + 1) % n;
1458        st.count += 1;
1459        return None;
1460    }
1461    st.buf[st.pos] = x;
1462    st.pos = (st.pos + 1) % n;
1463
1464    for i in 0..n {
1465        let src_idx = (st.pos + i) % n;
1466        st.scratch[i] = st.buf[src_idx];
1467    }
1468    let m = ma(ma_type, MaData::Slice(&st.scratch), n).ok()?;
1469    m.last().copied()
1470}
1471
1472#[cfg(feature = "python")]
1473#[pyfunction(name = "eri")]
1474#[pyo3(signature = (high, low, source, period=13, ma_type="ema", kernel=None))]
1475pub fn eri_py<'py>(
1476    py: Python<'py>,
1477    high: numpy::PyReadonlyArray1<'py, f64>,
1478    low: numpy::PyReadonlyArray1<'py, f64>,
1479    source: numpy::PyReadonlyArray1<'py, f64>,
1480    period: usize,
1481    ma_type: &str,
1482    kernel: Option<&str>,
1483) -> PyResult<(Bound<'py, PyArray1<f64>>, Bound<'py, PyArray1<f64>>)> {
1484    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1485
1486    let high_slice = high.as_slice()?;
1487    let low_slice = low.as_slice()?;
1488    let source_slice = source.as_slice()?;
1489
1490    if high_slice.len() != low_slice.len() || high_slice.len() != source_slice.len() {
1491        return Err(PyValueError::new_err(
1492            "high, low, and source arrays must have the same length",
1493        ));
1494    }
1495
1496    let kern = validate_kernel(kernel, false)?;
1497    let params = EriParams {
1498        period: Some(period),
1499        ma_type: Some(ma_type.to_string()),
1500    };
1501    let input = EriInput::from_slices(high_slice, low_slice, source_slice, params);
1502
1503    let result = py
1504        .allow_threads(|| eri_with_kernel(&input, kern))
1505        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1506
1507    Ok((result.bull.into_pyarray(py), result.bear.into_pyarray(py)))
1508}
1509
1510#[cfg(feature = "python")]
1511#[pyclass(name = "EriStream")]
1512pub struct EriStreamPy {
1513    stream: EriStream,
1514}
1515
1516#[cfg(feature = "python")]
1517#[pymethods]
1518impl EriStreamPy {
1519    #[new]
1520    fn new(period: usize, ma_type: Option<&str>) -> PyResult<Self> {
1521        let params = EriParams {
1522            period: Some(period),
1523            ma_type: ma_type
1524                .map(|s| s.to_string())
1525                .or_else(|| Some("ema".to_string())),
1526        };
1527        let stream =
1528            EriStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1529        Ok(EriStreamPy { stream })
1530    }
1531
1532    fn update(&mut self, high: f64, low: f64, source: f64) -> Option<(f64, f64)> {
1533        self.stream.update(high, low, source)
1534    }
1535}
1536
1537#[cfg(feature = "python")]
1538#[pyfunction(name = "eri_batch")]
1539#[pyo3(signature = (high, low, source, period_range=(13, 13, 0), ma_type="ema", kernel=None))]
1540pub fn eri_batch_py<'py>(
1541    py: Python<'py>,
1542    high: numpy::PyReadonlyArray1<'py, f64>,
1543    low: numpy::PyReadonlyArray1<'py, f64>,
1544    source: numpy::PyReadonlyArray1<'py, f64>,
1545    period_range: (usize, usize, usize),
1546    ma_type: &str,
1547    kernel: Option<&str>,
1548) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1549    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1550    use pyo3::types::PyDict;
1551
1552    let high_slice = high.as_slice()?;
1553    let low_slice = low.as_slice()?;
1554    let source_slice = source.as_slice()?;
1555
1556    if high_slice.len() != low_slice.len() || high_slice.len() != source_slice.len() {
1557        return Err(PyValueError::new_err(
1558            "high, low, and source arrays must have the same length",
1559        ));
1560    }
1561
1562    let sweep = EriBatchRange {
1563        period: period_range,
1564        ma_type: ma_type.to_string(),
1565    };
1566
1567    let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1568    let rows = combos.len();
1569    let cols = high_slice.len();
1570    let total = rows
1571        .checked_mul(cols)
1572        .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
1573
1574    let bull_array: Bound<'py, PyArray1<f64>> = unsafe { PyArray1::<f64>::new(py, [total], false) };
1575    let bear_array: Bound<'py, PyArray1<f64>> = unsafe { PyArray1::<f64>::new(py, [total], false) };
1576
1577    let bull_slice = unsafe { bull_array.as_slice_mut()? };
1578    let bear_slice = unsafe { bear_array.as_slice_mut()? };
1579
1580    let first_valid = high_slice
1581        .iter()
1582        .zip(low_slice.iter())
1583        .zip(source_slice.iter())
1584        .position(|((h, l), s)| !h.is_nan() && !l.is_nan() && !s.is_nan())
1585        .unwrap_or(0);
1586
1587    for (row, combo) in combos.iter().enumerate() {
1588        let period = combo.period.unwrap();
1589        let warmup = first_valid + period - 1;
1590        let row_start = row * cols;
1591        for i in 0..warmup.min(cols) {
1592            bull_slice[row_start + i] = f64::NAN;
1593            bear_slice[row_start + i] = f64::NAN;
1594        }
1595    }
1596
1597    let kern = validate_kernel(kernel, true)?;
1598    let kernel_to_use = match kern {
1599        Kernel::Auto => detect_best_batch_kernel(),
1600        k => k,
1601    };
1602    let simd = match kernel_to_use {
1603        Kernel::Avx512Batch => Kernel::Avx512,
1604        Kernel::Avx2Batch => Kernel::Avx2,
1605        Kernel::ScalarBatch => Kernel::Scalar,
1606        _ => Kernel::Scalar,
1607    };
1608
1609    let combos = py
1610        .allow_threads(|| {
1611            eri_batch_inner_into(
1612                high_slice,
1613                low_slice,
1614                source_slice,
1615                &sweep,
1616                simd,
1617                true,
1618                bull_slice,
1619                bear_slice,
1620            )
1621        })
1622        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1623
1624    let bull_reshaped = bull_array.reshape([rows, cols])?;
1625    let bear_reshaped = bear_array.reshape([rows, cols])?;
1626
1627    let periods: Vec<usize> = combos.iter().map(|c| c.period.unwrap()).collect();
1628    let ma_types: Vec<&str> = vec![ma_type; combos.len()];
1629
1630    let dict = PyDict::new(py);
1631    dict.set_item("bull_values", bull_reshaped)?;
1632    dict.set_item("bear_values", bear_reshaped)?;
1633    dict.set_item("periods", periods.into_pyarray(py))?;
1634    dict.set_item("ma_types", ma_types)?;
1635
1636    Ok(dict.into())
1637}
1638
1639pub fn eri_into_slice(
1640    dst_bull: &mut [f64],
1641    dst_bear: &mut [f64],
1642    input: &EriInput,
1643    kern: Kernel,
1644) -> Result<(), EriError> {
1645    let (high, low, source_data) = match &input.data {
1646        EriData::Candles { candles, source } => {
1647            let high = candles
1648                .select_candle_field("high")
1649                .map_err(|_| EriError::EmptyInputData)?;
1650            let low = candles
1651                .select_candle_field("low")
1652                .map_err(|_| EriError::EmptyInputData)?;
1653            let src = source_type(candles, source);
1654            (high, low, src)
1655        }
1656        EriData::Slices { high, low, source } => (*high, *low, *source),
1657    };
1658
1659    if source_data.is_empty() || high.is_empty() || low.is_empty() {
1660        return Err(EriError::EmptyInputData);
1661    }
1662
1663    if dst_bull.len() != source_data.len() || dst_bear.len() != source_data.len() {
1664        return Err(EriError::OutputLengthMismatch {
1665            expected: source_data.len(),
1666            got: dst_bull.len().max(dst_bear.len()),
1667        });
1668    }
1669
1670    let period = input.get_period();
1671    if period == 0 || period > source_data.len() {
1672        return Err(EriError::InvalidPeriod {
1673            period,
1674            data_len: source_data.len(),
1675        });
1676    }
1677
1678    let mut first_valid_idx = None;
1679    for i in 0..source_data.len() {
1680        if !(source_data[i].is_nan() || high[i].is_nan() || low[i].is_nan()) {
1681            first_valid_idx = Some(i);
1682            break;
1683        }
1684    }
1685    let first_valid_idx = match first_valid_idx {
1686        Some(idx) => idx,
1687        None => return Err(EriError::AllValuesNaN),
1688    };
1689
1690    if (source_data.len() - first_valid_idx) < period {
1691        return Err(EriError::NotEnoughValidData {
1692            needed: period,
1693            valid: source_data.len() - first_valid_idx,
1694        });
1695    }
1696
1697    let ma_type = input.get_ma_type();
1698    let full_ma = ma(&ma_type, MaData::Slice(&source_data), period)
1699        .map_err(|e| EriError::MaCalculationError(e.to_string()))?;
1700
1701    let warmup_period = first_valid_idx + period - 1;
1702
1703    for v in &mut dst_bull[..warmup_period] {
1704        *v = f64::NAN;
1705    }
1706    for v in &mut dst_bear[..warmup_period] {
1707        *v = f64::NAN;
1708    }
1709
1710    let chosen = match kern {
1711        Kernel::Auto => Kernel::Scalar,
1712        other => other,
1713    };
1714
1715    unsafe {
1716        match chosen {
1717            Kernel::Scalar | Kernel::ScalarBatch => eri_scalar(
1718                high,
1719                low,
1720                &full_ma,
1721                period,
1722                first_valid_idx,
1723                dst_bull,
1724                dst_bear,
1725            ),
1726            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1727            Kernel::Avx2 | Kernel::Avx2Batch => eri_avx2(
1728                high,
1729                low,
1730                &full_ma,
1731                period,
1732                first_valid_idx,
1733                dst_bull,
1734                dst_bear,
1735            ),
1736            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1737            Kernel::Avx512 | Kernel::Avx512Batch => eri_avx512(
1738                high,
1739                low,
1740                &full_ma,
1741                period,
1742                first_valid_idx,
1743                dst_bull,
1744                dst_bear,
1745            ),
1746            _ => unreachable!(),
1747        }
1748    }
1749
1750    Ok(())
1751}
1752
1753#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1754#[inline]
1755pub fn eri_into(
1756    input: &EriInput,
1757    bull_out: &mut [f64],
1758    bear_out: &mut [f64],
1759) -> Result<(), EriError> {
1760    eri_into_slice(bull_out, bear_out, input, Kernel::Auto)
1761}
1762
1763#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1764#[derive(Serialize, Deserialize)]
1765pub struct EriResult {
1766    pub values: Vec<f64>,
1767    pub rows: usize,
1768    pub cols: usize,
1769}
1770
1771#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1772#[wasm_bindgen]
1773pub fn eri_js_flat(
1774    high: &[f64],
1775    low: &[f64],
1776    source: &[f64],
1777    period: usize,
1778    ma_type: &str,
1779) -> Result<JsValue, JsValue> {
1780    if high.len() != low.len() || high.len() != source.len() {
1781        return Err(JsValue::from_str("length mismatch"));
1782    }
1783    let params = EriParams {
1784        period: Some(period),
1785        ma_type: Some(ma_type.to_string()),
1786    };
1787    let input = EriInput::from_slices(high, low, source, params);
1788
1789    let mut bull = vec![0.0; source.len()];
1790    let mut bear = vec![0.0; source.len()];
1791    eri_into_slice(&mut bull, &mut bear, &input, Kernel::Auto)
1792        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1793
1794    let mut values = bull;
1795    values.extend_from_slice(&bear);
1796
1797    let out = EriResult {
1798        values,
1799        rows: 2,
1800        cols: source.len(),
1801    };
1802    serde_wasm_bindgen::to_value(&out)
1803        .map_err(|e| JsValue::from_str(&format!("Serialization error: {e}")))
1804}
1805
1806#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1807#[wasm_bindgen]
1808pub fn eri_js(
1809    high: &[f64],
1810    low: &[f64],
1811    source: &[f64],
1812    period: usize,
1813    ma_type: &str,
1814) -> Result<Vec<f64>, JsValue> {
1815    if high.len() != low.len() || high.len() != source.len() {
1816        return Err(JsValue::from_str(
1817            "high, low, and source arrays must have the same length",
1818        ));
1819    }
1820
1821    let params = EriParams {
1822        period: Some(period),
1823        ma_type: Some(ma_type.to_string()),
1824    };
1825    let input = EriInput::from_slices(high, low, source, params);
1826
1827    let total = source
1828        .len()
1829        .checked_mul(2)
1830        .ok_or_else(|| JsValue::from_str("length overflow"))?;
1831    let mut output = vec![0.0; total];
1832    let (bull_part, bear_part) = output.split_at_mut(source.len());
1833
1834    eri_into_slice(bull_part, bear_part, &input, Kernel::Auto)
1835        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1836
1837    Ok(output)
1838}
1839
1840#[cfg(all(feature = "python", feature = "cuda"))]
1841use crate::cuda::eri_wrapper::CudaEri;
1842#[cfg(all(feature = "python", feature = "cuda"))]
1843use crate::utilities::dlpack_cuda::{make_device_array_py, DeviceArrayF32Py};
1844#[cfg(all(feature = "python", feature = "cuda"))]
1845#[pyfunction(name = "eri_cuda_batch_dev")]
1846#[pyo3(signature = (high_f32, low_f32, source_f32, period_range, ma_type, device_id=0))]
1847pub fn eri_cuda_batch_dev_py(
1848    py: Python<'_>,
1849    high_f32: numpy::PyReadonlyArray1<'_, f32>,
1850    low_f32: numpy::PyReadonlyArray1<'_, f32>,
1851    source_f32: numpy::PyReadonlyArray1<'_, f32>,
1852    period_range: (usize, usize, usize),
1853    ma_type: &str,
1854    device_id: usize,
1855) -> PyResult<(DeviceArrayF32Py, DeviceArrayF32Py)> {
1856    use crate::cuda::cuda_available;
1857    if !cuda_available() {
1858        return Err(PyValueError::new_err("CUDA not available"));
1859    }
1860    let h = high_f32.as_slice()?;
1861    let l = low_f32.as_slice()?;
1862    let s = source_f32.as_slice()?;
1863    let sweep = EriBatchRange {
1864        period: period_range,
1865        ma_type: ma_type.to_string(),
1866    };
1867    let (bull, bear) = py.allow_threads(|| {
1868        let cuda = CudaEri::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1869        cuda.eri_batch_dev(h, l, s, &sweep)
1870            .map_err(|e| PyValueError::new_err(e.to_string()))
1871            .map(|((bull, bear), _combos)| (bull, bear))
1872    })?;
1873    let bull_dev = make_device_array_py(device_id, bull)?;
1874    let bear_dev = make_device_array_py(device_id, bear)?;
1875    Ok((bull_dev, bear_dev))
1876}
1877
1878#[cfg(all(feature = "python", feature = "cuda"))]
1879#[pyfunction(name = "eri_cuda_many_series_one_param_dev")]
1880#[pyo3(signature = (high_tm_f32, low_tm_f32, source_tm_f32, cols, rows, period, ma_type, device_id=0))]
1881pub fn eri_cuda_many_series_one_param_dev_py(
1882    py: Python<'_>,
1883    high_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
1884    low_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
1885    source_tm_f32: numpy::PyReadonlyArray1<'_, f32>,
1886    cols: usize,
1887    rows: usize,
1888    period: usize,
1889    ma_type: &str,
1890    device_id: usize,
1891) -> PyResult<(DeviceArrayF32Py, DeviceArrayF32Py)> {
1892    use crate::cuda::cuda_available;
1893    if !cuda_available() {
1894        return Err(PyValueError::new_err("CUDA not available"));
1895    }
1896    let h = high_tm_f32.as_slice()?;
1897    let l = low_tm_f32.as_slice()?;
1898    let s = source_tm_f32.as_slice()?;
1899    let (bull, bear) = py.allow_threads(|| {
1900        let cuda = CudaEri::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1901        cuda.eri_many_series_one_param_time_major_dev(h, l, s, cols, rows, period, ma_type)
1902            .map_err(|e| PyValueError::new_err(e.to_string()))
1903            .map(|(bull, bear)| (bull, bear))
1904    })?;
1905    let bull_dev = make_device_array_py(device_id, bull)?;
1906    let bear_dev = make_device_array_py(device_id, bear)?;
1907    Ok((bull_dev, bear_dev))
1908}
1909
1910#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1911#[wasm_bindgen]
1912pub fn eri_into(
1913    high_ptr: *const f64,
1914    low_ptr: *const f64,
1915    source_ptr: *const f64,
1916    bull_ptr: *mut f64,
1917    bear_ptr: *mut f64,
1918    len: usize,
1919    period: usize,
1920    ma_type: &str,
1921) -> Result<(), JsValue> {
1922    if high_ptr.is_null()
1923        || low_ptr.is_null()
1924        || source_ptr.is_null()
1925        || bull_ptr.is_null()
1926        || bear_ptr.is_null()
1927    {
1928        return Err(JsValue::from_str("Null pointer provided"));
1929    }
1930
1931    unsafe {
1932        let high = std::slice::from_raw_parts(high_ptr, len);
1933        let low = std::slice::from_raw_parts(low_ptr, len);
1934        let source = std::slice::from_raw_parts(source_ptr, len);
1935
1936        let params = EriParams {
1937            period: Some(period),
1938            ma_type: Some(ma_type.to_string()),
1939        };
1940        let input = EriInput::from_slices(high, low, source, params);
1941
1942        let needs_temp = bull_ptr as *const f64 == high_ptr
1943            || bull_ptr as *const f64 == low_ptr
1944            || bull_ptr as *const f64 == source_ptr
1945            || bear_ptr as *const f64 == high_ptr
1946            || bear_ptr as *const f64 == low_ptr
1947            || bear_ptr as *const f64 == source_ptr
1948            || bull_ptr == bear_ptr;
1949
1950        if needs_temp {
1951            let mut temp_bull = vec![0.0; len];
1952            let mut temp_bear = vec![0.0; len];
1953            eri_into_slice(&mut temp_bull, &mut temp_bear, &input, Kernel::Auto)
1954                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1955
1956            let bull_out = std::slice::from_raw_parts_mut(bull_ptr, len);
1957            let bear_out = std::slice::from_raw_parts_mut(bear_ptr, len);
1958            bull_out.copy_from_slice(&temp_bull);
1959            bear_out.copy_from_slice(&temp_bear);
1960        } else {
1961            let bull_out = std::slice::from_raw_parts_mut(bull_ptr, len);
1962            let bear_out = std::slice::from_raw_parts_mut(bear_ptr, len);
1963            eri_into_slice(bull_out, bear_out, &input, Kernel::Auto)
1964                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1965        }
1966
1967        Ok(())
1968    }
1969}
1970
1971#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1972#[wasm_bindgen]
1973pub fn eri_alloc(len: usize) -> *mut f64 {
1974    let mut vec = Vec::<f64>::with_capacity(len);
1975    let ptr = vec.as_mut_ptr();
1976    std::mem::forget(vec);
1977    ptr
1978}
1979
1980#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1981#[wasm_bindgen]
1982pub fn eri_free(ptr: *mut f64, len: usize) {
1983    if !ptr.is_null() {
1984        unsafe {
1985            let _ = Vec::from_raw_parts(ptr, len, len);
1986        }
1987    }
1988}
1989
1990#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1991#[derive(Serialize, Deserialize)]
1992pub struct EriBatchConfig {
1993    pub period_range: (usize, usize, usize),
1994    pub ma_type: String,
1995}
1996
1997#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1998#[derive(Serialize, Deserialize)]
1999pub struct EriBatchJsOutput {
2000    pub values: Vec<f64>,
2001    pub rows: usize,
2002    pub cols: usize,
2003    pub periods: Vec<usize>,
2004}
2005
2006#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2007#[wasm_bindgen(js_name = eri_batch)]
2008pub fn eri_batch_js(
2009    high: &[f64],
2010    low: &[f64],
2011    source: &[f64],
2012    config: JsValue,
2013) -> Result<JsValue, JsValue> {
2014    if high.len() != low.len() || high.len() != source.len() {
2015        return Err(JsValue::from_str(
2016            "high, low, and source arrays must have the same length",
2017        ));
2018    }
2019
2020    let config: EriBatchConfig = serde_wasm_bindgen::from_value(config)
2021        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2022
2023    let sweep = EriBatchRange {
2024        period: config.period_range,
2025        ma_type: config.ma_type,
2026    };
2027
2028    let output = eri_batch_with_kernel(high, low, source, &sweep, Kernel::Auto)
2029        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2030
2031    let rows = output
2032        .rows
2033        .checked_mul(2)
2034        .ok_or_else(|| JsValue::from_str("rows overflow"))?;
2035    let cols = output.cols;
2036
2037    let cap = rows
2038        .checked_mul(cols)
2039        .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
2040    let mut values = Vec::with_capacity(cap);
2041
2042    for r in 0..output.rows {
2043        let start = r * cols;
2044        values.extend_from_slice(&output.bull[start..start + cols]);
2045    }
2046
2047    for r in 0..output.rows {
2048        let start = r * cols;
2049        values.extend_from_slice(&output.bear[start..start + cols]);
2050    }
2051
2052    let periods: Vec<usize> = output.params.iter().map(|p| p.period.unwrap()).collect();
2053
2054    let js_output = EriBatchJsOutput {
2055        values,
2056        rows,
2057        cols,
2058        periods,
2059    };
2060    serde_wasm_bindgen::to_value(&js_output)
2061        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2062}
2063
2064#[inline]
2065pub unsafe fn eri_scalar_classic_sma(
2066    high: &[f64],
2067    low: &[f64],
2068    source: &[f64],
2069    period: usize,
2070    first_valid_idx: usize,
2071    bull: &mut [f64],
2072    bear: &mut [f64],
2073) -> Result<(), EriError> {
2074    let start_idx = first_valid_idx + period - 1;
2075
2076    let mut sum = 0.0;
2077    for i in 0..period {
2078        sum += source[first_valid_idx + i];
2079    }
2080    let mut sma = sum / period as f64;
2081
2082    bull[start_idx] = high[start_idx] - sma;
2083    bear[start_idx] = low[start_idx] - sma;
2084
2085    for i in (start_idx + 1)..source.len() {
2086        let old_val = source[i - period];
2087        let new_val = source[i];
2088        sum = sum - old_val + new_val;
2089        sma = sum / period as f64;
2090
2091        bull[i] = high[i] - sma;
2092        bear[i] = low[i] - sma;
2093    }
2094
2095    Ok(())
2096}
2097
2098#[inline]
2099pub unsafe fn eri_scalar_classic_ema(
2100    high: &[f64],
2101    low: &[f64],
2102    source: &[f64],
2103    period: usize,
2104    first_valid_idx: usize,
2105    bull: &mut [f64],
2106    bear: &mut [f64],
2107) -> Result<(), EriError> {
2108    let start_idx = first_valid_idx + period - 1;
2109    let alpha = 2.0 / (period as f64 + 1.0);
2110    let beta = 1.0 - alpha;
2111
2112    let mut sum = 0.0;
2113    for i in 0..period {
2114        sum += source[first_valid_idx + i];
2115    }
2116    let mut ema = sum / period as f64;
2117
2118    bull[start_idx] = high[start_idx] - ema;
2119    bear[start_idx] = low[start_idx] - ema;
2120
2121    for i in (start_idx + 1)..source.len() {
2122        ema = alpha * source[i] + beta * ema;
2123
2124        bull[i] = high[i] - ema;
2125        bear[i] = low[i] - ema;
2126    }
2127
2128    Ok(())
2129}
2130
2131#[cfg(test)]
2132mod tests {
2133    use super::*;
2134    use crate::skip_if_unsupported;
2135    use crate::utilities::data_loader::read_candles_from_csv;
2136    use crate::utilities::enums::Kernel;
2137
2138    fn check_eri_partial_params(
2139        test_name: &str,
2140        kernel: Kernel,
2141    ) -> Result<(), Box<dyn std::error::Error>> {
2142        skip_if_unsupported!(kernel, test_name);
2143        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2144        let candles = read_candles_from_csv(file_path)?;
2145
2146        let default_params = EriParams {
2147            period: None,
2148            ma_type: None,
2149        };
2150        let input_default = EriInput::from_candles(&candles, "close", default_params);
2151        let output_default = eri_with_kernel(&input_default, kernel)?;
2152        assert_eq!(output_default.bull.len(), candles.close.len());
2153        assert_eq!(output_default.bear.len(), candles.close.len());
2154
2155        let params_period_14 = EriParams {
2156            period: Some(14),
2157            ma_type: Some("ema".to_string()),
2158        };
2159        let input_period_14 = EriInput::from_candles(&candles, "hl2", params_period_14);
2160        let output_period_14 = eri_with_kernel(&input_period_14, kernel)?;
2161        assert_eq!(output_period_14.bull.len(), candles.close.len());
2162        assert_eq!(output_period_14.bear.len(), candles.close.len());
2163
2164        let params_custom = EriParams {
2165            period: Some(20),
2166            ma_type: Some("sma".to_string()),
2167        };
2168        let input_custom = EriInput::from_candles(&candles, "hlc3", params_custom);
2169        let output_custom = eri_with_kernel(&input_custom, kernel)?;
2170        assert_eq!(output_custom.bull.len(), candles.close.len());
2171        assert_eq!(output_custom.bear.len(), candles.close.len());
2172
2173        Ok(())
2174    }
2175
2176    fn check_eri_accuracy(
2177        test_name: &str,
2178        kernel: Kernel,
2179    ) -> Result<(), Box<dyn std::error::Error>> {
2180        skip_if_unsupported!(kernel, test_name);
2181        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2182        let candles = read_candles_from_csv(file_path)?;
2183        let close_prices = candles
2184            .select_candle_field("close")
2185            .expect("Failed to extract close prices");
2186
2187        let params = EriParams {
2188            period: Some(13),
2189            ma_type: Some("ema".to_string()),
2190        };
2191        let input = EriInput::from_candles(&candles, "close", params);
2192        let eri_result = eri_with_kernel(&input, kernel)?;
2193
2194        assert_eq!(eri_result.bull.len(), close_prices.len());
2195        assert_eq!(eri_result.bear.len(), close_prices.len());
2196
2197        let expected_bull_last_five = [
2198            -103.35343557205488,
2199            6.839912366813223,
2200            -42.851503685589705,
2201            -9.444146016219747,
2202            11.476446271808527,
2203        ];
2204        let expected_bear_last_five = [
2205            -433.3534355720549,
2206            -314.1600876331868,
2207            -414.8515036855897,
2208            -336.44414601621975,
2209            -925.5235537281915,
2210        ];
2211
2212        let start_index = eri_result.bull.len() - 5;
2213        for i in 0..5 {
2214            let actual_bull = eri_result.bull[start_index + i];
2215            let actual_bear = eri_result.bear[start_index + i];
2216            let expected_bull = expected_bull_last_five[i];
2217            let expected_bear = expected_bear_last_five[i];
2218            assert!(
2219                (actual_bull - expected_bull).abs() < 1e-2,
2220                "ERI bull mismatch at index {}: expected {}, got {}",
2221                i,
2222                expected_bull,
2223                actual_bull
2224            );
2225            assert!(
2226                (actual_bear - expected_bear).abs() < 1e-2,
2227                "ERI bear mismatch at index {}: expected {}, got {}",
2228                i,
2229                expected_bear,
2230                actual_bear
2231            );
2232        }
2233        Ok(())
2234    }
2235
2236    fn check_eri_default_candles(
2237        test_name: &str,
2238        kernel: Kernel,
2239    ) -> Result<(), Box<dyn std::error::Error>> {
2240        skip_if_unsupported!(kernel, test_name);
2241        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2242        let candles = read_candles_from_csv(file_path)?;
2243
2244        let input = EriInput::with_default_candles(&candles);
2245        match input.data {
2246            EriData::Candles { source, .. } => assert_eq!(source, "close"),
2247            _ => panic!("Expected EriData::Candles"),
2248        }
2249        let output = eri_with_kernel(&input, kernel)?;
2250        assert_eq!(output.bull.len(), candles.close.len());
2251        assert_eq!(output.bear.len(), candles.close.len());
2252
2253        Ok(())
2254    }
2255
2256    fn check_eri_zero_period(
2257        test_name: &str,
2258        kernel: Kernel,
2259    ) -> Result<(), Box<dyn std::error::Error>> {
2260        skip_if_unsupported!(kernel, test_name);
2261        let high = [10.0, 20.0, 30.0];
2262        let low = [8.0, 18.0, 28.0];
2263        let src = [9.0, 19.0, 29.0];
2264        let params = EriParams {
2265            period: Some(0),
2266            ma_type: Some("ema".to_string()),
2267        };
2268        let input = EriInput::from_slices(&high, &low, &src, params);
2269        let res = eri_with_kernel(&input, kernel);
2270        assert!(
2271            res.is_err(),
2272            "[{}] ERI should fail with zero period",
2273            test_name
2274        );
2275        Ok(())
2276    }
2277
2278    fn check_eri_period_exceeds_length(
2279        test_name: &str,
2280        kernel: Kernel,
2281    ) -> Result<(), Box<dyn std::error::Error>> {
2282        skip_if_unsupported!(kernel, test_name);
2283        let high = [10.0, 20.0, 30.0];
2284        let low = [8.0, 18.0, 28.0];
2285        let src = [9.0, 19.0, 29.0];
2286        let params = EriParams {
2287            period: Some(10),
2288            ma_type: Some("ema".to_string()),
2289        };
2290        let input = EriInput::from_slices(&high, &low, &src, params);
2291        let res = eri_with_kernel(&input, kernel);
2292        assert!(
2293            res.is_err(),
2294            "[{}] ERI should fail with period exceeding length",
2295            test_name
2296        );
2297        Ok(())
2298    }
2299
2300    fn check_eri_very_small_dataset(
2301        test_name: &str,
2302        kernel: Kernel,
2303    ) -> Result<(), Box<dyn std::error::Error>> {
2304        skip_if_unsupported!(kernel, test_name);
2305        let high = [42.0];
2306        let low = [40.0];
2307        let src = [41.0];
2308        let params = EriParams {
2309            period: Some(9),
2310            ma_type: Some("ema".to_string()),
2311        };
2312        let input = EriInput::from_slices(&high, &low, &src, params);
2313        let res = eri_with_kernel(&input, kernel);
2314        assert!(
2315            res.is_err(),
2316            "[{}] ERI should fail with insufficient data",
2317            test_name
2318        );
2319        Ok(())
2320    }
2321
2322    fn check_eri_reinput(
2323        test_name: &str,
2324        kernel: Kernel,
2325    ) -> Result<(), Box<dyn std::error::Error>> {
2326        skip_if_unsupported!(kernel, test_name);
2327        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2328        let candles = read_candles_from_csv(file_path)?;
2329
2330        let first_params = EriParams {
2331            period: Some(14),
2332            ma_type: Some("ema".to_string()),
2333        };
2334        let first_input = EriInput::from_candles(&candles, "close", first_params);
2335        let first_result = eri_with_kernel(&first_input, kernel)?;
2336
2337        assert_eq!(first_result.bull.len(), candles.close.len());
2338        assert_eq!(first_result.bear.len(), candles.close.len());
2339
2340        let second_params = EriParams {
2341            period: Some(14),
2342            ma_type: Some("ema".to_string()),
2343        };
2344        let second_input = EriInput::from_slices(
2345            &first_result.bull,
2346            &first_result.bear,
2347            &first_result.bull,
2348            second_params,
2349        );
2350        let second_result = eri_with_kernel(&second_input, kernel)?;
2351
2352        assert_eq!(second_result.bull.len(), first_result.bull.len());
2353        assert_eq!(second_result.bear.len(), first_result.bear.len());
2354
2355        for i in 28..second_result.bull.len() {
2356            assert!(
2357                !second_result.bull[i].is_nan(),
2358                "Expected no NaN in bull after index 28, but found NaN at index {}",
2359                i
2360            );
2361            assert!(
2362                !second_result.bear[i].is_nan(),
2363                "Expected no NaN in bear after index 28, but found NaN at index {}",
2364                i
2365            );
2366        }
2367        Ok(())
2368    }
2369
2370    fn check_eri_nan_handling(
2371        test_name: &str,
2372        kernel: Kernel,
2373    ) -> Result<(), Box<dyn std::error::Error>> {
2374        skip_if_unsupported!(kernel, test_name);
2375        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2376        let candles = read_candles_from_csv(file_path)?;
2377
2378        let input = EriInput::from_candles(
2379            &candles,
2380            "close",
2381            EriParams {
2382                period: Some(13),
2383                ma_type: Some("ema".to_string()),
2384            },
2385        );
2386        let res = eri_with_kernel(&input, kernel)?;
2387        assert_eq!(res.bull.len(), candles.close.len());
2388        if res.bull.len() > 240 {
2389            for (i, &val) in res.bull[240..].iter().enumerate() {
2390                assert!(
2391                    !val.is_nan(),
2392                    "[{}] Found unexpected NaN at bull-index {}",
2393                    test_name,
2394                    240 + i
2395                );
2396            }
2397        }
2398        if res.bear.len() > 240 {
2399            for (i, &val) in res.bear[240..].iter().enumerate() {
2400                assert!(
2401                    !val.is_nan(),
2402                    "[{}] Found unexpected NaN at bear-index {}",
2403                    test_name,
2404                    240 + i
2405                );
2406            }
2407        }
2408        Ok(())
2409    }
2410
2411    #[cfg(debug_assertions)]
2412    fn check_eri_no_poison(
2413        test_name: &str,
2414        kernel: Kernel,
2415    ) -> Result<(), Box<dyn std::error::Error>> {
2416        skip_if_unsupported!(kernel, test_name);
2417
2418        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2419        let candles = read_candles_from_csv(file_path)?;
2420
2421        let test_params = vec![
2422            EriParams::default(),
2423            EriParams {
2424                period: Some(2),
2425                ma_type: Some("ema".to_string()),
2426            },
2427            EriParams {
2428                period: Some(5),
2429                ma_type: Some("sma".to_string()),
2430            },
2431            EriParams {
2432                period: Some(7),
2433                ma_type: Some("ema".to_string()),
2434            },
2435            EriParams {
2436                period: Some(10),
2437                ma_type: Some("wma".to_string()),
2438            },
2439            EriParams {
2440                period: Some(13),
2441                ma_type: Some("ema".to_string()),
2442            },
2443            EriParams {
2444                period: Some(20),
2445                ma_type: Some("sma".to_string()),
2446            },
2447            EriParams {
2448                period: Some(30),
2449                ma_type: Some("ema".to_string()),
2450            },
2451            EriParams {
2452                period: Some(50),
2453                ma_type: Some("sma".to_string()),
2454            },
2455            EriParams {
2456                period: Some(100),
2457                ma_type: Some("ema".to_string()),
2458            },
2459            EriParams {
2460                period: Some(3),
2461                ma_type: Some("hma".to_string()),
2462            },
2463            EriParams {
2464                period: Some(21),
2465                ma_type: Some("dema".to_string()),
2466            },
2467            EriParams {
2468                period: Some(14),
2469                ma_type: Some("tema".to_string()),
2470            },
2471        ];
2472
2473        for (param_idx, params) in test_params.iter().enumerate() {
2474            let input = EriInput::from_candles(&candles, "close", params.clone());
2475            let output = eri_with_kernel(&input, kernel)?;
2476
2477            for (i, &val) in output.bull.iter().enumerate() {
2478                if val.is_nan() {
2479                    continue;
2480                }
2481
2482                let bits = val.to_bits();
2483
2484                if bits == 0x11111111_11111111 {
2485                    panic!(
2486						"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at bull index {} \
2487						 with params: period={}, ma_type={} (param set {})",
2488						test_name,
2489						val,
2490						bits,
2491						i,
2492						params.period.unwrap_or(13),
2493						params.ma_type.as_ref().unwrap_or(&"ema".to_string()),
2494						param_idx
2495					);
2496                }
2497
2498                if bits == 0x22222222_22222222 {
2499                    panic!(
2500						"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at bull index {} \
2501						 with params: period={}, ma_type={} (param set {})",
2502						test_name,
2503						val,
2504						bits,
2505						i,
2506						params.period.unwrap_or(13),
2507						params.ma_type.as_ref().unwrap_or(&"ema".to_string()),
2508						param_idx
2509					);
2510                }
2511
2512                if bits == 0x33333333_33333333 {
2513                    panic!(
2514						"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at bull index {} \
2515						 with params: period={}, ma_type={} (param set {})",
2516						test_name,
2517						val,
2518						bits,
2519						i,
2520						params.period.unwrap_or(13),
2521						params.ma_type.as_ref().unwrap_or(&"ema".to_string()),
2522						param_idx
2523					);
2524                }
2525            }
2526
2527            for (i, &val) in output.bear.iter().enumerate() {
2528                if val.is_nan() {
2529                    continue;
2530                }
2531
2532                let bits = val.to_bits();
2533
2534                if bits == 0x11111111_11111111 {
2535                    panic!(
2536						"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at bear index {} \
2537						 with params: period={}, ma_type={} (param set {})",
2538						test_name,
2539						val,
2540						bits,
2541						i,
2542						params.period.unwrap_or(13),
2543						params.ma_type.as_ref().unwrap_or(&"ema".to_string()),
2544						param_idx
2545					);
2546                }
2547
2548                if bits == 0x22222222_22222222 {
2549                    panic!(
2550						"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at bear index {} \
2551						 with params: period={}, ma_type={} (param set {})",
2552						test_name,
2553						val,
2554						bits,
2555						i,
2556						params.period.unwrap_or(13),
2557						params.ma_type.as_ref().unwrap_or(&"ema".to_string()),
2558						param_idx
2559					);
2560                }
2561
2562                if bits == 0x33333333_33333333 {
2563                    panic!(
2564						"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at bear index {} \
2565						 with params: period={}, ma_type={} (param set {})",
2566						test_name,
2567						val,
2568						bits,
2569						i,
2570						params.period.unwrap_or(13),
2571						params.ma_type.as_ref().unwrap_or(&"ema".to_string()),
2572						param_idx
2573					);
2574                }
2575            }
2576        }
2577
2578        Ok(())
2579    }
2580
2581    #[cfg(not(debug_assertions))]
2582    fn check_eri_no_poison(
2583        _test_name: &str,
2584        _kernel: Kernel,
2585    ) -> Result<(), Box<dyn std::error::Error>> {
2586        Ok(())
2587    }
2588
2589    macro_rules! generate_all_eri_tests {
2590        ($($test_fn:ident),*) => {
2591            paste::paste! {
2592                $(
2593                    #[test]
2594                    fn [<$test_fn _scalar_f64>]() {
2595                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2596                    }
2597                )*
2598                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2599                $(
2600                    #[test]
2601                    fn [<$test_fn _avx2_f64>]() {
2602                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2603                    }
2604                    #[test]
2605                    fn [<$test_fn _avx512_f64>]() {
2606                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2607                    }
2608                )*
2609            }
2610        }
2611    }
2612
2613    generate_all_eri_tests!(
2614        check_eri_partial_params,
2615        check_eri_accuracy,
2616        check_eri_default_candles,
2617        check_eri_zero_period,
2618        check_eri_period_exceeds_length,
2619        check_eri_very_small_dataset,
2620        check_eri_reinput,
2621        check_eri_nan_handling,
2622        check_eri_no_poison
2623    );
2624
2625    #[cfg(test)]
2626    generate_all_eri_tests!(check_eri_property);
2627
2628    fn check_batch_default_row(
2629        test: &str,
2630        kernel: Kernel,
2631    ) -> Result<(), Box<dyn std::error::Error>> {
2632        skip_if_unsupported!(kernel, test);
2633
2634        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2635        let c = read_candles_from_csv(file)?;
2636
2637        let high = c.select_candle_field("high").unwrap();
2638        let low = c.select_candle_field("low").unwrap();
2639        let src = c.select_candle_field("close").unwrap();
2640
2641        let output = EriBatchBuilder::new()
2642            .kernel(kernel)
2643            .period_static(13)
2644            .apply_slices(high, low, src)?;
2645
2646        let def = EriParams::default();
2647        let row = output.values_for_bull(&def).expect("default row missing");
2648
2649        assert_eq!(row.len(), c.close.len());
2650
2651        let expected = [
2652            -103.35343557205488,
2653            6.839912366813223,
2654            -42.851503685589705,
2655            -9.444146016219747,
2656            11.476446271808527,
2657        ];
2658        let start = row.len() - 5;
2659        for (i, &v) in row[start..].iter().enumerate() {
2660            assert!(
2661                (v - expected[i]).abs() < 1e-2,
2662                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
2663            );
2664        }
2665        Ok(())
2666    }
2667
2668    #[test]
2669    fn test_eri_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
2670        let len = 256usize;
2671        let mut high = vec![0.0; len];
2672        let mut low = vec![0.0; len];
2673        let mut src = vec![0.0; len];
2674        for i in 0..len {
2675            let base = 100.0 + (i as f64) * 0.1;
2676            src[i] = base;
2677            high[i] = base + 1.0;
2678            low[i] = base - 1.0;
2679        }
2680
2681        let params = EriParams {
2682            period: Some(13),
2683            ma_type: Some("ema".to_string()),
2684        };
2685        let input = EriInput::from_slices(&high, &low, &src, params);
2686
2687        let baseline = eri(&input)?;
2688
2689        let mut bull = vec![0.0; len];
2690        let mut bear = vec![0.0; len];
2691        #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2692        {
2693            eri_into(&input, &mut bull, &mut bear)?;
2694        }
2695        #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2696        {
2697            eri_into_slice(&mut bull, &mut bear, &input, Kernel::Auto)?;
2698        }
2699
2700        assert_eq!(baseline.bull.len(), bull.len());
2701        assert_eq!(baseline.bear.len(), bear.len());
2702
2703        fn eq_or_both_nan(a: f64, b: f64) -> bool {
2704            (a.is_nan() && b.is_nan()) || (a - b).abs() <= 1e-12
2705        }
2706
2707        for i in 0..len {
2708            assert!(
2709                eq_or_both_nan(baseline.bull[i], bull[i]),
2710                "bull mismatch at index {i}: baseline={} into={}",
2711                baseline.bull[i],
2712                bull[i]
2713            );
2714            assert!(
2715                eq_or_both_nan(baseline.bear[i], bear[i]),
2716                "bear mismatch at index {i}: baseline={} into={}",
2717                baseline.bear[i],
2718                bear[i]
2719            );
2720        }
2721
2722        Ok(())
2723    }
2724
2725    #[cfg(debug_assertions)]
2726    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
2727        skip_if_unsupported!(kernel, test);
2728
2729        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2730        let c = read_candles_from_csv(file)?;
2731
2732        let high = c.select_candle_field("high").unwrap();
2733        let low = c.select_candle_field("low").unwrap();
2734        let src = c.select_candle_field("close").unwrap();
2735
2736        let test_configs = vec![
2737            (2, 10, 2),
2738            (5, 25, 5),
2739            (30, 60, 15),
2740            (2, 5, 1),
2741            (10, 20, 2),
2742            (20, 50, 10),
2743            (13, 13, 0),
2744        ];
2745
2746        for (cfg_idx, &(p_start, p_end, p_step)) in test_configs.iter().enumerate() {
2747            let output = EriBatchBuilder::new()
2748                .kernel(kernel)
2749                .period_range(p_start, p_end, p_step)
2750                .apply_slices(high, low, src)?;
2751
2752            for (idx, &val) in output.bull.iter().enumerate() {
2753                if val.is_nan() {
2754                    continue;
2755                }
2756
2757                let bits = val.to_bits();
2758                let row = idx / output.cols;
2759                let col = idx % output.cols;
2760                let combo = &output.params[row];
2761
2762                if bits == 0x11111111_11111111 {
2763                    panic!(
2764                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2765						 at bull row {} col {} (flat index {}) with params: period={}, ma_type={}",
2766                        test,
2767                        cfg_idx,
2768                        val,
2769                        bits,
2770                        row,
2771                        col,
2772                        idx,
2773                        combo.period.unwrap_or(13),
2774                        combo.ma_type.as_ref().unwrap_or(&"ema".to_string())
2775                    );
2776                }
2777
2778                if bits == 0x22222222_22222222 {
2779                    panic!(
2780                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2781						 at bull row {} col {} (flat index {}) with params: period={}, ma_type={}",
2782                        test,
2783                        cfg_idx,
2784                        val,
2785                        bits,
2786                        row,
2787                        col,
2788                        idx,
2789                        combo.period.unwrap_or(13),
2790                        combo.ma_type.as_ref().unwrap_or(&"ema".to_string())
2791                    );
2792                }
2793
2794                if bits == 0x33333333_33333333 {
2795                    panic!(
2796                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2797						 at bull row {} col {} (flat index {}) with params: period={}, ma_type={}",
2798                        test,
2799                        cfg_idx,
2800                        val,
2801                        bits,
2802                        row,
2803                        col,
2804                        idx,
2805                        combo.period.unwrap_or(13),
2806                        combo.ma_type.as_ref().unwrap_or(&"ema".to_string())
2807                    );
2808                }
2809            }
2810
2811            for (idx, &val) in output.bear.iter().enumerate() {
2812                if val.is_nan() {
2813                    continue;
2814                }
2815
2816                let bits = val.to_bits();
2817                let row = idx / output.cols;
2818                let col = idx % output.cols;
2819                let combo = &output.params[row];
2820
2821                if bits == 0x11111111_11111111 {
2822                    panic!(
2823                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2824						 at bear row {} col {} (flat index {}) with params: period={}, ma_type={}",
2825                        test,
2826                        cfg_idx,
2827                        val,
2828                        bits,
2829                        row,
2830                        col,
2831                        idx,
2832                        combo.period.unwrap_or(13),
2833                        combo.ma_type.as_ref().unwrap_or(&"ema".to_string())
2834                    );
2835                }
2836
2837                if bits == 0x22222222_22222222 {
2838                    panic!(
2839                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2840						 at bear row {} col {} (flat index {}) with params: period={}, ma_type={}",
2841                        test,
2842                        cfg_idx,
2843                        val,
2844                        bits,
2845                        row,
2846                        col,
2847                        idx,
2848                        combo.period.unwrap_or(13),
2849                        combo.ma_type.as_ref().unwrap_or(&"ema".to_string())
2850                    );
2851                }
2852
2853                if bits == 0x33333333_33333333 {
2854                    panic!(
2855                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2856						 at bear row {} col {} (flat index {}) with params: period={}, ma_type={}",
2857                        test,
2858                        cfg_idx,
2859                        val,
2860                        bits,
2861                        row,
2862                        col,
2863                        idx,
2864                        combo.period.unwrap_or(13),
2865                        combo.ma_type.as_ref().unwrap_or(&"ema".to_string())
2866                    );
2867                }
2868            }
2869        }
2870
2871        Ok(())
2872    }
2873
2874    #[cfg(not(debug_assertions))]
2875    fn check_batch_no_poison(
2876        _test: &str,
2877        _kernel: Kernel,
2878    ) -> Result<(), Box<dyn std::error::Error>> {
2879        Ok(())
2880    }
2881
2882    #[cfg(test)]
2883    #[allow(clippy::float_cmp)]
2884    fn check_eri_property(
2885        test_name: &str,
2886        kernel: Kernel,
2887    ) -> Result<(), Box<dyn std::error::Error>> {
2888        use proptest::prelude::*;
2889        skip_if_unsupported!(kernel, test_name);
2890
2891        let strat = (2usize..=50)
2892            .prop_flat_map(|period| {
2893                (
2894                    100.0f64..5000.0f64,
2895                    (period + 20)..400,
2896                    0.001f64..0.05f64,
2897                    -0.01f64..0.01f64,
2898                    Just(period),
2899                    prop::sample::select(vec!["ema", "sma", "wma"]),
2900                )
2901            })
2902            .prop_map(
2903                |(base_price, data_len, volatility, trend, period, ma_type)| {
2904                    let mut high = Vec::with_capacity(data_len);
2905                    let mut low = Vec::with_capacity(data_len);
2906                    let mut close = Vec::with_capacity(data_len);
2907
2908                    let mut price = base_price;
2909                    for i in 0..data_len {
2910                        price *= 1.0 + trend + (i as f64 * 0.0001 * trend);
2911                        let daily_vol = volatility * price;
2912
2913                        let c = price + daily_vol * ((i as f64).sin() * 0.3);
2914                        let h = c + daily_vol * (0.5 + (i as f64 * 0.7).cos().abs() * 0.5);
2915                        let l = c - daily_vol * (0.5 + (i as f64 * 0.7).sin().abs() * 0.5);
2916
2917                        high.push(h);
2918                        low.push(l.min(c));
2919                        close.push(c);
2920                    }
2921
2922                    (high, low, close, period, ma_type.to_string())
2923                },
2924            );
2925
2926        proptest::test_runner::TestRunner::default()
2927            .run(&strat, |(high, low, close, period, ma_type)| {
2928                let params = EriParams {
2929                    period: Some(period),
2930                    ma_type: Some(ma_type.clone()),
2931                };
2932                let input = EriInput::from_slices(&high, &low, &close, params.clone());
2933
2934                let result = match eri_with_kernel(&input, kernel) {
2935                    Ok(r) => r,
2936                    Err(e) => match e {
2937                        EriError::MaCalculationError(msg) if msg.contains("Not enough data") => {
2938                            return Ok(())
2939                        }
2940                        EriError::NotEnoughValidData { .. } => return Ok(()),
2941                        _ => panic!("Unexpected error type: {:?}", e),
2942                    },
2943                };
2944
2945                let reference = match eri_with_kernel(&input, Kernel::Scalar) {
2946                    Ok(r) => r,
2947                    Err(_) => return Ok(()),
2948                };
2949
2950                let first_valid_idx = high
2951                    .iter()
2952                    .zip(low.iter())
2953                    .zip(close.iter())
2954                    .position(|((h, l), c)| !h.is_nan() && !l.is_nan() && !c.is_nan())
2955                    .unwrap_or(0);
2956                let warmup_period = first_valid_idx + period - 1;
2957
2958                for i in 0..warmup_period.min(high.len()) {
2959                    prop_assert!(
2960                        result.bull[i].is_nan(),
2961                        "[{}] Expected NaN in bull warmup at index {}, got {}",
2962                        test_name,
2963                        i,
2964                        result.bull[i]
2965                    );
2966                    prop_assert!(
2967                        result.bear[i].is_nan(),
2968                        "[{}] Expected NaN in bear warmup at index {}, got {}",
2969                        test_name,
2970                        i,
2971                        result.bear[i]
2972                    );
2973                }
2974
2975                for i in warmup_period..high.len() {
2976                    if !high[i].is_nan() && !low[i].is_nan() && !close[i].is_nan() {
2977                        prop_assert!(
2978                            !result.bull[i].is_nan(),
2979                            "[{}] Unexpected NaN in bull at index {} after warmup",
2980                            test_name,
2981                            i
2982                        );
2983                        prop_assert!(
2984                            !result.bear[i].is_nan(),
2985                            "[{}] Unexpected NaN in bear at index {} after warmup",
2986                            test_name,
2987                            i
2988                        );
2989                    }
2990                }
2991
2992                for i in warmup_period..high.len() {
2993                    if !result.bull[i].is_nan() && !result.bear[i].is_nan() {
2994                        prop_assert!(
2995                            result.bear[i] <= result.bull[i] + 1e-9,
2996                            "[{}] Bear {} > Bull {} at index {} (low={}, high={})",
2997                            test_name,
2998                            result.bear[i],
2999                            result.bull[i],
3000                            i,
3001                            low[i],
3002                            high[i]
3003                        );
3004                    }
3005                }
3006
3007                for i in 0..high.len() {
3008                    let bull_val = result.bull[i];
3009                    let bear_val = result.bear[i];
3010                    let ref_bull = reference.bull[i];
3011                    let ref_bear = reference.bear[i];
3012
3013                    if bull_val.is_finite() && ref_bull.is_finite() {
3014                        let bull_diff = (bull_val - ref_bull).abs();
3015                        let bull_ulp = bull_val.to_bits().abs_diff(ref_bull.to_bits());
3016                        prop_assert!(
3017                            bull_diff <= 1e-9 || bull_ulp <= 4,
3018                            "[{}] Bull kernel mismatch at index {}: {} vs {} (diff={}, ULP={})",
3019                            test_name,
3020                            i,
3021                            bull_val,
3022                            ref_bull,
3023                            bull_diff,
3024                            bull_ulp
3025                        );
3026                    } else {
3027                        prop_assert_eq!(
3028                            bull_val.to_bits(),
3029                            ref_bull.to_bits(),
3030                            "[{}] Bull NaN/Inf mismatch at index {}: {} vs {}",
3031                            test_name,
3032                            i,
3033                            bull_val,
3034                            ref_bull
3035                        );
3036                    }
3037
3038                    if bear_val.is_finite() && ref_bear.is_finite() {
3039                        let bear_diff = (bear_val - ref_bear).abs();
3040                        let bear_ulp = bear_val.to_bits().abs_diff(ref_bear.to_bits());
3041                        prop_assert!(
3042                            bear_diff <= 1e-9 || bear_ulp <= 4,
3043                            "[{}] Bear kernel mismatch at index {}: {} vs {} (diff={}, ULP={})",
3044                            test_name,
3045                            i,
3046                            bear_val,
3047                            ref_bear,
3048                            bear_diff,
3049                            bear_ulp
3050                        );
3051                    } else {
3052                        prop_assert_eq!(
3053                            bear_val.to_bits(),
3054                            ref_bear.to_bits(),
3055                            "[{}] Bear NaN/Inf mismatch at index {}: {} vs {}",
3056                            test_name,
3057                            i,
3058                            bear_val,
3059                            ref_bear
3060                        );
3061                    }
3062                }
3063
3064                let all_high_same = high.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10);
3065                let all_low_same = low.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10);
3066                let all_close_same = close.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10);
3067
3068                if all_high_same
3069                    && all_low_same
3070                    && all_close_same
3071                    && high.len() > warmup_period + 2 * period
3072                {
3073                    let expected_bull = high[0] - close[0];
3074                    let expected_bear = low[0] - close[0];
3075
3076                    for i in (warmup_period + 2 * period)..high.len() {
3077                        if !result.bull[i].is_nan() && !result.bear[i].is_nan() {
3078                            prop_assert!(
3079                                (result.bull[i] - expected_bull).abs() < 1e-6,
3080                                "[{}] Constant price: Bull {} != expected {} at index {}",
3081                                test_name,
3082                                result.bull[i],
3083                                expected_bull,
3084                                i
3085                            );
3086                            prop_assert!(
3087                                (result.bear[i] - expected_bear).abs() < 1e-6,
3088                                "[{}] Constant price: Bear {} != expected {} at index {}",
3089                                test_name,
3090                                result.bear[i],
3091                                expected_bear,
3092                                i
3093                            );
3094                        }
3095                    }
3096                }
3097
3098                for i in warmup_period..high.len() {
3099                    if !result.bull[i].is_nan() && !result.bear[i].is_nan() {
3100                        let expected_diff = high[i] - low[i];
3101                        let actual_diff = result.bull[i] - result.bear[i];
3102                        prop_assert!(
3103                            (actual_diff - expected_diff).abs() < 1e-9,
3104                            "[{}] Bull - Bear != High - Low at index {}: {} vs {}",
3105                            test_name,
3106                            i,
3107                            actual_diff,
3108                            expected_diff
3109                        );
3110                    }
3111                }
3112
3113                for i in warmup_period..high.len() {
3114                    if !result.bull[i].is_nan() && !result.bear[i].is_nan() {
3115                        if result.bull[i] < 0.0 && result.bear[i] > 0.0 {
3116                            prop_assert!(
3117								false,
3118								"[{}] Impossible state: bull {} < 0 but bear {} > 0 at index {} (high={}, low={})",
3119								test_name, result.bull[i], result.bear[i], i, high[i], low[i]
3120							);
3121                        }
3122                    }
3123                }
3124
3125                if period == 1 {
3126                    for i in warmup_period..high.len().min(warmup_period + 10) {
3127                        if !result.bull[i].is_nan() && !result.bear[i].is_nan() {
3128                            let expected_diff = high[i] - low[i];
3129                            let actual_diff = result.bull[i] - result.bear[i];
3130                            prop_assert!(
3131                                (actual_diff - expected_diff).abs() < 1e-6,
3132                                "[{}] Period=1: Bull-Bear mismatch at index {}: {} vs {}",
3133                                test_name,
3134                                i,
3135                                actual_diff,
3136                                expected_diff
3137                            );
3138                        }
3139                    }
3140                }
3141
3142                let max_price = high.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
3143                let min_price = low.iter().cloned().fold(f64::INFINITY, f64::min);
3144                let price_range = max_price - min_price;
3145
3146                for i in warmup_period..high.len() {
3147                    if !result.bull[i].is_nan() && !result.bear[i].is_nan() {
3148                        prop_assert!(
3149                            result.bull[i].abs() <= price_range * 2.0,
3150                            "[{}] Bull {} exceeds reasonable bounds (price range: {}) at index {}",
3151                            test_name,
3152                            result.bull[i],
3153                            price_range,
3154                            i
3155                        );
3156                        prop_assert!(
3157                            result.bear[i].abs() <= price_range * 2.0,
3158                            "[{}] Bear {} exceeds reasonable bounds (price range: {}) at index {}",
3159                            test_name,
3160                            result.bear[i],
3161                            price_range,
3162                            i
3163                        );
3164                    }
3165                }
3166
3167                if period >= high.len() - 5 && period < high.len() {
3168                    let valid_count = result
3169                        .bull
3170                        .iter()
3171                        .zip(result.bear.iter())
3172                        .filter(|(b, r)| !b.is_nan() && !r.is_nan())
3173                        .count();
3174                    prop_assert!(
3175                        valid_count >= 1,
3176                        "[{}] No valid values with period {} and data_len {}",
3177                        test_name,
3178                        period,
3179                        high.len()
3180                    );
3181                }
3182
3183                Ok(())
3184            })
3185            .unwrap();
3186
3187        Ok(())
3188    }
3189
3190    macro_rules! gen_batch_tests {
3191        ($fn_name:ident) => {
3192            paste::paste! {
3193                #[test] fn [<$fn_name _scalar>]()      {
3194                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
3195                }
3196                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3197                #[test] fn [<$fn_name _avx2>]()        {
3198                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
3199                }
3200                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3201                #[test] fn [<$fn_name _avx512>]()      {
3202                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
3203                }
3204                #[test] fn [<$fn_name _auto_detect>]() {
3205                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
3206                }
3207            }
3208        };
3209    }
3210    gen_batch_tests!(check_batch_default_row);
3211    gen_batch_tests!(check_batch_no_poison);
3212}