1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5 make_uninit_matrix,
6};
7use aligned_vec::{AVec, CACHELINE_ALIGN};
8#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
9use core::arch::x86_64::*;
10#[cfg(not(target_arch = "wasm32"))]
11use rayon::prelude::*;
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use serde::{Deserialize, Serialize};
14use std::convert::AsRef;
15use thiserror::Error;
16#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
17use wasm_bindgen::prelude::*;
18
19impl<'a> AsRef<[f64]> for PmaInput<'a> {
20 #[inline(always)]
21 fn as_ref(&self) -> &[f64] {
22 match &self.data {
23 PmaData::Slice(slice) => slice,
24 PmaData::Candles { candles, source } => source_type(candles, source),
25 }
26 }
27}
28
29#[derive(Debug, Clone)]
30pub enum PmaData<'a> {
31 Candles {
32 candles: &'a Candles,
33 source: &'a str,
34 },
35 Slice(&'a [f64]),
36}
37
38#[derive(Debug, Clone)]
39pub struct PmaOutput {
40 pub predict: Vec<f64>,
41 pub trigger: Vec<f64>,
42}
43
44#[derive(Debug, Clone)]
45pub struct PmaParams;
46
47impl Default for PmaParams {
48 fn default() -> Self {
49 Self
50 }
51}
52
53#[derive(Debug, Clone)]
54pub struct PmaInput<'a> {
55 pub data: PmaData<'a>,
56 pub params: PmaParams,
57}
58
59impl<'a> PmaInput<'a> {
60 #[inline]
61 pub fn from_candles(c: &'a Candles, s: &'a str, p: PmaParams) -> Self {
62 Self {
63 data: PmaData::Candles {
64 candles: c,
65 source: s,
66 },
67 params: p,
68 }
69 }
70 #[inline]
71 pub fn from_slice(sl: &'a [f64], p: PmaParams) -> Self {
72 Self {
73 data: PmaData::Slice(sl),
74 params: p,
75 }
76 }
77 #[inline]
78 pub fn with_default_candles(c: &'a Candles) -> Self {
79 Self::from_candles(c, "close", PmaParams::default())
80 }
81}
82
83#[derive(Copy, Clone, Debug)]
84pub struct PmaBuilder {
85 kernel: Kernel,
86}
87
88impl Default for PmaBuilder {
89 fn default() -> Self {
90 Self {
91 kernel: Kernel::Auto,
92 }
93 }
94}
95
96impl PmaBuilder {
97 #[inline(always)]
98 pub fn new() -> Self {
99 Self::default()
100 }
101 #[inline(always)]
102 pub fn kernel(mut self, k: Kernel) -> Self {
103 self.kernel = k;
104 self
105 }
106
107 #[inline(always)]
108 pub fn apply(self, c: &Candles) -> Result<PmaOutput, PmaError> {
109 let i = PmaInput::from_candles(c, "close", PmaParams::default());
110 pma_with_kernel(&i, self.kernel)
111 }
112
113 #[inline(always)]
114 pub fn apply_slice(self, d: &[f64]) -> Result<PmaOutput, PmaError> {
115 let i = PmaInput::from_slice(d, PmaParams::default());
116 pma_with_kernel(&i, self.kernel)
117 }
118
119 #[inline(always)]
120 pub fn into_stream(self) -> Result<PmaStream, PmaError> {
121 PmaStream::try_new(PmaParams::default())
122 }
123}
124
125#[derive(Debug, Error)]
126pub enum PmaError {
127 #[error("pma: Empty data provided.")]
128 EmptyInputData,
129 #[error("pma: All values are NaN.")]
130 AllValuesNaN,
131 #[error("pma: Not enough valid data: needed = {needed}, valid = {valid}")]
132 NotEnoughValidData { needed: usize, valid: usize },
133 #[error("pma: Invalid period: period = {period}, data length = {data_len}")]
134 InvalidPeriod { period: usize, data_len: usize },
135 #[error("pma: Output slice length mismatch: expected = {expected}, got = {got}")]
136 OutputLengthMismatch { expected: usize, got: usize },
137 #[error("pma: Invalid range: start = {start}, end = {end}, step = {step}")]
138 InvalidRange {
139 start: usize,
140 end: usize,
141 step: usize,
142 },
143 #[error("pma: invalid kernel for batch API: {0:?}")]
144 InvalidKernelForBatch(Kernel),
145 #[error("pma: size overflow computing rows*cols: rows = {rows}, cols = {cols}")]
146 SizeOverflow { rows: usize, cols: usize },
147}
148
149#[inline(always)]
150fn pma_first_valid_idx(data: &[f64]) -> Result<usize, PmaError> {
151 if data.is_empty() {
152 return Err(PmaError::EmptyInputData);
153 }
154 let first = data
155 .iter()
156 .position(|x| !x.is_nan())
157 .ok_or(PmaError::AllValuesNaN)?;
158 let valid = data.len() - first;
159 if valid < 7 {
160 return Err(PmaError::NotEnoughValidData { needed: 7, valid });
161 }
162 Ok(first)
163}
164
165#[inline]
166pub fn pma(input: &PmaInput) -> Result<PmaOutput, PmaError> {
167 pma_with_kernel(input, Kernel::Auto)
168}
169
170pub fn pma_with_kernel(input: &PmaInput, kernel: Kernel) -> Result<PmaOutput, PmaError> {
171 let data: &[f64] = match &input.data {
172 PmaData::Candles { candles, source } => source_type(candles, source),
173 PmaData::Slice(sl) => sl,
174 };
175
176 let first = pma_first_valid_idx(data)?;
177
178 let chosen = match kernel {
179 Kernel::Auto => Kernel::Scalar,
180 other => other,
181 };
182
183 unsafe {
184 match chosen {
185 Kernel::Scalar | Kernel::ScalarBatch => pma_scalar(data, first),
186 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
187 Kernel::Avx2 | Kernel::Avx2Batch => pma_avx2(data, first),
188 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
189 Kernel::Avx512 | Kernel::Avx512Batch => pma_avx512(data, first),
190 _ => unreachable!(),
191 }
192 }
193}
194
195#[inline]
196pub fn pma_scalar(data: &[f64], first_valid_idx: usize) -> Result<PmaOutput, PmaError> {
197 let n = data.len();
198 let warmup_period = first_valid_idx + 7;
199 let mut predict = alloc_with_nan_prefix(n, warmup_period);
200 let mut trigger = alloc_with_nan_prefix(n, warmup_period);
201
202 if n <= first_valid_idx + 6 {
203 return Ok(PmaOutput { predict, trigger });
204 }
205
206 const INV_28: f64 = 1.0 / 28.0;
207 const INV_10: f64 = 1.0 / 10.0;
208
209 let mut x_ring = [0.0_f64; 7];
210 let mut w_ring = [0.0_f64; 7];
211 let mut p_ring = [0.0_f64; 4];
212 let mut x_head = 0usize;
213 let mut w_head = 0usize;
214 let mut p_head = 0usize;
215
216 let mut A = 0.0_f64;
217 let mut S = 0.0_f64;
218 let mut A1 = 0.0_f64;
219 let mut S1 = 0.0_f64;
220 let mut A2 = 0.0_f64;
221 let mut T = 0.0_f64;
222
223 let j0 = first_valid_idx + 6;
224
225 unsafe {
226 let dp = data.as_ptr();
227
228 let x0 = *dp.add(j0 - 6);
229 let x1 = *dp.add(j0 - 5);
230 let x2 = *dp.add(j0 - 4);
231 let x3 = *dp.add(j0 - 3);
232 let x4 = *dp.add(j0 - 2);
233 let x5 = *dp.add(j0 - 1);
234 let x6 = *dp.add(j0 - 0);
235
236 x_ring[0] = x0;
237 x_ring[1] = x1;
238 x_ring[2] = x2;
239 x_ring[3] = x3;
240 x_ring[4] = x4;
241 x_ring[5] = x5;
242 x_ring[6] = x6;
243
244 A = ((x0 + x1) + (x2 + x3)) + ((x4 + x5) + x6);
245
246 let s01 = x0.mul_add(1.0, 2.0 * x1);
247 let s23 = (3.0 * x2) + (4.0 * x3);
248 let s45 = (5.0 * x4) + (6.0 * x5);
249 S = (s01 + s23) + s45 + 7.0 * x6;
250
251 let mut w1 = S * INV_28;
252
253 let old_A1 = A1;
254 let old_w = w_ring[w_head];
255 S1 = (7.0_f64).mul_add(w1, S1) - old_A1;
256 A1 = A1 + w1 - old_w;
257 w_ring[w_head] = w1;
258 w_head += 1;
259 if w_head == 7 {
260 w_head = 0;
261 }
262
263 let mut w2 = S1 * INV_28;
264 let mut pr = (2.0_f64).mul_add(w1, -w2);
265 *predict.get_unchecked_mut(j0) = pr;
266
267 let old_A2 = A2;
268 let old_p = p_ring[p_head];
269 T = (4.0_f64).mul_add(pr, T) - old_A2;
270 A2 = A2 + pr - old_p;
271 p_ring[p_head] = pr;
272 p_head += 1;
273 if p_head == 4 {
274 p_head = 0;
275 }
276 *trigger.get_unchecked_mut(j0) = f64::NAN;
277
278 let mut j = j0 + 1;
279 while j < n {
280 let x_new = *dp.add(j);
281 let x_old = x_ring[x_head];
282 let old_A = A;
283
284 A = A + x_new - x_old;
285 S = (7.0_f64).mul_add(x_new, S) - old_A;
286
287 x_ring[x_head] = x_new;
288 x_head += 1;
289 if x_head == 7 {
290 x_head = 0;
291 }
292
293 w1 = S * INV_28;
294
295 let old_A1 = A1;
296 let w_old = w_ring[w_head];
297 S1 = (7.0_f64).mul_add(w1, S1) - old_A1;
298 A1 = A1 + w1 - w_old;
299
300 w_ring[w_head] = w1;
301 w_head += 1;
302 if w_head == 7 {
303 w_head = 0;
304 }
305
306 w2 = S1 * INV_28;
307
308 pr = (2.0_f64).mul_add(w1, -w2);
309 *predict.get_unchecked_mut(j) = pr;
310
311 let old_A2 = A2;
312 let p_old = p_ring[p_head];
313 T = (4.0_f64).mul_add(pr, T) - old_A2;
314 A2 = A2 + pr - p_old;
315
316 p_ring[p_head] = pr;
317 p_head += 1;
318 if p_head == 4 {
319 p_head = 0;
320 }
321
322 if j >= first_valid_idx + 9 {
323 *trigger.get_unchecked_mut(j) = T * INV_10;
324 } else {
325 *trigger.get_unchecked_mut(j) = f64::NAN;
326 }
327
328 j += 1;
329 }
330 }
331
332 Ok(PmaOutput { predict, trigger })
333}
334
335#[inline(always)]
336fn pma_compute_into(
337 data: &[f64],
338 first_valid_idx: usize,
339 _kernel: Kernel,
340 predict_out: &mut [f64],
341 trigger_out: &mut [f64],
342) {
343 let n = data.len();
344 if n <= first_valid_idx + 6 {
345 return;
346 }
347
348 const INV_28: f64 = 1.0 / 28.0;
349 const INV_10: f64 = 1.0 / 10.0;
350
351 let mut x_ring = [0.0_f64; 7];
352 let mut w_ring = [0.0_f64; 7];
353 let mut p_ring = [0.0_f64; 4];
354 let mut x_head = 0usize;
355 let mut w_head = 0usize;
356 let mut p_head = 0usize;
357
358 let mut A = 0.0_f64;
359 let mut S = 0.0_f64;
360 let mut A1 = 0.0_f64;
361 let mut S1 = 0.0_f64;
362 let mut A2 = 0.0_f64;
363 let mut T = 0.0_f64;
364
365 let j0 = first_valid_idx + 6;
366
367 unsafe {
368 let dp = data.as_ptr();
369
370 let x0 = *dp.add(j0 - 6);
371 let x1 = *dp.add(j0 - 5);
372 let x2 = *dp.add(j0 - 4);
373 let x3 = *dp.add(j0 - 3);
374 let x4 = *dp.add(j0 - 2);
375 let x5 = *dp.add(j0 - 1);
376 let x6 = *dp.add(j0 - 0);
377
378 x_ring[0] = x0;
379 x_ring[1] = x1;
380 x_ring[2] = x2;
381 x_ring[3] = x3;
382 x_ring[4] = x4;
383 x_ring[5] = x5;
384 x_ring[6] = x6;
385
386 A = ((x0 + x1) + (x2 + x3)) + ((x4 + x5) + x6);
387
388 let s01 = x0.mul_add(1.0, 2.0 * x1);
389 let s23 = (3.0 * x2) + (4.0 * x3);
390 let s45 = (5.0 * x4) + (6.0 * x5);
391 S = (s01 + s23) + s45 + 7.0 * x6;
392
393 let mut w1 = S * INV_28;
394
395 let old_A1 = A1;
396 let old_w = w_ring[w_head];
397 S1 = (7.0_f64).mul_add(w1, S1) - old_A1;
398 A1 = A1 + w1 - old_w;
399 w_ring[w_head] = w1;
400 w_head += 1;
401 if w_head == 7 {
402 w_head = 0;
403 }
404
405 let mut w2 = S1 * INV_28;
406 let mut pr = (2.0_f64).mul_add(w1, -w2);
407 *predict_out.get_unchecked_mut(j0) = pr;
408
409 let old_A2 = A2;
410 let old_p = p_ring[p_head];
411 T = (4.0_f64).mul_add(pr, T) - old_A2;
412 A2 = A2 + pr - old_p;
413 p_ring[p_head] = pr;
414 p_head += 1;
415 if p_head == 4 {
416 p_head = 0;
417 }
418
419 *trigger_out.get_unchecked_mut(j0) = f64::NAN;
420
421 let mut j = j0 + 1;
422 while j < n {
423 let x_new = *dp.add(j);
424 let x_old = x_ring[x_head];
425 let old_A = A;
426
427 A = A + x_new - x_old;
428 S = (7.0_f64).mul_add(x_new, S) - old_A;
429
430 x_ring[x_head] = x_new;
431 x_head += 1;
432 if x_head == 7 {
433 x_head = 0;
434 }
435
436 w1 = S * INV_28;
437
438 let old_A1 = A1;
439 let w_old = w_ring[w_head];
440 S1 = (7.0_f64).mul_add(w1, S1) - old_A1;
441 A1 = A1 + w1 - w_old;
442
443 w_ring[w_head] = w1;
444 w_head += 1;
445 if w_head == 7 {
446 w_head = 0;
447 }
448
449 w2 = S1 * INV_28;
450 pr = (2.0_f64).mul_add(w1, -w2);
451
452 *predict_out.get_unchecked_mut(j) = pr;
453
454 let old_A2 = A2;
455 let p_old = p_ring[p_head];
456 T = (4.0_f64).mul_add(pr, T) - old_A2;
457 A2 = A2 + pr - p_old;
458
459 p_ring[p_head] = pr;
460 p_head += 1;
461 if p_head == 4 {
462 p_head = 0;
463 }
464
465 if j >= first_valid_idx + 9 {
466 *trigger_out.get_unchecked_mut(j) = T * INV_10;
467 } else {
468 *trigger_out.get_unchecked_mut(j) = f64::NAN;
469 }
470
471 j += 1;
472 }
473 }
474}
475
476#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
477#[inline]
478pub fn pma_avx512(data: &[f64], first_valid_idx: usize) -> Result<PmaOutput, PmaError> {
479 pma_scalar(data, first_valid_idx)
480}
481
482#[inline]
483pub fn pma_avx2(data: &[f64], first_valid_idx: usize) -> Result<PmaOutput, PmaError> {
484 pma_scalar(data, first_valid_idx)
485}
486
487#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
488#[inline]
489pub fn pma_avx512_short(data: &[f64], first_valid_idx: usize) -> Result<PmaOutput, PmaError> {
490 pma_scalar(data, first_valid_idx)
491}
492
493#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
494#[inline]
495pub fn pma_avx512_long(data: &[f64], first_valid_idx: usize) -> Result<PmaOutput, PmaError> {
496 pma_scalar(data, first_valid_idx)
497}
498
499#[inline]
500pub fn pma_batch_with_kernel(
501 data: &[f64],
502 sweep: &PmaBatchRange,
503 k: Kernel,
504) -> Result<PmaBatchOutput, PmaError> {
505 let kernel = match k {
506 Kernel::Auto => detect_best_batch_kernel(),
507 other if other.is_batch() => other,
508 other => return Err(PmaError::InvalidKernelForBatch(other)),
509 };
510 let simd = match kernel {
511 Kernel::Avx512Batch => Kernel::Avx512,
512 Kernel::Avx2Batch => Kernel::Avx2,
513 Kernel::ScalarBatch => Kernel::Scalar,
514 _ => unreachable!(),
515 };
516 pma_batch_par_slice(data, sweep, simd)
517}
518
519#[inline]
520pub fn pma_batch_unified_with_kernel(
521 data: &[f64],
522 k: Kernel,
523) -> Result<PmaBatchOutputUnified, PmaError> {
524 let kernel = match k {
525 Kernel::Auto => detect_best_batch_kernel(),
526 other if other.is_batch() => other,
527 _ => Kernel::ScalarBatch,
528 };
529 pma_batch_unified_inner(data, kernel)
530}
531
532#[inline]
533fn pma_batch_unified_inner(data: &[f64], kern: Kernel) -> Result<PmaBatchOutputUnified, PmaError> {
534 let first = pma_first_valid_idx(data)?;
535
536 let rows = 2usize;
537 let cols = data.len();
538 let _ = rows
539 .checked_mul(cols)
540 .ok_or(PmaError::SizeOverflow { rows, cols })?;
541
542 let mut buf_mu = make_uninit_matrix(rows, cols);
543 let warm = [first + 7 - 1; 2];
544 init_matrix_prefixes(&mut buf_mu, cols, &warm);
545
546 let mut guard = core::mem::ManuallyDrop::new(buf_mu);
547 let outf: &mut [f64] =
548 unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
549
550 let (row0, row1) = outf.split_at_mut(cols);
551 pma_compute_into(
552 data,
553 first,
554 match kern {
555 Kernel::ScalarBatch => Kernel::Scalar,
556 Kernel::Avx2Batch => Kernel::Avx2,
557 Kernel::Avx512Batch => Kernel::Avx512,
558 _ => Kernel::Scalar,
559 },
560 row0,
561 row1,
562 );
563
564 let values = unsafe {
565 Vec::from_raw_parts(
566 guard.as_mut_ptr() as *mut f64,
567 guard.len(),
568 guard.capacity(),
569 )
570 };
571 Ok(PmaBatchOutputUnified { values, rows, cols })
572}
573
574#[derive(Debug, Clone)]
575pub struct PmaStream {
576 buffer: [f64; 7],
577 wma1: [f64; 7],
578 idx: usize,
579 filled7: bool,
580
581 pred4: [f64; 4],
582 pred_idx: usize,
583 pred_filled: bool,
584}
585
586impl PmaStream {
587 pub fn try_new(_params: PmaParams) -> Result<Self, PmaError> {
588 Ok(Self {
589 buffer: [f64::NAN; 7],
590 wma1: [0.0; 7],
591 idx: 0,
592 filled7: false,
593 pred4: [f64::NAN; 4],
594 pred_idx: 0,
595 pred_filled: false,
596 })
597 }
598 #[inline(always)]
599 pub fn update(&mut self, value: f64) -> Option<(f64, f64)> {
600 self.buffer[self.idx] = value;
601 self.idx = (self.idx + 1) % 7;
602 if !self.filled7 && self.idx == 0 {
603 self.filled7 = true;
604 }
605 if !self.filled7 {
606 return None;
607 }
608
609 let s = |k: usize| self.buffer[(self.idx + k) % 7];
610 let wma1_j =
611 (7.0 * s(6) + 6.0 * s(5) + 5.0 * s(4) + 4.0 * s(3) + 3.0 * s(2) + 2.0 * s(1) + s(0))
612 / 28.0;
613 self.wma1[self.idx] = wma1_j;
614
615 let w = |k: usize| self.wma1[(self.idx + k) % 7];
616 let wma2 =
617 (7.0 * w(6) + 6.0 * w(5) + 5.0 * w(4) + 4.0 * w(3) + 3.0 * w(2) + 2.0 * w(1) + w(0))
618 / 28.0;
619
620 let predict = 2.0 * wma1_j - wma2;
621
622 self.pred4[self.pred_idx] = predict;
623 self.pred_idx = (self.pred_idx + 1) % 4;
624 if !self.pred_filled && self.pred_idx == 0 {
625 self.pred_filled = true;
626 }
627
628 let trigger = if self.pred_filled {
629 let t3 = self.pred4[(self.pred_idx + 3) % 4];
630 let t2 = self.pred4[(self.pred_idx + 2) % 4];
631 let t1 = self.pred4[(self.pred_idx + 1) % 4];
632 let t0 = self.pred4[(self.pred_idx + 0) % 4];
633 (4.0 * t3 + 3.0 * t2 + 2.0 * t1 + t0) / 10.0
634 } else {
635 f64::NAN
636 };
637
638 Some((predict, trigger))
639 }
640}
641
642#[derive(Clone, Debug)]
643pub struct PmaBatchRange {
644 pub dummy: (usize, usize, usize),
645}
646
647impl Default for PmaBatchRange {
648 fn default() -> Self {
649 Self { dummy: (0, 0, 0) }
650 }
651}
652
653#[derive(Clone, Debug, Default)]
654pub struct PmaBatchBuilder {
655 range: PmaBatchRange,
656 kernel: Kernel,
657}
658
659impl PmaBatchBuilder {
660 pub fn new() -> Self {
661 Self::default()
662 }
663 pub fn kernel(mut self, k: Kernel) -> Self {
664 self.kernel = k;
665 self
666 }
667 #[inline]
668 pub fn apply_slice(self, data: &[f64]) -> Result<PmaBatchOutput, PmaError> {
669 pma_batch_with_kernel(data, &self.range, self.kernel)
670 }
671 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<PmaBatchOutput, PmaError> {
672 PmaBatchBuilder::new().kernel(k).apply_slice(data)
673 }
674 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<PmaBatchOutput, PmaError> {
675 let slice = source_type(c, src);
676 self.apply_slice(slice)
677 }
678 pub fn with_default_candles(c: &Candles) -> Result<PmaBatchOutput, PmaError> {
679 PmaBatchBuilder::new()
680 .kernel(Kernel::Auto)
681 .apply_candles(c, "close")
682 }
683}
684
685#[derive(Clone, Debug)]
686pub struct PmaBatchOutput {
687 pub predict: Vec<f64>,
688 pub trigger: Vec<f64>,
689 pub rows: usize,
690 pub cols: usize,
691}
692impl PmaBatchOutput {
693 pub fn values_for(&self, _dummy: &PmaParams) -> Option<(&[f64], &[f64])> {
694 Some((&self.predict[..], &self.trigger[..]))
695 }
696}
697
698#[derive(Clone, Debug)]
699#[cfg_attr(
700 all(target_arch = "wasm32", feature = "wasm"),
701 derive(Serialize, Deserialize)
702)]
703pub struct PmaBatchOutputUnified {
704 pub values: Vec<f64>,
705 pub rows: usize,
706 pub cols: usize,
707}
708
709#[inline(always)]
710pub fn expand_grid(_r: &PmaBatchRange) -> Vec<PmaParams> {
711 vec![PmaParams {}]
712}
713
714#[inline(always)]
715pub fn pma_batch_slice(
716 data: &[f64],
717 sweep: &PmaBatchRange,
718 kern: Kernel,
719) -> Result<PmaBatchOutput, PmaError> {
720 pma_batch_inner(data, sweep, kern, false)
721}
722
723#[inline(always)]
724pub fn pma_batch_par_slice(
725 data: &[f64],
726 sweep: &PmaBatchRange,
727 kern: Kernel,
728) -> Result<PmaBatchOutput, PmaError> {
729 pma_batch_inner(data, sweep, kern, true)
730}
731
732#[inline(always)]
733fn pma_batch_inner(
734 data: &[f64],
735 _sweep: &PmaBatchRange,
736 kern: Kernel,
737 _parallel: bool,
738) -> Result<PmaBatchOutput, PmaError> {
739 let first = pma_first_valid_idx(data)?;
740 let out = match kern {
741 Kernel::Scalar => pma_scalar(data, first)?,
742 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
743 Kernel::Avx2 => pma_avx2(data, first)?,
744 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
745 Kernel::Avx512 => pma_avx512(data, first)?,
746 _ => unreachable!(),
747 };
748 Ok(PmaBatchOutput {
749 predict: out.predict,
750 trigger: out.trigger,
751 rows: 1,
752 cols: data.len(),
753 })
754}
755
756#[inline(always)]
757pub unsafe fn pma_row_scalar(
758 data: &[f64],
759 first: usize,
760 _stride: usize,
761 _dummy: *const f64,
762 _inv_n: f64,
763 out_predict: &mut [f64],
764 out_trigger: &mut [f64],
765) {
766 pma_compute_into(data, first, Kernel::Scalar, out_predict, out_trigger);
767}
768
769#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
770#[inline(always)]
771pub unsafe fn pma_row_avx2(
772 data: &[f64],
773 first: usize,
774 stride: usize,
775 dummy: *const f64,
776 inv_n: f64,
777 out_predict: &mut [f64],
778 out_trigger: &mut [f64],
779) {
780 pma_row_scalar(data, first, stride, dummy, inv_n, out_predict, out_trigger);
781}
782
783#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
784#[inline(always)]
785pub unsafe fn pma_row_avx512(
786 data: &[f64],
787 first: usize,
788 stride: usize,
789 dummy: *const f64,
790 inv_n: f64,
791 out_predict: &mut [f64],
792 out_trigger: &mut [f64],
793) {
794 pma_row_scalar(data, first, stride, dummy, inv_n, out_predict, out_trigger);
795}
796
797#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
798#[inline(always)]
799pub unsafe fn pma_row_avx512_short(
800 data: &[f64],
801 first: usize,
802 stride: usize,
803 dummy: *const f64,
804 inv_n: f64,
805 out_predict: &mut [f64],
806 out_trigger: &mut [f64],
807) {
808 pma_row_scalar(data, first, stride, dummy, inv_n, out_predict, out_trigger);
809}
810
811#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
812#[inline(always)]
813pub unsafe fn pma_row_avx512_long(
814 data: &[f64],
815 first: usize,
816 stride: usize,
817 dummy: *const f64,
818 inv_n: f64,
819 out_predict: &mut [f64],
820 out_trigger: &mut [f64],
821) {
822 pma_row_scalar(data, first, stride, dummy, inv_n, out_predict, out_trigger);
823}
824
825#[inline]
826pub fn pma_into_slice(
827 predict_dst: &mut [f64],
828 trigger_dst: &mut [f64],
829 input: &PmaInput,
830 kern: Kernel,
831) -> Result<(), PmaError> {
832 let data = input.as_ref();
833
834 if predict_dst.len() != data.len() || trigger_dst.len() != data.len() {
835 return Err(PmaError::OutputLengthMismatch {
836 expected: data.len(),
837 got: predict_dst.len().min(trigger_dst.len()),
838 });
839 }
840
841 let first = pma_first_valid_idx(data)?;
842
843 let chosen = match kern {
844 Kernel::Auto => Kernel::Scalar,
845 k => k,
846 };
847
848 pma_compute_into(data, first, chosen, predict_dst, trigger_dst);
849
850 let warm_end = first + 7 - 1;
851 for v in &mut predict_dst[..warm_end] {
852 *v = f64::NAN;
853 }
854 for v in &mut trigger_dst[..warm_end] {
855 *v = f64::NAN;
856 }
857
858 Ok(())
859}
860
861#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
862#[inline]
863pub fn pma_into(
864 input: &PmaInput,
865 predict_out: &mut [f64],
866 trigger_out: &mut [f64],
867) -> Result<(), PmaError> {
868 pma_into_slice(predict_out, trigger_out, input, Kernel::Auto)
869}
870
871#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
872#[wasm_bindgen]
873pub fn pma_js(data: &[f64]) -> Result<Vec<f64>, JsValue> {
874 let input = PmaInput::from_slice(data, PmaParams {});
875 let rows = 2usize;
876 let cols = data.len();
877 let total = rows
878 .checked_mul(cols)
879 .ok_or_else(|| JsValue::from_str(&PmaError::SizeOverflow { rows, cols }.to_string()))?;
880 let mut values = vec![0.0; total];
881 {
882 let (pred, trig) = values.split_at_mut(cols);
883 pma_into_slice(pred, trig, &input, detect_best_kernel())
884 .map_err(|e| JsValue::from_str(&e.to_string()))?;
885 }
886
887 Ok(values)
888}
889
890#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
891#[wasm_bindgen]
892pub fn pma_into(
893 in_ptr: *const f64,
894 predict_ptr: *mut f64,
895 trigger_ptr: *mut f64,
896 len: usize,
897) -> Result<(), JsValue> {
898 if in_ptr.is_null() || predict_ptr.is_null() || trigger_ptr.is_null() {
899 return Err(JsValue::from_str("Null pointer provided"));
900 }
901
902 unsafe {
903 let data = std::slice::from_raw_parts(in_ptr, len);
904 let params = PmaParams {};
905 let input = PmaInput::from_slice(data, params);
906
907 let need_temp =
908 in_ptr == predict_ptr || in_ptr == trigger_ptr || predict_ptr == trigger_ptr;
909
910 if need_temp {
911 let mut temp_predict = vec![0.0; len];
912 let mut temp_trigger = vec![0.0; len];
913
914 pma_into_slice(&mut temp_predict, &mut temp_trigger, &input, Kernel::Auto)
915 .map_err(|e| JsValue::from_str(&e.to_string()))?;
916
917 let predict_out = std::slice::from_raw_parts_mut(predict_ptr, len);
918 let trigger_out = std::slice::from_raw_parts_mut(trigger_ptr, len);
919
920 predict_out.copy_from_slice(&temp_predict);
921 trigger_out.copy_from_slice(&temp_trigger);
922 } else {
923 let predict_out = std::slice::from_raw_parts_mut(predict_ptr, len);
924 let trigger_out = std::slice::from_raw_parts_mut(trigger_ptr, len);
925
926 pma_into_slice(predict_out, trigger_out, &input, Kernel::Auto)
927 .map_err(|e| JsValue::from_str(&e.to_string()))?;
928 }
929
930 Ok(())
931 }
932}
933
934#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
935#[wasm_bindgen]
936pub fn pma_alloc(len: usize) -> *mut f64 {
937 let mut vec = Vec::<f64>::with_capacity(len);
938 let ptr = vec.as_mut_ptr();
939 std::mem::forget(vec);
940 ptr
941}
942
943#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
944#[wasm_bindgen]
945pub fn pma_free(ptr: *mut f64, len: usize) {
946 if !ptr.is_null() {
947 unsafe {
948 let _ = Vec::from_raw_parts(ptr, len, len);
949 }
950 }
951}
952
953#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
954#[derive(Serialize, Deserialize)]
955pub struct PmaBatchConfig {
956 pub dummy: Option<usize>,
957}
958
959#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
960#[derive(Serialize, Deserialize)]
961pub struct PmaJsOutput {
962 pub values: Vec<f64>,
963 pub rows: usize,
964 pub cols: usize,
965}
966
967#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
968#[derive(Serialize, Deserialize)]
969pub struct PmaBatchJsOutput {
970 pub predict: Vec<f64>,
971 pub trigger: Vec<f64>,
972 pub rows: usize,
973 pub cols: usize,
974}
975
976#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
977#[wasm_bindgen]
978pub fn pma_batch(data: &[f64]) -> Result<JsValue, JsValue> {
979 let input = PmaInput::from_slice(data, PmaParams {});
980 let mut predict = vec![0.0; data.len()];
981 let mut trigger = vec![0.0; data.len()];
982
983 pma_into_slice(&mut predict, &mut trigger, &input, detect_best_kernel())
984 .map_err(|e| JsValue::from_str(&e.to_string()))?;
985
986 let output = PmaBatchJsOutput {
987 predict,
988 trigger,
989 rows: 1,
990 cols: data.len(),
991 };
992
993 serde_wasm_bindgen::to_value(&output)
994 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
995}
996
997#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
998#[wasm_bindgen]
999pub fn pma_unified_into(
1000 in_ptr: *const f64,
1001 out_ptr: *mut f64,
1002 len: usize,
1003) -> Result<usize, JsValue> {
1004 if in_ptr.is_null() || out_ptr.is_null() {
1005 return Err(JsValue::from_str("null pointer"));
1006 }
1007 let rows = 2usize;
1008 let cols = len;
1009 let total = rows
1010 .checked_mul(cols)
1011 .ok_or_else(|| JsValue::from_str(&PmaError::SizeOverflow { rows, cols }.to_string()))?;
1012 unsafe {
1013 let data = std::slice::from_raw_parts(in_ptr, len);
1014 let out = std::slice::from_raw_parts_mut(out_ptr, total);
1015 let input = PmaInput::from_slice(data, PmaParams {});
1016 let (pred, trig) = out.split_at_mut(cols);
1017 pma_into_slice(pred, trig, &input, detect_best_kernel())
1018 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1019 }
1020 Ok(rows)
1021}
1022
1023#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1024#[wasm_bindgen]
1025pub fn pma_batch_into(
1026 in_ptr: *const f64,
1027 predict_ptr: *mut f64,
1028 trigger_ptr: *mut f64,
1029 len: usize,
1030) -> Result<usize, JsValue> {
1031 pma_into(in_ptr, predict_ptr, trigger_ptr, len)?;
1032 Ok(1)
1033}
1034
1035#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1036#[wasm_bindgen]
1037pub struct PmaStreamWasm {
1038 stream: PmaStream,
1039}
1040
1041#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1042#[wasm_bindgen]
1043impl PmaStreamWasm {
1044 #[wasm_bindgen(constructor)]
1045 pub fn new() -> Result<PmaStreamWasm, JsValue> {
1046 let params = PmaParams {};
1047 let stream = PmaStream::try_new(params).map_err(|e| JsValue::from_str(&e.to_string()))?;
1048 Ok(PmaStreamWasm { stream })
1049 }
1050
1051 pub fn update(&mut self, value: f64) -> Result<Vec<f64>, JsValue> {
1052 match self.stream.update(value) {
1053 Some((predict, trigger)) => Ok(vec![predict, trigger]),
1054 None => Ok(vec![f64::NAN, f64::NAN]),
1055 }
1056 }
1057}
1058
1059#[cfg(feature = "python")]
1060use crate::utilities::kernel_validation::validate_kernel;
1061#[cfg(all(feature = "python", feature = "cuda"))]
1062use numpy::PyUntypedArrayMethods;
1063#[cfg(feature = "python")]
1064use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
1065#[cfg(feature = "python")]
1066use pyo3::exceptions::PyValueError;
1067#[cfg(feature = "python")]
1068use pyo3::prelude::*;
1069#[cfg(feature = "python")]
1070use pyo3::types::PyDict;
1071
1072#[cfg(all(feature = "python", feature = "cuda"))]
1073use crate::cuda::{cuda_available, moving_averages::CudaPma};
1074#[cfg(all(feature = "python", feature = "cuda"))]
1075use crate::indicators::moving_averages::alma::{make_device_array_py, DeviceArrayF32Py};
1076
1077#[cfg(feature = "python")]
1078#[pyfunction(name = "pma")]
1079#[pyo3(signature = (data, kernel=None))]
1080pub fn pma_py<'py>(
1081 py: Python<'py>,
1082 data: PyReadonlyArray1<'py, f64>,
1083 kernel: Option<&str>,
1084) -> PyResult<(Bound<'py, PyArray1<f64>>, Bound<'py, PyArray1<f64>>)> {
1085 let slice_in = data.as_slice()?;
1086 let kern = validate_kernel(kernel, false)?;
1087
1088 let input = PmaInput::from_slice(slice_in, PmaParams {});
1089
1090 let out = py
1091 .allow_threads(|| pma_with_kernel(&input, kern))
1092 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1093
1094 Ok((out.predict.into_pyarray(py), out.trigger.into_pyarray(py)))
1095}
1096
1097#[cfg(feature = "python")]
1098#[pyclass(name = "PmaStream")]
1099pub struct PmaStreamPy {
1100 stream: PmaStream,
1101}
1102
1103#[cfg(feature = "python")]
1104#[pymethods]
1105impl PmaStreamPy {
1106 #[new]
1107 fn new() -> PyResult<Self> {
1108 let params = PmaParams {};
1109 let stream =
1110 PmaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1111 Ok(PmaStreamPy { stream })
1112 }
1113
1114 fn update(&mut self, value: f64) -> Option<(f64, f64)> {
1115 self.stream.update(value)
1116 }
1117}
1118
1119#[cfg(feature = "python")]
1120#[pyfunction(name = "pma_batch")]
1121#[pyo3(signature = (data, kernel=None))]
1122pub fn pma_batch_py<'py>(
1123 py: Python<'py>,
1124 data: PyReadonlyArray1<'py, f64>,
1125 kernel: Option<&str>,
1126) -> PyResult<Bound<'py, PyDict>> {
1127 let slice_in = data.as_slice()?;
1128 let kern = validate_kernel(kernel, true)?;
1129 let (rows, cols) = (2usize, slice_in.len());
1130 let size = rows
1131 .checked_mul(cols)
1132 .ok_or_else(|| PyValueError::new_err(PmaError::SizeOverflow { rows, cols }.to_string()))?;
1133
1134 let values_arr = unsafe { PyArray1::<f64>::new(py, [size], false) };
1135 let values_slice = unsafe { values_arr.as_slice_mut()? };
1136
1137 py.allow_threads(|| -> PyResult<()> {
1138 let first =
1139 pma_first_valid_idx(slice_in).map_err(|e| PyValueError::new_err(e.to_string()))?;
1140
1141 let warm = first + 7 - 1;
1142 let warm_prefixes = [warm; 2];
1143 let values_mu: &mut [core::mem::MaybeUninit<f64>] = unsafe {
1144 core::slice::from_raw_parts_mut(
1145 values_slice.as_mut_ptr() as *mut core::mem::MaybeUninit<f64>,
1146 values_slice.len(),
1147 )
1148 };
1149 init_matrix_prefixes(values_mu, cols, &warm_prefixes);
1150
1151 let (row0, row1) = values_slice.split_at_mut(cols);
1152 pma_compute_into(
1153 slice_in,
1154 first,
1155 match kern {
1156 Kernel::Auto => Kernel::Scalar,
1157 Kernel::ScalarBatch => Kernel::Scalar,
1158 Kernel::Avx2Batch => Kernel::Avx2,
1159 Kernel::Avx512Batch => Kernel::Avx512,
1160 _ => Kernel::Scalar,
1161 },
1162 row0,
1163 row1,
1164 );
1165 Ok(())
1166 })?;
1167
1168 let dict = PyDict::new(py);
1169 dict.set_item("values", values_arr.reshape((rows, cols))?)?;
1170 dict.set_item("rows", rows)?;
1171 dict.set_item("cols", cols)?;
1172 Ok(dict)
1173}
1174
1175#[cfg(all(feature = "python", feature = "cuda"))]
1176#[pyfunction(name = "pma_cuda_batch_dev")]
1177#[pyo3(signature = (data_f32, device_id=0))]
1178pub fn pma_cuda_batch_dev_py(
1179 py: Python<'_>,
1180 data_f32: numpy::PyReadonlyArray1<'_, f32>,
1181 device_id: usize,
1182) -> PyResult<(DeviceArrayF32Py, DeviceArrayF32Py)> {
1183 if !cuda_available() {
1184 return Err(PyValueError::new_err("CUDA not available"));
1185 }
1186 let slice_in = data_f32.as_slice()?;
1187 let sweep = PmaBatchRange::default();
1188 let pair = py.allow_threads(|| {
1189 let cuda = CudaPma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1190 cuda.pma_batch_dev(slice_in, &sweep)
1191 .map_err(|e| PyValueError::new_err(e.to_string()))
1192 })?;
1193 let predict = make_device_array_py(device_id, pair.predict)?;
1194 let trigger = make_device_array_py(device_id, pair.trigger)?;
1195 Ok((predict, trigger))
1196}
1197
1198#[cfg(all(feature = "python", feature = "cuda"))]
1199#[pyfunction(name = "pma_cuda_many_series_one_param_dev")]
1200#[pyo3(signature = (data_tm_f32, device_id=0))]
1201pub fn pma_cuda_many_series_one_param_dev_py(
1202 py: Python<'_>,
1203 data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1204 device_id: usize,
1205) -> PyResult<(DeviceArrayF32Py, DeviceArrayF32Py)> {
1206 if !cuda_available() {
1207 return Err(PyValueError::new_err("CUDA not available"));
1208 }
1209 let shape = data_tm_f32.shape();
1210 if shape.len() != 2 {
1211 return Err(PyValueError::new_err("expected time-major 2D array"));
1212 }
1213 let rows = shape[0];
1214 let cols = shape[1];
1215 let flat = data_tm_f32.as_slice()?;
1216 let pair = py.allow_threads(|| {
1217 let cuda = CudaPma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1218 cuda.pma_many_series_one_param_time_major_dev(flat, cols, rows)
1219 .map_err(|e| PyValueError::new_err(e.to_string()))
1220 })?;
1221 let predict = make_device_array_py(device_id, pair.predict)?;
1222 let trigger = make_device_array_py(device_id, pair.trigger)?;
1223 Ok((predict, trigger))
1224}
1225
1226#[cfg(test)]
1227mod tests {
1228 use super::*;
1229 use crate::skip_if_unsupported;
1230 use crate::utilities::data_loader::read_candles_from_csv;
1231
1232 fn check_pma_default_candles(
1233 test_name: &str,
1234 kernel: Kernel,
1235 ) -> Result<(), Box<dyn std::error::Error>> {
1236 skip_if_unsupported!(kernel, test_name);
1237 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1238 let candles = read_candles_from_csv(file_path)?;
1239 let input = PmaInput::with_default_candles(&candles);
1240 let output = pma_with_kernel(&input, kernel)?;
1241 assert_eq!(output.predict.len(), candles.close.len());
1242 assert_eq!(output.trigger.len(), candles.close.len());
1243 Ok(())
1244 }
1245
1246 fn check_pma_with_slice(
1247 test_name: &str,
1248 kernel: Kernel,
1249 ) -> Result<(), Box<dyn std::error::Error>> {
1250 skip_if_unsupported!(kernel, test_name);
1251 let data = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0];
1252 let input = PmaInput::from_slice(&data, PmaParams {});
1253 let output = pma_with_kernel(&input, kernel)?;
1254 assert_eq!(output.predict.len(), data.len());
1255 assert_eq!(output.trigger.len(), data.len());
1256 Ok(())
1257 }
1258
1259 fn check_pma_not_enough_data(
1260 test_name: &str,
1261 kernel: Kernel,
1262 ) -> Result<(), Box<dyn std::error::Error>> {
1263 skip_if_unsupported!(kernel, test_name);
1264 let data = [10.0, 20.0, 30.0];
1265 let input = PmaInput::from_slice(&data, PmaParams {});
1266 let result = pma_with_kernel(&input, kernel);
1267 assert!(result.is_err(), "Expected error for not enough data");
1268 Ok(())
1269 }
1270
1271 fn check_pma_all_values_nan(
1272 test_name: &str,
1273 kernel: Kernel,
1274 ) -> Result<(), Box<dyn std::error::Error>> {
1275 skip_if_unsupported!(kernel, test_name);
1276 let data = [f64::NAN, f64::NAN, f64::NAN];
1277 let input = PmaInput::from_slice(&data, PmaParams {});
1278 let result = pma_with_kernel(&input, kernel);
1279 assert!(result.is_err(), "Expected error for all values NaN");
1280 Ok(())
1281 }
1282
1283 fn check_pma_expected_values(
1284 test_name: &str,
1285 kernel: Kernel,
1286 ) -> Result<(), Box<dyn std::error::Error>> {
1287 skip_if_unsupported!(kernel, test_name);
1288 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1289 let candles = read_candles_from_csv(file_path)?;
1290 let input = PmaInput::from_candles(&candles, "hl2", PmaParams {});
1291 let result = pma_with_kernel(&input, kernel)?;
1292
1293 assert_eq!(
1294 result.predict.len(),
1295 candles.close.len(),
1296 "Predict length mismatch"
1297 );
1298 assert_eq!(
1299 result.trigger.len(),
1300 candles.close.len(),
1301 "Trigger length mismatch"
1302 );
1303
1304 let expected_predict = [
1305 59208.18749999999,
1306 59233.83609693878,
1307 59213.19132653061,
1308 59199.002551020414,
1309 58993.318877551,
1310 ];
1311 let expected_trigger = [
1312 59157.70790816327,
1313 59208.60076530612,
1314 59218.6763392857,
1315 59211.1443877551,
1316 59123.05019132652,
1317 ];
1318
1319 assert!(
1320 result.predict.len() >= 5,
1321 "Output length too short for checking"
1322 );
1323 let start_idx = result.predict.len() - 5;
1324 for i in 0..5 {
1325 let calc_val = result.predict[start_idx + i];
1326 let exp_val = expected_predict[i];
1327 assert!(
1328 (calc_val - exp_val).abs() < 1e-1,
1329 "Mismatch in predict at index {}: expected {}, got {}",
1330 start_idx + i,
1331 exp_val,
1332 calc_val
1333 );
1334 }
1335 for i in 0..5 {
1336 let calc_val = result.trigger[start_idx + i];
1337 let exp_val = expected_trigger[i];
1338 assert!(
1339 (calc_val - exp_val).abs() < 1e-1,
1340 "Mismatch in trigger at index {}: expected {}, got {}",
1341 start_idx + i,
1342 exp_val,
1343 calc_val
1344 );
1345 }
1346 Ok(())
1347 }
1348
1349 macro_rules! generate_all_pma_tests {
1350 ($($test_fn:ident),*) => {
1351 paste::paste! {
1352 $(
1353 #[test]
1354 fn [<$test_fn _scalar_f64>]() {
1355 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1356 }
1357 )*
1358 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1359 $(
1360 #[test]
1361 fn [<$test_fn _avx2_f64>]() {
1362 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1363 }
1364 #[test]
1365 fn [<$test_fn _avx512_f64>]() {
1366 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1367 }
1368 )*
1369 }
1370 }
1371 }
1372
1373 #[cfg(debug_assertions)]
1374 fn check_pma_no_poison(
1375 test_name: &str,
1376 kernel: Kernel,
1377 ) -> Result<(), Box<dyn std::error::Error>> {
1378 skip_if_unsupported!(kernel, test_name);
1379
1380 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1381 let candles = read_candles_from_csv(file_path)?;
1382
1383 let test_sources = vec![
1384 "close", "open", "high", "low", "hl2", "hlc3", "ohlc4", "volume",
1385 ];
1386
1387 for (source_idx, source) in test_sources.iter().enumerate() {
1388 let input = PmaInput::from_candles(&candles, source, PmaParams {});
1389 let output = pma_with_kernel(&input, kernel)?;
1390
1391 for (i, &val) in output.predict.iter().enumerate() {
1392 if val.is_nan() {
1393 continue;
1394 }
1395
1396 let bits = val.to_bits();
1397
1398 if bits == 0x11111111_11111111 {
1399 panic!(
1400 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1401 in predict array with source: {} (source set {})",
1402 test_name, val, bits, i, source, source_idx
1403 );
1404 }
1405
1406 if bits == 0x22222222_22222222 {
1407 panic!(
1408 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1409 in predict array with source: {} (source set {})",
1410 test_name, val, bits, i, source, source_idx
1411 );
1412 }
1413
1414 if bits == 0x33333333_33333333 {
1415 panic!(
1416 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1417 in predict array with source: {} (source set {})",
1418 test_name, val, bits, i, source, source_idx
1419 );
1420 }
1421 }
1422
1423 for (i, &val) in output.trigger.iter().enumerate() {
1424 if val.is_nan() {
1425 continue;
1426 }
1427
1428 let bits = val.to_bits();
1429
1430 if bits == 0x11111111_11111111 {
1431 panic!(
1432 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1433 in trigger array with source: {} (source set {})",
1434 test_name, val, bits, i, source, source_idx
1435 );
1436 }
1437
1438 if bits == 0x22222222_22222222 {
1439 panic!(
1440 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1441 in trigger array with source: {} (source set {})",
1442 test_name, val, bits, i, source, source_idx
1443 );
1444 }
1445
1446 if bits == 0x33333333_33333333 {
1447 panic!(
1448 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1449 in trigger array with source: {} (source set {})",
1450 test_name, val, bits, i, source, source_idx
1451 );
1452 }
1453 }
1454 }
1455
1456 Ok(())
1457 }
1458
1459 #[cfg(not(debug_assertions))]
1460 fn check_pma_no_poison(
1461 _test_name: &str,
1462 _kernel: Kernel,
1463 ) -> Result<(), Box<dyn std::error::Error>> {
1464 Ok(())
1465 }
1466
1467 #[cfg(feature = "proptest")]
1468 #[allow(clippy::float_cmp)]
1469 fn check_pma_property(
1470 test_name: &str,
1471 kernel: Kernel,
1472 ) -> Result<(), Box<dyn std::error::Error>> {
1473 use proptest::prelude::*;
1474 skip_if_unsupported!(kernel, test_name);
1475
1476 let strat = prop::collection::vec(
1477 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1478 7..400,
1479 );
1480
1481 proptest::test_runner::TestRunner::default().run(&strat, |data| {
1482 let input = PmaInput::from_slice(&data, PmaParams {});
1483
1484 let result = pma_with_kernel(&input, kernel)?;
1485 let ref_result = pma_with_kernel(&input, Kernel::Scalar)?;
1486
1487 prop_assert_eq!(result.predict.len(), data.len());
1488 prop_assert_eq!(result.trigger.len(), data.len());
1489 prop_assert_eq!(ref_result.predict.len(), data.len());
1490 prop_assert_eq!(ref_result.trigger.len(), data.len());
1491
1492 let warmup_period = 7;
1493
1494 for i in 0..warmup_period {
1495 prop_assert!(
1496 result.predict[i].is_nan(),
1497 "Expected NaN in predict warmup at index {}",
1498 i
1499 );
1500 prop_assert!(
1501 result.trigger[i].is_nan(),
1502 "Expected NaN in trigger warmup at index {}",
1503 i
1504 );
1505 }
1506
1507 if data.windows(2).all(|w| (w[0] - w[1]).abs() < f64::EPSILON)
1508 && data.len() >= warmup_period
1509 {
1510 for i in warmup_period..data.len() {
1511 if result.predict[i].is_finite() {
1512 prop_assert!(
1513 (result.predict[i] - data[0]).abs() < 1e-9,
1514 "Constant data test failed: predict[{}] = {} should be close to {}",
1515 i,
1516 result.predict[i],
1517 data[0]
1518 );
1519 }
1520 }
1521 }
1522
1523 for i in warmup_period..data.len() {
1524 if result.predict[i].is_finite() && ref_result.predict[i].is_finite() {
1525 let diff_predict = (result.predict[i] - ref_result.predict[i]).abs();
1526 prop_assert!(
1527 diff_predict < 1e-10,
1528 "Predict mismatch at index {}: kernel={}, scalar={}, diff={}",
1529 i,
1530 result.predict[i],
1531 ref_result.predict[i],
1532 diff_predict
1533 );
1534 } else {
1535 prop_assert_eq!(
1536 result.predict[i].is_nan(),
1537 ref_result.predict[i].is_nan(),
1538 "NaN mismatch in predict at index {}",
1539 i
1540 );
1541 }
1542
1543 if result.trigger[i].is_finite() && ref_result.trigger[i].is_finite() {
1544 let diff_trigger = (result.trigger[i] - ref_result.trigger[i]).abs();
1545 prop_assert!(
1546 diff_trigger < 1e-10,
1547 "Trigger mismatch at index {}: kernel={}, scalar={}, diff={}",
1548 i,
1549 result.trigger[i],
1550 ref_result.trigger[i],
1551 diff_trigger
1552 );
1553 } else {
1554 prop_assert_eq!(
1555 result.trigger[i].is_nan(),
1556 ref_result.trigger[i].is_nan(),
1557 "NaN mismatch in trigger at index {}",
1558 i
1559 );
1560 }
1561
1562 if i >= warmup_period && result.predict[i].is_finite() {
1563 let window_start = i.saturating_sub(6);
1564 let window_data = &data[window_start..=i];
1565 let min_val = window_data.iter().fold(f64::INFINITY, |a, &b| a.min(b));
1566 let max_val = window_data.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
1567
1568 let tolerance = (max_val - min_val).abs() * 0.1 + 1e-9;
1569 prop_assert!(
1570 result.predict[i] >= min_val - tolerance
1571 && result.predict[i] <= max_val + tolerance,
1572 "Predict value {} at index {} outside bounds [{}, {}] with tolerance {}",
1573 result.predict[i],
1574 i,
1575 min_val - tolerance,
1576 max_val + tolerance,
1577 tolerance
1578 );
1579 }
1580
1581 if i == warmup_period && i >= 6 {
1582 let wma1_expected = (7.0 * data[i]
1583 + 6.0 * data[i - 1]
1584 + 5.0 * data[i - 2]
1585 + 4.0 * data[i - 3]
1586 + 3.0 * data[i - 4]
1587 + 2.0 * data[i - 5]
1588 + data[i - 6])
1589 / 28.0;
1590
1591 if result.predict[i].is_finite() {
1592 let window_start = i.saturating_sub(6);
1593 let window = &data[window_start..=i];
1594 let window_avg = window.iter().sum::<f64>() / window.len() as f64;
1595 let min = window.iter().fold(f64::INFINITY, |a, &b| a.min(b));
1596 let max = window.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
1597 prop_assert!(
1598 (result.predict[i] - window_avg).abs() < (max - min).abs() + 1e-9,
1599 "Predict value {} at index {} seems unrelated to window average {}",
1600 result.predict[i],
1601 i,
1602 window_avg
1603 );
1604 }
1605 }
1606
1607 if i >= warmup_period + 3
1608 && result.trigger[i].is_finite()
1609 && result.predict[i].is_finite()
1610 {
1611 if result.predict[i - 1].is_finite()
1612 && result.predict[i - 2].is_finite()
1613 && result.predict[i - 3].is_finite()
1614 {
1615 let expected_trigger = (4.0 * result.predict[i]
1616 + 3.0 * result.predict[i - 1]
1617 + 2.0 * result.predict[i - 2]
1618 + result.predict[i - 3])
1619 / 10.0;
1620 let trigger_diff = (result.trigger[i] - expected_trigger).abs();
1621 prop_assert!(
1622 trigger_diff < 1e-10,
1623 "Trigger calculation error at index {}: expected {}, got {}, diff={}",
1624 i,
1625 expected_trigger,
1626 result.trigger[i],
1627 trigger_diff
1628 );
1629 }
1630 }
1631 }
1632
1633 if data.len() == 7 {
1634 prop_assert!(
1635 result.predict[6].is_finite(),
1636 "With exactly 7 points, predict[6] should be finite but got NaN"
1637 );
1638 }
1639
1640 Ok(())
1641 })?;
1642
1643 Ok(())
1644 }
1645
1646 generate_all_pma_tests!(
1647 check_pma_default_candles,
1648 check_pma_with_slice,
1649 check_pma_not_enough_data,
1650 check_pma_all_values_nan,
1651 check_pma_expected_values,
1652 check_pma_no_poison
1653 );
1654
1655 #[cfg(feature = "proptest")]
1656 generate_all_pma_tests!(check_pma_property);
1657
1658 fn check_batch_default_row(
1659 test: &str,
1660 kernel: Kernel,
1661 ) -> Result<(), Box<dyn std::error::Error>> {
1662 skip_if_unsupported!(kernel, test);
1663
1664 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1665 let c = read_candles_from_csv(file)?;
1666 let output = PmaBatchBuilder::new()
1667 .kernel(kernel)
1668 .apply_candles(&c, "close")?;
1669
1670 assert_eq!(output.rows, 1, "Expected exactly 1 row");
1671 assert_eq!(output.cols, c.close.len());
1672 assert_eq!(output.predict.len(), c.close.len());
1673 assert_eq!(output.trigger.len(), c.close.len());
1674
1675 let input = PmaInput::from_candles(&c, "close", PmaParams::default());
1676 let expected = pma_with_kernel(&input, kernel)?;
1677
1678 for (i, (&a, &b)) in output
1679 .predict
1680 .iter()
1681 .zip(expected.predict.iter())
1682 .enumerate()
1683 {
1684 if a.is_nan() && b.is_nan() {
1685 continue;
1686 }
1687 assert!(
1688 (a - b).abs() < 1e-12,
1689 "[{test}] predict mismatch at idx {i}: batch={}, direct={}",
1690 a,
1691 b
1692 );
1693 }
1694 for (i, (&a, &b)) in output
1695 .trigger
1696 .iter()
1697 .zip(expected.trigger.iter())
1698 .enumerate()
1699 {
1700 if a.is_nan() && b.is_nan() {
1701 continue;
1702 }
1703 assert!(
1704 (a - b).abs() < 1e-12,
1705 "[{test}] trigger mismatch at idx {i}: batch={}, direct={}",
1706 a,
1707 b
1708 );
1709 }
1710 Ok(())
1711 }
1712
1713 macro_rules! gen_batch_tests {
1714 ($fn_name:ident) => {
1715 paste::paste! {
1716 #[test] fn [<$fn_name _scalar>]() {
1717 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1718 }
1719 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1720 #[test] fn [<$fn_name _avx2>]() {
1721 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1722 }
1723 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1724 #[test] fn [<$fn_name _avx512>]() {
1725 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1726 }
1727 #[test] fn [<$fn_name _auto_detect>]() {
1728 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1729 }
1730 }
1731 };
1732 }
1733 #[cfg(debug_assertions)]
1734 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
1735 skip_if_unsupported!(kernel, test);
1736
1737 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1738 let c = read_candles_from_csv(file)?;
1739
1740 let test_sources = vec!["close", "open", "high", "low", "hl2", "hlc3", "ohlc4"];
1741
1742 for (source_idx, source) in test_sources.iter().enumerate() {
1743 let output = PmaBatchBuilder::new()
1744 .kernel(kernel)
1745 .apply_candles(&c, source)?;
1746
1747 for (idx, &val) in output.predict.iter().enumerate() {
1748 if val.is_nan() {
1749 continue;
1750 }
1751
1752 let bits = val.to_bits();
1753
1754 if bits == 0x11111111_11111111 {
1755 panic!(
1756 "[{}] Source {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1757 at index {} in predict array with source: {}",
1758 test, source_idx, val, bits, idx, source
1759 );
1760 }
1761
1762 if bits == 0x22222222_22222222 {
1763 panic!(
1764 "[{}] Source {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1765 at index {} in predict array with source: {}",
1766 test, source_idx, val, bits, idx, source
1767 );
1768 }
1769
1770 if bits == 0x33333333_33333333 {
1771 panic!(
1772 "[{}] Source {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1773 at index {} in predict array with source: {}",
1774 test, source_idx, val, bits, idx, source
1775 );
1776 }
1777 }
1778
1779 for (idx, &val) in output.trigger.iter().enumerate() {
1780 if val.is_nan() {
1781 continue;
1782 }
1783
1784 let bits = val.to_bits();
1785
1786 if bits == 0x11111111_11111111 {
1787 panic!(
1788 "[{}] Source {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1789 at index {} in trigger array with source: {}",
1790 test, source_idx, val, bits, idx, source
1791 );
1792 }
1793
1794 if bits == 0x22222222_22222222 {
1795 panic!(
1796 "[{}] Source {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1797 at index {} in trigger array with source: {}",
1798 test, source_idx, val, bits, idx, source
1799 );
1800 }
1801
1802 if bits == 0x33333333_33333333 {
1803 panic!(
1804 "[{}] Source {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1805 at index {} in trigger array with source: {}",
1806 test, source_idx, val, bits, idx, source
1807 );
1808 }
1809 }
1810 }
1811
1812 Ok(())
1813 }
1814
1815 #[cfg(not(debug_assertions))]
1816 fn check_batch_no_poison(
1817 _test: &str,
1818 _kernel: Kernel,
1819 ) -> Result<(), Box<dyn std::error::Error>> {
1820 Ok(())
1821 }
1822
1823 gen_batch_tests!(check_batch_default_row);
1824 gen_batch_tests!(check_batch_no_poison);
1825
1826 #[test]
1827 fn test_pma_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
1828 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1829 let candles = read_candles_from_csv(file_path)?;
1830 let input = PmaInput::with_default_candles(&candles);
1831
1832 let base = pma_with_kernel(&input, Kernel::Auto)?;
1833
1834 let n = candles.close.len();
1835 let mut out_predict = vec![0.0; n];
1836 let mut out_trigger = vec![0.0; n];
1837
1838 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1839 {
1840 pma_into(&input, &mut out_predict, &mut out_trigger)?;
1841 }
1842
1843 assert_eq!(base.predict.len(), out_predict.len());
1844 assert_eq!(base.trigger.len(), out_trigger.len());
1845
1846 fn eq_or_both_nan_eps(a: f64, b: f64) -> bool {
1847 (a.is_nan() && b.is_nan()) || (a - b).abs() <= 1e-12
1848 }
1849
1850 for i in 0..n {
1851 assert!(
1852 eq_or_both_nan_eps(base.predict[i], out_predict[i]),
1853 "predict mismatch at {i}: api={}, into={}",
1854 base.predict[i],
1855 out_predict[i]
1856 );
1857 assert!(
1858 eq_or_both_nan_eps(base.trigger[i], out_trigger[i]),
1859 "trigger mismatch at {i}: api={}, into={}",
1860 base.trigger[i],
1861 out_trigger[i]
1862 );
1863 }
1864
1865 Ok(())
1866 }
1867}