Skip to main content

vector_ta/indicators/moving_averages/
ema.rs

1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5    make_uninit_matrix,
6};
7#[cfg(not(target_arch = "wasm32"))]
8use rayon::prelude::*;
9use std::convert::AsRef;
10use std::mem::MaybeUninit;
11use thiserror::Error;
12
13#[cfg(feature = "python")]
14use crate::utilities::kernel_validation::validate_kernel;
15#[cfg(feature = "python")]
16use numpy;
17#[cfg(feature = "python")]
18use numpy::PyUntypedArrayMethods;
19#[cfg(feature = "python")]
20use pyo3::exceptions::PyValueError;
21#[cfg(feature = "python")]
22use pyo3::prelude::*;
23
24#[cfg(all(feature = "python", feature = "cuda"))]
25use crate::cuda::cuda_available;
26#[cfg(all(feature = "python", feature = "cuda"))]
27use crate::cuda::moving_averages::CudaEma;
28#[cfg(all(feature = "python", feature = "cuda"))]
29use cust::context::Context;
30#[cfg(all(feature = "python", feature = "cuda"))]
31use cust::memory::DeviceBuffer;
32#[cfg(all(feature = "python", feature = "cuda"))]
33use std::sync::Arc;
34
35#[cfg(all(feature = "python", feature = "cuda"))]
36#[pyclass(module = "ta_indicators.cuda", unsendable)]
37pub struct EmaDeviceArrayF32Py {
38    pub(crate) buf: Option<DeviceBuffer<f32>>,
39    pub(crate) rows: usize,
40    pub(crate) cols: usize,
41    pub(crate) _ctx: Arc<Context>,
42    pub(crate) device_id: u32,
43}
44#[cfg(all(feature = "python", feature = "cuda"))]
45#[pymethods]
46impl EmaDeviceArrayF32Py {
47    #[getter]
48    fn __cuda_array_interface__<'py>(
49        &self,
50        py: Python<'py>,
51    ) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
52        let d = pyo3::types::PyDict::new(py);
53        d.set_item("shape", (self.rows, self.cols))?;
54        d.set_item("typestr", "<f4")?;
55        d.set_item(
56            "strides",
57            (
58                self.cols * std::mem::size_of::<f32>(),
59                std::mem::size_of::<f32>(),
60            ),
61        )?;
62        let ptr = self
63            .buf
64            .as_ref()
65            .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?
66            .as_device_ptr()
67            .as_raw() as usize;
68        d.set_item("data", (ptr, false))?;
69        d.set_item("version", 3)?;
70        Ok(d)
71    }
72
73    fn __dlpack_device__(&self) -> (i32, i32) {
74        (2, self.device_id as i32)
75    }
76
77    #[pyo3(signature=(stream=None, max_version=None, dl_device=None, copy=None))]
78    fn __dlpack__<'py>(
79        &mut self,
80        py: Python<'py>,
81        stream: Option<pyo3::PyObject>,
82        max_version: Option<pyo3::PyObject>,
83        dl_device: Option<pyo3::PyObject>,
84        copy: Option<pyo3::PyObject>,
85    ) -> PyResult<PyObject> {
86        let (kdl, alloc_dev) = self.__dlpack_device__();
87        if let Some(dev_obj) = dl_device.as_ref() {
88            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
89                if dev_ty != kdl || dev_id != alloc_dev {
90                    let wants_copy = copy
91                        .as_ref()
92                        .and_then(|c| c.extract::<bool>(py).ok())
93                        .unwrap_or(false);
94                    if wants_copy {
95                        return Err(PyValueError::new_err(
96                            "device copy not implemented for __dlpack__",
97                        ));
98                    } else {
99                        return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
100                    }
101                }
102            }
103        }
104        let _ = stream;
105
106        let buf = self
107            .buf
108            .take()
109            .ok_or_else(|| PyValueError::new_err("__dlpack__ may only be called once"))?;
110
111        let rows = self.rows;
112        let cols = self.cols;
113
114        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
115
116        crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d(
117            py,
118            buf,
119            rows,
120            cols,
121            alloc_dev,
122            max_version_bound,
123        )
124    }
125}
126impl<'a> AsRef<[f64]> for EmaInput<'a> {
127    #[inline(always)]
128    fn as_ref(&self) -> &[f64] {
129        match &self.data {
130            EmaData::Slice(slice) => slice,
131            EmaData::Candles { candles, source } => source_type(candles, source),
132        }
133    }
134}
135
136#[derive(Debug, Clone)]
137pub enum EmaData<'a> {
138    Candles {
139        candles: &'a Candles,
140        source: &'a str,
141    },
142    Slice(&'a [f64]),
143}
144
145#[derive(Debug, Clone)]
146pub struct EmaOutput {
147    pub values: Vec<f64>,
148}
149
150#[derive(Debug, Clone)]
151#[cfg_attr(
152    all(target_arch = "wasm32", feature = "wasm"),
153    derive(Serialize, Deserialize)
154)]
155pub struct EmaParams {
156    pub period: Option<usize>,
157}
158
159impl Default for EmaParams {
160    fn default() -> Self {
161        Self { period: Some(9) }
162    }
163}
164
165#[derive(Debug, Clone)]
166pub struct EmaInput<'a> {
167    pub data: EmaData<'a>,
168    pub params: EmaParams,
169}
170
171impl<'a> EmaInput<'a> {
172    #[inline]
173    pub fn from_candles(c: &'a Candles, s: &'a str, p: EmaParams) -> Self {
174        Self {
175            data: EmaData::Candles {
176                candles: c,
177                source: s,
178            },
179            params: p,
180        }
181    }
182    #[inline]
183    pub fn from_slice(sl: &'a [f64], p: EmaParams) -> Self {
184        Self {
185            data: EmaData::Slice(sl),
186            params: p,
187        }
188    }
189    #[inline]
190    pub fn with_default_candles(c: &'a Candles) -> Self {
191        Self::from_candles(c, "close", EmaParams::default())
192    }
193    #[inline]
194    pub fn get_period(&self) -> usize {
195        self.params.period.unwrap_or(9)
196    }
197}
198
199#[derive(Copy, Clone, Debug)]
200pub struct EmaBuilder {
201    period: Option<usize>,
202    kernel: Kernel,
203}
204
205impl Default for EmaBuilder {
206    fn default() -> Self {
207        Self {
208            period: None,
209            kernel: Kernel::Auto,
210        }
211    }
212}
213
214impl EmaBuilder {
215    #[inline(always)]
216    pub fn new() -> Self {
217        Self::default()
218    }
219    #[inline(always)]
220    pub fn period(mut self, n: usize) -> Self {
221        self.period = Some(n);
222        self
223    }
224    #[inline(always)]
225    pub fn kernel(mut self, k: Kernel) -> Self {
226        self.kernel = k;
227        self
228    }
229
230    #[inline(always)]
231    pub fn apply(self, c: &Candles) -> Result<EmaOutput, EmaError> {
232        let p = EmaParams {
233            period: self.period,
234        };
235        let i = EmaInput::from_candles(c, "close", p);
236        ema_with_kernel(&i, self.kernel)
237    }
238
239    #[inline(always)]
240    pub fn apply_slice(self, d: &[f64]) -> Result<EmaOutput, EmaError> {
241        let p = EmaParams {
242            period: self.period,
243        };
244        let i = EmaInput::from_slice(d, p);
245        ema_with_kernel(&i, self.kernel)
246    }
247
248    #[inline(always)]
249    pub fn into_stream(self) -> Result<EmaStream, EmaError> {
250        let p = EmaParams {
251            period: self.period,
252        };
253        EmaStream::try_new(p)
254    }
255}
256
257#[derive(Debug, Error)]
258pub enum EmaError {
259    #[error("ema: Input data slice is empty.")]
260    EmptyInputData,
261    #[error("ema: All values are NaN.")]
262    AllValuesNaN,
263    #[error("ema: Invalid period: period = {period}, data length = {data_len}")]
264    InvalidPeriod { period: usize, data_len: usize },
265    #[error("ema: Not enough valid data: needed = {needed}, valid = {valid}")]
266    NotEnoughValidData { needed: usize, valid: usize },
267    #[error("ema: Output length mismatch: expected = {expected}, got = {got}")]
268    OutputLengthMismatch { expected: usize, got: usize },
269    #[error("ema: Invalid range: start = {start}, end = {end}, step = {step}")]
270    InvalidRange {
271        start: usize,
272        end: usize,
273        step: usize,
274    },
275    #[error("ema: Invalid kernel for batch API: {0:?}")]
276    InvalidKernelForBatch(Kernel),
277    #[error("ema: arithmetic overflow while computing {context}")]
278    ArithmeticOverflow { context: &'static str },
279}
280
281#[inline]
282pub fn ema(input: &EmaInput) -> Result<EmaOutput, EmaError> {
283    ema_with_kernel(input, Kernel::Auto)
284}
285
286#[inline(always)]
287fn ema_prepare<'a>(
288    input: &'a EmaInput,
289    kernel: Kernel,
290) -> Result<(&'a [f64], usize, usize, f64, f64, Kernel), EmaError> {
291    let data: &[f64] = input.as_ref();
292
293    let len = data.len();
294    if len == 0 {
295        return Err(EmaError::EmptyInputData);
296    }
297
298    let first = data
299        .iter()
300        .position(|x| !x.is_nan())
301        .ok_or(EmaError::AllValuesNaN)?;
302    let period = input.get_period();
303    if period == 0 || period > len {
304        return Err(EmaError::InvalidPeriod {
305            period,
306            data_len: len,
307        });
308    }
309    if len - first < period {
310        return Err(EmaError::NotEnoughValidData {
311            needed: period,
312            valid: len - first,
313        });
314    }
315
316    let alpha = 2.0 / (period as f64 + 1.0);
317    let beta = 1.0 - alpha;
318    let chosen = if matches!(kernel, Kernel::Auto) {
319        Kernel::Scalar
320    } else {
321        kernel
322    };
323    Ok((data, period, first, alpha, beta, chosen))
324}
325
326#[inline(always)]
327fn ema_compute_into(
328    data: &[f64],
329    period: usize,
330    first: usize,
331    alpha: f64,
332    beta: f64,
333    kernel: Kernel,
334    out: &mut [f64],
335) {
336    unsafe {
337        match kernel {
338            Kernel::Scalar | Kernel::ScalarBatch => {
339                ema_scalar_into(data, period, first, alpha, beta, out)
340            }
341            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
342            Kernel::Avx2 | Kernel::Avx2Batch => {
343                ema_avx2_into(data, period, first, alpha, beta, out)
344            }
345            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
346            Kernel::Avx512 | Kernel::Avx512Batch => {
347                ema_avx512_into(data, period, first, alpha, beta, out)
348            }
349
350            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
351            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
352                ema_scalar_into(data, period, first, alpha, beta, out)
353            }
354            _ => unreachable!(),
355        }
356    }
357}
358
359pub fn ema_with_kernel(input: &EmaInput, kernel: Kernel) -> Result<EmaOutput, EmaError> {
360    let (data, period, first, alpha, beta, chosen) = ema_prepare(input, kernel)?;
361
362    let mut out = alloc_with_nan_prefix(data.len(), first);
363    ema_compute_into(data, period, first, alpha, beta, chosen, &mut out);
364
365    Ok(EmaOutput { values: out })
366}
367
368#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
369pub fn ema_into(input: &EmaInput, out: &mut [f64]) -> Result<(), EmaError> {
370    let (data, period, first, alpha, beta, chosen) = ema_prepare(input, Kernel::Auto)?;
371
372    if out.len() != data.len() {
373        return Err(EmaError::OutputLengthMismatch {
374            expected: data.len(),
375            got: out.len(),
376        });
377    }
378
379    let warm = first.min(out.len());
380    for i in 0..warm {
381        out[i] = f64::from_bits(0x7ff8_0000_0000_0000);
382    }
383
384    ema_compute_into(data, period, first, alpha, beta, chosen, out);
385
386    Ok(())
387}
388
389#[inline(always)]
390fn is_finite_fast(x: f64) -> bool {
391    const EXP_MASK: u64 = 0x7ff0_0000_0000_0000;
392    (x.to_bits() & EXP_MASK) != EXP_MASK
393}
394
395#[inline]
396pub fn ema_into_slice(dst: &mut [f64], input: &EmaInput, kern: Kernel) -> Result<(), EmaError> {
397    let (data, period, first, alpha, beta, chosen) = ema_prepare(input, kern)?;
398
399    if dst.len() != data.len() {
400        return Err(EmaError::OutputLengthMismatch {
401            expected: data.len(),
402            got: dst.len(),
403        });
404    }
405
406    ema_compute_into(data, period, first, alpha, beta, chosen, dst);
407
408    for v in &mut dst[..first] {
409        *v = f64::NAN;
410    }
411
412    Ok(())
413}
414
415#[inline(always)]
416pub unsafe fn ema_scalar(
417    data: &[f64],
418    period: usize,
419    first_val: usize,
420    out: &mut Vec<f64>,
421) -> Result<EmaOutput, EmaError> {
422    let alpha = 2.0 / (period as f64 + 1.0);
423    let beta = 1.0 - alpha;
424    ema_scalar_into(data, period, first_val, alpha, beta, out);
425    let values = std::mem::take(out);
426    Ok(EmaOutput { values })
427}
428
429#[inline(always)]
430unsafe fn ema_scalar_into(
431    data: &[f64],
432    period: usize,
433    first_val: usize,
434    alpha: f64,
435    beta: f64,
436    out: &mut [f64],
437) {
438    let len = data.len();
439    debug_assert_eq!(out.len(), len);
440
441    let mut mean = *data.get_unchecked(first_val);
442    *out.get_unchecked_mut(first_val) = mean;
443    let mut valid_count = 1usize;
444
445    let warmup_end = (first_val + period).min(len);
446    for i in (first_val + 1)..warmup_end {
447        let x = *data.get_unchecked(i);
448        if is_finite_fast(x) {
449            valid_count += 1;
450            let vc = valid_count as f64;
451            mean = ((vc - 1.0) * mean + x) / vc;
452        }
453
454        *out.get_unchecked_mut(i) = mean;
455    }
456
457    if warmup_end < len {
458        let mut prev = mean;
459        for i in warmup_end..len {
460            let x = *data.get_unchecked(i);
461            if is_finite_fast(x) {
462                prev = beta.mul_add(prev, alpha * x);
463            }
464
465            *out.get_unchecked_mut(i) = prev;
466        }
467    }
468}
469
470#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
471#[inline(always)]
472pub unsafe fn ema_avx2(
473    data: &[f64],
474    period: usize,
475    first_val: usize,
476    out: &mut Vec<f64>,
477) -> Result<EmaOutput, EmaError> {
478    ema_scalar(data, period, first_val, out)
479}
480
481#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
482#[inline(always)]
483unsafe fn ema_avx2_into(
484    data: &[f64],
485    period: usize,
486    first_val: usize,
487    alpha: f64,
488    beta: f64,
489    out: &mut [f64],
490) {
491    ema_scalar_into(data, period, first_val, alpha, beta, out)
492}
493
494#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
495#[inline(always)]
496pub unsafe fn ema_avx512(
497    data: &[f64],
498    period: usize,
499    first_val: usize,
500    out: &mut Vec<f64>,
501) -> Result<EmaOutput, EmaError> {
502    ema_scalar(data, period, first_val, out)
503}
504
505#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
506#[inline(always)]
507unsafe fn ema_avx512_into(
508    data: &[f64],
509    period: usize,
510    first_val: usize,
511    alpha: f64,
512    beta: f64,
513    out: &mut [f64],
514) {
515    ema_scalar_into(data, period, first_val, alpha, beta, out)
516}
517
518#[derive(Debug, Clone)]
519pub struct EmaStream {
520    period: usize,
521    alpha: f64,
522    beta: f64,
523    count: usize,
524    mean: f64,
525    filled: bool,
526
527    inv: Box<[f64]>,
528}
529
530impl EmaStream {
531    #[inline]
532    pub fn try_new(params: EmaParams) -> Result<Self, EmaError> {
533        let period = params.period.unwrap_or(9);
534        if period == 0 {
535            return Err(EmaError::InvalidPeriod {
536                period,
537                data_len: 0,
538            });
539        }
540
541        let alpha = 2.0 / (period as f64 + 1.0);
542        let beta = 1.0 - alpha;
543
544        let mut inv = Vec::with_capacity(period);
545        for n in 1..=period {
546            inv.push(1.0 / n as f64);
547        }
548
549        Ok(Self {
550            period,
551            alpha,
552            beta,
553            count: 0,
554            mean: f64::NAN,
555            filled: false,
556            inv: inv.into_boxed_slice(),
557        })
558    }
559
560    #[inline(always)]
561    pub fn update(&mut self, x: f64) -> Option<f64> {
562        if !is_finite_fast(x) {
563            return if self.filled { Some(self.mean) } else { None };
564        }
565
566        self.count += 1;
567        let c = self.count;
568
569        if c == 1 {
570            self.mean = x;
571        } else if c <= self.period {
572            let inv = self.inv[c - 1];
573            self.mean = (x - self.mean).mul_add(inv, self.mean);
574        } else {
575            self.mean = self.beta.mul_add(self.mean, self.alpha * x);
576        }
577
578        if !self.filled && c >= self.period {
579            self.filled = true;
580        }
581        if self.filled {
582            Some(self.mean)
583        } else {
584            None
585        }
586    }
587}
588
589#[derive(Clone, Debug)]
590pub struct EmaBatchRange {
591    pub period: (usize, usize, usize),
592}
593
594impl Default for EmaBatchRange {
595    fn default() -> Self {
596        Self {
597            period: (9, 258, 1),
598        }
599    }
600}
601
602#[derive(Clone, Debug, Default)]
603pub struct EmaBatchBuilder {
604    range: EmaBatchRange,
605    kernel: Kernel,
606}
607
608impl EmaBatchBuilder {
609    pub fn new() -> Self {
610        Self::default()
611    }
612    pub fn kernel(mut self, k: Kernel) -> Self {
613        self.kernel = k;
614        self
615    }
616
617    #[inline]
618    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
619        self.range.period = (start, end, step);
620        self
621    }
622    #[inline]
623    pub fn period_static(mut self, p: usize) -> Self {
624        self.range.period = (p, p, 0);
625        self
626    }
627
628    pub fn apply_slice(self, data: &[f64]) -> Result<EmaBatchOutput, EmaError> {
629        ema_batch_with_kernel(data, &self.range, self.kernel)
630    }
631
632    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<EmaBatchOutput, EmaError> {
633        EmaBatchBuilder::new().kernel(k).apply_slice(data)
634    }
635
636    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<EmaBatchOutput, EmaError> {
637        let slice = source_type(c, src);
638        self.apply_slice(slice)
639    }
640
641    pub fn with_default_candles(c: &Candles) -> Result<EmaBatchOutput, EmaError> {
642        EmaBatchBuilder::new()
643            .kernel(Kernel::Auto)
644            .apply_candles(c, "close")
645    }
646}
647
648pub fn ema_batch_with_kernel(
649    data: &[f64],
650    sweep: &EmaBatchRange,
651    k: Kernel,
652) -> Result<EmaBatchOutput, EmaError> {
653    let kernel = match k {
654        Kernel::Auto => detect_best_batch_kernel(),
655        other if other.is_batch() => other,
656        _ => return Err(EmaError::InvalidKernelForBatch(k)),
657    };
658
659    let simd = match kernel {
660        Kernel::Avx512Batch => Kernel::Avx512,
661        Kernel::Avx2Batch => Kernel::Avx2,
662        Kernel::ScalarBatch => Kernel::Scalar,
663        _ => unreachable!(),
664    };
665    ema_batch_par_slice(data, sweep, simd)
666}
667
668#[derive(Clone, Debug)]
669pub struct EmaBatchOutput {
670    pub values: Vec<f64>,
671    pub combos: Vec<EmaParams>,
672    pub rows: usize,
673    pub cols: usize,
674}
675impl EmaBatchOutput {
676    pub fn row_for_params(&self, p: &EmaParams) -> Option<usize> {
677        self.combos
678            .iter()
679            .position(|c| c.period.unwrap_or(9) == p.period.unwrap_or(9))
680    }
681
682    pub fn values_for(&self, p: &EmaParams) -> Option<&[f64]> {
683        self.row_for_params(p).map(|row| {
684            let start = row * self.cols;
685            &self.values[start..start + self.cols]
686        })
687    }
688}
689
690#[inline(always)]
691fn expand_grid(r: &EmaBatchRange) -> Result<Vec<EmaParams>, EmaError> {
692    fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
693        if step == 0 || start == end {
694            return vec![start];
695        }
696        let (lo, hi) = if start <= end {
697            (start, end)
698        } else {
699            (end, start)
700        };
701        (lo..=hi).step_by(step).collect()
702    }
703
704    let periods = axis_usize(r.period);
705    if periods.is_empty() {
706        return Err(EmaError::InvalidRange {
707            start: r.period.0,
708            end: r.period.1,
709            step: r.period.2,
710        });
711    }
712    let mut out = Vec::with_capacity(periods.len());
713    for &p in &periods {
714        out.push(EmaParams { period: Some(p) });
715    }
716    Ok(out)
717}
718
719#[inline(always)]
720pub fn ema_batch_slice(
721    data: &[f64],
722    sweep: &EmaBatchRange,
723    kern: Kernel,
724) -> Result<EmaBatchOutput, EmaError> {
725    ema_batch_inner(data, sweep, kern, false)
726}
727
728#[inline(always)]
729pub fn ema_batch_par_slice(
730    data: &[f64],
731    sweep: &EmaBatchRange,
732    kern: Kernel,
733) -> Result<EmaBatchOutput, EmaError> {
734    ema_batch_inner(data, sweep, kern, true)
735}
736
737#[inline(always)]
738fn ema_batch_inner(
739    data: &[f64],
740    sweep: &EmaBatchRange,
741    kern: Kernel,
742    parallel: bool,
743) -> Result<EmaBatchOutput, EmaError> {
744    let combos = expand_grid(sweep)?;
745    let rows = combos.len();
746    let cols = data.len();
747
748    if cols == 0 {
749        return Err(EmaError::EmptyInputData);
750    }
751
752    let first = data
753        .iter()
754        .position(|x| !x.is_nan())
755        .ok_or(EmaError::AllValuesNaN)?;
756
757    let _total = rows.checked_mul(cols).ok_or(EmaError::ArithmeticOverflow {
758        context: "rows*cols",
759    })?;
760    let mut buf_mu = make_uninit_matrix(rows, cols);
761
762    let warm: Vec<usize> = std::iter::repeat(first).take(rows).collect();
763    init_matrix_prefixes(&mut buf_mu, cols, &warm);
764
765    let mut guard = core::mem::ManuallyDrop::new(buf_mu);
766    let out: &mut [f64] =
767        unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
768
769    let returned_combos = ema_batch_inner_into(data, sweep, kern, parallel, out)?;
770
771    let values = unsafe {
772        Vec::from_raw_parts(
773            guard.as_mut_ptr() as *mut f64,
774            guard.len(),
775            guard.capacity(),
776        )
777    };
778
779    Ok(EmaBatchOutput {
780        values,
781        combos: returned_combos,
782        rows,
783        cols,
784    })
785}
786
787#[inline(always)]
788fn ema_batch_inner_into(
789    data: &[f64],
790    sweep: &EmaBatchRange,
791    kern: Kernel,
792    parallel: bool,
793    out: &mut [f64],
794) -> Result<Vec<EmaParams>, EmaError> {
795    let combos = expand_grid(sweep)?;
796
797    if data.is_empty() {
798        return Err(EmaError::EmptyInputData);
799    }
800
801    let first = data
802        .iter()
803        .position(|x| !x.is_nan())
804        .ok_or(EmaError::AllValuesNaN)?;
805    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
806    if data.len() - first < max_p {
807        return Err(EmaError::NotEnoughValidData {
808            needed: max_p,
809            valid: data.len() - first,
810        });
811    }
812
813    let rows = combos.len();
814    let cols = data.len();
815
816    let raw = unsafe {
817        core::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
818    };
819
820    let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
821        let period = combos[row].period.unwrap();
822
823        let dst = core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
824
825        match kern {
826            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
827            Kernel::Avx512 => ema_row_avx512(data, first, period, dst),
828            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
829            Kernel::Avx2 => ema_row_avx2(data, first, period, dst),
830            _ => ema_row_scalar(data, first, period, dst),
831        }
832    };
833
834    if parallel {
835        #[cfg(not(target_arch = "wasm32"))]
836        {
837            raw.par_chunks_mut(cols)
838                .enumerate()
839                .for_each(|(row, slice)| do_row(row, slice));
840        }
841
842        #[cfg(target_arch = "wasm32")]
843        {
844            for (row, slice) in raw.chunks_mut(cols).enumerate() {
845                do_row(row, slice);
846            }
847        }
848    } else {
849        for (row, slice) in raw.chunks_mut(cols).enumerate() {
850            do_row(row, slice);
851        }
852    }
853
854    Ok(combos)
855}
856
857#[inline(always)]
858unsafe fn ema_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
859    let alpha = 2.0 / (period as f64 + 1.0);
860    let beta = 1.0 - alpha;
861
862    let len = data.len();
863
864    let mut mean = unsafe { *data.get_unchecked(first) };
865    unsafe { *out.get_unchecked_mut(first) = mean };
866    let mut valid_count = 1usize;
867
868    let warmup_end = (first + period).min(len);
869    for i in (first + 1)..warmup_end {
870        let x = unsafe { *data.get_unchecked(i) };
871        if is_finite_fast(x) {
872            valid_count += 1;
873            let vc = valid_count as f64;
874            mean = ((vc - 1.0) * mean + x) / vc;
875        }
876
877        unsafe { *out.get_unchecked_mut(i) = mean };
878    }
879
880    if warmup_end < len {
881        let mut prev = mean;
882        for i in warmup_end..len {
883            let x = unsafe { *data.get_unchecked(i) };
884            if is_finite_fast(x) {
885                prev = beta.mul_add(prev, alpha * x);
886            }
887
888            unsafe { *out.get_unchecked_mut(i) = prev };
889        }
890    }
891}
892
893#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
894#[inline(always)]
895unsafe fn ema_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
896    ema_row_scalar(data, first, period, out);
897}
898
899#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
900#[inline(always)]
901unsafe fn ema_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
902    ema_row_scalar(data, first, period, out);
903}
904
905#[cfg(test)]
906mod tests {
907    use super::*;
908    use crate::skip_if_unsupported;
909    use crate::utilities::data_loader::read_candles_from_csv;
910    use proptest::prelude::*;
911
912    #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
913    #[test]
914    fn test_ema_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
915        let mut data = Vec::with_capacity(256);
916        for _ in 0..5 {
917            data.push(f64::NAN);
918        }
919        for i in 0..251 {
920            let x = (i as f64).sin() * 3.14159 + 100.0 + ((i % 7) as f64) * 0.01;
921            data.push(x);
922        }
923
924        let input = EmaInput::from_slice(&data, EmaParams::default());
925        let baseline = ema(&input)?.values;
926
927        let mut out = vec![0.0; data.len()];
928        ema_into(&input, &mut out)?;
929
930        assert_eq!(baseline.len(), out.len());
931        fn eq_or_both_nan(a: f64, b: f64) -> bool {
932            (a.is_nan() && b.is_nan()) || (a == b) || (a - b).abs() <= 1e-12
933        }
934        for (i, (&a, &b)) in baseline.iter().zip(out.iter()).enumerate() {
935            assert!(
936                eq_or_both_nan(a, b),
937                "mismatch at index {}: api={} into={}",
938                i,
939                a,
940                b
941            );
942        }
943        Ok(())
944    }
945
946    fn check_ema_partial_params(
947        test_name: &str,
948        kernel: Kernel,
949    ) -> Result<(), Box<dyn std::error::Error>> {
950        skip_if_unsupported!(kernel, test_name);
951        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
952        let candles = read_candles_from_csv(file_path)?;
953
954        let default_params = EmaParams { period: None };
955        let input = EmaInput::from_candles(&candles, "close", default_params);
956        let output = ema_with_kernel(&input, kernel)?;
957        assert_eq!(output.values.len(), candles.close.len());
958        Ok(())
959    }
960
961    fn check_ema_accuracy(
962        test_name: &str,
963        kernel: Kernel,
964    ) -> Result<(), Box<dyn std::error::Error>> {
965        skip_if_unsupported!(kernel, test_name);
966        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
967        let candles = read_candles_from_csv(file_path)?;
968
969        let input = EmaInput::from_candles(&candles, "close", EmaParams::default());
970        let result = ema_with_kernel(&input, kernel)?;
971        let expected_last_five = [59302.2, 59277.9, 59230.2, 59215.1, 59103.1];
972        let start = result.values.len().saturating_sub(5);
973        for (i, &val) in result.values[start..].iter().enumerate() {
974            let diff = (val - expected_last_five[i]).abs();
975            assert!(
976                diff < 1e-1,
977                "[{}] EMA {:?} mismatch at idx {}: got {}, expected {}",
978                test_name,
979                kernel,
980                i,
981                val,
982                expected_last_five[i]
983            );
984        }
985        Ok(())
986    }
987
988    fn check_ema_default_candles(
989        test_name: &str,
990        kernel: Kernel,
991    ) -> Result<(), Box<dyn std::error::Error>> {
992        skip_if_unsupported!(kernel, test_name);
993        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
994        let candles = read_candles_from_csv(file_path)?;
995
996        let input = EmaInput::with_default_candles(&candles);
997        match input.data {
998            EmaData::Candles { source, .. } => assert_eq!(source, "close"),
999            _ => panic!("Expected EmaData::Candles"),
1000        }
1001        let output = ema_with_kernel(&input, kernel)?;
1002        assert_eq!(output.values.len(), candles.close.len());
1003        Ok(())
1004    }
1005
1006    fn check_ema_zero_period(
1007        test_name: &str,
1008        kernel: Kernel,
1009    ) -> Result<(), Box<dyn std::error::Error>> {
1010        skip_if_unsupported!(kernel, test_name);
1011        let input_data = [10.0, 20.0, 30.0];
1012        let params = EmaParams { period: Some(0) };
1013        let input = EmaInput::from_slice(&input_data, params);
1014        let res = ema_with_kernel(&input, kernel);
1015        assert!(
1016            res.is_err(),
1017            "[{}] EMA should fail with zero period",
1018            test_name
1019        );
1020        Ok(())
1021    }
1022
1023    fn check_ema_period_exceeds_length(
1024        test_name: &str,
1025        kernel: Kernel,
1026    ) -> Result<(), Box<dyn std::error::Error>> {
1027        skip_if_unsupported!(kernel, test_name);
1028        let data_small = [10.0, 20.0, 30.0];
1029        let params = EmaParams { period: Some(10) };
1030        let input = EmaInput::from_slice(&data_small, params);
1031        let res = ema_with_kernel(&input, kernel);
1032        assert!(
1033            res.is_err(),
1034            "[{}] EMA should fail with period exceeding length",
1035            test_name
1036        );
1037        Ok(())
1038    }
1039
1040    fn check_ema_very_small_dataset(
1041        test_name: &str,
1042        kernel: Kernel,
1043    ) -> Result<(), Box<dyn std::error::Error>> {
1044        skip_if_unsupported!(kernel, test_name);
1045        let single_point = [42.0];
1046        let params = EmaParams { period: Some(9) };
1047        let input = EmaInput::from_slice(&single_point, params);
1048        let res = ema_with_kernel(&input, kernel);
1049        assert!(
1050            res.is_err(),
1051            "[{}] EMA should fail with insufficient data",
1052            test_name
1053        );
1054        Ok(())
1055    }
1056
1057    fn check_ema_empty_input(
1058        test_name: &str,
1059        kernel: Kernel,
1060    ) -> Result<(), Box<dyn std::error::Error>> {
1061        skip_if_unsupported!(kernel, test_name);
1062        let empty: [f64; 0] = [];
1063        let input = EmaInput::from_slice(&empty, EmaParams::default());
1064        let res = ema_with_kernel(&input, kernel);
1065        assert!(
1066            matches!(res, Err(EmaError::EmptyInputData)),
1067            "[{}] EMA should fail with empty input",
1068            test_name
1069        );
1070        Ok(())
1071    }
1072
1073    fn check_ema_reinput(
1074        test_name: &str,
1075        kernel: Kernel,
1076    ) -> Result<(), Box<dyn std::error::Error>> {
1077        skip_if_unsupported!(kernel, test_name);
1078        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1079        let candles = read_candles_from_csv(file_path)?;
1080
1081        let first_params = EmaParams { period: Some(9) };
1082        let first_input = EmaInput::from_candles(&candles, "close", first_params);
1083        let first_result = ema_with_kernel(&first_input, kernel)?;
1084
1085        let second_params = EmaParams { period: Some(5) };
1086        let second_input = EmaInput::from_slice(&first_result.values, second_params);
1087        let second_result = ema_with_kernel(&second_input, kernel)?;
1088
1089        assert_eq!(second_result.values.len(), first_result.values.len());
1090        if second_result.values.len() > 240 {
1091            for (i, &val) in second_result.values[240..].iter().enumerate() {
1092                assert!(
1093                    !val.is_nan(),
1094                    "[{}] Found unexpected NaN at out-index {}",
1095                    test_name,
1096                    240 + i
1097                );
1098            }
1099        }
1100        Ok(())
1101    }
1102
1103    fn check_ema_nan_handling(
1104        test_name: &str,
1105        kernel: Kernel,
1106    ) -> Result<(), Box<dyn std::error::Error>> {
1107        skip_if_unsupported!(kernel, test_name);
1108        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1109        let candles = read_candles_from_csv(file_path)?;
1110
1111        let input = EmaInput::from_candles(&candles, "close", EmaParams { period: Some(9) });
1112        let res = ema_with_kernel(&input, kernel)?;
1113        assert_eq!(res.values.len(), candles.close.len());
1114        if res.values.len() > 240 {
1115            for (i, &val) in res.values[240..].iter().enumerate() {
1116                assert!(
1117                    !val.is_nan(),
1118                    "[{}] Found unexpected NaN at out-index {}",
1119                    test_name,
1120                    240 + i
1121                );
1122            }
1123        }
1124        Ok(())
1125    }
1126
1127    fn check_ema_streaming(
1128        test_name: &str,
1129        kernel: Kernel,
1130    ) -> Result<(), Box<dyn std::error::Error>> {
1131        skip_if_unsupported!(kernel, test_name);
1132
1133        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1134        let candles = read_candles_from_csv(file_path)?;
1135
1136        let period = 9;
1137        let warm_up = 240;
1138
1139        let input = EmaInput::from_candles(
1140            &candles,
1141            "close",
1142            EmaParams {
1143                period: Some(period),
1144            },
1145        );
1146        let batch_output = ema_with_kernel(&input, kernel)?.values;
1147
1148        let mut stream = EmaStream::try_new(EmaParams {
1149            period: Some(period),
1150        })?;
1151        let mut stream_values = Vec::with_capacity(candles.close.len());
1152
1153        for (i, &price) in candles.close.iter().enumerate() {
1154            let stream_val = stream.update(price);
1155
1156            if i < period - 1 {
1157                assert!(
1158                    stream_val.is_none(),
1159                    "[{}] Stream should return None during warmup at idx {}",
1160                    test_name,
1161                    i
1162                );
1163                stream_values.push(f64::NAN);
1164            } else {
1165                stream_values.push(stream_val.unwrap_or(f64::NAN));
1166            }
1167        }
1168
1169        assert_eq!(batch_output.len(), stream_values.len());
1170
1171        for (i, (&b, &s)) in batch_output
1172            .iter()
1173            .zip(&stream_values)
1174            .enumerate()
1175            .skip(warm_up)
1176        {
1177            if b.is_nan() && s.is_nan() {
1178                continue;
1179            }
1180            let diff = (b - s).abs();
1181            assert!(
1182                diff < 1e-9,
1183                "[{}] EMA streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1184                test_name,
1185                i,
1186                b,
1187                s,
1188                diff
1189            );
1190        }
1191        Ok(())
1192    }
1193
1194    #[cfg(debug_assertions)]
1195    fn check_ema_no_poison(
1196        test_name: &str,
1197        kernel: Kernel,
1198    ) -> Result<(), Box<dyn std::error::Error>> {
1199        skip_if_unsupported!(kernel, test_name);
1200
1201        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1202        let candles = read_candles_from_csv(file_path)?;
1203
1204        let test_periods = vec![2, 5, 9, 14, 20, 50, 100, 200];
1205        let test_sources = vec!["open", "high", "low", "close", "hl2", "hlc3", "ohlc4"];
1206
1207        for period in &test_periods {
1208            for source in &test_sources {
1209                let input = EmaInput::from_candles(
1210                    &candles,
1211                    source,
1212                    EmaParams {
1213                        period: Some(*period),
1214                    },
1215                );
1216                let output = ema_with_kernel(&input, kernel)?;
1217
1218                for (i, &val) in output.values.iter().enumerate() {
1219                    if val.is_nan() {
1220                        continue;
1221                    }
1222
1223                    let bits = val.to_bits();
1224
1225                    if bits == 0x11111111_11111111 {
1226                        panic!(
1227                            "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} with period={}, source={}",
1228                            test_name, val, bits, i, period, source
1229                        );
1230                    }
1231
1232                    if bits == 0x22222222_22222222 {
1233                        panic!(
1234                            "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} with period={}, source={}",
1235                            test_name, val, bits, i, period, source
1236                        );
1237                    }
1238
1239                    if bits == 0x33333333_33333333 {
1240                        panic!(
1241                            "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} with period={}, source={}",
1242                            test_name, val, bits, i, period, source
1243                        );
1244                    }
1245                }
1246            }
1247        }
1248
1249        Ok(())
1250    }
1251
1252    #[cfg(not(debug_assertions))]
1253    fn check_ema_no_poison(
1254        _test_name: &str,
1255        _kernel: Kernel,
1256    ) -> Result<(), Box<dyn std::error::Error>> {
1257        Ok(())
1258    }
1259
1260    fn check_ema_property(
1261        test_name: &str,
1262        kernel: Kernel,
1263    ) -> Result<(), Box<dyn std::error::Error>> {
1264        use proptest::prelude::*;
1265        skip_if_unsupported!(kernel, test_name);
1266
1267        let strat = (1usize..=100).prop_flat_map(|period| {
1268            (
1269                prop::collection::vec(
1270                    (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1271                    period + 10..400,
1272                ),
1273                Just(period),
1274            )
1275        });
1276
1277        proptest::test_runner::TestRunner::default()
1278            .run(&strat, |(data, period)| {
1279                let params = EmaParams {
1280                    period: Some(period),
1281                };
1282                let input = EmaInput::from_slice(&data, params);
1283
1284                let EmaOutput { values: out } = ema_with_kernel(&input, kernel).unwrap();
1285
1286                let EmaOutput { values: ref_out } =
1287                    ema_with_kernel(&input, Kernel::Scalar).unwrap();
1288
1289                let alpha = 2.0 / (period as f64 + 1.0);
1290                let beta = 1.0 - alpha;
1291
1292                let first_valid = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1293
1294                for i in 0..data.len() {
1295                    let y = out[i];
1296                    let r = ref_out[i];
1297
1298                    if i < first_valid {
1299                        prop_assert!(
1300                            y.is_nan(),
1301                            "[{}] Expected NaN during warmup at idx {}, got {}",
1302                            test_name,
1303                            i,
1304                            y
1305                        );
1306                        continue;
1307                    }
1308
1309                    if i >= first_valid {
1310                        let window = &data[first_valid..=i];
1311                        let lo = window
1312                            .iter()
1313                            .cloned()
1314                            .filter(|x| x.is_finite())
1315                            .fold(f64::INFINITY, f64::min);
1316                        let hi = window
1317                            .iter()
1318                            .cloned()
1319                            .filter(|x| x.is_finite())
1320                            .fold(f64::NEG_INFINITY, f64::max);
1321
1322                        if !y.is_nan() && lo.is_finite() && hi.is_finite() {
1323                            prop_assert!(
1324                                y >= lo - 1e-9 && y <= hi + 1e-9,
1325                                "[{}] idx {}: {} not in [{}, {}]",
1326                                test_name,
1327                                i,
1328                                y,
1329                                lo,
1330                                hi
1331                            );
1332                        }
1333                    }
1334
1335                    if period == 1 && i >= first_valid && data[i].is_finite() {
1336                        prop_assert!(
1337                            (y - data[i]).abs() <= 1e-10,
1338                            "[{}] Period=1 mismatch at idx {}: {} vs {}",
1339                            test_name,
1340                            i,
1341                            y,
1342                            data[i]
1343                        );
1344                    }
1345
1346                    if i >= first_valid + period {
1347                        let window_start = i.saturating_sub(period);
1348                        let window = &data[window_start..=i];
1349                        if window
1350                            .iter()
1351                            .all(|&x| (x - data[window_start]).abs() < 1e-10)
1352                        {
1353                            let expected = data[window_start];
1354                            prop_assert!(
1355                                (y - expected).abs() <= 1e-6,
1356                                "[{}] Constant data convergence failed at idx {}: {} vs {}",
1357                                test_name,
1358                                i,
1359                                y,
1360                                expected
1361                            );
1362                        }
1363                    }
1364
1365                    if !y.is_finite() || !r.is_finite() {
1366                        prop_assert!(
1367                            y.to_bits() == r.to_bits(),
1368                            "[{}] NaN/infinite mismatch at idx {}: {} vs {}",
1369                            test_name,
1370                            i,
1371                            y,
1372                            r
1373                        );
1374                    } else {
1375                        let abs_diff = (y - r).abs();
1376                        let rel_diff = if r.abs() > 1e-10 {
1377                            abs_diff / r.abs()
1378                        } else {
1379                            abs_diff
1380                        };
1381
1382                        prop_assert!(
1383                            abs_diff <= 1e-9 || rel_diff <= 1e-9,
1384                            "[{}] Kernel mismatch at idx {}: {} vs {} (abs_diff={}, rel_diff={})",
1385                            test_name,
1386                            i,
1387                            y,
1388                            r,
1389                            abs_diff,
1390                            rel_diff
1391                        );
1392                    }
1393
1394                    if i >= first_valid + period
1395                        && y.is_finite()
1396                        && out[i - 1].is_finite()
1397                        && data[i].is_finite()
1398                    {
1399                        let expected_ema = alpha * data[i] + beta * out[i - 1];
1400                        let diff = (y - expected_ema).abs();
1401
1402                        prop_assert!(
1403                            diff <= 1e-9 * ((i - first_valid) as f64).max(1.0),
1404                            "[{}] EMA recursive property failed at idx {}: {} vs {} (diff={})",
1405                            test_name,
1406                            i,
1407                            y,
1408                            expected_ema,
1409                            diff
1410                        );
1411                    }
1412
1413                    if i >= first_valid + period * 2 {
1414                        let historical = &data[first_valid..=i];
1415                        let hist_min = historical
1416                            .iter()
1417                            .filter(|x| x.is_finite())
1418                            .fold(f64::INFINITY, |a, &b| a.min(b));
1419                        let hist_max = historical
1420                            .iter()
1421                            .filter(|x| x.is_finite())
1422                            .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
1423
1424                        if hist_min.is_finite() && hist_max.is_finite() && y.is_finite() {
1425                            prop_assert!(
1426                                y >= hist_min - 1e-6 && y <= hist_max + 1e-6,
1427                                "[{}] EMA outside historical bounds at idx {}: {} not in [{}, {}]",
1428                                test_name,
1429                                i,
1430                                y,
1431                                hist_min,
1432                                hist_max
1433                            );
1434                        }
1435                    }
1436                }
1437
1438                if first_valid < data.len() && out[first_valid].is_finite() {
1439                    prop_assert!(
1440                        (out[first_valid] - data[first_valid]).abs() <= 1e-10,
1441                        "[{}] First valid output should equal first valid input: {} vs {}",
1442                        test_name,
1443                        out[first_valid],
1444                        data[first_valid]
1445                    );
1446                }
1447
1448                Ok(())
1449            })
1450            .unwrap();
1451
1452        Ok(())
1453    }
1454
1455    macro_rules! generate_all_ema_tests {
1456        ($($test_fn:ident),*) => {
1457            paste::paste! {
1458                $(
1459                    #[test]
1460                    fn [<$test_fn _scalar_f64>]() {
1461                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1462                    }
1463
1464                    #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1465                    #[test]
1466                    fn [<$test_fn _avx2_f64>]() {
1467                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1468                    }
1469
1470                    #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1471                    #[test]
1472                    fn [<$test_fn _avx512_f64>]() {
1473                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1474                    }
1475                )*
1476            }
1477        }
1478    }
1479
1480    generate_all_ema_tests!(
1481        check_ema_partial_params,
1482        check_ema_accuracy,
1483        check_ema_default_candles,
1484        check_ema_zero_period,
1485        check_ema_period_exceeds_length,
1486        check_ema_very_small_dataset,
1487        check_ema_empty_input,
1488        check_ema_reinput,
1489        check_ema_nan_handling,
1490        check_ema_streaming,
1491        check_ema_property,
1492        check_ema_no_poison
1493    );
1494
1495    fn check_batch_default_row(
1496        test: &str,
1497        kernel: Kernel,
1498    ) -> Result<(), Box<dyn std::error::Error>> {
1499        skip_if_unsupported!(kernel, test);
1500
1501        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1502        let c = read_candles_from_csv(file)?;
1503
1504        let output = EmaBatchBuilder::new()
1505            .kernel(kernel)
1506            .apply_candles(&c, "close")?;
1507
1508        let def = EmaParams::default();
1509        let row = output.values_for(&def).expect("default row missing");
1510
1511        assert_eq!(row.len(), c.close.len());
1512
1513        let expected = [59302.2, 59277.9, 59230.2, 59215.1, 59103.1];
1514        let start = row.len() - 5;
1515        for (i, &v) in row[start..].iter().enumerate() {
1516            assert!(
1517                (v - expected[i]).abs() < 1e-1,
1518                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1519            );
1520        }
1521        Ok(())
1522    }
1523
1524    #[cfg(debug_assertions)]
1525    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
1526        skip_if_unsupported!(kernel, test);
1527
1528        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1529        let c = read_candles_from_csv(file)?;
1530
1531        let test_sources = vec!["open", "high", "low", "close", "hl2", "hlc3", "ohlc4"];
1532
1533        for source in &test_sources {
1534            let output = EmaBatchBuilder::new()
1535                .kernel(kernel)
1536                .period_range(2, 200, 3)
1537                .apply_candles(&c, source)?;
1538
1539            for (idx, &val) in output.values.iter().enumerate() {
1540                if val.is_nan() {
1541                    continue;
1542                }
1543
1544                let bits = val.to_bits();
1545                let row = idx / output.cols;
1546                let col = idx % output.cols;
1547
1548                if bits == 0x11111111_11111111 {
1549                    panic!(
1550                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (flat index {}) with source={}",
1551                        test, val, bits, row, col, idx, source
1552                    );
1553                }
1554
1555                if bits == 0x22222222_22222222 {
1556                    panic!(
1557                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (flat index {}) with source={}",
1558                        test, val, bits, row, col, idx, source
1559                    );
1560                }
1561
1562                if bits == 0x33333333_33333333 {
1563                    panic!(
1564                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (flat index {}) with source={}",
1565                        test, val, bits, row, col, idx, source
1566                    );
1567                }
1568            }
1569        }
1570
1571        let edge_case_ranges = vec![(2, 5, 1), (190, 200, 2), (50, 100, 10)];
1572        for (start, end, step) in edge_case_ranges {
1573            let output = EmaBatchBuilder::new()
1574                .kernel(kernel)
1575                .period_range(start, end, step)
1576                .apply_candles(&c, "close")?;
1577
1578            for (idx, &val) in output.values.iter().enumerate() {
1579                if val.is_nan() {
1580                    continue;
1581                }
1582
1583                let bits = val.to_bits();
1584                let row = idx / output.cols;
1585                let col = idx % output.cols;
1586
1587                if bits == 0x11111111_11111111
1588                    || bits == 0x22222222_22222222
1589                    || bits == 0x33333333_33333333
1590                {
1591                    panic!(
1592						"[{}] Found poison value {} (0x{:016X}) at row {} col {} with range ({},{},{})",
1593						test, val, bits, row, col, start, end, step
1594					);
1595                }
1596            }
1597        }
1598
1599        Ok(())
1600    }
1601
1602    #[cfg(not(debug_assertions))]
1603    fn check_batch_no_poison(
1604        _test: &str,
1605        _kernel: Kernel,
1606    ) -> Result<(), Box<dyn std::error::Error>> {
1607        Ok(())
1608    }
1609
1610    macro_rules! gen_batch_tests {
1611        ($fn_name:ident) => {
1612            paste::paste! {
1613                #[test]
1614                fn [<$fn_name _scalar>]() {
1615                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1616                }
1617                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1618                #[test]
1619                fn [<$fn_name _avx2>]() {
1620                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1621                }
1622                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1623                #[test]
1624                fn [<$fn_name _avx512>]() {
1625                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1626                }
1627                #[test]
1628                fn [<$fn_name _auto_detect>]() {
1629                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1630                }
1631            }
1632        };
1633    }
1634    gen_batch_tests!(check_batch_default_row);
1635    gen_batch_tests!(check_batch_no_poison);
1636
1637    #[test]
1638    fn test_batch_stream_consistency() -> Result<(), Box<dyn std::error::Error>> {
1639        let test_data = vec![
1640            1.0,
1641            2.0,
1642            3.0,
1643            4.0,
1644            5.0,
1645            f64::NAN,
1646            6.0,
1647            7.0,
1648            8.0,
1649            9.0,
1650            10.0,
1651            f64::NAN,
1652            11.0,
1653            12.0,
1654            13.0,
1655            14.0,
1656            15.0,
1657        ];
1658
1659        let period = 5;
1660
1661        let params = EmaParams {
1662            period: Some(period),
1663        };
1664        let input = EmaInput::from_slice(&test_data, params.clone());
1665        let batch_output = ema(&input)?;
1666
1667        let mut stream = EmaStream::try_new(params)?;
1668        let mut stream_output = Vec::new();
1669        for &val in &test_data {
1670            let result = stream.update(val);
1671
1672            stream_output.push(result.unwrap_or(f64::NAN));
1673        }
1674
1675        for i in period..test_data.len() {
1676            let batch_val = batch_output.values[i];
1677            let stream_val = stream_output[i];
1678
1679            if batch_val.is_finite() && stream_val.is_finite() {
1680                let diff = (batch_val - stream_val).abs();
1681                assert!(
1682                    diff < 1e-10,
1683                    "Batch/Stream mismatch at index {}: batch={}, stream={}, diff={}",
1684                    i,
1685                    batch_val,
1686                    stream_val,
1687                    diff
1688                );
1689            } else {
1690                assert_eq!(
1691                    batch_val.is_nan(),
1692                    stream_val.is_nan(),
1693                    "Batch/Stream NaN mismatch at index {}: batch={}, stream={}",
1694                    i,
1695                    batch_val,
1696                    stream_val
1697                );
1698            }
1699        }
1700
1701        for i in 0..period.min(test_data.len()) {
1702            if test_data[i].is_finite() {
1703                if i > 0 && batch_output.values[i].is_finite() {
1704                    assert!(
1705                        (batch_output.values[i] - test_data[0]).abs() > 1e-10 || i == 0,
1706                        "Batch should use running mean during warmup, not just first value"
1707                    );
1708                }
1709            }
1710        }
1711
1712        Ok(())
1713    }
1714}
1715
1716#[cfg(feature = "python")]
1717#[pyfunction(name = "ema")]
1718#[pyo3(signature = (data, period, kernel=None))]
1719pub fn ema_py<'py>(
1720    py: Python<'py>,
1721    data: numpy::PyReadonlyArray1<'py, f64>,
1722    period: usize,
1723    kernel: Option<&str>,
1724) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
1725    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1726
1727    let kern = validate_kernel(kernel, false)?;
1728
1729    let params = EmaParams {
1730        period: Some(period),
1731    };
1732
1733    let result_vec: Vec<f64> = if let Ok(slice_in) = data.as_slice() {
1734        let ema_in = EmaInput::from_slice(slice_in, params);
1735        py.allow_threads(|| ema_with_kernel(&ema_in, kern).map(|o| o.values))
1736            .map_err(|e| PyValueError::new_err(e.to_string()))?
1737    } else {
1738        let owned = data.as_array().to_owned();
1739        let slice_in = owned
1740            .as_slice()
1741            .expect("owned numpy array should be contiguous");
1742        let ema_in = EmaInput::from_slice(slice_in, params);
1743        py.allow_threads(|| ema_with_kernel(&ema_in, kern).map(|o| o.values))
1744            .map_err(|e| PyValueError::new_err(e.to_string()))?
1745    };
1746
1747    Ok(result_vec.into_pyarray(py))
1748}
1749
1750#[cfg(feature = "python")]
1751#[pyfunction(name = "ema_batch")]
1752#[pyo3(signature = (data, period_range, kernel=None))]
1753pub fn ema_batch_py<'py>(
1754    py: Python<'py>,
1755    data: numpy::PyReadonlyArray1<'py, f64>,
1756    period_range: (usize, usize, usize),
1757    kernel: Option<&str>,
1758) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1759    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1760    use pyo3::types::PyDict;
1761
1762    let slice_in = data.as_slice()?;
1763
1764    let sweep = EmaBatchRange {
1765        period: period_range,
1766    };
1767
1768    let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1769    let rows = combos.len();
1770    let cols = slice_in.len();
1771
1772    let out_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
1773    let slice_out = unsafe { out_arr.as_slice_mut()? };
1774
1775    let kern = validate_kernel(kernel, true)?;
1776
1777    let first = slice_in.iter().position(|x| !x.is_nan()).unwrap_or(0);
1778    for r in 0..rows {
1779        let row_start = r * cols;
1780        for i in 0..first {
1781            slice_out[row_start + i] = f64::NAN;
1782        }
1783    }
1784
1785    let combos = py
1786        .allow_threads(|| {
1787            let kernel = match kern {
1788                Kernel::Auto => detect_best_batch_kernel(),
1789                k => k,
1790            };
1791            let simd = match kernel {
1792                Kernel::Avx512Batch => Kernel::Avx512,
1793                Kernel::Avx2Batch => Kernel::Avx2,
1794                Kernel::ScalarBatch => Kernel::Scalar,
1795                _ => unreachable!(),
1796            };
1797            ema_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1798        })
1799        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1800
1801    let dict = PyDict::new(py);
1802    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1803    dict.set_item(
1804        "periods",
1805        combos
1806            .iter()
1807            .map(|p| p.period.unwrap() as u64)
1808            .collect::<Vec<_>>()
1809            .into_pyarray(py),
1810    )?;
1811
1812    Ok(dict.into())
1813}
1814
1815#[cfg(all(feature = "python", feature = "cuda"))]
1816#[pyfunction(name = "ema_cuda_batch_dev")]
1817#[pyo3(signature = (data_f32, period_range=(9, 9, 0), device_id=0))]
1818pub fn ema_cuda_batch_dev_py(
1819    py: Python<'_>,
1820    data_f32: numpy::PyReadonlyArray1<'_, f32>,
1821    period_range: (usize, usize, usize),
1822    device_id: usize,
1823) -> PyResult<EmaDeviceArrayF32Py> {
1824    if !cuda_available() {
1825        return Err(PyValueError::new_err("CUDA not available"));
1826    }
1827
1828    let slice_in = data_f32.as_slice()?;
1829    let sweep = EmaBatchRange {
1830        period: period_range,
1831    };
1832
1833    let (buf, rows, cols, ctx, dev) = py.allow_threads(|| {
1834        let cuda = CudaEma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1835        let handle = cuda
1836            .ema_batch_dev(slice_in, &sweep)
1837            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1838        let ctx = cuda.context_arc();
1839        let dev = cuda.device_id();
1840        Ok::<_, PyErr>((handle.buf, handle.rows, handle.cols, ctx, dev))
1841    })?;
1842
1843    Ok(EmaDeviceArrayF32Py {
1844        buf: Some(buf),
1845        rows,
1846        cols,
1847        _ctx: ctx,
1848        device_id: dev,
1849    })
1850}
1851
1852#[cfg(all(feature = "python", feature = "cuda"))]
1853#[pyfunction(name = "ema_cuda_many_series_one_param_dev")]
1854#[pyo3(signature = (data_tm_f32, period, device_id=0))]
1855pub fn ema_cuda_many_series_one_param_dev_py(
1856    py: Python<'_>,
1857    data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1858    period: usize,
1859    device_id: usize,
1860) -> PyResult<EmaDeviceArrayF32Py> {
1861    if !cuda_available() {
1862        return Err(PyValueError::new_err("CUDA not available"));
1863    }
1864    if period == 0 {
1865        return Err(PyValueError::new_err("period must be positive"));
1866    }
1867
1868    let flat = data_tm_f32.as_slice()?;
1869    let shape = data_tm_f32.shape();
1870    let series_len = shape[0];
1871    let num_series = shape[1];
1872    let params = EmaParams {
1873        period: Some(period),
1874    };
1875
1876    let (buf, rows, cols, ctx, dev) = py.allow_threads(|| {
1877        let cuda = CudaEma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1878        let handle = cuda
1879            .ema_many_series_one_param_time_major_dev(flat, num_series, series_len, &params)
1880            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1881        let ctx = cuda.context_arc();
1882        let dev = cuda.device_id();
1883        Ok::<_, PyErr>((handle.buf, handle.rows, handle.cols, ctx, dev))
1884    })?;
1885
1886    Ok(EmaDeviceArrayF32Py {
1887        buf: Some(buf),
1888        rows,
1889        cols,
1890        _ctx: ctx,
1891        device_id: dev,
1892    })
1893}
1894
1895#[cfg(feature = "python")]
1896#[pyclass(name = "EmaStream")]
1897pub struct EmaStreamPy {
1898    inner: EmaStream,
1899}
1900
1901#[cfg(feature = "python")]
1902#[pymethods]
1903impl EmaStreamPy {
1904    #[new]
1905    pub fn new(period: usize) -> PyResult<Self> {
1906        let params = EmaParams {
1907            period: Some(period),
1908        };
1909        let inner = EmaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1910        Ok(Self { inner })
1911    }
1912
1913    pub fn update(&mut self, value: f64) -> f64 {
1914        self.inner.update(value).unwrap_or(f64::NAN)
1915    }
1916}
1917
1918#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1919use serde::{Deserialize, Serialize};
1920#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1921use wasm_bindgen::prelude::*;
1922
1923#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1924#[wasm_bindgen]
1925pub fn ema_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1926    let params = EmaParams {
1927        period: Some(period),
1928    };
1929    let input = EmaInput::from_slice(data, params);
1930
1931    let mut output = vec![0.0; data.len()];
1932
1933    ema_into_slice(&mut output, &input, Kernel::Auto)
1934        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1935
1936    Ok(output)
1937}
1938
1939#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1940#[derive(Serialize, Deserialize)]
1941pub struct EmaBatchConfig {
1942    pub period_range: (usize, usize, usize),
1943}
1944
1945#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1946#[derive(Serialize, Deserialize)]
1947pub struct EmaBatchJsOutput {
1948    pub values: Vec<f64>,
1949    pub combos: Vec<EmaParams>,
1950    pub rows: usize,
1951    pub cols: usize,
1952}
1953
1954#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1955#[wasm_bindgen(js_name = ema_batch)]
1956pub fn ema_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1957    let config: EmaBatchConfig = serde_wasm_bindgen::from_value(config)
1958        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1959
1960    let sweep = EmaBatchRange {
1961        period: config.period_range,
1962    };
1963
1964    let output = ema_batch_inner(data, &sweep, Kernel::Auto, false)
1965        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1966
1967    let js_output = EmaBatchJsOutput {
1968        values: output.values,
1969        combos: output.combos,
1970        rows: output.rows,
1971        cols: output.cols,
1972    };
1973
1974    serde_wasm_bindgen::to_value(&js_output)
1975        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1976}
1977
1978#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1979#[wasm_bindgen]
1980pub fn ema_batch_metadata_js(
1981    period_start: usize,
1982    period_end: usize,
1983    period_step: usize,
1984) -> Result<Vec<f64>, JsValue> {
1985    let sweep = EmaBatchRange {
1986        period: (period_start, period_end, period_step),
1987    };
1988
1989    let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1990    let mut metadata = Vec::with_capacity(combos.len());
1991
1992    for combo in combos {
1993        metadata.push(combo.period.unwrap() as f64);
1994    }
1995
1996    Ok(metadata)
1997}
1998
1999#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2000#[wasm_bindgen]
2001pub fn ema_alloc(len: usize) -> *mut f64 {
2002    let mut vec = Vec::<f64>::with_capacity(len);
2003    let ptr = vec.as_mut_ptr();
2004    std::mem::forget(vec);
2005    ptr
2006}
2007
2008#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2009#[wasm_bindgen]
2010pub fn ema_free(ptr: *mut f64, len: usize) {
2011    if !ptr.is_null() {
2012        unsafe {
2013            let _ = Vec::from_raw_parts(ptr, len, len);
2014        }
2015    }
2016}
2017
2018#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2019#[wasm_bindgen]
2020pub fn ema_into(
2021    in_ptr: *const f64,
2022    out_ptr: *mut f64,
2023    len: usize,
2024    period: usize,
2025) -> Result<(), JsValue> {
2026    if in_ptr.is_null() || out_ptr.is_null() {
2027        return Err(JsValue::from_str("null pointer passed to ema_into"));
2028    }
2029
2030    unsafe {
2031        let data = std::slice::from_raw_parts(in_ptr, len);
2032
2033        if period == 0 || period > len {
2034            return Err(JsValue::from_str("Invalid period"));
2035        }
2036
2037        let params = EmaParams {
2038            period: Some(period),
2039        };
2040        let input = EmaInput::from_slice(data, params);
2041
2042        if in_ptr == out_ptr {
2043            let mut temp = vec![0.0; len];
2044            ema_into_slice(&mut temp, &input, Kernel::Auto)
2045                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2046
2047            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2048            out.copy_from_slice(&temp);
2049        } else {
2050            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2051            ema_into_slice(out, &input, Kernel::Auto)
2052                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2053        }
2054
2055        Ok(())
2056    }
2057}
2058
2059#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2060#[wasm_bindgen]
2061pub fn ema_batch_into(
2062    in_ptr: *const f64,
2063    out_ptr: *mut f64,
2064    len: usize,
2065    period_start: usize,
2066    period_end: usize,
2067    period_step: usize,
2068) -> Result<usize, JsValue> {
2069    if in_ptr.is_null() || out_ptr.is_null() {
2070        return Err(JsValue::from_str("null pointer passed to ema_batch_into"));
2071    }
2072
2073    unsafe {
2074        let data = std::slice::from_raw_parts(in_ptr, len);
2075
2076        let sweep = EmaBatchRange {
2077            period: (period_start, period_end, period_step),
2078        };
2079
2080        let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2081        let rows = combos.len();
2082        let cols = len;
2083        let elems = rows
2084            .checked_mul(cols)
2085            .ok_or(JsValue::from_str("overflow rows*cols"))?;
2086        let out = std::slice::from_raw_parts_mut(out_ptr, elems);
2087
2088        let first = data
2089            .iter()
2090            .position(|x| !x.is_nan())
2091            .ok_or(JsValue::from_str("All NaN"))?;
2092        for r in 0..rows {
2093            let s = r * cols;
2094            out[s..s + first].fill(f64::NAN);
2095        }
2096
2097        ema_batch_inner_into(data, &sweep, Kernel::Auto, false, out)
2098            .map_err(|e| JsValue::from_str(&e.to_string()))?;
2099
2100        Ok(rows)
2101    }
2102}