Skip to main content

vector_ta/indicators/moving_averages/
dema.rs

1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::moving_averages::{CudaDema, DeviceArrayF32};
3use crate::utilities::data_loader::{source_type, Candles};
4use crate::utilities::enums::Kernel;
5use crate::utilities::helpers::{
6    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
7    make_uninit_matrix,
8};
9#[cfg(feature = "python")]
10use crate::utilities::kernel_validation::validate_kernel;
11#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
12use core::arch::x86_64::*;
13#[cfg(feature = "python")]
14use pyo3::exceptions::PyValueError;
15#[cfg(feature = "python")]
16use pyo3::prelude::*;
17#[cfg(not(target_arch = "wasm32"))]
18use rayon::prelude::*;
19#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
20use serde::{Deserialize, Serialize};
21use std::convert::AsRef;
22use std::error::Error;
23use std::mem::MaybeUninit;
24use thiserror::Error;
25#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
26use wasm_bindgen::prelude::*;
27
28#[derive(Debug, Clone)]
29pub enum DemaData<'a> {
30    Candles {
31        candles: &'a Candles,
32        source: &'a str,
33    },
34    Slice(&'a [f64]),
35}
36
37impl<'a> AsRef<[f64]> for DemaInput<'a> {
38    #[inline(always)]
39    fn as_ref(&self) -> &[f64] {
40        match &self.data {
41            DemaData::Slice(slice) => slice,
42            DemaData::Candles { candles, source } => source_type(candles, source),
43        }
44    }
45}
46
47#[derive(Debug, Clone)]
48#[cfg_attr(
49    all(target_arch = "wasm32", feature = "wasm"),
50    derive(serde::Serialize, serde::Deserialize)
51)]
52pub struct DemaParams {
53    pub period: Option<usize>,
54}
55
56impl Default for DemaParams {
57    fn default() -> Self {
58        Self { period: Some(30) }
59    }
60}
61
62#[derive(Debug, Clone)]
63pub struct DemaInput<'a> {
64    pub data: DemaData<'a>,
65    pub params: DemaParams,
66}
67
68impl<'a> DemaInput<'a> {
69    #[inline]
70    pub fn from_candles(c: &'a Candles, s: &'a str, p: DemaParams) -> Self {
71        Self {
72            data: DemaData::Candles {
73                candles: c,
74                source: s,
75            },
76            params: p,
77        }
78    }
79    #[inline]
80    pub fn from_slice(sl: &'a [f64], p: DemaParams) -> Self {
81        Self {
82            data: DemaData::Slice(sl),
83            params: p,
84        }
85    }
86    #[inline]
87    pub fn with_default_candles(c: &'a Candles) -> Self {
88        Self::from_candles(c, "close", DemaParams::default())
89    }
90    #[inline]
91    pub fn get_period(&self) -> usize {
92        self.params.period.unwrap_or(30)
93    }
94}
95
96#[derive(Debug, Clone)]
97pub struct DemaOutput {
98    pub values: Vec<f64>,
99}
100
101#[derive(Copy, Clone, Debug)]
102pub struct DemaBuilder {
103    period: Option<usize>,
104    kernel: Kernel,
105}
106
107impl Default for DemaBuilder {
108    fn default() -> Self {
109        Self {
110            period: None,
111            kernel: Kernel::Auto,
112        }
113    }
114}
115
116impl DemaBuilder {
117    #[inline(always)]
118    pub fn new() -> Self {
119        Self::default()
120    }
121    #[inline(always)]
122    pub fn period(mut self, n: usize) -> Self {
123        self.period = Some(n);
124        self
125    }
126    #[inline(always)]
127    pub fn kernel(mut self, k: Kernel) -> Self {
128        self.kernel = k;
129        self
130    }
131
132    #[inline(always)]
133    pub fn apply(self, c: &Candles) -> Result<DemaOutput, DemaError> {
134        let p = DemaParams {
135            period: self.period,
136        };
137        let i = DemaInput::from_candles(c, "close", p);
138        dema_with_kernel(&i, self.kernel)
139    }
140    #[inline(always)]
141    pub fn apply_slice(self, d: &[f64]) -> Result<DemaOutput, DemaError> {
142        let p = DemaParams {
143            period: self.period,
144        };
145        let i = DemaInput::from_slice(d, p);
146        dema_with_kernel(&i, self.kernel)
147    }
148    #[inline(always)]
149    pub fn into_stream(self) -> Result<DemaStream, DemaError> {
150        let p = DemaParams {
151            period: self.period,
152        };
153        DemaStream::try_new(p)
154    }
155}
156
157#[derive(Debug, Error)]
158pub enum DemaError {
159    #[error("dema: Input data slice is empty.")]
160    EmptyInputData,
161    #[error("dema: All values are NaN.")]
162    AllValuesNaN,
163    #[error("dema: Invalid period: period = {period}, data length = {data_len}")]
164    InvalidPeriod { period: usize, data_len: usize },
165    #[error("dema: Not enough data: needed = {needed}, valid = {valid}")]
166    NotEnoughData { needed: usize, valid: usize },
167    #[error("dema: Not enough valid data: needed = {needed}, valid = {valid}")]
168    NotEnoughValidData { needed: usize, valid: usize },
169    #[error("dema: output length mismatch: expected = {expected}, got = {got}")]
170    OutputLengthMismatch { expected: usize, got: usize },
171    #[error("dema: invalid range: start = {start}, end = {end}, step = {step}")]
172    InvalidRange {
173        start: usize,
174        end: usize,
175        step: usize,
176    },
177    #[error("dema: invalid kernel for batch: {0:?}")]
178    InvalidKernelForBatch(Kernel),
179    #[error("dema: size overflow when computing {context}")]
180    SizeOverflow { context: &'static str },
181}
182
183#[inline]
184pub fn dema(input: &DemaInput) -> Result<DemaOutput, DemaError> {
185    dema_with_kernel(input, Kernel::Auto)
186}
187
188#[inline(always)]
189fn dema_prepare<'a>(
190    input: &'a DemaInput,
191    kernel: Kernel,
192) -> Result<(&'a [f64], usize, usize, usize, Kernel), DemaError> {
193    let data: &[f64] = match &input.data {
194        DemaData::Candles { candles, source } => source_type(candles, source),
195        DemaData::Slice(sl) => sl,
196    };
197
198    let len = data.len();
199    if len == 0 {
200        return Err(DemaError::EmptyInputData);
201    }
202
203    let first = data
204        .iter()
205        .position(|x| !x.is_nan())
206        .ok_or(DemaError::AllValuesNaN)?;
207
208    let period = input.get_period();
209
210    if period < 1 || period > len {
211        return Err(DemaError::InvalidPeriod {
212            period,
213            data_len: len,
214        });
215    }
216    let needed = 2 * (period - 1);
217    if len < needed {
218        return Err(DemaError::NotEnoughData { needed, valid: len });
219    }
220    let valid = len - first;
221    if valid < needed {
222        return Err(DemaError::NotEnoughValidData { needed, valid });
223    }
224
225    let chosen = match kernel {
226        Kernel::Auto => match detect_best_kernel() {
227            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
228            Kernel::Avx512 => Kernel::Avx512,
229            _ => Kernel::Scalar,
230        },
231        other => other,
232    };
233
234    let warm = first + period - 1;
235
236    Ok((data, period, first, warm, chosen))
237}
238
239#[inline(always)]
240fn dema_compute_into(data: &[f64], period: usize, first: usize, chosen: Kernel, out: &mut [f64]) {
241    unsafe {
242        match chosen {
243            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
244            Kernel::Avx512 => dema_avx512(data, period, first, out),
245            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
246            Kernel::Avx2 => dema_avx2(data, period, first, out),
247            _ => dema_scalar(data, period, first, out),
248        }
249    }
250}
251
252pub fn dema_with_kernel(input: &DemaInput, kernel: Kernel) -> Result<DemaOutput, DemaError> {
253    let (data, period, first, warm, chosen) = dema_prepare(input, kernel)?;
254    let len = data.len();
255    let mut out = alloc_with_nan_prefix(len, warm);
256    dema_compute_into(data, period, first, chosen, &mut out);
257
258    out[..warm].fill(f64::NAN);
259    Ok(DemaOutput { values: out })
260}
261
262#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
263#[inline]
264pub fn dema_into(input: &DemaInput, out: &mut [f64]) -> Result<(), DemaError> {
265    dema_into_slice(out, input, Kernel::Auto)
266}
267
268#[inline]
269pub fn dema_into_slice(dst: &mut [f64], input: &DemaInput, kern: Kernel) -> Result<(), DemaError> {
270    let (data, period, first, warmup, chosen) = dema_prepare(input, kern)?;
271
272    if dst.len() != data.len() {
273        return Err(DemaError::OutputLengthMismatch {
274            expected: data.len(),
275            got: dst.len(),
276        });
277    }
278
279    dema_compute_into(data, period, first, chosen, dst);
280
281    for v in &mut dst[..warmup] {
282        *v = f64::NAN;
283    }
284
285    Ok(())
286}
287
288#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
289#[target_feature(enable = "fma")]
290#[inline]
291pub unsafe fn dema_scalar(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
292    debug_assert!(period >= 1 && data.len() == out.len());
293    let n = data.len();
294    if first >= n {
295        return;
296    }
297
298    let alpha = 2.0 / (period as f64 + 1.0);
299    let a = 1.0 - alpha;
300
301    let mut ema1 = *data.get_unchecked(first);
302    let mut ema2 = ema1;
303    *out.get_unchecked_mut(first) = ema1;
304
305    let mut i = first + 1;
306    let mut p = data.as_ptr().add(i);
307    let mut q = out.as_mut_ptr().add(i);
308
309    let limit = n.saturating_sub(4);
310    while i <= limit {
311        if i + 32 < n {
312            core::arch::x86_64::_mm_prefetch(
313                p.add(32) as *const i8,
314                core::arch::x86_64::_MM_HINT_T0,
315            );
316        }
317
318        let x0 = *p;
319        ema1 = ema1.mul_add(a, x0 * alpha);
320        ema2 = ema2.mul_add(a, ema1 * alpha);
321        *q = ema1.mul_add(2.0, -ema2);
322
323        let x1 = *p.add(1);
324        ema1 = ema1.mul_add(a, x1 * alpha);
325        ema2 = ema2.mul_add(a, ema1 * alpha);
326        *q.add(1) = ema1.mul_add(2.0, -ema2);
327
328        let x2 = *p.add(2);
329        ema1 = ema1.mul_add(a, x2 * alpha);
330        ema2 = ema2.mul_add(a, ema1 * alpha);
331        *q.add(2) = ema1.mul_add(2.0, -ema2);
332
333        let x3 = *p.add(3);
334        ema1 = ema1.mul_add(a, x3 * alpha);
335        ema2 = ema2.mul_add(a, ema1 * alpha);
336        *q.add(3) = ema1.mul_add(2.0, -ema2);
337
338        p = p.add(4);
339        q = q.add(4);
340        i += 4;
341    }
342
343    while i < n {
344        let x = *p;
345        ema1 = ema1.mul_add(a, x * alpha);
346        ema2 = ema2.mul_add(a, ema1 * alpha);
347        *q = ema1.mul_add(2.0, -ema2);
348
349        p = p.add(1);
350        q = q.add(1);
351        i += 1;
352    }
353}
354
355#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
356#[inline]
357pub unsafe fn dema_scalar(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
358    debug_assert!(period >= 1 && data.len() == out.len());
359    let n = data.len();
360    if first >= n {
361        return;
362    }
363
364    let alpha = 2.0 / (period as f64 + 1.0);
365    let a = 1.0 - alpha;
366
367    let mut ema1 = *data.get_unchecked(first);
368    let mut ema2 = ema1;
369    *out.get_unchecked_mut(first) = ema1;
370
371    let mut i = first + 1;
372    let mut p = data.as_ptr().add(i);
373    let mut q = out.as_mut_ptr().add(i);
374
375    let limit = n.saturating_sub(4);
376    while i <= limit {
377        let x0 = *p;
378        ema1 = ema1 * a + x0 * alpha;
379        ema2 = ema2 * a + ema1 * alpha;
380        *q = 2.0 * ema1 - ema2;
381
382        let x1 = *p.add(1);
383        ema1 = ema1 * a + x1 * alpha;
384        ema2 = ema2 * a + ema1 * alpha;
385        *q.add(1) = 2.0 * ema1 - ema2;
386
387        let x2 = *p.add(2);
388        ema1 = ema1 * a + x2 * alpha;
389        ema2 = ema2 * a + ema1 * alpha;
390        *q.add(2) = 2.0 * ema1 - ema2;
391
392        let x3 = *p.add(3);
393        ema1 = ema1 * a + x3 * alpha;
394        ema2 = ema2 * a + ema1 * alpha;
395        *q.add(3) = 2.0 * ema1 - ema2;
396
397        p = p.add(4);
398        q = q.add(4);
399        i += 4;
400    }
401
402    while i < n {
403        let x = *p;
404        ema1 = ema1 * a + x * alpha;
405        ema2 = ema2 * a + ema1 * alpha;
406        *q = 2.0 * ema1 - ema2;
407
408        p = p.add(1);
409        q = q.add(1);
410        i += 1;
411    }
412}
413
414#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
415#[inline(always)]
416unsafe fn last_lane_256(v: __m256d) -> f64 {
417    let hi: __m128d = _mm256_extractf128_pd(v, 1);
418    let dup_hi: __m128d = _mm_unpackhi_pd(hi, hi);
419    _mm_cvtsd_f64(dup_hi)
420}
421
422#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
423#[inline(always)]
424unsafe fn last_lane_512(v: __m512d) -> f64 {
425    let hi2: __m128d = _mm512_extractf64x2_pd(v, 3);
426    let dup_hi: __m128d = _mm_unpackhi_pd(hi2, hi2);
427    _mm_cvtsd_f64(dup_hi)
428}
429
430#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
431#[inline(always)]
432unsafe fn shl1_256(x: __m256d) -> __m256d {
433    let lo: __m128d = _mm256_castpd256_pd128(x);
434    let hi: __m128d = _mm256_extractf128_pd(x, 1);
435    let lo_res = _mm_unpacklo_pd(_mm_setzero_pd(), lo);
436    let hi_res = _mm_shuffle_pd(_mm_unpackhi_pd(lo, lo), _mm_unpacklo_pd(hi, hi), 0x0);
437    _mm256_insertf128_pd(_mm256_castpd128_pd256(lo_res), hi_res, 1)
438}
439
440#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
441#[inline(always)]
442unsafe fn shl2_256(x: __m256d) -> __m256d {
443    let lo: __m128d = _mm256_castpd256_pd128(x);
444    _mm256_insertf128_pd(_mm256_castpd128_pd256(_mm_setzero_pd()), lo, 1)
445}
446
447#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
448#[inline(always)]
449unsafe fn scan4(v: __m256d, a1: __m256d, a2: __m256d) -> __m256d {
450    let t1 = _mm256_fmadd_pd(a1, shl1_256(v), v);
451    let t2 = _mm256_fmadd_pd(a2, shl2_256(t1), t1);
452    t2
453}
454
455#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
456#[inline(always)]
457unsafe fn shl1_512(x: __m512d) -> __m512d {
458    let idx: __m512i = _mm512_set_epi64(6, 5, 4, 3, 2, 1, 0, 0);
459    let mask: __mmask8 = 0b1111_1110;
460    _mm512_maskz_permutexvar_pd(mask, idx, x)
461}
462#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
463#[inline(always)]
464unsafe fn shl2_512(x: __m512d) -> __m512d {
465    let idx: __m512i = _mm512_set_epi64(5, 4, 3, 2, 1, 0, 0, 0);
466    let mask: __mmask8 = 0b1111_1100;
467    _mm512_maskz_permutexvar_pd(mask, idx, x)
468}
469#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
470#[inline(always)]
471unsafe fn shl4_512(x: __m512d) -> __m512d {
472    let idx: __m512i = _mm512_set_epi64(3, 2, 1, 0, 0, 0, 0, 0);
473    let mask: __mmask8 = 0b1111_0000;
474    _mm512_maskz_permutexvar_pd(mask, idx, x)
475}
476#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
477#[inline(always)]
478unsafe fn scan8(v: __m512d, a1: __m512d, a2: __m512d, a4: __m512d) -> __m512d {
479    let t1 = _mm512_fmadd_pd(a1, shl1_512(v), v);
480    let t2 = _mm512_fmadd_pd(a2, shl2_512(t1), t1);
481    let t3 = _mm512_fmadd_pd(a4, shl4_512(t2), t2);
482    t3
483}
484
485#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
486#[target_feature(enable = "avx2,fma")]
487#[inline]
488pub unsafe fn dema_avx2(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
489    debug_assert!(data.len() == out.len());
490    if first >= data.len() {
491        return;
492    }
493
494    let n = data.len();
495    let alpha = 2.0 / (period as f64 + 1.0);
496    let a = 1.0 - alpha;
497
498    let mut i = first;
499    let mut ema1 = *data.get_unchecked(i);
500    let mut ema2 = ema1;
501    *out.get_unchecked_mut(i) = ema1;
502    i += 1;
503    if i >= n {
504        return;
505    }
506
507    let alpha_v = _mm256_set1_pd(alpha);
508    let a1_s = a;
509    let a2_s = a1_s * a1_s;
510    let a3_s = a2_s * a1_s;
511    let a4_s = a2_s * a2_s;
512    let pow_vec = _mm256_set_pd(a4_s, a3_s, a2_s, a1_s);
513    let a1_v = _mm256_set1_pd(a1_s);
514    let a2_v = _mm256_set1_pd(a2_s);
515
516    while i + 4 <= n {
517        let x = _mm256_loadu_pd(data.as_ptr().add(i));
518
519        let v1 = _mm256_mul_pd(alpha_v, x);
520        let t1 = scan4(v1, a1_v, a2_v);
521        let prev1 = _mm256_set1_pd(ema1);
522        let ema1_vec = _mm256_fmadd_pd(pow_vec, prev1, t1);
523
524        let v2 = _mm256_mul_pd(alpha_v, ema1_vec);
525        let t2 = scan4(v2, a1_v, a2_v);
526        let prev2 = _mm256_set1_pd(ema2);
527        let ema2_vec = _mm256_fmadd_pd(pow_vec, prev2, t2);
528
529        let two_ema1 = _mm256_add_pd(ema1_vec, ema1_vec);
530        let dema_v = _mm256_sub_pd(two_ema1, ema2_vec);
531        _mm256_storeu_pd(out.as_mut_ptr().add(i), dema_v);
532
533        ema1 = last_lane_256(ema1_vec);
534        ema2 = last_lane_256(ema2_vec);
535        i += 4;
536    }
537
538    while i < n {
539        let price = *data.get_unchecked(i);
540        ema1 = ema1.mul_add(a, price * alpha);
541        ema2 = ema2.mul_add(a, ema1 * alpha);
542        *out.get_unchecked_mut(i) = (2.0 * ema1) - ema2;
543        i += 1;
544    }
545}
546
547#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
548#[target_feature(enable = "avx512f,avx512dq,fma")]
549#[inline]
550pub unsafe fn dema_avx512(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
551    debug_assert!(data.len() == out.len());
552    if first >= data.len() {
553        return;
554    }
555
556    let n = data.len();
557    let alpha = 2.0 / (period as f64 + 1.0);
558    let a = 1.0 - alpha;
559
560    let mut i = first;
561    let mut ema1 = *data.get_unchecked(i);
562    let mut ema2 = ema1;
563    *out.get_unchecked_mut(i) = ema1;
564    i += 1;
565    if i >= n {
566        return;
567    }
568
569    let alpha_v = _mm512_set1_pd(alpha);
570    let a1_s = a;
571    let a2_s = a1_s * a1_s;
572    let a3_s = a2_s * a1_s;
573    let a4_s = a2_s * a2_s;
574    let a5_s = a4_s * a1_s;
575    let a6_s = a3_s * a3_s;
576    let a7_s = a6_s * a1_s;
577    let a8_s = a4_s * a4_s;
578    let pow_vec = _mm512_set_pd(a8_s, a7_s, a6_s, a5_s, a4_s, a3_s, a2_s, a1_s);
579    let a1_v = _mm512_set1_pd(a1_s);
580    let a2_v = _mm512_set1_pd(a2_s);
581    let a4_v = _mm512_set1_pd(a4_s);
582
583    while i + 8 <= n {
584        let x = _mm512_loadu_pd(data.as_ptr().add(i));
585
586        let v1 = _mm512_mul_pd(alpha_v, x);
587        let t1 = scan8(v1, a1_v, a2_v, a4_v);
588        let prev1 = _mm512_set1_pd(ema1);
589        let ema1_vec = _mm512_fmadd_pd(pow_vec, prev1, t1);
590
591        let v2 = _mm512_mul_pd(alpha_v, ema1_vec);
592        let t2 = scan8(v2, a1_v, a2_v, a4_v);
593        let prev2 = _mm512_set1_pd(ema2);
594        let ema2_vec = _mm512_fmadd_pd(pow_vec, prev2, t2);
595
596        let two_ema1 = _mm512_add_pd(ema1_vec, ema1_vec);
597        let dema_v = _mm512_sub_pd(two_ema1, ema2_vec);
598        _mm512_storeu_pd(out.as_mut_ptr().add(i), dema_v);
599
600        ema1 = last_lane_512(ema1_vec);
601        ema2 = last_lane_512(ema2_vec);
602        i += 8;
603    }
604
605    while i < n {
606        let price = *data.get_unchecked(i);
607        ema1 = ema1.mul_add(a, price * alpha);
608        ema2 = ema2.mul_add(a, ema1 * alpha);
609        *out.get_unchecked_mut(i) = (2.0 * ema1) - ema2;
610        i += 1;
611    }
612}
613
614#[derive(Debug, Clone)]
615pub struct DemaStream {
616    period: usize,
617    alpha: f64,
618    alpha_1: f64,
619    ema: f64,
620    ema2: f64,
621    filled: usize,
622    nan_fill: usize,
623}
624
625impl DemaStream {
626    pub fn try_new(params: DemaParams) -> Result<Self, DemaError> {
627        let period = params.period.unwrap_or(30);
628        if period < 1 {
629            return Err(DemaError::InvalidPeriod {
630                period,
631                data_len: 0,
632            });
633        }
634        Ok(Self {
635            period,
636            alpha: 2.0 / (period as f64 + 1.0),
637            alpha_1: 1.0 - 2.0 / (period as f64 + 1.0),
638            ema: f64::NAN,
639            ema2: f64::NAN,
640            filled: 0,
641            nan_fill: period - 1,
642        })
643    }
644
645    #[inline(always)]
646    fn fmadd(a: f64, b: f64, c: f64) -> f64 {
647        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
648        {
649            a.mul_add(b, c)
650        }
651        #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
652        {
653            a * b + c
654        }
655    }
656
657    #[inline(always)]
658    pub fn update(&mut self, value: f64) -> Option<f64> {
659        if self.filled == 0 {
660            self.ema = value;
661            self.ema2 = value;
662            self.filled = 1;
663
664            return if self.nan_fill == 0 {
665                Some(value)
666            } else {
667                None
668            };
669        }
670
671        let a = self.alpha;
672        let a1 = self.alpha_1;
673
674        self.ema = Self::fmadd(self.ema, a1, value * a);
675
676        self.ema2 = Self::fmadd(self.ema2, a1, self.ema * a);
677
678        let y = Self::fmadd(self.ema, 2.0, -self.ema2);
679
680        self.filled = self.filled.saturating_add(1);
681
682        if self.filled > self.nan_fill {
683            Some(y)
684        } else {
685            None
686        }
687    }
688
689    #[inline(always)]
690    pub fn update_nan(&mut self, value: f64) -> f64 {
691        match self.update(value) {
692            Some(v) => v,
693            None => f64::NAN,
694        }
695    }
696}
697
698#[inline(always)]
699fn fast_recip_nr1(d: f64) -> f64 {
700    let x0 = (d as f32).recip() as f64;
701    x0 * (2.0 - d * x0)
702}
703
704#[derive(Clone, Debug)]
705pub struct DemaBatchRange {
706    pub period: (usize, usize, usize),
707}
708
709impl Default for DemaBatchRange {
710    fn default() -> Self {
711        Self {
712            period: (30, 279, 1),
713        }
714    }
715}
716
717#[derive(Clone, Debug, Default)]
718pub struct DemaBatchBuilder {
719    range: DemaBatchRange,
720    kernel: Kernel,
721}
722
723impl DemaBatchBuilder {
724    pub fn new() -> Self {
725        Self::default()
726    }
727    pub fn kernel(mut self, k: Kernel) -> Self {
728        self.kernel = k;
729        self
730    }
731
732    #[inline]
733    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
734        self.range.period = (start, end, step);
735        self
736    }
737    #[inline]
738    pub fn period_static(mut self, p: usize) -> Self {
739        self.range.period = (p, p, 0);
740        self
741    }
742    pub fn apply_slice(self, data: &[f64]) -> Result<DemaBatchOutput, DemaError> {
743        dema_batch_with_kernel(data, &self.range, self.kernel)
744    }
745    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<DemaBatchOutput, DemaError> {
746        DemaBatchBuilder::new().kernel(k).apply_slice(data)
747    }
748    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<DemaBatchOutput, DemaError> {
749        let slice = source_type(c, src);
750        self.apply_slice(slice)
751    }
752    pub fn with_default_candles(c: &Candles) -> Result<DemaBatchOutput, DemaError> {
753        DemaBatchBuilder::new()
754            .kernel(Kernel::Auto)
755            .apply_candles(c, "close")
756    }
757}
758
759pub struct DemaBatchOutput {
760    pub values: Vec<f64>,
761    pub combos: Vec<DemaParams>,
762    pub rows: usize,
763    pub cols: usize,
764}
765impl DemaBatchOutput {
766    pub fn row_for_params(&self, p: &DemaParams) -> Option<usize> {
767        self.combos
768            .iter()
769            .position(|c| c.period.unwrap_or(30) == p.period.unwrap_or(30))
770    }
771    pub fn values_for(&self, p: &DemaParams) -> Option<&[f64]> {
772        self.row_for_params(p).map(|row| {
773            let start = row * self.cols;
774            &self.values[start..start + self.cols]
775        })
776    }
777}
778
779#[inline(always)]
780fn expand_grid(r: &DemaBatchRange) -> Vec<DemaParams> {
781    fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
782        if step == 0 || start == end {
783            return vec![start];
784        }
785        let mut vals = Vec::new();
786        if start < end {
787            let mut v = start;
788            while v <= end {
789                vals.push(v);
790                match v.checked_add(step) {
791                    Some(n) if n != v => v = n,
792                    _ => break,
793                }
794            }
795        } else {
796            let mut v = start;
797            loop {
798                vals.push(v);
799                if v <= end {
800                    break;
801                }
802                match v.checked_sub(step) {
803                    Some(n) if n != v => v = n,
804                    _ => break,
805                }
806            }
807        }
808        vals
809    }
810    let periods = axis_usize(r.period);
811    let mut out = Vec::with_capacity(periods.len());
812    for &p in &periods {
813        out.push(DemaParams { period: Some(p) });
814    }
815    out
816}
817
818#[inline(always)]
819pub fn dema_batch_slice(
820    data: &[f64],
821    sweep: &DemaBatchRange,
822    kern: Kernel,
823) -> Result<DemaBatchOutput, DemaError> {
824    dema_batch_inner(data, sweep, kern, false)
825}
826
827#[inline(always)]
828pub fn dema_batch_par_slice(
829    data: &[f64],
830    sweep: &DemaBatchRange,
831    kern: Kernel,
832) -> Result<DemaBatchOutput, DemaError> {
833    dema_batch_inner(data, sweep, kern, true)
834}
835
836#[inline(always)]
837pub(crate) fn dema_batch_with_kernel(
838    data: &[f64],
839    sweep: &DemaBatchRange,
840    k: Kernel,
841) -> Result<DemaBatchOutput, DemaError> {
842    let kernel = match k {
843        Kernel::Auto => Kernel::ScalarBatch,
844        other if other.is_batch() => other,
845        other => return Err(DemaError::InvalidKernelForBatch(other)),
846    };
847
848    let simd = match kernel {
849        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
850        Kernel::Avx512Batch => Kernel::Avx512,
851        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
852        Kernel::Avx2Batch => Kernel::Scalar,
853        Kernel::ScalarBatch => Kernel::Scalar,
854        _ => unreachable!(),
855    };
856    dema_batch_par_slice(data, sweep, simd)
857}
858
859#[inline(always)]
860fn dema_batch_inner(
861    data: &[f64],
862    sweep: &DemaBatchRange,
863    kern: Kernel,
864    parallel: bool,
865) -> Result<DemaBatchOutput, DemaError> {
866    let combos = {
867        let v = expand_grid(sweep);
868        if v.is_empty() {
869            return Err(DemaError::InvalidRange {
870                start: sweep.period.0,
871                end: sweep.period.1,
872                step: sweep.period.2,
873            });
874        }
875        v
876    };
877    let cols = data.len();
878    let rows = combos.len();
879
880    let _total = rows.checked_mul(cols).ok_or(DemaError::SizeOverflow {
881        context: "rows*cols for batch buffer",
882    })?;
883
884    if cols == 0 {
885        return Err(DemaError::EmptyInputData);
886    }
887
888    let mut buf_mu = make_uninit_matrix(rows, cols);
889
890    let warm: Vec<usize> = combos
891        .iter()
892        .map(|c| {
893            data.iter()
894                .position(|x| !x.is_nan())
895                .unwrap_or(0)
896                .saturating_add(c.period.unwrap().saturating_sub(1))
897        })
898        .collect();
899
900    init_matrix_prefixes(&mut buf_mu, cols, &warm);
901
902    let mut buf_guard = std::mem::ManuallyDrop::new(buf_mu);
903    let out: &mut [f64] = unsafe {
904        core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
905    };
906
907    dema_batch_inner_into(data, sweep, kern, parallel, out)?;
908
909    let values = unsafe {
910        Vec::from_raw_parts(
911            buf_guard.as_mut_ptr() as *mut f64,
912            buf_guard.len(),
913            buf_guard.capacity(),
914        )
915    };
916
917    Ok(DemaBatchOutput {
918        values,
919        combos,
920        rows,
921        cols,
922    })
923}
924
925#[inline(always)]
926fn dema_batch_inner_into(
927    data: &[f64],
928    sweep: &DemaBatchRange,
929    kern: Kernel,
930    parallel: bool,
931    out: &mut [f64],
932) -> Result<Vec<DemaParams>, DemaError> {
933    let combos = {
934        let v = expand_grid(sweep);
935        if v.is_empty() {
936            return Err(DemaError::InvalidRange {
937                start: sweep.period.0,
938                end: sweep.period.1,
939                step: sweep.period.2,
940            });
941        }
942        v
943    };
944
945    if data.is_empty() {
946        return Err(DemaError::EmptyInputData);
947    }
948
949    let first = data
950        .iter()
951        .position(|x| !x.is_nan())
952        .ok_or(DemaError::AllValuesNaN)?;
953
954    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
955    let needed = 2 * (max_p - 1);
956    if data.len() < needed {
957        return Err(DemaError::NotEnoughData {
958            needed,
959            valid: data.len(),
960        });
961    }
962    let valid = data.len() - first;
963    if valid < needed {
964        return Err(DemaError::NotEnoughValidData { needed, valid });
965    }
966
967    let rows = combos.len();
968    let cols = data.len();
969
970    let expected = rows.checked_mul(cols).ok_or(DemaError::SizeOverflow {
971        context: "rows*cols when validating output buffer",
972    })?;
973    if out.len() != expected {
974        return Err(DemaError::OutputLengthMismatch {
975            expected,
976            got: out.len(),
977        });
978    }
979
980    let do_row = |row: usize, dst: &mut [f64]| unsafe {
981        let p = combos[row].period.unwrap();
982
983        match kern {
984            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
985            Kernel::Avx512 => dema_row_avx512(data, first, p, dst),
986            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
987            Kernel::Avx2 => dema_row_avx2(data, first, p, dst),
988            _ => dema_row_scalar(data, first, p, dst),
989        }
990
991        let warm = first + p - 1;
992        dst[..warm].fill(f64::NAN);
993    };
994
995    if parallel {
996        #[cfg(not(target_arch = "wasm32"))]
997        {
998            out.par_chunks_mut(cols)
999                .enumerate()
1000                .for_each(|(row, slice)| do_row(row, slice));
1001        }
1002        #[cfg(target_arch = "wasm32")]
1003        {
1004            for (row, slice) in out.chunks_mut(cols).enumerate() {
1005                do_row(row, slice);
1006            }
1007        }
1008    } else {
1009        for (row, slice) in out.chunks_mut(cols).enumerate() {
1010            do_row(row, slice);
1011        }
1012    }
1013
1014    Ok(combos)
1015}
1016#[inline(always)]
1017unsafe fn dema_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1018    dema_scalar(data, period, first, out)
1019}
1020#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1021#[target_feature(enable = "avx2,fma")]
1022unsafe fn dema_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1023    dema_avx2(data, period, first, out)
1024}
1025#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1026#[target_feature(enable = "avx512f,fma")]
1027unsafe fn dema_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1028    dema_avx512(data, period, first, out)
1029}
1030
1031#[cfg(all(feature = "python", feature = "cuda"))]
1032#[pyclass(module = "ta_indicators.cuda", name = "DeviceArrayF32Dema", unsendable)]
1033pub struct DeviceArrayF32DemaPy {
1034    pub(crate) inner: DeviceArrayF32,
1035    _ctx_guard: std::sync::Arc<cust::context::Context>,
1036    _device_id: u32,
1037}
1038
1039#[cfg(all(feature = "python", feature = "cuda"))]
1040#[pymethods]
1041impl DeviceArrayF32DemaPy {
1042    #[new]
1043    fn py_new() -> PyResult<Self> {
1044        Err(pyo3::exceptions::PyTypeError::new_err(
1045            "use factory methods from CUDA functions",
1046        ))
1047    }
1048
1049    #[getter]
1050    fn __cuda_array_interface__<'py>(
1051        &self,
1052        py: Python<'py>,
1053    ) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1054        let d = pyo3::types::PyDict::new(py);
1055        let itemsize = std::mem::size_of::<f32>();
1056        d.set_item("shape", (self.inner.rows, self.inner.cols))?;
1057        d.set_item("typestr", "<f4")?;
1058        d.set_item("strides", (self.inner.cols * itemsize, itemsize))?;
1059        let size = self.inner.rows.saturating_mul(self.inner.cols);
1060        let ptr_val: usize = if size == 0 {
1061            0
1062        } else {
1063            self.inner.buf.as_device_ptr().as_raw() as usize
1064        };
1065        d.set_item("data", (ptr_val, false))?;
1066        d.set_item("version", 3)?;
1067        Ok(d)
1068    }
1069
1070    fn __dlpack_device__(&self) -> (i32, i32) {
1071        (2, self._device_id as i32)
1072    }
1073
1074    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1075    fn __dlpack__<'py>(
1076        &mut self,
1077        py: Python<'py>,
1078        stream: Option<PyObject>,
1079        max_version: Option<PyObject>,
1080        dl_device: Option<PyObject>,
1081        copy: Option<PyObject>,
1082    ) -> PyResult<pyo3::PyObject> {
1083        use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
1084
1085        let (kdl, alloc_dev) = self.__dlpack_device__();
1086        if let Some(dev_obj) = dl_device.as_ref() {
1087            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1088                if dev_ty != kdl || dev_id != alloc_dev {
1089                    let wants_copy = copy
1090                        .as_ref()
1091                        .and_then(|c| c.extract::<bool>(py).ok())
1092                        .unwrap_or(false);
1093                    if wants_copy {
1094                        return Err(PyValueError::new_err(
1095                            "device copy not implemented for __dlpack__",
1096                        ));
1097                    } else {
1098                        return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
1099                    }
1100                }
1101            }
1102        }
1103
1104        let _ = stream;
1105
1106        let dummy = cust::memory::DeviceBuffer::from_slice(&[])
1107            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1108        let inner = std::mem::replace(
1109            &mut self.inner,
1110            DeviceArrayF32 {
1111                buf: dummy,
1112                rows: 0,
1113                cols: 0,
1114            },
1115        );
1116
1117        let rows = inner.rows;
1118        let cols = inner.cols;
1119        let buf = inner.buf;
1120
1121        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1122
1123        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1124    }
1125}
1126
1127#[cfg(all(feature = "python", feature = "cuda"))]
1128impl DeviceArrayF32DemaPy {
1129    pub fn new(
1130        inner: DeviceArrayF32,
1131        ctx_guard: std::sync::Arc<cust::context::Context>,
1132        device_id: u32,
1133    ) -> Self {
1134        Self {
1135            inner,
1136            _ctx_guard: ctx_guard,
1137            _device_id: device_id,
1138        }
1139    }
1140}
1141
1142#[cfg(test)]
1143mod tests {
1144    use super::*;
1145    use crate::skip_if_unsupported;
1146    use crate::utilities::data_loader::read_candles_from_csv;
1147    use proptest::prelude::*;
1148
1149    fn check_dema_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1150        skip_if_unsupported!(kernel, test_name);
1151        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1152        let candles = read_candles_from_csv(file_path)?;
1153
1154        let default_params = DemaParams { period: None };
1155        let input_default = DemaInput::from_candles(&candles, "close", default_params);
1156        let output_default = dema_with_kernel(&input_default, kernel)?;
1157        assert_eq!(output_default.values.len(), candles.close.len());
1158
1159        let params_period_14 = DemaParams { period: Some(14) };
1160        let input_period_14 = DemaInput::from_candles(&candles, "hl2", params_period_14);
1161        let output_period_14 = dema_with_kernel(&input_period_14, kernel)?;
1162        assert_eq!(output_period_14.values.len(), candles.close.len());
1163
1164        let params_custom = DemaParams { period: Some(20) };
1165        let input_custom = DemaInput::from_candles(&candles, "hlc3", params_custom);
1166        let output_custom = dema_with_kernel(&input_custom, kernel)?;
1167        assert_eq!(output_custom.values.len(), candles.close.len());
1168        Ok(())
1169    }
1170
1171    fn check_dema_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1172        skip_if_unsupported!(kernel, test_name);
1173        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1174        let candles = read_candles_from_csv(file_path)?;
1175
1176        let input = DemaInput::with_default_candles(&candles);
1177        let result = dema_with_kernel(&input, kernel)?;
1178
1179        let expected_last_five = [
1180            59189.73193987478,
1181            59129.24920772847,
1182            59058.80282420511,
1183            59011.5555611042,
1184            58908.370159946775,
1185        ];
1186        let start_index = result.values.len().saturating_sub(5);
1187        let last_five = &result.values[start_index..];
1188        for (i, &val) in last_five.iter().enumerate() {
1189            let exp = expected_last_five[i];
1190            assert!(
1191                (val - exp).abs() < 1e-6,
1192                "DEMA mismatch at index {}: expected {}, got {}",
1193                start_index + i,
1194                exp,
1195                val
1196            );
1197        }
1198        Ok(())
1199    }
1200
1201    fn check_dema_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1202        skip_if_unsupported!(kernel, test_name);
1203        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1204        let candles = read_candles_from_csv(file_path)?;
1205        let input = DemaInput::with_default_candles(&candles);
1206        match input.data {
1207            DemaData::Candles { source, .. } => assert_eq!(source, "close"),
1208            _ => panic!("Expected DemaData::Candles"),
1209        }
1210        assert_eq!(input.params.period, Some(30));
1211        let output = dema_with_kernel(&input, kernel)?;
1212        assert_eq!(output.values.len(), candles.close.len());
1213        Ok(())
1214    }
1215
1216    fn check_dema_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1217        skip_if_unsupported!(kernel, test_name);
1218        let input_data = [10.0, 20.0, 30.0];
1219        let params = DemaParams { period: Some(0) };
1220        let input = DemaInput::from_slice(&input_data, params);
1221        let result = dema_with_kernel(&input, kernel);
1222        assert!(result.is_err());
1223        Ok(())
1224    }
1225
1226    fn check_dema_period_exceeds_length(
1227        test_name: &str,
1228        kernel: Kernel,
1229    ) -> Result<(), Box<dyn Error>> {
1230        skip_if_unsupported!(kernel, test_name);
1231        let data_small = [10.0, 20.0, 30.0];
1232        let params = DemaParams { period: Some(10) };
1233        let input = DemaInput::from_slice(&data_small, params);
1234        let result = dema_with_kernel(&input, kernel);
1235        assert!(result.is_err());
1236        Ok(())
1237    }
1238
1239    fn check_dema_very_small_dataset(
1240        test_name: &str,
1241        kernel: Kernel,
1242    ) -> Result<(), Box<dyn Error>> {
1243        skip_if_unsupported!(kernel, test_name);
1244        let single_point = [42.0];
1245        let params = DemaParams { period: Some(9) };
1246        let input = DemaInput::from_slice(&single_point, params);
1247        let result = dema_with_kernel(&input, kernel);
1248        assert!(result.is_err());
1249        Ok(())
1250    }
1251
1252    fn check_dema_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1253        skip_if_unsupported!(kernel, test_name);
1254        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1255        let candles = read_candles_from_csv(file_path)?;
1256
1257        let first_params = DemaParams { period: Some(80) };
1258        let first_input = DemaInput::from_candles(&candles, "close", first_params);
1259        let first_result = dema_with_kernel(&first_input, kernel)?;
1260
1261        let second_params = DemaParams { period: Some(60) };
1262        let second_input = DemaInput::from_slice(&first_result.values, second_params);
1263        let second_result = dema_with_kernel(&second_input, kernel)?;
1264
1265        assert_eq!(second_result.values.len(), first_result.values.len());
1266        if second_result.values.len() > 240 {
1267            for i in 240..second_result.values.len() {
1268                assert!(!second_result.values[i].is_nan());
1269            }
1270        }
1271        Ok(())
1272    }
1273
1274    fn check_dema_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1275        skip_if_unsupported!(kernel, test_name);
1276        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1277        let candles = read_candles_from_csv(file_path)?;
1278        let params = DemaParams { period: Some(30) };
1279        let input = DemaInput::from_candles(&candles, "close", params);
1280        let result = dema_with_kernel(&input, kernel)?;
1281        if result.values.len() > 240 {
1282            for i in 240..result.values.len() {
1283                assert!(!result.values[i].is_nan());
1284            }
1285        }
1286        Ok(())
1287    }
1288
1289    fn check_dema_empty_input(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1290        skip_if_unsupported!(kernel, test_name);
1291        let empty: [f64; 0] = [];
1292        let input = DemaInput::from_slice(&empty, DemaParams::default());
1293        let res = dema_with_kernel(&input, kernel);
1294        assert!(matches!(res, Err(DemaError::EmptyInputData)));
1295        Ok(())
1296    }
1297
1298    fn check_dema_not_enough_valid(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1299        skip_if_unsupported!(kernel, test_name);
1300        let data = [f64::NAN, f64::NAN, 1.0, 2.0];
1301        let params = DemaParams { period: Some(3) };
1302        let input = DemaInput::from_slice(&data, params);
1303        let res = dema_with_kernel(&input, kernel);
1304        assert!(matches!(res, Err(DemaError::NotEnoughValidData { .. })));
1305        Ok(())
1306    }
1307
1308    #[allow(clippy::float_cmp)]
1309    fn check_dema_property(
1310        test_name: &str,
1311        kernel: Kernel,
1312    ) -> Result<(), Box<dyn std::error::Error>> {
1313        use float_cmp::approx_eq;
1314        use proptest::prelude::*;
1315
1316        skip_if_unsupported!(kernel, test_name);
1317
1318        let strat = (1usize..=32).prop_flat_map(|period| {
1319            let min_len = 2 * period.max(2);
1320            (
1321                prop::collection::vec(
1322                    (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1323                    min_len..400,
1324                ),
1325                Just(period),
1326                (-1e3f64..1e3f64).prop_filter("non-zero scale", |a| a.is_finite() && *a != 0.0),
1327                -1e3f64..1e3f64,
1328            )
1329        });
1330
1331        proptest::test_runner::TestRunner::default()
1332            .run(&strat, |(data, period, a, b)| {
1333                let params = DemaParams {
1334                    period: Some(period),
1335                };
1336                let input = DemaInput::from_slice(&data, params.clone());
1337
1338                let fast = dema_with_kernel(&input, kernel);
1339                let slow = dema_with_kernel(&input, Kernel::Scalar);
1340
1341                match (fast, slow) {
1342                    (Err(e1), Err(e2))
1343                        if std::mem::discriminant(&e1) == std::mem::discriminant(&e2) =>
1344                    {
1345                        return Ok(())
1346                    }
1347
1348                    (Err(e1), Err(e2)) => {
1349                        prop_assert!(false, "different errors: fast={:?} slow={:?}", e1, e2)
1350                    }
1351
1352                    (Err(e1), Ok(_)) => {
1353                        prop_assert!(false, "fast errored {e1:?} but scalar succeeded")
1354                    }
1355                    (Ok(_), Err(e2)) => {
1356                        prop_assert!(false, "scalar errored {e2:?} but fast succeeded")
1357                    }
1358
1359                    (Ok(fast), Ok(reference)) => {
1360                        let DemaOutput { values: out } = fast;
1361                        let DemaOutput { values: rref } = reference;
1362
1363                        let mut stream = DemaStream::try_new(params.clone()).unwrap();
1364                        let mut s_out = Vec::with_capacity(data.len());
1365                        for &v in &data {
1366                            s_out.push(stream.update(v).unwrap_or(f64::NAN));
1367                        }
1368
1369                        let transformed: Vec<f64> = data.iter().map(|x| a * *x + b).collect();
1370                        let t_out =
1371                            dema(&DemaInput::from_slice(&transformed, params.clone()))?.values;
1372
1373                        let nan_fill = period - 1;
1374                        for i in 0..data.len() {
1375                            let y = out[i];
1376                            let yr = rref[i];
1377                            let ys = s_out[i];
1378                            let yt = t_out[i];
1379
1380                            if period == 1 && y.is_finite() {
1381                                prop_assert!(approx_eq!(f64, y, data[i], ulps = 2));
1382                            }
1383
1384                            if i >= period - 1 {
1385                                let window = &data[i.saturating_sub(period - 1)..=i];
1386                                if window.iter().all(|v| *v == window[0]) {
1387                                    prop_assert!(approx_eq!(f64, y, window[0], epsilon = 1e-9));
1388                                }
1389                            } else {
1390                                prop_assert!(y.is_nan(), "Expected NaN during warmup at index {i}");
1391                            }
1392
1393                            if i >= nan_fill {
1394                                if y.is_finite() {
1395                                    let expected = a * y + b;
1396                                    let diff = (yt - expected).abs();
1397
1398                                    let tol = 1e-7_f64.max(expected.abs() * 1e-9);
1399                                    let ulp = yt.to_bits().abs_diff(expected.to_bits());
1400                                    prop_assert!(
1401                                        diff <= tol || ulp <= 8,
1402                                        "idx {i}: affine mismatch diff={diff:e}  ULP={ulp}"
1403                                    );
1404                                } else {
1405                                    prop_assert_eq!(
1406                                        y.to_bits(),
1407                                        yt.to_bits(),
1408                                        "idx {}: special-value mismatch under affine map",
1409                                        i
1410                                    );
1411                                }
1412                            }
1413
1414                            let ulp = y.to_bits().abs_diff(yr.to_bits());
1415                            prop_assert!(
1416                                (y - yr).abs() <= 1e-9 || ulp <= 4,
1417                                "idx {i}: fast={y} ref={yr} ULP={ulp}"
1418                            );
1419
1420                            if period == 1 {
1421                                prop_assert!(
1422									(y - ys).abs() <= 1e-9 || (y.is_nan() && ys.is_nan()),
1423									"idx {i}: stream mismatch for period=1 - batch={y}, stream={ys}"
1424								);
1425                            } else if i < period - 1 {
1426                                prop_assert!(
1427                                    ys.is_nan(),
1428                                    "idx {i}: stream should return NaN during warmup, got {ys}"
1429                                );
1430                            } else {
1431                                prop_assert!(
1432                                    (y - ys).abs() <= 1e-9 || (y.is_nan() && ys.is_nan()),
1433                                    "idx {i}: stream mismatch - batch={y}, stream={ys}"
1434                                );
1435                            }
1436                        }
1437                    }
1438                }
1439
1440                Ok(())
1441            })
1442            .unwrap();
1443
1444        assert!(dema(&DemaInput::from_slice(&[], DemaParams::default())).is_err());
1445        assert!(dema(&DemaInput::from_slice(
1446            &[f64::NAN; 12],
1447            DemaParams::default()
1448        ))
1449        .is_err());
1450        assert!(dema(&DemaInput::from_slice(
1451            &[1.0; 5],
1452            DemaParams { period: Some(12) }
1453        ))
1454        .is_err());
1455        assert!(dema(&DemaInput::from_slice(
1456            &[1.0; 5],
1457            DemaParams { period: Some(0) }
1458        ))
1459        .is_err());
1460
1461        Ok(())
1462    }
1463
1464    fn check_dema_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1465        skip_if_unsupported!(kernel, test_name);
1466
1467        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1468        let candles = read_candles_from_csv(file_path)?;
1469
1470        let period = 30;
1471        let input = DemaInput::from_candles(
1472            &candles,
1473            "close",
1474            DemaParams {
1475                period: Some(period),
1476            },
1477        );
1478        let batch_output = dema_with_kernel(&input, kernel)?.values;
1479
1480        let mut stream = DemaStream::try_new(DemaParams {
1481            period: Some(period),
1482        })?;
1483        let mut stream_values = Vec::with_capacity(candles.close.len());
1484        for &price in &candles.close {
1485            stream_values.push(stream.update(price).unwrap_or(f64::NAN));
1486        }
1487
1488        assert_eq!(batch_output.len(), stream_values.len());
1489
1490        for (i, (&b, &s)) in batch_output
1491            .iter()
1492            .zip(&stream_values)
1493            .enumerate()
1494            .skip(period)
1495        {
1496            if b.is_nan() && s.is_nan() {
1497                continue;
1498            }
1499
1500            let diff = (b - s).abs();
1501            assert!(
1502                diff < 1e-9,
1503                "[{}] DEMA streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1504                test_name,
1505                i,
1506                b,
1507                s,
1508                diff
1509            );
1510        }
1511        Ok(())
1512    }
1513
1514    #[cfg(debug_assertions)]
1515    fn check_dema_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1516        skip_if_unsupported!(kernel, test_name);
1517
1518        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1519        let candles = read_candles_from_csv(file_path)?;
1520
1521        let test_params = vec![
1522            DemaParams::default(),
1523            DemaParams { period: Some(2) },
1524            DemaParams { period: Some(3) },
1525            DemaParams { period: Some(5) },
1526            DemaParams { period: Some(7) },
1527            DemaParams { period: Some(10) },
1528            DemaParams { period: Some(12) },
1529            DemaParams { period: Some(20) },
1530            DemaParams { period: Some(30) },
1531            DemaParams { period: Some(50) },
1532            DemaParams { period: Some(100) },
1533            DemaParams { period: Some(200) },
1534            DemaParams { period: Some(1) },
1535            DemaParams { period: Some(250) },
1536        ];
1537
1538        for (param_idx, params) in test_params.iter().enumerate() {
1539            let input = DemaInput::from_candles(&candles, "close", params.clone());
1540            let output = dema_with_kernel(&input, kernel)?;
1541
1542            for (i, &val) in output.values.iter().enumerate() {
1543                if val.is_nan() {
1544                    continue;
1545                }
1546
1547                let bits = val.to_bits();
1548
1549                if bits == 0x11111111_11111111 {
1550                    panic!(
1551                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1552                        with params: period={}",
1553                        test_name,
1554                        val,
1555                        bits,
1556                        i,
1557                        params.period.unwrap_or(30)
1558                    );
1559                }
1560
1561                if bits == 0x22222222_22222222 {
1562                    panic!(
1563                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1564                        with params: period={}",
1565                        test_name,
1566                        val,
1567                        bits,
1568                        i,
1569                        params.period.unwrap_or(30)
1570                    );
1571                }
1572
1573                if bits == 0x33333333_33333333 {
1574                    panic!(
1575                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1576                        with params: period={}",
1577                        test_name,
1578                        val,
1579                        bits,
1580                        i,
1581                        params.period.unwrap_or(30)
1582                    );
1583                }
1584            }
1585        }
1586
1587        Ok(())
1588    }
1589
1590    #[cfg(not(debug_assertions))]
1591    fn check_dema_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1592        Ok(())
1593    }
1594
1595    macro_rules! generate_all_dema_tests {
1596        ($($test_fn:ident),*) => {
1597            paste::paste! {
1598                $(
1599                    #[test]
1600                    fn [<$test_fn _scalar_f64>]() {
1601                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1602                    }
1603                    #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1604                    #[test]
1605                    fn [<$test_fn _avx2_f64>]() {
1606                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1607                    }
1608                    #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1609                    #[test]
1610                    fn [<$test_fn _avx512_f64>]() {
1611                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1612                    }
1613                )*
1614            }
1615        }
1616    }
1617
1618    fn check_dema_warmup_nan_preservation(
1619        test_name: &str,
1620        kernel: Kernel,
1621    ) -> Result<(), Box<dyn Error>> {
1622        skip_if_unsupported!(kernel, test_name);
1623        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1624        let candles = read_candles_from_csv(file_path)?;
1625
1626        let test_periods = vec![10, 20, 30, 50];
1627
1628        for period in test_periods {
1629            let params = DemaParams {
1630                period: Some(period),
1631            };
1632            let input = DemaInput::from_candles(&candles, "close", params);
1633            let result = dema_with_kernel(&input, kernel)?;
1634
1635            let warmup = period - 1;
1636            for i in 0..warmup {
1637                assert!(
1638                    result.values[i].is_nan(),
1639                    "[{}] Expected NaN at index {} (warmup={}) for period={}, got {}",
1640                    test_name,
1641                    i,
1642                    warmup,
1643                    period,
1644                    result.values[i]
1645                );
1646            }
1647
1648            for i in warmup..warmup + 10 {
1649                assert!(
1650                    !result.values[i].is_nan(),
1651                    "[{}] Expected non-NaN at index {} (warmup={}) for period={}, got NaN",
1652                    test_name,
1653                    i,
1654                    warmup,
1655                    period
1656                );
1657            }
1658        }
1659        Ok(())
1660    }
1661
1662    generate_all_dema_tests!(
1663        check_dema_partial_params,
1664        check_dema_accuracy,
1665        check_dema_default_candles,
1666        check_dema_zero_period,
1667        check_dema_period_exceeds_length,
1668        check_dema_very_small_dataset,
1669        check_dema_empty_input,
1670        check_dema_not_enough_valid,
1671        check_dema_reinput,
1672        check_dema_nan_handling,
1673        check_dema_streaming,
1674        check_dema_property,
1675        check_dema_no_poison,
1676        check_dema_warmup_nan_preservation
1677    );
1678
1679    #[test]
1680    fn test_dema_into_matches_api() -> Result<(), Box<dyn Error>> {
1681        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1682        let candles = read_candles_from_csv(file_path)?;
1683
1684        let input = DemaInput::with_default_candles(&candles);
1685        let baseline = dema(&input)?.values;
1686
1687        let mut out = vec![0.0; candles.close.len()];
1688
1689        #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1690        {
1691            dema_into(&input, &mut out)?;
1692        }
1693        #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1694        {
1695            dema_into_slice(&mut out, &input, Kernel::Auto)?;
1696        }
1697
1698        assert_eq!(out.len(), baseline.len());
1699        for i in 0..out.len() {
1700            let a = out[i];
1701            let b = baseline[i];
1702            if a.is_nan() || b.is_nan() {
1703                assert!(a.is_nan() && b.is_nan(), "NaN mismatch at index {}", i);
1704            } else {
1705                assert!(a == b, "Value mismatch at index {}: {} != {}", i, a, b);
1706            }
1707        }
1708        Ok(())
1709    }
1710
1711    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1712        skip_if_unsupported!(kernel, test);
1713
1714        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1715        let c = read_candles_from_csv(file)?;
1716
1717        let output = DemaBatchBuilder::new()
1718            .kernel(kernel)
1719            .apply_candles(&c, "close")?;
1720
1721        let def = DemaParams::default();
1722        let row = output.values_for(&def).expect("default row missing");
1723
1724        assert_eq!(row.len(), c.close.len());
1725
1726        let expected = [
1727            59189.73193987478,
1728            59129.24920772847,
1729            59058.80282420511,
1730            59011.5555611042,
1731            58908.370159946775,
1732        ];
1733        let start = row.len() - 5;
1734        for (i, &v) in row[start..].iter().enumerate() {
1735            assert!(
1736                (v - expected[i]).abs() < 1e-6,
1737                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1738            );
1739        }
1740        Ok(())
1741    }
1742
1743    #[cfg(debug_assertions)]
1744    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1745        skip_if_unsupported!(kernel, test);
1746
1747        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1748        let c = read_candles_from_csv(file)?;
1749
1750        let test_configs = vec![
1751            (2, 5, 1),
1752            (5, 25, 5),
1753            (10, 50, 10),
1754            (1, 3, 1),
1755            (50, 150, 25),
1756            (10, 30, 2),
1757            (10, 30, 10),
1758            (100, 300, 50),
1759        ];
1760
1761        for (cfg_idx, &(p_start, p_end, p_step)) in test_configs.iter().enumerate() {
1762            let output = DemaBatchBuilder::new()
1763                .kernel(kernel)
1764                .period_range(p_start, p_end, p_step)
1765                .apply_candles(&c, "close")?;
1766
1767            for (idx, &val) in output.values.iter().enumerate() {
1768                if val.is_nan() {
1769                    continue;
1770                }
1771
1772                let bits = val.to_bits();
1773                let row = idx / output.cols;
1774                let col = idx % output.cols;
1775                let combo = &output.combos[row];
1776
1777                if bits == 0x11111111_11111111 {
1778                    panic!(
1779                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1780                        at row {} col {} (flat index {}) with params: period={}",
1781                        test,
1782                        cfg_idx,
1783                        val,
1784                        bits,
1785                        row,
1786                        col,
1787                        idx,
1788                        combo.period.unwrap_or(30)
1789                    );
1790                }
1791
1792                if bits == 0x22222222_22222222 {
1793                    panic!(
1794                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1795                        at row {} col {} (flat index {}) with params: period={}",
1796                        test,
1797                        cfg_idx,
1798                        val,
1799                        bits,
1800                        row,
1801                        col,
1802                        idx,
1803                        combo.period.unwrap_or(30)
1804                    );
1805                }
1806
1807                if bits == 0x33333333_33333333 {
1808                    panic!(
1809                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1810                        at row {} col {} (flat index {}) with params: period={}",
1811                        test,
1812                        cfg_idx,
1813                        val,
1814                        bits,
1815                        row,
1816                        col,
1817                        idx,
1818                        combo.period.unwrap_or(30)
1819                    );
1820                }
1821            }
1822        }
1823
1824        Ok(())
1825    }
1826
1827    #[cfg(not(debug_assertions))]
1828    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1829        Ok(())
1830    }
1831
1832    macro_rules! gen_batch_tests {
1833        ($fn_name:ident) => {
1834            paste::paste! {
1835                #[test]
1836                fn [<$fn_name _scalar>]() {
1837                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1838                }
1839                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1840                #[test]
1841                fn [<$fn_name _avx2>]() {
1842                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1843                }
1844                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1845                #[test]
1846                fn [<$fn_name _avx512>]() {
1847                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1848                }
1849                #[test]
1850                fn [<$fn_name _auto_detect>]() {
1851                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1852                }
1853            }
1854        };
1855    }
1856    fn check_batch_warmup_nan_preservation(
1857        test: &str,
1858        kernel: Kernel,
1859    ) -> Result<(), Box<dyn Error>> {
1860        skip_if_unsupported!(kernel, test);
1861
1862        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1863        let c = read_candles_from_csv(file)?;
1864
1865        let output = DemaBatchBuilder::new()
1866            .kernel(kernel)
1867            .period_range(10, 30, 10)
1868            .apply_candles(&c, "close")?;
1869
1870        for (row_idx, combo) in output.combos.iter().enumerate() {
1871            let period = combo.period.unwrap_or(30);
1872            let warmup = period - 1;
1873            let row_start = row_idx * output.cols;
1874
1875            for i in 0..warmup {
1876                let val = output.values[row_start + i];
1877                assert!(
1878                    val.is_nan(),
1879                    "[{}] Batch row {} (period={}): Expected NaN at index {}, got {}",
1880                    test,
1881                    row_idx,
1882                    period,
1883                    i,
1884                    val
1885                );
1886            }
1887
1888            for i in warmup..warmup.min(output.cols).min(warmup + 10) {
1889                let val = output.values[row_start + i];
1890                assert!(
1891                    !val.is_nan(),
1892                    "[{}] Batch row {} (period={}): Expected non-NaN at index {}, got NaN",
1893                    test,
1894                    row_idx,
1895                    period,
1896                    i
1897                );
1898            }
1899        }
1900        Ok(())
1901    }
1902
1903    gen_batch_tests!(check_batch_default_row);
1904    gen_batch_tests!(check_batch_no_poison);
1905    gen_batch_tests!(check_batch_warmup_nan_preservation);
1906}
1907
1908#[cfg(feature = "python")]
1909#[pyfunction(name = "dema")]
1910#[pyo3(signature = (data, period, kernel=None))]
1911pub fn dema_py<'py>(
1912    py: Python<'py>,
1913    data: numpy::PyReadonlyArray1<'py, f64>,
1914    period: usize,
1915    kernel: Option<&str>,
1916) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
1917    use numpy::{IntoPyArray, PyArrayMethods};
1918
1919    let slice_in = data.as_slice()?;
1920    let kern = validate_kernel(kernel, false)?;
1921
1922    let params = DemaParams {
1923        period: Some(period),
1924    };
1925    let dema_in = DemaInput::from_slice(slice_in, params);
1926
1927    let result_vec: Vec<f64> = py
1928        .allow_threads(|| dema_with_kernel(&dema_in, kern).map(|o| o.values))
1929        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1930
1931    Ok(result_vec.into_pyarray(py))
1932}
1933
1934#[cfg(feature = "python")]
1935#[pyclass(name = "DemaStream")]
1936pub struct DemaStreamPy {
1937    stream: DemaStream,
1938}
1939
1940#[cfg(feature = "python")]
1941#[pymethods]
1942impl DemaStreamPy {
1943    #[new]
1944    fn new(period: usize) -> PyResult<Self> {
1945        let params = DemaParams {
1946            period: Some(period),
1947        };
1948        let stream =
1949            DemaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1950        Ok(DemaStreamPy { stream })
1951    }
1952
1953    fn update(&mut self, value: f64) -> Option<f64> {
1954        self.stream.update(value)
1955    }
1956}
1957
1958#[cfg(feature = "python")]
1959#[pyfunction(name = "dema_batch")]
1960#[pyo3(signature = (data, period_range, kernel=None))]
1961pub fn dema_batch_py<'py>(
1962    py: Python<'py>,
1963    data: numpy::PyReadonlyArray1<'py, f64>,
1964    period_range: (usize, usize, usize),
1965    kernel: Option<&str>,
1966) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1967    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1968    use pyo3::types::PyDict;
1969    use std::mem::ManuallyDrop;
1970
1971    let slice_in = data.as_slice()?;
1972    let sweep = DemaBatchRange {
1973        period: period_range,
1974    };
1975    let kern = validate_kernel(kernel, true)?;
1976
1977    let combos = expand_grid(&sweep);
1978    if combos.is_empty() {
1979        return Err(PyValueError::new_err(
1980            "invalid period range: empty expansion",
1981        ));
1982    }
1983    let rows = combos.len();
1984    let cols = slice_in.len();
1985
1986    let mut buf_mu = make_uninit_matrix(rows, cols);
1987    let first = slice_in.iter().position(|x| !x.is_nan()).unwrap_or(0);
1988    let warm: Vec<usize> = combos
1989        .iter()
1990        .map(|c| first + c.period.unwrap() - 1)
1991        .collect();
1992    init_matrix_prefixes(&mut buf_mu, cols, &warm);
1993
1994    let mut guard = ManuallyDrop::new(buf_mu);
1995    let out: &mut [f64] =
1996        unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
1997
1998    let simd = match match kern {
1999        Kernel::Auto => detect_best_batch_kernel(),
2000        k => k,
2001    } {
2002        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2003        Kernel::Avx512Batch => Kernel::Avx512,
2004        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2005        Kernel::Avx2Batch => Kernel::Scalar,
2006        Kernel::ScalarBatch => Kernel::Scalar,
2007        _ => unreachable!(),
2008    };
2009
2010    let combos = py
2011        .allow_threads(|| dema_batch_inner_into(slice_in, &sweep, simd, true, out))
2012        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2013
2014    let values: Vec<f64> = unsafe {
2015        Vec::from_raw_parts(
2016            guard.as_mut_ptr() as *mut f64,
2017            guard.len(),
2018            guard.capacity(),
2019        )
2020    };
2021    let arr = values.into_pyarray(py).reshape((rows, cols))?;
2022
2023    let dict = PyDict::new(py);
2024    dict.set_item("values", arr)?;
2025    dict.set_item(
2026        "periods",
2027        combos
2028            .iter()
2029            .map(|p| p.period.unwrap() as u64)
2030            .collect::<Vec<_>>()
2031            .into_pyarray(py),
2032    )?;
2033    Ok(dict)
2034}
2035
2036#[cfg(all(feature = "python", feature = "cuda"))]
2037#[pyfunction(name = "dema_cuda_batch_dev")]
2038#[pyo3(signature = (data_f32, period_range, device_id=0))]
2039pub fn dema_cuda_batch_dev_py(
2040    py: Python<'_>,
2041    data_f32: numpy::PyReadonlyArray1<'_, f32>,
2042    period_range: (usize, usize, usize),
2043    device_id: usize,
2044) -> PyResult<DeviceArrayF32DemaPy> {
2045    use crate::cuda::cuda_available;
2046
2047    if !cuda_available() {
2048        return Err(PyValueError::new_err("CUDA not available"));
2049    }
2050
2051    let slice_in = data_f32.as_slice()?;
2052    let sweep = DemaBatchRange {
2053        period: period_range,
2054    };
2055
2056    let (inner, ctx, dev_id) = py.allow_threads(|| {
2057        let cuda = CudaDema::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2058        let ctx = cuda.ctx();
2059        let dev_id = cuda.device_id();
2060        let arr = cuda
2061            .dema_batch_dev(slice_in, &sweep)
2062            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2063        Ok::<_, pyo3::PyErr>((arr, ctx, dev_id))
2064    })?;
2065
2066    Ok(DeviceArrayF32DemaPy::new(inner, ctx, dev_id))
2067}
2068
2069#[cfg(all(feature = "python", feature = "cuda"))]
2070#[pyfunction(name = "dema_cuda_many_series_one_param_dev")]
2071#[pyo3(signature = (data_tm_f32, period, device_id=0))]
2072pub fn dema_cuda_many_series_one_param_dev_py(
2073    py: Python<'_>,
2074    data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
2075    period: usize,
2076    device_id: usize,
2077) -> PyResult<DeviceArrayF32DemaPy> {
2078    use crate::cuda::cuda_available;
2079    use numpy::PyUntypedArrayMethods;
2080    if !cuda_available() {
2081        return Err(PyValueError::new_err("CUDA not available"));
2082    }
2083    if period == 0 {
2084        return Err(PyValueError::new_err("period must be positive"));
2085    }
2086    let flat = data_tm_f32.as_slice()?;
2087    let shape = data_tm_f32.shape();
2088    let series_len = shape[0];
2089    let num_series = shape[1];
2090    let params = DemaParams {
2091        period: Some(period),
2092    };
2093    let (inner, ctx, dev_id) = py.allow_threads(|| {
2094        let cuda = CudaDema::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2095        let ctx = cuda.ctx();
2096        let dev_id = cuda.device_id();
2097        let arr = cuda
2098            .dema_many_series_one_param_time_major_dev(flat, num_series, series_len, &params)
2099            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2100        Ok::<_, pyo3::PyErr>((arr, ctx, dev_id))
2101    })?;
2102    Ok(DeviceArrayF32DemaPy::new(inner, ctx, dev_id))
2103}
2104
2105#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2106#[wasm_bindgen]
2107pub fn dema_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
2108    let params = DemaParams {
2109        period: Some(period),
2110    };
2111    let input = DemaInput::from_slice(data, params);
2112
2113    let mut output = vec![0.0; data.len()];
2114
2115    dema_into_slice(&mut output, &input, Kernel::Auto)
2116        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2117
2118    Ok(output)
2119}
2120
2121#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2122#[derive(Serialize, Deserialize)]
2123pub struct DemaBatchConfig {
2124    pub period_range: (usize, usize, usize),
2125}
2126
2127#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2128#[derive(Serialize, Deserialize)]
2129pub struct DemaBatchJsOutput {
2130    pub values: Vec<f64>,
2131    pub combos: Vec<DemaParams>,
2132    pub rows: usize,
2133    pub cols: usize,
2134}
2135
2136#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2137#[wasm_bindgen(js_name = dema_batch)]
2138pub fn dema_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2139    let config: DemaBatchConfig = serde_wasm_bindgen::from_value(config)
2140        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2141
2142    let sweep = DemaBatchRange {
2143        period: config.period_range,
2144    };
2145
2146    let output = dema_batch_inner(data, &sweep, Kernel::Auto, false)
2147        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2148
2149    let js_output = DemaBatchJsOutput {
2150        values: output.values,
2151        combos: output.combos,
2152        rows: output.rows,
2153        cols: output.cols,
2154    };
2155
2156    serde_wasm_bindgen::to_value(&js_output)
2157        .map_err(|e| JsValue::from_str(&format!("Failed to serialize output: {}", e)))
2158}
2159
2160#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2161#[wasm_bindgen]
2162#[deprecated(since = "1.0.0", note = "Use dema_batch instead")]
2163pub fn dema_batch_metadata_js(
2164    period_start: usize,
2165    period_end: usize,
2166    period_step: usize,
2167) -> Result<Vec<f64>, JsValue> {
2168    let sweep = DemaBatchRange {
2169        period: (period_start, period_end, period_step),
2170    };
2171
2172    let combos = expand_grid(&sweep);
2173    let metadata: Vec<f64> = combos
2174        .iter()
2175        .map(|combo| combo.period.unwrap() as f64)
2176        .collect();
2177
2178    Ok(metadata)
2179}
2180
2181#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2182#[wasm_bindgen]
2183pub fn dema_alloc(len: usize) -> *mut f64 {
2184    let mut vec = Vec::<f64>::with_capacity(len);
2185    let ptr = vec.as_mut_ptr();
2186    std::mem::forget(vec);
2187    ptr
2188}
2189
2190#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2191#[wasm_bindgen]
2192pub fn dema_free(ptr: *mut f64, len: usize) {
2193    unsafe {
2194        let _ = Vec::from_raw_parts(ptr, len, len);
2195    }
2196}
2197
2198#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2199#[wasm_bindgen]
2200pub fn dema_into(
2201    in_ptr: *const f64,
2202    out_ptr: *mut f64,
2203    len: usize,
2204    period: usize,
2205) -> Result<(), JsValue> {
2206    if in_ptr.is_null() || out_ptr.is_null() {
2207        return Err(JsValue::from_str("null pointer passed to dema_into"));
2208    }
2209
2210    unsafe {
2211        let data = std::slice::from_raw_parts(in_ptr, len);
2212
2213        if period == 0 || period > len {
2214            return Err(JsValue::from_str("Invalid period"));
2215        }
2216
2217        let params = DemaParams {
2218            period: Some(period),
2219        };
2220        let input = DemaInput::from_slice(data, params);
2221
2222        if in_ptr == out_ptr {
2223            let mut temp = vec![0.0; len];
2224            dema_into_slice(&mut temp, &input, Kernel::Auto)
2225                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2226
2227            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2228            out.copy_from_slice(&temp);
2229        } else {
2230            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2231            dema_into_slice(out, &input, Kernel::Auto)
2232                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2233        }
2234
2235        Ok(())
2236    }
2237}
2238
2239#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2240#[wasm_bindgen]
2241pub fn dema_batch_into(
2242    in_ptr: *const f64,
2243    out_ptr: *mut f64,
2244    len: usize,
2245    period_start: usize,
2246    period_end: usize,
2247    period_step: usize,
2248) -> Result<usize, JsValue> {
2249    if in_ptr.is_null() || out_ptr.is_null() {
2250        return Err(JsValue::from_str("null pointer passed to dema_batch_into"));
2251    }
2252
2253    unsafe {
2254        let data = std::slice::from_raw_parts(in_ptr, len);
2255
2256        let sweep = DemaBatchRange {
2257            period: (period_start, period_end, period_step),
2258        };
2259
2260        let combos = expand_grid(&sweep);
2261        let rows = combos.len();
2262        let cols = len;
2263
2264        let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
2265
2266        dema_batch_inner_into(data, &sweep, Kernel::Auto, false, out)
2267            .map_err(|e| JsValue::from_str(&e.to_string()))?;
2268
2269        Ok(rows)
2270    }
2271}