1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::cuda_available;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::cuda::moving_averages::alma_wrapper::DeviceArrayF32;
5#[cfg(all(feature = "python", feature = "cuda"))]
6use crate::cuda::moving_averages::CudaEpma;
7use crate::utilities::data_loader::{source_type, Candles};
8use crate::utilities::enums::Kernel;
9use crate::utilities::helpers::{
10 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
11 make_uninit_matrix,
12};
13#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
14use core::arch::x86_64::*;
15#[cfg(all(feature = "python", feature = "cuda"))]
16use cust::context::Context;
17#[cfg(not(target_arch = "wasm32"))]
18use rayon::prelude::*;
19use std::convert::AsRef;
20use std::mem::{ManuallyDrop, MaybeUninit};
21#[cfg(all(feature = "python", feature = "cuda"))]
22use std::sync::Arc;
23use thiserror::Error;
24
25impl<'a> AsRef<[f64]> for EpmaInput<'a> {
26 #[inline(always)]
27 fn as_ref(&self) -> &[f64] {
28 match &self.data {
29 EpmaData::Slice(slice) => slice,
30 EpmaData::Candles { candles, source } => source_type(candles, source),
31 }
32 }
33}
34
35#[derive(Debug, Clone)]
36pub enum EpmaData<'a> {
37 Candles {
38 candles: &'a Candles,
39 source: &'a str,
40 },
41 Slice(&'a [f64]),
42}
43
44#[derive(Debug, Clone)]
45pub struct EpmaOutput {
46 pub values: Vec<f64>,
47}
48
49#[derive(Debug, Clone)]
50#[cfg_attr(
51 all(target_arch = "wasm32", feature = "wasm"),
52 derive(serde::Serialize, serde::Deserialize)
53)]
54pub struct EpmaParams {
55 pub period: Option<usize>,
56 pub offset: Option<usize>,
57}
58impl Default for EpmaParams {
59 fn default() -> Self {
60 Self {
61 period: Some(11),
62 offset: Some(4),
63 }
64 }
65}
66
67#[derive(Debug, Clone)]
68pub struct EpmaInput<'a> {
69 pub data: EpmaData<'a>,
70 pub params: EpmaParams,
71}
72
73impl<'a> EpmaInput<'a> {
74 #[inline]
75 pub fn from_candles(c: &'a Candles, s: &'a str, p: EpmaParams) -> Self {
76 Self {
77 data: EpmaData::Candles {
78 candles: c,
79 source: s,
80 },
81 params: p,
82 }
83 }
84 #[inline]
85 pub fn from_slice(sl: &'a [f64], p: EpmaParams) -> Self {
86 Self {
87 data: EpmaData::Slice(sl),
88 params: p,
89 }
90 }
91 #[inline]
92 pub fn with_default_candles(c: &'a Candles) -> Self {
93 Self::from_candles(c, "close", EpmaParams::default())
94 }
95 #[inline]
96 pub fn get_period(&self) -> usize {
97 self.params.period.unwrap_or(11)
98 }
99 #[inline]
100 pub fn get_offset(&self) -> usize {
101 self.params.offset.unwrap_or(4)
102 }
103}
104
105#[derive(Debug, Error)]
106pub enum EpmaError {
107 #[error("epma: Input data slice is empty.")]
108 EmptyInputData,
109
110 #[error("epma: All values are NaN.")]
111 AllValuesNaN,
112
113 #[error("epma: Invalid period: period = {period}, data length = {data_len}")]
114 InvalidPeriod { period: usize, data_len: usize },
115
116 #[error("epma: Invalid offset: {offset}")]
117 InvalidOffset { offset: usize },
118
119 #[error("epma: Not enough valid data: needed = {needed}, valid = {valid}")]
120 NotEnoughValidData { needed: usize, valid: usize },
121
122 #[error("epma: output length mismatch: expected = {expected}, got = {got}")]
123 OutputLengthMismatch { expected: usize, got: usize },
124
125 #[error("epma: Invalid kernel for batch operation: got {0:?}")]
126 InvalidKernelForBatch(Kernel),
127
128 #[error("epma: invalid range: start={start}, end={end}, step={step}")]
129 InvalidRange {
130 start: usize,
131 end: usize,
132 step: usize,
133 },
134
135 #[error("epma: size overflow computing rows*cols: rows={rows}, cols={cols}")]
136 SizeOverflow { rows: usize, cols: usize },
137}
138
139#[derive(Copy, Clone, Debug)]
140pub struct EpmaBuilder {
141 period: Option<usize>,
142 offset: Option<usize>,
143 kernel: Kernel,
144}
145impl Default for EpmaBuilder {
146 fn default() -> Self {
147 Self {
148 period: None,
149 offset: None,
150 kernel: Kernel::Auto,
151 }
152 }
153}
154impl EpmaBuilder {
155 #[inline(always)]
156 pub fn new() -> Self {
157 Self::default()
158 }
159 #[inline(always)]
160 pub fn period(mut self, n: usize) -> Self {
161 self.period = Some(n);
162 self
163 }
164 #[inline(always)]
165 pub fn offset(mut self, o: usize) -> Self {
166 self.offset = Some(o);
167 self
168 }
169 #[inline(always)]
170 pub fn kernel(mut self, k: Kernel) -> Self {
171 self.kernel = k;
172 self
173 }
174 #[inline(always)]
175 pub fn apply(self, c: &Candles) -> Result<EpmaOutput, EpmaError> {
176 let p = EpmaParams {
177 period: self.period,
178 offset: self.offset,
179 };
180 let i = EpmaInput::from_candles(c, "close", p);
181 epma_with_kernel(&i, self.kernel)
182 }
183 #[inline(always)]
184 pub fn apply_slice(self, d: &[f64]) -> Result<EpmaOutput, EpmaError> {
185 let p = EpmaParams {
186 period: self.period,
187 offset: self.offset,
188 };
189 let i = EpmaInput::from_slice(d, p);
190 epma_with_kernel(&i, self.kernel)
191 }
192 #[inline(always)]
193 pub fn into_stream(self) -> Result<EpmaStream, EpmaError> {
194 let p = EpmaParams {
195 period: self.period,
196 offset: self.offset,
197 };
198 EpmaStream::try_new(p)
199 }
200}
201
202#[inline]
203pub fn epma(input: &EpmaInput) -> Result<EpmaOutput, EpmaError> {
204 epma_with_kernel(input, Kernel::Auto)
205}
206
207#[inline(always)]
208fn epma_prepare<'a>(
209 input: &'a EpmaInput,
210 kernel: Kernel,
211) -> Result<(&'a [f64], usize, usize, usize, usize, Kernel), EpmaError> {
212 let data: &[f64] = input.as_ref();
213 let len = data.len();
214 if len == 0 {
215 return Err(EpmaError::EmptyInputData);
216 }
217
218 let first = data
219 .iter()
220 .position(|x| !x.is_nan())
221 .ok_or(EpmaError::AllValuesNaN)?;
222 let period = input.get_period();
223 let offset = input.get_offset();
224
225 if offset >= period {
226 return Err(EpmaError::InvalidOffset { offset });
227 }
228
229 if period < 2 || period > len {
230 return Err(EpmaError::InvalidPeriod {
231 period,
232 data_len: len,
233 });
234 }
235 let needed = period + offset + 1;
236 if (len - first) < needed {
237 return Err(EpmaError::NotEnoughValidData {
238 needed,
239 valid: len - first,
240 });
241 }
242
243 let chosen = match kernel {
244 Kernel::Auto => detect_best_kernel(),
245 other => other,
246 };
247 let warmup = first + period + offset + 1;
248
249 Ok((data, period, offset, first, warmup, chosen))
250}
251
252#[inline(always)]
253fn epma_compute_into(
254 data: &[f64],
255 period: usize,
256 offset: usize,
257 first: usize,
258 kernel: Kernel,
259 out: &mut [f64],
260) {
261 unsafe {
262 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
263 {
264 if matches!(kernel, Kernel::Scalar | Kernel::ScalarBatch) {
265 epma_simd128(data, period, offset, first, out);
266 return;
267 }
268 }
269
270 match kernel {
271 Kernel::Scalar | Kernel::ScalarBatch => epma_scalar(data, period, offset, first, out),
272 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
273 Kernel::Avx2 | Kernel::Avx2Batch => epma_avx2(data, period, offset, first, out),
274 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
275 Kernel::Avx512 | Kernel::Avx512Batch => epma_avx512(data, period, offset, first, out),
276 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
277 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
278 epma_scalar(data, period, offset, first, out)
279 }
280 _ => unreachable!(),
281 }
282 }
283}
284
285pub fn epma_with_kernel(input: &EpmaInput, kernel: Kernel) -> Result<EpmaOutput, EpmaError> {
286 let (data, period, offset, first, warmup, chosen) = epma_prepare(input, kernel)?;
287
288 let mut out = alloc_with_nan_prefix(data.len(), warmup);
289 epma_compute_into(data, period, offset, first, chosen, &mut out);
290
291 Ok(EpmaOutput { values: out })
292}
293
294#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
295#[inline]
296pub fn epma_into(input: &EpmaInput, out: &mut [f64]) -> Result<(), EpmaError> {
297 let (data, period, offset, first, warmup, chosen) = epma_prepare(input, Kernel::Auto)?;
298
299 if out.len() != data.len() {
300 return Err(EpmaError::OutputLengthMismatch {
301 expected: data.len(),
302 got: out.len(),
303 });
304 }
305
306 let w = warmup.min(out.len());
307 for v in &mut out[..w] {
308 *v = f64::from_bits(0x7ff8_0000_0000_0000);
309 }
310
311 epma_compute_into(data, period, offset, first, chosen, out);
312 Ok(())
313}
314
315#[inline]
316pub fn epma_into_slice(dst: &mut [f64], input: &EpmaInput, kern: Kernel) -> Result<(), EpmaError> {
317 let (data, period, offset, first, warmup, chosen) = epma_prepare(input, kern)?;
318
319 if dst.len() != data.len() {
320 return Err(EpmaError::OutputLengthMismatch {
321 expected: data.len(),
322 got: dst.len(),
323 });
324 }
325
326 epma_compute_into(data, period, offset, first, chosen, dst);
327
328 for v in &mut dst[..warmup] {
329 *v = f64::from_bits(0x7ff8_0000_0000_0000);
330 }
331
332 Ok(())
333}
334
335#[inline(always)]
336pub fn epma_scalar(
337 data: &[f64],
338 period: usize,
339 offset: usize,
340 first_valid: usize,
341 out: &mut [f64],
342) {
343 let n = data.len();
344 let p1 = period - 1;
345
346 let c0 = 2.0 - (offset as f64);
347 let p1f = p1 as f64;
348 let weight_sum = p1f.mul_add(c0, 0.5 * (p1f - 1.0) * p1f);
349 let inv = 1.0 / weight_sum;
350
351 for j in (first_valid + period + offset + 1)..n {
352 let start = j + 1 - p1;
353
354 let mut sum = 0.0;
355 let mut i = 0usize;
356 let mut wi = c0;
357
358 while i + 3 < p1 {
359 sum = data[start + i].mul_add(wi, sum);
360 sum = data[start + i + 1].mul_add(wi + 1.0, sum);
361 sum = data[start + i + 2].mul_add(wi + 2.0, sum);
362 sum = data[start + i + 3].mul_add(wi + 3.0, sum);
363 i += 4;
364 wi += 4.0;
365 }
366
367 while i < p1 {
368 sum = data[start + i].mul_add(wi, sum);
369 i += 1;
370 wi += 1.0;
371 }
372
373 out[j] = sum * inv;
374 }
375}
376
377#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
378#[inline]
379unsafe fn epma_simd128(
380 data: &[f64],
381 period: usize,
382 offset: usize,
383 first_valid: usize,
384 out: &mut [f64],
385) {
386 use core::arch::wasm32::*;
387
388 const STEP: usize = 2;
389 let n = data.len();
390 let p1 = period - 1;
391
392 let mut weights = Vec::with_capacity(p1);
393 let mut weight_sum = 0.0;
394 for i in 0..p1 {
395 let w = (period as i32 - i as i32 - offset as i32) as f64;
396 weights.push(w);
397 weight_sum += w;
398 }
399
400 let chunks = p1 / STEP;
401 let tail = p1 % STEP;
402
403 for j in (first_valid + period + offset + 1)..n {
404 let start = j + 1 - p1;
405 let mut acc = f64x2_splat(0.0);
406
407 for blk in 0..chunks {
408 let idx = blk * STEP;
409
410 let w0 = weights[p1 - 1 - idx];
411 let w1 = weights[p1 - 2 - idx];
412 let w = f64x2(w0, w1);
413
414 let d = v128_load(data.as_ptr().add(start + idx) as *const v128);
415 acc = f64x2_add(acc, f64x2_mul(d, w));
416 }
417
418 let mut sum = f64x2_extract_lane::<0>(acc) + f64x2_extract_lane::<1>(acc);
419
420 if tail != 0 {
421 sum += data[start + p1 - 1] * weights[0];
422 }
423
424 out[j] = sum / weight_sum;
425 }
426}
427
428#[inline(always)]
429fn build_weights_rev(period: usize, offset: usize) -> (Vec<f64>, f64) {
430 let p1 = period - 1;
431 let mut w = Vec::with_capacity(p1);
432 let mut sum = 0.0;
433
434 for k in 0..p1 {
435 let val = (period as isize - (p1 - 1 - k) as isize - offset as isize) as f64;
436 w.push(val);
437 sum += val;
438 }
439 (w, sum)
440}
441
442#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
443#[target_feature(enable = "avx2,fma")]
444pub unsafe fn epma_avx2(
445 data: &[f64],
446 period: usize,
447 offset: usize,
448 first_valid: usize,
449 out: &mut [f64],
450) {
451 const STEP: usize = 4;
452 let p1 = period - 1;
453 let chunks = p1 / STEP;
454 let tail = p1 % STEP;
455 let mask = match tail {
456 0 => _mm256_setzero_si256(),
457 1 => _mm256_setr_epi64x(-1, 0, 0, 0),
458 2 => _mm256_setr_epi64x(-1, -1, 0, 0),
459 3 => _mm256_setr_epi64x(-1, -1, -1, 0),
460 _ => unreachable!(),
461 };
462
463 let c0 = 2.0 - (offset as f64);
464 let p1f = p1 as f64;
465 let wsum = p1f.mul_add(c0, 0.5 * (p1f - 1.0) * p1f);
466 let inv = 1.0 / wsum;
467
468 let ramp = _mm256_setr_pd(0.0, 1.0, 2.0, 3.0);
469
470 for j in (first_valid + period + offset + 1)..data.len() {
471 let start = j + 1 - p1;
472 let base_ptr = data.as_ptr().add(start);
473
474 let mut acc = _mm256_setzero_pd();
475
476 for blk in 0..chunks {
477 let base_w = c0 + (blk * STEP) as f64;
478 let w = _mm256_add_pd(_mm256_set1_pd(base_w), ramp);
479 let d = _mm256_loadu_pd(base_ptr.add(blk * STEP));
480 acc = _mm256_fmadd_pd(d, w, acc);
481 }
482
483 if tail != 0 {
484 let base_w = c0 + (chunks * STEP) as f64;
485 let w_t = _mm256_add_pd(_mm256_set1_pd(base_w), ramp);
486 let d_t = _mm256_maskload_pd(base_ptr.add(chunks * STEP), mask);
487 acc = _mm256_fmadd_pd(d_t, w_t, acc);
488 }
489
490 let hi = _mm256_extractf128_pd(acc, 1);
491 let lo = _mm256_castpd256_pd128(acc);
492 let s2 = _mm_add_pd(hi, lo);
493 let s1 = _mm_add_pd(s2, _mm_unpackhi_pd(s2, s2));
494 *out.get_unchecked_mut(j) = _mm_cvtsd_f64(s1) * inv;
495 }
496}
497
498#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
499#[inline]
500pub unsafe fn epma_avx512(
501 data: &[f64],
502 period: usize,
503 offset: usize,
504 first_valid: usize,
505 out: &mut [f64],
506) {
507 if period <= 32 {
508 epma_avx512_short(data, period, offset, first_valid, out)
509 } else {
510 epma_avx512_long(data, period, offset, first_valid, out)
511 }
512}
513#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
514#[target_feature(enable = "avx512f,avx512dq,fma")]
515#[inline]
516unsafe fn epma_avx512_short(
517 data: &[f64],
518 period: usize,
519 offset: usize,
520 first_valid: usize,
521 out: &mut [f64],
522) {
523 const STEP: usize = 8;
524 let p1 = period - 1;
525 let chunks = p1 / STEP;
526 let tail = p1 % STEP;
527 let tmask: __mmask8 = (1u8 << tail).wrapping_sub(1);
528
529 let c0 = 2.0 - (offset as f64);
530 let p1f = p1 as f64;
531 let wsum = p1f.mul_add(c0, 0.5 * (p1f - 1.0) * p1f);
532 let inv = 1.0 / wsum;
533
534 let ramp = _mm512_setr_pd(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0);
535
536 for j in (first_valid + period + offset + 1)..data.len() {
537 let start = j + 1 - p1;
538 let base_ptr = data.as_ptr().add(start);
539 let mut acc = _mm512_setzero_pd();
540
541 for blk in 0..chunks {
542 let base_w = c0 + (blk * STEP) as f64;
543 let w = _mm512_add_pd(_mm512_set1_pd(base_w), ramp);
544 let d = _mm512_loadu_pd(base_ptr.add(blk * STEP));
545 acc = _mm512_fmadd_pd(d, w, acc);
546 }
547
548 if tail != 0 {
549 let base_w = c0 + (chunks * STEP) as f64;
550 let w_t = _mm512_add_pd(_mm512_set1_pd(base_w), ramp);
551 let d_t = _mm512_maskz_loadu_pd(tmask, base_ptr.add(chunks * STEP));
552 acc = _mm512_fmadd_pd(d_t, w_t, acc);
553 }
554
555 *out.get_unchecked_mut(j) = _mm512_reduce_add_pd(acc) * inv;
556 }
557}
558#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
559#[target_feature(enable = "avx512f,avx512dq,fma")]
560unsafe fn epma_avx512_long(
561 data: &[f64],
562 period: usize,
563 offset: usize,
564 first_valid: usize,
565 out: &mut [f64],
566) {
567 const STEP: usize = 8;
568 let p1 = period - 1;
569 let n_chunks = p1 / STEP;
570 let tail_len = p1 % STEP;
571 let paired = n_chunks & !3;
572 let tmask: __mmask8 = (1u8 << tail_len).wrapping_sub(1);
573
574 let c0 = 2.0 - (offset as f64);
575 let p1f = p1 as f64;
576 let wsum = p1f.mul_add(c0, 0.5 * (p1f - 1.0) * p1f);
577 let inv = 1.0 / wsum;
578
579 let ramp = _mm512_setr_pd(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0);
580
581 for j in (first_valid + period + offset + 1)..data.len() {
582 let start_ptr = data.as_ptr().add(j + 1 - p1);
583
584 let mut s0 = _mm512_setzero_pd();
585 let mut s1 = _mm512_setzero_pd();
586 let mut s2 = _mm512_setzero_pd();
587 let mut s3 = _mm512_setzero_pd();
588
589 let mut blk = 0usize;
590 while blk < paired {
591 let base0 = c0 + ((blk + 0) * STEP) as f64;
592 let base1 = c0 + ((blk + 1) * STEP) as f64;
593 let base2 = c0 + ((blk + 2) * STEP) as f64;
594 let base3 = c0 + ((blk + 3) * STEP) as f64;
595
596 let w0 = _mm512_add_pd(_mm512_set1_pd(base0), ramp);
597 let w1 = _mm512_add_pd(_mm512_set1_pd(base1), ramp);
598 let w2 = _mm512_add_pd(_mm512_set1_pd(base2), ramp);
599 let w3 = _mm512_add_pd(_mm512_set1_pd(base3), ramp);
600
601 let d0 = _mm512_loadu_pd(start_ptr.add((blk + 0) * STEP));
602 let d1 = _mm512_loadu_pd(start_ptr.add((blk + 1) * STEP));
603 let d2 = _mm512_loadu_pd(start_ptr.add((blk + 2) * STEP));
604 let d3 = _mm512_loadu_pd(start_ptr.add((blk + 3) * STEP));
605
606 s0 = _mm512_fmadd_pd(d0, w0, s0);
607 s1 = _mm512_fmadd_pd(d1, w1, s1);
608 s2 = _mm512_fmadd_pd(d2, w2, s2);
609 s3 = _mm512_fmadd_pd(d3, w3, s3);
610
611 blk += 4;
612 }
613
614 while blk < n_chunks {
615 let base = c0 + (blk * STEP) as f64;
616 let w = _mm512_add_pd(_mm512_set1_pd(base), ramp);
617 let d = _mm512_loadu_pd(start_ptr.add(blk * STEP));
618 s0 = _mm512_fmadd_pd(d, w, s0);
619 blk += 1;
620 }
621
622 if tail_len != 0 {
623 let base = c0 + (n_chunks * STEP) as f64;
624 let w_t = _mm512_add_pd(_mm512_set1_pd(base), ramp);
625 let d_t = _mm512_maskz_loadu_pd(tmask, start_ptr.add(n_chunks * STEP));
626 s0 = _mm512_fmadd_pd(d_t, w_t, s0);
627 }
628
629 let total = _mm512_add_pd(_mm512_add_pd(s0, s1), _mm512_add_pd(s2, s3));
630 *out.get_unchecked_mut(j) = _mm512_reduce_add_pd(total) * inv;
631 }
632}
633
634#[derive(Debug, Clone)]
635pub struct EpmaStream {
636 period: usize,
637 offset: usize,
638 p1: usize,
639
640 buffer: Vec<f64>,
641 head: usize,
642
643 seen: usize,
644 included: usize,
645
646 sum: f64,
647 sum_c: f64,
648 ramp: f64,
649 ramp_c: f64,
650
651 c0: f64,
652 inv_wsum: f64,
653}
654
655#[inline(always)]
656fn kahan_add(sum: &mut f64, c: &mut f64, x: f64) {
657 let y = x - *c;
658 let t = *sum + y;
659 *c = (t - *sum) - y;
660 *sum = t;
661}
662
663impl EpmaStream {
664 pub fn try_new(params: EpmaParams) -> Result<Self, EpmaError> {
665 let period = params.period.unwrap_or(11);
666 let offset = params.offset.unwrap_or(4);
667
668 if period < 2 {
669 return Err(EpmaError::InvalidPeriod {
670 period,
671 data_len: 0,
672 });
673 }
674 if offset >= period {
675 return Err(EpmaError::InvalidOffset { offset });
676 }
677
678 let p1 = period - 1;
679 let c0 = 2.0 - (offset as f64);
680
681 let p1f = p1 as f64;
682 let wsum = p1f.mul_add(c0, 0.5 * (p1f - 1.0) * p1f);
683 let inv_wsum = 1.0 / wsum;
684
685 Ok(Self {
686 period,
687 offset,
688 p1,
689
690 buffer: vec![0.0; period],
691 head: 0,
692
693 seen: 0,
694 included: 0,
695
696 sum: 0.0,
697 sum_c: 0.0,
698 ramp: 0.0,
699 ramp_c: 0.0,
700
701 c0,
702 inv_wsum,
703 })
704 }
705
706 #[inline(always)]
707 pub fn update(&mut self, value: f64) -> Option<f64> {
708 let p = self.period;
709 let p1m1 = (self.p1 - 1) as f64;
710
711 let idx_out = (self.head + 1) % p;
712 let x_out = if self.included == self.p1 {
713 self.buffer[idx_out]
714 } else {
715 0.0
716 };
717
718 self.buffer[self.head] = value;
719 self.head = (self.head + 1) % p;
720 self.seen += 1;
721
722 if self.included < self.p1 {
723 let m = self.included as f64;
724
725 kahan_add(&mut self.sum, &mut self.sum_c, value);
726
727 kahan_add(&mut self.ramp, &mut self.ramp_c, m * value);
728
729 self.included += 1;
730 } else {
731 let s_old = self.sum;
732
733 let delta_s = value - x_out;
734 kahan_add(&mut self.sum, &mut self.sum_c, delta_s);
735
736 let delta_r = p1m1.mul_add(value, x_out - s_old);
737 kahan_add(&mut self.ramp, &mut self.ramp_c, delta_r);
738 }
739
740 if self.seen <= (self.period + self.offset + 1) {
741 return Some(value);
742 }
743
744 let num = self.c0.mul_add(self.sum, self.ramp);
745 Some(num * self.inv_wsum)
746 }
747}
748
749#[derive(Clone, Debug)]
750pub struct EpmaBatchRange {
751 pub period: (usize, usize, usize),
752 pub offset: (usize, usize, usize),
753}
754impl Default for EpmaBatchRange {
755 fn default() -> Self {
756 Self {
757 period: (11, 260, 1),
758 offset: (4, 4, 0),
759 }
760 }
761}
762#[derive(Clone, Debug, Default)]
763pub struct EpmaBatchBuilder {
764 range: EpmaBatchRange,
765 kernel: Kernel,
766}
767impl EpmaBatchBuilder {
768 pub fn new() -> Self {
769 Self::default()
770 }
771 pub fn kernel(mut self, k: Kernel) -> Self {
772 self.kernel = k;
773 self
774 }
775 #[inline]
776 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
777 self.range.period = (start, end, step);
778 self
779 }
780 #[inline]
781 pub fn period_static(mut self, p: usize) -> Self {
782 self.range.period = (p, p, 0);
783 self
784 }
785 #[inline]
786 pub fn offset_range(mut self, start: usize, end: usize, step: usize) -> Self {
787 self.range.offset = (start, end, step);
788 self
789 }
790 #[inline]
791 pub fn offset_static(mut self, o: usize) -> Self {
792 self.range.offset = (o, o, 0);
793 self
794 }
795 pub fn apply_slice(self, data: &[f64]) -> Result<EpmaBatchOutput, EpmaError> {
796 epma_batch_with_kernel(data, &self.range, self.kernel)
797 }
798 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<EpmaBatchOutput, EpmaError> {
799 EpmaBatchBuilder::new().kernel(k).apply_slice(data)
800 }
801 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<EpmaBatchOutput, EpmaError> {
802 let slice = source_type(c, src);
803 self.apply_slice(slice)
804 }
805 pub fn with_default_candles(c: &Candles) -> Result<EpmaBatchOutput, EpmaError> {
806 EpmaBatchBuilder::new()
807 .kernel(Kernel::Auto)
808 .apply_candles(c, "close")
809 }
810}
811
812#[derive(Clone, Debug)]
813pub struct EpmaBatchOutput {
814 pub values: Vec<f64>,
815 pub combos: Vec<EpmaParams>,
816 pub rows: usize,
817 pub cols: usize,
818}
819impl EpmaBatchOutput {
820 pub fn row_for_params(&self, p: &EpmaParams) -> Option<usize> {
821 self.combos.iter().position(|c| {
822 c.period.unwrap_or(11) == p.period.unwrap_or(11)
823 && c.offset.unwrap_or(4) == p.offset.unwrap_or(4)
824 })
825 }
826 pub fn values_for(&self, p: &EpmaParams) -> Option<&[f64]> {
827 self.row_for_params(p).map(|row| {
828 let start = row * self.cols;
829 &self.values[start..start + self.cols]
830 })
831 }
832}
833
834#[inline(always)]
835fn expand_grid(r: &EpmaBatchRange) -> Vec<EpmaParams> {
836 fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
837 if step == 0 {
838 return vec![start];
839 }
840 let (lo, hi) = if start <= end {
841 (start, end)
842 } else {
843 (end, start)
844 };
845 (lo..=hi).step_by(step).collect()
846 }
847 let periods = axis_usize(r.period);
848 let offsets = axis_usize(r.offset);
849 let mut out = Vec::with_capacity(periods.len() * offsets.len());
850 for &p in &periods {
851 for &o in &offsets {
852 out.push(EpmaParams {
853 period: Some(p),
854 offset: Some(o),
855 });
856 }
857 }
858 out
859}
860
861#[inline(always)]
862pub fn epma_batch_with_kernel(
863 data: &[f64],
864 sweep: &EpmaBatchRange,
865 k: Kernel,
866) -> Result<EpmaBatchOutput, EpmaError> {
867 let kernel = match k {
868 Kernel::Auto => detect_best_batch_kernel(),
869 other if other.is_batch() => other,
870 other => return Err(EpmaError::InvalidKernelForBatch(other)),
871 };
872 let simd = match kernel {
873 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
874 Kernel::Avx512Batch => Kernel::Avx512,
875 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
876 Kernel::Avx2Batch => Kernel::Avx2,
877 Kernel::ScalarBatch => Kernel::Scalar,
878 _ => unreachable!(),
879 };
880
881 epma_batch_par_slice(data, sweep, simd)
882}
883
884#[inline(always)]
885pub fn epma_batch_slice(
886 data: &[f64],
887 sweep: &EpmaBatchRange,
888 kern: Kernel,
889) -> Result<EpmaBatchOutput, EpmaError> {
890 epma_batch_inner(data, sweep, kern, false)
891}
892#[inline(always)]
893pub fn epma_batch_par_slice(
894 data: &[f64],
895 sweep: &EpmaBatchRange,
896 kern: Kernel,
897) -> Result<EpmaBatchOutput, EpmaError> {
898 epma_batch_inner(data, sweep, kern, true)
899}
900#[inline(always)]
901fn epma_batch_inner(
902 data: &[f64],
903 sweep: &EpmaBatchRange,
904 kern: Kernel,
905 parallel: bool,
906) -> Result<EpmaBatchOutput, EpmaError> {
907 let combos = expand_grid(sweep);
908 let rows = combos.len();
909 let cols = data.len();
910
911 let _total_cells = rows
912 .checked_mul(cols)
913 .ok_or(EpmaError::SizeOverflow { rows, cols })?;
914
915 let mut buf_mu = make_uninit_matrix(rows, cols);
916
917 let combos = epma_batch_inner_into_uninit(data, sweep, kern, parallel, &mut buf_mu)?;
918
919 let values = unsafe {
920 let mut buf_guard = ManuallyDrop::new(buf_mu);
921 Vec::from_raw_parts(
922 buf_guard.as_mut_ptr() as *mut f64,
923 buf_guard.len(),
924 buf_guard.capacity(),
925 )
926 };
927
928 Ok(EpmaBatchOutput {
929 values,
930 combos,
931 rows,
932 cols,
933 })
934}
935
936#[cfg(any(feature = "python", feature = "wasm"))]
937#[inline(always)]
938pub fn epma_batch_inner_into(
939 data: &[f64],
940 sweep: &EpmaBatchRange,
941 kern: Kernel,
942 parallel: bool,
943 out: &mut [f64],
944) -> Result<Vec<EpmaParams>, EpmaError> {
945 let buf_mu = unsafe {
946 std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
947 };
948 epma_batch_inner_into_uninit(data, sweep, kern, parallel, buf_mu)
949}
950
951#[inline(always)]
952fn epma_batch_inner_into_uninit(
953 data: &[f64],
954 sweep: &EpmaBatchRange,
955 kern: Kernel,
956 parallel: bool,
957 buf_mu: &mut [MaybeUninit<f64>],
958) -> Result<Vec<EpmaParams>, EpmaError> {
959 if data.is_empty() {
960 return Err(EpmaError::EmptyInputData);
961 }
962 let combos = expand_grid(sweep);
963 if combos.is_empty() {
964 return Err(EpmaError::InvalidRange {
965 start: 0,
966 end: 0,
967 step: 0,
968 });
969 }
970 for c in &combos {
971 let p = c.period.unwrap();
972 let o = c.offset.unwrap();
973 if p < 2 {
974 return Err(EpmaError::InvalidPeriod {
975 period: p,
976 data_len: data.len(),
977 });
978 }
979 if o >= p {
980 return Err(EpmaError::InvalidOffset { offset: o });
981 }
982 }
983 let first = data
984 .iter()
985 .position(|x| !x.is_nan())
986 .ok_or(EpmaError::AllValuesNaN)?;
987 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
988 let max_off = combos.iter().map(|c| c.offset.unwrap()).max().unwrap();
989 let needed = max_p + max_off + 1;
990 if data.len() - first < needed {
991 return Err(EpmaError::NotEnoughValidData {
992 needed,
993 valid: data.len() - first,
994 });
995 }
996 let cols = data.len();
997
998 let warm: Vec<usize> = combos
999 .iter()
1000 .map(|c| first + c.period.unwrap() + c.offset.unwrap() + 1)
1001 .collect();
1002 init_matrix_prefixes(buf_mu, cols, &warm);
1003
1004 let use_prefix = {
1005 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1006 {
1007 matches!(kern, Kernel::Avx2 | Kernel::Avx512)
1008 }
1009 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1010 {
1011 false
1012 }
1013 };
1014
1015 if use_prefix {
1016 let mut p: Vec<f64> = Vec::with_capacity(cols);
1017 let mut q: Vec<f64> = Vec::with_capacity(cols);
1018 let mut ps = 0.0f64;
1019 let mut qs = 0.0f64;
1020 for (idx, &x) in data.iter().enumerate() {
1021 ps += x;
1022 qs = (idx as f64).mul_add(x, qs);
1023 p.push(ps);
1024 q.push(qs);
1025 }
1026
1027 let do_row_prefix = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| {
1028 let period = combos[row].period.unwrap();
1029 let offset = combos[row].offset.unwrap();
1030 let p1 = period - 1;
1031 let c0 = 2.0 - (offset as f64);
1032 let p1f = p1 as f64;
1033 let wsum = p1f.mul_add(c0, 0.5 * (p1f - 1.0) * p1f);
1034 let inv = 1.0 / wsum;
1035 let warmup = warm[row];
1036
1037 let dst = unsafe {
1038 core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len())
1039 };
1040
1041 for j in warmup..cols {
1042 let a = j + 1 - p1;
1043 let b = j;
1044
1045 let sumx = if a > 0 { p[b] - p[a - 1] } else { p[b] };
1046
1047 let sum_abs = if a > 0 { q[b] - q[a - 1] } else { q[b] };
1048
1049 let sum_ix = sum_abs - (a as f64) * sumx;
1050 let y = (c0 * sumx + sum_ix) * inv;
1051
1052 unsafe {
1053 *dst.get_unchecked_mut(j) = y;
1054 }
1055 }
1056 };
1057
1058 if parallel {
1059 #[cfg(not(target_arch = "wasm32"))]
1060 {
1061 buf_mu
1062 .par_chunks_mut(cols)
1063 .enumerate()
1064 .for_each(|(row, slice)| do_row_prefix(row, slice));
1065 }
1066 #[cfg(target_arch = "wasm32")]
1067 {
1068 for (row, slice) in buf_mu.chunks_mut(cols).enumerate() {
1069 do_row_prefix(row, slice);
1070 }
1071 }
1072 } else {
1073 for (row, slice) in buf_mu.chunks_mut(cols).enumerate() {
1074 do_row_prefix(row, slice);
1075 }
1076 }
1077 } else {
1078 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
1079 let period = combos[row].period.unwrap();
1080 let offset = combos[row].offset.unwrap();
1081
1082 let dst =
1083 core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
1084
1085 match kern {
1086 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1087 Kernel::Avx512 => epma_row_avx512(data, first, period, offset, dst),
1088 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1089 Kernel::Avx2 => epma_row_avx2(data, first, period, offset, dst),
1090 _ => epma_row_scalar(data, first, period, offset, dst),
1091 }
1092 };
1093
1094 if parallel {
1095 #[cfg(not(target_arch = "wasm32"))]
1096 {
1097 buf_mu
1098 .par_chunks_mut(cols)
1099 .enumerate()
1100 .for_each(|(row, slice)| do_row(row, slice));
1101 }
1102 #[cfg(target_arch = "wasm32")]
1103 {
1104 for (row, slice) in buf_mu.chunks_mut(cols).enumerate() {
1105 do_row(row, slice);
1106 }
1107 }
1108 } else {
1109 for (row, slice) in buf_mu.chunks_mut(cols).enumerate() {
1110 do_row(row, slice);
1111 }
1112 }
1113 }
1114
1115 Ok(combos)
1116}
1117
1118#[inline(always)]
1119unsafe fn epma_row_scalar(
1120 data: &[f64],
1121 first: usize,
1122 period: usize,
1123 offset: usize,
1124 out: &mut [f64],
1125) {
1126 epma_scalar(data, period, offset, first, out);
1127}
1128
1129#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1130#[target_feature(enable = "avx2,fma")]
1131unsafe fn epma_row_avx2(data: &[f64], first: usize, period: usize, offset: usize, out: &mut [f64]) {
1132 epma_avx2(data, period, offset, first, out);
1133}
1134
1135#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1136#[target_feature(enable = "avx512f,avx512dq,fma")]
1137unsafe fn epma_row_avx512(
1138 data: &[f64],
1139 first: usize,
1140 period: usize,
1141 offset: usize,
1142 out: &mut [f64],
1143) {
1144 epma_avx512(data, period, offset, first, out);
1145}
1146
1147#[cfg(test)]
1148mod tests {
1149 use super::*;
1150 use crate::skip_if_unsupported;
1151 use crate::utilities::data_loader::read_candles_from_csv;
1152 use std::error::Error;
1153
1154 #[test]
1155 fn test_epma_into_matches_api() -> Result<(), Box<dyn Error>> {
1156 let mut data = Vec::with_capacity(256);
1157 data.extend_from_slice(&[f64::NAN, f64::NAN, f64::NAN]);
1158 for i in 0..253u32 {
1159 let v = (i as f64).sin() * 10.0 + (i as f64) * 0.01;
1160 data.push(v);
1161 }
1162
1163 let input = EpmaInput::from_slice(&data, EpmaParams::default());
1164
1165 let baseline = epma(&input)?.values;
1166
1167 let mut out = vec![0.0; data.len()];
1168 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1169 {
1170 epma_into(&input, &mut out)?;
1171 }
1172
1173 fn eq_or_both_nan(a: f64, b: f64) -> bool {
1174 (a.is_nan() && b.is_nan()) || (a == b)
1175 }
1176
1177 assert_eq!(baseline.len(), out.len());
1178 for i in 0..out.len() {
1179 assert!(
1180 eq_or_both_nan(baseline[i], out[i]),
1181 "Mismatch at {}: baseline={} out={}",
1182 i,
1183 baseline[i],
1184 out[i]
1185 );
1186 }
1187 Ok(())
1188 }
1189
1190 fn check_epma_partial_params(
1191 test_name: &str,
1192 kernel: Kernel,
1193 ) -> Result<(), Box<dyn std::error::Error>> {
1194 skip_if_unsupported!(kernel, test_name);
1195 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1196 let candles = read_candles_from_csv(file_path)?;
1197
1198 let default_params = EpmaParams {
1199 period: None,
1200 offset: None,
1201 };
1202 let input = EpmaInput::from_candles(&candles, "close", default_params);
1203 let output = epma_with_kernel(&input, kernel)?;
1204 assert_eq!(output.values.len(), candles.close.len());
1205 Ok(())
1206 }
1207 fn check_epma_accuracy(
1208 test_name: &str,
1209 kernel: Kernel,
1210 ) -> Result<(), Box<dyn std::error::Error>> {
1211 skip_if_unsupported!(kernel, test_name);
1212 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1213 let candles = read_candles_from_csv(file_path)?;
1214 let default_params = EpmaParams::default();
1215 let input = EpmaInput::from_candles(&candles, "close", default_params);
1216 let result = epma_with_kernel(&input, kernel)?;
1217 let expected_last_five = [59174.48, 59201.04, 59167.60, 59200.32, 59117.04];
1218 let start_index = result.values.len().saturating_sub(5);
1219 let result_last_five = &result.values[start_index..];
1220 for (i, &value) in result_last_five.iter().enumerate() {
1221 assert!(
1222 (value - expected_last_five[i]).abs() < 1e-1,
1223 "[{}] EPMA {:?} mismatch at idx {}: got {}, expected {}",
1224 test_name,
1225 kernel,
1226 i,
1227 value,
1228 expected_last_five[i]
1229 );
1230 }
1231 Ok(())
1232 }
1233 fn check_epma_default_candles(
1234 test_name: &str,
1235 kernel: Kernel,
1236 ) -> Result<(), Box<dyn std::error::Error>> {
1237 skip_if_unsupported!(kernel, test_name);
1238 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1239 let candles = read_candles_from_csv(file_path)?;
1240 let input = EpmaInput::with_default_candles(&candles);
1241 match input.data {
1242 EpmaData::Candles { source, .. } => assert_eq!(source, "close"),
1243 _ => panic!("Expected EpmaData::Candles"),
1244 }
1245 let output = epma_with_kernel(&input, kernel)?;
1246 assert_eq!(output.values.len(), candles.close.len());
1247 Ok(())
1248 }
1249 fn check_epma_zero_period(
1250 test_name: &str,
1251 kernel: Kernel,
1252 ) -> Result<(), Box<dyn std::error::Error>> {
1253 skip_if_unsupported!(kernel, test_name);
1254 let input_data = [10.0, 20.0, 30.0];
1255 let params = EpmaParams {
1256 period: Some(0),
1257 offset: None,
1258 };
1259 let input = EpmaInput::from_slice(&input_data, params);
1260 let res = epma_with_kernel(&input, kernel);
1261 assert!(
1262 res.is_err(),
1263 "[{}] EPMA should fail with zero period",
1264 test_name
1265 );
1266 Ok(())
1267 }
1268 fn check_epma_period_exceeds_length(
1269 test_name: &str,
1270 kernel: Kernel,
1271 ) -> Result<(), Box<dyn std::error::Error>> {
1272 skip_if_unsupported!(kernel, test_name);
1273 let data_small = [10.0, 20.0, 30.0];
1274 let params = EpmaParams {
1275 period: Some(10),
1276 offset: None,
1277 };
1278 let input = EpmaInput::from_slice(&data_small, params);
1279 let res = epma_with_kernel(&input, kernel);
1280 assert!(
1281 res.is_err(),
1282 "[{}] EPMA should fail with period exceeding length",
1283 test_name
1284 );
1285 Ok(())
1286 }
1287 fn check_epma_very_small_dataset(
1288 test_name: &str,
1289 kernel: Kernel,
1290 ) -> Result<(), Box<dyn std::error::Error>> {
1291 skip_if_unsupported!(kernel, test_name);
1292 let single_point = [42.0];
1293 let params = EpmaParams {
1294 period: Some(9),
1295 offset: None,
1296 };
1297 let input = EpmaInput::from_slice(&single_point, params);
1298 let res = epma_with_kernel(&input, kernel);
1299 assert!(
1300 res.is_err(),
1301 "[{}] EPMA should fail with insufficient data",
1302 test_name
1303 );
1304 Ok(())
1305 }
1306 fn check_epma_empty_input(
1307 test_name: &str,
1308 kernel: Kernel,
1309 ) -> Result<(), Box<dyn std::error::Error>> {
1310 skip_if_unsupported!(kernel, test_name);
1311 let empty: [f64; 0] = [];
1312 let input = EpmaInput::from_slice(&empty, EpmaParams::default());
1313 let res = epma_with_kernel(&input, kernel);
1314 assert!(
1315 matches!(res, Err(EpmaError::EmptyInputData)),
1316 "[{}] EPMA should fail with empty input",
1317 test_name
1318 );
1319 Ok(())
1320 }
1321 fn check_epma_invalid_offset(
1322 test_name: &str,
1323 kernel: Kernel,
1324 ) -> Result<(), Box<dyn std::error::Error>> {
1325 skip_if_unsupported!(kernel, test_name);
1326 let data = [1.0, 2.0, 3.0, 4.0];
1327 let params = EpmaParams {
1328 period: Some(3),
1329 offset: Some(3),
1330 };
1331 let input = EpmaInput::from_slice(&data, params);
1332 let res = epma_with_kernel(&input, kernel);
1333 assert!(
1334 matches!(res, Err(EpmaError::InvalidOffset { .. })),
1335 "[{}] EPMA should fail with invalid offset",
1336 test_name
1337 );
1338 Ok(())
1339 }
1340 fn check_epma_property(
1341 test_name: &str,
1342 kernel: Kernel,
1343 ) -> Result<(), Box<dyn std::error::Error>> {
1344 use proptest::prelude::*;
1345 skip_if_unsupported!(kernel, test_name);
1346
1347 let strat = (2usize..=50).prop_flat_map(|period| {
1348 (
1349 prop::collection::vec(
1350 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1351 (period * 2 + 10)..500,
1352 ),
1353 Just(period),
1354 0usize..period,
1355 )
1356 });
1357
1358 proptest::test_runner::TestRunner::default()
1359 .run(&strat, |(data, period, offset)| {
1360 let params = EpmaParams {
1361 period: Some(period),
1362 offset: Some(offset),
1363 };
1364 let input = EpmaInput::from_slice(&data, params);
1365
1366
1367 let EpmaOutput { values: out } = epma_with_kernel(&input, kernel).unwrap();
1368
1369
1370 let EpmaOutput { values: ref_out } = epma_with_kernel(&input, Kernel::Scalar).unwrap();
1371
1372
1373 let first_valid = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1374
1375
1376 let warmup = first_valid + period + offset + 1;
1377
1378
1379 for i in 0..warmup.min(out.len()) {
1380 prop_assert!(
1381 out[i].is_nan(),
1382 "[{}] Expected NaN during warmup at index {}, got {}",
1383 test_name,
1384 i,
1385 out[i]
1386 );
1387 }
1388
1389
1390
1391 if warmup < out.len() && data[warmup].is_finite() {
1392
1393 let p1 = period - 1;
1394 let mut weight_sum = 0.0;
1395 for i in 0..p1 {
1396 let w = (period as i32 - i as i32 - offset as i32) as f64;
1397 weight_sum += w;
1398 }
1399
1400
1401 if weight_sum.abs() > 1e-10 {
1402 prop_assert!(
1403 !out[warmup].is_nan(),
1404 "[{}] Expected valid value at warmup index {}, got NaN",
1405 test_name,
1406 warmup
1407 );
1408 }
1409 }
1410
1411
1412
1413 let p1 = period - 1;
1414 let mut weight_sum = 0.0;
1415 for i in 0..p1 {
1416 let w = (period as i32 - i as i32 - offset as i32) as f64;
1417 weight_sum += w;
1418 }
1419
1420 if weight_sum.abs() > 1e-10 {
1421
1422 for i in warmup..data.len() {
1423 let y = out[i];
1424 prop_assert!(
1425 y.is_finite(),
1426 "[{}] EPMA output at index {} is not finite: {} (period={}, offset={}, weight_sum={})",
1427 test_name,
1428 i,
1429 y,
1430 period,
1431 offset,
1432 weight_sum
1433 );
1434 }
1435 } else {
1436
1437
1438 for i in warmup..data.len() {
1439 let both_nan = out[i].is_nan() && ref_out[i].is_nan();
1440 let both_inf = out[i].is_infinite() && ref_out[i].is_infinite();
1441 prop_assert!(
1442 both_nan || both_inf,
1443 "[{}] With weight_sum=0, expected consistent NaN or Inf at index {} but got: kernel={}, scalar={} (period={}, offset={})",
1444 test_name,
1445 i,
1446 out[i],
1447 ref_out[i],
1448 period,
1449 offset
1450 );
1451 }
1452 }
1453
1454
1455
1456
1457 if period == 2 && offset == 0 && warmup < data.len() {
1458
1459 for i in warmup..data.len() {
1460 if data[i].is_finite() {
1461
1462 prop_assert!(
1463 (out[i] - data[i]).abs() < 1e-9,
1464 "[{}] Period=2,offset=0 mismatch at {}: got {}, expected {}",
1465 test_name,
1466 i,
1467 out[i],
1468 data[i]
1469 );
1470 }
1471 }
1472 }
1473
1474
1475
1476 if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-12) && data.iter().any(|x| x.is_finite() && x.abs() > 1e-10) {
1477 let constant = *data.iter().find(|x| x.is_finite()).unwrap();
1478
1479 let p1 = period - 1;
1480 let mut weight_sum = 0.0;
1481 for i in 0..p1 {
1482 let w = (period as i32 - i as i32 - offset as i32) as f64;
1483 weight_sum += w;
1484 }
1485
1486 if weight_sum.abs() > 1e-10 {
1487 for i in warmup..data.len() {
1488 prop_assert!(
1489 (out[i] - constant).abs() < 1e-9,
1490 "[{}] Constant data mismatch at {}: got {}, expected {}",
1491 test_name,
1492 i,
1493 out[i],
1494 constant
1495 );
1496 }
1497 }
1498 }
1499
1500
1501
1502 for i in (warmup.saturating_add(1))..data.len() {
1503 let y = out[i];
1504 let r = ref_out[i];
1505
1506 if !y.is_finite() || !r.is_finite() {
1507 prop_assert!(
1508 y.to_bits() == r.to_bits(),
1509 "[{}] finite/NaN mismatch at idx {}: {} vs {}",
1510 test_name,
1511 i,
1512 y,
1513 r
1514 );
1515 continue;
1516 }
1517
1518 let ulp_diff: u64 = y.to_bits().abs_diff(r.to_bits());
1519
1520 let rel_error = if r.abs() > 1e-10 {
1521 (y - r).abs() / r.abs()
1522 } else {
1523 (y - r).abs()
1524 };
1525
1526
1527 prop_assert!(
1528 rel_error <= 1e-4 || (y - r).abs() <= 1e-9 || ulp_diff <= 100,
1529 "[{}] Kernel mismatch at idx {}: {} vs {} (ULP={}, rel_err={})",
1530 test_name,
1531 i,
1532 y,
1533 r,
1534 ulp_diff,
1535 rel_error
1536 );
1537 }
1538
1539
1540
1541 if warmup + 5 < data.len() {
1542
1543 let p1 = period - 1;
1544 let mut weights = Vec::with_capacity(p1);
1545 let mut weight_sum = 0.0;
1546 for i in 0..p1 {
1547 let w = (period as i32 - i as i32 - offset as i32) as f64;
1548 weights.push(w);
1549 weight_sum += w;
1550 }
1551
1552
1553 if weight_sum.abs() > 1e-10 {
1554
1555 for idx in [warmup + 1, warmup + 2, data.len() - 1].iter().copied() {
1556 if idx > warmup && idx < data.len() {
1557 let start = idx + 1 - p1;
1558 let mut expected_sum = 0.0;
1559 for i in 0..p1 {
1560 expected_sum += data[start + i] * weights[p1 - 1 - i];
1561 }
1562 let expected = expected_sum / weight_sum;
1563
1564
1565 if out[idx].is_finite() && expected.is_finite() {
1566
1567 let tolerance = if expected.abs() > 1000.0 {
1568 expected.abs() * 1e-12
1569 } else {
1570 1e-9
1571 };
1572 prop_assert!(
1573 (out[idx] - expected).abs() < tolerance,
1574 "[{}] EPMA formula mismatch at {}: got {}, expected {} (diff: {})",
1575 test_name,
1576 idx,
1577 out[idx],
1578 expected,
1579 (out[idx] - expected).abs()
1580 );
1581 } else {
1582
1583 prop_assert!(
1584 out[idx].is_nan() == expected.is_nan() &&
1585 out[idx].is_infinite() == expected.is_infinite(),
1586 "[{}] EPMA formula NaN/Inf mismatch at {}: got {}, expected {}",
1587 test_name,
1588 idx,
1589 out[idx],
1590 expected
1591 );
1592 }
1593 }
1594 }
1595 }
1596 }
1597
1598
1599
1600
1601 if offset == period - 1 && warmup < data.len() && weight_sum.abs() > 1e-10 {
1602
1603 for i in warmup..data.len() {
1604 prop_assert!(
1605 out[i].is_finite(),
1606 "[{}] Edge case offset={} produced non-finite at {}",
1607 test_name,
1608 offset,
1609 i
1610 );
1611 }
1612 }
1613
1614 Ok(())
1615 })
1616 .unwrap();
1617
1618 Ok(())
1619 }
1620 fn check_epma_invalid_params(
1621 test_name: &str,
1622 kernel: Kernel,
1623 ) -> Result<(), Box<dyn std::error::Error>> {
1624 skip_if_unsupported!(kernel, test_name);
1625
1626 let zero_weight_cases = vec![(4, 3), (5, 3), (6, 4), (8, 6)];
1627
1628 for (period, offset) in zero_weight_cases {
1629 let p1 = period - 1;
1630 let mut weight_sum = 0.0;
1631 for i in 0..p1 {
1632 let w = (period as i32 - i as i32 - offset as i32) as f64;
1633 weight_sum += w;
1634 }
1635
1636 let data = vec![1.0; period * 2];
1637 let params = EpmaParams {
1638 period: Some(period),
1639 offset: Some(offset),
1640 };
1641 let input = EpmaInput::from_slice(&data, params);
1642
1643 if weight_sum.abs() < 1e-10 {
1644 let out = epma_with_kernel(&input, kernel)?;
1645 let scalar_out = epma_with_kernel(&input, Kernel::Scalar)?;
1646
1647 let warmup = period + offset + 1;
1648 for i in warmup..data.len() {
1649 let both_nan = out.values[i].is_nan() && scalar_out.values[i].is_nan();
1650 let both_inf =
1651 out.values[i].is_infinite() && scalar_out.values[i].is_infinite();
1652 assert!(
1653 both_nan || both_inf,
1654 "[{}] Period={}, Offset={} (weight_sum=0) should produce consistent NaN or Inf, got kernel={}, scalar={}",
1655 test_name,
1656 period,
1657 offset,
1658 out.values[i],
1659 scalar_out.values[i]
1660 );
1661 }
1662 }
1663 }
1664
1665 Ok(())
1666 }
1667
1668 fn check_epma_reinput(
1669 test_name: &str,
1670 kernel: Kernel,
1671 ) -> Result<(), Box<dyn std::error::Error>> {
1672 skip_if_unsupported!(kernel, test_name);
1673 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1674 let candles = read_candles_from_csv(file_path)?;
1675 let first_params = EpmaParams {
1676 period: Some(9),
1677 offset: None,
1678 };
1679 let first_input = EpmaInput::from_candles(&candles, "close", first_params);
1680 let first_result = epma_with_kernel(&first_input, kernel)?;
1681 let second_params = EpmaParams {
1682 period: Some(3),
1683 offset: None,
1684 };
1685 let second_input = EpmaInput::from_slice(&first_result.values, second_params);
1686 let second_result = epma_with_kernel(&second_input, kernel)?;
1687 assert_eq!(second_result.values.len(), first_result.values.len());
1688 Ok(())
1689 }
1690 fn check_epma_nan_handling(
1691 test_name: &str,
1692 kernel: Kernel,
1693 ) -> Result<(), Box<dyn std::error::Error>> {
1694 skip_if_unsupported!(kernel, test_name);
1695 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1696 let candles = read_candles_from_csv(file_path)?;
1697 let params = EpmaParams {
1698 period: Some(11),
1699 offset: Some(4),
1700 };
1701 let input = EpmaInput::from_candles(&candles, "close", params.clone());
1702 let res = epma_with_kernel(&input, kernel)?;
1703 assert_eq!(res.values.len(), candles.close.len());
1704 if res.values.len() > 240 {
1705 for (i, &val) in res.values[240..].iter().enumerate() {
1706 assert!(
1707 !val.is_nan(),
1708 "[{}] Found unexpected NaN at out-index {}",
1709 test_name,
1710 240 + i
1711 );
1712 }
1713 }
1714 Ok(())
1715 }
1716 fn check_epma_streaming(
1717 test_name: &str,
1718 kernel: Kernel,
1719 ) -> Result<(), Box<dyn std::error::Error>> {
1720 skip_if_unsupported!(kernel, test_name);
1721 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1722 let candles = read_candles_from_csv(file_path)?;
1723 let period = 11;
1724 let offset = 4;
1725 let input = EpmaInput::from_candles(
1726 &candles,
1727 "close",
1728 EpmaParams {
1729 period: Some(period),
1730 offset: Some(offset),
1731 },
1732 );
1733 let batch_output = epma_with_kernel(&input, kernel)?.values;
1734 let mut stream = EpmaStream::try_new(EpmaParams {
1735 period: Some(period),
1736 offset: Some(offset),
1737 })?;
1738 let mut stream_values = Vec::with_capacity(candles.close.len());
1739 for &price in &candles.close {
1740 match stream.update(price) {
1741 Some(val) => stream_values.push(val),
1742 None => stream_values.push(f64::NAN),
1743 }
1744 }
1745 assert_eq!(batch_output.len(), stream_values.len());
1746 for (i, (&b, &s)) in batch_output
1747 .iter()
1748 .zip(stream_values.iter())
1749 .enumerate()
1750 .skip(period + offset + 1)
1751 {
1752 if b.is_nan() && s.is_nan() {
1753 continue;
1754 }
1755 let diff = (b - s).abs();
1756 assert!(
1757 diff < 1e-9,
1758 "[{}] EPMA streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1759 test_name,
1760 i,
1761 b,
1762 s,
1763 diff
1764 );
1765 }
1766 Ok(())
1767 }
1768
1769 #[cfg(debug_assertions)]
1770 fn check_epma_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1771 skip_if_unsupported!(kernel, test_name);
1772
1773 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1774 let candles = read_candles_from_csv(file_path)?;
1775
1776 let test_cases = vec![
1777 EpmaParams::default(),
1778 EpmaParams {
1779 period: Some(2),
1780 offset: Some(0),
1781 },
1782 EpmaParams {
1783 period: Some(5),
1784 offset: Some(1),
1785 },
1786 EpmaParams {
1787 period: Some(10),
1788 offset: Some(3),
1789 },
1790 EpmaParams {
1791 period: Some(10),
1792 offset: Some(9),
1793 },
1794 EpmaParams {
1795 period: Some(20),
1796 offset: Some(5),
1797 },
1798 EpmaParams {
1799 period: Some(30),
1800 offset: Some(10),
1801 },
1802 EpmaParams {
1803 period: Some(15),
1804 offset: Some(14),
1805 },
1806 ];
1807
1808 for params in test_cases {
1809 let input = EpmaInput::from_candles(&candles, "close", params.clone());
1810 let output = epma_with_kernel(&input, kernel)?;
1811
1812 for (i, &val) in output.values.iter().enumerate() {
1813 if val.is_nan() {
1814 continue;
1815 }
1816
1817 let bits = val.to_bits();
1818
1819 if bits == 0x11111111_11111111 {
1820 panic!(
1821 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} with params period={:?}, offset={:?}",
1822 test_name, val, bits, i, params.period, params.offset
1823 );
1824 }
1825
1826 if bits == 0x22222222_22222222 {
1827 panic!(
1828 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} with params period={:?}, offset={:?}",
1829 test_name, val, bits, i, params.period, params.offset
1830 );
1831 }
1832
1833 if bits == 0x33333333_33333333 {
1834 panic!(
1835 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} with params period={:?}, offset={:?}",
1836 test_name, val, bits, i, params.period, params.offset
1837 );
1838 }
1839 }
1840 }
1841
1842 Ok(())
1843 }
1844
1845 #[cfg(not(debug_assertions))]
1846 fn check_epma_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1847 Ok(())
1848 }
1849
1850 macro_rules! generate_all_epma_tests {
1851 ($($test_fn:ident),*) => {
1852 paste::paste! {
1853 $(
1854 #[test]
1855 fn [<$test_fn _scalar_f64>]() {
1856 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1857 }
1858 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1859 #[test]
1860 fn [<$test_fn _avx2_f64>]() {
1861 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1862 }
1863 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1864 #[test]
1865 fn [<$test_fn _avx512_f64>]() {
1866 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1867 }
1868 )*
1869
1870 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
1871 $(
1872 #[test]
1873 fn [<$test_fn _simd128_f64>]() {
1874 let _ = $test_fn(stringify!([<$test_fn _simd128_f64>]), Kernel::Scalar);
1875 }
1876 )*
1877 }
1878 }
1879 }
1880 generate_all_epma_tests!(
1881 check_epma_partial_params,
1882 check_epma_accuracy,
1883 check_epma_default_candles,
1884 check_epma_zero_period,
1885 check_epma_period_exceeds_length,
1886 check_epma_very_small_dataset,
1887 check_epma_empty_input,
1888 check_epma_invalid_offset,
1889 check_epma_invalid_params,
1890 check_epma_reinput,
1891 check_epma_nan_handling,
1892 check_epma_streaming,
1893 check_epma_property,
1894 check_epma_no_poison
1895 );
1896
1897 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
1898 #[test]
1899 fn test_epma_simd128_correctness() {
1900 let data = vec![
1901 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1902 ];
1903 let period = 5;
1904 let offset = 2;
1905
1906 let params = EpmaParams {
1907 period: Some(period),
1908 offset: Some(offset),
1909 };
1910 let input = EpmaInput::from_slice(&data, params);
1911
1912 let mut scalar_out = vec![0.0; data.len()];
1913 epma_scalar(&data, period, offset, 0, &mut scalar_out);
1914
1915 let simd128_output = epma_with_kernel(&input, Kernel::Scalar).unwrap();
1916
1917 let warmup = period + offset + 1;
1918 for i in warmup..data.len() {
1919 assert!(
1920 (scalar_out[i] - simd128_output.values[i]).abs() < 1e-10,
1921 "SIMD128 mismatch at index {}: scalar={}, simd128={}",
1922 i,
1923 scalar_out[i],
1924 simd128_output.values[i]
1925 );
1926 }
1927 }
1928
1929 fn check_batch_default_row(
1930 test: &str,
1931 kernel: Kernel,
1932 ) -> Result<(), Box<dyn std::error::Error>> {
1933 skip_if_unsupported!(kernel, test);
1934 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1935 let c = read_candles_from_csv(file)?;
1936 let output = EpmaBatchBuilder::new()
1937 .kernel(kernel)
1938 .apply_candles(&c, "close")?;
1939 let def = EpmaParams::default();
1940 let row = output.values_for(&def).expect("default row missing");
1941 assert_eq!(row.len(), c.close.len());
1942 let expected = [59174.48, 59201.04, 59167.60, 59200.32, 59117.04];
1943 let start = row.len() - 5;
1944 for (i, &v) in row[start..].iter().enumerate() {
1945 assert!(
1946 (v - expected[i]).abs() < 1e-1,
1947 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1948 );
1949 }
1950 Ok(())
1951 }
1952
1953 #[cfg(debug_assertions)]
1954 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1955 skip_if_unsupported!(kernel, test);
1956
1957 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1958 let c = read_candles_from_csv(file)?;
1959
1960 let test_configs = vec![
1961 ((2, 5, 1), (0, 2, 1)),
1962 ((10, 20, 5), (0, 19, 3)),
1963 ((20, 30, 2), (5, 15, 5)),
1964 ((15, 25, 5), (10, 14, 2)),
1965 ((5, 10, 1), (0, 9, 1)),
1966 ];
1967
1968 for (period_range, offset_range) in test_configs {
1969 let output = EpmaBatchBuilder::new()
1970 .kernel(kernel)
1971 .period_range(period_range.0, period_range.1, period_range.2)
1972 .offset_range(offset_range.0, offset_range.1, offset_range.2)
1973 .apply_candles(&c, "close")?;
1974
1975 for (idx, &val) in output.values.iter().enumerate() {
1976 if val.is_nan() {
1977 continue;
1978 }
1979
1980 let bits = val.to_bits();
1981 let row = idx / output.cols;
1982 let col = idx % output.cols;
1983 let params = &output.combos[row];
1984
1985 if bits == 0x11111111_11111111 {
1986 panic!(
1987 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (params: period={:?}, offset={:?})",
1988 test, val, bits, row, col, params.period, params.offset
1989 );
1990 }
1991
1992 if bits == 0x22222222_22222222 {
1993 panic!(
1994 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (params: period={:?}, offset={:?})",
1995 test, val, bits, row, col, params.period, params.offset
1996 );
1997 }
1998
1999 if bits == 0x33333333_33333333 {
2000 panic!(
2001 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (params: period={:?}, offset={:?})",
2002 test, val, bits, row, col, params.period, params.offset
2003 );
2004 }
2005 }
2006 }
2007
2008 Ok(())
2009 }
2010
2011 #[cfg(not(debug_assertions))]
2012 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2013 Ok(())
2014 }
2015
2016 macro_rules! gen_batch_tests {
2017 ($fn_name:ident) => {
2018 paste::paste! {
2019 #[test]
2020 fn [<$fn_name _scalar>]() {
2021 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2022 }
2023 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2024 #[test]
2025 fn [<$fn_name _avx2>]() {
2026 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2027 }
2028 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2029 #[test]
2030 fn [<$fn_name _avx512>]() {
2031 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2032 }
2033 #[test]
2034 fn [<$fn_name _auto_detect>]() {
2035 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2036 }
2037 }
2038 };
2039 }
2040 gen_batch_tests!(check_batch_default_row);
2041 gen_batch_tests!(check_batch_no_poison);
2042
2043 #[test]
2044 fn test_invalid_output_len_error() {
2045 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
2046 let params = EpmaParams {
2047 period: Some(3),
2048 offset: Some(1),
2049 };
2050 let input = EpmaInput::from_slice(&data, params);
2051 let mut wrong_size_dst = vec![0.0; 3];
2052
2053 let result = epma_into_slice(&mut wrong_size_dst, &input, Kernel::Scalar);
2054 assert!(result.is_err());
2055
2056 match result {
2057 Err(EpmaError::OutputLengthMismatch { expected, got }) => {
2058 assert_eq!(expected, 5);
2059 assert_eq!(got, 3);
2060 }
2061 _ => panic!("Expected OutputLengthMismatch error"),
2062 }
2063 }
2064
2065 #[test]
2066 fn test_invalid_kernel_error() {
2067 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
2068 let sweep = EpmaBatchRange {
2069 period: (3, 5, 1),
2070 offset: (1, 2, 1),
2071 };
2072
2073 let result = epma_batch_with_kernel(&data, &sweep, Kernel::Scalar);
2074 assert!(result.is_err());
2075
2076 match result {
2077 Err(EpmaError::InvalidKernelForBatch(k)) => {
2078 assert_eq!(k, Kernel::Scalar);
2079 }
2080 _ => panic!("Expected InvalidKernelForBatch error"),
2081 }
2082 }
2083}
2084
2085#[cfg(feature = "python")]
2086use crate::utilities::kernel_validation::validate_kernel;
2087#[cfg(all(feature = "python", feature = "cuda"))]
2088use numpy::PyReadonlyArray2;
2089#[cfg(feature = "python")]
2090use numpy::PyUntypedArrayMethods;
2091#[cfg(feature = "python")]
2092use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
2093#[cfg(feature = "python")]
2094use pyo3::{exceptions::PyValueError, prelude::*, types::PyDict};
2095
2096#[cfg(feature = "python")]
2097#[pyfunction(name = "epma")]
2098#[pyo3(signature = (data, period, offset, kernel=None))]
2099pub fn epma_py<'py>(
2100 py: Python<'py>,
2101 data: numpy::PyReadonlyArray1<'py, f64>,
2102 period: usize,
2103 offset: usize,
2104 kernel: Option<&str>,
2105) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
2106 let slice_in = data.as_slice()?;
2107 let params = EpmaParams {
2108 period: Some(period),
2109 offset: Some(offset),
2110 };
2111 let input = EpmaInput::from_slice(slice_in, params);
2112 let kern = validate_kernel(kernel, false)?;
2113
2114 let out_arr = unsafe { PyArray1::<f64>::new(py, [slice_in.len()], false) };
2115 let out_slice = unsafe { out_arr.as_slice_mut()? };
2116
2117 py.allow_threads(|| epma_into_slice(out_slice, &input, kern))
2118 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2119
2120 Ok(out_arr)
2121}
2122
2123#[cfg(feature = "python")]
2124#[pyfunction(name = "epma_batch")]
2125#[pyo3(signature = (data, period_range, offset_range, kernel=None))]
2126pub fn epma_batch_py<'py>(
2127 py: Python<'py>,
2128 data: numpy::PyReadonlyArray1<'py, f64>,
2129 period_range: (usize, usize, usize),
2130 offset_range: (usize, usize, usize),
2131 kernel: Option<&str>,
2132) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
2133 let slice_in = data.as_slice()?;
2134
2135 let sweep = EpmaBatchRange {
2136 period: period_range,
2137 offset: offset_range,
2138 };
2139
2140 let combos = expand_grid(&sweep);
2141 let rows = combos.len();
2142 let cols = slice_in.len();
2143
2144 let out_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
2145 let out_slice = unsafe { out_arr.as_slice_mut()? };
2146
2147 let kern = validate_kernel(kernel, true)?;
2148
2149 let combos_done = py
2150 .allow_threads(|| {
2151 let actual = match kern {
2152 Kernel::Auto => detect_best_batch_kernel(),
2153 k => k,
2154 };
2155 let simd = match actual {
2156 Kernel::Avx512Batch => Kernel::Avx512,
2157 Kernel::Avx2Batch => Kernel::Avx2,
2158 Kernel::ScalarBatch => Kernel::Scalar,
2159 _ => unreachable!(),
2160 };
2161 epma_batch_inner_into(slice_in, &sweep, simd, true, out_slice)
2162 })
2163 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2164
2165 let dict = PyDict::new(py);
2166 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2167 dict.set_item(
2168 "periods",
2169 combos_done
2170 .iter()
2171 .map(|p| p.period.unwrap() as u64)
2172 .collect::<Vec<_>>()
2173 .into_pyarray(py),
2174 )?;
2175 dict.set_item(
2176 "offsets",
2177 combos_done
2178 .iter()
2179 .map(|p| p.offset.unwrap() as u64)
2180 .collect::<Vec<_>>()
2181 .into_pyarray(py),
2182 )?;
2183 Ok(dict)
2184}
2185
2186#[cfg(all(feature = "python", feature = "cuda"))]
2187#[pyfunction(name = "epma_cuda_batch_dev")]
2188#[pyo3(signature = (data_f32, period_range, offset_range, device_id=0))]
2189pub fn epma_cuda_batch_dev_py(
2190 py: Python<'_>,
2191 data_f32: numpy::PyReadonlyArray1<'_, f32>,
2192 period_range: (usize, usize, usize),
2193 offset_range: (usize, usize, usize),
2194 device_id: usize,
2195) -> PyResult<EpmaDeviceArrayF32Py> {
2196 if !cuda_available() {
2197 return Err(PyValueError::new_err("CUDA not available"));
2198 }
2199
2200 let sweep = EpmaBatchRange {
2201 period: period_range,
2202 offset: offset_range,
2203 };
2204 let slice_in = data_f32.as_slice()?;
2205
2206 let (inner, ctx_guard, dev_id) = py.allow_threads(|| {
2207 let cuda = CudaEpma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2208 let ctx = cuda.context_arc();
2209 let arr = cuda
2210 .epma_batch_dev(slice_in, &sweep)
2211 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2212 Ok::<_, PyErr>((arr, ctx, device_id as u32))
2213 })?;
2214
2215 Ok(EpmaDeviceArrayF32Py {
2216 inner,
2217 _ctx_guard: ctx_guard,
2218 device_id: dev_id as u32,
2219 })
2220}
2221
2222#[cfg(all(feature = "python", feature = "cuda"))]
2223#[pyfunction(name = "epma_cuda_many_series_one_param_dev")]
2224#[pyo3(signature = (data_tm_f32, period, offset, device_id=0))]
2225pub fn epma_cuda_many_series_one_param_dev_py(
2226 py: Python<'_>,
2227 data_tm_f32: PyReadonlyArray2<'_, f32>,
2228 period: usize,
2229 offset: usize,
2230 device_id: usize,
2231) -> PyResult<EpmaDeviceArrayF32Py> {
2232 if !cuda_available() {
2233 return Err(PyValueError::new_err("CUDA not available"));
2234 }
2235
2236 let flat_in = data_tm_f32.as_slice()?;
2237 let shape = data_tm_f32.shape();
2238 let rows = shape[0];
2239 let cols = shape[1];
2240 let params = EpmaParams {
2241 period: Some(period),
2242 offset: Some(offset),
2243 };
2244
2245 let (inner, ctx_guard, dev_id) = py.allow_threads(|| {
2246 let cuda = CudaEpma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2247 let ctx = cuda.context_arc();
2248 let arr = cuda
2249 .epma_many_series_one_param_time_major_dev(flat_in, cols, rows, ¶ms)
2250 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2251 Ok::<_, PyErr>((arr, ctx, device_id as u32))
2252 })?;
2253
2254 Ok(EpmaDeviceArrayF32Py {
2255 inner,
2256 _ctx_guard: ctx_guard,
2257 device_id: dev_id as u32,
2258 })
2259}
2260
2261#[cfg(all(feature = "python", feature = "cuda"))]
2262#[pyclass(module = "ta_indicators.cuda", unsendable)]
2263pub struct EpmaDeviceArrayF32Py {
2264 pub(crate) inner: DeviceArrayF32,
2265 pub(crate) _ctx_guard: Arc<Context>,
2266 pub(crate) device_id: u32,
2267}
2268
2269#[cfg(all(feature = "python", feature = "cuda"))]
2270#[pymethods]
2271impl EpmaDeviceArrayF32Py {
2272 #[getter]
2273 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2274 let d = PyDict::new(py);
2275
2276 d.set_item("shape", (self.inner.rows, self.inner.cols))?;
2277 d.set_item("typestr", "<f4")?;
2278 d.set_item(
2279 "strides",
2280 (
2281 self.inner.cols * std::mem::size_of::<f32>(),
2282 std::mem::size_of::<f32>(),
2283 ),
2284 )?;
2285 d.set_item("data", (self.inner.device_ptr() as usize, false))?;
2286
2287 d.set_item("version", 3)?;
2288 Ok(d)
2289 }
2290
2291 fn __dlpack_device__(&self) -> (i32, i32) {
2292 (2, self.device_id as i32)
2293 }
2294
2295 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2296 fn __dlpack__<'py>(
2297 &mut self,
2298 py: Python<'py>,
2299 stream: Option<pyo3::PyObject>,
2300 max_version: Option<pyo3::PyObject>,
2301 dl_device: Option<pyo3::PyObject>,
2302 copy: Option<pyo3::PyObject>,
2303 ) -> PyResult<PyObject> {
2304 use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
2305
2306 let (kdl, alloc_dev) = self.__dlpack_device__();
2307 if let Some(dev_obj) = dl_device.as_ref() {
2308 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2309 if dev_ty != kdl || dev_id != alloc_dev {
2310 let wants_copy = copy
2311 .as_ref()
2312 .and_then(|c| c.extract::<bool>(py).ok())
2313 .unwrap_or(false);
2314 if wants_copy {
2315 return Err(PyValueError::new_err(
2316 "device copy not implemented for __dlpack__",
2317 ));
2318 } else {
2319 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
2320 }
2321 }
2322 }
2323 }
2324
2325 let _ = stream;
2326
2327 let dummy = cust::memory::DeviceBuffer::from_slice(&[])
2328 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2329 let inner = std::mem::replace(
2330 &mut self.inner,
2331 DeviceArrayF32 {
2332 buf: dummy,
2333 rows: 0,
2334 cols: 0,
2335 },
2336 );
2337
2338 let rows = inner.rows;
2339 let cols = inner.cols;
2340 let buf = inner.buf;
2341
2342 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2343
2344 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2345 }
2346}
2347
2348#[cfg(feature = "python")]
2349#[pyclass(name = "EpmaStream")]
2350pub struct EpmaStreamPy {
2351 stream: EpmaStream,
2352}
2353
2354#[cfg(feature = "python")]
2355#[pymethods]
2356impl EpmaStreamPy {
2357 #[new]
2358 fn new(period: usize, offset: usize) -> PyResult<Self> {
2359 let params = EpmaParams {
2360 period: Some(period),
2361 offset: Some(offset),
2362 };
2363 let stream =
2364 EpmaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2365 Ok(Self { stream })
2366 }
2367 fn update(&mut self, value: f64) -> Option<f64> {
2368 self.stream.update(value)
2369 }
2370}
2371
2372#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2373use serde::{Deserialize, Serialize};
2374#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2375use wasm_bindgen::prelude::*;
2376
2377#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2378#[wasm_bindgen]
2379pub fn epma_js(data: &[f64], period: usize, offset: usize) -> Result<Vec<f64>, JsValue> {
2380 let params = EpmaParams {
2381 period: Some(period),
2382 offset: Some(offset),
2383 };
2384 let input = EpmaInput::from_slice(data, params);
2385 let mut out = vec![0.0; data.len()];
2386 epma_into_slice(&mut out, &input, detect_best_kernel())
2387 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2388 Ok(out)
2389}
2390
2391#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2392#[derive(Serialize, Deserialize)]
2393pub struct EpmaBatchConfig {
2394 pub period_range: (usize, usize, usize),
2395 pub offset_range: (usize, usize, usize),
2396}
2397
2398#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2399#[derive(Serialize, Deserialize)]
2400pub struct EpmaBatchJsOutput {
2401 pub values: Vec<f64>,
2402 pub combos: Vec<EpmaParams>,
2403 pub rows: usize,
2404 pub cols: usize,
2405}
2406
2407#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2408#[wasm_bindgen(js_name = epma_batch)]
2409pub fn epma_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2410 let cfg: EpmaBatchConfig = serde_wasm_bindgen::from_value(config)
2411 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2412
2413 let sweep = EpmaBatchRange {
2414 period: cfg.period_range,
2415 offset: cfg.offset_range,
2416 };
2417 let out = epma_batch_inner(data, &sweep, detect_best_kernel(), false)
2418 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2419
2420 serde_wasm_bindgen::to_value(&EpmaBatchJsOutput {
2421 values: out.values,
2422 combos: out.combos,
2423 rows: out.rows,
2424 cols: out.cols,
2425 })
2426 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2427}
2428
2429#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2430#[wasm_bindgen]
2431pub fn epma_alloc(len: usize) -> *mut f64 {
2432 let mut v = Vec::<f64>::with_capacity(len);
2433 let ptr = v.as_mut_ptr();
2434 std::mem::forget(v);
2435 ptr
2436}
2437
2438#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2439#[wasm_bindgen]
2440pub fn epma_free(ptr: *mut f64, len: usize) {
2441 unsafe {
2442 let _ = Vec::from_raw_parts(ptr, len, len);
2443 }
2444}
2445
2446#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2447#[wasm_bindgen]
2448pub fn epma_into(
2449 in_ptr: *const f64,
2450 out_ptr: *mut f64,
2451 len: usize,
2452 period: usize,
2453 offset: usize,
2454) -> Result<(), JsValue> {
2455 if in_ptr.is_null() || out_ptr.is_null() {
2456 return Err(JsValue::from_str("null pointer"));
2457 }
2458 unsafe {
2459 let data = std::slice::from_raw_parts(in_ptr, len);
2460 let params = EpmaParams {
2461 period: Some(period),
2462 offset: Some(offset),
2463 };
2464 let input = EpmaInput::from_slice(data, params);
2465
2466 if in_ptr == out_ptr {
2467 let mut tmp = vec![0.0; len];
2468 epma_into_slice(&mut tmp, &input, detect_best_kernel())
2469 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2470 std::slice::from_raw_parts_mut(out_ptr, len).copy_from_slice(&tmp);
2471 } else {
2472 epma_into_slice(
2473 std::slice::from_raw_parts_mut(out_ptr, len),
2474 &input,
2475 detect_best_kernel(),
2476 )
2477 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2478 }
2479 }
2480 Ok(())
2481}
2482
2483#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2484#[wasm_bindgen]
2485pub fn epma_batch_into(
2486 in_ptr: *const f64,
2487 out_ptr: *mut f64,
2488 len: usize,
2489 p_start: usize,
2490 p_end: usize,
2491 p_step: usize,
2492 o_start: usize,
2493 o_end: usize,
2494 o_step: usize,
2495) -> Result<usize, JsValue> {
2496 if in_ptr.is_null() || out_ptr.is_null() {
2497 return Err(JsValue::from_str("null pointer"));
2498 }
2499 unsafe {
2500 let data = std::slice::from_raw_parts(in_ptr, len);
2501 let sweep = EpmaBatchRange {
2502 period: (p_start, p_end, p_step),
2503 offset: (o_start, o_end, o_step),
2504 };
2505 let combos = expand_grid(&sweep);
2506 let rows = combos.len();
2507 let cols = len;
2508 let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
2509
2510 epma_batch_inner_into(data, &sweep, detect_best_kernel(), false, out)
2511 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2512 Ok(rows)
2513 }
2514}