1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::{PyAny, PyDict, PyList};
9#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
10use serde::{Deserialize, Serialize};
11#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
12use wasm_bindgen::prelude::*;
13
14use crate::utilities::data_loader::{source_type, Candles};
15use crate::utilities::enums::Kernel;
16use crate::utilities::helpers::{
17 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
18 make_uninit_matrix,
19};
20#[cfg(feature = "python")]
21use crate::utilities::kernel_validation::validate_kernel;
22use aligned_vec::{AVec, CACHELINE_ALIGN};
23#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
24use core::arch::x86_64::*;
25#[cfg(not(target_arch = "wasm32"))]
26use rayon::prelude::*;
27use std::convert::AsRef;
28use std::error::Error;
29use std::mem::MaybeUninit;
30use thiserror::Error;
31
32impl<'a> AsRef<[f64]> for ErInput<'a> {
33 #[inline(always)]
34 fn as_ref(&self) -> &[f64] {
35 match &self.data {
36 ErData::Slice(slice) => slice,
37 ErData::Candles { candles, source } => source_type(candles, source),
38 }
39 }
40}
41
42#[derive(Debug, Clone)]
43pub enum ErData<'a> {
44 Candles {
45 candles: &'a Candles,
46 source: &'a str,
47 },
48 Slice(&'a [f64]),
49}
50
51#[derive(Debug, Clone)]
52pub struct ErOutput {
53 pub values: Vec<f64>,
54}
55
56#[derive(Debug, Clone)]
57#[cfg_attr(
58 all(target_arch = "wasm32", feature = "wasm"),
59 derive(Serialize, Deserialize)
60)]
61pub struct ErParams {
62 pub period: Option<usize>,
63}
64
65impl Default for ErParams {
66 fn default() -> Self {
67 Self { period: Some(5) }
68 }
69}
70
71#[derive(Debug, Clone)]
72pub struct ErInput<'a> {
73 pub data: ErData<'a>,
74 pub params: ErParams,
75}
76
77impl<'a> ErInput<'a> {
78 #[inline]
79 pub fn from_candles(c: &'a Candles, s: &'a str, p: ErParams) -> Self {
80 Self {
81 data: ErData::Candles {
82 candles: c,
83 source: s,
84 },
85 params: p,
86 }
87 }
88 #[inline]
89 pub fn from_slice(sl: &'a [f64], p: ErParams) -> Self {
90 Self {
91 data: ErData::Slice(sl),
92 params: p,
93 }
94 }
95 #[inline]
96 pub fn with_default_candles(c: &'a Candles) -> Self {
97 Self::from_candles(c, "close", ErParams::default())
98 }
99 #[inline]
100 pub fn get_period(&self) -> usize {
101 self.params.period.unwrap_or(5)
102 }
103}
104
105#[derive(Copy, Clone, Debug)]
106pub struct ErBuilder {
107 period: Option<usize>,
108 kernel: Kernel,
109}
110
111impl Default for ErBuilder {
112 fn default() -> Self {
113 Self {
114 period: None,
115 kernel: Kernel::Auto,
116 }
117 }
118}
119
120impl ErBuilder {
121 #[inline(always)]
122 pub fn new() -> Self {
123 Self::default()
124 }
125 #[inline(always)]
126 pub fn period(mut self, n: usize) -> Self {
127 self.period = Some(n);
128 self
129 }
130 #[inline(always)]
131 pub fn kernel(mut self, k: Kernel) -> Self {
132 self.kernel = k;
133 self
134 }
135 #[inline(always)]
136 pub fn apply(self, c: &Candles) -> Result<ErOutput, ErError> {
137 let p = ErParams {
138 period: self.period,
139 };
140 let i = ErInput::from_candles(c, "close", p);
141 er_with_kernel(&i, self.kernel)
142 }
143 #[inline(always)]
144 pub fn apply_slice(self, d: &[f64]) -> Result<ErOutput, ErError> {
145 let p = ErParams {
146 period: self.period,
147 };
148 let i = ErInput::from_slice(d, p);
149 er_with_kernel(&i, self.kernel)
150 }
151 #[inline(always)]
152 pub fn into_stream(self) -> Result<ErStream, ErError> {
153 let p = ErParams {
154 period: self.period,
155 };
156 ErStream::try_new(p)
157 }
158}
159
160#[derive(Debug, Error)]
161pub enum ErError {
162 #[error("er: Input data slice is empty.")]
163 EmptyInputData,
164 #[error("er: All input data values are NaN.")]
165 AllValuesNaN,
166 #[error("er: Invalid period: period = {period}, data length = {data_len}")]
167 InvalidPeriod { period: usize, data_len: usize },
168 #[error("er: Not enough valid data: needed = {needed}, valid = {valid}")]
169 NotEnoughValidData { needed: usize, valid: usize },
170 #[error("er: Output length mismatch: expected {expected}, got {got}")]
171 OutputLengthMismatch { expected: usize, got: usize },
172 #[error("er: Invalid range: start={start}, end={end}, step={step}")]
173 InvalidRange {
174 start: String,
175 end: String,
176 step: String,
177 },
178 #[error("er: Invalid kernel for batch: {0:?}")]
179 InvalidKernelForBatch(crate::utilities::enums::Kernel),
180}
181
182#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
183impl From<ErError> for JsValue {
184 fn from(err: ErError) -> Self {
185 JsValue::from_str(&err.to_string())
186 }
187}
188
189#[inline]
190pub fn er(input: &ErInput) -> Result<ErOutput, ErError> {
191 er_with_kernel(input, Kernel::Auto)
192}
193
194#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
195#[inline]
196pub fn er_into(input: &ErInput, out: &mut [f64]) -> Result<(), ErError> {
197 er_into_slice(out, input, Kernel::Auto)
198}
199
200pub fn er_with_kernel(input: &ErInput, kernel: Kernel) -> Result<ErOutput, ErError> {
201 let data: &[f64] = input.as_ref();
202 let len = data.len();
203 if len == 0 {
204 return Err(ErError::EmptyInputData);
205 }
206 let first = data
207 .iter()
208 .position(|x| !x.is_nan())
209 .ok_or(ErError::AllValuesNaN)?;
210 let period = input.get_period();
211 if period == 0 || period > len {
212 return Err(ErError::InvalidPeriod {
213 period,
214 data_len: len,
215 });
216 }
217 if (len - first) < period {
218 return Err(ErError::NotEnoughValidData {
219 needed: period,
220 valid: len - first,
221 });
222 }
223
224 let chosen = match kernel {
225 Kernel::Auto => Kernel::Scalar,
226 other => other,
227 };
228
229 let warm = first + period - 1;
230 let mut out = alloc_with_nan_prefix(len, warm);
231 unsafe {
232 match chosen {
233 Kernel::Scalar | Kernel::ScalarBatch => er_scalar(data, period, first, &mut out),
234 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
235 Kernel::Avx2 | Kernel::Avx2Batch => er_avx2(data, period, first, &mut out),
236
237 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
238 Kernel::Avx512 | Kernel::Avx512Batch => er_scalar(data, period, first, &mut out),
239 _ => unreachable!(),
240 }
241 }
242 Ok(ErOutput { values: out })
243}
244
245#[inline]
246pub fn er_into_slice(dst: &mut [f64], input: &ErInput, kern: Kernel) -> Result<(), ErError> {
247 let data: &[f64] = input.as_ref();
248 let len = data.len();
249 if len == 0 {
250 return Err(ErError::EmptyInputData);
251 }
252 let first = data
253 .iter()
254 .position(|x| !x.is_nan())
255 .ok_or(ErError::AllValuesNaN)?;
256 let period = input.get_period();
257 if period == 0 || period > len {
258 return Err(ErError::InvalidPeriod {
259 period,
260 data_len: len,
261 });
262 }
263 if (len - first) < period {
264 return Err(ErError::NotEnoughValidData {
265 needed: period,
266 valid: len - first,
267 });
268 }
269 if dst.len() != len {
270 return Err(ErError::OutputLengthMismatch {
271 expected: len,
272 got: dst.len(),
273 });
274 }
275
276 let chosen = match kern {
277 Kernel::Auto => Kernel::Scalar,
278 other => other,
279 };
280
281 unsafe {
282 match chosen {
283 Kernel::Scalar | Kernel::ScalarBatch => er_scalar(data, period, first, dst),
284 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
285 Kernel::Avx2 | Kernel::Avx2Batch => er_avx2(data, period, first, dst),
286
287 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
288 Kernel::Avx512 | Kernel::Avx512Batch => er_scalar(data, period, first, dst),
289 _ => unreachable!(),
290 }
291 }
292
293 let warm_end = first + period - 1;
294 for v in &mut dst[..warm_end] {
295 *v = f64::NAN;
296 }
297
298 Ok(())
299}
300
301#[inline]
302pub fn er_scalar(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
303 let n = data.len();
304 let warm = first + period - 1;
305 if warm >= n {
306 return;
307 }
308
309 let mut roll = 0.0f64;
310 let mut j = first;
311 while j < warm {
312 roll += (data[j + 1] - data[j]).abs();
313 j += 1;
314 }
315
316 let mut start = first;
317 let mut i = warm;
318 while i < n {
319 let delta = (data[i] - data[start]).abs();
320 out[i] = if roll > 0.0 {
321 (delta / roll).min(1.0)
322 } else {
323 0.0
324 };
325
326 if i + 1 == n {
327 break;
328 }
329 let add = (data[i + 1] - data[i]).abs();
330 let sub = (data[start + 1] - data[start]).abs();
331 roll = roll + add - sub;
332 start += 1;
333 i += 1;
334 }
335}
336
337#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
338#[inline]
339pub fn er_avx512(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
340 unsafe {
341 if period <= 32 {
342 er_avx512_short(data, period, first, out);
343 } else {
344 er_avx512_long(data, period, first, out);
345 }
346 }
347}
348
349#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
350#[inline]
351#[target_feature(enable = "avx2")]
352pub unsafe fn er_avx2(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
353 use core::arch::x86_64::*;
354 #[inline(always)]
355 unsafe fn hsum256(x: __m256d) -> f64 {
356 let hi = _mm256_extractf128_pd(x, 1);
357 let lo = _mm256_castpd256_pd128(x);
358 let s = _mm_add_pd(hi, lo);
359 let sh = _mm_unpackhi_pd(s, s);
360 _mm_cvtsd_f64(_mm_add_sd(s, sh))
361 }
362 #[inline(always)]
363 unsafe fn vabs(a: __m256d) -> __m256d {
364 let sign = _mm256_set1_pd(-0.0);
365 _mm256_andnot_pd(sign, a)
366 }
367
368 let n = data.len();
369 let warm = first + period - 1;
370 if warm >= n {
371 return;
372 }
373
374 let ptr = data.as_ptr();
375 let mut acc = unsafe { _mm256_setzero_pd() };
376 let mut j = first;
377 while j + 4 <= warm {
378 let a = unsafe { _mm256_loadu_pd(ptr.add(j)) };
379 let b = unsafe { _mm256_loadu_pd(ptr.add(j + 1)) };
380 acc = unsafe { _mm256_add_pd(acc, vabs(_mm256_sub_pd(b, a))) };
381 j += 4;
382 }
383 let mut roll = unsafe { hsum256(acc) };
384 while j < warm {
385 roll += (data[j + 1] - data[j]).abs();
386 j += 1;
387 }
388
389 let mut start = first;
390 let mut i = warm;
391 while i < n {
392 let delta = (data[i] - data[start]).abs();
393 out[i] = if roll > 0.0 {
394 (delta / roll).min(1.0)
395 } else {
396 0.0
397 };
398 if i + 1 == n {
399 break;
400 }
401 let add = (data[i + 1] - data[i]).abs();
402 let sub = (data[start + 1] - data[start]).abs();
403 roll = roll + add - sub;
404 start += 1;
405 i += 1;
406 }
407}
408
409#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
410#[inline]
411#[target_feature(enable = "avx512f")]
412pub unsafe fn er_avx512_short(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
413 use core::arch::x86_64::*;
414 #[inline(always)]
415 unsafe fn hsum512(x: __m512d) -> f64 {
416 let v1 = _mm512_add_pd(x, _mm512_shuffle_f64x2(x, x, 0b11_10_01_00));
417 let v2 = _mm512_add_pd(v1, _mm512_shuffle_f64x2(v1, v1, 0b00_00_11_10));
418 let lo = _mm512_castpd512_pd128(v2);
419 let hi = _mm256_extractf64x2_pd(_mm512_castpd512_pd256(v2), 1);
420 let s = _mm_add_pd(lo, hi);
421 let sh = _mm_unpackhi_pd(s, s);
422 _mm_cvtsd_f64(_mm_add_sd(s, sh))
423 }
424 #[inline(always)]
425 unsafe fn vabs(a: __m512d) -> __m512d {
426 let sign = _mm512_set1_pd(-0.0);
427 _mm512_andnot_pd(sign, a)
428 }
429
430 let n = data.len();
431 let warm = first + period - 1;
432 if warm >= n {
433 return;
434 }
435
436 let ptr = data.as_ptr();
437 let mut acc = _mm512_setzero_pd();
438 let mut j = first;
439 while j + 8 <= warm {
440 let a = _mm512_loadu_pd(ptr.add(j));
441 let b = _mm512_loadu_pd(ptr.add(j + 1));
442 acc = _mm512_add_pd(acc, vabs(_mm512_sub_pd(b, a)));
443 j += 8;
444 }
445 let mut roll = hsum512(acc);
446 while j < warm {
447 roll += (data[j + 1] - data[j]).abs();
448 j += 1;
449 }
450
451 let mut start = first;
452 let mut i = warm;
453 while i < n {
454 let delta = (data[i] - data[start]).abs();
455 out[i] = if roll > 0.0 {
456 (delta / roll).min(1.0)
457 } else {
458 0.0
459 };
460 if i + 1 == n {
461 break;
462 }
463 let add = (data[i + 1] - data[i]).abs();
464 let sub = (data[start + 1] - data[start]).abs();
465 roll = roll + add - sub;
466 start += 1;
467 i += 1;
468 }
469}
470
471#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
472#[inline]
473#[target_feature(enable = "avx512f")]
474pub unsafe fn er_avx512_long(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
475 er_avx512_short(data, period, first, out)
476}
477
478#[derive(Debug, Clone)]
479pub struct ErStream {
480 period: usize,
481 buffer: Vec<f64>,
482 head: usize,
483 filled: bool,
484 len: usize,
485 denom: f64,
486}
487
488impl ErStream {
489 pub fn try_new(params: ErParams) -> Result<Self, ErError> {
490 let period = params.period.unwrap_or(5);
491 if period == 0 {
492 return Err(ErError::InvalidPeriod {
493 period,
494 data_len: 0,
495 });
496 }
497 Ok(Self {
498 period,
499 buffer: vec![f64::NAN; period],
500 head: 0,
501 filled: false,
502 len: 0,
503 denom: 0.0,
504 })
505 }
506
507 #[inline(always)]
508 pub fn update(&mut self, value: f64) -> Option<f64> {
509 if self.period == 1 {
510 self.buffer[0] = value;
511 self.head = 0;
512 self.filled = true;
513 self.len = 1;
514 self.denom = 0.0;
515 return Some(0.0);
516 }
517
518 if !self.filled {
519 if self.len == 0 {
520 self.buffer[self.head] = value;
521 self.head = (self.head + 1) % self.period;
522 self.len = 1;
523 return None;
524 } else {
525 let prev_idx = if self.head == 0 {
526 self.period - 1
527 } else {
528 self.head - 1
529 };
530 self.denom += (value - self.buffer[prev_idx]).abs();
531
532 self.buffer[self.head] = value;
533 self.head = (self.head + 1) % self.period;
534 self.len += 1;
535
536 if self.len < self.period {
537 return None;
538 }
539
540 self.filled = true;
541
542 let start = self.head;
543 let end = if start == 0 {
544 self.period - 1
545 } else {
546 start - 1
547 };
548 debug_assert!(self.len == self.period);
549
550 let delta = (self.buffer[end] - self.buffer[start]).abs();
551 if self.denom > 0.0 {
552 return Some(if delta >= self.denom {
553 1.0
554 } else {
555 delta / self.denom
556 });
557 } else {
558 return Some(0.0);
559 }
560 }
561 }
562
563 let start = self.head;
564 let second = if start + 1 == self.period {
565 0
566 } else {
567 start + 1
568 };
569 let end_prev = if start == 0 {
570 self.period - 1
571 } else {
572 start - 1
573 };
574
575 let sub = (self.buffer[second] - self.buffer[start]).abs();
576 let add = (value - self.buffer[end_prev]).abs();
577 let new_denom = self.denom + add - sub;
578
579 let delta = (value - self.buffer[second]).abs();
580
581 self.denom = new_denom;
582 self.buffer[start] = value;
583 self.head = second;
584
585 if self.denom > 0.0 {
586 Some(if delta >= self.denom {
587 1.0
588 } else {
589 delta / self.denom
590 })
591 } else {
592 Some(0.0)
593 }
594 }
595}
596
597#[derive(Clone, Debug)]
598pub struct ErBatchRange {
599 pub period: (usize, usize, usize),
600}
601
602impl Default for ErBatchRange {
603 fn default() -> Self {
604 Self {
605 period: (5, 254, 1),
606 }
607 }
608}
609
610#[derive(Clone, Debug, Default)]
611pub struct ErBatchBuilder {
612 range: ErBatchRange,
613 kernel: Kernel,
614}
615
616impl ErBatchBuilder {
617 pub fn new() -> Self {
618 Self::default()
619 }
620 pub fn kernel(mut self, k: Kernel) -> Self {
621 self.kernel = k;
622 self
623 }
624 #[inline]
625 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
626 self.range.period = (start, end, step);
627 self
628 }
629 #[inline]
630 pub fn period_static(mut self, p: usize) -> Self {
631 self.range.period = (p, p, 0);
632 self
633 }
634 pub fn apply_slice(self, data: &[f64]) -> Result<ErBatchOutput, ErError> {
635 er_batch_with_kernel(data, &self.range, self.kernel)
636 }
637 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<ErBatchOutput, ErError> {
638 ErBatchBuilder::new().kernel(k).apply_slice(data)
639 }
640 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<ErBatchOutput, ErError> {
641 let slice = source_type(c, src);
642 self.apply_slice(slice)
643 }
644 pub fn with_default_candles(c: &Candles) -> Result<ErBatchOutput, ErError> {
645 ErBatchBuilder::new()
646 .kernel(Kernel::Auto)
647 .apply_candles(c, "close")
648 }
649}
650
651pub fn er_batch_with_kernel(
652 data: &[f64],
653 sweep: &ErBatchRange,
654 k: Kernel,
655) -> Result<ErBatchOutput, ErError> {
656 let kernel = match k {
657 Kernel::Auto => detect_best_batch_kernel(),
658 other if other.is_batch() => other,
659 other => return Err(ErError::InvalidKernelForBatch(other)),
660 };
661 let simd = match kernel {
662 Kernel::Avx512Batch => Kernel::Avx512,
663 Kernel::Avx2Batch => Kernel::Avx2,
664 Kernel::ScalarBatch => Kernel::Scalar,
665 _ => unreachable!(),
666 };
667 er_batch_par_slice(data, sweep, simd)
668}
669
670#[derive(Clone, Debug)]
671pub struct ErBatchOutput {
672 pub values: Vec<f64>,
673 pub combos: Vec<ErParams>,
674 pub rows: usize,
675 pub cols: usize,
676}
677impl ErBatchOutput {
678 pub fn row_for_params(&self, p: &ErParams) -> Option<usize> {
679 self.combos
680 .iter()
681 .position(|c| c.period.unwrap_or(5) == p.period.unwrap_or(5))
682 }
683 pub fn values_for(&self, p: &ErParams) -> Option<&[f64]> {
684 self.row_for_params(p).map(|row| {
685 let start = row * self.cols;
686 &self.values[start..start + self.cols]
687 })
688 }
689}
690
691#[inline(always)]
692fn expand_grid(r: &ErBatchRange) -> Vec<ErParams> {
693 fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
694 if step == 0 || start == end {
695 return vec![start];
696 }
697 let st = step.max(1);
698 if start < end {
699 (start..=end).step_by(st).collect()
700 } else {
701 let mut v = Vec::new();
702 let mut x = start as isize;
703 let end_i = end as isize;
704 let st_i = st as isize;
705 while x >= end_i {
706 v.push(x as usize);
707 x -= st_i;
708 }
709 v
710 }
711 }
712 let periods = axis_usize(r.period);
713 let mut out = Vec::with_capacity(periods.len());
714 for &p in &periods {
715 out.push(ErParams { period: Some(p) });
716 }
717 out
718}
719
720#[inline(always)]
721pub fn er_batch_slice(
722 data: &[f64],
723 sweep: &ErBatchRange,
724 kern: Kernel,
725) -> Result<ErBatchOutput, ErError> {
726 er_batch_inner(data, sweep, kern, false)
727}
728
729#[inline(always)]
730pub fn er_batch_par_slice(
731 data: &[f64],
732 sweep: &ErBatchRange,
733 kern: Kernel,
734) -> Result<ErBatchOutput, ErError> {
735 er_batch_inner(data, sweep, kern, true)
736}
737
738#[inline(always)]
739fn er_batch_inner_into(
740 data: &[f64],
741 sweep: &ErBatchRange,
742 kern: Kernel,
743 parallel: bool,
744 out: &mut [f64],
745) -> Result<Vec<ErParams>, ErError> {
746 let combos = expand_grid(sweep);
747 if combos.is_empty() {
748 return Err(ErError::InvalidRange {
749 start: sweep.period.0.to_string(),
750 end: sweep.period.1.to_string(),
751 step: sweep.period.2.to_string(),
752 });
753 }
754
755 let cols = data.len();
756 if cols == 0 {
757 return Err(ErError::EmptyInputData);
758 }
759 let first = data
760 .iter()
761 .position(|x| !x.is_nan())
762 .ok_or(ErError::AllValuesNaN)?;
763 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
764 if cols - first < max_p {
765 return Err(ErError::NotEnoughValidData {
766 needed: max_p,
767 valid: cols - first,
768 });
769 }
770
771 let rows = combos.len();
772 let out_mu = unsafe {
773 std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
774 };
775 let expected = rows
776 .checked_mul(cols)
777 .ok_or_else(|| ErError::InvalidRange {
778 start: "rows*cols".into(),
779 end: "overflow".into(),
780 step: "*".into(),
781 })?;
782 if out.len() != expected {
783 return Err(ErError::OutputLengthMismatch {
784 expected,
785 got: out.len(),
786 });
787 }
788 let warm: Vec<usize> = combos
789 .iter()
790 .map(|c| first + c.period.unwrap() - 1)
791 .collect();
792 init_matrix_prefixes(out_mu, cols, &warm);
793
794 let mut prefix = vec![0.0f64; cols];
795 if first < cols {
796 let mut j = first;
797 while j + 1 < cols {
798 let d = (data[j + 1] - data[j]).abs();
799 prefix[j + 1] = prefix[j] + d;
800 j += 1;
801 }
802 }
803
804 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
805 let period = combos[row].period.unwrap();
806 match kern {
807 Kernel::Scalar => er_row_scalar_with_prefix(data, &prefix, first, period, out_row),
808 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
809 Kernel::Avx2 => er_row_avx2(data, first, period, out_row),
810 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
811 Kernel::Avx512 => er_row_avx512(data, first, period, out_row),
812 _ => unreachable!(),
813 }
814 };
815
816 if parallel {
817 #[cfg(not(target_arch = "wasm32"))]
818 {
819 out.par_chunks_mut(cols)
820 .enumerate()
821 .for_each(|(row, slice)| do_row(row, slice));
822 }
823
824 #[cfg(target_arch = "wasm32")]
825 {
826 for (row, slice) in out.chunks_mut(cols).enumerate() {
827 do_row(row, slice);
828 }
829 }
830 } else {
831 for (row, slice) in out.chunks_mut(cols).enumerate() {
832 do_row(row, slice);
833 }
834 }
835
836 Ok(combos)
837}
838
839#[inline(always)]
840fn er_batch_inner(
841 data: &[f64],
842 sweep: &ErBatchRange,
843 kern: Kernel,
844 parallel: bool,
845) -> Result<ErBatchOutput, ErError> {
846 let combos = expand_grid(sweep);
847 if combos.is_empty() {
848 return Err(ErError::InvalidRange {
849 start: sweep.period.0.to_string(),
850 end: sweep.period.1.to_string(),
851 step: sweep.period.2.to_string(),
852 });
853 }
854
855 let cols = data.len();
856 if cols == 0 {
857 return Err(ErError::EmptyInputData);
858 }
859 let first = data
860 .iter()
861 .position(|x| !x.is_nan())
862 .ok_or(ErError::AllValuesNaN)?;
863 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
864 if cols - first < max_p {
865 return Err(ErError::NotEnoughValidData {
866 needed: max_p,
867 valid: cols - first,
868 });
869 }
870
871 let rows = combos.len();
872 let _total = rows
873 .checked_mul(cols)
874 .ok_or_else(|| ErError::InvalidRange {
875 start: "rows*cols".into(),
876 end: "overflow".into(),
877 step: "*".into(),
878 })?;
879 let mut buf_mu = make_uninit_matrix(rows, cols);
880
881 let warm: Vec<usize> = combos
882 .iter()
883 .map(|c| first + c.period.unwrap() - 1)
884 .collect();
885 init_matrix_prefixes(&mut buf_mu, cols, &warm);
886
887 let mut buf_guard = std::mem::ManuallyDrop::new(buf_mu);
888 let values: &mut [f64] = unsafe {
889 std::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
890 };
891
892 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
893 let period = combos[row].period.unwrap();
894 match kern {
895 Kernel::Scalar => er_row_scalar(data, first, period, out_row),
896 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
897 Kernel::Avx2 => er_row_avx2(data, first, period, out_row),
898 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
899 Kernel::Avx512 => er_row_avx512(data, first, period, out_row),
900 _ => unreachable!(),
901 }
902 };
903
904 if parallel {
905 #[cfg(not(target_arch = "wasm32"))]
906 {
907 values
908 .par_chunks_mut(cols)
909 .enumerate()
910 .for_each(|(row, slice)| do_row(row, slice));
911 }
912
913 #[cfg(target_arch = "wasm32")]
914 {
915 for (row, slice) in values.chunks_mut(cols).enumerate() {
916 do_row(row, slice);
917 }
918 }
919 } else {
920 for (row, slice) in values.chunks_mut(cols).enumerate() {
921 do_row(row, slice);
922 }
923 }
924
925 let values = unsafe {
926 Vec::from_raw_parts(
927 buf_guard.as_mut_ptr() as *mut f64,
928 buf_guard.len(),
929 buf_guard.capacity(),
930 )
931 };
932
933 Ok(ErBatchOutput {
934 values,
935 combos,
936 rows,
937 cols,
938 })
939}
940
941#[inline(always)]
942unsafe fn er_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
943 er_scalar(data, period, first, out)
944}
945
946#[inline(always)]
947fn er_row_scalar_with_prefix(
948 data: &[f64],
949 prefix: &[f64],
950 first: usize,
951 period: usize,
952 out: &mut [f64],
953) {
954 let n = data.len();
955 let warm = first + period - 1;
956 if warm >= n {
957 return;
958 }
959 let mut i = warm;
960 while i < n {
961 let start = i + 1 - period;
962 let delta = (data[i] - data[start]).abs();
963 let denom = prefix[i] - prefix[start];
964 out[i] = if denom > 0.0 {
965 (delta / denom).min(1.0)
966 } else {
967 0.0
968 };
969 i += 1;
970 }
971}
972
973#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
974#[inline(always)]
975unsafe fn er_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
976 er_avx2(data, period, first, out)
977}
978
979#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
980#[inline(always)]
981unsafe fn er_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
982 if period <= 32 {
983 er_row_avx512_short(data, first, period, out);
984 } else {
985 er_row_avx512_long(data, first, period, out);
986 }
987}
988
989#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
990#[inline(always)]
991unsafe fn er_row_avx512_short(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
992 er_avx512_short(data, period, first, out)
993}
994
995#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
996#[inline(always)]
997unsafe fn er_row_avx512_long(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
998 er_avx512_long(data, period, first, out)
999}
1000
1001#[cfg(feature = "python")]
1002#[pyfunction(name = "er")]
1003#[pyo3(signature = (data, period, kernel=None))]
1004pub fn er_py<'py>(
1005 py: Python<'py>,
1006 data: numpy::PyReadonlyArray1<'py, f64>,
1007 period: usize,
1008 kernel: Option<&str>,
1009) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
1010 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1011
1012 let slice_in = data.as_slice()?;
1013 let kern = validate_kernel(kernel, false)?;
1014
1015 let params = ErParams {
1016 period: Some(period),
1017 };
1018 let input = ErInput::from_slice(slice_in, params);
1019
1020 let result_vec = py
1021 .allow_threads(|| er_with_kernel(&input, kern))
1022 .map(|result| result.values)
1023 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1024
1025 Ok(result_vec.into_pyarray(py))
1026}
1027
1028#[cfg(feature = "python")]
1029#[pyclass(name = "ErStream")]
1030pub struct ErStreamPy {
1031 stream: ErStream,
1032}
1033
1034#[cfg(feature = "python")]
1035#[pymethods]
1036impl ErStreamPy {
1037 #[new]
1038 fn new(period: usize) -> PyResult<Self> {
1039 let params = ErParams {
1040 period: Some(period),
1041 };
1042 let stream = ErStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1043 Ok(ErStreamPy { stream })
1044 }
1045
1046 fn update(&mut self, value: f64) -> Option<f64> {
1047 self.stream.update(value)
1048 }
1049}
1050
1051#[cfg(feature = "python")]
1052#[pyfunction(name = "er_batch")]
1053#[pyo3(signature = (data, period_range, kernel=None))]
1054pub fn er_batch_py<'py>(
1055 py: Python<'py>,
1056 data: numpy::PyReadonlyArray1<'py, f64>,
1057 period_range: (usize, usize, usize),
1058 kernel: Option<&str>,
1059) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1060 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1061 use pyo3::types::PyDict;
1062
1063 let slice_in = data.as_slice()?;
1064 let kern = validate_kernel(kernel, true)?;
1065
1066 let sweep = ErBatchRange {
1067 period: period_range,
1068 };
1069 let combos = expand_grid(&sweep);
1070 let rows = combos.len();
1071 let cols = slice_in.len();
1072
1073 let out_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
1074 let slice_out = unsafe { out_arr.as_slice_mut()? };
1075
1076 let combos = py
1077 .allow_threads(|| {
1078 let simd = match kern {
1079 Kernel::Auto => {
1080 let base = detect_best_kernel();
1081 match base {
1082 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1083 Kernel::Avx512 => Kernel::Scalar,
1084 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1085 Kernel::Avx2 => Kernel::Avx2,
1086 _ => Kernel::Scalar,
1087 }
1088 }
1089 other => match other {
1090 Kernel::ScalarBatch => Kernel::Scalar,
1091 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1092 Kernel::Avx2Batch => Kernel::Avx2,
1093 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1094 Kernel::Avx512Batch => Kernel::Avx512,
1095 _ => unreachable!(),
1096 },
1097 };
1098 er_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1099 })
1100 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1101
1102 let dict = PyDict::new(py);
1103 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1104 dict.set_item(
1105 "periods",
1106 combos
1107 .iter()
1108 .map(|p| p.period.unwrap() as u64)
1109 .collect::<Vec<_>>()
1110 .into_pyarray(py),
1111 )?;
1112 dict.set_item("rows", rows)?;
1113 dict.set_item("cols", cols)?;
1114 Ok(dict)
1115}
1116
1117#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1118#[wasm_bindgen]
1119pub fn er_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1120 let params = ErParams {
1121 period: Some(period),
1122 };
1123 let input = ErInput::from_slice(data, params);
1124
1125 let mut output = vec![0.0; data.len()];
1126
1127 er_into_slice(&mut output, &input, Kernel::Auto)
1128 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1129
1130 Ok(output)
1131}
1132
1133#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1134#[wasm_bindgen]
1135pub fn er_alloc(len: usize) -> *mut f64 {
1136 let mut vec = Vec::<f64>::with_capacity(len);
1137 let ptr = vec.as_mut_ptr();
1138 std::mem::forget(vec);
1139 ptr
1140}
1141
1142#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1143#[wasm_bindgen]
1144pub fn er_free(ptr: *mut f64, len: usize) {
1145 if !ptr.is_null() {
1146 unsafe {
1147 let _ = Vec::from_raw_parts(ptr, len, len);
1148 }
1149 }
1150}
1151
1152#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1153#[wasm_bindgen]
1154pub fn er_into(
1155 in_ptr: *const f64,
1156 out_ptr: *mut f64,
1157 len: usize,
1158 period: usize,
1159) -> Result<(), JsValue> {
1160 if in_ptr.is_null() || out_ptr.is_null() {
1161 return Err(JsValue::from_str("null pointer passed to er_into"));
1162 }
1163
1164 unsafe {
1165 let data = std::slice::from_raw_parts(in_ptr, len);
1166
1167 if period == 0 || period > len {
1168 return Err(JsValue::from_str("Invalid period"));
1169 }
1170
1171 let params = ErParams {
1172 period: Some(period),
1173 };
1174 let input = ErInput::from_slice(data, params);
1175
1176 if in_ptr == out_ptr {
1177 let mut temp = vec![0.0; len];
1178 er_into_slice(&mut temp, &input, Kernel::Auto)
1179 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1180 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1181 out.copy_from_slice(&temp);
1182 } else {
1183 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1184 er_into_slice(out, &input, Kernel::Auto)
1185 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1186 }
1187
1188 Ok(())
1189 }
1190}
1191
1192#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1193#[derive(Serialize, Deserialize)]
1194pub struct ErBatchConfig {
1195 pub period_range: (usize, usize, usize),
1196}
1197
1198#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1199#[derive(Serialize, Deserialize)]
1200pub struct ErBatchJsOutput {
1201 pub values: Vec<f64>,
1202 pub combos: Vec<ErParams>,
1203 pub rows: usize,
1204 pub cols: usize,
1205}
1206
1207#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1208#[wasm_bindgen(js_name = er_batch)]
1209pub fn er_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1210 let config: ErBatchConfig = serde_wasm_bindgen::from_value(config)
1211 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1212
1213 let sweep = ErBatchRange {
1214 period: config.period_range,
1215 };
1216
1217 let output = er_batch_with_kernel(data, &sweep, Kernel::Auto)
1218 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1219
1220 let js_output = ErBatchJsOutput {
1221 values: output.values,
1222 combos: output.combos,
1223 rows: output.rows,
1224 cols: output.cols,
1225 };
1226
1227 serde_wasm_bindgen::to_value(&js_output).map_err(|e| JsValue::from_str(&e.to_string()))
1228}
1229
1230#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1231#[wasm_bindgen]
1232pub fn er_batch_into(
1233 in_ptr: *const f64,
1234 out_ptr: *mut f64,
1235 len: usize,
1236 period_start: usize,
1237 period_end: usize,
1238 period_step: usize,
1239) -> Result<usize, JsValue> {
1240 if in_ptr.is_null() || out_ptr.is_null() {
1241 return Err(JsValue::from_str("null pointer passed to er_batch_into"));
1242 }
1243
1244 unsafe {
1245 let data = std::slice::from_raw_parts(in_ptr, len);
1246 let sweep = ErBatchRange {
1247 period: (period_start, period_end, period_step),
1248 };
1249 let combos = expand_grid(&sweep);
1250 let rows = combos.len();
1251 let cols = len;
1252 if rows * cols > 0 {
1253 let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
1254
1255 let batch_kernel = detect_best_batch_kernel();
1256 let simd = match batch_kernel {
1257 Kernel::Avx512Batch => Kernel::Avx512,
1258 Kernel::Avx2Batch => Kernel::Avx2,
1259 Kernel::ScalarBatch => Kernel::Scalar,
1260 _ => unreachable!(),
1261 };
1262 er_batch_inner_into(data, &sweep, simd, false, out)
1263 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1264 }
1265 Ok(rows)
1266 }
1267}
1268
1269#[cfg(all(feature = "python", feature = "cuda"))]
1270use crate::cuda::er_wrapper::{CudaEr, DeviceArrayF32Er};
1271#[cfg(all(feature = "python", feature = "cuda"))]
1272use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
1273#[cfg(all(feature = "python", feature = "cuda"))]
1274use numpy::PyReadonlyArray1;
1275#[cfg(all(feature = "python", feature = "cuda"))]
1276#[cfg(all(feature = "python", feature = "cuda"))]
1277use pyo3::prelude::*;
1278
1279#[cfg(all(feature = "python", feature = "cuda"))]
1280#[pyfunction(name = "er_cuda_batch_dev")]
1281#[pyo3(signature = (data_f32, period_range, device_id=0))]
1282pub fn er_cuda_batch_dev_py(
1283 py: Python<'_>,
1284 data_f32: PyReadonlyArray1<'_, f32>,
1285 period_range: (usize, usize, usize),
1286 device_id: usize,
1287) -> PyResult<DeviceArrayF32ErPy> {
1288 use crate::cuda::cuda_available;
1289 if !cuda_available() {
1290 return Err(PyValueError::new_err("CUDA not available"));
1291 }
1292 let slice = data_f32.as_slice()?;
1293 let sweep = ErBatchRange {
1294 period: period_range,
1295 };
1296 let inner = py.allow_threads(|| {
1297 let cuda = CudaEr::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1298 cuda.er_batch_dev(slice, &sweep)
1299 .map_err(|e| PyValueError::new_err(e.to_string()))
1300 })?;
1301 Ok(DeviceArrayF32ErPy { inner })
1302}
1303
1304#[cfg(all(feature = "python", feature = "cuda"))]
1305#[pyfunction(name = "er_cuda_many_series_one_param_dev")]
1306#[pyo3(signature = (data_tm_f32, cols, rows, period, device_id=0))]
1307pub fn er_cuda_many_series_one_param_dev_py(
1308 py: Python<'_>,
1309 data_tm_f32: PyReadonlyArray1<'_, f32>,
1310 cols: usize,
1311 rows: usize,
1312 period: usize,
1313 device_id: usize,
1314) -> PyResult<DeviceArrayF32ErPy> {
1315 use crate::cuda::cuda_available;
1316 if !cuda_available() {
1317 return Err(PyValueError::new_err("CUDA not available"));
1318 }
1319 let slice = data_tm_f32.as_slice()?;
1320 let inner = py.allow_threads(|| {
1321 let cuda = CudaEr::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1322 cuda.er_many_series_one_param_time_major_dev(slice, cols, rows, period)
1323 .map_err(|e| PyValueError::new_err(e.to_string()))
1324 })?;
1325 Ok(DeviceArrayF32ErPy { inner })
1326}
1327
1328#[cfg(all(feature = "python", feature = "cuda"))]
1329#[pyclass(module = "ta_indicators.cuda", unsendable)]
1330pub struct DeviceArrayF32ErPy {
1331 pub(crate) inner: DeviceArrayF32Er,
1332}
1333
1334#[cfg(all(feature = "python", feature = "cuda"))]
1335#[pymethods]
1336impl DeviceArrayF32ErPy {
1337 #[getter]
1338 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1339 let d = PyDict::new(py);
1340 d.set_item("shape", (self.inner.rows, self.inner.cols))?;
1341 d.set_item("typestr", "<f4")?;
1342 d.set_item(
1343 "strides",
1344 (
1345 self.inner.cols * std::mem::size_of::<f32>(),
1346 std::mem::size_of::<f32>(),
1347 ),
1348 )?;
1349 d.set_item("data", (self.inner.device_ptr() as usize, false))?;
1350 d.set_item("version", 3)?;
1351 Ok(d)
1352 }
1353
1354 fn __dlpack_device__(&self) -> (i32, i32) {
1355 (2, self.inner.device_id as i32)
1356 }
1357
1358 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1359 fn __dlpack__<'py>(
1360 &mut self,
1361 py: Python<'py>,
1362 stream: Option<PyObject>,
1363 max_version: Option<PyObject>,
1364 dl_device: Option<PyObject>,
1365 copy: Option<PyObject>,
1366 ) -> PyResult<PyObject> {
1367 use cust::memory::DeviceBuffer;
1368 use pyo3::types::PyAny;
1369 use pyo3::Bound;
1370
1371 let (dev_ty, alloc_dev) = self.__dlpack_device__();
1372 if let Some(dev_obj) = dl_device.as_ref() {
1373 if let Ok((want_ty, want_dev)) = dev_obj.extract::<(i32, i32)>(py) {
1374 if want_ty != dev_ty || want_dev != alloc_dev {
1375 return Err(PyValueError::new_err(
1376 "__dlpack__ dl_device does not match ER buffer device",
1377 ));
1378 }
1379 } else {
1380 return Err(PyValueError::new_err(
1381 "__dlpack__ dl_device must be a (device_type, device_id) tuple",
1382 ));
1383 }
1384 }
1385
1386 let _ = stream;
1387 let _ = copy;
1388
1389 let dummy =
1390 DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
1391 let rows = self.inner.rows;
1392 let cols = self.inner.cols;
1393 let ctx = self.inner.ctx.clone();
1394 let device_id = self.inner.device_id;
1395 let inner = std::mem::replace(
1396 &mut self.inner,
1397 DeviceArrayF32Er {
1398 buf: dummy,
1399 rows: 0,
1400 cols: 0,
1401 ctx,
1402 device_id,
1403 },
1404 );
1405
1406 let max_version_bound: Option<Bound<'py, PyAny>> =
1407 max_version.map(|obj| obj.into_bound(py));
1408
1409 export_f32_cuda_dlpack_2d(py, inner.buf, rows, cols, alloc_dev, max_version_bound)
1410 }
1411}
1412
1413#[cfg(test)]
1414mod tests {
1415 use super::*;
1416 use crate::skip_if_unsupported;
1417 use crate::utilities::data_loader::read_candles_from_csv;
1418
1419 fn check_er_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1420 skip_if_unsupported!(kernel, test_name);
1421 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1422 let candles = read_candles_from_csv(file_path)?;
1423
1424 let default_params = ErParams { period: None };
1425 let input = ErInput::from_candles(&candles, "close", default_params);
1426 let output = er_with_kernel(&input, kernel)?;
1427 assert_eq!(output.values.len(), candles.close.len());
1428
1429 Ok(())
1430 }
1431
1432 #[test]
1433 fn test_er_into_matches_api() -> Result<(), Box<dyn Error>> {
1434 let n = 256usize;
1435 let mut data = Vec::with_capacity(n);
1436 for i in 0..n {
1437 if i < 3 {
1438 data.push(f64::NAN);
1439 } else {
1440 let x = i as f64;
1441 data.push((x * 0.01).sin() * (x * 0.02).cos() + 0.001 * x);
1442 }
1443 }
1444
1445 let input = ErInput::from_slice(&data, ErParams::default());
1446
1447 let base = er(&input)?.values;
1448
1449 let mut out = vec![0.0; n];
1450
1451 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1452 {
1453 er_into(&input, &mut out)?;
1454 }
1455
1456 assert_eq!(base.len(), out.len());
1457
1458 fn eq_or_both_nan_eps(a: f64, b: f64) -> bool {
1459 (a.is_nan() && b.is_nan()) || (a - b).abs() <= 1e-12
1460 }
1461
1462 for i in 0..n {
1463 assert!(
1464 eq_or_both_nan_eps(base[i], out[i]),
1465 "mismatch at {}: base={:?}, into={:?}",
1466 i,
1467 base[i],
1468 out[i]
1469 );
1470 }
1471 Ok(())
1472 }
1473
1474 fn check_er_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1475 skip_if_unsupported!(kernel, test_name);
1476 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1477 let candles = read_candles_from_csv(file_path)?;
1478
1479 let input = ErInput::with_default_candles(&candles);
1480 match input.data {
1481 ErData::Candles { source, .. } => assert_eq!(source, "close"),
1482 _ => panic!("Expected ErData::Candles"),
1483 }
1484 let output = er_with_kernel(&input, kernel)?;
1485 assert_eq!(output.values.len(), candles.close.len());
1486 Ok(())
1487 }
1488
1489 fn check_er_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1490 skip_if_unsupported!(kernel, test_name);
1491 let input_data = [10.0, 20.0, 30.0];
1492 let params = ErParams { period: Some(0) };
1493 let input = ErInput::from_slice(&input_data, params);
1494 let res = er_with_kernel(&input, kernel);
1495 assert!(
1496 res.is_err(),
1497 "[{}] ER should fail with zero period",
1498 test_name
1499 );
1500 Ok(())
1501 }
1502
1503 fn check_er_period_exceeds_length(
1504 test_name: &str,
1505 kernel: Kernel,
1506 ) -> Result<(), Box<dyn Error>> {
1507 skip_if_unsupported!(kernel, test_name);
1508 let data_small = [10.0, 20.0, 30.0];
1509 let params = ErParams { period: Some(10) };
1510 let input = ErInput::from_slice(&data_small, params);
1511 let res = er_with_kernel(&input, kernel);
1512 assert!(
1513 res.is_err(),
1514 "[{}] ER should fail with period exceeding length",
1515 test_name
1516 );
1517 Ok(())
1518 }
1519
1520 fn check_er_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1521 skip_if_unsupported!(kernel, test_name);
1522 let single_point = [42.0];
1523 let params = ErParams { period: Some(5) };
1524 let input = ErInput::from_slice(&single_point, params);
1525 let res = er_with_kernel(&input, kernel);
1526 assert!(
1527 res.is_err(),
1528 "[{}] ER should fail with insufficient data",
1529 test_name
1530 );
1531 Ok(())
1532 }
1533
1534 fn check_er_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1535 skip_if_unsupported!(kernel, test_name);
1536 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1537 let candles = read_candles_from_csv(file_path)?;
1538
1539 let first_params = ErParams { period: Some(5) };
1540 let first_input = ErInput::from_candles(&candles, "close", first_params);
1541 let first_result = er_with_kernel(&first_input, kernel)?;
1542
1543 let second_params = ErParams { period: Some(5) };
1544 let second_input = ErInput::from_slice(&first_result.values, second_params);
1545 let second_result = er_with_kernel(&second_input, kernel)?;
1546
1547 assert_eq!(second_result.values.len(), first_result.values.len());
1548 Ok(())
1549 }
1550
1551 fn check_er_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1552 skip_if_unsupported!(kernel, test_name);
1553 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1554 let candles = read_candles_from_csv(file_path)?;
1555
1556 let input = ErInput::from_candles(&candles, "close", ErParams { period: Some(5) });
1557 let res = er_with_kernel(&input, kernel)?;
1558 assert_eq!(res.values.len(), candles.close.len());
1559 if res.values.len() > 240 {
1560 for (i, &val) in res.values[240..].iter().enumerate() {
1561 assert!(
1562 !val.is_nan(),
1563 "[{}] Found unexpected NaN at out-index {}",
1564 test_name,
1565 240 + i
1566 );
1567 }
1568 }
1569 Ok(())
1570 }
1571
1572 fn check_er_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1573 skip_if_unsupported!(kernel, test_name);
1574
1575 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1576 let candles = read_candles_from_csv(file_path)?;
1577
1578 let period = 5;
1579
1580 let input = ErInput::from_candles(
1581 &candles,
1582 "close",
1583 ErParams {
1584 period: Some(period),
1585 },
1586 );
1587 let batch_output = er_with_kernel(&input, kernel)?.values;
1588
1589 let mut stream = ErStream::try_new(ErParams {
1590 period: Some(period),
1591 })?;
1592
1593 let mut stream_values = Vec::with_capacity(candles.close.len());
1594 for &price in &candles.close {
1595 match stream.update(price) {
1596 Some(er_val) => stream_values.push(er_val),
1597 None => stream_values.push(f64::NAN),
1598 }
1599 }
1600
1601 assert_eq!(batch_output.len(), stream_values.len());
1602 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1603 if b.is_nan() && s.is_nan() {
1604 continue;
1605 }
1606 let diff = (b - s).abs();
1607 assert!(
1608 diff < 1e-9,
1609 "[{}] ER streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1610 test_name,
1611 i,
1612 b,
1613 s,
1614 diff
1615 );
1616 }
1617 Ok(())
1618 }
1619
1620 #[cfg(debug_assertions)]
1621 fn check_er_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1622 skip_if_unsupported!(kernel, test_name);
1623
1624 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1625 let candles = read_candles_from_csv(file_path)?;
1626
1627 let test_params = vec![
1628 ErParams::default(),
1629 ErParams { period: Some(1) },
1630 ErParams { period: Some(2) },
1631 ErParams { period: Some(3) },
1632 ErParams { period: Some(4) },
1633 ErParams { period: Some(5) },
1634 ErParams { period: Some(10) },
1635 ErParams { period: Some(14) },
1636 ErParams { period: Some(20) },
1637 ErParams { period: Some(30) },
1638 ErParams { period: Some(50) },
1639 ErParams { period: Some(100) },
1640 ErParams { period: Some(200) },
1641 ErParams { period: Some(500) },
1642 ErParams { period: Some(1000) },
1643 ErParams { period: Some(2000) },
1644 ];
1645
1646 for (param_idx, params) in test_params.iter().enumerate() {
1647 let input = ErInput::from_candles(&candles, "close", params.clone());
1648 let output = er_with_kernel(&input, kernel)?;
1649
1650 for (i, &val) in output.values.iter().enumerate() {
1651 if val.is_nan() {
1652 continue;
1653 }
1654
1655 let bits = val.to_bits();
1656
1657 if bits == 0x11111111_11111111 {
1658 panic!(
1659 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1660 with params: period={} (param set {})",
1661 test_name,
1662 val,
1663 bits,
1664 i,
1665 params.period.unwrap_or(5),
1666 param_idx
1667 );
1668 }
1669
1670 if bits == 0x22222222_22222222 {
1671 panic!(
1672 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1673 with params: period={} (param set {})",
1674 test_name,
1675 val,
1676 bits,
1677 i,
1678 params.period.unwrap_or(5),
1679 param_idx
1680 );
1681 }
1682
1683 if bits == 0x33333333_33333333 {
1684 panic!(
1685 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1686 with params: period={} (param set {})",
1687 test_name,
1688 val,
1689 bits,
1690 i,
1691 params.period.unwrap_or(5),
1692 param_idx
1693 );
1694 }
1695 }
1696 }
1697
1698 Ok(())
1699 }
1700
1701 #[cfg(not(debug_assertions))]
1702 fn check_er_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1703 Ok(())
1704 }
1705
1706 macro_rules! generate_all_er_tests {
1707 ($($test_fn:ident),*) => {
1708 paste::paste! {
1709 $(
1710 #[test]
1711 fn [<$test_fn _scalar_f64>]() {
1712 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1713 }
1714 )*
1715 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1716 $(
1717 #[test]
1718 fn [<$test_fn _avx2_f64>]() {
1719 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1720 }
1721 #[test]
1722 fn [<$test_fn _avx512_f64>]() {
1723 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1724 }
1725 )*
1726 }
1727 }
1728 }
1729
1730 generate_all_er_tests!(
1731 check_er_partial_params,
1732 check_er_default_candles,
1733 check_er_zero_period,
1734 check_er_period_exceeds_length,
1735 check_er_very_small_dataset,
1736 check_er_reinput,
1737 check_er_nan_handling,
1738 check_er_streaming,
1739 check_er_no_poison
1740 );
1741
1742 #[cfg(feature = "proptest")]
1743 generate_all_er_tests!(check_er_property);
1744
1745 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1746 skip_if_unsupported!(kernel, test);
1747
1748 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1749 let c = read_candles_from_csv(file)?;
1750
1751 let output = ErBatchBuilder::new()
1752 .kernel(kernel)
1753 .apply_candles(&c, "close")?;
1754
1755 let def = ErParams::default();
1756 let row = output.values_for(&def).expect("default row missing");
1757 assert_eq!(row.len(), c.close.len());
1758
1759 Ok(())
1760 }
1761
1762 macro_rules! gen_batch_tests {
1763 ($fn_name:ident) => {
1764 paste::paste! {
1765 #[test] fn [<$fn_name _scalar>]() {
1766 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1767 }
1768 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1769 #[test] fn [<$fn_name _avx2>]() {
1770 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1771 }
1772 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1773 #[test] fn [<$fn_name _avx512>]() {
1774 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1775 }
1776 #[test] fn [<$fn_name _auto_detect>]() {
1777 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1778 }
1779 }
1780 };
1781 }
1782 #[cfg(debug_assertions)]
1783 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1784 skip_if_unsupported!(kernel, test);
1785
1786 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1787 let c = read_candles_from_csv(file)?;
1788
1789 let test_configs = vec![
1790 (1, 5, 1),
1791 (2, 10, 2),
1792 (5, 30, 5),
1793 (10, 100, 10),
1794 (50, 500, 50),
1795 (100, 1000, 100),
1796 (14, 14, 0),
1797 (3, 15, 1),
1798 (20, 200, 20),
1799 (25, 50, 5),
1800 ];
1801
1802 for (cfg_idx, &(period_start, period_end, period_step)) in test_configs.iter().enumerate() {
1803 let output = ErBatchBuilder::new()
1804 .kernel(kernel)
1805 .period_range(period_start, period_end, period_step)
1806 .apply_candles(&c, "close")?;
1807
1808 for (idx, &val) in output.values.iter().enumerate() {
1809 if val.is_nan() {
1810 continue;
1811 }
1812
1813 let bits = val.to_bits();
1814 let row = idx / output.cols;
1815 let col = idx % output.cols;
1816 let combo = &output.combos[row];
1817
1818 if bits == 0x11111111_11111111 {
1819 panic!(
1820 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1821 at row {} col {} (flat index {}) with params: period={}",
1822 test,
1823 cfg_idx,
1824 val,
1825 bits,
1826 row,
1827 col,
1828 idx,
1829 combo.period.unwrap_or(5)
1830 );
1831 }
1832
1833 if bits == 0x22222222_22222222 {
1834 panic!(
1835 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1836 at row {} col {} (flat index {}) with params: period={}",
1837 test,
1838 cfg_idx,
1839 val,
1840 bits,
1841 row,
1842 col,
1843 idx,
1844 combo.period.unwrap_or(5)
1845 );
1846 }
1847
1848 if bits == 0x33333333_33333333 {
1849 panic!(
1850 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1851 at row {} col {} (flat index {}) with params: period={}",
1852 test,
1853 cfg_idx,
1854 val,
1855 bits,
1856 row,
1857 col,
1858 idx,
1859 combo.period.unwrap_or(5)
1860 );
1861 }
1862 }
1863 }
1864
1865 Ok(())
1866 }
1867
1868 #[cfg(not(debug_assertions))]
1869 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1870 Ok(())
1871 }
1872
1873 gen_batch_tests!(check_batch_default_row);
1874 gen_batch_tests!(check_batch_no_poison);
1875
1876 #[cfg(feature = "proptest")]
1877 fn check_er_property(
1878 test_name: &str,
1879 kernel: Kernel,
1880 ) -> Result<(), Box<dyn std::error::Error>> {
1881 use proptest::prelude::*;
1882 skip_if_unsupported!(kernel, test_name);
1883
1884 let strat = (2usize..=50)
1885 .prop_flat_map(|period| {
1886 let min_len = period * 2;
1887 (
1888 (100.0f64..5000.0f64, 0.01f64..0.1f64),
1889 -0.02f64..0.02f64,
1890 Just(period),
1891 min_len..400,
1892 )
1893 })
1894 .prop_flat_map(|((base_price, volatility), trend, period, len)| {
1895 let price_changes = prop::collection::vec((-1.0f64..1.0f64), len);
1896
1897 (
1898 Just(base_price),
1899 Just(volatility),
1900 Just(trend),
1901 Just(period),
1902 price_changes,
1903 )
1904 })
1905 .prop_map(|(base_price, volatility, trend, period, changes)| {
1906 let mut data = Vec::with_capacity(changes.len());
1907 let mut price = base_price;
1908
1909 for (i, &noise) in changes.iter().enumerate() {
1910 price *= 1.0 + trend;
1911
1912 price *= 1.0 + (noise * volatility);
1913
1914 price = price.max(1.0);
1915 data.push(price);
1916 }
1917
1918 (data, period)
1919 });
1920
1921 proptest::test_runner::TestRunner::default()
1922 .run(&strat, |(data, period)| {
1923 let params = ErParams {
1924 period: Some(period),
1925 };
1926 let input = ErInput::from_slice(&data, params);
1927
1928 let ErOutput { values: out } = er_with_kernel(&input, kernel).unwrap();
1929 let ErOutput { values: ref_out } = er_with_kernel(&input, Kernel::Scalar).unwrap();
1930
1931 prop_assert_eq!(out.len(), data.len());
1932
1933 let warmup = period - 1;
1934 for i in 0..warmup {
1935 prop_assert!(
1936 out[i].is_nan(),
1937 "Expected NaN during warmup at index {}, got {}",
1938 i,
1939 out[i]
1940 );
1941 }
1942
1943 for i in warmup..data.len() {
1944 let val = out[i];
1945 if !val.is_nan() {
1946 prop_assert!(
1947 val >= -1e-10 && val <= 1.0 + 1e-10,
1948 "ER value {} at index {} outside valid range [0, 1]",
1949 val,
1950 i
1951 );
1952 }
1953 }
1954
1955 for i in 0..data.len() {
1956 let y = out[i];
1957 let r = ref_out[i];
1958
1959 if !y.is_finite() || !r.is_finite() {
1960 prop_assert_eq!(
1961 y.to_bits(),
1962 r.to_bits(),
1963 "NaN/Inf mismatch at index {}: {} vs {}",
1964 i,
1965 y,
1966 r
1967 );
1968 } else {
1969 let diff = (y - r).abs();
1970 let ulp_diff = y.to_bits().abs_diff(r.to_bits());
1971 prop_assert!(
1972 diff <= 1e-9 || ulp_diff <= 4,
1973 "Kernel mismatch at index {}: {} vs {} (diff={}, ULP={})",
1974 i,
1975 y,
1976 r,
1977 diff,
1978 ulp_diff
1979 );
1980 }
1981 }
1982
1983 if data.len() >= period + 10 {
1984 for i in (warmup + 1)..data.len() {
1985 if i < period {
1986 continue;
1987 }
1988 let window_start = i + 1 - period;
1989 let window_end = i;
1990
1991 let window = &data[window_start..=window_end];
1992 let is_monotonic_up = window.windows(2).all(|w| w[1] >= w[0] - 1e-10);
1993 let is_monotonic_down = window.windows(2).all(|w| w[1] <= w[0] + 1e-10);
1994 let is_constant = window.windows(2).all(|w| (w[1] - w[0]).abs() < 1e-10);
1995
1996 if !is_constant && (is_monotonic_up || is_monotonic_down) {
1997 let er_val = out[i];
1998 let net_change = (window[window.len() - 1] - window[0]).abs();
1999 if !er_val.is_nan() && net_change > 1e-6 {
2000 prop_assert!(
2001 er_val >= 0.90,
2002 "Expected high ER (>0.90) for monotonic move at index {}, got {}",
2003 i,
2004 er_val
2005 );
2006 }
2007 }
2008 }
2009 }
2010
2011 for i in (warmup + 1)..data.len() {
2012 if i < period {
2013 continue;
2014 }
2015 let window_start = i + 1 - period;
2016 let window_end = i;
2017 let window = &data[window_start..=window_end];
2018 let is_constant = window.windows(2).all(|w| (w[1] - w[0]).abs() < 1e-10);
2019
2020 if is_constant {
2021 let er_val = out[i];
2022
2023 prop_assert!(
2024 er_val.is_nan() || er_val.abs() < 1e-10,
2025 "Constant prices should yield NaN or 0, got {} at index {}",
2026 er_val,
2027 i
2028 );
2029 }
2030 }
2031
2032 for i in warmup..data.len() {
2033 let val = out[i];
2034 if !val.is_nan() {
2035 prop_assert!(
2036 val >= -1e-10,
2037 "ER should be non-negative, got {} at index {}",
2038 val,
2039 i
2040 );
2041 }
2042 }
2043
2044 if period >= 4 && data.len() >= period * 3 {
2045 for i in (warmup + 1)..data.len() {
2046 if i < period {
2047 continue;
2048 }
2049 let window_start = i + 1 - period;
2050 let window_end = i;
2051
2052 let net_change = (data[window_end] - data[window_start]).abs();
2053 let mut total_movement = 0.0;
2054 for j in window_start..window_end {
2055 total_movement += (data[j + 1] - data[j]).abs();
2056 }
2057
2058 if total_movement > 0.0 && net_change / total_movement < 0.3 {
2059 let er_val = out[i];
2060 if !er_val.is_nan() {
2061 prop_assert!(
2062 er_val <= 0.35,
2063 "Expected low ER (<0.35) for choppy market at index {}, got {}",
2064 i,
2065 er_val
2066 );
2067 }
2068 }
2069 }
2070 }
2071
2072 Ok(())
2073 })
2074 .unwrap();
2075
2076 Ok(())
2077 }
2078}