Skip to main content

vector_ta/indicators/
er.rs

1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::{PyAny, PyDict, PyList};
9#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
10use serde::{Deserialize, Serialize};
11#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
12use wasm_bindgen::prelude::*;
13
14use crate::utilities::data_loader::{source_type, Candles};
15use crate::utilities::enums::Kernel;
16use crate::utilities::helpers::{
17    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
18    make_uninit_matrix,
19};
20#[cfg(feature = "python")]
21use crate::utilities::kernel_validation::validate_kernel;
22use aligned_vec::{AVec, CACHELINE_ALIGN};
23#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
24use core::arch::x86_64::*;
25#[cfg(not(target_arch = "wasm32"))]
26use rayon::prelude::*;
27use std::convert::AsRef;
28use std::error::Error;
29use std::mem::MaybeUninit;
30use thiserror::Error;
31
32impl<'a> AsRef<[f64]> for ErInput<'a> {
33    #[inline(always)]
34    fn as_ref(&self) -> &[f64] {
35        match &self.data {
36            ErData::Slice(slice) => slice,
37            ErData::Candles { candles, source } => source_type(candles, source),
38        }
39    }
40}
41
42#[derive(Debug, Clone)]
43pub enum ErData<'a> {
44    Candles {
45        candles: &'a Candles,
46        source: &'a str,
47    },
48    Slice(&'a [f64]),
49}
50
51#[derive(Debug, Clone)]
52pub struct ErOutput {
53    pub values: Vec<f64>,
54}
55
56#[derive(Debug, Clone)]
57#[cfg_attr(
58    all(target_arch = "wasm32", feature = "wasm"),
59    derive(Serialize, Deserialize)
60)]
61pub struct ErParams {
62    pub period: Option<usize>,
63}
64
65impl Default for ErParams {
66    fn default() -> Self {
67        Self { period: Some(5) }
68    }
69}
70
71#[derive(Debug, Clone)]
72pub struct ErInput<'a> {
73    pub data: ErData<'a>,
74    pub params: ErParams,
75}
76
77impl<'a> ErInput<'a> {
78    #[inline]
79    pub fn from_candles(c: &'a Candles, s: &'a str, p: ErParams) -> Self {
80        Self {
81            data: ErData::Candles {
82                candles: c,
83                source: s,
84            },
85            params: p,
86        }
87    }
88    #[inline]
89    pub fn from_slice(sl: &'a [f64], p: ErParams) -> Self {
90        Self {
91            data: ErData::Slice(sl),
92            params: p,
93        }
94    }
95    #[inline]
96    pub fn with_default_candles(c: &'a Candles) -> Self {
97        Self::from_candles(c, "close", ErParams::default())
98    }
99    #[inline]
100    pub fn get_period(&self) -> usize {
101        self.params.period.unwrap_or(5)
102    }
103}
104
105#[derive(Copy, Clone, Debug)]
106pub struct ErBuilder {
107    period: Option<usize>,
108    kernel: Kernel,
109}
110
111impl Default for ErBuilder {
112    fn default() -> Self {
113        Self {
114            period: None,
115            kernel: Kernel::Auto,
116        }
117    }
118}
119
120impl ErBuilder {
121    #[inline(always)]
122    pub fn new() -> Self {
123        Self::default()
124    }
125    #[inline(always)]
126    pub fn period(mut self, n: usize) -> Self {
127        self.period = Some(n);
128        self
129    }
130    #[inline(always)]
131    pub fn kernel(mut self, k: Kernel) -> Self {
132        self.kernel = k;
133        self
134    }
135    #[inline(always)]
136    pub fn apply(self, c: &Candles) -> Result<ErOutput, ErError> {
137        let p = ErParams {
138            period: self.period,
139        };
140        let i = ErInput::from_candles(c, "close", p);
141        er_with_kernel(&i, self.kernel)
142    }
143    #[inline(always)]
144    pub fn apply_slice(self, d: &[f64]) -> Result<ErOutput, ErError> {
145        let p = ErParams {
146            period: self.period,
147        };
148        let i = ErInput::from_slice(d, p);
149        er_with_kernel(&i, self.kernel)
150    }
151    #[inline(always)]
152    pub fn into_stream(self) -> Result<ErStream, ErError> {
153        let p = ErParams {
154            period: self.period,
155        };
156        ErStream::try_new(p)
157    }
158}
159
160#[derive(Debug, Error)]
161pub enum ErError {
162    #[error("er: Input data slice is empty.")]
163    EmptyInputData,
164    #[error("er: All input data values are NaN.")]
165    AllValuesNaN,
166    #[error("er: Invalid period: period = {period}, data length = {data_len}")]
167    InvalidPeriod { period: usize, data_len: usize },
168    #[error("er: Not enough valid data: needed = {needed}, valid = {valid}")]
169    NotEnoughValidData { needed: usize, valid: usize },
170    #[error("er: Output length mismatch: expected {expected}, got {got}")]
171    OutputLengthMismatch { expected: usize, got: usize },
172    #[error("er: Invalid range: start={start}, end={end}, step={step}")]
173    InvalidRange {
174        start: String,
175        end: String,
176        step: String,
177    },
178    #[error("er: Invalid kernel for batch: {0:?}")]
179    InvalidKernelForBatch(crate::utilities::enums::Kernel),
180}
181
182#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
183impl From<ErError> for JsValue {
184    fn from(err: ErError) -> Self {
185        JsValue::from_str(&err.to_string())
186    }
187}
188
189#[inline]
190pub fn er(input: &ErInput) -> Result<ErOutput, ErError> {
191    er_with_kernel(input, Kernel::Auto)
192}
193
194#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
195#[inline]
196pub fn er_into(input: &ErInput, out: &mut [f64]) -> Result<(), ErError> {
197    er_into_slice(out, input, Kernel::Auto)
198}
199
200pub fn er_with_kernel(input: &ErInput, kernel: Kernel) -> Result<ErOutput, ErError> {
201    let data: &[f64] = input.as_ref();
202    let len = data.len();
203    if len == 0 {
204        return Err(ErError::EmptyInputData);
205    }
206    let first = data
207        .iter()
208        .position(|x| !x.is_nan())
209        .ok_or(ErError::AllValuesNaN)?;
210    let period = input.get_period();
211    if period == 0 || period > len {
212        return Err(ErError::InvalidPeriod {
213            period,
214            data_len: len,
215        });
216    }
217    if (len - first) < period {
218        return Err(ErError::NotEnoughValidData {
219            needed: period,
220            valid: len - first,
221        });
222    }
223
224    let chosen = match kernel {
225        Kernel::Auto => Kernel::Scalar,
226        other => other,
227    };
228
229    let warm = first + period - 1;
230    let mut out = alloc_with_nan_prefix(len, warm);
231    unsafe {
232        match chosen {
233            Kernel::Scalar | Kernel::ScalarBatch => er_scalar(data, period, first, &mut out),
234            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
235            Kernel::Avx2 | Kernel::Avx2Batch => er_avx2(data, period, first, &mut out),
236
237            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
238            Kernel::Avx512 | Kernel::Avx512Batch => er_scalar(data, period, first, &mut out),
239            _ => unreachable!(),
240        }
241    }
242    Ok(ErOutput { values: out })
243}
244
245#[inline]
246pub fn er_into_slice(dst: &mut [f64], input: &ErInput, kern: Kernel) -> Result<(), ErError> {
247    let data: &[f64] = input.as_ref();
248    let len = data.len();
249    if len == 0 {
250        return Err(ErError::EmptyInputData);
251    }
252    let first = data
253        .iter()
254        .position(|x| !x.is_nan())
255        .ok_or(ErError::AllValuesNaN)?;
256    let period = input.get_period();
257    if period == 0 || period > len {
258        return Err(ErError::InvalidPeriod {
259            period,
260            data_len: len,
261        });
262    }
263    if (len - first) < period {
264        return Err(ErError::NotEnoughValidData {
265            needed: period,
266            valid: len - first,
267        });
268    }
269    if dst.len() != len {
270        return Err(ErError::OutputLengthMismatch {
271            expected: len,
272            got: dst.len(),
273        });
274    }
275
276    let chosen = match kern {
277        Kernel::Auto => Kernel::Scalar,
278        other => other,
279    };
280
281    unsafe {
282        match chosen {
283            Kernel::Scalar | Kernel::ScalarBatch => er_scalar(data, period, first, dst),
284            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
285            Kernel::Avx2 | Kernel::Avx2Batch => er_avx2(data, period, first, dst),
286
287            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
288            Kernel::Avx512 | Kernel::Avx512Batch => er_scalar(data, period, first, dst),
289            _ => unreachable!(),
290        }
291    }
292
293    let warm_end = first + period - 1;
294    for v in &mut dst[..warm_end] {
295        *v = f64::NAN;
296    }
297
298    Ok(())
299}
300
301#[inline]
302pub fn er_scalar(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
303    let n = data.len();
304    let warm = first + period - 1;
305    if warm >= n {
306        return;
307    }
308
309    let mut roll = 0.0f64;
310    let mut j = first;
311    while j < warm {
312        roll += (data[j + 1] - data[j]).abs();
313        j += 1;
314    }
315
316    let mut start = first;
317    let mut i = warm;
318    while i < n {
319        let delta = (data[i] - data[start]).abs();
320        out[i] = if roll > 0.0 {
321            (delta / roll).min(1.0)
322        } else {
323            0.0
324        };
325
326        if i + 1 == n {
327            break;
328        }
329        let add = (data[i + 1] - data[i]).abs();
330        let sub = (data[start + 1] - data[start]).abs();
331        roll = roll + add - sub;
332        start += 1;
333        i += 1;
334    }
335}
336
337#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
338#[inline]
339pub fn er_avx512(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
340    unsafe {
341        if period <= 32 {
342            er_avx512_short(data, period, first, out);
343        } else {
344            er_avx512_long(data, period, first, out);
345        }
346    }
347}
348
349#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
350#[inline]
351#[target_feature(enable = "avx2")]
352pub unsafe fn er_avx2(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
353    use core::arch::x86_64::*;
354    #[inline(always)]
355    unsafe fn hsum256(x: __m256d) -> f64 {
356        let hi = _mm256_extractf128_pd(x, 1);
357        let lo = _mm256_castpd256_pd128(x);
358        let s = _mm_add_pd(hi, lo);
359        let sh = _mm_unpackhi_pd(s, s);
360        _mm_cvtsd_f64(_mm_add_sd(s, sh))
361    }
362    #[inline(always)]
363    unsafe fn vabs(a: __m256d) -> __m256d {
364        let sign = _mm256_set1_pd(-0.0);
365        _mm256_andnot_pd(sign, a)
366    }
367
368    let n = data.len();
369    let warm = first + period - 1;
370    if warm >= n {
371        return;
372    }
373
374    let ptr = data.as_ptr();
375    let mut acc = unsafe { _mm256_setzero_pd() };
376    let mut j = first;
377    while j + 4 <= warm {
378        let a = unsafe { _mm256_loadu_pd(ptr.add(j)) };
379        let b = unsafe { _mm256_loadu_pd(ptr.add(j + 1)) };
380        acc = unsafe { _mm256_add_pd(acc, vabs(_mm256_sub_pd(b, a))) };
381        j += 4;
382    }
383    let mut roll = unsafe { hsum256(acc) };
384    while j < warm {
385        roll += (data[j + 1] - data[j]).abs();
386        j += 1;
387    }
388
389    let mut start = first;
390    let mut i = warm;
391    while i < n {
392        let delta = (data[i] - data[start]).abs();
393        out[i] = if roll > 0.0 {
394            (delta / roll).min(1.0)
395        } else {
396            0.0
397        };
398        if i + 1 == n {
399            break;
400        }
401        let add = (data[i + 1] - data[i]).abs();
402        let sub = (data[start + 1] - data[start]).abs();
403        roll = roll + add - sub;
404        start += 1;
405        i += 1;
406    }
407}
408
409#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
410#[inline]
411#[target_feature(enable = "avx512f")]
412pub unsafe fn er_avx512_short(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
413    use core::arch::x86_64::*;
414    #[inline(always)]
415    unsafe fn hsum512(x: __m512d) -> f64 {
416        let v1 = _mm512_add_pd(x, _mm512_shuffle_f64x2(x, x, 0b11_10_01_00));
417        let v2 = _mm512_add_pd(v1, _mm512_shuffle_f64x2(v1, v1, 0b00_00_11_10));
418        let lo = _mm512_castpd512_pd128(v2);
419        let hi = _mm256_extractf64x2_pd(_mm512_castpd512_pd256(v2), 1);
420        let s = _mm_add_pd(lo, hi);
421        let sh = _mm_unpackhi_pd(s, s);
422        _mm_cvtsd_f64(_mm_add_sd(s, sh))
423    }
424    #[inline(always)]
425    unsafe fn vabs(a: __m512d) -> __m512d {
426        let sign = _mm512_set1_pd(-0.0);
427        _mm512_andnot_pd(sign, a)
428    }
429
430    let n = data.len();
431    let warm = first + period - 1;
432    if warm >= n {
433        return;
434    }
435
436    let ptr = data.as_ptr();
437    let mut acc = _mm512_setzero_pd();
438    let mut j = first;
439    while j + 8 <= warm {
440        let a = _mm512_loadu_pd(ptr.add(j));
441        let b = _mm512_loadu_pd(ptr.add(j + 1));
442        acc = _mm512_add_pd(acc, vabs(_mm512_sub_pd(b, a)));
443        j += 8;
444    }
445    let mut roll = hsum512(acc);
446    while j < warm {
447        roll += (data[j + 1] - data[j]).abs();
448        j += 1;
449    }
450
451    let mut start = first;
452    let mut i = warm;
453    while i < n {
454        let delta = (data[i] - data[start]).abs();
455        out[i] = if roll > 0.0 {
456            (delta / roll).min(1.0)
457        } else {
458            0.0
459        };
460        if i + 1 == n {
461            break;
462        }
463        let add = (data[i + 1] - data[i]).abs();
464        let sub = (data[start + 1] - data[start]).abs();
465        roll = roll + add - sub;
466        start += 1;
467        i += 1;
468    }
469}
470
471#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
472#[inline]
473#[target_feature(enable = "avx512f")]
474pub unsafe fn er_avx512_long(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
475    er_avx512_short(data, period, first, out)
476}
477
478#[derive(Debug, Clone)]
479pub struct ErStream {
480    period: usize,
481    buffer: Vec<f64>,
482    head: usize,
483    filled: bool,
484    len: usize,
485    denom: f64,
486}
487
488impl ErStream {
489    pub fn try_new(params: ErParams) -> Result<Self, ErError> {
490        let period = params.period.unwrap_or(5);
491        if period == 0 {
492            return Err(ErError::InvalidPeriod {
493                period,
494                data_len: 0,
495            });
496        }
497        Ok(Self {
498            period,
499            buffer: vec![f64::NAN; period],
500            head: 0,
501            filled: false,
502            len: 0,
503            denom: 0.0,
504        })
505    }
506
507    #[inline(always)]
508    pub fn update(&mut self, value: f64) -> Option<f64> {
509        if self.period == 1 {
510            self.buffer[0] = value;
511            self.head = 0;
512            self.filled = true;
513            self.len = 1;
514            self.denom = 0.0;
515            return Some(0.0);
516        }
517
518        if !self.filled {
519            if self.len == 0 {
520                self.buffer[self.head] = value;
521                self.head = (self.head + 1) % self.period;
522                self.len = 1;
523                return None;
524            } else {
525                let prev_idx = if self.head == 0 {
526                    self.period - 1
527                } else {
528                    self.head - 1
529                };
530                self.denom += (value - self.buffer[prev_idx]).abs();
531
532                self.buffer[self.head] = value;
533                self.head = (self.head + 1) % self.period;
534                self.len += 1;
535
536                if self.len < self.period {
537                    return None;
538                }
539
540                self.filled = true;
541
542                let start = self.head;
543                let end = if start == 0 {
544                    self.period - 1
545                } else {
546                    start - 1
547                };
548                debug_assert!(self.len == self.period);
549
550                let delta = (self.buffer[end] - self.buffer[start]).abs();
551                if self.denom > 0.0 {
552                    return Some(if delta >= self.denom {
553                        1.0
554                    } else {
555                        delta / self.denom
556                    });
557                } else {
558                    return Some(0.0);
559                }
560            }
561        }
562
563        let start = self.head;
564        let second = if start + 1 == self.period {
565            0
566        } else {
567            start + 1
568        };
569        let end_prev = if start == 0 {
570            self.period - 1
571        } else {
572            start - 1
573        };
574
575        let sub = (self.buffer[second] - self.buffer[start]).abs();
576        let add = (value - self.buffer[end_prev]).abs();
577        let new_denom = self.denom + add - sub;
578
579        let delta = (value - self.buffer[second]).abs();
580
581        self.denom = new_denom;
582        self.buffer[start] = value;
583        self.head = second;
584
585        if self.denom > 0.0 {
586            Some(if delta >= self.denom {
587                1.0
588            } else {
589                delta / self.denom
590            })
591        } else {
592            Some(0.0)
593        }
594    }
595}
596
597#[derive(Clone, Debug)]
598pub struct ErBatchRange {
599    pub period: (usize, usize, usize),
600}
601
602impl Default for ErBatchRange {
603    fn default() -> Self {
604        Self {
605            period: (5, 254, 1),
606        }
607    }
608}
609
610#[derive(Clone, Debug, Default)]
611pub struct ErBatchBuilder {
612    range: ErBatchRange,
613    kernel: Kernel,
614}
615
616impl ErBatchBuilder {
617    pub fn new() -> Self {
618        Self::default()
619    }
620    pub fn kernel(mut self, k: Kernel) -> Self {
621        self.kernel = k;
622        self
623    }
624    #[inline]
625    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
626        self.range.period = (start, end, step);
627        self
628    }
629    #[inline]
630    pub fn period_static(mut self, p: usize) -> Self {
631        self.range.period = (p, p, 0);
632        self
633    }
634    pub fn apply_slice(self, data: &[f64]) -> Result<ErBatchOutput, ErError> {
635        er_batch_with_kernel(data, &self.range, self.kernel)
636    }
637    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<ErBatchOutput, ErError> {
638        ErBatchBuilder::new().kernel(k).apply_slice(data)
639    }
640    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<ErBatchOutput, ErError> {
641        let slice = source_type(c, src);
642        self.apply_slice(slice)
643    }
644    pub fn with_default_candles(c: &Candles) -> Result<ErBatchOutput, ErError> {
645        ErBatchBuilder::new()
646            .kernel(Kernel::Auto)
647            .apply_candles(c, "close")
648    }
649}
650
651pub fn er_batch_with_kernel(
652    data: &[f64],
653    sweep: &ErBatchRange,
654    k: Kernel,
655) -> Result<ErBatchOutput, ErError> {
656    let kernel = match k {
657        Kernel::Auto => detect_best_batch_kernel(),
658        other if other.is_batch() => other,
659        other => return Err(ErError::InvalidKernelForBatch(other)),
660    };
661    let simd = match kernel {
662        Kernel::Avx512Batch => Kernel::Avx512,
663        Kernel::Avx2Batch => Kernel::Avx2,
664        Kernel::ScalarBatch => Kernel::Scalar,
665        _ => unreachable!(),
666    };
667    er_batch_par_slice(data, sweep, simd)
668}
669
670#[derive(Clone, Debug)]
671pub struct ErBatchOutput {
672    pub values: Vec<f64>,
673    pub combos: Vec<ErParams>,
674    pub rows: usize,
675    pub cols: usize,
676}
677impl ErBatchOutput {
678    pub fn row_for_params(&self, p: &ErParams) -> Option<usize> {
679        self.combos
680            .iter()
681            .position(|c| c.period.unwrap_or(5) == p.period.unwrap_or(5))
682    }
683    pub fn values_for(&self, p: &ErParams) -> Option<&[f64]> {
684        self.row_for_params(p).map(|row| {
685            let start = row * self.cols;
686            &self.values[start..start + self.cols]
687        })
688    }
689}
690
691#[inline(always)]
692fn expand_grid(r: &ErBatchRange) -> Vec<ErParams> {
693    fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
694        if step == 0 || start == end {
695            return vec![start];
696        }
697        let st = step.max(1);
698        if start < end {
699            (start..=end).step_by(st).collect()
700        } else {
701            let mut v = Vec::new();
702            let mut x = start as isize;
703            let end_i = end as isize;
704            let st_i = st as isize;
705            while x >= end_i {
706                v.push(x as usize);
707                x -= st_i;
708            }
709            v
710        }
711    }
712    let periods = axis_usize(r.period);
713    let mut out = Vec::with_capacity(periods.len());
714    for &p in &periods {
715        out.push(ErParams { period: Some(p) });
716    }
717    out
718}
719
720#[inline(always)]
721pub fn er_batch_slice(
722    data: &[f64],
723    sweep: &ErBatchRange,
724    kern: Kernel,
725) -> Result<ErBatchOutput, ErError> {
726    er_batch_inner(data, sweep, kern, false)
727}
728
729#[inline(always)]
730pub fn er_batch_par_slice(
731    data: &[f64],
732    sweep: &ErBatchRange,
733    kern: Kernel,
734) -> Result<ErBatchOutput, ErError> {
735    er_batch_inner(data, sweep, kern, true)
736}
737
738#[inline(always)]
739fn er_batch_inner_into(
740    data: &[f64],
741    sweep: &ErBatchRange,
742    kern: Kernel,
743    parallel: bool,
744    out: &mut [f64],
745) -> Result<Vec<ErParams>, ErError> {
746    let combos = expand_grid(sweep);
747    if combos.is_empty() {
748        return Err(ErError::InvalidRange {
749            start: sweep.period.0.to_string(),
750            end: sweep.period.1.to_string(),
751            step: sweep.period.2.to_string(),
752        });
753    }
754
755    let cols = data.len();
756    if cols == 0 {
757        return Err(ErError::EmptyInputData);
758    }
759    let first = data
760        .iter()
761        .position(|x| !x.is_nan())
762        .ok_or(ErError::AllValuesNaN)?;
763    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
764    if cols - first < max_p {
765        return Err(ErError::NotEnoughValidData {
766            needed: max_p,
767            valid: cols - first,
768        });
769    }
770
771    let rows = combos.len();
772    let out_mu = unsafe {
773        std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
774    };
775    let expected = rows
776        .checked_mul(cols)
777        .ok_or_else(|| ErError::InvalidRange {
778            start: "rows*cols".into(),
779            end: "overflow".into(),
780            step: "*".into(),
781        })?;
782    if out.len() != expected {
783        return Err(ErError::OutputLengthMismatch {
784            expected,
785            got: out.len(),
786        });
787    }
788    let warm: Vec<usize> = combos
789        .iter()
790        .map(|c| first + c.period.unwrap() - 1)
791        .collect();
792    init_matrix_prefixes(out_mu, cols, &warm);
793
794    let mut prefix = vec![0.0f64; cols];
795    if first < cols {
796        let mut j = first;
797        while j + 1 < cols {
798            let d = (data[j + 1] - data[j]).abs();
799            prefix[j + 1] = prefix[j] + d;
800            j += 1;
801        }
802    }
803
804    let do_row = |row: usize, out_row: &mut [f64]| unsafe {
805        let period = combos[row].period.unwrap();
806        match kern {
807            Kernel::Scalar => er_row_scalar_with_prefix(data, &prefix, first, period, out_row),
808            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
809            Kernel::Avx2 => er_row_avx2(data, first, period, out_row),
810            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
811            Kernel::Avx512 => er_row_avx512(data, first, period, out_row),
812            _ => unreachable!(),
813        }
814    };
815
816    if parallel {
817        #[cfg(not(target_arch = "wasm32"))]
818        {
819            out.par_chunks_mut(cols)
820                .enumerate()
821                .for_each(|(row, slice)| do_row(row, slice));
822        }
823
824        #[cfg(target_arch = "wasm32")]
825        {
826            for (row, slice) in out.chunks_mut(cols).enumerate() {
827                do_row(row, slice);
828            }
829        }
830    } else {
831        for (row, slice) in out.chunks_mut(cols).enumerate() {
832            do_row(row, slice);
833        }
834    }
835
836    Ok(combos)
837}
838
839#[inline(always)]
840fn er_batch_inner(
841    data: &[f64],
842    sweep: &ErBatchRange,
843    kern: Kernel,
844    parallel: bool,
845) -> Result<ErBatchOutput, ErError> {
846    let combos = expand_grid(sweep);
847    if combos.is_empty() {
848        return Err(ErError::InvalidRange {
849            start: sweep.period.0.to_string(),
850            end: sweep.period.1.to_string(),
851            step: sweep.period.2.to_string(),
852        });
853    }
854
855    let cols = data.len();
856    if cols == 0 {
857        return Err(ErError::EmptyInputData);
858    }
859    let first = data
860        .iter()
861        .position(|x| !x.is_nan())
862        .ok_or(ErError::AllValuesNaN)?;
863    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
864    if cols - first < max_p {
865        return Err(ErError::NotEnoughValidData {
866            needed: max_p,
867            valid: cols - first,
868        });
869    }
870
871    let rows = combos.len();
872    let _total = rows
873        .checked_mul(cols)
874        .ok_or_else(|| ErError::InvalidRange {
875            start: "rows*cols".into(),
876            end: "overflow".into(),
877            step: "*".into(),
878        })?;
879    let mut buf_mu = make_uninit_matrix(rows, cols);
880
881    let warm: Vec<usize> = combos
882        .iter()
883        .map(|c| first + c.period.unwrap() - 1)
884        .collect();
885    init_matrix_prefixes(&mut buf_mu, cols, &warm);
886
887    let mut buf_guard = std::mem::ManuallyDrop::new(buf_mu);
888    let values: &mut [f64] = unsafe {
889        std::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
890    };
891
892    let do_row = |row: usize, out_row: &mut [f64]| unsafe {
893        let period = combos[row].period.unwrap();
894        match kern {
895            Kernel::Scalar => er_row_scalar(data, first, period, out_row),
896            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
897            Kernel::Avx2 => er_row_avx2(data, first, period, out_row),
898            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
899            Kernel::Avx512 => er_row_avx512(data, first, period, out_row),
900            _ => unreachable!(),
901        }
902    };
903
904    if parallel {
905        #[cfg(not(target_arch = "wasm32"))]
906        {
907            values
908                .par_chunks_mut(cols)
909                .enumerate()
910                .for_each(|(row, slice)| do_row(row, slice));
911        }
912
913        #[cfg(target_arch = "wasm32")]
914        {
915            for (row, slice) in values.chunks_mut(cols).enumerate() {
916                do_row(row, slice);
917            }
918        }
919    } else {
920        for (row, slice) in values.chunks_mut(cols).enumerate() {
921            do_row(row, slice);
922        }
923    }
924
925    let values = unsafe {
926        Vec::from_raw_parts(
927            buf_guard.as_mut_ptr() as *mut f64,
928            buf_guard.len(),
929            buf_guard.capacity(),
930        )
931    };
932
933    Ok(ErBatchOutput {
934        values,
935        combos,
936        rows,
937        cols,
938    })
939}
940
941#[inline(always)]
942unsafe fn er_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
943    er_scalar(data, period, first, out)
944}
945
946#[inline(always)]
947fn er_row_scalar_with_prefix(
948    data: &[f64],
949    prefix: &[f64],
950    first: usize,
951    period: usize,
952    out: &mut [f64],
953) {
954    let n = data.len();
955    let warm = first + period - 1;
956    if warm >= n {
957        return;
958    }
959    let mut i = warm;
960    while i < n {
961        let start = i + 1 - period;
962        let delta = (data[i] - data[start]).abs();
963        let denom = prefix[i] - prefix[start];
964        out[i] = if denom > 0.0 {
965            (delta / denom).min(1.0)
966        } else {
967            0.0
968        };
969        i += 1;
970    }
971}
972
973#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
974#[inline(always)]
975unsafe fn er_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
976    er_avx2(data, period, first, out)
977}
978
979#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
980#[inline(always)]
981unsafe fn er_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
982    if period <= 32 {
983        er_row_avx512_short(data, first, period, out);
984    } else {
985        er_row_avx512_long(data, first, period, out);
986    }
987}
988
989#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
990#[inline(always)]
991unsafe fn er_row_avx512_short(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
992    er_avx512_short(data, period, first, out)
993}
994
995#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
996#[inline(always)]
997unsafe fn er_row_avx512_long(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
998    er_avx512_long(data, period, first, out)
999}
1000
1001#[cfg(feature = "python")]
1002#[pyfunction(name = "er")]
1003#[pyo3(signature = (data, period, kernel=None))]
1004pub fn er_py<'py>(
1005    py: Python<'py>,
1006    data: numpy::PyReadonlyArray1<'py, f64>,
1007    period: usize,
1008    kernel: Option<&str>,
1009) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
1010    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1011
1012    let slice_in = data.as_slice()?;
1013    let kern = validate_kernel(kernel, false)?;
1014
1015    let params = ErParams {
1016        period: Some(period),
1017    };
1018    let input = ErInput::from_slice(slice_in, params);
1019
1020    let result_vec = py
1021        .allow_threads(|| er_with_kernel(&input, kern))
1022        .map(|result| result.values)
1023        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1024
1025    Ok(result_vec.into_pyarray(py))
1026}
1027
1028#[cfg(feature = "python")]
1029#[pyclass(name = "ErStream")]
1030pub struct ErStreamPy {
1031    stream: ErStream,
1032}
1033
1034#[cfg(feature = "python")]
1035#[pymethods]
1036impl ErStreamPy {
1037    #[new]
1038    fn new(period: usize) -> PyResult<Self> {
1039        let params = ErParams {
1040            period: Some(period),
1041        };
1042        let stream = ErStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1043        Ok(ErStreamPy { stream })
1044    }
1045
1046    fn update(&mut self, value: f64) -> Option<f64> {
1047        self.stream.update(value)
1048    }
1049}
1050
1051#[cfg(feature = "python")]
1052#[pyfunction(name = "er_batch")]
1053#[pyo3(signature = (data, period_range, kernel=None))]
1054pub fn er_batch_py<'py>(
1055    py: Python<'py>,
1056    data: numpy::PyReadonlyArray1<'py, f64>,
1057    period_range: (usize, usize, usize),
1058    kernel: Option<&str>,
1059) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1060    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1061    use pyo3::types::PyDict;
1062
1063    let slice_in = data.as_slice()?;
1064    let kern = validate_kernel(kernel, true)?;
1065
1066    let sweep = ErBatchRange {
1067        period: period_range,
1068    };
1069    let combos = expand_grid(&sweep);
1070    let rows = combos.len();
1071    let cols = slice_in.len();
1072
1073    let out_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
1074    let slice_out = unsafe { out_arr.as_slice_mut()? };
1075
1076    let combos = py
1077        .allow_threads(|| {
1078            let simd = match kern {
1079                Kernel::Auto => {
1080                    let base = detect_best_kernel();
1081                    match base {
1082                        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1083                        Kernel::Avx512 => Kernel::Scalar,
1084                        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1085                        Kernel::Avx2 => Kernel::Avx2,
1086                        _ => Kernel::Scalar,
1087                    }
1088                }
1089                other => match other {
1090                    Kernel::ScalarBatch => Kernel::Scalar,
1091                    #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1092                    Kernel::Avx2Batch => Kernel::Avx2,
1093                    #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1094                    Kernel::Avx512Batch => Kernel::Avx512,
1095                    _ => unreachable!(),
1096                },
1097            };
1098            er_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1099        })
1100        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1101
1102    let dict = PyDict::new(py);
1103    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1104    dict.set_item(
1105        "periods",
1106        combos
1107            .iter()
1108            .map(|p| p.period.unwrap() as u64)
1109            .collect::<Vec<_>>()
1110            .into_pyarray(py),
1111    )?;
1112    dict.set_item("rows", rows)?;
1113    dict.set_item("cols", cols)?;
1114    Ok(dict)
1115}
1116
1117#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1118#[wasm_bindgen]
1119pub fn er_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1120    let params = ErParams {
1121        period: Some(period),
1122    };
1123    let input = ErInput::from_slice(data, params);
1124
1125    let mut output = vec![0.0; data.len()];
1126
1127    er_into_slice(&mut output, &input, Kernel::Auto)
1128        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1129
1130    Ok(output)
1131}
1132
1133#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1134#[wasm_bindgen]
1135pub fn er_alloc(len: usize) -> *mut f64 {
1136    let mut vec = Vec::<f64>::with_capacity(len);
1137    let ptr = vec.as_mut_ptr();
1138    std::mem::forget(vec);
1139    ptr
1140}
1141
1142#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1143#[wasm_bindgen]
1144pub fn er_free(ptr: *mut f64, len: usize) {
1145    if !ptr.is_null() {
1146        unsafe {
1147            let _ = Vec::from_raw_parts(ptr, len, len);
1148        }
1149    }
1150}
1151
1152#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1153#[wasm_bindgen]
1154pub fn er_into(
1155    in_ptr: *const f64,
1156    out_ptr: *mut f64,
1157    len: usize,
1158    period: usize,
1159) -> Result<(), JsValue> {
1160    if in_ptr.is_null() || out_ptr.is_null() {
1161        return Err(JsValue::from_str("null pointer passed to er_into"));
1162    }
1163
1164    unsafe {
1165        let data = std::slice::from_raw_parts(in_ptr, len);
1166
1167        if period == 0 || period > len {
1168            return Err(JsValue::from_str("Invalid period"));
1169        }
1170
1171        let params = ErParams {
1172            period: Some(period),
1173        };
1174        let input = ErInput::from_slice(data, params);
1175
1176        if in_ptr == out_ptr {
1177            let mut temp = vec![0.0; len];
1178            er_into_slice(&mut temp, &input, Kernel::Auto)
1179                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1180            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1181            out.copy_from_slice(&temp);
1182        } else {
1183            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1184            er_into_slice(out, &input, Kernel::Auto)
1185                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1186        }
1187
1188        Ok(())
1189    }
1190}
1191
1192#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1193#[derive(Serialize, Deserialize)]
1194pub struct ErBatchConfig {
1195    pub period_range: (usize, usize, usize),
1196}
1197
1198#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1199#[derive(Serialize, Deserialize)]
1200pub struct ErBatchJsOutput {
1201    pub values: Vec<f64>,
1202    pub combos: Vec<ErParams>,
1203    pub rows: usize,
1204    pub cols: usize,
1205}
1206
1207#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1208#[wasm_bindgen(js_name = er_batch)]
1209pub fn er_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1210    let config: ErBatchConfig = serde_wasm_bindgen::from_value(config)
1211        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1212
1213    let sweep = ErBatchRange {
1214        period: config.period_range,
1215    };
1216
1217    let output = er_batch_with_kernel(data, &sweep, Kernel::Auto)
1218        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1219
1220    let js_output = ErBatchJsOutput {
1221        values: output.values,
1222        combos: output.combos,
1223        rows: output.rows,
1224        cols: output.cols,
1225    };
1226
1227    serde_wasm_bindgen::to_value(&js_output).map_err(|e| JsValue::from_str(&e.to_string()))
1228}
1229
1230#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1231#[wasm_bindgen]
1232pub fn er_batch_into(
1233    in_ptr: *const f64,
1234    out_ptr: *mut f64,
1235    len: usize,
1236    period_start: usize,
1237    period_end: usize,
1238    period_step: usize,
1239) -> Result<usize, JsValue> {
1240    if in_ptr.is_null() || out_ptr.is_null() {
1241        return Err(JsValue::from_str("null pointer passed to er_batch_into"));
1242    }
1243
1244    unsafe {
1245        let data = std::slice::from_raw_parts(in_ptr, len);
1246        let sweep = ErBatchRange {
1247            period: (period_start, period_end, period_step),
1248        };
1249        let combos = expand_grid(&sweep);
1250        let rows = combos.len();
1251        let cols = len;
1252        if rows * cols > 0 {
1253            let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
1254
1255            let batch_kernel = detect_best_batch_kernel();
1256            let simd = match batch_kernel {
1257                Kernel::Avx512Batch => Kernel::Avx512,
1258                Kernel::Avx2Batch => Kernel::Avx2,
1259                Kernel::ScalarBatch => Kernel::Scalar,
1260                _ => unreachable!(),
1261            };
1262            er_batch_inner_into(data, &sweep, simd, false, out)
1263                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1264        }
1265        Ok(rows)
1266    }
1267}
1268
1269#[cfg(all(feature = "python", feature = "cuda"))]
1270use crate::cuda::er_wrapper::{CudaEr, DeviceArrayF32Er};
1271#[cfg(all(feature = "python", feature = "cuda"))]
1272use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
1273#[cfg(all(feature = "python", feature = "cuda"))]
1274use numpy::PyReadonlyArray1;
1275#[cfg(all(feature = "python", feature = "cuda"))]
1276#[cfg(all(feature = "python", feature = "cuda"))]
1277use pyo3::prelude::*;
1278
1279#[cfg(all(feature = "python", feature = "cuda"))]
1280#[pyfunction(name = "er_cuda_batch_dev")]
1281#[pyo3(signature = (data_f32, period_range, device_id=0))]
1282pub fn er_cuda_batch_dev_py(
1283    py: Python<'_>,
1284    data_f32: PyReadonlyArray1<'_, f32>,
1285    period_range: (usize, usize, usize),
1286    device_id: usize,
1287) -> PyResult<DeviceArrayF32ErPy> {
1288    use crate::cuda::cuda_available;
1289    if !cuda_available() {
1290        return Err(PyValueError::new_err("CUDA not available"));
1291    }
1292    let slice = data_f32.as_slice()?;
1293    let sweep = ErBatchRange {
1294        period: period_range,
1295    };
1296    let inner = py.allow_threads(|| {
1297        let cuda = CudaEr::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1298        cuda.er_batch_dev(slice, &sweep)
1299            .map_err(|e| PyValueError::new_err(e.to_string()))
1300    })?;
1301    Ok(DeviceArrayF32ErPy { inner })
1302}
1303
1304#[cfg(all(feature = "python", feature = "cuda"))]
1305#[pyfunction(name = "er_cuda_many_series_one_param_dev")]
1306#[pyo3(signature = (data_tm_f32, cols, rows, period, device_id=0))]
1307pub fn er_cuda_many_series_one_param_dev_py(
1308    py: Python<'_>,
1309    data_tm_f32: PyReadonlyArray1<'_, f32>,
1310    cols: usize,
1311    rows: usize,
1312    period: usize,
1313    device_id: usize,
1314) -> PyResult<DeviceArrayF32ErPy> {
1315    use crate::cuda::cuda_available;
1316    if !cuda_available() {
1317        return Err(PyValueError::new_err("CUDA not available"));
1318    }
1319    let slice = data_tm_f32.as_slice()?;
1320    let inner = py.allow_threads(|| {
1321        let cuda = CudaEr::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1322        cuda.er_many_series_one_param_time_major_dev(slice, cols, rows, period)
1323            .map_err(|e| PyValueError::new_err(e.to_string()))
1324    })?;
1325    Ok(DeviceArrayF32ErPy { inner })
1326}
1327
1328#[cfg(all(feature = "python", feature = "cuda"))]
1329#[pyclass(module = "ta_indicators.cuda", unsendable)]
1330pub struct DeviceArrayF32ErPy {
1331    pub(crate) inner: DeviceArrayF32Er,
1332}
1333
1334#[cfg(all(feature = "python", feature = "cuda"))]
1335#[pymethods]
1336impl DeviceArrayF32ErPy {
1337    #[getter]
1338    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1339        let d = PyDict::new(py);
1340        d.set_item("shape", (self.inner.rows, self.inner.cols))?;
1341        d.set_item("typestr", "<f4")?;
1342        d.set_item(
1343            "strides",
1344            (
1345                self.inner.cols * std::mem::size_of::<f32>(),
1346                std::mem::size_of::<f32>(),
1347            ),
1348        )?;
1349        d.set_item("data", (self.inner.device_ptr() as usize, false))?;
1350        d.set_item("version", 3)?;
1351        Ok(d)
1352    }
1353
1354    fn __dlpack_device__(&self) -> (i32, i32) {
1355        (2, self.inner.device_id as i32)
1356    }
1357
1358    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1359    fn __dlpack__<'py>(
1360        &mut self,
1361        py: Python<'py>,
1362        stream: Option<PyObject>,
1363        max_version: Option<PyObject>,
1364        dl_device: Option<PyObject>,
1365        copy: Option<PyObject>,
1366    ) -> PyResult<PyObject> {
1367        use cust::memory::DeviceBuffer;
1368        use pyo3::types::PyAny;
1369        use pyo3::Bound;
1370
1371        let (dev_ty, alloc_dev) = self.__dlpack_device__();
1372        if let Some(dev_obj) = dl_device.as_ref() {
1373            if let Ok((want_ty, want_dev)) = dev_obj.extract::<(i32, i32)>(py) {
1374                if want_ty != dev_ty || want_dev != alloc_dev {
1375                    return Err(PyValueError::new_err(
1376                        "__dlpack__ dl_device does not match ER buffer device",
1377                    ));
1378                }
1379            } else {
1380                return Err(PyValueError::new_err(
1381                    "__dlpack__ dl_device must be a (device_type, device_id) tuple",
1382                ));
1383            }
1384        }
1385
1386        let _ = stream;
1387        let _ = copy;
1388
1389        let dummy =
1390            DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
1391        let rows = self.inner.rows;
1392        let cols = self.inner.cols;
1393        let ctx = self.inner.ctx.clone();
1394        let device_id = self.inner.device_id;
1395        let inner = std::mem::replace(
1396            &mut self.inner,
1397            DeviceArrayF32Er {
1398                buf: dummy,
1399                rows: 0,
1400                cols: 0,
1401                ctx,
1402                device_id,
1403            },
1404        );
1405
1406        let max_version_bound: Option<Bound<'py, PyAny>> =
1407            max_version.map(|obj| obj.into_bound(py));
1408
1409        export_f32_cuda_dlpack_2d(py, inner.buf, rows, cols, alloc_dev, max_version_bound)
1410    }
1411}
1412
1413#[cfg(test)]
1414mod tests {
1415    use super::*;
1416    use crate::skip_if_unsupported;
1417    use crate::utilities::data_loader::read_candles_from_csv;
1418
1419    fn check_er_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1420        skip_if_unsupported!(kernel, test_name);
1421        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1422        let candles = read_candles_from_csv(file_path)?;
1423
1424        let default_params = ErParams { period: None };
1425        let input = ErInput::from_candles(&candles, "close", default_params);
1426        let output = er_with_kernel(&input, kernel)?;
1427        assert_eq!(output.values.len(), candles.close.len());
1428
1429        Ok(())
1430    }
1431
1432    #[test]
1433    fn test_er_into_matches_api() -> Result<(), Box<dyn Error>> {
1434        let n = 256usize;
1435        let mut data = Vec::with_capacity(n);
1436        for i in 0..n {
1437            if i < 3 {
1438                data.push(f64::NAN);
1439            } else {
1440                let x = i as f64;
1441                data.push((x * 0.01).sin() * (x * 0.02).cos() + 0.001 * x);
1442            }
1443        }
1444
1445        let input = ErInput::from_slice(&data, ErParams::default());
1446
1447        let base = er(&input)?.values;
1448
1449        let mut out = vec![0.0; n];
1450
1451        #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1452        {
1453            er_into(&input, &mut out)?;
1454        }
1455
1456        assert_eq!(base.len(), out.len());
1457
1458        fn eq_or_both_nan_eps(a: f64, b: f64) -> bool {
1459            (a.is_nan() && b.is_nan()) || (a - b).abs() <= 1e-12
1460        }
1461
1462        for i in 0..n {
1463            assert!(
1464                eq_or_both_nan_eps(base[i], out[i]),
1465                "mismatch at {}: base={:?}, into={:?}",
1466                i,
1467                base[i],
1468                out[i]
1469            );
1470        }
1471        Ok(())
1472    }
1473
1474    fn check_er_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1475        skip_if_unsupported!(kernel, test_name);
1476        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1477        let candles = read_candles_from_csv(file_path)?;
1478
1479        let input = ErInput::with_default_candles(&candles);
1480        match input.data {
1481            ErData::Candles { source, .. } => assert_eq!(source, "close"),
1482            _ => panic!("Expected ErData::Candles"),
1483        }
1484        let output = er_with_kernel(&input, kernel)?;
1485        assert_eq!(output.values.len(), candles.close.len());
1486        Ok(())
1487    }
1488
1489    fn check_er_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1490        skip_if_unsupported!(kernel, test_name);
1491        let input_data = [10.0, 20.0, 30.0];
1492        let params = ErParams { period: Some(0) };
1493        let input = ErInput::from_slice(&input_data, params);
1494        let res = er_with_kernel(&input, kernel);
1495        assert!(
1496            res.is_err(),
1497            "[{}] ER should fail with zero period",
1498            test_name
1499        );
1500        Ok(())
1501    }
1502
1503    fn check_er_period_exceeds_length(
1504        test_name: &str,
1505        kernel: Kernel,
1506    ) -> Result<(), Box<dyn Error>> {
1507        skip_if_unsupported!(kernel, test_name);
1508        let data_small = [10.0, 20.0, 30.0];
1509        let params = ErParams { period: Some(10) };
1510        let input = ErInput::from_slice(&data_small, params);
1511        let res = er_with_kernel(&input, kernel);
1512        assert!(
1513            res.is_err(),
1514            "[{}] ER should fail with period exceeding length",
1515            test_name
1516        );
1517        Ok(())
1518    }
1519
1520    fn check_er_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1521        skip_if_unsupported!(kernel, test_name);
1522        let single_point = [42.0];
1523        let params = ErParams { period: Some(5) };
1524        let input = ErInput::from_slice(&single_point, params);
1525        let res = er_with_kernel(&input, kernel);
1526        assert!(
1527            res.is_err(),
1528            "[{}] ER should fail with insufficient data",
1529            test_name
1530        );
1531        Ok(())
1532    }
1533
1534    fn check_er_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1535        skip_if_unsupported!(kernel, test_name);
1536        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1537        let candles = read_candles_from_csv(file_path)?;
1538
1539        let first_params = ErParams { period: Some(5) };
1540        let first_input = ErInput::from_candles(&candles, "close", first_params);
1541        let first_result = er_with_kernel(&first_input, kernel)?;
1542
1543        let second_params = ErParams { period: Some(5) };
1544        let second_input = ErInput::from_slice(&first_result.values, second_params);
1545        let second_result = er_with_kernel(&second_input, kernel)?;
1546
1547        assert_eq!(second_result.values.len(), first_result.values.len());
1548        Ok(())
1549    }
1550
1551    fn check_er_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1552        skip_if_unsupported!(kernel, test_name);
1553        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1554        let candles = read_candles_from_csv(file_path)?;
1555
1556        let input = ErInput::from_candles(&candles, "close", ErParams { period: Some(5) });
1557        let res = er_with_kernel(&input, kernel)?;
1558        assert_eq!(res.values.len(), candles.close.len());
1559        if res.values.len() > 240 {
1560            for (i, &val) in res.values[240..].iter().enumerate() {
1561                assert!(
1562                    !val.is_nan(),
1563                    "[{}] Found unexpected NaN at out-index {}",
1564                    test_name,
1565                    240 + i
1566                );
1567            }
1568        }
1569        Ok(())
1570    }
1571
1572    fn check_er_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1573        skip_if_unsupported!(kernel, test_name);
1574
1575        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1576        let candles = read_candles_from_csv(file_path)?;
1577
1578        let period = 5;
1579
1580        let input = ErInput::from_candles(
1581            &candles,
1582            "close",
1583            ErParams {
1584                period: Some(period),
1585            },
1586        );
1587        let batch_output = er_with_kernel(&input, kernel)?.values;
1588
1589        let mut stream = ErStream::try_new(ErParams {
1590            period: Some(period),
1591        })?;
1592
1593        let mut stream_values = Vec::with_capacity(candles.close.len());
1594        for &price in &candles.close {
1595            match stream.update(price) {
1596                Some(er_val) => stream_values.push(er_val),
1597                None => stream_values.push(f64::NAN),
1598            }
1599        }
1600
1601        assert_eq!(batch_output.len(), stream_values.len());
1602        for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1603            if b.is_nan() && s.is_nan() {
1604                continue;
1605            }
1606            let diff = (b - s).abs();
1607            assert!(
1608                diff < 1e-9,
1609                "[{}] ER streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1610                test_name,
1611                i,
1612                b,
1613                s,
1614                diff
1615            );
1616        }
1617        Ok(())
1618    }
1619
1620    #[cfg(debug_assertions)]
1621    fn check_er_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1622        skip_if_unsupported!(kernel, test_name);
1623
1624        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1625        let candles = read_candles_from_csv(file_path)?;
1626
1627        let test_params = vec![
1628            ErParams::default(),
1629            ErParams { period: Some(1) },
1630            ErParams { period: Some(2) },
1631            ErParams { period: Some(3) },
1632            ErParams { period: Some(4) },
1633            ErParams { period: Some(5) },
1634            ErParams { period: Some(10) },
1635            ErParams { period: Some(14) },
1636            ErParams { period: Some(20) },
1637            ErParams { period: Some(30) },
1638            ErParams { period: Some(50) },
1639            ErParams { period: Some(100) },
1640            ErParams { period: Some(200) },
1641            ErParams { period: Some(500) },
1642            ErParams { period: Some(1000) },
1643            ErParams { period: Some(2000) },
1644        ];
1645
1646        for (param_idx, params) in test_params.iter().enumerate() {
1647            let input = ErInput::from_candles(&candles, "close", params.clone());
1648            let output = er_with_kernel(&input, kernel)?;
1649
1650            for (i, &val) in output.values.iter().enumerate() {
1651                if val.is_nan() {
1652                    continue;
1653                }
1654
1655                let bits = val.to_bits();
1656
1657                if bits == 0x11111111_11111111 {
1658                    panic!(
1659                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1660						 with params: period={} (param set {})",
1661                        test_name,
1662                        val,
1663                        bits,
1664                        i,
1665                        params.period.unwrap_or(5),
1666                        param_idx
1667                    );
1668                }
1669
1670                if bits == 0x22222222_22222222 {
1671                    panic!(
1672                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1673						 with params: period={} (param set {})",
1674                        test_name,
1675                        val,
1676                        bits,
1677                        i,
1678                        params.period.unwrap_or(5),
1679                        param_idx
1680                    );
1681                }
1682
1683                if bits == 0x33333333_33333333 {
1684                    panic!(
1685                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1686						 with params: period={} (param set {})",
1687                        test_name,
1688                        val,
1689                        bits,
1690                        i,
1691                        params.period.unwrap_or(5),
1692                        param_idx
1693                    );
1694                }
1695            }
1696        }
1697
1698        Ok(())
1699    }
1700
1701    #[cfg(not(debug_assertions))]
1702    fn check_er_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1703        Ok(())
1704    }
1705
1706    macro_rules! generate_all_er_tests {
1707        ($($test_fn:ident),*) => {
1708            paste::paste! {
1709                $(
1710                    #[test]
1711                    fn [<$test_fn _scalar_f64>]() {
1712                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1713                    }
1714                )*
1715                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1716                $(
1717                    #[test]
1718                    fn [<$test_fn _avx2_f64>]() {
1719                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1720                    }
1721                    #[test]
1722                    fn [<$test_fn _avx512_f64>]() {
1723                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1724                    }
1725                )*
1726            }
1727        }
1728    }
1729
1730    generate_all_er_tests!(
1731        check_er_partial_params,
1732        check_er_default_candles,
1733        check_er_zero_period,
1734        check_er_period_exceeds_length,
1735        check_er_very_small_dataset,
1736        check_er_reinput,
1737        check_er_nan_handling,
1738        check_er_streaming,
1739        check_er_no_poison
1740    );
1741
1742    #[cfg(feature = "proptest")]
1743    generate_all_er_tests!(check_er_property);
1744
1745    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1746        skip_if_unsupported!(kernel, test);
1747
1748        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1749        let c = read_candles_from_csv(file)?;
1750
1751        let output = ErBatchBuilder::new()
1752            .kernel(kernel)
1753            .apply_candles(&c, "close")?;
1754
1755        let def = ErParams::default();
1756        let row = output.values_for(&def).expect("default row missing");
1757        assert_eq!(row.len(), c.close.len());
1758
1759        Ok(())
1760    }
1761
1762    macro_rules! gen_batch_tests {
1763        ($fn_name:ident) => {
1764            paste::paste! {
1765                #[test] fn [<$fn_name _scalar>]()      {
1766                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1767                }
1768                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1769                #[test] fn [<$fn_name _avx2>]()        {
1770                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1771                }
1772                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1773                #[test] fn [<$fn_name _avx512>]()      {
1774                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1775                }
1776                #[test] fn [<$fn_name _auto_detect>]() {
1777                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1778                }
1779            }
1780        };
1781    }
1782    #[cfg(debug_assertions)]
1783    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1784        skip_if_unsupported!(kernel, test);
1785
1786        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1787        let c = read_candles_from_csv(file)?;
1788
1789        let test_configs = vec![
1790            (1, 5, 1),
1791            (2, 10, 2),
1792            (5, 30, 5),
1793            (10, 100, 10),
1794            (50, 500, 50),
1795            (100, 1000, 100),
1796            (14, 14, 0),
1797            (3, 15, 1),
1798            (20, 200, 20),
1799            (25, 50, 5),
1800        ];
1801
1802        for (cfg_idx, &(period_start, period_end, period_step)) in test_configs.iter().enumerate() {
1803            let output = ErBatchBuilder::new()
1804                .kernel(kernel)
1805                .period_range(period_start, period_end, period_step)
1806                .apply_candles(&c, "close")?;
1807
1808            for (idx, &val) in output.values.iter().enumerate() {
1809                if val.is_nan() {
1810                    continue;
1811                }
1812
1813                let bits = val.to_bits();
1814                let row = idx / output.cols;
1815                let col = idx % output.cols;
1816                let combo = &output.combos[row];
1817
1818                if bits == 0x11111111_11111111 {
1819                    panic!(
1820                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1821						at row {} col {} (flat index {}) with params: period={}",
1822                        test,
1823                        cfg_idx,
1824                        val,
1825                        bits,
1826                        row,
1827                        col,
1828                        idx,
1829                        combo.period.unwrap_or(5)
1830                    );
1831                }
1832
1833                if bits == 0x22222222_22222222 {
1834                    panic!(
1835                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1836						at row {} col {} (flat index {}) with params: period={}",
1837                        test,
1838                        cfg_idx,
1839                        val,
1840                        bits,
1841                        row,
1842                        col,
1843                        idx,
1844                        combo.period.unwrap_or(5)
1845                    );
1846                }
1847
1848                if bits == 0x33333333_33333333 {
1849                    panic!(
1850                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1851						at row {} col {} (flat index {}) with params: period={}",
1852                        test,
1853                        cfg_idx,
1854                        val,
1855                        bits,
1856                        row,
1857                        col,
1858                        idx,
1859                        combo.period.unwrap_or(5)
1860                    );
1861                }
1862            }
1863        }
1864
1865        Ok(())
1866    }
1867
1868    #[cfg(not(debug_assertions))]
1869    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1870        Ok(())
1871    }
1872
1873    gen_batch_tests!(check_batch_default_row);
1874    gen_batch_tests!(check_batch_no_poison);
1875
1876    #[cfg(feature = "proptest")]
1877    fn check_er_property(
1878        test_name: &str,
1879        kernel: Kernel,
1880    ) -> Result<(), Box<dyn std::error::Error>> {
1881        use proptest::prelude::*;
1882        skip_if_unsupported!(kernel, test_name);
1883
1884        let strat = (2usize..=50)
1885            .prop_flat_map(|period| {
1886                let min_len = period * 2;
1887                (
1888                    (100.0f64..5000.0f64, 0.01f64..0.1f64),
1889                    -0.02f64..0.02f64,
1890                    Just(period),
1891                    min_len..400,
1892                )
1893            })
1894            .prop_flat_map(|((base_price, volatility), trend, period, len)| {
1895                let price_changes = prop::collection::vec((-1.0f64..1.0f64), len);
1896
1897                (
1898                    Just(base_price),
1899                    Just(volatility),
1900                    Just(trend),
1901                    Just(period),
1902                    price_changes,
1903                )
1904            })
1905            .prop_map(|(base_price, volatility, trend, period, changes)| {
1906                let mut data = Vec::with_capacity(changes.len());
1907                let mut price = base_price;
1908
1909                for (i, &noise) in changes.iter().enumerate() {
1910                    price *= 1.0 + trend;
1911
1912                    price *= 1.0 + (noise * volatility);
1913
1914                    price = price.max(1.0);
1915                    data.push(price);
1916                }
1917
1918                (data, period)
1919            });
1920
1921        proptest::test_runner::TestRunner::default()
1922            .run(&strat, |(data, period)| {
1923                let params = ErParams {
1924                    period: Some(period),
1925                };
1926                let input = ErInput::from_slice(&data, params);
1927
1928                let ErOutput { values: out } = er_with_kernel(&input, kernel).unwrap();
1929                let ErOutput { values: ref_out } = er_with_kernel(&input, Kernel::Scalar).unwrap();
1930
1931                prop_assert_eq!(out.len(), data.len());
1932
1933                let warmup = period - 1;
1934                for i in 0..warmup {
1935                    prop_assert!(
1936                        out[i].is_nan(),
1937                        "Expected NaN during warmup at index {}, got {}",
1938                        i,
1939                        out[i]
1940                    );
1941                }
1942
1943                for i in warmup..data.len() {
1944                    let val = out[i];
1945                    if !val.is_nan() {
1946                        prop_assert!(
1947                            val >= -1e-10 && val <= 1.0 + 1e-10,
1948                            "ER value {} at index {} outside valid range [0, 1]",
1949                            val,
1950                            i
1951                        );
1952                    }
1953                }
1954
1955                for i in 0..data.len() {
1956                    let y = out[i];
1957                    let r = ref_out[i];
1958
1959                    if !y.is_finite() || !r.is_finite() {
1960                        prop_assert_eq!(
1961                            y.to_bits(),
1962                            r.to_bits(),
1963                            "NaN/Inf mismatch at index {}: {} vs {}",
1964                            i,
1965                            y,
1966                            r
1967                        );
1968                    } else {
1969                        let diff = (y - r).abs();
1970                        let ulp_diff = y.to_bits().abs_diff(r.to_bits());
1971                        prop_assert!(
1972                            diff <= 1e-9 || ulp_diff <= 4,
1973                            "Kernel mismatch at index {}: {} vs {} (diff={}, ULP={})",
1974                            i,
1975                            y,
1976                            r,
1977                            diff,
1978                            ulp_diff
1979                        );
1980                    }
1981                }
1982
1983                if data.len() >= period + 10 {
1984                    for i in (warmup + 1)..data.len() {
1985                        if i < period {
1986                            continue;
1987                        }
1988                        let window_start = i + 1 - period;
1989                        let window_end = i;
1990
1991                        let window = &data[window_start..=window_end];
1992                        let is_monotonic_up = window.windows(2).all(|w| w[1] >= w[0] - 1e-10);
1993                        let is_monotonic_down = window.windows(2).all(|w| w[1] <= w[0] + 1e-10);
1994                        let is_constant = window.windows(2).all(|w| (w[1] - w[0]).abs() < 1e-10);
1995
1996                        if !is_constant && (is_monotonic_up || is_monotonic_down) {
1997                            let er_val = out[i];
1998                            let net_change = (window[window.len() - 1] - window[0]).abs();
1999                            if !er_val.is_nan() && net_change > 1e-6 {
2000                                prop_assert!(
2001									er_val >= 0.90,
2002									"Expected high ER (>0.90) for monotonic move at index {}, got {}",
2003									i,
2004									er_val
2005								);
2006                            }
2007                        }
2008                    }
2009                }
2010
2011                for i in (warmup + 1)..data.len() {
2012                    if i < period {
2013                        continue;
2014                    }
2015                    let window_start = i + 1 - period;
2016                    let window_end = i;
2017                    let window = &data[window_start..=window_end];
2018                    let is_constant = window.windows(2).all(|w| (w[1] - w[0]).abs() < 1e-10);
2019
2020                    if is_constant {
2021                        let er_val = out[i];
2022
2023                        prop_assert!(
2024                            er_val.is_nan() || er_val.abs() < 1e-10,
2025                            "Constant prices should yield NaN or 0, got {} at index {}",
2026                            er_val,
2027                            i
2028                        );
2029                    }
2030                }
2031
2032                for i in warmup..data.len() {
2033                    let val = out[i];
2034                    if !val.is_nan() {
2035                        prop_assert!(
2036                            val >= -1e-10,
2037                            "ER should be non-negative, got {} at index {}",
2038                            val,
2039                            i
2040                        );
2041                    }
2042                }
2043
2044                if period >= 4 && data.len() >= period * 3 {
2045                    for i in (warmup + 1)..data.len() {
2046                        if i < period {
2047                            continue;
2048                        }
2049                        let window_start = i + 1 - period;
2050                        let window_end = i;
2051
2052                        let net_change = (data[window_end] - data[window_start]).abs();
2053                        let mut total_movement = 0.0;
2054                        for j in window_start..window_end {
2055                            total_movement += (data[j + 1] - data[j]).abs();
2056                        }
2057
2058                        if total_movement > 0.0 && net_change / total_movement < 0.3 {
2059                            let er_val = out[i];
2060                            if !er_val.is_nan() {
2061                                prop_assert!(
2062                                    er_val <= 0.35,
2063                                    "Expected low ER (<0.35) for choppy market at index {}, got {}",
2064                                    i,
2065                                    er_val
2066                                );
2067                            }
2068                        }
2069                    }
2070                }
2071
2072                Ok(())
2073            })
2074            .unwrap();
2075
2076        Ok(())
2077    }
2078}