Skip to main content

vector_ta/indicators/moving_averages/
epma.rs

1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::cuda_available;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::cuda::moving_averages::alma_wrapper::DeviceArrayF32;
5#[cfg(all(feature = "python", feature = "cuda"))]
6use crate::cuda::moving_averages::CudaEpma;
7use crate::utilities::data_loader::{source_type, Candles};
8use crate::utilities::enums::Kernel;
9use crate::utilities::helpers::{
10    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
11    make_uninit_matrix,
12};
13#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
14use core::arch::x86_64::*;
15#[cfg(all(feature = "python", feature = "cuda"))]
16use cust::context::Context;
17#[cfg(not(target_arch = "wasm32"))]
18use rayon::prelude::*;
19use std::convert::AsRef;
20use std::mem::{ManuallyDrop, MaybeUninit};
21#[cfg(all(feature = "python", feature = "cuda"))]
22use std::sync::Arc;
23use thiserror::Error;
24
25impl<'a> AsRef<[f64]> for EpmaInput<'a> {
26    #[inline(always)]
27    fn as_ref(&self) -> &[f64] {
28        match &self.data {
29            EpmaData::Slice(slice) => slice,
30            EpmaData::Candles { candles, source } => source_type(candles, source),
31        }
32    }
33}
34
35#[derive(Debug, Clone)]
36pub enum EpmaData<'a> {
37    Candles {
38        candles: &'a Candles,
39        source: &'a str,
40    },
41    Slice(&'a [f64]),
42}
43
44#[derive(Debug, Clone)]
45pub struct EpmaOutput {
46    pub values: Vec<f64>,
47}
48
49#[derive(Debug, Clone)]
50#[cfg_attr(
51    all(target_arch = "wasm32", feature = "wasm"),
52    derive(serde::Serialize, serde::Deserialize)
53)]
54pub struct EpmaParams {
55    pub period: Option<usize>,
56    pub offset: Option<usize>,
57}
58impl Default for EpmaParams {
59    fn default() -> Self {
60        Self {
61            period: Some(11),
62            offset: Some(4),
63        }
64    }
65}
66
67#[derive(Debug, Clone)]
68pub struct EpmaInput<'a> {
69    pub data: EpmaData<'a>,
70    pub params: EpmaParams,
71}
72
73impl<'a> EpmaInput<'a> {
74    #[inline]
75    pub fn from_candles(c: &'a Candles, s: &'a str, p: EpmaParams) -> Self {
76        Self {
77            data: EpmaData::Candles {
78                candles: c,
79                source: s,
80            },
81            params: p,
82        }
83    }
84    #[inline]
85    pub fn from_slice(sl: &'a [f64], p: EpmaParams) -> Self {
86        Self {
87            data: EpmaData::Slice(sl),
88            params: p,
89        }
90    }
91    #[inline]
92    pub fn with_default_candles(c: &'a Candles) -> Self {
93        Self::from_candles(c, "close", EpmaParams::default())
94    }
95    #[inline]
96    pub fn get_period(&self) -> usize {
97        self.params.period.unwrap_or(11)
98    }
99    #[inline]
100    pub fn get_offset(&self) -> usize {
101        self.params.offset.unwrap_or(4)
102    }
103}
104
105#[derive(Debug, Error)]
106pub enum EpmaError {
107    #[error("epma: Input data slice is empty.")]
108    EmptyInputData,
109
110    #[error("epma: All values are NaN.")]
111    AllValuesNaN,
112
113    #[error("epma: Invalid period: period = {period}, data length = {data_len}")]
114    InvalidPeriod { period: usize, data_len: usize },
115
116    #[error("epma: Invalid offset: {offset}")]
117    InvalidOffset { offset: usize },
118
119    #[error("epma: Not enough valid data: needed = {needed}, valid = {valid}")]
120    NotEnoughValidData { needed: usize, valid: usize },
121
122    #[error("epma: output length mismatch: expected = {expected}, got = {got}")]
123    OutputLengthMismatch { expected: usize, got: usize },
124
125    #[error("epma: Invalid kernel for batch operation: got {0:?}")]
126    InvalidKernelForBatch(Kernel),
127
128    #[error("epma: invalid range: start={start}, end={end}, step={step}")]
129    InvalidRange {
130        start: usize,
131        end: usize,
132        step: usize,
133    },
134
135    #[error("epma: size overflow computing rows*cols: rows={rows}, cols={cols}")]
136    SizeOverflow { rows: usize, cols: usize },
137}
138
139#[derive(Copy, Clone, Debug)]
140pub struct EpmaBuilder {
141    period: Option<usize>,
142    offset: Option<usize>,
143    kernel: Kernel,
144}
145impl Default for EpmaBuilder {
146    fn default() -> Self {
147        Self {
148            period: None,
149            offset: None,
150            kernel: Kernel::Auto,
151        }
152    }
153}
154impl EpmaBuilder {
155    #[inline(always)]
156    pub fn new() -> Self {
157        Self::default()
158    }
159    #[inline(always)]
160    pub fn period(mut self, n: usize) -> Self {
161        self.period = Some(n);
162        self
163    }
164    #[inline(always)]
165    pub fn offset(mut self, o: usize) -> Self {
166        self.offset = Some(o);
167        self
168    }
169    #[inline(always)]
170    pub fn kernel(mut self, k: Kernel) -> Self {
171        self.kernel = k;
172        self
173    }
174    #[inline(always)]
175    pub fn apply(self, c: &Candles) -> Result<EpmaOutput, EpmaError> {
176        let p = EpmaParams {
177            period: self.period,
178            offset: self.offset,
179        };
180        let i = EpmaInput::from_candles(c, "close", p);
181        epma_with_kernel(&i, self.kernel)
182    }
183    #[inline(always)]
184    pub fn apply_slice(self, d: &[f64]) -> Result<EpmaOutput, EpmaError> {
185        let p = EpmaParams {
186            period: self.period,
187            offset: self.offset,
188        };
189        let i = EpmaInput::from_slice(d, p);
190        epma_with_kernel(&i, self.kernel)
191    }
192    #[inline(always)]
193    pub fn into_stream(self) -> Result<EpmaStream, EpmaError> {
194        let p = EpmaParams {
195            period: self.period,
196            offset: self.offset,
197        };
198        EpmaStream::try_new(p)
199    }
200}
201
202#[inline]
203pub fn epma(input: &EpmaInput) -> Result<EpmaOutput, EpmaError> {
204    epma_with_kernel(input, Kernel::Auto)
205}
206
207#[inline(always)]
208fn epma_prepare<'a>(
209    input: &'a EpmaInput,
210    kernel: Kernel,
211) -> Result<(&'a [f64], usize, usize, usize, usize, Kernel), EpmaError> {
212    let data: &[f64] = input.as_ref();
213    let len = data.len();
214    if len == 0 {
215        return Err(EpmaError::EmptyInputData);
216    }
217
218    let first = data
219        .iter()
220        .position(|x| !x.is_nan())
221        .ok_or(EpmaError::AllValuesNaN)?;
222    let period = input.get_period();
223    let offset = input.get_offset();
224
225    if offset >= period {
226        return Err(EpmaError::InvalidOffset { offset });
227    }
228
229    if period < 2 || period > len {
230        return Err(EpmaError::InvalidPeriod {
231            period,
232            data_len: len,
233        });
234    }
235    let needed = period + offset + 1;
236    if (len - first) < needed {
237        return Err(EpmaError::NotEnoughValidData {
238            needed,
239            valid: len - first,
240        });
241    }
242
243    let chosen = match kernel {
244        Kernel::Auto => detect_best_kernel(),
245        other => other,
246    };
247    let warmup = first + period + offset + 1;
248
249    Ok((data, period, offset, first, warmup, chosen))
250}
251
252#[inline(always)]
253fn epma_compute_into(
254    data: &[f64],
255    period: usize,
256    offset: usize,
257    first: usize,
258    kernel: Kernel,
259    out: &mut [f64],
260) {
261    unsafe {
262        #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
263        {
264            if matches!(kernel, Kernel::Scalar | Kernel::ScalarBatch) {
265                epma_simd128(data, period, offset, first, out);
266                return;
267            }
268        }
269
270        match kernel {
271            Kernel::Scalar | Kernel::ScalarBatch => epma_scalar(data, period, offset, first, out),
272            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
273            Kernel::Avx2 | Kernel::Avx2Batch => epma_avx2(data, period, offset, first, out),
274            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
275            Kernel::Avx512 | Kernel::Avx512Batch => epma_avx512(data, period, offset, first, out),
276            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
277            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
278                epma_scalar(data, period, offset, first, out)
279            }
280            _ => unreachable!(),
281        }
282    }
283}
284
285pub fn epma_with_kernel(input: &EpmaInput, kernel: Kernel) -> Result<EpmaOutput, EpmaError> {
286    let (data, period, offset, first, warmup, chosen) = epma_prepare(input, kernel)?;
287
288    let mut out = alloc_with_nan_prefix(data.len(), warmup);
289    epma_compute_into(data, period, offset, first, chosen, &mut out);
290
291    Ok(EpmaOutput { values: out })
292}
293
294#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
295#[inline]
296pub fn epma_into(input: &EpmaInput, out: &mut [f64]) -> Result<(), EpmaError> {
297    let (data, period, offset, first, warmup, chosen) = epma_prepare(input, Kernel::Auto)?;
298
299    if out.len() != data.len() {
300        return Err(EpmaError::OutputLengthMismatch {
301            expected: data.len(),
302            got: out.len(),
303        });
304    }
305
306    let w = warmup.min(out.len());
307    for v in &mut out[..w] {
308        *v = f64::from_bits(0x7ff8_0000_0000_0000);
309    }
310
311    epma_compute_into(data, period, offset, first, chosen, out);
312    Ok(())
313}
314
315#[inline]
316pub fn epma_into_slice(dst: &mut [f64], input: &EpmaInput, kern: Kernel) -> Result<(), EpmaError> {
317    let (data, period, offset, first, warmup, chosen) = epma_prepare(input, kern)?;
318
319    if dst.len() != data.len() {
320        return Err(EpmaError::OutputLengthMismatch {
321            expected: data.len(),
322            got: dst.len(),
323        });
324    }
325
326    epma_compute_into(data, period, offset, first, chosen, dst);
327
328    for v in &mut dst[..warmup] {
329        *v = f64::from_bits(0x7ff8_0000_0000_0000);
330    }
331
332    Ok(())
333}
334
335#[inline(always)]
336pub fn epma_scalar(
337    data: &[f64],
338    period: usize,
339    offset: usize,
340    first_valid: usize,
341    out: &mut [f64],
342) {
343    let n = data.len();
344    let p1 = period - 1;
345
346    let c0 = 2.0 - (offset as f64);
347    let p1f = p1 as f64;
348    let weight_sum = p1f.mul_add(c0, 0.5 * (p1f - 1.0) * p1f);
349    let inv = 1.0 / weight_sum;
350
351    for j in (first_valid + period + offset + 1)..n {
352        let start = j + 1 - p1;
353
354        let mut sum = 0.0;
355        let mut i = 0usize;
356        let mut wi = c0;
357
358        while i + 3 < p1 {
359            sum = data[start + i].mul_add(wi, sum);
360            sum = data[start + i + 1].mul_add(wi + 1.0, sum);
361            sum = data[start + i + 2].mul_add(wi + 2.0, sum);
362            sum = data[start + i + 3].mul_add(wi + 3.0, sum);
363            i += 4;
364            wi += 4.0;
365        }
366
367        while i < p1 {
368            sum = data[start + i].mul_add(wi, sum);
369            i += 1;
370            wi += 1.0;
371        }
372
373        out[j] = sum * inv;
374    }
375}
376
377#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
378#[inline]
379unsafe fn epma_simd128(
380    data: &[f64],
381    period: usize,
382    offset: usize,
383    first_valid: usize,
384    out: &mut [f64],
385) {
386    use core::arch::wasm32::*;
387
388    const STEP: usize = 2;
389    let n = data.len();
390    let p1 = period - 1;
391
392    let mut weights = Vec::with_capacity(p1);
393    let mut weight_sum = 0.0;
394    for i in 0..p1 {
395        let w = (period as i32 - i as i32 - offset as i32) as f64;
396        weights.push(w);
397        weight_sum += w;
398    }
399
400    let chunks = p1 / STEP;
401    let tail = p1 % STEP;
402
403    for j in (first_valid + period + offset + 1)..n {
404        let start = j + 1 - p1;
405        let mut acc = f64x2_splat(0.0);
406
407        for blk in 0..chunks {
408            let idx = blk * STEP;
409
410            let w0 = weights[p1 - 1 - idx];
411            let w1 = weights[p1 - 2 - idx];
412            let w = f64x2(w0, w1);
413
414            let d = v128_load(data.as_ptr().add(start + idx) as *const v128);
415            acc = f64x2_add(acc, f64x2_mul(d, w));
416        }
417
418        let mut sum = f64x2_extract_lane::<0>(acc) + f64x2_extract_lane::<1>(acc);
419
420        if tail != 0 {
421            sum += data[start + p1 - 1] * weights[0];
422        }
423
424        out[j] = sum / weight_sum;
425    }
426}
427
428#[inline(always)]
429fn build_weights_rev(period: usize, offset: usize) -> (Vec<f64>, f64) {
430    let p1 = period - 1;
431    let mut w = Vec::with_capacity(p1);
432    let mut sum = 0.0;
433
434    for k in 0..p1 {
435        let val = (period as isize - (p1 - 1 - k) as isize - offset as isize) as f64;
436        w.push(val);
437        sum += val;
438    }
439    (w, sum)
440}
441
442#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
443#[target_feature(enable = "avx2,fma")]
444pub unsafe fn epma_avx2(
445    data: &[f64],
446    period: usize,
447    offset: usize,
448    first_valid: usize,
449    out: &mut [f64],
450) {
451    const STEP: usize = 4;
452    let p1 = period - 1;
453    let chunks = p1 / STEP;
454    let tail = p1 % STEP;
455    let mask = match tail {
456        0 => _mm256_setzero_si256(),
457        1 => _mm256_setr_epi64x(-1, 0, 0, 0),
458        2 => _mm256_setr_epi64x(-1, -1, 0, 0),
459        3 => _mm256_setr_epi64x(-1, -1, -1, 0),
460        _ => unreachable!(),
461    };
462
463    let c0 = 2.0 - (offset as f64);
464    let p1f = p1 as f64;
465    let wsum = p1f.mul_add(c0, 0.5 * (p1f - 1.0) * p1f);
466    let inv = 1.0 / wsum;
467
468    let ramp = _mm256_setr_pd(0.0, 1.0, 2.0, 3.0);
469
470    for j in (first_valid + period + offset + 1)..data.len() {
471        let start = j + 1 - p1;
472        let base_ptr = data.as_ptr().add(start);
473
474        let mut acc = _mm256_setzero_pd();
475
476        for blk in 0..chunks {
477            let base_w = c0 + (blk * STEP) as f64;
478            let w = _mm256_add_pd(_mm256_set1_pd(base_w), ramp);
479            let d = _mm256_loadu_pd(base_ptr.add(blk * STEP));
480            acc = _mm256_fmadd_pd(d, w, acc);
481        }
482
483        if tail != 0 {
484            let base_w = c0 + (chunks * STEP) as f64;
485            let w_t = _mm256_add_pd(_mm256_set1_pd(base_w), ramp);
486            let d_t = _mm256_maskload_pd(base_ptr.add(chunks * STEP), mask);
487            acc = _mm256_fmadd_pd(d_t, w_t, acc);
488        }
489
490        let hi = _mm256_extractf128_pd(acc, 1);
491        let lo = _mm256_castpd256_pd128(acc);
492        let s2 = _mm_add_pd(hi, lo);
493        let s1 = _mm_add_pd(s2, _mm_unpackhi_pd(s2, s2));
494        *out.get_unchecked_mut(j) = _mm_cvtsd_f64(s1) * inv;
495    }
496}
497
498#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
499#[inline]
500pub unsafe fn epma_avx512(
501    data: &[f64],
502    period: usize,
503    offset: usize,
504    first_valid: usize,
505    out: &mut [f64],
506) {
507    if period <= 32 {
508        epma_avx512_short(data, period, offset, first_valid, out)
509    } else {
510        epma_avx512_long(data, period, offset, first_valid, out)
511    }
512}
513#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
514#[target_feature(enable = "avx512f,avx512dq,fma")]
515#[inline]
516unsafe fn epma_avx512_short(
517    data: &[f64],
518    period: usize,
519    offset: usize,
520    first_valid: usize,
521    out: &mut [f64],
522) {
523    const STEP: usize = 8;
524    let p1 = period - 1;
525    let chunks = p1 / STEP;
526    let tail = p1 % STEP;
527    let tmask: __mmask8 = (1u8 << tail).wrapping_sub(1);
528
529    let c0 = 2.0 - (offset as f64);
530    let p1f = p1 as f64;
531    let wsum = p1f.mul_add(c0, 0.5 * (p1f - 1.0) * p1f);
532    let inv = 1.0 / wsum;
533
534    let ramp = _mm512_setr_pd(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0);
535
536    for j in (first_valid + period + offset + 1)..data.len() {
537        let start = j + 1 - p1;
538        let base_ptr = data.as_ptr().add(start);
539        let mut acc = _mm512_setzero_pd();
540
541        for blk in 0..chunks {
542            let base_w = c0 + (blk * STEP) as f64;
543            let w = _mm512_add_pd(_mm512_set1_pd(base_w), ramp);
544            let d = _mm512_loadu_pd(base_ptr.add(blk * STEP));
545            acc = _mm512_fmadd_pd(d, w, acc);
546        }
547
548        if tail != 0 {
549            let base_w = c0 + (chunks * STEP) as f64;
550            let w_t = _mm512_add_pd(_mm512_set1_pd(base_w), ramp);
551            let d_t = _mm512_maskz_loadu_pd(tmask, base_ptr.add(chunks * STEP));
552            acc = _mm512_fmadd_pd(d_t, w_t, acc);
553        }
554
555        *out.get_unchecked_mut(j) = _mm512_reduce_add_pd(acc) * inv;
556    }
557}
558#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
559#[target_feature(enable = "avx512f,avx512dq,fma")]
560unsafe fn epma_avx512_long(
561    data: &[f64],
562    period: usize,
563    offset: usize,
564    first_valid: usize,
565    out: &mut [f64],
566) {
567    const STEP: usize = 8;
568    let p1 = period - 1;
569    let n_chunks = p1 / STEP;
570    let tail_len = p1 % STEP;
571    let paired = n_chunks & !3;
572    let tmask: __mmask8 = (1u8 << tail_len).wrapping_sub(1);
573
574    let c0 = 2.0 - (offset as f64);
575    let p1f = p1 as f64;
576    let wsum = p1f.mul_add(c0, 0.5 * (p1f - 1.0) * p1f);
577    let inv = 1.0 / wsum;
578
579    let ramp = _mm512_setr_pd(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0);
580
581    for j in (first_valid + period + offset + 1)..data.len() {
582        let start_ptr = data.as_ptr().add(j + 1 - p1);
583
584        let mut s0 = _mm512_setzero_pd();
585        let mut s1 = _mm512_setzero_pd();
586        let mut s2 = _mm512_setzero_pd();
587        let mut s3 = _mm512_setzero_pd();
588
589        let mut blk = 0usize;
590        while blk < paired {
591            let base0 = c0 + ((blk + 0) * STEP) as f64;
592            let base1 = c0 + ((blk + 1) * STEP) as f64;
593            let base2 = c0 + ((blk + 2) * STEP) as f64;
594            let base3 = c0 + ((blk + 3) * STEP) as f64;
595
596            let w0 = _mm512_add_pd(_mm512_set1_pd(base0), ramp);
597            let w1 = _mm512_add_pd(_mm512_set1_pd(base1), ramp);
598            let w2 = _mm512_add_pd(_mm512_set1_pd(base2), ramp);
599            let w3 = _mm512_add_pd(_mm512_set1_pd(base3), ramp);
600
601            let d0 = _mm512_loadu_pd(start_ptr.add((blk + 0) * STEP));
602            let d1 = _mm512_loadu_pd(start_ptr.add((blk + 1) * STEP));
603            let d2 = _mm512_loadu_pd(start_ptr.add((blk + 2) * STEP));
604            let d3 = _mm512_loadu_pd(start_ptr.add((blk + 3) * STEP));
605
606            s0 = _mm512_fmadd_pd(d0, w0, s0);
607            s1 = _mm512_fmadd_pd(d1, w1, s1);
608            s2 = _mm512_fmadd_pd(d2, w2, s2);
609            s3 = _mm512_fmadd_pd(d3, w3, s3);
610
611            blk += 4;
612        }
613
614        while blk < n_chunks {
615            let base = c0 + (blk * STEP) as f64;
616            let w = _mm512_add_pd(_mm512_set1_pd(base), ramp);
617            let d = _mm512_loadu_pd(start_ptr.add(blk * STEP));
618            s0 = _mm512_fmadd_pd(d, w, s0);
619            blk += 1;
620        }
621
622        if tail_len != 0 {
623            let base = c0 + (n_chunks * STEP) as f64;
624            let w_t = _mm512_add_pd(_mm512_set1_pd(base), ramp);
625            let d_t = _mm512_maskz_loadu_pd(tmask, start_ptr.add(n_chunks * STEP));
626            s0 = _mm512_fmadd_pd(d_t, w_t, s0);
627        }
628
629        let total = _mm512_add_pd(_mm512_add_pd(s0, s1), _mm512_add_pd(s2, s3));
630        *out.get_unchecked_mut(j) = _mm512_reduce_add_pd(total) * inv;
631    }
632}
633
634#[derive(Debug, Clone)]
635pub struct EpmaStream {
636    period: usize,
637    offset: usize,
638    p1: usize,
639
640    buffer: Vec<f64>,
641    head: usize,
642
643    seen: usize,
644    included: usize,
645
646    sum: f64,
647    sum_c: f64,
648    ramp: f64,
649    ramp_c: f64,
650
651    c0: f64,
652    inv_wsum: f64,
653}
654
655#[inline(always)]
656fn kahan_add(sum: &mut f64, c: &mut f64, x: f64) {
657    let y = x - *c;
658    let t = *sum + y;
659    *c = (t - *sum) - y;
660    *sum = t;
661}
662
663impl EpmaStream {
664    pub fn try_new(params: EpmaParams) -> Result<Self, EpmaError> {
665        let period = params.period.unwrap_or(11);
666        let offset = params.offset.unwrap_or(4);
667
668        if period < 2 {
669            return Err(EpmaError::InvalidPeriod {
670                period,
671                data_len: 0,
672            });
673        }
674        if offset >= period {
675            return Err(EpmaError::InvalidOffset { offset });
676        }
677
678        let p1 = period - 1;
679        let c0 = 2.0 - (offset as f64);
680
681        let p1f = p1 as f64;
682        let wsum = p1f.mul_add(c0, 0.5 * (p1f - 1.0) * p1f);
683        let inv_wsum = 1.0 / wsum;
684
685        Ok(Self {
686            period,
687            offset,
688            p1,
689
690            buffer: vec![0.0; period],
691            head: 0,
692
693            seen: 0,
694            included: 0,
695
696            sum: 0.0,
697            sum_c: 0.0,
698            ramp: 0.0,
699            ramp_c: 0.0,
700
701            c0,
702            inv_wsum,
703        })
704    }
705
706    #[inline(always)]
707    pub fn update(&mut self, value: f64) -> Option<f64> {
708        let p = self.period;
709        let p1m1 = (self.p1 - 1) as f64;
710
711        let idx_out = (self.head + 1) % p;
712        let x_out = if self.included == self.p1 {
713            self.buffer[idx_out]
714        } else {
715            0.0
716        };
717
718        self.buffer[self.head] = value;
719        self.head = (self.head + 1) % p;
720        self.seen += 1;
721
722        if self.included < self.p1 {
723            let m = self.included as f64;
724
725            kahan_add(&mut self.sum, &mut self.sum_c, value);
726
727            kahan_add(&mut self.ramp, &mut self.ramp_c, m * value);
728
729            self.included += 1;
730        } else {
731            let s_old = self.sum;
732
733            let delta_s = value - x_out;
734            kahan_add(&mut self.sum, &mut self.sum_c, delta_s);
735
736            let delta_r = p1m1.mul_add(value, x_out - s_old);
737            kahan_add(&mut self.ramp, &mut self.ramp_c, delta_r);
738        }
739
740        if self.seen <= (self.period + self.offset + 1) {
741            return Some(value);
742        }
743
744        let num = self.c0.mul_add(self.sum, self.ramp);
745        Some(num * self.inv_wsum)
746    }
747}
748
749#[derive(Clone, Debug)]
750pub struct EpmaBatchRange {
751    pub period: (usize, usize, usize),
752    pub offset: (usize, usize, usize),
753}
754impl Default for EpmaBatchRange {
755    fn default() -> Self {
756        Self {
757            period: (11, 260, 1),
758            offset: (4, 4, 0),
759        }
760    }
761}
762#[derive(Clone, Debug, Default)]
763pub struct EpmaBatchBuilder {
764    range: EpmaBatchRange,
765    kernel: Kernel,
766}
767impl EpmaBatchBuilder {
768    pub fn new() -> Self {
769        Self::default()
770    }
771    pub fn kernel(mut self, k: Kernel) -> Self {
772        self.kernel = k;
773        self
774    }
775    #[inline]
776    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
777        self.range.period = (start, end, step);
778        self
779    }
780    #[inline]
781    pub fn period_static(mut self, p: usize) -> Self {
782        self.range.period = (p, p, 0);
783        self
784    }
785    #[inline]
786    pub fn offset_range(mut self, start: usize, end: usize, step: usize) -> Self {
787        self.range.offset = (start, end, step);
788        self
789    }
790    #[inline]
791    pub fn offset_static(mut self, o: usize) -> Self {
792        self.range.offset = (o, o, 0);
793        self
794    }
795    pub fn apply_slice(self, data: &[f64]) -> Result<EpmaBatchOutput, EpmaError> {
796        epma_batch_with_kernel(data, &self.range, self.kernel)
797    }
798    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<EpmaBatchOutput, EpmaError> {
799        EpmaBatchBuilder::new().kernel(k).apply_slice(data)
800    }
801    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<EpmaBatchOutput, EpmaError> {
802        let slice = source_type(c, src);
803        self.apply_slice(slice)
804    }
805    pub fn with_default_candles(c: &Candles) -> Result<EpmaBatchOutput, EpmaError> {
806        EpmaBatchBuilder::new()
807            .kernel(Kernel::Auto)
808            .apply_candles(c, "close")
809    }
810}
811
812#[derive(Clone, Debug)]
813pub struct EpmaBatchOutput {
814    pub values: Vec<f64>,
815    pub combos: Vec<EpmaParams>,
816    pub rows: usize,
817    pub cols: usize,
818}
819impl EpmaBatchOutput {
820    pub fn row_for_params(&self, p: &EpmaParams) -> Option<usize> {
821        self.combos.iter().position(|c| {
822            c.period.unwrap_or(11) == p.period.unwrap_or(11)
823                && c.offset.unwrap_or(4) == p.offset.unwrap_or(4)
824        })
825    }
826    pub fn values_for(&self, p: &EpmaParams) -> Option<&[f64]> {
827        self.row_for_params(p).map(|row| {
828            let start = row * self.cols;
829            &self.values[start..start + self.cols]
830        })
831    }
832}
833
834#[inline(always)]
835fn expand_grid(r: &EpmaBatchRange) -> Vec<EpmaParams> {
836    fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
837        if step == 0 {
838            return vec![start];
839        }
840        let (lo, hi) = if start <= end {
841            (start, end)
842        } else {
843            (end, start)
844        };
845        (lo..=hi).step_by(step).collect()
846    }
847    let periods = axis_usize(r.period);
848    let offsets = axis_usize(r.offset);
849    let mut out = Vec::with_capacity(periods.len() * offsets.len());
850    for &p in &periods {
851        for &o in &offsets {
852            out.push(EpmaParams {
853                period: Some(p),
854                offset: Some(o),
855            });
856        }
857    }
858    out
859}
860
861#[inline(always)]
862pub fn epma_batch_with_kernel(
863    data: &[f64],
864    sweep: &EpmaBatchRange,
865    k: Kernel,
866) -> Result<EpmaBatchOutput, EpmaError> {
867    let kernel = match k {
868        Kernel::Auto => detect_best_batch_kernel(),
869        other if other.is_batch() => other,
870        other => return Err(EpmaError::InvalidKernelForBatch(other)),
871    };
872    let simd = match kernel {
873        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
874        Kernel::Avx512Batch => Kernel::Avx512,
875        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
876        Kernel::Avx2Batch => Kernel::Avx2,
877        Kernel::ScalarBatch => Kernel::Scalar,
878        _ => unreachable!(),
879    };
880
881    epma_batch_par_slice(data, sweep, simd)
882}
883
884#[inline(always)]
885pub fn epma_batch_slice(
886    data: &[f64],
887    sweep: &EpmaBatchRange,
888    kern: Kernel,
889) -> Result<EpmaBatchOutput, EpmaError> {
890    epma_batch_inner(data, sweep, kern, false)
891}
892#[inline(always)]
893pub fn epma_batch_par_slice(
894    data: &[f64],
895    sweep: &EpmaBatchRange,
896    kern: Kernel,
897) -> Result<EpmaBatchOutput, EpmaError> {
898    epma_batch_inner(data, sweep, kern, true)
899}
900#[inline(always)]
901fn epma_batch_inner(
902    data: &[f64],
903    sweep: &EpmaBatchRange,
904    kern: Kernel,
905    parallel: bool,
906) -> Result<EpmaBatchOutput, EpmaError> {
907    let combos = expand_grid(sweep);
908    let rows = combos.len();
909    let cols = data.len();
910
911    let _total_cells = rows
912        .checked_mul(cols)
913        .ok_or(EpmaError::SizeOverflow { rows, cols })?;
914
915    let mut buf_mu = make_uninit_matrix(rows, cols);
916
917    let combos = epma_batch_inner_into_uninit(data, sweep, kern, parallel, &mut buf_mu)?;
918
919    let values = unsafe {
920        let mut buf_guard = ManuallyDrop::new(buf_mu);
921        Vec::from_raw_parts(
922            buf_guard.as_mut_ptr() as *mut f64,
923            buf_guard.len(),
924            buf_guard.capacity(),
925        )
926    };
927
928    Ok(EpmaBatchOutput {
929        values,
930        combos,
931        rows,
932        cols,
933    })
934}
935
936#[cfg(any(feature = "python", feature = "wasm"))]
937#[inline(always)]
938pub fn epma_batch_inner_into(
939    data: &[f64],
940    sweep: &EpmaBatchRange,
941    kern: Kernel,
942    parallel: bool,
943    out: &mut [f64],
944) -> Result<Vec<EpmaParams>, EpmaError> {
945    let buf_mu = unsafe {
946        std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
947    };
948    epma_batch_inner_into_uninit(data, sweep, kern, parallel, buf_mu)
949}
950
951#[inline(always)]
952fn epma_batch_inner_into_uninit(
953    data: &[f64],
954    sweep: &EpmaBatchRange,
955    kern: Kernel,
956    parallel: bool,
957    buf_mu: &mut [MaybeUninit<f64>],
958) -> Result<Vec<EpmaParams>, EpmaError> {
959    if data.is_empty() {
960        return Err(EpmaError::EmptyInputData);
961    }
962    let combos = expand_grid(sweep);
963    if combos.is_empty() {
964        return Err(EpmaError::InvalidRange {
965            start: 0,
966            end: 0,
967            step: 0,
968        });
969    }
970    for c in &combos {
971        let p = c.period.unwrap();
972        let o = c.offset.unwrap();
973        if p < 2 {
974            return Err(EpmaError::InvalidPeriod {
975                period: p,
976                data_len: data.len(),
977            });
978        }
979        if o >= p {
980            return Err(EpmaError::InvalidOffset { offset: o });
981        }
982    }
983    let first = data
984        .iter()
985        .position(|x| !x.is_nan())
986        .ok_or(EpmaError::AllValuesNaN)?;
987    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
988    let max_off = combos.iter().map(|c| c.offset.unwrap()).max().unwrap();
989    let needed = max_p + max_off + 1;
990    if data.len() - first < needed {
991        return Err(EpmaError::NotEnoughValidData {
992            needed,
993            valid: data.len() - first,
994        });
995    }
996    let cols = data.len();
997
998    let warm: Vec<usize> = combos
999        .iter()
1000        .map(|c| first + c.period.unwrap() + c.offset.unwrap() + 1)
1001        .collect();
1002    init_matrix_prefixes(buf_mu, cols, &warm);
1003
1004    let use_prefix = {
1005        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1006        {
1007            matches!(kern, Kernel::Avx2 | Kernel::Avx512)
1008        }
1009        #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1010        {
1011            false
1012        }
1013    };
1014
1015    if use_prefix {
1016        let mut p: Vec<f64> = Vec::with_capacity(cols);
1017        let mut q: Vec<f64> = Vec::with_capacity(cols);
1018        let mut ps = 0.0f64;
1019        let mut qs = 0.0f64;
1020        for (idx, &x) in data.iter().enumerate() {
1021            ps += x;
1022            qs = (idx as f64).mul_add(x, qs);
1023            p.push(ps);
1024            q.push(qs);
1025        }
1026
1027        let do_row_prefix = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| {
1028            let period = combos[row].period.unwrap();
1029            let offset = combos[row].offset.unwrap();
1030            let p1 = period - 1;
1031            let c0 = 2.0 - (offset as f64);
1032            let p1f = p1 as f64;
1033            let wsum = p1f.mul_add(c0, 0.5 * (p1f - 1.0) * p1f);
1034            let inv = 1.0 / wsum;
1035            let warmup = warm[row];
1036
1037            let dst = unsafe {
1038                core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len())
1039            };
1040
1041            for j in warmup..cols {
1042                let a = j + 1 - p1;
1043                let b = j;
1044
1045                let sumx = if a > 0 { p[b] - p[a - 1] } else { p[b] };
1046
1047                let sum_abs = if a > 0 { q[b] - q[a - 1] } else { q[b] };
1048
1049                let sum_ix = sum_abs - (a as f64) * sumx;
1050                let y = (c0 * sumx + sum_ix) * inv;
1051
1052                unsafe {
1053                    *dst.get_unchecked_mut(j) = y;
1054                }
1055            }
1056        };
1057
1058        if parallel {
1059            #[cfg(not(target_arch = "wasm32"))]
1060            {
1061                buf_mu
1062                    .par_chunks_mut(cols)
1063                    .enumerate()
1064                    .for_each(|(row, slice)| do_row_prefix(row, slice));
1065            }
1066            #[cfg(target_arch = "wasm32")]
1067            {
1068                for (row, slice) in buf_mu.chunks_mut(cols).enumerate() {
1069                    do_row_prefix(row, slice);
1070                }
1071            }
1072        } else {
1073            for (row, slice) in buf_mu.chunks_mut(cols).enumerate() {
1074                do_row_prefix(row, slice);
1075            }
1076        }
1077    } else {
1078        let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
1079            let period = combos[row].period.unwrap();
1080            let offset = combos[row].offset.unwrap();
1081
1082            let dst =
1083                core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
1084
1085            match kern {
1086                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1087                Kernel::Avx512 => epma_row_avx512(data, first, period, offset, dst),
1088                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1089                Kernel::Avx2 => epma_row_avx2(data, first, period, offset, dst),
1090                _ => epma_row_scalar(data, first, period, offset, dst),
1091            }
1092        };
1093
1094        if parallel {
1095            #[cfg(not(target_arch = "wasm32"))]
1096            {
1097                buf_mu
1098                    .par_chunks_mut(cols)
1099                    .enumerate()
1100                    .for_each(|(row, slice)| do_row(row, slice));
1101            }
1102            #[cfg(target_arch = "wasm32")]
1103            {
1104                for (row, slice) in buf_mu.chunks_mut(cols).enumerate() {
1105                    do_row(row, slice);
1106                }
1107            }
1108        } else {
1109            for (row, slice) in buf_mu.chunks_mut(cols).enumerate() {
1110                do_row(row, slice);
1111            }
1112        }
1113    }
1114
1115    Ok(combos)
1116}
1117
1118#[inline(always)]
1119unsafe fn epma_row_scalar(
1120    data: &[f64],
1121    first: usize,
1122    period: usize,
1123    offset: usize,
1124    out: &mut [f64],
1125) {
1126    epma_scalar(data, period, offset, first, out);
1127}
1128
1129#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1130#[target_feature(enable = "avx2,fma")]
1131unsafe fn epma_row_avx2(data: &[f64], first: usize, period: usize, offset: usize, out: &mut [f64]) {
1132    epma_avx2(data, period, offset, first, out);
1133}
1134
1135#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1136#[target_feature(enable = "avx512f,avx512dq,fma")]
1137unsafe fn epma_row_avx512(
1138    data: &[f64],
1139    first: usize,
1140    period: usize,
1141    offset: usize,
1142    out: &mut [f64],
1143) {
1144    epma_avx512(data, period, offset, first, out);
1145}
1146
1147#[cfg(test)]
1148mod tests {
1149    use super::*;
1150    use crate::skip_if_unsupported;
1151    use crate::utilities::data_loader::read_candles_from_csv;
1152    use std::error::Error;
1153
1154    #[test]
1155    fn test_epma_into_matches_api() -> Result<(), Box<dyn Error>> {
1156        let mut data = Vec::with_capacity(256);
1157        data.extend_from_slice(&[f64::NAN, f64::NAN, f64::NAN]);
1158        for i in 0..253u32 {
1159            let v = (i as f64).sin() * 10.0 + (i as f64) * 0.01;
1160            data.push(v);
1161        }
1162
1163        let input = EpmaInput::from_slice(&data, EpmaParams::default());
1164
1165        let baseline = epma(&input)?.values;
1166
1167        let mut out = vec![0.0; data.len()];
1168        #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1169        {
1170            epma_into(&input, &mut out)?;
1171        }
1172
1173        fn eq_or_both_nan(a: f64, b: f64) -> bool {
1174            (a.is_nan() && b.is_nan()) || (a == b)
1175        }
1176
1177        assert_eq!(baseline.len(), out.len());
1178        for i in 0..out.len() {
1179            assert!(
1180                eq_or_both_nan(baseline[i], out[i]),
1181                "Mismatch at {}: baseline={} out={}",
1182                i,
1183                baseline[i],
1184                out[i]
1185            );
1186        }
1187        Ok(())
1188    }
1189
1190    fn check_epma_partial_params(
1191        test_name: &str,
1192        kernel: Kernel,
1193    ) -> Result<(), Box<dyn std::error::Error>> {
1194        skip_if_unsupported!(kernel, test_name);
1195        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1196        let candles = read_candles_from_csv(file_path)?;
1197
1198        let default_params = EpmaParams {
1199            period: None,
1200            offset: None,
1201        };
1202        let input = EpmaInput::from_candles(&candles, "close", default_params);
1203        let output = epma_with_kernel(&input, kernel)?;
1204        assert_eq!(output.values.len(), candles.close.len());
1205        Ok(())
1206    }
1207    fn check_epma_accuracy(
1208        test_name: &str,
1209        kernel: Kernel,
1210    ) -> Result<(), Box<dyn std::error::Error>> {
1211        skip_if_unsupported!(kernel, test_name);
1212        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1213        let candles = read_candles_from_csv(file_path)?;
1214        let default_params = EpmaParams::default();
1215        let input = EpmaInput::from_candles(&candles, "close", default_params);
1216        let result = epma_with_kernel(&input, kernel)?;
1217        let expected_last_five = [59174.48, 59201.04, 59167.60, 59200.32, 59117.04];
1218        let start_index = result.values.len().saturating_sub(5);
1219        let result_last_five = &result.values[start_index..];
1220        for (i, &value) in result_last_five.iter().enumerate() {
1221            assert!(
1222                (value - expected_last_five[i]).abs() < 1e-1,
1223                "[{}] EPMA {:?} mismatch at idx {}: got {}, expected {}",
1224                test_name,
1225                kernel,
1226                i,
1227                value,
1228                expected_last_five[i]
1229            );
1230        }
1231        Ok(())
1232    }
1233    fn check_epma_default_candles(
1234        test_name: &str,
1235        kernel: Kernel,
1236    ) -> Result<(), Box<dyn std::error::Error>> {
1237        skip_if_unsupported!(kernel, test_name);
1238        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1239        let candles = read_candles_from_csv(file_path)?;
1240        let input = EpmaInput::with_default_candles(&candles);
1241        match input.data {
1242            EpmaData::Candles { source, .. } => assert_eq!(source, "close"),
1243            _ => panic!("Expected EpmaData::Candles"),
1244        }
1245        let output = epma_with_kernel(&input, kernel)?;
1246        assert_eq!(output.values.len(), candles.close.len());
1247        Ok(())
1248    }
1249    fn check_epma_zero_period(
1250        test_name: &str,
1251        kernel: Kernel,
1252    ) -> Result<(), Box<dyn std::error::Error>> {
1253        skip_if_unsupported!(kernel, test_name);
1254        let input_data = [10.0, 20.0, 30.0];
1255        let params = EpmaParams {
1256            period: Some(0),
1257            offset: None,
1258        };
1259        let input = EpmaInput::from_slice(&input_data, params);
1260        let res = epma_with_kernel(&input, kernel);
1261        assert!(
1262            res.is_err(),
1263            "[{}] EPMA should fail with zero period",
1264            test_name
1265        );
1266        Ok(())
1267    }
1268    fn check_epma_period_exceeds_length(
1269        test_name: &str,
1270        kernel: Kernel,
1271    ) -> Result<(), Box<dyn std::error::Error>> {
1272        skip_if_unsupported!(kernel, test_name);
1273        let data_small = [10.0, 20.0, 30.0];
1274        let params = EpmaParams {
1275            period: Some(10),
1276            offset: None,
1277        };
1278        let input = EpmaInput::from_slice(&data_small, params);
1279        let res = epma_with_kernel(&input, kernel);
1280        assert!(
1281            res.is_err(),
1282            "[{}] EPMA should fail with period exceeding length",
1283            test_name
1284        );
1285        Ok(())
1286    }
1287    fn check_epma_very_small_dataset(
1288        test_name: &str,
1289        kernel: Kernel,
1290    ) -> Result<(), Box<dyn std::error::Error>> {
1291        skip_if_unsupported!(kernel, test_name);
1292        let single_point = [42.0];
1293        let params = EpmaParams {
1294            period: Some(9),
1295            offset: None,
1296        };
1297        let input = EpmaInput::from_slice(&single_point, params);
1298        let res = epma_with_kernel(&input, kernel);
1299        assert!(
1300            res.is_err(),
1301            "[{}] EPMA should fail with insufficient data",
1302            test_name
1303        );
1304        Ok(())
1305    }
1306    fn check_epma_empty_input(
1307        test_name: &str,
1308        kernel: Kernel,
1309    ) -> Result<(), Box<dyn std::error::Error>> {
1310        skip_if_unsupported!(kernel, test_name);
1311        let empty: [f64; 0] = [];
1312        let input = EpmaInput::from_slice(&empty, EpmaParams::default());
1313        let res = epma_with_kernel(&input, kernel);
1314        assert!(
1315            matches!(res, Err(EpmaError::EmptyInputData)),
1316            "[{}] EPMA should fail with empty input",
1317            test_name
1318        );
1319        Ok(())
1320    }
1321    fn check_epma_invalid_offset(
1322        test_name: &str,
1323        kernel: Kernel,
1324    ) -> Result<(), Box<dyn std::error::Error>> {
1325        skip_if_unsupported!(kernel, test_name);
1326        let data = [1.0, 2.0, 3.0, 4.0];
1327        let params = EpmaParams {
1328            period: Some(3),
1329            offset: Some(3),
1330        };
1331        let input = EpmaInput::from_slice(&data, params);
1332        let res = epma_with_kernel(&input, kernel);
1333        assert!(
1334            matches!(res, Err(EpmaError::InvalidOffset { .. })),
1335            "[{}] EPMA should fail with invalid offset",
1336            test_name
1337        );
1338        Ok(())
1339    }
1340    fn check_epma_property(
1341        test_name: &str,
1342        kernel: Kernel,
1343    ) -> Result<(), Box<dyn std::error::Error>> {
1344        use proptest::prelude::*;
1345        skip_if_unsupported!(kernel, test_name);
1346
1347        let strat = (2usize..=50).prop_flat_map(|period| {
1348            (
1349                prop::collection::vec(
1350                    (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1351                    (period * 2 + 10)..500,
1352                ),
1353                Just(period),
1354                0usize..period,
1355            )
1356        });
1357
1358        proptest::test_runner::TestRunner::default()
1359			.run(&strat, |(data, period, offset)| {
1360				let params = EpmaParams {
1361					period: Some(period),
1362					offset: Some(offset),
1363				};
1364				let input = EpmaInput::from_slice(&data, params);
1365
1366
1367				let EpmaOutput { values: out } = epma_with_kernel(&input, kernel).unwrap();
1368
1369
1370				let EpmaOutput { values: ref_out } = epma_with_kernel(&input, Kernel::Scalar).unwrap();
1371
1372
1373				let first_valid = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1374
1375
1376				let warmup = first_valid + period + offset + 1;
1377
1378
1379				for i in 0..warmup.min(out.len()) {
1380					prop_assert!(
1381						out[i].is_nan(),
1382						"[{}] Expected NaN during warmup at index {}, got {}",
1383						test_name,
1384						i,
1385						out[i]
1386					);
1387				}
1388
1389
1390
1391				if warmup < out.len() && data[warmup].is_finite() {
1392
1393					let p1 = period - 1;
1394					let mut weight_sum = 0.0;
1395					for i in 0..p1 {
1396						let w = (period as i32 - i as i32 - offset as i32) as f64;
1397						weight_sum += w;
1398					}
1399
1400
1401					if weight_sum.abs() > 1e-10 {
1402						prop_assert!(
1403							!out[warmup].is_nan(),
1404							"[{}] Expected valid value at warmup index {}, got NaN",
1405							test_name,
1406							warmup
1407						);
1408					}
1409				}
1410
1411
1412
1413				let p1 = period - 1;
1414				let mut weight_sum = 0.0;
1415				for i in 0..p1 {
1416					let w = (period as i32 - i as i32 - offset as i32) as f64;
1417					weight_sum += w;
1418				}
1419
1420				if weight_sum.abs() > 1e-10 {
1421
1422					for i in warmup..data.len() {
1423						let y = out[i];
1424						prop_assert!(
1425							y.is_finite(),
1426							"[{}] EPMA output at index {} is not finite: {} (period={}, offset={}, weight_sum={})",
1427							test_name,
1428							i,
1429							y,
1430							period,
1431							offset,
1432							weight_sum
1433						);
1434					}
1435				} else {
1436
1437
1438					for i in warmup..data.len() {
1439						let both_nan = out[i].is_nan() && ref_out[i].is_nan();
1440						let both_inf = out[i].is_infinite() && ref_out[i].is_infinite();
1441						prop_assert!(
1442							both_nan || both_inf,
1443							"[{}] With weight_sum=0, expected consistent NaN or Inf at index {} but got: kernel={}, scalar={} (period={}, offset={})",
1444							test_name,
1445							i,
1446							out[i],
1447							ref_out[i],
1448							period,
1449							offset
1450						);
1451					}
1452				}
1453
1454
1455
1456
1457				if period == 2 && offset == 0 && warmup < data.len() {
1458
1459					for i in warmup..data.len() {
1460						if data[i].is_finite() {
1461
1462							prop_assert!(
1463								(out[i] - data[i]).abs() < 1e-9,
1464								"[{}] Period=2,offset=0 mismatch at {}: got {}, expected {}",
1465								test_name,
1466								i,
1467								out[i],
1468								data[i]
1469							);
1470						}
1471					}
1472				}
1473
1474
1475
1476				if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-12) && data.iter().any(|x| x.is_finite() && x.abs() > 1e-10) {
1477					let constant = *data.iter().find(|x| x.is_finite()).unwrap();
1478
1479					let p1 = period - 1;
1480					let mut weight_sum = 0.0;
1481					for i in 0..p1 {
1482						let w = (period as i32 - i as i32 - offset as i32) as f64;
1483						weight_sum += w;
1484					}
1485
1486					if weight_sum.abs() > 1e-10 {
1487						for i in warmup..data.len() {
1488							prop_assert!(
1489								(out[i] - constant).abs() < 1e-9,
1490								"[{}] Constant data mismatch at {}: got {}, expected {}",
1491								test_name,
1492								i,
1493								out[i],
1494								constant
1495							);
1496						}
1497					}
1498				}
1499
1500
1501
1502				for i in (warmup.saturating_add(1))..data.len() {
1503					let y = out[i];
1504					let r = ref_out[i];
1505
1506					if !y.is_finite() || !r.is_finite() {
1507						prop_assert!(
1508							y.to_bits() == r.to_bits(),
1509							"[{}] finite/NaN mismatch at idx {}: {} vs {}",
1510							test_name,
1511							i,
1512							y,
1513							r
1514						);
1515						continue;
1516					}
1517
1518					let ulp_diff: u64 = y.to_bits().abs_diff(r.to_bits());
1519
1520					let rel_error = if r.abs() > 1e-10 {
1521						(y - r).abs() / r.abs()
1522					} else {
1523						(y - r).abs()
1524					};
1525
1526
1527					prop_assert!(
1528						rel_error <= 1e-4 || (y - r).abs() <= 1e-9 || ulp_diff <= 100,
1529						"[{}] Kernel mismatch at idx {}: {} vs {} (ULP={}, rel_err={})",
1530						test_name,
1531						i,
1532						y,
1533						r,
1534						ulp_diff,
1535						rel_error
1536					);
1537				}
1538
1539
1540
1541				if warmup + 5 < data.len() {
1542
1543					let p1 = period - 1;
1544					let mut weights = Vec::with_capacity(p1);
1545					let mut weight_sum = 0.0;
1546					for i in 0..p1 {
1547						let w = (period as i32 - i as i32 - offset as i32) as f64;
1548						weights.push(w);
1549						weight_sum += w;
1550					}
1551
1552
1553					if weight_sum.abs() > 1e-10 {
1554
1555						for idx in [warmup + 1, warmup + 2, data.len() - 1].iter().copied() {
1556							if idx > warmup && idx < data.len() {
1557								let start = idx + 1 - p1;
1558								let mut expected_sum = 0.0;
1559								for i in 0..p1 {
1560									expected_sum += data[start + i] * weights[p1 - 1 - i];
1561								}
1562								let expected = expected_sum / weight_sum;
1563
1564
1565								if out[idx].is_finite() && expected.is_finite() {
1566
1567									let tolerance = if expected.abs() > 1000.0 {
1568										expected.abs() * 1e-12
1569									} else {
1570										1e-9
1571									};
1572									prop_assert!(
1573										(out[idx] - expected).abs() < tolerance,
1574										"[{}] EPMA formula mismatch at {}: got {}, expected {} (diff: {})",
1575										test_name,
1576										idx,
1577										out[idx],
1578										expected,
1579										(out[idx] - expected).abs()
1580									);
1581								} else {
1582
1583									prop_assert!(
1584										out[idx].is_nan() == expected.is_nan() &&
1585										out[idx].is_infinite() == expected.is_infinite(),
1586										"[{}] EPMA formula NaN/Inf mismatch at {}: got {}, expected {}",
1587										test_name,
1588										idx,
1589										out[idx],
1590										expected
1591									);
1592								}
1593							}
1594						}
1595					}
1596				}
1597
1598
1599
1600
1601				if offset == period - 1 && warmup < data.len() && weight_sum.abs() > 1e-10 {
1602
1603					for i in warmup..data.len() {
1604						prop_assert!(
1605							out[i].is_finite(),
1606							"[{}] Edge case offset={} produced non-finite at {}",
1607							test_name,
1608							offset,
1609							i
1610						);
1611					}
1612				}
1613
1614				Ok(())
1615			})
1616			.unwrap();
1617
1618        Ok(())
1619    }
1620    fn check_epma_invalid_params(
1621        test_name: &str,
1622        kernel: Kernel,
1623    ) -> Result<(), Box<dyn std::error::Error>> {
1624        skip_if_unsupported!(kernel, test_name);
1625
1626        let zero_weight_cases = vec![(4, 3), (5, 3), (6, 4), (8, 6)];
1627
1628        for (period, offset) in zero_weight_cases {
1629            let p1 = period - 1;
1630            let mut weight_sum = 0.0;
1631            for i in 0..p1 {
1632                let w = (period as i32 - i as i32 - offset as i32) as f64;
1633                weight_sum += w;
1634            }
1635
1636            let data = vec![1.0; period * 2];
1637            let params = EpmaParams {
1638                period: Some(period),
1639                offset: Some(offset),
1640            };
1641            let input = EpmaInput::from_slice(&data, params);
1642
1643            if weight_sum.abs() < 1e-10 {
1644                let out = epma_with_kernel(&input, kernel)?;
1645                let scalar_out = epma_with_kernel(&input, Kernel::Scalar)?;
1646
1647                let warmup = period + offset + 1;
1648                for i in warmup..data.len() {
1649                    let both_nan = out.values[i].is_nan() && scalar_out.values[i].is_nan();
1650                    let both_inf =
1651                        out.values[i].is_infinite() && scalar_out.values[i].is_infinite();
1652                    assert!(
1653						both_nan || both_inf,
1654						"[{}] Period={}, Offset={} (weight_sum=0) should produce consistent NaN or Inf, got kernel={}, scalar={}",
1655						test_name,
1656						period,
1657						offset,
1658						out.values[i],
1659						scalar_out.values[i]
1660					);
1661                }
1662            }
1663        }
1664
1665        Ok(())
1666    }
1667
1668    fn check_epma_reinput(
1669        test_name: &str,
1670        kernel: Kernel,
1671    ) -> Result<(), Box<dyn std::error::Error>> {
1672        skip_if_unsupported!(kernel, test_name);
1673        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1674        let candles = read_candles_from_csv(file_path)?;
1675        let first_params = EpmaParams {
1676            period: Some(9),
1677            offset: None,
1678        };
1679        let first_input = EpmaInput::from_candles(&candles, "close", first_params);
1680        let first_result = epma_with_kernel(&first_input, kernel)?;
1681        let second_params = EpmaParams {
1682            period: Some(3),
1683            offset: None,
1684        };
1685        let second_input = EpmaInput::from_slice(&first_result.values, second_params);
1686        let second_result = epma_with_kernel(&second_input, kernel)?;
1687        assert_eq!(second_result.values.len(), first_result.values.len());
1688        Ok(())
1689    }
1690    fn check_epma_nan_handling(
1691        test_name: &str,
1692        kernel: Kernel,
1693    ) -> Result<(), Box<dyn std::error::Error>> {
1694        skip_if_unsupported!(kernel, test_name);
1695        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1696        let candles = read_candles_from_csv(file_path)?;
1697        let params = EpmaParams {
1698            period: Some(11),
1699            offset: Some(4),
1700        };
1701        let input = EpmaInput::from_candles(&candles, "close", params.clone());
1702        let res = epma_with_kernel(&input, kernel)?;
1703        assert_eq!(res.values.len(), candles.close.len());
1704        if res.values.len() > 240 {
1705            for (i, &val) in res.values[240..].iter().enumerate() {
1706                assert!(
1707                    !val.is_nan(),
1708                    "[{}] Found unexpected NaN at out-index {}",
1709                    test_name,
1710                    240 + i
1711                );
1712            }
1713        }
1714        Ok(())
1715    }
1716    fn check_epma_streaming(
1717        test_name: &str,
1718        kernel: Kernel,
1719    ) -> Result<(), Box<dyn std::error::Error>> {
1720        skip_if_unsupported!(kernel, test_name);
1721        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1722        let candles = read_candles_from_csv(file_path)?;
1723        let period = 11;
1724        let offset = 4;
1725        let input = EpmaInput::from_candles(
1726            &candles,
1727            "close",
1728            EpmaParams {
1729                period: Some(period),
1730                offset: Some(offset),
1731            },
1732        );
1733        let batch_output = epma_with_kernel(&input, kernel)?.values;
1734        let mut stream = EpmaStream::try_new(EpmaParams {
1735            period: Some(period),
1736            offset: Some(offset),
1737        })?;
1738        let mut stream_values = Vec::with_capacity(candles.close.len());
1739        for &price in &candles.close {
1740            match stream.update(price) {
1741                Some(val) => stream_values.push(val),
1742                None => stream_values.push(f64::NAN),
1743            }
1744        }
1745        assert_eq!(batch_output.len(), stream_values.len());
1746        for (i, (&b, &s)) in batch_output
1747            .iter()
1748            .zip(stream_values.iter())
1749            .enumerate()
1750            .skip(period + offset + 1)
1751        {
1752            if b.is_nan() && s.is_nan() {
1753                continue;
1754            }
1755            let diff = (b - s).abs();
1756            assert!(
1757                diff < 1e-9,
1758                "[{}] EPMA streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1759                test_name,
1760                i,
1761                b,
1762                s,
1763                diff
1764            );
1765        }
1766        Ok(())
1767    }
1768
1769    #[cfg(debug_assertions)]
1770    fn check_epma_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1771        skip_if_unsupported!(kernel, test_name);
1772
1773        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1774        let candles = read_candles_from_csv(file_path)?;
1775
1776        let test_cases = vec![
1777            EpmaParams::default(),
1778            EpmaParams {
1779                period: Some(2),
1780                offset: Some(0),
1781            },
1782            EpmaParams {
1783                period: Some(5),
1784                offset: Some(1),
1785            },
1786            EpmaParams {
1787                period: Some(10),
1788                offset: Some(3),
1789            },
1790            EpmaParams {
1791                period: Some(10),
1792                offset: Some(9),
1793            },
1794            EpmaParams {
1795                period: Some(20),
1796                offset: Some(5),
1797            },
1798            EpmaParams {
1799                period: Some(30),
1800                offset: Some(10),
1801            },
1802            EpmaParams {
1803                period: Some(15),
1804                offset: Some(14),
1805            },
1806        ];
1807
1808        for params in test_cases {
1809            let input = EpmaInput::from_candles(&candles, "close", params.clone());
1810            let output = epma_with_kernel(&input, kernel)?;
1811
1812            for (i, &val) in output.values.iter().enumerate() {
1813                if val.is_nan() {
1814                    continue;
1815                }
1816
1817                let bits = val.to_bits();
1818
1819                if bits == 0x11111111_11111111 {
1820                    panic!(
1821                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} with params period={:?}, offset={:?}",
1822                        test_name, val, bits, i, params.period, params.offset
1823                    );
1824                }
1825
1826                if bits == 0x22222222_22222222 {
1827                    panic!(
1828                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} with params period={:?}, offset={:?}",
1829                        test_name, val, bits, i, params.period, params.offset
1830                    );
1831                }
1832
1833                if bits == 0x33333333_33333333 {
1834                    panic!(
1835                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} with params period={:?}, offset={:?}",
1836                        test_name, val, bits, i, params.period, params.offset
1837                    );
1838                }
1839            }
1840        }
1841
1842        Ok(())
1843    }
1844
1845    #[cfg(not(debug_assertions))]
1846    fn check_epma_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1847        Ok(())
1848    }
1849
1850    macro_rules! generate_all_epma_tests {
1851        ($($test_fn:ident),*) => {
1852            paste::paste! {
1853                $(
1854                    #[test]
1855                    fn [<$test_fn _scalar_f64>]() {
1856                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1857                    }
1858                    #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1859                    #[test]
1860                    fn [<$test_fn _avx2_f64>]() {
1861                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1862                    }
1863                    #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1864                    #[test]
1865                    fn [<$test_fn _avx512_f64>]() {
1866                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1867                    }
1868                )*
1869
1870                #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
1871                $(
1872                    #[test]
1873                    fn [<$test_fn _simd128_f64>]() {
1874                        let _ = $test_fn(stringify!([<$test_fn _simd128_f64>]), Kernel::Scalar);
1875                    }
1876                )*
1877            }
1878        }
1879    }
1880    generate_all_epma_tests!(
1881        check_epma_partial_params,
1882        check_epma_accuracy,
1883        check_epma_default_candles,
1884        check_epma_zero_period,
1885        check_epma_period_exceeds_length,
1886        check_epma_very_small_dataset,
1887        check_epma_empty_input,
1888        check_epma_invalid_offset,
1889        check_epma_invalid_params,
1890        check_epma_reinput,
1891        check_epma_nan_handling,
1892        check_epma_streaming,
1893        check_epma_property,
1894        check_epma_no_poison
1895    );
1896
1897    #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
1898    #[test]
1899    fn test_epma_simd128_correctness() {
1900        let data = vec![
1901            1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1902        ];
1903        let period = 5;
1904        let offset = 2;
1905
1906        let params = EpmaParams {
1907            period: Some(period),
1908            offset: Some(offset),
1909        };
1910        let input = EpmaInput::from_slice(&data, params);
1911
1912        let mut scalar_out = vec![0.0; data.len()];
1913        epma_scalar(&data, period, offset, 0, &mut scalar_out);
1914
1915        let simd128_output = epma_with_kernel(&input, Kernel::Scalar).unwrap();
1916
1917        let warmup = period + offset + 1;
1918        for i in warmup..data.len() {
1919            assert!(
1920                (scalar_out[i] - simd128_output.values[i]).abs() < 1e-10,
1921                "SIMD128 mismatch at index {}: scalar={}, simd128={}",
1922                i,
1923                scalar_out[i],
1924                simd128_output.values[i]
1925            );
1926        }
1927    }
1928
1929    fn check_batch_default_row(
1930        test: &str,
1931        kernel: Kernel,
1932    ) -> Result<(), Box<dyn std::error::Error>> {
1933        skip_if_unsupported!(kernel, test);
1934        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1935        let c = read_candles_from_csv(file)?;
1936        let output = EpmaBatchBuilder::new()
1937            .kernel(kernel)
1938            .apply_candles(&c, "close")?;
1939        let def = EpmaParams::default();
1940        let row = output.values_for(&def).expect("default row missing");
1941        assert_eq!(row.len(), c.close.len());
1942        let expected = [59174.48, 59201.04, 59167.60, 59200.32, 59117.04];
1943        let start = row.len() - 5;
1944        for (i, &v) in row[start..].iter().enumerate() {
1945            assert!(
1946                (v - expected[i]).abs() < 1e-1,
1947                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1948            );
1949        }
1950        Ok(())
1951    }
1952
1953    #[cfg(debug_assertions)]
1954    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1955        skip_if_unsupported!(kernel, test);
1956
1957        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1958        let c = read_candles_from_csv(file)?;
1959
1960        let test_configs = vec![
1961            ((2, 5, 1), (0, 2, 1)),
1962            ((10, 20, 5), (0, 19, 3)),
1963            ((20, 30, 2), (5, 15, 5)),
1964            ((15, 25, 5), (10, 14, 2)),
1965            ((5, 10, 1), (0, 9, 1)),
1966        ];
1967
1968        for (period_range, offset_range) in test_configs {
1969            let output = EpmaBatchBuilder::new()
1970                .kernel(kernel)
1971                .period_range(period_range.0, period_range.1, period_range.2)
1972                .offset_range(offset_range.0, offset_range.1, offset_range.2)
1973                .apply_candles(&c, "close")?;
1974
1975            for (idx, &val) in output.values.iter().enumerate() {
1976                if val.is_nan() {
1977                    continue;
1978                }
1979
1980                let bits = val.to_bits();
1981                let row = idx / output.cols;
1982                let col = idx % output.cols;
1983                let params = &output.combos[row];
1984
1985                if bits == 0x11111111_11111111 {
1986                    panic!(
1987                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (params: period={:?}, offset={:?})",
1988                        test, val, bits, row, col, params.period, params.offset
1989                    );
1990                }
1991
1992                if bits == 0x22222222_22222222 {
1993                    panic!(
1994                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (params: period={:?}, offset={:?})",
1995                        test, val, bits, row, col, params.period, params.offset
1996                    );
1997                }
1998
1999                if bits == 0x33333333_33333333 {
2000                    panic!(
2001                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (params: period={:?}, offset={:?})",
2002                        test, val, bits, row, col, params.period, params.offset
2003                    );
2004                }
2005            }
2006        }
2007
2008        Ok(())
2009    }
2010
2011    #[cfg(not(debug_assertions))]
2012    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2013        Ok(())
2014    }
2015
2016    macro_rules! gen_batch_tests {
2017        ($fn_name:ident) => {
2018            paste::paste! {
2019                #[test]
2020                fn [<$fn_name _scalar>]() {
2021                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2022                }
2023                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2024                #[test]
2025                fn [<$fn_name _avx2>]() {
2026                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2027                }
2028                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2029                #[test]
2030                fn [<$fn_name _avx512>]() {
2031                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2032                }
2033                #[test]
2034                fn [<$fn_name _auto_detect>]() {
2035                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2036                }
2037            }
2038        };
2039    }
2040    gen_batch_tests!(check_batch_default_row);
2041    gen_batch_tests!(check_batch_no_poison);
2042
2043    #[test]
2044    fn test_invalid_output_len_error() {
2045        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
2046        let params = EpmaParams {
2047            period: Some(3),
2048            offset: Some(1),
2049        };
2050        let input = EpmaInput::from_slice(&data, params);
2051        let mut wrong_size_dst = vec![0.0; 3];
2052
2053        let result = epma_into_slice(&mut wrong_size_dst, &input, Kernel::Scalar);
2054        assert!(result.is_err());
2055
2056        match result {
2057            Err(EpmaError::OutputLengthMismatch { expected, got }) => {
2058                assert_eq!(expected, 5);
2059                assert_eq!(got, 3);
2060            }
2061            _ => panic!("Expected OutputLengthMismatch error"),
2062        }
2063    }
2064
2065    #[test]
2066    fn test_invalid_kernel_error() {
2067        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
2068        let sweep = EpmaBatchRange {
2069            period: (3, 5, 1),
2070            offset: (1, 2, 1),
2071        };
2072
2073        let result = epma_batch_with_kernel(&data, &sweep, Kernel::Scalar);
2074        assert!(result.is_err());
2075
2076        match result {
2077            Err(EpmaError::InvalidKernelForBatch(k)) => {
2078                assert_eq!(k, Kernel::Scalar);
2079            }
2080            _ => panic!("Expected InvalidKernelForBatch error"),
2081        }
2082    }
2083}
2084
2085#[cfg(feature = "python")]
2086use crate::utilities::kernel_validation::validate_kernel;
2087#[cfg(all(feature = "python", feature = "cuda"))]
2088use numpy::PyReadonlyArray2;
2089#[cfg(feature = "python")]
2090use numpy::PyUntypedArrayMethods;
2091#[cfg(feature = "python")]
2092use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
2093#[cfg(feature = "python")]
2094use pyo3::{exceptions::PyValueError, prelude::*, types::PyDict};
2095
2096#[cfg(feature = "python")]
2097#[pyfunction(name = "epma")]
2098#[pyo3(signature = (data, period, offset, kernel=None))]
2099pub fn epma_py<'py>(
2100    py: Python<'py>,
2101    data: numpy::PyReadonlyArray1<'py, f64>,
2102    period: usize,
2103    offset: usize,
2104    kernel: Option<&str>,
2105) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
2106    let slice_in = data.as_slice()?;
2107    let params = EpmaParams {
2108        period: Some(period),
2109        offset: Some(offset),
2110    };
2111    let input = EpmaInput::from_slice(slice_in, params);
2112    let kern = validate_kernel(kernel, false)?;
2113
2114    let out_arr = unsafe { PyArray1::<f64>::new(py, [slice_in.len()], false) };
2115    let out_slice = unsafe { out_arr.as_slice_mut()? };
2116
2117    py.allow_threads(|| epma_into_slice(out_slice, &input, kern))
2118        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2119
2120    Ok(out_arr)
2121}
2122
2123#[cfg(feature = "python")]
2124#[pyfunction(name = "epma_batch")]
2125#[pyo3(signature = (data, period_range, offset_range, kernel=None))]
2126pub fn epma_batch_py<'py>(
2127    py: Python<'py>,
2128    data: numpy::PyReadonlyArray1<'py, f64>,
2129    period_range: (usize, usize, usize),
2130    offset_range: (usize, usize, usize),
2131    kernel: Option<&str>,
2132) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
2133    let slice_in = data.as_slice()?;
2134
2135    let sweep = EpmaBatchRange {
2136        period: period_range,
2137        offset: offset_range,
2138    };
2139
2140    let combos = expand_grid(&sweep);
2141    let rows = combos.len();
2142    let cols = slice_in.len();
2143
2144    let out_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
2145    let out_slice = unsafe { out_arr.as_slice_mut()? };
2146
2147    let kern = validate_kernel(kernel, true)?;
2148
2149    let combos_done = py
2150        .allow_threads(|| {
2151            let actual = match kern {
2152                Kernel::Auto => detect_best_batch_kernel(),
2153                k => k,
2154            };
2155            let simd = match actual {
2156                Kernel::Avx512Batch => Kernel::Avx512,
2157                Kernel::Avx2Batch => Kernel::Avx2,
2158                Kernel::ScalarBatch => Kernel::Scalar,
2159                _ => unreachable!(),
2160            };
2161            epma_batch_inner_into(slice_in, &sweep, simd, true, out_slice)
2162        })
2163        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2164
2165    let dict = PyDict::new(py);
2166    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2167    dict.set_item(
2168        "periods",
2169        combos_done
2170            .iter()
2171            .map(|p| p.period.unwrap() as u64)
2172            .collect::<Vec<_>>()
2173            .into_pyarray(py),
2174    )?;
2175    dict.set_item(
2176        "offsets",
2177        combos_done
2178            .iter()
2179            .map(|p| p.offset.unwrap() as u64)
2180            .collect::<Vec<_>>()
2181            .into_pyarray(py),
2182    )?;
2183    Ok(dict)
2184}
2185
2186#[cfg(all(feature = "python", feature = "cuda"))]
2187#[pyfunction(name = "epma_cuda_batch_dev")]
2188#[pyo3(signature = (data_f32, period_range, offset_range, device_id=0))]
2189pub fn epma_cuda_batch_dev_py(
2190    py: Python<'_>,
2191    data_f32: numpy::PyReadonlyArray1<'_, f32>,
2192    period_range: (usize, usize, usize),
2193    offset_range: (usize, usize, usize),
2194    device_id: usize,
2195) -> PyResult<EpmaDeviceArrayF32Py> {
2196    if !cuda_available() {
2197        return Err(PyValueError::new_err("CUDA not available"));
2198    }
2199
2200    let sweep = EpmaBatchRange {
2201        period: period_range,
2202        offset: offset_range,
2203    };
2204    let slice_in = data_f32.as_slice()?;
2205
2206    let (inner, ctx_guard, dev_id) = py.allow_threads(|| {
2207        let cuda = CudaEpma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2208        let ctx = cuda.context_arc();
2209        let arr = cuda
2210            .epma_batch_dev(slice_in, &sweep)
2211            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2212        Ok::<_, PyErr>((arr, ctx, device_id as u32))
2213    })?;
2214
2215    Ok(EpmaDeviceArrayF32Py {
2216        inner,
2217        _ctx_guard: ctx_guard,
2218        device_id: dev_id as u32,
2219    })
2220}
2221
2222#[cfg(all(feature = "python", feature = "cuda"))]
2223#[pyfunction(name = "epma_cuda_many_series_one_param_dev")]
2224#[pyo3(signature = (data_tm_f32, period, offset, device_id=0))]
2225pub fn epma_cuda_many_series_one_param_dev_py(
2226    py: Python<'_>,
2227    data_tm_f32: PyReadonlyArray2<'_, f32>,
2228    period: usize,
2229    offset: usize,
2230    device_id: usize,
2231) -> PyResult<EpmaDeviceArrayF32Py> {
2232    if !cuda_available() {
2233        return Err(PyValueError::new_err("CUDA not available"));
2234    }
2235
2236    let flat_in = data_tm_f32.as_slice()?;
2237    let shape = data_tm_f32.shape();
2238    let rows = shape[0];
2239    let cols = shape[1];
2240    let params = EpmaParams {
2241        period: Some(period),
2242        offset: Some(offset),
2243    };
2244
2245    let (inner, ctx_guard, dev_id) = py.allow_threads(|| {
2246        let cuda = CudaEpma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2247        let ctx = cuda.context_arc();
2248        let arr = cuda
2249            .epma_many_series_one_param_time_major_dev(flat_in, cols, rows, &params)
2250            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2251        Ok::<_, PyErr>((arr, ctx, device_id as u32))
2252    })?;
2253
2254    Ok(EpmaDeviceArrayF32Py {
2255        inner,
2256        _ctx_guard: ctx_guard,
2257        device_id: dev_id as u32,
2258    })
2259}
2260
2261#[cfg(all(feature = "python", feature = "cuda"))]
2262#[pyclass(module = "ta_indicators.cuda", unsendable)]
2263pub struct EpmaDeviceArrayF32Py {
2264    pub(crate) inner: DeviceArrayF32,
2265    pub(crate) _ctx_guard: Arc<Context>,
2266    pub(crate) device_id: u32,
2267}
2268
2269#[cfg(all(feature = "python", feature = "cuda"))]
2270#[pymethods]
2271impl EpmaDeviceArrayF32Py {
2272    #[getter]
2273    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2274        let d = PyDict::new(py);
2275
2276        d.set_item("shape", (self.inner.rows, self.inner.cols))?;
2277        d.set_item("typestr", "<f4")?;
2278        d.set_item(
2279            "strides",
2280            (
2281                self.inner.cols * std::mem::size_of::<f32>(),
2282                std::mem::size_of::<f32>(),
2283            ),
2284        )?;
2285        d.set_item("data", (self.inner.device_ptr() as usize, false))?;
2286
2287        d.set_item("version", 3)?;
2288        Ok(d)
2289    }
2290
2291    fn __dlpack_device__(&self) -> (i32, i32) {
2292        (2, self.device_id as i32)
2293    }
2294
2295    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2296    fn __dlpack__<'py>(
2297        &mut self,
2298        py: Python<'py>,
2299        stream: Option<pyo3::PyObject>,
2300        max_version: Option<pyo3::PyObject>,
2301        dl_device: Option<pyo3::PyObject>,
2302        copy: Option<pyo3::PyObject>,
2303    ) -> PyResult<PyObject> {
2304        use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
2305
2306        let (kdl, alloc_dev) = self.__dlpack_device__();
2307        if let Some(dev_obj) = dl_device.as_ref() {
2308            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2309                if dev_ty != kdl || dev_id != alloc_dev {
2310                    let wants_copy = copy
2311                        .as_ref()
2312                        .and_then(|c| c.extract::<bool>(py).ok())
2313                        .unwrap_or(false);
2314                    if wants_copy {
2315                        return Err(PyValueError::new_err(
2316                            "device copy not implemented for __dlpack__",
2317                        ));
2318                    } else {
2319                        return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
2320                    }
2321                }
2322            }
2323        }
2324
2325        let _ = stream;
2326
2327        let dummy = cust::memory::DeviceBuffer::from_slice(&[])
2328            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2329        let inner = std::mem::replace(
2330            &mut self.inner,
2331            DeviceArrayF32 {
2332                buf: dummy,
2333                rows: 0,
2334                cols: 0,
2335            },
2336        );
2337
2338        let rows = inner.rows;
2339        let cols = inner.cols;
2340        let buf = inner.buf;
2341
2342        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2343
2344        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2345    }
2346}
2347
2348#[cfg(feature = "python")]
2349#[pyclass(name = "EpmaStream")]
2350pub struct EpmaStreamPy {
2351    stream: EpmaStream,
2352}
2353
2354#[cfg(feature = "python")]
2355#[pymethods]
2356impl EpmaStreamPy {
2357    #[new]
2358    fn new(period: usize, offset: usize) -> PyResult<Self> {
2359        let params = EpmaParams {
2360            period: Some(period),
2361            offset: Some(offset),
2362        };
2363        let stream =
2364            EpmaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2365        Ok(Self { stream })
2366    }
2367    fn update(&mut self, value: f64) -> Option<f64> {
2368        self.stream.update(value)
2369    }
2370}
2371
2372#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2373use serde::{Deserialize, Serialize};
2374#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2375use wasm_bindgen::prelude::*;
2376
2377#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2378#[wasm_bindgen]
2379pub fn epma_js(data: &[f64], period: usize, offset: usize) -> Result<Vec<f64>, JsValue> {
2380    let params = EpmaParams {
2381        period: Some(period),
2382        offset: Some(offset),
2383    };
2384    let input = EpmaInput::from_slice(data, params);
2385    let mut out = vec![0.0; data.len()];
2386    epma_into_slice(&mut out, &input, detect_best_kernel())
2387        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2388    Ok(out)
2389}
2390
2391#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2392#[derive(Serialize, Deserialize)]
2393pub struct EpmaBatchConfig {
2394    pub period_range: (usize, usize, usize),
2395    pub offset_range: (usize, usize, usize),
2396}
2397
2398#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2399#[derive(Serialize, Deserialize)]
2400pub struct EpmaBatchJsOutput {
2401    pub values: Vec<f64>,
2402    pub combos: Vec<EpmaParams>,
2403    pub rows: usize,
2404    pub cols: usize,
2405}
2406
2407#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2408#[wasm_bindgen(js_name = epma_batch)]
2409pub fn epma_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2410    let cfg: EpmaBatchConfig = serde_wasm_bindgen::from_value(config)
2411        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2412
2413    let sweep = EpmaBatchRange {
2414        period: cfg.period_range,
2415        offset: cfg.offset_range,
2416    };
2417    let out = epma_batch_inner(data, &sweep, detect_best_kernel(), false)
2418        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2419
2420    serde_wasm_bindgen::to_value(&EpmaBatchJsOutput {
2421        values: out.values,
2422        combos: out.combos,
2423        rows: out.rows,
2424        cols: out.cols,
2425    })
2426    .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2427}
2428
2429#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2430#[wasm_bindgen]
2431pub fn epma_alloc(len: usize) -> *mut f64 {
2432    let mut v = Vec::<f64>::with_capacity(len);
2433    let ptr = v.as_mut_ptr();
2434    std::mem::forget(v);
2435    ptr
2436}
2437
2438#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2439#[wasm_bindgen]
2440pub fn epma_free(ptr: *mut f64, len: usize) {
2441    unsafe {
2442        let _ = Vec::from_raw_parts(ptr, len, len);
2443    }
2444}
2445
2446#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2447#[wasm_bindgen]
2448pub fn epma_into(
2449    in_ptr: *const f64,
2450    out_ptr: *mut f64,
2451    len: usize,
2452    period: usize,
2453    offset: usize,
2454) -> Result<(), JsValue> {
2455    if in_ptr.is_null() || out_ptr.is_null() {
2456        return Err(JsValue::from_str("null pointer"));
2457    }
2458    unsafe {
2459        let data = std::slice::from_raw_parts(in_ptr, len);
2460        let params = EpmaParams {
2461            period: Some(period),
2462            offset: Some(offset),
2463        };
2464        let input = EpmaInput::from_slice(data, params);
2465
2466        if in_ptr == out_ptr {
2467            let mut tmp = vec![0.0; len];
2468            epma_into_slice(&mut tmp, &input, detect_best_kernel())
2469                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2470            std::slice::from_raw_parts_mut(out_ptr, len).copy_from_slice(&tmp);
2471        } else {
2472            epma_into_slice(
2473                std::slice::from_raw_parts_mut(out_ptr, len),
2474                &input,
2475                detect_best_kernel(),
2476            )
2477            .map_err(|e| JsValue::from_str(&e.to_string()))?;
2478        }
2479    }
2480    Ok(())
2481}
2482
2483#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2484#[wasm_bindgen]
2485pub fn epma_batch_into(
2486    in_ptr: *const f64,
2487    out_ptr: *mut f64,
2488    len: usize,
2489    p_start: usize,
2490    p_end: usize,
2491    p_step: usize,
2492    o_start: usize,
2493    o_end: usize,
2494    o_step: usize,
2495) -> Result<usize, JsValue> {
2496    if in_ptr.is_null() || out_ptr.is_null() {
2497        return Err(JsValue::from_str("null pointer"));
2498    }
2499    unsafe {
2500        let data = std::slice::from_raw_parts(in_ptr, len);
2501        let sweep = EpmaBatchRange {
2502            period: (p_start, p_end, p_step),
2503            offset: (o_start, o_end, o_step),
2504        };
2505        let combos = expand_grid(&sweep);
2506        let rows = combos.len();
2507        let cols = len;
2508        let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
2509
2510        epma_batch_inner_into(data, &sweep, detect_best_kernel(), false, out)
2511            .map_err(|e| JsValue::from_str(&e.to_string()))?;
2512        Ok(rows)
2513    }
2514}