1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5 make_uninit_matrix,
6};
7#[cfg(not(target_arch = "wasm32"))]
8use rayon::prelude::*;
9use std::convert::AsRef;
10use std::mem::MaybeUninit;
11use thiserror::Error;
12
13#[cfg(feature = "python")]
14use crate::utilities::kernel_validation::validate_kernel;
15#[cfg(feature = "python")]
16use numpy;
17#[cfg(feature = "python")]
18use numpy::PyUntypedArrayMethods;
19#[cfg(feature = "python")]
20use pyo3::exceptions::PyValueError;
21#[cfg(feature = "python")]
22use pyo3::prelude::*;
23
24#[cfg(all(feature = "python", feature = "cuda"))]
25use crate::cuda::cuda_available;
26#[cfg(all(feature = "python", feature = "cuda"))]
27use crate::cuda::moving_averages::CudaEma;
28#[cfg(all(feature = "python", feature = "cuda"))]
29use cust::context::Context;
30#[cfg(all(feature = "python", feature = "cuda"))]
31use cust::memory::DeviceBuffer;
32#[cfg(all(feature = "python", feature = "cuda"))]
33use std::sync::Arc;
34
35#[cfg(all(feature = "python", feature = "cuda"))]
36#[pyclass(module = "ta_indicators.cuda", unsendable)]
37pub struct EmaDeviceArrayF32Py {
38 pub(crate) buf: Option<DeviceBuffer<f32>>,
39 pub(crate) rows: usize,
40 pub(crate) cols: usize,
41 pub(crate) _ctx: Arc<Context>,
42 pub(crate) device_id: u32,
43}
44#[cfg(all(feature = "python", feature = "cuda"))]
45#[pymethods]
46impl EmaDeviceArrayF32Py {
47 #[getter]
48 fn __cuda_array_interface__<'py>(
49 &self,
50 py: Python<'py>,
51 ) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
52 let d = pyo3::types::PyDict::new(py);
53 d.set_item("shape", (self.rows, self.cols))?;
54 d.set_item("typestr", "<f4")?;
55 d.set_item(
56 "strides",
57 (
58 self.cols * std::mem::size_of::<f32>(),
59 std::mem::size_of::<f32>(),
60 ),
61 )?;
62 let ptr = self
63 .buf
64 .as_ref()
65 .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?
66 .as_device_ptr()
67 .as_raw() as usize;
68 d.set_item("data", (ptr, false))?;
69 d.set_item("version", 3)?;
70 Ok(d)
71 }
72
73 fn __dlpack_device__(&self) -> (i32, i32) {
74 (2, self.device_id as i32)
75 }
76
77 #[pyo3(signature=(stream=None, max_version=None, dl_device=None, copy=None))]
78 fn __dlpack__<'py>(
79 &mut self,
80 py: Python<'py>,
81 stream: Option<pyo3::PyObject>,
82 max_version: Option<pyo3::PyObject>,
83 dl_device: Option<pyo3::PyObject>,
84 copy: Option<pyo3::PyObject>,
85 ) -> PyResult<PyObject> {
86 let (kdl, alloc_dev) = self.__dlpack_device__();
87 if let Some(dev_obj) = dl_device.as_ref() {
88 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
89 if dev_ty != kdl || dev_id != alloc_dev {
90 let wants_copy = copy
91 .as_ref()
92 .and_then(|c| c.extract::<bool>(py).ok())
93 .unwrap_or(false);
94 if wants_copy {
95 return Err(PyValueError::new_err(
96 "device copy not implemented for __dlpack__",
97 ));
98 } else {
99 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
100 }
101 }
102 }
103 }
104 let _ = stream;
105
106 let buf = self
107 .buf
108 .take()
109 .ok_or_else(|| PyValueError::new_err("__dlpack__ may only be called once"))?;
110
111 let rows = self.rows;
112 let cols = self.cols;
113
114 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
115
116 crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d(
117 py,
118 buf,
119 rows,
120 cols,
121 alloc_dev,
122 max_version_bound,
123 )
124 }
125}
126impl<'a> AsRef<[f64]> for EmaInput<'a> {
127 #[inline(always)]
128 fn as_ref(&self) -> &[f64] {
129 match &self.data {
130 EmaData::Slice(slice) => slice,
131 EmaData::Candles { candles, source } => source_type(candles, source),
132 }
133 }
134}
135
136#[derive(Debug, Clone)]
137pub enum EmaData<'a> {
138 Candles {
139 candles: &'a Candles,
140 source: &'a str,
141 },
142 Slice(&'a [f64]),
143}
144
145#[derive(Debug, Clone)]
146pub struct EmaOutput {
147 pub values: Vec<f64>,
148}
149
150#[derive(Debug, Clone)]
151#[cfg_attr(
152 all(target_arch = "wasm32", feature = "wasm"),
153 derive(Serialize, Deserialize)
154)]
155pub struct EmaParams {
156 pub period: Option<usize>,
157}
158
159impl Default for EmaParams {
160 fn default() -> Self {
161 Self { period: Some(9) }
162 }
163}
164
165#[derive(Debug, Clone)]
166pub struct EmaInput<'a> {
167 pub data: EmaData<'a>,
168 pub params: EmaParams,
169}
170
171impl<'a> EmaInput<'a> {
172 #[inline]
173 pub fn from_candles(c: &'a Candles, s: &'a str, p: EmaParams) -> Self {
174 Self {
175 data: EmaData::Candles {
176 candles: c,
177 source: s,
178 },
179 params: p,
180 }
181 }
182 #[inline]
183 pub fn from_slice(sl: &'a [f64], p: EmaParams) -> Self {
184 Self {
185 data: EmaData::Slice(sl),
186 params: p,
187 }
188 }
189 #[inline]
190 pub fn with_default_candles(c: &'a Candles) -> Self {
191 Self::from_candles(c, "close", EmaParams::default())
192 }
193 #[inline]
194 pub fn get_period(&self) -> usize {
195 self.params.period.unwrap_or(9)
196 }
197}
198
199#[derive(Copy, Clone, Debug)]
200pub struct EmaBuilder {
201 period: Option<usize>,
202 kernel: Kernel,
203}
204
205impl Default for EmaBuilder {
206 fn default() -> Self {
207 Self {
208 period: None,
209 kernel: Kernel::Auto,
210 }
211 }
212}
213
214impl EmaBuilder {
215 #[inline(always)]
216 pub fn new() -> Self {
217 Self::default()
218 }
219 #[inline(always)]
220 pub fn period(mut self, n: usize) -> Self {
221 self.period = Some(n);
222 self
223 }
224 #[inline(always)]
225 pub fn kernel(mut self, k: Kernel) -> Self {
226 self.kernel = k;
227 self
228 }
229
230 #[inline(always)]
231 pub fn apply(self, c: &Candles) -> Result<EmaOutput, EmaError> {
232 let p = EmaParams {
233 period: self.period,
234 };
235 let i = EmaInput::from_candles(c, "close", p);
236 ema_with_kernel(&i, self.kernel)
237 }
238
239 #[inline(always)]
240 pub fn apply_slice(self, d: &[f64]) -> Result<EmaOutput, EmaError> {
241 let p = EmaParams {
242 period: self.period,
243 };
244 let i = EmaInput::from_slice(d, p);
245 ema_with_kernel(&i, self.kernel)
246 }
247
248 #[inline(always)]
249 pub fn into_stream(self) -> Result<EmaStream, EmaError> {
250 let p = EmaParams {
251 period: self.period,
252 };
253 EmaStream::try_new(p)
254 }
255}
256
257#[derive(Debug, Error)]
258pub enum EmaError {
259 #[error("ema: Input data slice is empty.")]
260 EmptyInputData,
261 #[error("ema: All values are NaN.")]
262 AllValuesNaN,
263 #[error("ema: Invalid period: period = {period}, data length = {data_len}")]
264 InvalidPeriod { period: usize, data_len: usize },
265 #[error("ema: Not enough valid data: needed = {needed}, valid = {valid}")]
266 NotEnoughValidData { needed: usize, valid: usize },
267 #[error("ema: Output length mismatch: expected = {expected}, got = {got}")]
268 OutputLengthMismatch { expected: usize, got: usize },
269 #[error("ema: Invalid range: start = {start}, end = {end}, step = {step}")]
270 InvalidRange {
271 start: usize,
272 end: usize,
273 step: usize,
274 },
275 #[error("ema: Invalid kernel for batch API: {0:?}")]
276 InvalidKernelForBatch(Kernel),
277 #[error("ema: arithmetic overflow while computing {context}")]
278 ArithmeticOverflow { context: &'static str },
279}
280
281#[inline]
282pub fn ema(input: &EmaInput) -> Result<EmaOutput, EmaError> {
283 ema_with_kernel(input, Kernel::Auto)
284}
285
286#[inline(always)]
287fn ema_prepare<'a>(
288 input: &'a EmaInput,
289 kernel: Kernel,
290) -> Result<(&'a [f64], usize, usize, f64, f64, Kernel), EmaError> {
291 let data: &[f64] = input.as_ref();
292
293 let len = data.len();
294 if len == 0 {
295 return Err(EmaError::EmptyInputData);
296 }
297
298 let first = data
299 .iter()
300 .position(|x| !x.is_nan())
301 .ok_or(EmaError::AllValuesNaN)?;
302 let period = input.get_period();
303 if period == 0 || period > len {
304 return Err(EmaError::InvalidPeriod {
305 period,
306 data_len: len,
307 });
308 }
309 if len - first < period {
310 return Err(EmaError::NotEnoughValidData {
311 needed: period,
312 valid: len - first,
313 });
314 }
315
316 let alpha = 2.0 / (period as f64 + 1.0);
317 let beta = 1.0 - alpha;
318 let chosen = if matches!(kernel, Kernel::Auto) {
319 Kernel::Scalar
320 } else {
321 kernel
322 };
323 Ok((data, period, first, alpha, beta, chosen))
324}
325
326#[inline(always)]
327fn ema_compute_into(
328 data: &[f64],
329 period: usize,
330 first: usize,
331 alpha: f64,
332 beta: f64,
333 kernel: Kernel,
334 out: &mut [f64],
335) {
336 unsafe {
337 match kernel {
338 Kernel::Scalar | Kernel::ScalarBatch => {
339 ema_scalar_into(data, period, first, alpha, beta, out)
340 }
341 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
342 Kernel::Avx2 | Kernel::Avx2Batch => {
343 ema_avx2_into(data, period, first, alpha, beta, out)
344 }
345 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
346 Kernel::Avx512 | Kernel::Avx512Batch => {
347 ema_avx512_into(data, period, first, alpha, beta, out)
348 }
349
350 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
351 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
352 ema_scalar_into(data, period, first, alpha, beta, out)
353 }
354 _ => unreachable!(),
355 }
356 }
357}
358
359pub fn ema_with_kernel(input: &EmaInput, kernel: Kernel) -> Result<EmaOutput, EmaError> {
360 let (data, period, first, alpha, beta, chosen) = ema_prepare(input, kernel)?;
361
362 let mut out = alloc_with_nan_prefix(data.len(), first);
363 ema_compute_into(data, period, first, alpha, beta, chosen, &mut out);
364
365 Ok(EmaOutput { values: out })
366}
367
368#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
369pub fn ema_into(input: &EmaInput, out: &mut [f64]) -> Result<(), EmaError> {
370 let (data, period, first, alpha, beta, chosen) = ema_prepare(input, Kernel::Auto)?;
371
372 if out.len() != data.len() {
373 return Err(EmaError::OutputLengthMismatch {
374 expected: data.len(),
375 got: out.len(),
376 });
377 }
378
379 let warm = first.min(out.len());
380 for i in 0..warm {
381 out[i] = f64::from_bits(0x7ff8_0000_0000_0000);
382 }
383
384 ema_compute_into(data, period, first, alpha, beta, chosen, out);
385
386 Ok(())
387}
388
389#[inline(always)]
390fn is_finite_fast(x: f64) -> bool {
391 const EXP_MASK: u64 = 0x7ff0_0000_0000_0000;
392 (x.to_bits() & EXP_MASK) != EXP_MASK
393}
394
395#[inline]
396pub fn ema_into_slice(dst: &mut [f64], input: &EmaInput, kern: Kernel) -> Result<(), EmaError> {
397 let (data, period, first, alpha, beta, chosen) = ema_prepare(input, kern)?;
398
399 if dst.len() != data.len() {
400 return Err(EmaError::OutputLengthMismatch {
401 expected: data.len(),
402 got: dst.len(),
403 });
404 }
405
406 ema_compute_into(data, period, first, alpha, beta, chosen, dst);
407
408 for v in &mut dst[..first] {
409 *v = f64::NAN;
410 }
411
412 Ok(())
413}
414
415#[inline(always)]
416pub unsafe fn ema_scalar(
417 data: &[f64],
418 period: usize,
419 first_val: usize,
420 out: &mut Vec<f64>,
421) -> Result<EmaOutput, EmaError> {
422 let alpha = 2.0 / (period as f64 + 1.0);
423 let beta = 1.0 - alpha;
424 ema_scalar_into(data, period, first_val, alpha, beta, out);
425 let values = std::mem::take(out);
426 Ok(EmaOutput { values })
427}
428
429#[inline(always)]
430unsafe fn ema_scalar_into(
431 data: &[f64],
432 period: usize,
433 first_val: usize,
434 alpha: f64,
435 beta: f64,
436 out: &mut [f64],
437) {
438 let len = data.len();
439 debug_assert_eq!(out.len(), len);
440
441 let mut mean = *data.get_unchecked(first_val);
442 *out.get_unchecked_mut(first_val) = mean;
443 let mut valid_count = 1usize;
444
445 let warmup_end = (first_val + period).min(len);
446 for i in (first_val + 1)..warmup_end {
447 let x = *data.get_unchecked(i);
448 if is_finite_fast(x) {
449 valid_count += 1;
450 let vc = valid_count as f64;
451 mean = ((vc - 1.0) * mean + x) / vc;
452 }
453
454 *out.get_unchecked_mut(i) = mean;
455 }
456
457 if warmup_end < len {
458 let mut prev = mean;
459 for i in warmup_end..len {
460 let x = *data.get_unchecked(i);
461 if is_finite_fast(x) {
462 prev = beta.mul_add(prev, alpha * x);
463 }
464
465 *out.get_unchecked_mut(i) = prev;
466 }
467 }
468}
469
470#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
471#[inline(always)]
472pub unsafe fn ema_avx2(
473 data: &[f64],
474 period: usize,
475 first_val: usize,
476 out: &mut Vec<f64>,
477) -> Result<EmaOutput, EmaError> {
478 ema_scalar(data, period, first_val, out)
479}
480
481#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
482#[inline(always)]
483unsafe fn ema_avx2_into(
484 data: &[f64],
485 period: usize,
486 first_val: usize,
487 alpha: f64,
488 beta: f64,
489 out: &mut [f64],
490) {
491 ema_scalar_into(data, period, first_val, alpha, beta, out)
492}
493
494#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
495#[inline(always)]
496pub unsafe fn ema_avx512(
497 data: &[f64],
498 period: usize,
499 first_val: usize,
500 out: &mut Vec<f64>,
501) -> Result<EmaOutput, EmaError> {
502 ema_scalar(data, period, first_val, out)
503}
504
505#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
506#[inline(always)]
507unsafe fn ema_avx512_into(
508 data: &[f64],
509 period: usize,
510 first_val: usize,
511 alpha: f64,
512 beta: f64,
513 out: &mut [f64],
514) {
515 ema_scalar_into(data, period, first_val, alpha, beta, out)
516}
517
518#[derive(Debug, Clone)]
519pub struct EmaStream {
520 period: usize,
521 alpha: f64,
522 beta: f64,
523 count: usize,
524 mean: f64,
525 filled: bool,
526
527 inv: Box<[f64]>,
528}
529
530impl EmaStream {
531 #[inline]
532 pub fn try_new(params: EmaParams) -> Result<Self, EmaError> {
533 let period = params.period.unwrap_or(9);
534 if period == 0 {
535 return Err(EmaError::InvalidPeriod {
536 period,
537 data_len: 0,
538 });
539 }
540
541 let alpha = 2.0 / (period as f64 + 1.0);
542 let beta = 1.0 - alpha;
543
544 let mut inv = Vec::with_capacity(period);
545 for n in 1..=period {
546 inv.push(1.0 / n as f64);
547 }
548
549 Ok(Self {
550 period,
551 alpha,
552 beta,
553 count: 0,
554 mean: f64::NAN,
555 filled: false,
556 inv: inv.into_boxed_slice(),
557 })
558 }
559
560 #[inline(always)]
561 pub fn update(&mut self, x: f64) -> Option<f64> {
562 if !is_finite_fast(x) {
563 return if self.filled { Some(self.mean) } else { None };
564 }
565
566 self.count += 1;
567 let c = self.count;
568
569 if c == 1 {
570 self.mean = x;
571 } else if c <= self.period {
572 let inv = self.inv[c - 1];
573 self.mean = (x - self.mean).mul_add(inv, self.mean);
574 } else {
575 self.mean = self.beta.mul_add(self.mean, self.alpha * x);
576 }
577
578 if !self.filled && c >= self.period {
579 self.filled = true;
580 }
581 if self.filled {
582 Some(self.mean)
583 } else {
584 None
585 }
586 }
587}
588
589#[derive(Clone, Debug)]
590pub struct EmaBatchRange {
591 pub period: (usize, usize, usize),
592}
593
594impl Default for EmaBatchRange {
595 fn default() -> Self {
596 Self {
597 period: (9, 258, 1),
598 }
599 }
600}
601
602#[derive(Clone, Debug, Default)]
603pub struct EmaBatchBuilder {
604 range: EmaBatchRange,
605 kernel: Kernel,
606}
607
608impl EmaBatchBuilder {
609 pub fn new() -> Self {
610 Self::default()
611 }
612 pub fn kernel(mut self, k: Kernel) -> Self {
613 self.kernel = k;
614 self
615 }
616
617 #[inline]
618 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
619 self.range.period = (start, end, step);
620 self
621 }
622 #[inline]
623 pub fn period_static(mut self, p: usize) -> Self {
624 self.range.period = (p, p, 0);
625 self
626 }
627
628 pub fn apply_slice(self, data: &[f64]) -> Result<EmaBatchOutput, EmaError> {
629 ema_batch_with_kernel(data, &self.range, self.kernel)
630 }
631
632 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<EmaBatchOutput, EmaError> {
633 EmaBatchBuilder::new().kernel(k).apply_slice(data)
634 }
635
636 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<EmaBatchOutput, EmaError> {
637 let slice = source_type(c, src);
638 self.apply_slice(slice)
639 }
640
641 pub fn with_default_candles(c: &Candles) -> Result<EmaBatchOutput, EmaError> {
642 EmaBatchBuilder::new()
643 .kernel(Kernel::Auto)
644 .apply_candles(c, "close")
645 }
646}
647
648pub fn ema_batch_with_kernel(
649 data: &[f64],
650 sweep: &EmaBatchRange,
651 k: Kernel,
652) -> Result<EmaBatchOutput, EmaError> {
653 let kernel = match k {
654 Kernel::Auto => detect_best_batch_kernel(),
655 other if other.is_batch() => other,
656 _ => return Err(EmaError::InvalidKernelForBatch(k)),
657 };
658
659 let simd = match kernel {
660 Kernel::Avx512Batch => Kernel::Avx512,
661 Kernel::Avx2Batch => Kernel::Avx2,
662 Kernel::ScalarBatch => Kernel::Scalar,
663 _ => unreachable!(),
664 };
665 ema_batch_par_slice(data, sweep, simd)
666}
667
668#[derive(Clone, Debug)]
669pub struct EmaBatchOutput {
670 pub values: Vec<f64>,
671 pub combos: Vec<EmaParams>,
672 pub rows: usize,
673 pub cols: usize,
674}
675impl EmaBatchOutput {
676 pub fn row_for_params(&self, p: &EmaParams) -> Option<usize> {
677 self.combos
678 .iter()
679 .position(|c| c.period.unwrap_or(9) == p.period.unwrap_or(9))
680 }
681
682 pub fn values_for(&self, p: &EmaParams) -> Option<&[f64]> {
683 self.row_for_params(p).map(|row| {
684 let start = row * self.cols;
685 &self.values[start..start + self.cols]
686 })
687 }
688}
689
690#[inline(always)]
691fn expand_grid(r: &EmaBatchRange) -> Result<Vec<EmaParams>, EmaError> {
692 fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
693 if step == 0 || start == end {
694 return vec![start];
695 }
696 let (lo, hi) = if start <= end {
697 (start, end)
698 } else {
699 (end, start)
700 };
701 (lo..=hi).step_by(step).collect()
702 }
703
704 let periods = axis_usize(r.period);
705 if periods.is_empty() {
706 return Err(EmaError::InvalidRange {
707 start: r.period.0,
708 end: r.period.1,
709 step: r.period.2,
710 });
711 }
712 let mut out = Vec::with_capacity(periods.len());
713 for &p in &periods {
714 out.push(EmaParams { period: Some(p) });
715 }
716 Ok(out)
717}
718
719#[inline(always)]
720pub fn ema_batch_slice(
721 data: &[f64],
722 sweep: &EmaBatchRange,
723 kern: Kernel,
724) -> Result<EmaBatchOutput, EmaError> {
725 ema_batch_inner(data, sweep, kern, false)
726}
727
728#[inline(always)]
729pub fn ema_batch_par_slice(
730 data: &[f64],
731 sweep: &EmaBatchRange,
732 kern: Kernel,
733) -> Result<EmaBatchOutput, EmaError> {
734 ema_batch_inner(data, sweep, kern, true)
735}
736
737#[inline(always)]
738fn ema_batch_inner(
739 data: &[f64],
740 sweep: &EmaBatchRange,
741 kern: Kernel,
742 parallel: bool,
743) -> Result<EmaBatchOutput, EmaError> {
744 let combos = expand_grid(sweep)?;
745 let rows = combos.len();
746 let cols = data.len();
747
748 if cols == 0 {
749 return Err(EmaError::EmptyInputData);
750 }
751
752 let first = data
753 .iter()
754 .position(|x| !x.is_nan())
755 .ok_or(EmaError::AllValuesNaN)?;
756
757 let _total = rows.checked_mul(cols).ok_or(EmaError::ArithmeticOverflow {
758 context: "rows*cols",
759 })?;
760 let mut buf_mu = make_uninit_matrix(rows, cols);
761
762 let warm: Vec<usize> = std::iter::repeat(first).take(rows).collect();
763 init_matrix_prefixes(&mut buf_mu, cols, &warm);
764
765 let mut guard = core::mem::ManuallyDrop::new(buf_mu);
766 let out: &mut [f64] =
767 unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
768
769 let returned_combos = ema_batch_inner_into(data, sweep, kern, parallel, out)?;
770
771 let values = unsafe {
772 Vec::from_raw_parts(
773 guard.as_mut_ptr() as *mut f64,
774 guard.len(),
775 guard.capacity(),
776 )
777 };
778
779 Ok(EmaBatchOutput {
780 values,
781 combos: returned_combos,
782 rows,
783 cols,
784 })
785}
786
787#[inline(always)]
788fn ema_batch_inner_into(
789 data: &[f64],
790 sweep: &EmaBatchRange,
791 kern: Kernel,
792 parallel: bool,
793 out: &mut [f64],
794) -> Result<Vec<EmaParams>, EmaError> {
795 let combos = expand_grid(sweep)?;
796
797 if data.is_empty() {
798 return Err(EmaError::EmptyInputData);
799 }
800
801 let first = data
802 .iter()
803 .position(|x| !x.is_nan())
804 .ok_or(EmaError::AllValuesNaN)?;
805 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
806 if data.len() - first < max_p {
807 return Err(EmaError::NotEnoughValidData {
808 needed: max_p,
809 valid: data.len() - first,
810 });
811 }
812
813 let rows = combos.len();
814 let cols = data.len();
815
816 let raw = unsafe {
817 core::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
818 };
819
820 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
821 let period = combos[row].period.unwrap();
822
823 let dst = core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
824
825 match kern {
826 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
827 Kernel::Avx512 => ema_row_avx512(data, first, period, dst),
828 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
829 Kernel::Avx2 => ema_row_avx2(data, first, period, dst),
830 _ => ema_row_scalar(data, first, period, dst),
831 }
832 };
833
834 if parallel {
835 #[cfg(not(target_arch = "wasm32"))]
836 {
837 raw.par_chunks_mut(cols)
838 .enumerate()
839 .for_each(|(row, slice)| do_row(row, slice));
840 }
841
842 #[cfg(target_arch = "wasm32")]
843 {
844 for (row, slice) in raw.chunks_mut(cols).enumerate() {
845 do_row(row, slice);
846 }
847 }
848 } else {
849 for (row, slice) in raw.chunks_mut(cols).enumerate() {
850 do_row(row, slice);
851 }
852 }
853
854 Ok(combos)
855}
856
857#[inline(always)]
858unsafe fn ema_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
859 let alpha = 2.0 / (period as f64 + 1.0);
860 let beta = 1.0 - alpha;
861
862 let len = data.len();
863
864 let mut mean = unsafe { *data.get_unchecked(first) };
865 unsafe { *out.get_unchecked_mut(first) = mean };
866 let mut valid_count = 1usize;
867
868 let warmup_end = (first + period).min(len);
869 for i in (first + 1)..warmup_end {
870 let x = unsafe { *data.get_unchecked(i) };
871 if is_finite_fast(x) {
872 valid_count += 1;
873 let vc = valid_count as f64;
874 mean = ((vc - 1.0) * mean + x) / vc;
875 }
876
877 unsafe { *out.get_unchecked_mut(i) = mean };
878 }
879
880 if warmup_end < len {
881 let mut prev = mean;
882 for i in warmup_end..len {
883 let x = unsafe { *data.get_unchecked(i) };
884 if is_finite_fast(x) {
885 prev = beta.mul_add(prev, alpha * x);
886 }
887
888 unsafe { *out.get_unchecked_mut(i) = prev };
889 }
890 }
891}
892
893#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
894#[inline(always)]
895unsafe fn ema_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
896 ema_row_scalar(data, first, period, out);
897}
898
899#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
900#[inline(always)]
901unsafe fn ema_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
902 ema_row_scalar(data, first, period, out);
903}
904
905#[cfg(test)]
906mod tests {
907 use super::*;
908 use crate::skip_if_unsupported;
909 use crate::utilities::data_loader::read_candles_from_csv;
910 use proptest::prelude::*;
911
912 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
913 #[test]
914 fn test_ema_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
915 let mut data = Vec::with_capacity(256);
916 for _ in 0..5 {
917 data.push(f64::NAN);
918 }
919 for i in 0..251 {
920 let x = (i as f64).sin() * 3.14159 + 100.0 + ((i % 7) as f64) * 0.01;
921 data.push(x);
922 }
923
924 let input = EmaInput::from_slice(&data, EmaParams::default());
925 let baseline = ema(&input)?.values;
926
927 let mut out = vec![0.0; data.len()];
928 ema_into(&input, &mut out)?;
929
930 assert_eq!(baseline.len(), out.len());
931 fn eq_or_both_nan(a: f64, b: f64) -> bool {
932 (a.is_nan() && b.is_nan()) || (a == b) || (a - b).abs() <= 1e-12
933 }
934 for (i, (&a, &b)) in baseline.iter().zip(out.iter()).enumerate() {
935 assert!(
936 eq_or_both_nan(a, b),
937 "mismatch at index {}: api={} into={}",
938 i,
939 a,
940 b
941 );
942 }
943 Ok(())
944 }
945
946 fn check_ema_partial_params(
947 test_name: &str,
948 kernel: Kernel,
949 ) -> Result<(), Box<dyn std::error::Error>> {
950 skip_if_unsupported!(kernel, test_name);
951 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
952 let candles = read_candles_from_csv(file_path)?;
953
954 let default_params = EmaParams { period: None };
955 let input = EmaInput::from_candles(&candles, "close", default_params);
956 let output = ema_with_kernel(&input, kernel)?;
957 assert_eq!(output.values.len(), candles.close.len());
958 Ok(())
959 }
960
961 fn check_ema_accuracy(
962 test_name: &str,
963 kernel: Kernel,
964 ) -> Result<(), Box<dyn std::error::Error>> {
965 skip_if_unsupported!(kernel, test_name);
966 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
967 let candles = read_candles_from_csv(file_path)?;
968
969 let input = EmaInput::from_candles(&candles, "close", EmaParams::default());
970 let result = ema_with_kernel(&input, kernel)?;
971 let expected_last_five = [59302.2, 59277.9, 59230.2, 59215.1, 59103.1];
972 let start = result.values.len().saturating_sub(5);
973 for (i, &val) in result.values[start..].iter().enumerate() {
974 let diff = (val - expected_last_five[i]).abs();
975 assert!(
976 diff < 1e-1,
977 "[{}] EMA {:?} mismatch at idx {}: got {}, expected {}",
978 test_name,
979 kernel,
980 i,
981 val,
982 expected_last_five[i]
983 );
984 }
985 Ok(())
986 }
987
988 fn check_ema_default_candles(
989 test_name: &str,
990 kernel: Kernel,
991 ) -> Result<(), Box<dyn std::error::Error>> {
992 skip_if_unsupported!(kernel, test_name);
993 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
994 let candles = read_candles_from_csv(file_path)?;
995
996 let input = EmaInput::with_default_candles(&candles);
997 match input.data {
998 EmaData::Candles { source, .. } => assert_eq!(source, "close"),
999 _ => panic!("Expected EmaData::Candles"),
1000 }
1001 let output = ema_with_kernel(&input, kernel)?;
1002 assert_eq!(output.values.len(), candles.close.len());
1003 Ok(())
1004 }
1005
1006 fn check_ema_zero_period(
1007 test_name: &str,
1008 kernel: Kernel,
1009 ) -> Result<(), Box<dyn std::error::Error>> {
1010 skip_if_unsupported!(kernel, test_name);
1011 let input_data = [10.0, 20.0, 30.0];
1012 let params = EmaParams { period: Some(0) };
1013 let input = EmaInput::from_slice(&input_data, params);
1014 let res = ema_with_kernel(&input, kernel);
1015 assert!(
1016 res.is_err(),
1017 "[{}] EMA should fail with zero period",
1018 test_name
1019 );
1020 Ok(())
1021 }
1022
1023 fn check_ema_period_exceeds_length(
1024 test_name: &str,
1025 kernel: Kernel,
1026 ) -> Result<(), Box<dyn std::error::Error>> {
1027 skip_if_unsupported!(kernel, test_name);
1028 let data_small = [10.0, 20.0, 30.0];
1029 let params = EmaParams { period: Some(10) };
1030 let input = EmaInput::from_slice(&data_small, params);
1031 let res = ema_with_kernel(&input, kernel);
1032 assert!(
1033 res.is_err(),
1034 "[{}] EMA should fail with period exceeding length",
1035 test_name
1036 );
1037 Ok(())
1038 }
1039
1040 fn check_ema_very_small_dataset(
1041 test_name: &str,
1042 kernel: Kernel,
1043 ) -> Result<(), Box<dyn std::error::Error>> {
1044 skip_if_unsupported!(kernel, test_name);
1045 let single_point = [42.0];
1046 let params = EmaParams { period: Some(9) };
1047 let input = EmaInput::from_slice(&single_point, params);
1048 let res = ema_with_kernel(&input, kernel);
1049 assert!(
1050 res.is_err(),
1051 "[{}] EMA should fail with insufficient data",
1052 test_name
1053 );
1054 Ok(())
1055 }
1056
1057 fn check_ema_empty_input(
1058 test_name: &str,
1059 kernel: Kernel,
1060 ) -> Result<(), Box<dyn std::error::Error>> {
1061 skip_if_unsupported!(kernel, test_name);
1062 let empty: [f64; 0] = [];
1063 let input = EmaInput::from_slice(&empty, EmaParams::default());
1064 let res = ema_with_kernel(&input, kernel);
1065 assert!(
1066 matches!(res, Err(EmaError::EmptyInputData)),
1067 "[{}] EMA should fail with empty input",
1068 test_name
1069 );
1070 Ok(())
1071 }
1072
1073 fn check_ema_reinput(
1074 test_name: &str,
1075 kernel: Kernel,
1076 ) -> Result<(), Box<dyn std::error::Error>> {
1077 skip_if_unsupported!(kernel, test_name);
1078 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1079 let candles = read_candles_from_csv(file_path)?;
1080
1081 let first_params = EmaParams { period: Some(9) };
1082 let first_input = EmaInput::from_candles(&candles, "close", first_params);
1083 let first_result = ema_with_kernel(&first_input, kernel)?;
1084
1085 let second_params = EmaParams { period: Some(5) };
1086 let second_input = EmaInput::from_slice(&first_result.values, second_params);
1087 let second_result = ema_with_kernel(&second_input, kernel)?;
1088
1089 assert_eq!(second_result.values.len(), first_result.values.len());
1090 if second_result.values.len() > 240 {
1091 for (i, &val) in second_result.values[240..].iter().enumerate() {
1092 assert!(
1093 !val.is_nan(),
1094 "[{}] Found unexpected NaN at out-index {}",
1095 test_name,
1096 240 + i
1097 );
1098 }
1099 }
1100 Ok(())
1101 }
1102
1103 fn check_ema_nan_handling(
1104 test_name: &str,
1105 kernel: Kernel,
1106 ) -> Result<(), Box<dyn std::error::Error>> {
1107 skip_if_unsupported!(kernel, test_name);
1108 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1109 let candles = read_candles_from_csv(file_path)?;
1110
1111 let input = EmaInput::from_candles(&candles, "close", EmaParams { period: Some(9) });
1112 let res = ema_with_kernel(&input, kernel)?;
1113 assert_eq!(res.values.len(), candles.close.len());
1114 if res.values.len() > 240 {
1115 for (i, &val) in res.values[240..].iter().enumerate() {
1116 assert!(
1117 !val.is_nan(),
1118 "[{}] Found unexpected NaN at out-index {}",
1119 test_name,
1120 240 + i
1121 );
1122 }
1123 }
1124 Ok(())
1125 }
1126
1127 fn check_ema_streaming(
1128 test_name: &str,
1129 kernel: Kernel,
1130 ) -> Result<(), Box<dyn std::error::Error>> {
1131 skip_if_unsupported!(kernel, test_name);
1132
1133 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1134 let candles = read_candles_from_csv(file_path)?;
1135
1136 let period = 9;
1137 let warm_up = 240;
1138
1139 let input = EmaInput::from_candles(
1140 &candles,
1141 "close",
1142 EmaParams {
1143 period: Some(period),
1144 },
1145 );
1146 let batch_output = ema_with_kernel(&input, kernel)?.values;
1147
1148 let mut stream = EmaStream::try_new(EmaParams {
1149 period: Some(period),
1150 })?;
1151 let mut stream_values = Vec::with_capacity(candles.close.len());
1152
1153 for (i, &price) in candles.close.iter().enumerate() {
1154 let stream_val = stream.update(price);
1155
1156 if i < period - 1 {
1157 assert!(
1158 stream_val.is_none(),
1159 "[{}] Stream should return None during warmup at idx {}",
1160 test_name,
1161 i
1162 );
1163 stream_values.push(f64::NAN);
1164 } else {
1165 stream_values.push(stream_val.unwrap_or(f64::NAN));
1166 }
1167 }
1168
1169 assert_eq!(batch_output.len(), stream_values.len());
1170
1171 for (i, (&b, &s)) in batch_output
1172 .iter()
1173 .zip(&stream_values)
1174 .enumerate()
1175 .skip(warm_up)
1176 {
1177 if b.is_nan() && s.is_nan() {
1178 continue;
1179 }
1180 let diff = (b - s).abs();
1181 assert!(
1182 diff < 1e-9,
1183 "[{}] EMA streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1184 test_name,
1185 i,
1186 b,
1187 s,
1188 diff
1189 );
1190 }
1191 Ok(())
1192 }
1193
1194 #[cfg(debug_assertions)]
1195 fn check_ema_no_poison(
1196 test_name: &str,
1197 kernel: Kernel,
1198 ) -> Result<(), Box<dyn std::error::Error>> {
1199 skip_if_unsupported!(kernel, test_name);
1200
1201 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1202 let candles = read_candles_from_csv(file_path)?;
1203
1204 let test_periods = vec![2, 5, 9, 14, 20, 50, 100, 200];
1205 let test_sources = vec!["open", "high", "low", "close", "hl2", "hlc3", "ohlc4"];
1206
1207 for period in &test_periods {
1208 for source in &test_sources {
1209 let input = EmaInput::from_candles(
1210 &candles,
1211 source,
1212 EmaParams {
1213 period: Some(*period),
1214 },
1215 );
1216 let output = ema_with_kernel(&input, kernel)?;
1217
1218 for (i, &val) in output.values.iter().enumerate() {
1219 if val.is_nan() {
1220 continue;
1221 }
1222
1223 let bits = val.to_bits();
1224
1225 if bits == 0x11111111_11111111 {
1226 panic!(
1227 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} with period={}, source={}",
1228 test_name, val, bits, i, period, source
1229 );
1230 }
1231
1232 if bits == 0x22222222_22222222 {
1233 panic!(
1234 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} with period={}, source={}",
1235 test_name, val, bits, i, period, source
1236 );
1237 }
1238
1239 if bits == 0x33333333_33333333 {
1240 panic!(
1241 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} with period={}, source={}",
1242 test_name, val, bits, i, period, source
1243 );
1244 }
1245 }
1246 }
1247 }
1248
1249 Ok(())
1250 }
1251
1252 #[cfg(not(debug_assertions))]
1253 fn check_ema_no_poison(
1254 _test_name: &str,
1255 _kernel: Kernel,
1256 ) -> Result<(), Box<dyn std::error::Error>> {
1257 Ok(())
1258 }
1259
1260 fn check_ema_property(
1261 test_name: &str,
1262 kernel: Kernel,
1263 ) -> Result<(), Box<dyn std::error::Error>> {
1264 use proptest::prelude::*;
1265 skip_if_unsupported!(kernel, test_name);
1266
1267 let strat = (1usize..=100).prop_flat_map(|period| {
1268 (
1269 prop::collection::vec(
1270 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1271 period + 10..400,
1272 ),
1273 Just(period),
1274 )
1275 });
1276
1277 proptest::test_runner::TestRunner::default()
1278 .run(&strat, |(data, period)| {
1279 let params = EmaParams {
1280 period: Some(period),
1281 };
1282 let input = EmaInput::from_slice(&data, params);
1283
1284 let EmaOutput { values: out } = ema_with_kernel(&input, kernel).unwrap();
1285
1286 let EmaOutput { values: ref_out } =
1287 ema_with_kernel(&input, Kernel::Scalar).unwrap();
1288
1289 let alpha = 2.0 / (period as f64 + 1.0);
1290 let beta = 1.0 - alpha;
1291
1292 let first_valid = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1293
1294 for i in 0..data.len() {
1295 let y = out[i];
1296 let r = ref_out[i];
1297
1298 if i < first_valid {
1299 prop_assert!(
1300 y.is_nan(),
1301 "[{}] Expected NaN during warmup at idx {}, got {}",
1302 test_name,
1303 i,
1304 y
1305 );
1306 continue;
1307 }
1308
1309 if i >= first_valid {
1310 let window = &data[first_valid..=i];
1311 let lo = window
1312 .iter()
1313 .cloned()
1314 .filter(|x| x.is_finite())
1315 .fold(f64::INFINITY, f64::min);
1316 let hi = window
1317 .iter()
1318 .cloned()
1319 .filter(|x| x.is_finite())
1320 .fold(f64::NEG_INFINITY, f64::max);
1321
1322 if !y.is_nan() && lo.is_finite() && hi.is_finite() {
1323 prop_assert!(
1324 y >= lo - 1e-9 && y <= hi + 1e-9,
1325 "[{}] idx {}: {} not in [{}, {}]",
1326 test_name,
1327 i,
1328 y,
1329 lo,
1330 hi
1331 );
1332 }
1333 }
1334
1335 if period == 1 && i >= first_valid && data[i].is_finite() {
1336 prop_assert!(
1337 (y - data[i]).abs() <= 1e-10,
1338 "[{}] Period=1 mismatch at idx {}: {} vs {}",
1339 test_name,
1340 i,
1341 y,
1342 data[i]
1343 );
1344 }
1345
1346 if i >= first_valid + period {
1347 let window_start = i.saturating_sub(period);
1348 let window = &data[window_start..=i];
1349 if window
1350 .iter()
1351 .all(|&x| (x - data[window_start]).abs() < 1e-10)
1352 {
1353 let expected = data[window_start];
1354 prop_assert!(
1355 (y - expected).abs() <= 1e-6,
1356 "[{}] Constant data convergence failed at idx {}: {} vs {}",
1357 test_name,
1358 i,
1359 y,
1360 expected
1361 );
1362 }
1363 }
1364
1365 if !y.is_finite() || !r.is_finite() {
1366 prop_assert!(
1367 y.to_bits() == r.to_bits(),
1368 "[{}] NaN/infinite mismatch at idx {}: {} vs {}",
1369 test_name,
1370 i,
1371 y,
1372 r
1373 );
1374 } else {
1375 let abs_diff = (y - r).abs();
1376 let rel_diff = if r.abs() > 1e-10 {
1377 abs_diff / r.abs()
1378 } else {
1379 abs_diff
1380 };
1381
1382 prop_assert!(
1383 abs_diff <= 1e-9 || rel_diff <= 1e-9,
1384 "[{}] Kernel mismatch at idx {}: {} vs {} (abs_diff={}, rel_diff={})",
1385 test_name,
1386 i,
1387 y,
1388 r,
1389 abs_diff,
1390 rel_diff
1391 );
1392 }
1393
1394 if i >= first_valid + period
1395 && y.is_finite()
1396 && out[i - 1].is_finite()
1397 && data[i].is_finite()
1398 {
1399 let expected_ema = alpha * data[i] + beta * out[i - 1];
1400 let diff = (y - expected_ema).abs();
1401
1402 prop_assert!(
1403 diff <= 1e-9 * ((i - first_valid) as f64).max(1.0),
1404 "[{}] EMA recursive property failed at idx {}: {} vs {} (diff={})",
1405 test_name,
1406 i,
1407 y,
1408 expected_ema,
1409 diff
1410 );
1411 }
1412
1413 if i >= first_valid + period * 2 {
1414 let historical = &data[first_valid..=i];
1415 let hist_min = historical
1416 .iter()
1417 .filter(|x| x.is_finite())
1418 .fold(f64::INFINITY, |a, &b| a.min(b));
1419 let hist_max = historical
1420 .iter()
1421 .filter(|x| x.is_finite())
1422 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
1423
1424 if hist_min.is_finite() && hist_max.is_finite() && y.is_finite() {
1425 prop_assert!(
1426 y >= hist_min - 1e-6 && y <= hist_max + 1e-6,
1427 "[{}] EMA outside historical bounds at idx {}: {} not in [{}, {}]",
1428 test_name,
1429 i,
1430 y,
1431 hist_min,
1432 hist_max
1433 );
1434 }
1435 }
1436 }
1437
1438 if first_valid < data.len() && out[first_valid].is_finite() {
1439 prop_assert!(
1440 (out[first_valid] - data[first_valid]).abs() <= 1e-10,
1441 "[{}] First valid output should equal first valid input: {} vs {}",
1442 test_name,
1443 out[first_valid],
1444 data[first_valid]
1445 );
1446 }
1447
1448 Ok(())
1449 })
1450 .unwrap();
1451
1452 Ok(())
1453 }
1454
1455 macro_rules! generate_all_ema_tests {
1456 ($($test_fn:ident),*) => {
1457 paste::paste! {
1458 $(
1459 #[test]
1460 fn [<$test_fn _scalar_f64>]() {
1461 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1462 }
1463
1464 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1465 #[test]
1466 fn [<$test_fn _avx2_f64>]() {
1467 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1468 }
1469
1470 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1471 #[test]
1472 fn [<$test_fn _avx512_f64>]() {
1473 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1474 }
1475 )*
1476 }
1477 }
1478 }
1479
1480 generate_all_ema_tests!(
1481 check_ema_partial_params,
1482 check_ema_accuracy,
1483 check_ema_default_candles,
1484 check_ema_zero_period,
1485 check_ema_period_exceeds_length,
1486 check_ema_very_small_dataset,
1487 check_ema_empty_input,
1488 check_ema_reinput,
1489 check_ema_nan_handling,
1490 check_ema_streaming,
1491 check_ema_property,
1492 check_ema_no_poison
1493 );
1494
1495 fn check_batch_default_row(
1496 test: &str,
1497 kernel: Kernel,
1498 ) -> Result<(), Box<dyn std::error::Error>> {
1499 skip_if_unsupported!(kernel, test);
1500
1501 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1502 let c = read_candles_from_csv(file)?;
1503
1504 let output = EmaBatchBuilder::new()
1505 .kernel(kernel)
1506 .apply_candles(&c, "close")?;
1507
1508 let def = EmaParams::default();
1509 let row = output.values_for(&def).expect("default row missing");
1510
1511 assert_eq!(row.len(), c.close.len());
1512
1513 let expected = [59302.2, 59277.9, 59230.2, 59215.1, 59103.1];
1514 let start = row.len() - 5;
1515 for (i, &v) in row[start..].iter().enumerate() {
1516 assert!(
1517 (v - expected[i]).abs() < 1e-1,
1518 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1519 );
1520 }
1521 Ok(())
1522 }
1523
1524 #[cfg(debug_assertions)]
1525 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
1526 skip_if_unsupported!(kernel, test);
1527
1528 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1529 let c = read_candles_from_csv(file)?;
1530
1531 let test_sources = vec!["open", "high", "low", "close", "hl2", "hlc3", "ohlc4"];
1532
1533 for source in &test_sources {
1534 let output = EmaBatchBuilder::new()
1535 .kernel(kernel)
1536 .period_range(2, 200, 3)
1537 .apply_candles(&c, source)?;
1538
1539 for (idx, &val) in output.values.iter().enumerate() {
1540 if val.is_nan() {
1541 continue;
1542 }
1543
1544 let bits = val.to_bits();
1545 let row = idx / output.cols;
1546 let col = idx % output.cols;
1547
1548 if bits == 0x11111111_11111111 {
1549 panic!(
1550 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (flat index {}) with source={}",
1551 test, val, bits, row, col, idx, source
1552 );
1553 }
1554
1555 if bits == 0x22222222_22222222 {
1556 panic!(
1557 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (flat index {}) with source={}",
1558 test, val, bits, row, col, idx, source
1559 );
1560 }
1561
1562 if bits == 0x33333333_33333333 {
1563 panic!(
1564 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (flat index {}) with source={}",
1565 test, val, bits, row, col, idx, source
1566 );
1567 }
1568 }
1569 }
1570
1571 let edge_case_ranges = vec![(2, 5, 1), (190, 200, 2), (50, 100, 10)];
1572 for (start, end, step) in edge_case_ranges {
1573 let output = EmaBatchBuilder::new()
1574 .kernel(kernel)
1575 .period_range(start, end, step)
1576 .apply_candles(&c, "close")?;
1577
1578 for (idx, &val) in output.values.iter().enumerate() {
1579 if val.is_nan() {
1580 continue;
1581 }
1582
1583 let bits = val.to_bits();
1584 let row = idx / output.cols;
1585 let col = idx % output.cols;
1586
1587 if bits == 0x11111111_11111111
1588 || bits == 0x22222222_22222222
1589 || bits == 0x33333333_33333333
1590 {
1591 panic!(
1592 "[{}] Found poison value {} (0x{:016X}) at row {} col {} with range ({},{},{})",
1593 test, val, bits, row, col, start, end, step
1594 );
1595 }
1596 }
1597 }
1598
1599 Ok(())
1600 }
1601
1602 #[cfg(not(debug_assertions))]
1603 fn check_batch_no_poison(
1604 _test: &str,
1605 _kernel: Kernel,
1606 ) -> Result<(), Box<dyn std::error::Error>> {
1607 Ok(())
1608 }
1609
1610 macro_rules! gen_batch_tests {
1611 ($fn_name:ident) => {
1612 paste::paste! {
1613 #[test]
1614 fn [<$fn_name _scalar>]() {
1615 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1616 }
1617 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1618 #[test]
1619 fn [<$fn_name _avx2>]() {
1620 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1621 }
1622 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1623 #[test]
1624 fn [<$fn_name _avx512>]() {
1625 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1626 }
1627 #[test]
1628 fn [<$fn_name _auto_detect>]() {
1629 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1630 }
1631 }
1632 };
1633 }
1634 gen_batch_tests!(check_batch_default_row);
1635 gen_batch_tests!(check_batch_no_poison);
1636
1637 #[test]
1638 fn test_batch_stream_consistency() -> Result<(), Box<dyn std::error::Error>> {
1639 let test_data = vec![
1640 1.0,
1641 2.0,
1642 3.0,
1643 4.0,
1644 5.0,
1645 f64::NAN,
1646 6.0,
1647 7.0,
1648 8.0,
1649 9.0,
1650 10.0,
1651 f64::NAN,
1652 11.0,
1653 12.0,
1654 13.0,
1655 14.0,
1656 15.0,
1657 ];
1658
1659 let period = 5;
1660
1661 let params = EmaParams {
1662 period: Some(period),
1663 };
1664 let input = EmaInput::from_slice(&test_data, params.clone());
1665 let batch_output = ema(&input)?;
1666
1667 let mut stream = EmaStream::try_new(params)?;
1668 let mut stream_output = Vec::new();
1669 for &val in &test_data {
1670 let result = stream.update(val);
1671
1672 stream_output.push(result.unwrap_or(f64::NAN));
1673 }
1674
1675 for i in period..test_data.len() {
1676 let batch_val = batch_output.values[i];
1677 let stream_val = stream_output[i];
1678
1679 if batch_val.is_finite() && stream_val.is_finite() {
1680 let diff = (batch_val - stream_val).abs();
1681 assert!(
1682 diff < 1e-10,
1683 "Batch/Stream mismatch at index {}: batch={}, stream={}, diff={}",
1684 i,
1685 batch_val,
1686 stream_val,
1687 diff
1688 );
1689 } else {
1690 assert_eq!(
1691 batch_val.is_nan(),
1692 stream_val.is_nan(),
1693 "Batch/Stream NaN mismatch at index {}: batch={}, stream={}",
1694 i,
1695 batch_val,
1696 stream_val
1697 );
1698 }
1699 }
1700
1701 for i in 0..period.min(test_data.len()) {
1702 if test_data[i].is_finite() {
1703 if i > 0 && batch_output.values[i].is_finite() {
1704 assert!(
1705 (batch_output.values[i] - test_data[0]).abs() > 1e-10 || i == 0,
1706 "Batch should use running mean during warmup, not just first value"
1707 );
1708 }
1709 }
1710 }
1711
1712 Ok(())
1713 }
1714}
1715
1716#[cfg(feature = "python")]
1717#[pyfunction(name = "ema")]
1718#[pyo3(signature = (data, period, kernel=None))]
1719pub fn ema_py<'py>(
1720 py: Python<'py>,
1721 data: numpy::PyReadonlyArray1<'py, f64>,
1722 period: usize,
1723 kernel: Option<&str>,
1724) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
1725 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1726
1727 let kern = validate_kernel(kernel, false)?;
1728
1729 let params = EmaParams {
1730 period: Some(period),
1731 };
1732
1733 let result_vec: Vec<f64> = if let Ok(slice_in) = data.as_slice() {
1734 let ema_in = EmaInput::from_slice(slice_in, params);
1735 py.allow_threads(|| ema_with_kernel(&ema_in, kern).map(|o| o.values))
1736 .map_err(|e| PyValueError::new_err(e.to_string()))?
1737 } else {
1738 let owned = data.as_array().to_owned();
1739 let slice_in = owned
1740 .as_slice()
1741 .expect("owned numpy array should be contiguous");
1742 let ema_in = EmaInput::from_slice(slice_in, params);
1743 py.allow_threads(|| ema_with_kernel(&ema_in, kern).map(|o| o.values))
1744 .map_err(|e| PyValueError::new_err(e.to_string()))?
1745 };
1746
1747 Ok(result_vec.into_pyarray(py))
1748}
1749
1750#[cfg(feature = "python")]
1751#[pyfunction(name = "ema_batch")]
1752#[pyo3(signature = (data, period_range, kernel=None))]
1753pub fn ema_batch_py<'py>(
1754 py: Python<'py>,
1755 data: numpy::PyReadonlyArray1<'py, f64>,
1756 period_range: (usize, usize, usize),
1757 kernel: Option<&str>,
1758) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1759 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1760 use pyo3::types::PyDict;
1761
1762 let slice_in = data.as_slice()?;
1763
1764 let sweep = EmaBatchRange {
1765 period: period_range,
1766 };
1767
1768 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1769 let rows = combos.len();
1770 let cols = slice_in.len();
1771
1772 let out_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
1773 let slice_out = unsafe { out_arr.as_slice_mut()? };
1774
1775 let kern = validate_kernel(kernel, true)?;
1776
1777 let first = slice_in.iter().position(|x| !x.is_nan()).unwrap_or(0);
1778 for r in 0..rows {
1779 let row_start = r * cols;
1780 for i in 0..first {
1781 slice_out[row_start + i] = f64::NAN;
1782 }
1783 }
1784
1785 let combos = py
1786 .allow_threads(|| {
1787 let kernel = match kern {
1788 Kernel::Auto => detect_best_batch_kernel(),
1789 k => k,
1790 };
1791 let simd = match kernel {
1792 Kernel::Avx512Batch => Kernel::Avx512,
1793 Kernel::Avx2Batch => Kernel::Avx2,
1794 Kernel::ScalarBatch => Kernel::Scalar,
1795 _ => unreachable!(),
1796 };
1797 ema_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1798 })
1799 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1800
1801 let dict = PyDict::new(py);
1802 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1803 dict.set_item(
1804 "periods",
1805 combos
1806 .iter()
1807 .map(|p| p.period.unwrap() as u64)
1808 .collect::<Vec<_>>()
1809 .into_pyarray(py),
1810 )?;
1811
1812 Ok(dict.into())
1813}
1814
1815#[cfg(all(feature = "python", feature = "cuda"))]
1816#[pyfunction(name = "ema_cuda_batch_dev")]
1817#[pyo3(signature = (data_f32, period_range=(9, 9, 0), device_id=0))]
1818pub fn ema_cuda_batch_dev_py(
1819 py: Python<'_>,
1820 data_f32: numpy::PyReadonlyArray1<'_, f32>,
1821 period_range: (usize, usize, usize),
1822 device_id: usize,
1823) -> PyResult<EmaDeviceArrayF32Py> {
1824 if !cuda_available() {
1825 return Err(PyValueError::new_err("CUDA not available"));
1826 }
1827
1828 let slice_in = data_f32.as_slice()?;
1829 let sweep = EmaBatchRange {
1830 period: period_range,
1831 };
1832
1833 let (buf, rows, cols, ctx, dev) = py.allow_threads(|| {
1834 let cuda = CudaEma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1835 let handle = cuda
1836 .ema_batch_dev(slice_in, &sweep)
1837 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1838 let ctx = cuda.context_arc();
1839 let dev = cuda.device_id();
1840 Ok::<_, PyErr>((handle.buf, handle.rows, handle.cols, ctx, dev))
1841 })?;
1842
1843 Ok(EmaDeviceArrayF32Py {
1844 buf: Some(buf),
1845 rows,
1846 cols,
1847 _ctx: ctx,
1848 device_id: dev,
1849 })
1850}
1851
1852#[cfg(all(feature = "python", feature = "cuda"))]
1853#[pyfunction(name = "ema_cuda_many_series_one_param_dev")]
1854#[pyo3(signature = (data_tm_f32, period, device_id=0))]
1855pub fn ema_cuda_many_series_one_param_dev_py(
1856 py: Python<'_>,
1857 data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1858 period: usize,
1859 device_id: usize,
1860) -> PyResult<EmaDeviceArrayF32Py> {
1861 if !cuda_available() {
1862 return Err(PyValueError::new_err("CUDA not available"));
1863 }
1864 if period == 0 {
1865 return Err(PyValueError::new_err("period must be positive"));
1866 }
1867
1868 let flat = data_tm_f32.as_slice()?;
1869 let shape = data_tm_f32.shape();
1870 let series_len = shape[0];
1871 let num_series = shape[1];
1872 let params = EmaParams {
1873 period: Some(period),
1874 };
1875
1876 let (buf, rows, cols, ctx, dev) = py.allow_threads(|| {
1877 let cuda = CudaEma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1878 let handle = cuda
1879 .ema_many_series_one_param_time_major_dev(flat, num_series, series_len, ¶ms)
1880 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1881 let ctx = cuda.context_arc();
1882 let dev = cuda.device_id();
1883 Ok::<_, PyErr>((handle.buf, handle.rows, handle.cols, ctx, dev))
1884 })?;
1885
1886 Ok(EmaDeviceArrayF32Py {
1887 buf: Some(buf),
1888 rows,
1889 cols,
1890 _ctx: ctx,
1891 device_id: dev,
1892 })
1893}
1894
1895#[cfg(feature = "python")]
1896#[pyclass(name = "EmaStream")]
1897pub struct EmaStreamPy {
1898 inner: EmaStream,
1899}
1900
1901#[cfg(feature = "python")]
1902#[pymethods]
1903impl EmaStreamPy {
1904 #[new]
1905 pub fn new(period: usize) -> PyResult<Self> {
1906 let params = EmaParams {
1907 period: Some(period),
1908 };
1909 let inner = EmaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1910 Ok(Self { inner })
1911 }
1912
1913 pub fn update(&mut self, value: f64) -> f64 {
1914 self.inner.update(value).unwrap_or(f64::NAN)
1915 }
1916}
1917
1918#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1919use serde::{Deserialize, Serialize};
1920#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1921use wasm_bindgen::prelude::*;
1922
1923#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1924#[wasm_bindgen]
1925pub fn ema_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1926 let params = EmaParams {
1927 period: Some(period),
1928 };
1929 let input = EmaInput::from_slice(data, params);
1930
1931 let mut output = vec![0.0; data.len()];
1932
1933 ema_into_slice(&mut output, &input, Kernel::Auto)
1934 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1935
1936 Ok(output)
1937}
1938
1939#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1940#[derive(Serialize, Deserialize)]
1941pub struct EmaBatchConfig {
1942 pub period_range: (usize, usize, usize),
1943}
1944
1945#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1946#[derive(Serialize, Deserialize)]
1947pub struct EmaBatchJsOutput {
1948 pub values: Vec<f64>,
1949 pub combos: Vec<EmaParams>,
1950 pub rows: usize,
1951 pub cols: usize,
1952}
1953
1954#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1955#[wasm_bindgen(js_name = ema_batch)]
1956pub fn ema_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1957 let config: EmaBatchConfig = serde_wasm_bindgen::from_value(config)
1958 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1959
1960 let sweep = EmaBatchRange {
1961 period: config.period_range,
1962 };
1963
1964 let output = ema_batch_inner(data, &sweep, Kernel::Auto, false)
1965 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1966
1967 let js_output = EmaBatchJsOutput {
1968 values: output.values,
1969 combos: output.combos,
1970 rows: output.rows,
1971 cols: output.cols,
1972 };
1973
1974 serde_wasm_bindgen::to_value(&js_output)
1975 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1976}
1977
1978#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1979#[wasm_bindgen]
1980pub fn ema_batch_metadata_js(
1981 period_start: usize,
1982 period_end: usize,
1983 period_step: usize,
1984) -> Result<Vec<f64>, JsValue> {
1985 let sweep = EmaBatchRange {
1986 period: (period_start, period_end, period_step),
1987 };
1988
1989 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1990 let mut metadata = Vec::with_capacity(combos.len());
1991
1992 for combo in combos {
1993 metadata.push(combo.period.unwrap() as f64);
1994 }
1995
1996 Ok(metadata)
1997}
1998
1999#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2000#[wasm_bindgen]
2001pub fn ema_alloc(len: usize) -> *mut f64 {
2002 let mut vec = Vec::<f64>::with_capacity(len);
2003 let ptr = vec.as_mut_ptr();
2004 std::mem::forget(vec);
2005 ptr
2006}
2007
2008#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2009#[wasm_bindgen]
2010pub fn ema_free(ptr: *mut f64, len: usize) {
2011 if !ptr.is_null() {
2012 unsafe {
2013 let _ = Vec::from_raw_parts(ptr, len, len);
2014 }
2015 }
2016}
2017
2018#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2019#[wasm_bindgen]
2020pub fn ema_into(
2021 in_ptr: *const f64,
2022 out_ptr: *mut f64,
2023 len: usize,
2024 period: usize,
2025) -> Result<(), JsValue> {
2026 if in_ptr.is_null() || out_ptr.is_null() {
2027 return Err(JsValue::from_str("null pointer passed to ema_into"));
2028 }
2029
2030 unsafe {
2031 let data = std::slice::from_raw_parts(in_ptr, len);
2032
2033 if period == 0 || period > len {
2034 return Err(JsValue::from_str("Invalid period"));
2035 }
2036
2037 let params = EmaParams {
2038 period: Some(period),
2039 };
2040 let input = EmaInput::from_slice(data, params);
2041
2042 if in_ptr == out_ptr {
2043 let mut temp = vec![0.0; len];
2044 ema_into_slice(&mut temp, &input, Kernel::Auto)
2045 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2046
2047 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2048 out.copy_from_slice(&temp);
2049 } else {
2050 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2051 ema_into_slice(out, &input, Kernel::Auto)
2052 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2053 }
2054
2055 Ok(())
2056 }
2057}
2058
2059#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2060#[wasm_bindgen]
2061pub fn ema_batch_into(
2062 in_ptr: *const f64,
2063 out_ptr: *mut f64,
2064 len: usize,
2065 period_start: usize,
2066 period_end: usize,
2067 period_step: usize,
2068) -> Result<usize, JsValue> {
2069 if in_ptr.is_null() || out_ptr.is_null() {
2070 return Err(JsValue::from_str("null pointer passed to ema_batch_into"));
2071 }
2072
2073 unsafe {
2074 let data = std::slice::from_raw_parts(in_ptr, len);
2075
2076 let sweep = EmaBatchRange {
2077 period: (period_start, period_end, period_step),
2078 };
2079
2080 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2081 let rows = combos.len();
2082 let cols = len;
2083 let elems = rows
2084 .checked_mul(cols)
2085 .ok_or(JsValue::from_str("overflow rows*cols"))?;
2086 let out = std::slice::from_raw_parts_mut(out_ptr, elems);
2087
2088 let first = data
2089 .iter()
2090 .position(|x| !x.is_nan())
2091 .ok_or(JsValue::from_str("All NaN"))?;
2092 for r in 0..rows {
2093 let s = r * cols;
2094 out[s..s + first].fill(f64::NAN);
2095 }
2096
2097 ema_batch_inner_into(data, &sweep, Kernel::Auto, false, out)
2098 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2099
2100 Ok(rows)
2101 }
2102}