1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::moving_averages::{CudaDema, DeviceArrayF32};
3use crate::utilities::data_loader::{source_type, Candles};
4use crate::utilities::enums::Kernel;
5use crate::utilities::helpers::{
6 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
7 make_uninit_matrix,
8};
9#[cfg(feature = "python")]
10use crate::utilities::kernel_validation::validate_kernel;
11#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
12use core::arch::x86_64::*;
13#[cfg(feature = "python")]
14use pyo3::exceptions::PyValueError;
15#[cfg(feature = "python")]
16use pyo3::prelude::*;
17#[cfg(not(target_arch = "wasm32"))]
18use rayon::prelude::*;
19#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
20use serde::{Deserialize, Serialize};
21use std::convert::AsRef;
22use std::error::Error;
23use std::mem::MaybeUninit;
24use thiserror::Error;
25#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
26use wasm_bindgen::prelude::*;
27
28#[derive(Debug, Clone)]
29pub enum DemaData<'a> {
30 Candles {
31 candles: &'a Candles,
32 source: &'a str,
33 },
34 Slice(&'a [f64]),
35}
36
37impl<'a> AsRef<[f64]> for DemaInput<'a> {
38 #[inline(always)]
39 fn as_ref(&self) -> &[f64] {
40 match &self.data {
41 DemaData::Slice(slice) => slice,
42 DemaData::Candles { candles, source } => source_type(candles, source),
43 }
44 }
45}
46
47#[derive(Debug, Clone)]
48#[cfg_attr(
49 all(target_arch = "wasm32", feature = "wasm"),
50 derive(serde::Serialize, serde::Deserialize)
51)]
52pub struct DemaParams {
53 pub period: Option<usize>,
54}
55
56impl Default for DemaParams {
57 fn default() -> Self {
58 Self { period: Some(30) }
59 }
60}
61
62#[derive(Debug, Clone)]
63pub struct DemaInput<'a> {
64 pub data: DemaData<'a>,
65 pub params: DemaParams,
66}
67
68impl<'a> DemaInput<'a> {
69 #[inline]
70 pub fn from_candles(c: &'a Candles, s: &'a str, p: DemaParams) -> Self {
71 Self {
72 data: DemaData::Candles {
73 candles: c,
74 source: s,
75 },
76 params: p,
77 }
78 }
79 #[inline]
80 pub fn from_slice(sl: &'a [f64], p: DemaParams) -> Self {
81 Self {
82 data: DemaData::Slice(sl),
83 params: p,
84 }
85 }
86 #[inline]
87 pub fn with_default_candles(c: &'a Candles) -> Self {
88 Self::from_candles(c, "close", DemaParams::default())
89 }
90 #[inline]
91 pub fn get_period(&self) -> usize {
92 self.params.period.unwrap_or(30)
93 }
94}
95
96#[derive(Debug, Clone)]
97pub struct DemaOutput {
98 pub values: Vec<f64>,
99}
100
101#[derive(Copy, Clone, Debug)]
102pub struct DemaBuilder {
103 period: Option<usize>,
104 kernel: Kernel,
105}
106
107impl Default for DemaBuilder {
108 fn default() -> Self {
109 Self {
110 period: None,
111 kernel: Kernel::Auto,
112 }
113 }
114}
115
116impl DemaBuilder {
117 #[inline(always)]
118 pub fn new() -> Self {
119 Self::default()
120 }
121 #[inline(always)]
122 pub fn period(mut self, n: usize) -> Self {
123 self.period = Some(n);
124 self
125 }
126 #[inline(always)]
127 pub fn kernel(mut self, k: Kernel) -> Self {
128 self.kernel = k;
129 self
130 }
131
132 #[inline(always)]
133 pub fn apply(self, c: &Candles) -> Result<DemaOutput, DemaError> {
134 let p = DemaParams {
135 period: self.period,
136 };
137 let i = DemaInput::from_candles(c, "close", p);
138 dema_with_kernel(&i, self.kernel)
139 }
140 #[inline(always)]
141 pub fn apply_slice(self, d: &[f64]) -> Result<DemaOutput, DemaError> {
142 let p = DemaParams {
143 period: self.period,
144 };
145 let i = DemaInput::from_slice(d, p);
146 dema_with_kernel(&i, self.kernel)
147 }
148 #[inline(always)]
149 pub fn into_stream(self) -> Result<DemaStream, DemaError> {
150 let p = DemaParams {
151 period: self.period,
152 };
153 DemaStream::try_new(p)
154 }
155}
156
157#[derive(Debug, Error)]
158pub enum DemaError {
159 #[error("dema: Input data slice is empty.")]
160 EmptyInputData,
161 #[error("dema: All values are NaN.")]
162 AllValuesNaN,
163 #[error("dema: Invalid period: period = {period}, data length = {data_len}")]
164 InvalidPeriod { period: usize, data_len: usize },
165 #[error("dema: Not enough data: needed = {needed}, valid = {valid}")]
166 NotEnoughData { needed: usize, valid: usize },
167 #[error("dema: Not enough valid data: needed = {needed}, valid = {valid}")]
168 NotEnoughValidData { needed: usize, valid: usize },
169 #[error("dema: output length mismatch: expected = {expected}, got = {got}")]
170 OutputLengthMismatch { expected: usize, got: usize },
171 #[error("dema: invalid range: start = {start}, end = {end}, step = {step}")]
172 InvalidRange {
173 start: usize,
174 end: usize,
175 step: usize,
176 },
177 #[error("dema: invalid kernel for batch: {0:?}")]
178 InvalidKernelForBatch(Kernel),
179 #[error("dema: size overflow when computing {context}")]
180 SizeOverflow { context: &'static str },
181}
182
183#[inline]
184pub fn dema(input: &DemaInput) -> Result<DemaOutput, DemaError> {
185 dema_with_kernel(input, Kernel::Auto)
186}
187
188#[inline(always)]
189fn dema_prepare<'a>(
190 input: &'a DemaInput,
191 kernel: Kernel,
192) -> Result<(&'a [f64], usize, usize, usize, Kernel), DemaError> {
193 let data: &[f64] = match &input.data {
194 DemaData::Candles { candles, source } => source_type(candles, source),
195 DemaData::Slice(sl) => sl,
196 };
197
198 let len = data.len();
199 if len == 0 {
200 return Err(DemaError::EmptyInputData);
201 }
202
203 let first = data
204 .iter()
205 .position(|x| !x.is_nan())
206 .ok_or(DemaError::AllValuesNaN)?;
207
208 let period = input.get_period();
209
210 if period < 1 || period > len {
211 return Err(DemaError::InvalidPeriod {
212 period,
213 data_len: len,
214 });
215 }
216 let needed = 2 * (period - 1);
217 if len < needed {
218 return Err(DemaError::NotEnoughData { needed, valid: len });
219 }
220 let valid = len - first;
221 if valid < needed {
222 return Err(DemaError::NotEnoughValidData { needed, valid });
223 }
224
225 let chosen = match kernel {
226 Kernel::Auto => match detect_best_kernel() {
227 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
228 Kernel::Avx512 => Kernel::Avx512,
229 _ => Kernel::Scalar,
230 },
231 other => other,
232 };
233
234 let warm = first + period - 1;
235
236 Ok((data, period, first, warm, chosen))
237}
238
239#[inline(always)]
240fn dema_compute_into(data: &[f64], period: usize, first: usize, chosen: Kernel, out: &mut [f64]) {
241 unsafe {
242 match chosen {
243 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
244 Kernel::Avx512 => dema_avx512(data, period, first, out),
245 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
246 Kernel::Avx2 => dema_avx2(data, period, first, out),
247 _ => dema_scalar(data, period, first, out),
248 }
249 }
250}
251
252pub fn dema_with_kernel(input: &DemaInput, kernel: Kernel) -> Result<DemaOutput, DemaError> {
253 let (data, period, first, warm, chosen) = dema_prepare(input, kernel)?;
254 let len = data.len();
255 let mut out = alloc_with_nan_prefix(len, warm);
256 dema_compute_into(data, period, first, chosen, &mut out);
257
258 out[..warm].fill(f64::NAN);
259 Ok(DemaOutput { values: out })
260}
261
262#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
263#[inline]
264pub fn dema_into(input: &DemaInput, out: &mut [f64]) -> Result<(), DemaError> {
265 dema_into_slice(out, input, Kernel::Auto)
266}
267
268#[inline]
269pub fn dema_into_slice(dst: &mut [f64], input: &DemaInput, kern: Kernel) -> Result<(), DemaError> {
270 let (data, period, first, warmup, chosen) = dema_prepare(input, kern)?;
271
272 if dst.len() != data.len() {
273 return Err(DemaError::OutputLengthMismatch {
274 expected: data.len(),
275 got: dst.len(),
276 });
277 }
278
279 dema_compute_into(data, period, first, chosen, dst);
280
281 for v in &mut dst[..warmup] {
282 *v = f64::NAN;
283 }
284
285 Ok(())
286}
287
288#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
289#[target_feature(enable = "fma")]
290#[inline]
291pub unsafe fn dema_scalar(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
292 debug_assert!(period >= 1 && data.len() == out.len());
293 let n = data.len();
294 if first >= n {
295 return;
296 }
297
298 let alpha = 2.0 / (period as f64 + 1.0);
299 let a = 1.0 - alpha;
300
301 let mut ema1 = *data.get_unchecked(first);
302 let mut ema2 = ema1;
303 *out.get_unchecked_mut(first) = ema1;
304
305 let mut i = first + 1;
306 let mut p = data.as_ptr().add(i);
307 let mut q = out.as_mut_ptr().add(i);
308
309 let limit = n.saturating_sub(4);
310 while i <= limit {
311 if i + 32 < n {
312 core::arch::x86_64::_mm_prefetch(
313 p.add(32) as *const i8,
314 core::arch::x86_64::_MM_HINT_T0,
315 );
316 }
317
318 let x0 = *p;
319 ema1 = ema1.mul_add(a, x0 * alpha);
320 ema2 = ema2.mul_add(a, ema1 * alpha);
321 *q = ema1.mul_add(2.0, -ema2);
322
323 let x1 = *p.add(1);
324 ema1 = ema1.mul_add(a, x1 * alpha);
325 ema2 = ema2.mul_add(a, ema1 * alpha);
326 *q.add(1) = ema1.mul_add(2.0, -ema2);
327
328 let x2 = *p.add(2);
329 ema1 = ema1.mul_add(a, x2 * alpha);
330 ema2 = ema2.mul_add(a, ema1 * alpha);
331 *q.add(2) = ema1.mul_add(2.0, -ema2);
332
333 let x3 = *p.add(3);
334 ema1 = ema1.mul_add(a, x3 * alpha);
335 ema2 = ema2.mul_add(a, ema1 * alpha);
336 *q.add(3) = ema1.mul_add(2.0, -ema2);
337
338 p = p.add(4);
339 q = q.add(4);
340 i += 4;
341 }
342
343 while i < n {
344 let x = *p;
345 ema1 = ema1.mul_add(a, x * alpha);
346 ema2 = ema2.mul_add(a, ema1 * alpha);
347 *q = ema1.mul_add(2.0, -ema2);
348
349 p = p.add(1);
350 q = q.add(1);
351 i += 1;
352 }
353}
354
355#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
356#[inline]
357pub unsafe fn dema_scalar(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
358 debug_assert!(period >= 1 && data.len() == out.len());
359 let n = data.len();
360 if first >= n {
361 return;
362 }
363
364 let alpha = 2.0 / (period as f64 + 1.0);
365 let a = 1.0 - alpha;
366
367 let mut ema1 = *data.get_unchecked(first);
368 let mut ema2 = ema1;
369 *out.get_unchecked_mut(first) = ema1;
370
371 let mut i = first + 1;
372 let mut p = data.as_ptr().add(i);
373 let mut q = out.as_mut_ptr().add(i);
374
375 let limit = n.saturating_sub(4);
376 while i <= limit {
377 let x0 = *p;
378 ema1 = ema1 * a + x0 * alpha;
379 ema2 = ema2 * a + ema1 * alpha;
380 *q = 2.0 * ema1 - ema2;
381
382 let x1 = *p.add(1);
383 ema1 = ema1 * a + x1 * alpha;
384 ema2 = ema2 * a + ema1 * alpha;
385 *q.add(1) = 2.0 * ema1 - ema2;
386
387 let x2 = *p.add(2);
388 ema1 = ema1 * a + x2 * alpha;
389 ema2 = ema2 * a + ema1 * alpha;
390 *q.add(2) = 2.0 * ema1 - ema2;
391
392 let x3 = *p.add(3);
393 ema1 = ema1 * a + x3 * alpha;
394 ema2 = ema2 * a + ema1 * alpha;
395 *q.add(3) = 2.0 * ema1 - ema2;
396
397 p = p.add(4);
398 q = q.add(4);
399 i += 4;
400 }
401
402 while i < n {
403 let x = *p;
404 ema1 = ema1 * a + x * alpha;
405 ema2 = ema2 * a + ema1 * alpha;
406 *q = 2.0 * ema1 - ema2;
407
408 p = p.add(1);
409 q = q.add(1);
410 i += 1;
411 }
412}
413
414#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
415#[inline(always)]
416unsafe fn last_lane_256(v: __m256d) -> f64 {
417 let hi: __m128d = _mm256_extractf128_pd(v, 1);
418 let dup_hi: __m128d = _mm_unpackhi_pd(hi, hi);
419 _mm_cvtsd_f64(dup_hi)
420}
421
422#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
423#[inline(always)]
424unsafe fn last_lane_512(v: __m512d) -> f64 {
425 let hi2: __m128d = _mm512_extractf64x2_pd(v, 3);
426 let dup_hi: __m128d = _mm_unpackhi_pd(hi2, hi2);
427 _mm_cvtsd_f64(dup_hi)
428}
429
430#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
431#[inline(always)]
432unsafe fn shl1_256(x: __m256d) -> __m256d {
433 let lo: __m128d = _mm256_castpd256_pd128(x);
434 let hi: __m128d = _mm256_extractf128_pd(x, 1);
435 let lo_res = _mm_unpacklo_pd(_mm_setzero_pd(), lo);
436 let hi_res = _mm_shuffle_pd(_mm_unpackhi_pd(lo, lo), _mm_unpacklo_pd(hi, hi), 0x0);
437 _mm256_insertf128_pd(_mm256_castpd128_pd256(lo_res), hi_res, 1)
438}
439
440#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
441#[inline(always)]
442unsafe fn shl2_256(x: __m256d) -> __m256d {
443 let lo: __m128d = _mm256_castpd256_pd128(x);
444 _mm256_insertf128_pd(_mm256_castpd128_pd256(_mm_setzero_pd()), lo, 1)
445}
446
447#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
448#[inline(always)]
449unsafe fn scan4(v: __m256d, a1: __m256d, a2: __m256d) -> __m256d {
450 let t1 = _mm256_fmadd_pd(a1, shl1_256(v), v);
451 let t2 = _mm256_fmadd_pd(a2, shl2_256(t1), t1);
452 t2
453}
454
455#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
456#[inline(always)]
457unsafe fn shl1_512(x: __m512d) -> __m512d {
458 let idx: __m512i = _mm512_set_epi64(6, 5, 4, 3, 2, 1, 0, 0);
459 let mask: __mmask8 = 0b1111_1110;
460 _mm512_maskz_permutexvar_pd(mask, idx, x)
461}
462#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
463#[inline(always)]
464unsafe fn shl2_512(x: __m512d) -> __m512d {
465 let idx: __m512i = _mm512_set_epi64(5, 4, 3, 2, 1, 0, 0, 0);
466 let mask: __mmask8 = 0b1111_1100;
467 _mm512_maskz_permutexvar_pd(mask, idx, x)
468}
469#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
470#[inline(always)]
471unsafe fn shl4_512(x: __m512d) -> __m512d {
472 let idx: __m512i = _mm512_set_epi64(3, 2, 1, 0, 0, 0, 0, 0);
473 let mask: __mmask8 = 0b1111_0000;
474 _mm512_maskz_permutexvar_pd(mask, idx, x)
475}
476#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
477#[inline(always)]
478unsafe fn scan8(v: __m512d, a1: __m512d, a2: __m512d, a4: __m512d) -> __m512d {
479 let t1 = _mm512_fmadd_pd(a1, shl1_512(v), v);
480 let t2 = _mm512_fmadd_pd(a2, shl2_512(t1), t1);
481 let t3 = _mm512_fmadd_pd(a4, shl4_512(t2), t2);
482 t3
483}
484
485#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
486#[target_feature(enable = "avx2,fma")]
487#[inline]
488pub unsafe fn dema_avx2(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
489 debug_assert!(data.len() == out.len());
490 if first >= data.len() {
491 return;
492 }
493
494 let n = data.len();
495 let alpha = 2.0 / (period as f64 + 1.0);
496 let a = 1.0 - alpha;
497
498 let mut i = first;
499 let mut ema1 = *data.get_unchecked(i);
500 let mut ema2 = ema1;
501 *out.get_unchecked_mut(i) = ema1;
502 i += 1;
503 if i >= n {
504 return;
505 }
506
507 let alpha_v = _mm256_set1_pd(alpha);
508 let a1_s = a;
509 let a2_s = a1_s * a1_s;
510 let a3_s = a2_s * a1_s;
511 let a4_s = a2_s * a2_s;
512 let pow_vec = _mm256_set_pd(a4_s, a3_s, a2_s, a1_s);
513 let a1_v = _mm256_set1_pd(a1_s);
514 let a2_v = _mm256_set1_pd(a2_s);
515
516 while i + 4 <= n {
517 let x = _mm256_loadu_pd(data.as_ptr().add(i));
518
519 let v1 = _mm256_mul_pd(alpha_v, x);
520 let t1 = scan4(v1, a1_v, a2_v);
521 let prev1 = _mm256_set1_pd(ema1);
522 let ema1_vec = _mm256_fmadd_pd(pow_vec, prev1, t1);
523
524 let v2 = _mm256_mul_pd(alpha_v, ema1_vec);
525 let t2 = scan4(v2, a1_v, a2_v);
526 let prev2 = _mm256_set1_pd(ema2);
527 let ema2_vec = _mm256_fmadd_pd(pow_vec, prev2, t2);
528
529 let two_ema1 = _mm256_add_pd(ema1_vec, ema1_vec);
530 let dema_v = _mm256_sub_pd(two_ema1, ema2_vec);
531 _mm256_storeu_pd(out.as_mut_ptr().add(i), dema_v);
532
533 ema1 = last_lane_256(ema1_vec);
534 ema2 = last_lane_256(ema2_vec);
535 i += 4;
536 }
537
538 while i < n {
539 let price = *data.get_unchecked(i);
540 ema1 = ema1.mul_add(a, price * alpha);
541 ema2 = ema2.mul_add(a, ema1 * alpha);
542 *out.get_unchecked_mut(i) = (2.0 * ema1) - ema2;
543 i += 1;
544 }
545}
546
547#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
548#[target_feature(enable = "avx512f,avx512dq,fma")]
549#[inline]
550pub unsafe fn dema_avx512(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
551 debug_assert!(data.len() == out.len());
552 if first >= data.len() {
553 return;
554 }
555
556 let n = data.len();
557 let alpha = 2.0 / (period as f64 + 1.0);
558 let a = 1.0 - alpha;
559
560 let mut i = first;
561 let mut ema1 = *data.get_unchecked(i);
562 let mut ema2 = ema1;
563 *out.get_unchecked_mut(i) = ema1;
564 i += 1;
565 if i >= n {
566 return;
567 }
568
569 let alpha_v = _mm512_set1_pd(alpha);
570 let a1_s = a;
571 let a2_s = a1_s * a1_s;
572 let a3_s = a2_s * a1_s;
573 let a4_s = a2_s * a2_s;
574 let a5_s = a4_s * a1_s;
575 let a6_s = a3_s * a3_s;
576 let a7_s = a6_s * a1_s;
577 let a8_s = a4_s * a4_s;
578 let pow_vec = _mm512_set_pd(a8_s, a7_s, a6_s, a5_s, a4_s, a3_s, a2_s, a1_s);
579 let a1_v = _mm512_set1_pd(a1_s);
580 let a2_v = _mm512_set1_pd(a2_s);
581 let a4_v = _mm512_set1_pd(a4_s);
582
583 while i + 8 <= n {
584 let x = _mm512_loadu_pd(data.as_ptr().add(i));
585
586 let v1 = _mm512_mul_pd(alpha_v, x);
587 let t1 = scan8(v1, a1_v, a2_v, a4_v);
588 let prev1 = _mm512_set1_pd(ema1);
589 let ema1_vec = _mm512_fmadd_pd(pow_vec, prev1, t1);
590
591 let v2 = _mm512_mul_pd(alpha_v, ema1_vec);
592 let t2 = scan8(v2, a1_v, a2_v, a4_v);
593 let prev2 = _mm512_set1_pd(ema2);
594 let ema2_vec = _mm512_fmadd_pd(pow_vec, prev2, t2);
595
596 let two_ema1 = _mm512_add_pd(ema1_vec, ema1_vec);
597 let dema_v = _mm512_sub_pd(two_ema1, ema2_vec);
598 _mm512_storeu_pd(out.as_mut_ptr().add(i), dema_v);
599
600 ema1 = last_lane_512(ema1_vec);
601 ema2 = last_lane_512(ema2_vec);
602 i += 8;
603 }
604
605 while i < n {
606 let price = *data.get_unchecked(i);
607 ema1 = ema1.mul_add(a, price * alpha);
608 ema2 = ema2.mul_add(a, ema1 * alpha);
609 *out.get_unchecked_mut(i) = (2.0 * ema1) - ema2;
610 i += 1;
611 }
612}
613
614#[derive(Debug, Clone)]
615pub struct DemaStream {
616 period: usize,
617 alpha: f64,
618 alpha_1: f64,
619 ema: f64,
620 ema2: f64,
621 filled: usize,
622 nan_fill: usize,
623}
624
625impl DemaStream {
626 pub fn try_new(params: DemaParams) -> Result<Self, DemaError> {
627 let period = params.period.unwrap_or(30);
628 if period < 1 {
629 return Err(DemaError::InvalidPeriod {
630 period,
631 data_len: 0,
632 });
633 }
634 Ok(Self {
635 period,
636 alpha: 2.0 / (period as f64 + 1.0),
637 alpha_1: 1.0 - 2.0 / (period as f64 + 1.0),
638 ema: f64::NAN,
639 ema2: f64::NAN,
640 filled: 0,
641 nan_fill: period - 1,
642 })
643 }
644
645 #[inline(always)]
646 fn fmadd(a: f64, b: f64, c: f64) -> f64 {
647 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
648 {
649 a.mul_add(b, c)
650 }
651 #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
652 {
653 a * b + c
654 }
655 }
656
657 #[inline(always)]
658 pub fn update(&mut self, value: f64) -> Option<f64> {
659 if self.filled == 0 {
660 self.ema = value;
661 self.ema2 = value;
662 self.filled = 1;
663
664 return if self.nan_fill == 0 {
665 Some(value)
666 } else {
667 None
668 };
669 }
670
671 let a = self.alpha;
672 let a1 = self.alpha_1;
673
674 self.ema = Self::fmadd(self.ema, a1, value * a);
675
676 self.ema2 = Self::fmadd(self.ema2, a1, self.ema * a);
677
678 let y = Self::fmadd(self.ema, 2.0, -self.ema2);
679
680 self.filled = self.filled.saturating_add(1);
681
682 if self.filled > self.nan_fill {
683 Some(y)
684 } else {
685 None
686 }
687 }
688
689 #[inline(always)]
690 pub fn update_nan(&mut self, value: f64) -> f64 {
691 match self.update(value) {
692 Some(v) => v,
693 None => f64::NAN,
694 }
695 }
696}
697
698#[inline(always)]
699fn fast_recip_nr1(d: f64) -> f64 {
700 let x0 = (d as f32).recip() as f64;
701 x0 * (2.0 - d * x0)
702}
703
704#[derive(Clone, Debug)]
705pub struct DemaBatchRange {
706 pub period: (usize, usize, usize),
707}
708
709impl Default for DemaBatchRange {
710 fn default() -> Self {
711 Self {
712 period: (30, 279, 1),
713 }
714 }
715}
716
717#[derive(Clone, Debug, Default)]
718pub struct DemaBatchBuilder {
719 range: DemaBatchRange,
720 kernel: Kernel,
721}
722
723impl DemaBatchBuilder {
724 pub fn new() -> Self {
725 Self::default()
726 }
727 pub fn kernel(mut self, k: Kernel) -> Self {
728 self.kernel = k;
729 self
730 }
731
732 #[inline]
733 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
734 self.range.period = (start, end, step);
735 self
736 }
737 #[inline]
738 pub fn period_static(mut self, p: usize) -> Self {
739 self.range.period = (p, p, 0);
740 self
741 }
742 pub fn apply_slice(self, data: &[f64]) -> Result<DemaBatchOutput, DemaError> {
743 dema_batch_with_kernel(data, &self.range, self.kernel)
744 }
745 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<DemaBatchOutput, DemaError> {
746 DemaBatchBuilder::new().kernel(k).apply_slice(data)
747 }
748 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<DemaBatchOutput, DemaError> {
749 let slice = source_type(c, src);
750 self.apply_slice(slice)
751 }
752 pub fn with_default_candles(c: &Candles) -> Result<DemaBatchOutput, DemaError> {
753 DemaBatchBuilder::new()
754 .kernel(Kernel::Auto)
755 .apply_candles(c, "close")
756 }
757}
758
759pub struct DemaBatchOutput {
760 pub values: Vec<f64>,
761 pub combos: Vec<DemaParams>,
762 pub rows: usize,
763 pub cols: usize,
764}
765impl DemaBatchOutput {
766 pub fn row_for_params(&self, p: &DemaParams) -> Option<usize> {
767 self.combos
768 .iter()
769 .position(|c| c.period.unwrap_or(30) == p.period.unwrap_or(30))
770 }
771 pub fn values_for(&self, p: &DemaParams) -> Option<&[f64]> {
772 self.row_for_params(p).map(|row| {
773 let start = row * self.cols;
774 &self.values[start..start + self.cols]
775 })
776 }
777}
778
779#[inline(always)]
780fn expand_grid(r: &DemaBatchRange) -> Vec<DemaParams> {
781 fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
782 if step == 0 || start == end {
783 return vec![start];
784 }
785 let mut vals = Vec::new();
786 if start < end {
787 let mut v = start;
788 while v <= end {
789 vals.push(v);
790 match v.checked_add(step) {
791 Some(n) if n != v => v = n,
792 _ => break,
793 }
794 }
795 } else {
796 let mut v = start;
797 loop {
798 vals.push(v);
799 if v <= end {
800 break;
801 }
802 match v.checked_sub(step) {
803 Some(n) if n != v => v = n,
804 _ => break,
805 }
806 }
807 }
808 vals
809 }
810 let periods = axis_usize(r.period);
811 let mut out = Vec::with_capacity(periods.len());
812 for &p in &periods {
813 out.push(DemaParams { period: Some(p) });
814 }
815 out
816}
817
818#[inline(always)]
819pub fn dema_batch_slice(
820 data: &[f64],
821 sweep: &DemaBatchRange,
822 kern: Kernel,
823) -> Result<DemaBatchOutput, DemaError> {
824 dema_batch_inner(data, sweep, kern, false)
825}
826
827#[inline(always)]
828pub fn dema_batch_par_slice(
829 data: &[f64],
830 sweep: &DemaBatchRange,
831 kern: Kernel,
832) -> Result<DemaBatchOutput, DemaError> {
833 dema_batch_inner(data, sweep, kern, true)
834}
835
836#[inline(always)]
837pub(crate) fn dema_batch_with_kernel(
838 data: &[f64],
839 sweep: &DemaBatchRange,
840 k: Kernel,
841) -> Result<DemaBatchOutput, DemaError> {
842 let kernel = match k {
843 Kernel::Auto => Kernel::ScalarBatch,
844 other if other.is_batch() => other,
845 other => return Err(DemaError::InvalidKernelForBatch(other)),
846 };
847
848 let simd = match kernel {
849 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
850 Kernel::Avx512Batch => Kernel::Avx512,
851 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
852 Kernel::Avx2Batch => Kernel::Scalar,
853 Kernel::ScalarBatch => Kernel::Scalar,
854 _ => unreachable!(),
855 };
856 dema_batch_par_slice(data, sweep, simd)
857}
858
859#[inline(always)]
860fn dema_batch_inner(
861 data: &[f64],
862 sweep: &DemaBatchRange,
863 kern: Kernel,
864 parallel: bool,
865) -> Result<DemaBatchOutput, DemaError> {
866 let combos = {
867 let v = expand_grid(sweep);
868 if v.is_empty() {
869 return Err(DemaError::InvalidRange {
870 start: sweep.period.0,
871 end: sweep.period.1,
872 step: sweep.period.2,
873 });
874 }
875 v
876 };
877 let cols = data.len();
878 let rows = combos.len();
879
880 let _total = rows.checked_mul(cols).ok_or(DemaError::SizeOverflow {
881 context: "rows*cols for batch buffer",
882 })?;
883
884 if cols == 0 {
885 return Err(DemaError::EmptyInputData);
886 }
887
888 let mut buf_mu = make_uninit_matrix(rows, cols);
889
890 let warm: Vec<usize> = combos
891 .iter()
892 .map(|c| {
893 data.iter()
894 .position(|x| !x.is_nan())
895 .unwrap_or(0)
896 .saturating_add(c.period.unwrap().saturating_sub(1))
897 })
898 .collect();
899
900 init_matrix_prefixes(&mut buf_mu, cols, &warm);
901
902 let mut buf_guard = std::mem::ManuallyDrop::new(buf_mu);
903 let out: &mut [f64] = unsafe {
904 core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
905 };
906
907 dema_batch_inner_into(data, sweep, kern, parallel, out)?;
908
909 let values = unsafe {
910 Vec::from_raw_parts(
911 buf_guard.as_mut_ptr() as *mut f64,
912 buf_guard.len(),
913 buf_guard.capacity(),
914 )
915 };
916
917 Ok(DemaBatchOutput {
918 values,
919 combos,
920 rows,
921 cols,
922 })
923}
924
925#[inline(always)]
926fn dema_batch_inner_into(
927 data: &[f64],
928 sweep: &DemaBatchRange,
929 kern: Kernel,
930 parallel: bool,
931 out: &mut [f64],
932) -> Result<Vec<DemaParams>, DemaError> {
933 let combos = {
934 let v = expand_grid(sweep);
935 if v.is_empty() {
936 return Err(DemaError::InvalidRange {
937 start: sweep.period.0,
938 end: sweep.period.1,
939 step: sweep.period.2,
940 });
941 }
942 v
943 };
944
945 if data.is_empty() {
946 return Err(DemaError::EmptyInputData);
947 }
948
949 let first = data
950 .iter()
951 .position(|x| !x.is_nan())
952 .ok_or(DemaError::AllValuesNaN)?;
953
954 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
955 let needed = 2 * (max_p - 1);
956 if data.len() < needed {
957 return Err(DemaError::NotEnoughData {
958 needed,
959 valid: data.len(),
960 });
961 }
962 let valid = data.len() - first;
963 if valid < needed {
964 return Err(DemaError::NotEnoughValidData { needed, valid });
965 }
966
967 let rows = combos.len();
968 let cols = data.len();
969
970 let expected = rows.checked_mul(cols).ok_or(DemaError::SizeOverflow {
971 context: "rows*cols when validating output buffer",
972 })?;
973 if out.len() != expected {
974 return Err(DemaError::OutputLengthMismatch {
975 expected,
976 got: out.len(),
977 });
978 }
979
980 let do_row = |row: usize, dst: &mut [f64]| unsafe {
981 let p = combos[row].period.unwrap();
982
983 match kern {
984 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
985 Kernel::Avx512 => dema_row_avx512(data, first, p, dst),
986 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
987 Kernel::Avx2 => dema_row_avx2(data, first, p, dst),
988 _ => dema_row_scalar(data, first, p, dst),
989 }
990
991 let warm = first + p - 1;
992 dst[..warm].fill(f64::NAN);
993 };
994
995 if parallel {
996 #[cfg(not(target_arch = "wasm32"))]
997 {
998 out.par_chunks_mut(cols)
999 .enumerate()
1000 .for_each(|(row, slice)| do_row(row, slice));
1001 }
1002 #[cfg(target_arch = "wasm32")]
1003 {
1004 for (row, slice) in out.chunks_mut(cols).enumerate() {
1005 do_row(row, slice);
1006 }
1007 }
1008 } else {
1009 for (row, slice) in out.chunks_mut(cols).enumerate() {
1010 do_row(row, slice);
1011 }
1012 }
1013
1014 Ok(combos)
1015}
1016#[inline(always)]
1017unsafe fn dema_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1018 dema_scalar(data, period, first, out)
1019}
1020#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1021#[target_feature(enable = "avx2,fma")]
1022unsafe fn dema_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1023 dema_avx2(data, period, first, out)
1024}
1025#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1026#[target_feature(enable = "avx512f,fma")]
1027unsafe fn dema_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1028 dema_avx512(data, period, first, out)
1029}
1030
1031#[cfg(all(feature = "python", feature = "cuda"))]
1032#[pyclass(module = "ta_indicators.cuda", name = "DeviceArrayF32Dema", unsendable)]
1033pub struct DeviceArrayF32DemaPy {
1034 pub(crate) inner: DeviceArrayF32,
1035 _ctx_guard: std::sync::Arc<cust::context::Context>,
1036 _device_id: u32,
1037}
1038
1039#[cfg(all(feature = "python", feature = "cuda"))]
1040#[pymethods]
1041impl DeviceArrayF32DemaPy {
1042 #[new]
1043 fn py_new() -> PyResult<Self> {
1044 Err(pyo3::exceptions::PyTypeError::new_err(
1045 "use factory methods from CUDA functions",
1046 ))
1047 }
1048
1049 #[getter]
1050 fn __cuda_array_interface__<'py>(
1051 &self,
1052 py: Python<'py>,
1053 ) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1054 let d = pyo3::types::PyDict::new(py);
1055 let itemsize = std::mem::size_of::<f32>();
1056 d.set_item("shape", (self.inner.rows, self.inner.cols))?;
1057 d.set_item("typestr", "<f4")?;
1058 d.set_item("strides", (self.inner.cols * itemsize, itemsize))?;
1059 let size = self.inner.rows.saturating_mul(self.inner.cols);
1060 let ptr_val: usize = if size == 0 {
1061 0
1062 } else {
1063 self.inner.buf.as_device_ptr().as_raw() as usize
1064 };
1065 d.set_item("data", (ptr_val, false))?;
1066 d.set_item("version", 3)?;
1067 Ok(d)
1068 }
1069
1070 fn __dlpack_device__(&self) -> (i32, i32) {
1071 (2, self._device_id as i32)
1072 }
1073
1074 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1075 fn __dlpack__<'py>(
1076 &mut self,
1077 py: Python<'py>,
1078 stream: Option<PyObject>,
1079 max_version: Option<PyObject>,
1080 dl_device: Option<PyObject>,
1081 copy: Option<PyObject>,
1082 ) -> PyResult<pyo3::PyObject> {
1083 use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
1084
1085 let (kdl, alloc_dev) = self.__dlpack_device__();
1086 if let Some(dev_obj) = dl_device.as_ref() {
1087 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1088 if dev_ty != kdl || dev_id != alloc_dev {
1089 let wants_copy = copy
1090 .as_ref()
1091 .and_then(|c| c.extract::<bool>(py).ok())
1092 .unwrap_or(false);
1093 if wants_copy {
1094 return Err(PyValueError::new_err(
1095 "device copy not implemented for __dlpack__",
1096 ));
1097 } else {
1098 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
1099 }
1100 }
1101 }
1102 }
1103
1104 let _ = stream;
1105
1106 let dummy = cust::memory::DeviceBuffer::from_slice(&[])
1107 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1108 let inner = std::mem::replace(
1109 &mut self.inner,
1110 DeviceArrayF32 {
1111 buf: dummy,
1112 rows: 0,
1113 cols: 0,
1114 },
1115 );
1116
1117 let rows = inner.rows;
1118 let cols = inner.cols;
1119 let buf = inner.buf;
1120
1121 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1122
1123 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1124 }
1125}
1126
1127#[cfg(all(feature = "python", feature = "cuda"))]
1128impl DeviceArrayF32DemaPy {
1129 pub fn new(
1130 inner: DeviceArrayF32,
1131 ctx_guard: std::sync::Arc<cust::context::Context>,
1132 device_id: u32,
1133 ) -> Self {
1134 Self {
1135 inner,
1136 _ctx_guard: ctx_guard,
1137 _device_id: device_id,
1138 }
1139 }
1140}
1141
1142#[cfg(test)]
1143mod tests {
1144 use super::*;
1145 use crate::skip_if_unsupported;
1146 use crate::utilities::data_loader::read_candles_from_csv;
1147 use proptest::prelude::*;
1148
1149 fn check_dema_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1150 skip_if_unsupported!(kernel, test_name);
1151 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1152 let candles = read_candles_from_csv(file_path)?;
1153
1154 let default_params = DemaParams { period: None };
1155 let input_default = DemaInput::from_candles(&candles, "close", default_params);
1156 let output_default = dema_with_kernel(&input_default, kernel)?;
1157 assert_eq!(output_default.values.len(), candles.close.len());
1158
1159 let params_period_14 = DemaParams { period: Some(14) };
1160 let input_period_14 = DemaInput::from_candles(&candles, "hl2", params_period_14);
1161 let output_period_14 = dema_with_kernel(&input_period_14, kernel)?;
1162 assert_eq!(output_period_14.values.len(), candles.close.len());
1163
1164 let params_custom = DemaParams { period: Some(20) };
1165 let input_custom = DemaInput::from_candles(&candles, "hlc3", params_custom);
1166 let output_custom = dema_with_kernel(&input_custom, kernel)?;
1167 assert_eq!(output_custom.values.len(), candles.close.len());
1168 Ok(())
1169 }
1170
1171 fn check_dema_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1172 skip_if_unsupported!(kernel, test_name);
1173 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1174 let candles = read_candles_from_csv(file_path)?;
1175
1176 let input = DemaInput::with_default_candles(&candles);
1177 let result = dema_with_kernel(&input, kernel)?;
1178
1179 let expected_last_five = [
1180 59189.73193987478,
1181 59129.24920772847,
1182 59058.80282420511,
1183 59011.5555611042,
1184 58908.370159946775,
1185 ];
1186 let start_index = result.values.len().saturating_sub(5);
1187 let last_five = &result.values[start_index..];
1188 for (i, &val) in last_five.iter().enumerate() {
1189 let exp = expected_last_five[i];
1190 assert!(
1191 (val - exp).abs() < 1e-6,
1192 "DEMA mismatch at index {}: expected {}, got {}",
1193 start_index + i,
1194 exp,
1195 val
1196 );
1197 }
1198 Ok(())
1199 }
1200
1201 fn check_dema_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1202 skip_if_unsupported!(kernel, test_name);
1203 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1204 let candles = read_candles_from_csv(file_path)?;
1205 let input = DemaInput::with_default_candles(&candles);
1206 match input.data {
1207 DemaData::Candles { source, .. } => assert_eq!(source, "close"),
1208 _ => panic!("Expected DemaData::Candles"),
1209 }
1210 assert_eq!(input.params.period, Some(30));
1211 let output = dema_with_kernel(&input, kernel)?;
1212 assert_eq!(output.values.len(), candles.close.len());
1213 Ok(())
1214 }
1215
1216 fn check_dema_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1217 skip_if_unsupported!(kernel, test_name);
1218 let input_data = [10.0, 20.0, 30.0];
1219 let params = DemaParams { period: Some(0) };
1220 let input = DemaInput::from_slice(&input_data, params);
1221 let result = dema_with_kernel(&input, kernel);
1222 assert!(result.is_err());
1223 Ok(())
1224 }
1225
1226 fn check_dema_period_exceeds_length(
1227 test_name: &str,
1228 kernel: Kernel,
1229 ) -> Result<(), Box<dyn Error>> {
1230 skip_if_unsupported!(kernel, test_name);
1231 let data_small = [10.0, 20.0, 30.0];
1232 let params = DemaParams { period: Some(10) };
1233 let input = DemaInput::from_slice(&data_small, params);
1234 let result = dema_with_kernel(&input, kernel);
1235 assert!(result.is_err());
1236 Ok(())
1237 }
1238
1239 fn check_dema_very_small_dataset(
1240 test_name: &str,
1241 kernel: Kernel,
1242 ) -> Result<(), Box<dyn Error>> {
1243 skip_if_unsupported!(kernel, test_name);
1244 let single_point = [42.0];
1245 let params = DemaParams { period: Some(9) };
1246 let input = DemaInput::from_slice(&single_point, params);
1247 let result = dema_with_kernel(&input, kernel);
1248 assert!(result.is_err());
1249 Ok(())
1250 }
1251
1252 fn check_dema_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1253 skip_if_unsupported!(kernel, test_name);
1254 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1255 let candles = read_candles_from_csv(file_path)?;
1256
1257 let first_params = DemaParams { period: Some(80) };
1258 let first_input = DemaInput::from_candles(&candles, "close", first_params);
1259 let first_result = dema_with_kernel(&first_input, kernel)?;
1260
1261 let second_params = DemaParams { period: Some(60) };
1262 let second_input = DemaInput::from_slice(&first_result.values, second_params);
1263 let second_result = dema_with_kernel(&second_input, kernel)?;
1264
1265 assert_eq!(second_result.values.len(), first_result.values.len());
1266 if second_result.values.len() > 240 {
1267 for i in 240..second_result.values.len() {
1268 assert!(!second_result.values[i].is_nan());
1269 }
1270 }
1271 Ok(())
1272 }
1273
1274 fn check_dema_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1275 skip_if_unsupported!(kernel, test_name);
1276 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1277 let candles = read_candles_from_csv(file_path)?;
1278 let params = DemaParams { period: Some(30) };
1279 let input = DemaInput::from_candles(&candles, "close", params);
1280 let result = dema_with_kernel(&input, kernel)?;
1281 if result.values.len() > 240 {
1282 for i in 240..result.values.len() {
1283 assert!(!result.values[i].is_nan());
1284 }
1285 }
1286 Ok(())
1287 }
1288
1289 fn check_dema_empty_input(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1290 skip_if_unsupported!(kernel, test_name);
1291 let empty: [f64; 0] = [];
1292 let input = DemaInput::from_slice(&empty, DemaParams::default());
1293 let res = dema_with_kernel(&input, kernel);
1294 assert!(matches!(res, Err(DemaError::EmptyInputData)));
1295 Ok(())
1296 }
1297
1298 fn check_dema_not_enough_valid(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1299 skip_if_unsupported!(kernel, test_name);
1300 let data = [f64::NAN, f64::NAN, 1.0, 2.0];
1301 let params = DemaParams { period: Some(3) };
1302 let input = DemaInput::from_slice(&data, params);
1303 let res = dema_with_kernel(&input, kernel);
1304 assert!(matches!(res, Err(DemaError::NotEnoughValidData { .. })));
1305 Ok(())
1306 }
1307
1308 #[allow(clippy::float_cmp)]
1309 fn check_dema_property(
1310 test_name: &str,
1311 kernel: Kernel,
1312 ) -> Result<(), Box<dyn std::error::Error>> {
1313 use float_cmp::approx_eq;
1314 use proptest::prelude::*;
1315
1316 skip_if_unsupported!(kernel, test_name);
1317
1318 let strat = (1usize..=32).prop_flat_map(|period| {
1319 let min_len = 2 * period.max(2);
1320 (
1321 prop::collection::vec(
1322 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1323 min_len..400,
1324 ),
1325 Just(period),
1326 (-1e3f64..1e3f64).prop_filter("non-zero scale", |a| a.is_finite() && *a != 0.0),
1327 -1e3f64..1e3f64,
1328 )
1329 });
1330
1331 proptest::test_runner::TestRunner::default()
1332 .run(&strat, |(data, period, a, b)| {
1333 let params = DemaParams {
1334 period: Some(period),
1335 };
1336 let input = DemaInput::from_slice(&data, params.clone());
1337
1338 let fast = dema_with_kernel(&input, kernel);
1339 let slow = dema_with_kernel(&input, Kernel::Scalar);
1340
1341 match (fast, slow) {
1342 (Err(e1), Err(e2))
1343 if std::mem::discriminant(&e1) == std::mem::discriminant(&e2) =>
1344 {
1345 return Ok(())
1346 }
1347
1348 (Err(e1), Err(e2)) => {
1349 prop_assert!(false, "different errors: fast={:?} slow={:?}", e1, e2)
1350 }
1351
1352 (Err(e1), Ok(_)) => {
1353 prop_assert!(false, "fast errored {e1:?} but scalar succeeded")
1354 }
1355 (Ok(_), Err(e2)) => {
1356 prop_assert!(false, "scalar errored {e2:?} but fast succeeded")
1357 }
1358
1359 (Ok(fast), Ok(reference)) => {
1360 let DemaOutput { values: out } = fast;
1361 let DemaOutput { values: rref } = reference;
1362
1363 let mut stream = DemaStream::try_new(params.clone()).unwrap();
1364 let mut s_out = Vec::with_capacity(data.len());
1365 for &v in &data {
1366 s_out.push(stream.update(v).unwrap_or(f64::NAN));
1367 }
1368
1369 let transformed: Vec<f64> = data.iter().map(|x| a * *x + b).collect();
1370 let t_out =
1371 dema(&DemaInput::from_slice(&transformed, params.clone()))?.values;
1372
1373 let nan_fill = period - 1;
1374 for i in 0..data.len() {
1375 let y = out[i];
1376 let yr = rref[i];
1377 let ys = s_out[i];
1378 let yt = t_out[i];
1379
1380 if period == 1 && y.is_finite() {
1381 prop_assert!(approx_eq!(f64, y, data[i], ulps = 2));
1382 }
1383
1384 if i >= period - 1 {
1385 let window = &data[i.saturating_sub(period - 1)..=i];
1386 if window.iter().all(|v| *v == window[0]) {
1387 prop_assert!(approx_eq!(f64, y, window[0], epsilon = 1e-9));
1388 }
1389 } else {
1390 prop_assert!(y.is_nan(), "Expected NaN during warmup at index {i}");
1391 }
1392
1393 if i >= nan_fill {
1394 if y.is_finite() {
1395 let expected = a * y + b;
1396 let diff = (yt - expected).abs();
1397
1398 let tol = 1e-7_f64.max(expected.abs() * 1e-9);
1399 let ulp = yt.to_bits().abs_diff(expected.to_bits());
1400 prop_assert!(
1401 diff <= tol || ulp <= 8,
1402 "idx {i}: affine mismatch diff={diff:e} ULP={ulp}"
1403 );
1404 } else {
1405 prop_assert_eq!(
1406 y.to_bits(),
1407 yt.to_bits(),
1408 "idx {}: special-value mismatch under affine map",
1409 i
1410 );
1411 }
1412 }
1413
1414 let ulp = y.to_bits().abs_diff(yr.to_bits());
1415 prop_assert!(
1416 (y - yr).abs() <= 1e-9 || ulp <= 4,
1417 "idx {i}: fast={y} ref={yr} ULP={ulp}"
1418 );
1419
1420 if period == 1 {
1421 prop_assert!(
1422 (y - ys).abs() <= 1e-9 || (y.is_nan() && ys.is_nan()),
1423 "idx {i}: stream mismatch for period=1 - batch={y}, stream={ys}"
1424 );
1425 } else if i < period - 1 {
1426 prop_assert!(
1427 ys.is_nan(),
1428 "idx {i}: stream should return NaN during warmup, got {ys}"
1429 );
1430 } else {
1431 prop_assert!(
1432 (y - ys).abs() <= 1e-9 || (y.is_nan() && ys.is_nan()),
1433 "idx {i}: stream mismatch - batch={y}, stream={ys}"
1434 );
1435 }
1436 }
1437 }
1438 }
1439
1440 Ok(())
1441 })
1442 .unwrap();
1443
1444 assert!(dema(&DemaInput::from_slice(&[], DemaParams::default())).is_err());
1445 assert!(dema(&DemaInput::from_slice(
1446 &[f64::NAN; 12],
1447 DemaParams::default()
1448 ))
1449 .is_err());
1450 assert!(dema(&DemaInput::from_slice(
1451 &[1.0; 5],
1452 DemaParams { period: Some(12) }
1453 ))
1454 .is_err());
1455 assert!(dema(&DemaInput::from_slice(
1456 &[1.0; 5],
1457 DemaParams { period: Some(0) }
1458 ))
1459 .is_err());
1460
1461 Ok(())
1462 }
1463
1464 fn check_dema_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1465 skip_if_unsupported!(kernel, test_name);
1466
1467 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1468 let candles = read_candles_from_csv(file_path)?;
1469
1470 let period = 30;
1471 let input = DemaInput::from_candles(
1472 &candles,
1473 "close",
1474 DemaParams {
1475 period: Some(period),
1476 },
1477 );
1478 let batch_output = dema_with_kernel(&input, kernel)?.values;
1479
1480 let mut stream = DemaStream::try_new(DemaParams {
1481 period: Some(period),
1482 })?;
1483 let mut stream_values = Vec::with_capacity(candles.close.len());
1484 for &price in &candles.close {
1485 stream_values.push(stream.update(price).unwrap_or(f64::NAN));
1486 }
1487
1488 assert_eq!(batch_output.len(), stream_values.len());
1489
1490 for (i, (&b, &s)) in batch_output
1491 .iter()
1492 .zip(&stream_values)
1493 .enumerate()
1494 .skip(period)
1495 {
1496 if b.is_nan() && s.is_nan() {
1497 continue;
1498 }
1499
1500 let diff = (b - s).abs();
1501 assert!(
1502 diff < 1e-9,
1503 "[{}] DEMA streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1504 test_name,
1505 i,
1506 b,
1507 s,
1508 diff
1509 );
1510 }
1511 Ok(())
1512 }
1513
1514 #[cfg(debug_assertions)]
1515 fn check_dema_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1516 skip_if_unsupported!(kernel, test_name);
1517
1518 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1519 let candles = read_candles_from_csv(file_path)?;
1520
1521 let test_params = vec![
1522 DemaParams::default(),
1523 DemaParams { period: Some(2) },
1524 DemaParams { period: Some(3) },
1525 DemaParams { period: Some(5) },
1526 DemaParams { period: Some(7) },
1527 DemaParams { period: Some(10) },
1528 DemaParams { period: Some(12) },
1529 DemaParams { period: Some(20) },
1530 DemaParams { period: Some(30) },
1531 DemaParams { period: Some(50) },
1532 DemaParams { period: Some(100) },
1533 DemaParams { period: Some(200) },
1534 DemaParams { period: Some(1) },
1535 DemaParams { period: Some(250) },
1536 ];
1537
1538 for (param_idx, params) in test_params.iter().enumerate() {
1539 let input = DemaInput::from_candles(&candles, "close", params.clone());
1540 let output = dema_with_kernel(&input, kernel)?;
1541
1542 for (i, &val) in output.values.iter().enumerate() {
1543 if val.is_nan() {
1544 continue;
1545 }
1546
1547 let bits = val.to_bits();
1548
1549 if bits == 0x11111111_11111111 {
1550 panic!(
1551 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1552 with params: period={}",
1553 test_name,
1554 val,
1555 bits,
1556 i,
1557 params.period.unwrap_or(30)
1558 );
1559 }
1560
1561 if bits == 0x22222222_22222222 {
1562 panic!(
1563 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1564 with params: period={}",
1565 test_name,
1566 val,
1567 bits,
1568 i,
1569 params.period.unwrap_or(30)
1570 );
1571 }
1572
1573 if bits == 0x33333333_33333333 {
1574 panic!(
1575 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1576 with params: period={}",
1577 test_name,
1578 val,
1579 bits,
1580 i,
1581 params.period.unwrap_or(30)
1582 );
1583 }
1584 }
1585 }
1586
1587 Ok(())
1588 }
1589
1590 #[cfg(not(debug_assertions))]
1591 fn check_dema_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1592 Ok(())
1593 }
1594
1595 macro_rules! generate_all_dema_tests {
1596 ($($test_fn:ident),*) => {
1597 paste::paste! {
1598 $(
1599 #[test]
1600 fn [<$test_fn _scalar_f64>]() {
1601 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1602 }
1603 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1604 #[test]
1605 fn [<$test_fn _avx2_f64>]() {
1606 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1607 }
1608 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1609 #[test]
1610 fn [<$test_fn _avx512_f64>]() {
1611 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1612 }
1613 )*
1614 }
1615 }
1616 }
1617
1618 fn check_dema_warmup_nan_preservation(
1619 test_name: &str,
1620 kernel: Kernel,
1621 ) -> Result<(), Box<dyn Error>> {
1622 skip_if_unsupported!(kernel, test_name);
1623 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1624 let candles = read_candles_from_csv(file_path)?;
1625
1626 let test_periods = vec![10, 20, 30, 50];
1627
1628 for period in test_periods {
1629 let params = DemaParams {
1630 period: Some(period),
1631 };
1632 let input = DemaInput::from_candles(&candles, "close", params);
1633 let result = dema_with_kernel(&input, kernel)?;
1634
1635 let warmup = period - 1;
1636 for i in 0..warmup {
1637 assert!(
1638 result.values[i].is_nan(),
1639 "[{}] Expected NaN at index {} (warmup={}) for period={}, got {}",
1640 test_name,
1641 i,
1642 warmup,
1643 period,
1644 result.values[i]
1645 );
1646 }
1647
1648 for i in warmup..warmup + 10 {
1649 assert!(
1650 !result.values[i].is_nan(),
1651 "[{}] Expected non-NaN at index {} (warmup={}) for period={}, got NaN",
1652 test_name,
1653 i,
1654 warmup,
1655 period
1656 );
1657 }
1658 }
1659 Ok(())
1660 }
1661
1662 generate_all_dema_tests!(
1663 check_dema_partial_params,
1664 check_dema_accuracy,
1665 check_dema_default_candles,
1666 check_dema_zero_period,
1667 check_dema_period_exceeds_length,
1668 check_dema_very_small_dataset,
1669 check_dema_empty_input,
1670 check_dema_not_enough_valid,
1671 check_dema_reinput,
1672 check_dema_nan_handling,
1673 check_dema_streaming,
1674 check_dema_property,
1675 check_dema_no_poison,
1676 check_dema_warmup_nan_preservation
1677 );
1678
1679 #[test]
1680 fn test_dema_into_matches_api() -> Result<(), Box<dyn Error>> {
1681 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1682 let candles = read_candles_from_csv(file_path)?;
1683
1684 let input = DemaInput::with_default_candles(&candles);
1685 let baseline = dema(&input)?.values;
1686
1687 let mut out = vec![0.0; candles.close.len()];
1688
1689 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1690 {
1691 dema_into(&input, &mut out)?;
1692 }
1693 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1694 {
1695 dema_into_slice(&mut out, &input, Kernel::Auto)?;
1696 }
1697
1698 assert_eq!(out.len(), baseline.len());
1699 for i in 0..out.len() {
1700 let a = out[i];
1701 let b = baseline[i];
1702 if a.is_nan() || b.is_nan() {
1703 assert!(a.is_nan() && b.is_nan(), "NaN mismatch at index {}", i);
1704 } else {
1705 assert!(a == b, "Value mismatch at index {}: {} != {}", i, a, b);
1706 }
1707 }
1708 Ok(())
1709 }
1710
1711 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1712 skip_if_unsupported!(kernel, test);
1713
1714 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1715 let c = read_candles_from_csv(file)?;
1716
1717 let output = DemaBatchBuilder::new()
1718 .kernel(kernel)
1719 .apply_candles(&c, "close")?;
1720
1721 let def = DemaParams::default();
1722 let row = output.values_for(&def).expect("default row missing");
1723
1724 assert_eq!(row.len(), c.close.len());
1725
1726 let expected = [
1727 59189.73193987478,
1728 59129.24920772847,
1729 59058.80282420511,
1730 59011.5555611042,
1731 58908.370159946775,
1732 ];
1733 let start = row.len() - 5;
1734 for (i, &v) in row[start..].iter().enumerate() {
1735 assert!(
1736 (v - expected[i]).abs() < 1e-6,
1737 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1738 );
1739 }
1740 Ok(())
1741 }
1742
1743 #[cfg(debug_assertions)]
1744 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1745 skip_if_unsupported!(kernel, test);
1746
1747 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1748 let c = read_candles_from_csv(file)?;
1749
1750 let test_configs = vec![
1751 (2, 5, 1),
1752 (5, 25, 5),
1753 (10, 50, 10),
1754 (1, 3, 1),
1755 (50, 150, 25),
1756 (10, 30, 2),
1757 (10, 30, 10),
1758 (100, 300, 50),
1759 ];
1760
1761 for (cfg_idx, &(p_start, p_end, p_step)) in test_configs.iter().enumerate() {
1762 let output = DemaBatchBuilder::new()
1763 .kernel(kernel)
1764 .period_range(p_start, p_end, p_step)
1765 .apply_candles(&c, "close")?;
1766
1767 for (idx, &val) in output.values.iter().enumerate() {
1768 if val.is_nan() {
1769 continue;
1770 }
1771
1772 let bits = val.to_bits();
1773 let row = idx / output.cols;
1774 let col = idx % output.cols;
1775 let combo = &output.combos[row];
1776
1777 if bits == 0x11111111_11111111 {
1778 panic!(
1779 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1780 at row {} col {} (flat index {}) with params: period={}",
1781 test,
1782 cfg_idx,
1783 val,
1784 bits,
1785 row,
1786 col,
1787 idx,
1788 combo.period.unwrap_or(30)
1789 );
1790 }
1791
1792 if bits == 0x22222222_22222222 {
1793 panic!(
1794 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1795 at row {} col {} (flat index {}) with params: period={}",
1796 test,
1797 cfg_idx,
1798 val,
1799 bits,
1800 row,
1801 col,
1802 idx,
1803 combo.period.unwrap_or(30)
1804 );
1805 }
1806
1807 if bits == 0x33333333_33333333 {
1808 panic!(
1809 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1810 at row {} col {} (flat index {}) with params: period={}",
1811 test,
1812 cfg_idx,
1813 val,
1814 bits,
1815 row,
1816 col,
1817 idx,
1818 combo.period.unwrap_or(30)
1819 );
1820 }
1821 }
1822 }
1823
1824 Ok(())
1825 }
1826
1827 #[cfg(not(debug_assertions))]
1828 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1829 Ok(())
1830 }
1831
1832 macro_rules! gen_batch_tests {
1833 ($fn_name:ident) => {
1834 paste::paste! {
1835 #[test]
1836 fn [<$fn_name _scalar>]() {
1837 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1838 }
1839 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1840 #[test]
1841 fn [<$fn_name _avx2>]() {
1842 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1843 }
1844 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1845 #[test]
1846 fn [<$fn_name _avx512>]() {
1847 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1848 }
1849 #[test]
1850 fn [<$fn_name _auto_detect>]() {
1851 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1852 }
1853 }
1854 };
1855 }
1856 fn check_batch_warmup_nan_preservation(
1857 test: &str,
1858 kernel: Kernel,
1859 ) -> Result<(), Box<dyn Error>> {
1860 skip_if_unsupported!(kernel, test);
1861
1862 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1863 let c = read_candles_from_csv(file)?;
1864
1865 let output = DemaBatchBuilder::new()
1866 .kernel(kernel)
1867 .period_range(10, 30, 10)
1868 .apply_candles(&c, "close")?;
1869
1870 for (row_idx, combo) in output.combos.iter().enumerate() {
1871 let period = combo.period.unwrap_or(30);
1872 let warmup = period - 1;
1873 let row_start = row_idx * output.cols;
1874
1875 for i in 0..warmup {
1876 let val = output.values[row_start + i];
1877 assert!(
1878 val.is_nan(),
1879 "[{}] Batch row {} (period={}): Expected NaN at index {}, got {}",
1880 test,
1881 row_idx,
1882 period,
1883 i,
1884 val
1885 );
1886 }
1887
1888 for i in warmup..warmup.min(output.cols).min(warmup + 10) {
1889 let val = output.values[row_start + i];
1890 assert!(
1891 !val.is_nan(),
1892 "[{}] Batch row {} (period={}): Expected non-NaN at index {}, got NaN",
1893 test,
1894 row_idx,
1895 period,
1896 i
1897 );
1898 }
1899 }
1900 Ok(())
1901 }
1902
1903 gen_batch_tests!(check_batch_default_row);
1904 gen_batch_tests!(check_batch_no_poison);
1905 gen_batch_tests!(check_batch_warmup_nan_preservation);
1906}
1907
1908#[cfg(feature = "python")]
1909#[pyfunction(name = "dema")]
1910#[pyo3(signature = (data, period, kernel=None))]
1911pub fn dema_py<'py>(
1912 py: Python<'py>,
1913 data: numpy::PyReadonlyArray1<'py, f64>,
1914 period: usize,
1915 kernel: Option<&str>,
1916) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
1917 use numpy::{IntoPyArray, PyArrayMethods};
1918
1919 let slice_in = data.as_slice()?;
1920 let kern = validate_kernel(kernel, false)?;
1921
1922 let params = DemaParams {
1923 period: Some(period),
1924 };
1925 let dema_in = DemaInput::from_slice(slice_in, params);
1926
1927 let result_vec: Vec<f64> = py
1928 .allow_threads(|| dema_with_kernel(&dema_in, kern).map(|o| o.values))
1929 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1930
1931 Ok(result_vec.into_pyarray(py))
1932}
1933
1934#[cfg(feature = "python")]
1935#[pyclass(name = "DemaStream")]
1936pub struct DemaStreamPy {
1937 stream: DemaStream,
1938}
1939
1940#[cfg(feature = "python")]
1941#[pymethods]
1942impl DemaStreamPy {
1943 #[new]
1944 fn new(period: usize) -> PyResult<Self> {
1945 let params = DemaParams {
1946 period: Some(period),
1947 };
1948 let stream =
1949 DemaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1950 Ok(DemaStreamPy { stream })
1951 }
1952
1953 fn update(&mut self, value: f64) -> Option<f64> {
1954 self.stream.update(value)
1955 }
1956}
1957
1958#[cfg(feature = "python")]
1959#[pyfunction(name = "dema_batch")]
1960#[pyo3(signature = (data, period_range, kernel=None))]
1961pub fn dema_batch_py<'py>(
1962 py: Python<'py>,
1963 data: numpy::PyReadonlyArray1<'py, f64>,
1964 period_range: (usize, usize, usize),
1965 kernel: Option<&str>,
1966) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1967 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1968 use pyo3::types::PyDict;
1969 use std::mem::ManuallyDrop;
1970
1971 let slice_in = data.as_slice()?;
1972 let sweep = DemaBatchRange {
1973 period: period_range,
1974 };
1975 let kern = validate_kernel(kernel, true)?;
1976
1977 let combos = expand_grid(&sweep);
1978 if combos.is_empty() {
1979 return Err(PyValueError::new_err(
1980 "invalid period range: empty expansion",
1981 ));
1982 }
1983 let rows = combos.len();
1984 let cols = slice_in.len();
1985
1986 let mut buf_mu = make_uninit_matrix(rows, cols);
1987 let first = slice_in.iter().position(|x| !x.is_nan()).unwrap_or(0);
1988 let warm: Vec<usize> = combos
1989 .iter()
1990 .map(|c| first + c.period.unwrap() - 1)
1991 .collect();
1992 init_matrix_prefixes(&mut buf_mu, cols, &warm);
1993
1994 let mut guard = ManuallyDrop::new(buf_mu);
1995 let out: &mut [f64] =
1996 unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
1997
1998 let simd = match match kern {
1999 Kernel::Auto => detect_best_batch_kernel(),
2000 k => k,
2001 } {
2002 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2003 Kernel::Avx512Batch => Kernel::Avx512,
2004 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2005 Kernel::Avx2Batch => Kernel::Scalar,
2006 Kernel::ScalarBatch => Kernel::Scalar,
2007 _ => unreachable!(),
2008 };
2009
2010 let combos = py
2011 .allow_threads(|| dema_batch_inner_into(slice_in, &sweep, simd, true, out))
2012 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2013
2014 let values: Vec<f64> = unsafe {
2015 Vec::from_raw_parts(
2016 guard.as_mut_ptr() as *mut f64,
2017 guard.len(),
2018 guard.capacity(),
2019 )
2020 };
2021 let arr = values.into_pyarray(py).reshape((rows, cols))?;
2022
2023 let dict = PyDict::new(py);
2024 dict.set_item("values", arr)?;
2025 dict.set_item(
2026 "periods",
2027 combos
2028 .iter()
2029 .map(|p| p.period.unwrap() as u64)
2030 .collect::<Vec<_>>()
2031 .into_pyarray(py),
2032 )?;
2033 Ok(dict)
2034}
2035
2036#[cfg(all(feature = "python", feature = "cuda"))]
2037#[pyfunction(name = "dema_cuda_batch_dev")]
2038#[pyo3(signature = (data_f32, period_range, device_id=0))]
2039pub fn dema_cuda_batch_dev_py(
2040 py: Python<'_>,
2041 data_f32: numpy::PyReadonlyArray1<'_, f32>,
2042 period_range: (usize, usize, usize),
2043 device_id: usize,
2044) -> PyResult<DeviceArrayF32DemaPy> {
2045 use crate::cuda::cuda_available;
2046
2047 if !cuda_available() {
2048 return Err(PyValueError::new_err("CUDA not available"));
2049 }
2050
2051 let slice_in = data_f32.as_slice()?;
2052 let sweep = DemaBatchRange {
2053 period: period_range,
2054 };
2055
2056 let (inner, ctx, dev_id) = py.allow_threads(|| {
2057 let cuda = CudaDema::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2058 let ctx = cuda.ctx();
2059 let dev_id = cuda.device_id();
2060 let arr = cuda
2061 .dema_batch_dev(slice_in, &sweep)
2062 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2063 Ok::<_, pyo3::PyErr>((arr, ctx, dev_id))
2064 })?;
2065
2066 Ok(DeviceArrayF32DemaPy::new(inner, ctx, dev_id))
2067}
2068
2069#[cfg(all(feature = "python", feature = "cuda"))]
2070#[pyfunction(name = "dema_cuda_many_series_one_param_dev")]
2071#[pyo3(signature = (data_tm_f32, period, device_id=0))]
2072pub fn dema_cuda_many_series_one_param_dev_py(
2073 py: Python<'_>,
2074 data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
2075 period: usize,
2076 device_id: usize,
2077) -> PyResult<DeviceArrayF32DemaPy> {
2078 use crate::cuda::cuda_available;
2079 use numpy::PyUntypedArrayMethods;
2080 if !cuda_available() {
2081 return Err(PyValueError::new_err("CUDA not available"));
2082 }
2083 if period == 0 {
2084 return Err(PyValueError::new_err("period must be positive"));
2085 }
2086 let flat = data_tm_f32.as_slice()?;
2087 let shape = data_tm_f32.shape();
2088 let series_len = shape[0];
2089 let num_series = shape[1];
2090 let params = DemaParams {
2091 period: Some(period),
2092 };
2093 let (inner, ctx, dev_id) = py.allow_threads(|| {
2094 let cuda = CudaDema::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2095 let ctx = cuda.ctx();
2096 let dev_id = cuda.device_id();
2097 let arr = cuda
2098 .dema_many_series_one_param_time_major_dev(flat, num_series, series_len, ¶ms)
2099 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2100 Ok::<_, pyo3::PyErr>((arr, ctx, dev_id))
2101 })?;
2102 Ok(DeviceArrayF32DemaPy::new(inner, ctx, dev_id))
2103}
2104
2105#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2106#[wasm_bindgen]
2107pub fn dema_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
2108 let params = DemaParams {
2109 period: Some(period),
2110 };
2111 let input = DemaInput::from_slice(data, params);
2112
2113 let mut output = vec![0.0; data.len()];
2114
2115 dema_into_slice(&mut output, &input, Kernel::Auto)
2116 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2117
2118 Ok(output)
2119}
2120
2121#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2122#[derive(Serialize, Deserialize)]
2123pub struct DemaBatchConfig {
2124 pub period_range: (usize, usize, usize),
2125}
2126
2127#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2128#[derive(Serialize, Deserialize)]
2129pub struct DemaBatchJsOutput {
2130 pub values: Vec<f64>,
2131 pub combos: Vec<DemaParams>,
2132 pub rows: usize,
2133 pub cols: usize,
2134}
2135
2136#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2137#[wasm_bindgen(js_name = dema_batch)]
2138pub fn dema_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2139 let config: DemaBatchConfig = serde_wasm_bindgen::from_value(config)
2140 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2141
2142 let sweep = DemaBatchRange {
2143 period: config.period_range,
2144 };
2145
2146 let output = dema_batch_inner(data, &sweep, Kernel::Auto, false)
2147 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2148
2149 let js_output = DemaBatchJsOutput {
2150 values: output.values,
2151 combos: output.combos,
2152 rows: output.rows,
2153 cols: output.cols,
2154 };
2155
2156 serde_wasm_bindgen::to_value(&js_output)
2157 .map_err(|e| JsValue::from_str(&format!("Failed to serialize output: {}", e)))
2158}
2159
2160#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2161#[wasm_bindgen]
2162#[deprecated(since = "1.0.0", note = "Use dema_batch instead")]
2163pub fn dema_batch_metadata_js(
2164 period_start: usize,
2165 period_end: usize,
2166 period_step: usize,
2167) -> Result<Vec<f64>, JsValue> {
2168 let sweep = DemaBatchRange {
2169 period: (period_start, period_end, period_step),
2170 };
2171
2172 let combos = expand_grid(&sweep);
2173 let metadata: Vec<f64> = combos
2174 .iter()
2175 .map(|combo| combo.period.unwrap() as f64)
2176 .collect();
2177
2178 Ok(metadata)
2179}
2180
2181#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2182#[wasm_bindgen]
2183pub fn dema_alloc(len: usize) -> *mut f64 {
2184 let mut vec = Vec::<f64>::with_capacity(len);
2185 let ptr = vec.as_mut_ptr();
2186 std::mem::forget(vec);
2187 ptr
2188}
2189
2190#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2191#[wasm_bindgen]
2192pub fn dema_free(ptr: *mut f64, len: usize) {
2193 unsafe {
2194 let _ = Vec::from_raw_parts(ptr, len, len);
2195 }
2196}
2197
2198#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2199#[wasm_bindgen]
2200pub fn dema_into(
2201 in_ptr: *const f64,
2202 out_ptr: *mut f64,
2203 len: usize,
2204 period: usize,
2205) -> Result<(), JsValue> {
2206 if in_ptr.is_null() || out_ptr.is_null() {
2207 return Err(JsValue::from_str("null pointer passed to dema_into"));
2208 }
2209
2210 unsafe {
2211 let data = std::slice::from_raw_parts(in_ptr, len);
2212
2213 if period == 0 || period > len {
2214 return Err(JsValue::from_str("Invalid period"));
2215 }
2216
2217 let params = DemaParams {
2218 period: Some(period),
2219 };
2220 let input = DemaInput::from_slice(data, params);
2221
2222 if in_ptr == out_ptr {
2223 let mut temp = vec![0.0; len];
2224 dema_into_slice(&mut temp, &input, Kernel::Auto)
2225 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2226
2227 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2228 out.copy_from_slice(&temp);
2229 } else {
2230 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2231 dema_into_slice(out, &input, Kernel::Auto)
2232 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2233 }
2234
2235 Ok(())
2236 }
2237}
2238
2239#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2240#[wasm_bindgen]
2241pub fn dema_batch_into(
2242 in_ptr: *const f64,
2243 out_ptr: *mut f64,
2244 len: usize,
2245 period_start: usize,
2246 period_end: usize,
2247 period_step: usize,
2248) -> Result<usize, JsValue> {
2249 if in_ptr.is_null() || out_ptr.is_null() {
2250 return Err(JsValue::from_str("null pointer passed to dema_batch_into"));
2251 }
2252
2253 unsafe {
2254 let data = std::slice::from_raw_parts(in_ptr, len);
2255
2256 let sweep = DemaBatchRange {
2257 period: (period_start, period_end, period_step),
2258 };
2259
2260 let combos = expand_grid(&sweep);
2261 let rows = combos.len();
2262 let cols = len;
2263
2264 let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
2265
2266 dema_batch_inner_into(data, &sweep, Kernel::Auto, false, out)
2267 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2268
2269 Ok(rows)
2270 }
2271}