Skip to main content

vector_ta/indicators/
pma.rs

1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5    make_uninit_matrix,
6};
7use aligned_vec::{AVec, CACHELINE_ALIGN};
8#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
9use core::arch::x86_64::*;
10#[cfg(not(target_arch = "wasm32"))]
11use rayon::prelude::*;
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use serde::{Deserialize, Serialize};
14use std::convert::AsRef;
15use thiserror::Error;
16#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
17use wasm_bindgen::prelude::*;
18
19impl<'a> AsRef<[f64]> for PmaInput<'a> {
20    #[inline(always)]
21    fn as_ref(&self) -> &[f64] {
22        match &self.data {
23            PmaData::Slice(slice) => slice,
24            PmaData::Candles { candles, source } => source_type(candles, source),
25        }
26    }
27}
28
29#[derive(Debug, Clone)]
30pub enum PmaData<'a> {
31    Candles {
32        candles: &'a Candles,
33        source: &'a str,
34    },
35    Slice(&'a [f64]),
36}
37
38#[derive(Debug, Clone)]
39pub struct PmaOutput {
40    pub predict: Vec<f64>,
41    pub trigger: Vec<f64>,
42}
43
44#[derive(Debug, Clone)]
45pub struct PmaParams;
46
47impl Default for PmaParams {
48    fn default() -> Self {
49        Self
50    }
51}
52
53#[derive(Debug, Clone)]
54pub struct PmaInput<'a> {
55    pub data: PmaData<'a>,
56    pub params: PmaParams,
57}
58
59impl<'a> PmaInput<'a> {
60    #[inline]
61    pub fn from_candles(c: &'a Candles, s: &'a str, p: PmaParams) -> Self {
62        Self {
63            data: PmaData::Candles {
64                candles: c,
65                source: s,
66            },
67            params: p,
68        }
69    }
70    #[inline]
71    pub fn from_slice(sl: &'a [f64], p: PmaParams) -> Self {
72        Self {
73            data: PmaData::Slice(sl),
74            params: p,
75        }
76    }
77    #[inline]
78    pub fn with_default_candles(c: &'a Candles) -> Self {
79        Self::from_candles(c, "close", PmaParams::default())
80    }
81}
82
83#[derive(Copy, Clone, Debug)]
84pub struct PmaBuilder {
85    kernel: Kernel,
86}
87
88impl Default for PmaBuilder {
89    fn default() -> Self {
90        Self {
91            kernel: Kernel::Auto,
92        }
93    }
94}
95
96impl PmaBuilder {
97    #[inline(always)]
98    pub fn new() -> Self {
99        Self::default()
100    }
101    #[inline(always)]
102    pub fn kernel(mut self, k: Kernel) -> Self {
103        self.kernel = k;
104        self
105    }
106
107    #[inline(always)]
108    pub fn apply(self, c: &Candles) -> Result<PmaOutput, PmaError> {
109        let i = PmaInput::from_candles(c, "close", PmaParams::default());
110        pma_with_kernel(&i, self.kernel)
111    }
112
113    #[inline(always)]
114    pub fn apply_slice(self, d: &[f64]) -> Result<PmaOutput, PmaError> {
115        let i = PmaInput::from_slice(d, PmaParams::default());
116        pma_with_kernel(&i, self.kernel)
117    }
118
119    #[inline(always)]
120    pub fn into_stream(self) -> Result<PmaStream, PmaError> {
121        PmaStream::try_new(PmaParams::default())
122    }
123}
124
125#[derive(Debug, Error)]
126pub enum PmaError {
127    #[error("pma: Empty data provided.")]
128    EmptyInputData,
129    #[error("pma: All values are NaN.")]
130    AllValuesNaN,
131    #[error("pma: Not enough valid data: needed = {needed}, valid = {valid}")]
132    NotEnoughValidData { needed: usize, valid: usize },
133    #[error("pma: Invalid period: period = {period}, data length = {data_len}")]
134    InvalidPeriod { period: usize, data_len: usize },
135    #[error("pma: Output slice length mismatch: expected = {expected}, got = {got}")]
136    OutputLengthMismatch { expected: usize, got: usize },
137    #[error("pma: Invalid range: start = {start}, end = {end}, step = {step}")]
138    InvalidRange {
139        start: usize,
140        end: usize,
141        step: usize,
142    },
143    #[error("pma: invalid kernel for batch API: {0:?}")]
144    InvalidKernelForBatch(Kernel),
145    #[error("pma: size overflow computing rows*cols: rows = {rows}, cols = {cols}")]
146    SizeOverflow { rows: usize, cols: usize },
147}
148
149#[inline(always)]
150fn pma_first_valid_idx(data: &[f64]) -> Result<usize, PmaError> {
151    if data.is_empty() {
152        return Err(PmaError::EmptyInputData);
153    }
154    let first = data
155        .iter()
156        .position(|x| !x.is_nan())
157        .ok_or(PmaError::AllValuesNaN)?;
158    let valid = data.len() - first;
159    if valid < 7 {
160        return Err(PmaError::NotEnoughValidData { needed: 7, valid });
161    }
162    Ok(first)
163}
164
165#[inline]
166pub fn pma(input: &PmaInput) -> Result<PmaOutput, PmaError> {
167    pma_with_kernel(input, Kernel::Auto)
168}
169
170pub fn pma_with_kernel(input: &PmaInput, kernel: Kernel) -> Result<PmaOutput, PmaError> {
171    let data: &[f64] = match &input.data {
172        PmaData::Candles { candles, source } => source_type(candles, source),
173        PmaData::Slice(sl) => sl,
174    };
175
176    let first = pma_first_valid_idx(data)?;
177
178    let chosen = match kernel {
179        Kernel::Auto => Kernel::Scalar,
180        other => other,
181    };
182
183    unsafe {
184        match chosen {
185            Kernel::Scalar | Kernel::ScalarBatch => pma_scalar(data, first),
186            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
187            Kernel::Avx2 | Kernel::Avx2Batch => pma_avx2(data, first),
188            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
189            Kernel::Avx512 | Kernel::Avx512Batch => pma_avx512(data, first),
190            _ => unreachable!(),
191        }
192    }
193}
194
195#[inline]
196pub fn pma_scalar(data: &[f64], first_valid_idx: usize) -> Result<PmaOutput, PmaError> {
197    let n = data.len();
198    let warmup_period = first_valid_idx + 7;
199    let mut predict = alloc_with_nan_prefix(n, warmup_period);
200    let mut trigger = alloc_with_nan_prefix(n, warmup_period);
201
202    if n <= first_valid_idx + 6 {
203        return Ok(PmaOutput { predict, trigger });
204    }
205
206    const INV_28: f64 = 1.0 / 28.0;
207    const INV_10: f64 = 1.0 / 10.0;
208
209    let mut x_ring = [0.0_f64; 7];
210    let mut w_ring = [0.0_f64; 7];
211    let mut p_ring = [0.0_f64; 4];
212    let mut x_head = 0usize;
213    let mut w_head = 0usize;
214    let mut p_head = 0usize;
215
216    let mut A = 0.0_f64;
217    let mut S = 0.0_f64;
218    let mut A1 = 0.0_f64;
219    let mut S1 = 0.0_f64;
220    let mut A2 = 0.0_f64;
221    let mut T = 0.0_f64;
222
223    let j0 = first_valid_idx + 6;
224
225    unsafe {
226        let dp = data.as_ptr();
227
228        let x0 = *dp.add(j0 - 6);
229        let x1 = *dp.add(j0 - 5);
230        let x2 = *dp.add(j0 - 4);
231        let x3 = *dp.add(j0 - 3);
232        let x4 = *dp.add(j0 - 2);
233        let x5 = *dp.add(j0 - 1);
234        let x6 = *dp.add(j0 - 0);
235
236        x_ring[0] = x0;
237        x_ring[1] = x1;
238        x_ring[2] = x2;
239        x_ring[3] = x3;
240        x_ring[4] = x4;
241        x_ring[5] = x5;
242        x_ring[6] = x6;
243
244        A = ((x0 + x1) + (x2 + x3)) + ((x4 + x5) + x6);
245
246        let s01 = x0.mul_add(1.0, 2.0 * x1);
247        let s23 = (3.0 * x2) + (4.0 * x3);
248        let s45 = (5.0 * x4) + (6.0 * x5);
249        S = (s01 + s23) + s45 + 7.0 * x6;
250
251        let mut w1 = S * INV_28;
252
253        let old_A1 = A1;
254        let old_w = w_ring[w_head];
255        S1 = (7.0_f64).mul_add(w1, S1) - old_A1;
256        A1 = A1 + w1 - old_w;
257        w_ring[w_head] = w1;
258        w_head += 1;
259        if w_head == 7 {
260            w_head = 0;
261        }
262
263        let mut w2 = S1 * INV_28;
264        let mut pr = (2.0_f64).mul_add(w1, -w2);
265        *predict.get_unchecked_mut(j0) = pr;
266
267        let old_A2 = A2;
268        let old_p = p_ring[p_head];
269        T = (4.0_f64).mul_add(pr, T) - old_A2;
270        A2 = A2 + pr - old_p;
271        p_ring[p_head] = pr;
272        p_head += 1;
273        if p_head == 4 {
274            p_head = 0;
275        }
276        *trigger.get_unchecked_mut(j0) = f64::NAN;
277
278        let mut j = j0 + 1;
279        while j < n {
280            let x_new = *dp.add(j);
281            let x_old = x_ring[x_head];
282            let old_A = A;
283
284            A = A + x_new - x_old;
285            S = (7.0_f64).mul_add(x_new, S) - old_A;
286
287            x_ring[x_head] = x_new;
288            x_head += 1;
289            if x_head == 7 {
290                x_head = 0;
291            }
292
293            w1 = S * INV_28;
294
295            let old_A1 = A1;
296            let w_old = w_ring[w_head];
297            S1 = (7.0_f64).mul_add(w1, S1) - old_A1;
298            A1 = A1 + w1 - w_old;
299
300            w_ring[w_head] = w1;
301            w_head += 1;
302            if w_head == 7 {
303                w_head = 0;
304            }
305
306            w2 = S1 * INV_28;
307
308            pr = (2.0_f64).mul_add(w1, -w2);
309            *predict.get_unchecked_mut(j) = pr;
310
311            let old_A2 = A2;
312            let p_old = p_ring[p_head];
313            T = (4.0_f64).mul_add(pr, T) - old_A2;
314            A2 = A2 + pr - p_old;
315
316            p_ring[p_head] = pr;
317            p_head += 1;
318            if p_head == 4 {
319                p_head = 0;
320            }
321
322            if j >= first_valid_idx + 9 {
323                *trigger.get_unchecked_mut(j) = T * INV_10;
324            } else {
325                *trigger.get_unchecked_mut(j) = f64::NAN;
326            }
327
328            j += 1;
329        }
330    }
331
332    Ok(PmaOutput { predict, trigger })
333}
334
335#[inline(always)]
336fn pma_compute_into(
337    data: &[f64],
338    first_valid_idx: usize,
339    _kernel: Kernel,
340    predict_out: &mut [f64],
341    trigger_out: &mut [f64],
342) {
343    let n = data.len();
344    if n <= first_valid_idx + 6 {
345        return;
346    }
347
348    const INV_28: f64 = 1.0 / 28.0;
349    const INV_10: f64 = 1.0 / 10.0;
350
351    let mut x_ring = [0.0_f64; 7];
352    let mut w_ring = [0.0_f64; 7];
353    let mut p_ring = [0.0_f64; 4];
354    let mut x_head = 0usize;
355    let mut w_head = 0usize;
356    let mut p_head = 0usize;
357
358    let mut A = 0.0_f64;
359    let mut S = 0.0_f64;
360    let mut A1 = 0.0_f64;
361    let mut S1 = 0.0_f64;
362    let mut A2 = 0.0_f64;
363    let mut T = 0.0_f64;
364
365    let j0 = first_valid_idx + 6;
366
367    unsafe {
368        let dp = data.as_ptr();
369
370        let x0 = *dp.add(j0 - 6);
371        let x1 = *dp.add(j0 - 5);
372        let x2 = *dp.add(j0 - 4);
373        let x3 = *dp.add(j0 - 3);
374        let x4 = *dp.add(j0 - 2);
375        let x5 = *dp.add(j0 - 1);
376        let x6 = *dp.add(j0 - 0);
377
378        x_ring[0] = x0;
379        x_ring[1] = x1;
380        x_ring[2] = x2;
381        x_ring[3] = x3;
382        x_ring[4] = x4;
383        x_ring[5] = x5;
384        x_ring[6] = x6;
385
386        A = ((x0 + x1) + (x2 + x3)) + ((x4 + x5) + x6);
387
388        let s01 = x0.mul_add(1.0, 2.0 * x1);
389        let s23 = (3.0 * x2) + (4.0 * x3);
390        let s45 = (5.0 * x4) + (6.0 * x5);
391        S = (s01 + s23) + s45 + 7.0 * x6;
392
393        let mut w1 = S * INV_28;
394
395        let old_A1 = A1;
396        let old_w = w_ring[w_head];
397        S1 = (7.0_f64).mul_add(w1, S1) - old_A1;
398        A1 = A1 + w1 - old_w;
399        w_ring[w_head] = w1;
400        w_head += 1;
401        if w_head == 7 {
402            w_head = 0;
403        }
404
405        let mut w2 = S1 * INV_28;
406        let mut pr = (2.0_f64).mul_add(w1, -w2);
407        *predict_out.get_unchecked_mut(j0) = pr;
408
409        let old_A2 = A2;
410        let old_p = p_ring[p_head];
411        T = (4.0_f64).mul_add(pr, T) - old_A2;
412        A2 = A2 + pr - old_p;
413        p_ring[p_head] = pr;
414        p_head += 1;
415        if p_head == 4 {
416            p_head = 0;
417        }
418
419        *trigger_out.get_unchecked_mut(j0) = f64::NAN;
420
421        let mut j = j0 + 1;
422        while j < n {
423            let x_new = *dp.add(j);
424            let x_old = x_ring[x_head];
425            let old_A = A;
426
427            A = A + x_new - x_old;
428            S = (7.0_f64).mul_add(x_new, S) - old_A;
429
430            x_ring[x_head] = x_new;
431            x_head += 1;
432            if x_head == 7 {
433                x_head = 0;
434            }
435
436            w1 = S * INV_28;
437
438            let old_A1 = A1;
439            let w_old = w_ring[w_head];
440            S1 = (7.0_f64).mul_add(w1, S1) - old_A1;
441            A1 = A1 + w1 - w_old;
442
443            w_ring[w_head] = w1;
444            w_head += 1;
445            if w_head == 7 {
446                w_head = 0;
447            }
448
449            w2 = S1 * INV_28;
450            pr = (2.0_f64).mul_add(w1, -w2);
451
452            *predict_out.get_unchecked_mut(j) = pr;
453
454            let old_A2 = A2;
455            let p_old = p_ring[p_head];
456            T = (4.0_f64).mul_add(pr, T) - old_A2;
457            A2 = A2 + pr - p_old;
458
459            p_ring[p_head] = pr;
460            p_head += 1;
461            if p_head == 4 {
462                p_head = 0;
463            }
464
465            if j >= first_valid_idx + 9 {
466                *trigger_out.get_unchecked_mut(j) = T * INV_10;
467            } else {
468                *trigger_out.get_unchecked_mut(j) = f64::NAN;
469            }
470
471            j += 1;
472        }
473    }
474}
475
476#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
477#[inline]
478pub fn pma_avx512(data: &[f64], first_valid_idx: usize) -> Result<PmaOutput, PmaError> {
479    pma_scalar(data, first_valid_idx)
480}
481
482#[inline]
483pub fn pma_avx2(data: &[f64], first_valid_idx: usize) -> Result<PmaOutput, PmaError> {
484    pma_scalar(data, first_valid_idx)
485}
486
487#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
488#[inline]
489pub fn pma_avx512_short(data: &[f64], first_valid_idx: usize) -> Result<PmaOutput, PmaError> {
490    pma_scalar(data, first_valid_idx)
491}
492
493#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
494#[inline]
495pub fn pma_avx512_long(data: &[f64], first_valid_idx: usize) -> Result<PmaOutput, PmaError> {
496    pma_scalar(data, first_valid_idx)
497}
498
499#[inline]
500pub fn pma_batch_with_kernel(
501    data: &[f64],
502    sweep: &PmaBatchRange,
503    k: Kernel,
504) -> Result<PmaBatchOutput, PmaError> {
505    let kernel = match k {
506        Kernel::Auto => detect_best_batch_kernel(),
507        other if other.is_batch() => other,
508        other => return Err(PmaError::InvalidKernelForBatch(other)),
509    };
510    let simd = match kernel {
511        Kernel::Avx512Batch => Kernel::Avx512,
512        Kernel::Avx2Batch => Kernel::Avx2,
513        Kernel::ScalarBatch => Kernel::Scalar,
514        _ => unreachable!(),
515    };
516    pma_batch_par_slice(data, sweep, simd)
517}
518
519#[inline]
520pub fn pma_batch_unified_with_kernel(
521    data: &[f64],
522    k: Kernel,
523) -> Result<PmaBatchOutputUnified, PmaError> {
524    let kernel = match k {
525        Kernel::Auto => detect_best_batch_kernel(),
526        other if other.is_batch() => other,
527        _ => Kernel::ScalarBatch,
528    };
529    pma_batch_unified_inner(data, kernel)
530}
531
532#[inline]
533fn pma_batch_unified_inner(data: &[f64], kern: Kernel) -> Result<PmaBatchOutputUnified, PmaError> {
534    let first = pma_first_valid_idx(data)?;
535
536    let rows = 2usize;
537    let cols = data.len();
538    let _ = rows
539        .checked_mul(cols)
540        .ok_or(PmaError::SizeOverflow { rows, cols })?;
541
542    let mut buf_mu = make_uninit_matrix(rows, cols);
543    let warm = [first + 7 - 1; 2];
544    init_matrix_prefixes(&mut buf_mu, cols, &warm);
545
546    let mut guard = core::mem::ManuallyDrop::new(buf_mu);
547    let outf: &mut [f64] =
548        unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
549
550    let (row0, row1) = outf.split_at_mut(cols);
551    pma_compute_into(
552        data,
553        first,
554        match kern {
555            Kernel::ScalarBatch => Kernel::Scalar,
556            Kernel::Avx2Batch => Kernel::Avx2,
557            Kernel::Avx512Batch => Kernel::Avx512,
558            _ => Kernel::Scalar,
559        },
560        row0,
561        row1,
562    );
563
564    let values = unsafe {
565        Vec::from_raw_parts(
566            guard.as_mut_ptr() as *mut f64,
567            guard.len(),
568            guard.capacity(),
569        )
570    };
571    Ok(PmaBatchOutputUnified { values, rows, cols })
572}
573
574#[derive(Debug, Clone)]
575pub struct PmaStream {
576    buffer: [f64; 7],
577    wma1: [f64; 7],
578    idx: usize,
579    filled7: bool,
580
581    pred4: [f64; 4],
582    pred_idx: usize,
583    pred_filled: bool,
584}
585
586impl PmaStream {
587    pub fn try_new(_params: PmaParams) -> Result<Self, PmaError> {
588        Ok(Self {
589            buffer: [f64::NAN; 7],
590            wma1: [0.0; 7],
591            idx: 0,
592            filled7: false,
593            pred4: [f64::NAN; 4],
594            pred_idx: 0,
595            pred_filled: false,
596        })
597    }
598    #[inline(always)]
599    pub fn update(&mut self, value: f64) -> Option<(f64, f64)> {
600        self.buffer[self.idx] = value;
601        self.idx = (self.idx + 1) % 7;
602        if !self.filled7 && self.idx == 0 {
603            self.filled7 = true;
604        }
605        if !self.filled7 {
606            return None;
607        }
608
609        let s = |k: usize| self.buffer[(self.idx + k) % 7];
610        let wma1_j =
611            (7.0 * s(6) + 6.0 * s(5) + 5.0 * s(4) + 4.0 * s(3) + 3.0 * s(2) + 2.0 * s(1) + s(0))
612                / 28.0;
613        self.wma1[self.idx] = wma1_j;
614
615        let w = |k: usize| self.wma1[(self.idx + k) % 7];
616        let wma2 =
617            (7.0 * w(6) + 6.0 * w(5) + 5.0 * w(4) + 4.0 * w(3) + 3.0 * w(2) + 2.0 * w(1) + w(0))
618                / 28.0;
619
620        let predict = 2.0 * wma1_j - wma2;
621
622        self.pred4[self.pred_idx] = predict;
623        self.pred_idx = (self.pred_idx + 1) % 4;
624        if !self.pred_filled && self.pred_idx == 0 {
625            self.pred_filled = true;
626        }
627
628        let trigger = if self.pred_filled {
629            let t3 = self.pred4[(self.pred_idx + 3) % 4];
630            let t2 = self.pred4[(self.pred_idx + 2) % 4];
631            let t1 = self.pred4[(self.pred_idx + 1) % 4];
632            let t0 = self.pred4[(self.pred_idx + 0) % 4];
633            (4.0 * t3 + 3.0 * t2 + 2.0 * t1 + t0) / 10.0
634        } else {
635            f64::NAN
636        };
637
638        Some((predict, trigger))
639    }
640}
641
642#[derive(Clone, Debug)]
643pub struct PmaBatchRange {
644    pub dummy: (usize, usize, usize),
645}
646
647impl Default for PmaBatchRange {
648    fn default() -> Self {
649        Self { dummy: (0, 0, 0) }
650    }
651}
652
653#[derive(Clone, Debug, Default)]
654pub struct PmaBatchBuilder {
655    range: PmaBatchRange,
656    kernel: Kernel,
657}
658
659impl PmaBatchBuilder {
660    pub fn new() -> Self {
661        Self::default()
662    }
663    pub fn kernel(mut self, k: Kernel) -> Self {
664        self.kernel = k;
665        self
666    }
667    #[inline]
668    pub fn apply_slice(self, data: &[f64]) -> Result<PmaBatchOutput, PmaError> {
669        pma_batch_with_kernel(data, &self.range, self.kernel)
670    }
671    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<PmaBatchOutput, PmaError> {
672        PmaBatchBuilder::new().kernel(k).apply_slice(data)
673    }
674    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<PmaBatchOutput, PmaError> {
675        let slice = source_type(c, src);
676        self.apply_slice(slice)
677    }
678    pub fn with_default_candles(c: &Candles) -> Result<PmaBatchOutput, PmaError> {
679        PmaBatchBuilder::new()
680            .kernel(Kernel::Auto)
681            .apply_candles(c, "close")
682    }
683}
684
685#[derive(Clone, Debug)]
686pub struct PmaBatchOutput {
687    pub predict: Vec<f64>,
688    pub trigger: Vec<f64>,
689    pub rows: usize,
690    pub cols: usize,
691}
692impl PmaBatchOutput {
693    pub fn values_for(&self, _dummy: &PmaParams) -> Option<(&[f64], &[f64])> {
694        Some((&self.predict[..], &self.trigger[..]))
695    }
696}
697
698#[derive(Clone, Debug)]
699#[cfg_attr(
700    all(target_arch = "wasm32", feature = "wasm"),
701    derive(Serialize, Deserialize)
702)]
703pub struct PmaBatchOutputUnified {
704    pub values: Vec<f64>,
705    pub rows: usize,
706    pub cols: usize,
707}
708
709#[inline(always)]
710pub fn expand_grid(_r: &PmaBatchRange) -> Vec<PmaParams> {
711    vec![PmaParams {}]
712}
713
714#[inline(always)]
715pub fn pma_batch_slice(
716    data: &[f64],
717    sweep: &PmaBatchRange,
718    kern: Kernel,
719) -> Result<PmaBatchOutput, PmaError> {
720    pma_batch_inner(data, sweep, kern, false)
721}
722
723#[inline(always)]
724pub fn pma_batch_par_slice(
725    data: &[f64],
726    sweep: &PmaBatchRange,
727    kern: Kernel,
728) -> Result<PmaBatchOutput, PmaError> {
729    pma_batch_inner(data, sweep, kern, true)
730}
731
732#[inline(always)]
733fn pma_batch_inner(
734    data: &[f64],
735    _sweep: &PmaBatchRange,
736    kern: Kernel,
737    _parallel: bool,
738) -> Result<PmaBatchOutput, PmaError> {
739    let first = pma_first_valid_idx(data)?;
740    let out = match kern {
741        Kernel::Scalar => pma_scalar(data, first)?,
742        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
743        Kernel::Avx2 => pma_avx2(data, first)?,
744        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
745        Kernel::Avx512 => pma_avx512(data, first)?,
746        _ => unreachable!(),
747    };
748    Ok(PmaBatchOutput {
749        predict: out.predict,
750        trigger: out.trigger,
751        rows: 1,
752        cols: data.len(),
753    })
754}
755
756#[inline(always)]
757pub unsafe fn pma_row_scalar(
758    data: &[f64],
759    first: usize,
760    _stride: usize,
761    _dummy: *const f64,
762    _inv_n: f64,
763    out_predict: &mut [f64],
764    out_trigger: &mut [f64],
765) {
766    pma_compute_into(data, first, Kernel::Scalar, out_predict, out_trigger);
767}
768
769#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
770#[inline(always)]
771pub unsafe fn pma_row_avx2(
772    data: &[f64],
773    first: usize,
774    stride: usize,
775    dummy: *const f64,
776    inv_n: f64,
777    out_predict: &mut [f64],
778    out_trigger: &mut [f64],
779) {
780    pma_row_scalar(data, first, stride, dummy, inv_n, out_predict, out_trigger);
781}
782
783#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
784#[inline(always)]
785pub unsafe fn pma_row_avx512(
786    data: &[f64],
787    first: usize,
788    stride: usize,
789    dummy: *const f64,
790    inv_n: f64,
791    out_predict: &mut [f64],
792    out_trigger: &mut [f64],
793) {
794    pma_row_scalar(data, first, stride, dummy, inv_n, out_predict, out_trigger);
795}
796
797#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
798#[inline(always)]
799pub unsafe fn pma_row_avx512_short(
800    data: &[f64],
801    first: usize,
802    stride: usize,
803    dummy: *const f64,
804    inv_n: f64,
805    out_predict: &mut [f64],
806    out_trigger: &mut [f64],
807) {
808    pma_row_scalar(data, first, stride, dummy, inv_n, out_predict, out_trigger);
809}
810
811#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
812#[inline(always)]
813pub unsafe fn pma_row_avx512_long(
814    data: &[f64],
815    first: usize,
816    stride: usize,
817    dummy: *const f64,
818    inv_n: f64,
819    out_predict: &mut [f64],
820    out_trigger: &mut [f64],
821) {
822    pma_row_scalar(data, first, stride, dummy, inv_n, out_predict, out_trigger);
823}
824
825#[inline]
826pub fn pma_into_slice(
827    predict_dst: &mut [f64],
828    trigger_dst: &mut [f64],
829    input: &PmaInput,
830    kern: Kernel,
831) -> Result<(), PmaError> {
832    let data = input.as_ref();
833
834    if predict_dst.len() != data.len() || trigger_dst.len() != data.len() {
835        return Err(PmaError::OutputLengthMismatch {
836            expected: data.len(),
837            got: predict_dst.len().min(trigger_dst.len()),
838        });
839    }
840
841    let first = pma_first_valid_idx(data)?;
842
843    let chosen = match kern {
844        Kernel::Auto => Kernel::Scalar,
845        k => k,
846    };
847
848    pma_compute_into(data, first, chosen, predict_dst, trigger_dst);
849
850    let warm_end = first + 7 - 1;
851    for v in &mut predict_dst[..warm_end] {
852        *v = f64::NAN;
853    }
854    for v in &mut trigger_dst[..warm_end] {
855        *v = f64::NAN;
856    }
857
858    Ok(())
859}
860
861#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
862#[inline]
863pub fn pma_into(
864    input: &PmaInput,
865    predict_out: &mut [f64],
866    trigger_out: &mut [f64],
867) -> Result<(), PmaError> {
868    pma_into_slice(predict_out, trigger_out, input, Kernel::Auto)
869}
870
871#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
872#[wasm_bindgen]
873pub fn pma_js(data: &[f64]) -> Result<Vec<f64>, JsValue> {
874    let input = PmaInput::from_slice(data, PmaParams {});
875    let rows = 2usize;
876    let cols = data.len();
877    let total = rows
878        .checked_mul(cols)
879        .ok_or_else(|| JsValue::from_str(&PmaError::SizeOverflow { rows, cols }.to_string()))?;
880    let mut values = vec![0.0; total];
881    {
882        let (pred, trig) = values.split_at_mut(cols);
883        pma_into_slice(pred, trig, &input, detect_best_kernel())
884            .map_err(|e| JsValue::from_str(&e.to_string()))?;
885    }
886
887    Ok(values)
888}
889
890#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
891#[wasm_bindgen]
892pub fn pma_into(
893    in_ptr: *const f64,
894    predict_ptr: *mut f64,
895    trigger_ptr: *mut f64,
896    len: usize,
897) -> Result<(), JsValue> {
898    if in_ptr.is_null() || predict_ptr.is_null() || trigger_ptr.is_null() {
899        return Err(JsValue::from_str("Null pointer provided"));
900    }
901
902    unsafe {
903        let data = std::slice::from_raw_parts(in_ptr, len);
904        let params = PmaParams {};
905        let input = PmaInput::from_slice(data, params);
906
907        let need_temp =
908            in_ptr == predict_ptr || in_ptr == trigger_ptr || predict_ptr == trigger_ptr;
909
910        if need_temp {
911            let mut temp_predict = vec![0.0; len];
912            let mut temp_trigger = vec![0.0; len];
913
914            pma_into_slice(&mut temp_predict, &mut temp_trigger, &input, Kernel::Auto)
915                .map_err(|e| JsValue::from_str(&e.to_string()))?;
916
917            let predict_out = std::slice::from_raw_parts_mut(predict_ptr, len);
918            let trigger_out = std::slice::from_raw_parts_mut(trigger_ptr, len);
919
920            predict_out.copy_from_slice(&temp_predict);
921            trigger_out.copy_from_slice(&temp_trigger);
922        } else {
923            let predict_out = std::slice::from_raw_parts_mut(predict_ptr, len);
924            let trigger_out = std::slice::from_raw_parts_mut(trigger_ptr, len);
925
926            pma_into_slice(predict_out, trigger_out, &input, Kernel::Auto)
927                .map_err(|e| JsValue::from_str(&e.to_string()))?;
928        }
929
930        Ok(())
931    }
932}
933
934#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
935#[wasm_bindgen]
936pub fn pma_alloc(len: usize) -> *mut f64 {
937    let mut vec = Vec::<f64>::with_capacity(len);
938    let ptr = vec.as_mut_ptr();
939    std::mem::forget(vec);
940    ptr
941}
942
943#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
944#[wasm_bindgen]
945pub fn pma_free(ptr: *mut f64, len: usize) {
946    if !ptr.is_null() {
947        unsafe {
948            let _ = Vec::from_raw_parts(ptr, len, len);
949        }
950    }
951}
952
953#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
954#[derive(Serialize, Deserialize)]
955pub struct PmaBatchConfig {
956    pub dummy: Option<usize>,
957}
958
959#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
960#[derive(Serialize, Deserialize)]
961pub struct PmaJsOutput {
962    pub values: Vec<f64>,
963    pub rows: usize,
964    pub cols: usize,
965}
966
967#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
968#[derive(Serialize, Deserialize)]
969pub struct PmaBatchJsOutput {
970    pub predict: Vec<f64>,
971    pub trigger: Vec<f64>,
972    pub rows: usize,
973    pub cols: usize,
974}
975
976#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
977#[wasm_bindgen]
978pub fn pma_batch(data: &[f64]) -> Result<JsValue, JsValue> {
979    let input = PmaInput::from_slice(data, PmaParams {});
980    let mut predict = vec![0.0; data.len()];
981    let mut trigger = vec![0.0; data.len()];
982
983    pma_into_slice(&mut predict, &mut trigger, &input, detect_best_kernel())
984        .map_err(|e| JsValue::from_str(&e.to_string()))?;
985
986    let output = PmaBatchJsOutput {
987        predict,
988        trigger,
989        rows: 1,
990        cols: data.len(),
991    };
992
993    serde_wasm_bindgen::to_value(&output)
994        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
995}
996
997#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
998#[wasm_bindgen]
999pub fn pma_unified_into(
1000    in_ptr: *const f64,
1001    out_ptr: *mut f64,
1002    len: usize,
1003) -> Result<usize, JsValue> {
1004    if in_ptr.is_null() || out_ptr.is_null() {
1005        return Err(JsValue::from_str("null pointer"));
1006    }
1007    let rows = 2usize;
1008    let cols = len;
1009    let total = rows
1010        .checked_mul(cols)
1011        .ok_or_else(|| JsValue::from_str(&PmaError::SizeOverflow { rows, cols }.to_string()))?;
1012    unsafe {
1013        let data = std::slice::from_raw_parts(in_ptr, len);
1014        let out = std::slice::from_raw_parts_mut(out_ptr, total);
1015        let input = PmaInput::from_slice(data, PmaParams {});
1016        let (pred, trig) = out.split_at_mut(cols);
1017        pma_into_slice(pred, trig, &input, detect_best_kernel())
1018            .map_err(|e| JsValue::from_str(&e.to_string()))?;
1019    }
1020    Ok(rows)
1021}
1022
1023#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1024#[wasm_bindgen]
1025pub fn pma_batch_into(
1026    in_ptr: *const f64,
1027    predict_ptr: *mut f64,
1028    trigger_ptr: *mut f64,
1029    len: usize,
1030) -> Result<usize, JsValue> {
1031    pma_into(in_ptr, predict_ptr, trigger_ptr, len)?;
1032    Ok(1)
1033}
1034
1035#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1036#[wasm_bindgen]
1037pub struct PmaStreamWasm {
1038    stream: PmaStream,
1039}
1040
1041#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1042#[wasm_bindgen]
1043impl PmaStreamWasm {
1044    #[wasm_bindgen(constructor)]
1045    pub fn new() -> Result<PmaStreamWasm, JsValue> {
1046        let params = PmaParams {};
1047        let stream = PmaStream::try_new(params).map_err(|e| JsValue::from_str(&e.to_string()))?;
1048        Ok(PmaStreamWasm { stream })
1049    }
1050
1051    pub fn update(&mut self, value: f64) -> Result<Vec<f64>, JsValue> {
1052        match self.stream.update(value) {
1053            Some((predict, trigger)) => Ok(vec![predict, trigger]),
1054            None => Ok(vec![f64::NAN, f64::NAN]),
1055        }
1056    }
1057}
1058
1059#[cfg(feature = "python")]
1060use crate::utilities::kernel_validation::validate_kernel;
1061#[cfg(all(feature = "python", feature = "cuda"))]
1062use numpy::PyUntypedArrayMethods;
1063#[cfg(feature = "python")]
1064use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
1065#[cfg(feature = "python")]
1066use pyo3::exceptions::PyValueError;
1067#[cfg(feature = "python")]
1068use pyo3::prelude::*;
1069#[cfg(feature = "python")]
1070use pyo3::types::PyDict;
1071
1072#[cfg(all(feature = "python", feature = "cuda"))]
1073use crate::cuda::{cuda_available, moving_averages::CudaPma};
1074#[cfg(all(feature = "python", feature = "cuda"))]
1075use crate::indicators::moving_averages::alma::{make_device_array_py, DeviceArrayF32Py};
1076
1077#[cfg(feature = "python")]
1078#[pyfunction(name = "pma")]
1079#[pyo3(signature = (data, kernel=None))]
1080pub fn pma_py<'py>(
1081    py: Python<'py>,
1082    data: PyReadonlyArray1<'py, f64>,
1083    kernel: Option<&str>,
1084) -> PyResult<(Bound<'py, PyArray1<f64>>, Bound<'py, PyArray1<f64>>)> {
1085    let slice_in = data.as_slice()?;
1086    let kern = validate_kernel(kernel, false)?;
1087
1088    let input = PmaInput::from_slice(slice_in, PmaParams {});
1089
1090    let out = py
1091        .allow_threads(|| pma_with_kernel(&input, kern))
1092        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1093
1094    Ok((out.predict.into_pyarray(py), out.trigger.into_pyarray(py)))
1095}
1096
1097#[cfg(feature = "python")]
1098#[pyclass(name = "PmaStream")]
1099pub struct PmaStreamPy {
1100    stream: PmaStream,
1101}
1102
1103#[cfg(feature = "python")]
1104#[pymethods]
1105impl PmaStreamPy {
1106    #[new]
1107    fn new() -> PyResult<Self> {
1108        let params = PmaParams {};
1109        let stream =
1110            PmaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1111        Ok(PmaStreamPy { stream })
1112    }
1113
1114    fn update(&mut self, value: f64) -> Option<(f64, f64)> {
1115        self.stream.update(value)
1116    }
1117}
1118
1119#[cfg(feature = "python")]
1120#[pyfunction(name = "pma_batch")]
1121#[pyo3(signature = (data, kernel=None))]
1122pub fn pma_batch_py<'py>(
1123    py: Python<'py>,
1124    data: PyReadonlyArray1<'py, f64>,
1125    kernel: Option<&str>,
1126) -> PyResult<Bound<'py, PyDict>> {
1127    let slice_in = data.as_slice()?;
1128    let kern = validate_kernel(kernel, true)?;
1129    let (rows, cols) = (2usize, slice_in.len());
1130    let size = rows
1131        .checked_mul(cols)
1132        .ok_or_else(|| PyValueError::new_err(PmaError::SizeOverflow { rows, cols }.to_string()))?;
1133
1134    let values_arr = unsafe { PyArray1::<f64>::new(py, [size], false) };
1135    let values_slice = unsafe { values_arr.as_slice_mut()? };
1136
1137    py.allow_threads(|| -> PyResult<()> {
1138        let first =
1139            pma_first_valid_idx(slice_in).map_err(|e| PyValueError::new_err(e.to_string()))?;
1140
1141        let warm = first + 7 - 1;
1142        let warm_prefixes = [warm; 2];
1143        let values_mu: &mut [core::mem::MaybeUninit<f64>] = unsafe {
1144            core::slice::from_raw_parts_mut(
1145                values_slice.as_mut_ptr() as *mut core::mem::MaybeUninit<f64>,
1146                values_slice.len(),
1147            )
1148        };
1149        init_matrix_prefixes(values_mu, cols, &warm_prefixes);
1150
1151        let (row0, row1) = values_slice.split_at_mut(cols);
1152        pma_compute_into(
1153            slice_in,
1154            first,
1155            match kern {
1156                Kernel::Auto => Kernel::Scalar,
1157                Kernel::ScalarBatch => Kernel::Scalar,
1158                Kernel::Avx2Batch => Kernel::Avx2,
1159                Kernel::Avx512Batch => Kernel::Avx512,
1160                _ => Kernel::Scalar,
1161            },
1162            row0,
1163            row1,
1164        );
1165        Ok(())
1166    })?;
1167
1168    let dict = PyDict::new(py);
1169    dict.set_item("values", values_arr.reshape((rows, cols))?)?;
1170    dict.set_item("rows", rows)?;
1171    dict.set_item("cols", cols)?;
1172    Ok(dict)
1173}
1174
1175#[cfg(all(feature = "python", feature = "cuda"))]
1176#[pyfunction(name = "pma_cuda_batch_dev")]
1177#[pyo3(signature = (data_f32, device_id=0))]
1178pub fn pma_cuda_batch_dev_py(
1179    py: Python<'_>,
1180    data_f32: numpy::PyReadonlyArray1<'_, f32>,
1181    device_id: usize,
1182) -> PyResult<(DeviceArrayF32Py, DeviceArrayF32Py)> {
1183    if !cuda_available() {
1184        return Err(PyValueError::new_err("CUDA not available"));
1185    }
1186    let slice_in = data_f32.as_slice()?;
1187    let sweep = PmaBatchRange::default();
1188    let pair = py.allow_threads(|| {
1189        let cuda = CudaPma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1190        cuda.pma_batch_dev(slice_in, &sweep)
1191            .map_err(|e| PyValueError::new_err(e.to_string()))
1192    })?;
1193    let predict = make_device_array_py(device_id, pair.predict)?;
1194    let trigger = make_device_array_py(device_id, pair.trigger)?;
1195    Ok((predict, trigger))
1196}
1197
1198#[cfg(all(feature = "python", feature = "cuda"))]
1199#[pyfunction(name = "pma_cuda_many_series_one_param_dev")]
1200#[pyo3(signature = (data_tm_f32, device_id=0))]
1201pub fn pma_cuda_many_series_one_param_dev_py(
1202    py: Python<'_>,
1203    data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1204    device_id: usize,
1205) -> PyResult<(DeviceArrayF32Py, DeviceArrayF32Py)> {
1206    if !cuda_available() {
1207        return Err(PyValueError::new_err("CUDA not available"));
1208    }
1209    let shape = data_tm_f32.shape();
1210    if shape.len() != 2 {
1211        return Err(PyValueError::new_err("expected time-major 2D array"));
1212    }
1213    let rows = shape[0];
1214    let cols = shape[1];
1215    let flat = data_tm_f32.as_slice()?;
1216    let pair = py.allow_threads(|| {
1217        let cuda = CudaPma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1218        cuda.pma_many_series_one_param_time_major_dev(flat, cols, rows)
1219            .map_err(|e| PyValueError::new_err(e.to_string()))
1220    })?;
1221    let predict = make_device_array_py(device_id, pair.predict)?;
1222    let trigger = make_device_array_py(device_id, pair.trigger)?;
1223    Ok((predict, trigger))
1224}
1225
1226#[cfg(test)]
1227mod tests {
1228    use super::*;
1229    use crate::skip_if_unsupported;
1230    use crate::utilities::data_loader::read_candles_from_csv;
1231
1232    fn check_pma_default_candles(
1233        test_name: &str,
1234        kernel: Kernel,
1235    ) -> Result<(), Box<dyn std::error::Error>> {
1236        skip_if_unsupported!(kernel, test_name);
1237        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1238        let candles = read_candles_from_csv(file_path)?;
1239        let input = PmaInput::with_default_candles(&candles);
1240        let output = pma_with_kernel(&input, kernel)?;
1241        assert_eq!(output.predict.len(), candles.close.len());
1242        assert_eq!(output.trigger.len(), candles.close.len());
1243        Ok(())
1244    }
1245
1246    fn check_pma_with_slice(
1247        test_name: &str,
1248        kernel: Kernel,
1249    ) -> Result<(), Box<dyn std::error::Error>> {
1250        skip_if_unsupported!(kernel, test_name);
1251        let data = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0];
1252        let input = PmaInput::from_slice(&data, PmaParams {});
1253        let output = pma_with_kernel(&input, kernel)?;
1254        assert_eq!(output.predict.len(), data.len());
1255        assert_eq!(output.trigger.len(), data.len());
1256        Ok(())
1257    }
1258
1259    fn check_pma_not_enough_data(
1260        test_name: &str,
1261        kernel: Kernel,
1262    ) -> Result<(), Box<dyn std::error::Error>> {
1263        skip_if_unsupported!(kernel, test_name);
1264        let data = [10.0, 20.0, 30.0];
1265        let input = PmaInput::from_slice(&data, PmaParams {});
1266        let result = pma_with_kernel(&input, kernel);
1267        assert!(result.is_err(), "Expected error for not enough data");
1268        Ok(())
1269    }
1270
1271    fn check_pma_all_values_nan(
1272        test_name: &str,
1273        kernel: Kernel,
1274    ) -> Result<(), Box<dyn std::error::Error>> {
1275        skip_if_unsupported!(kernel, test_name);
1276        let data = [f64::NAN, f64::NAN, f64::NAN];
1277        let input = PmaInput::from_slice(&data, PmaParams {});
1278        let result = pma_with_kernel(&input, kernel);
1279        assert!(result.is_err(), "Expected error for all values NaN");
1280        Ok(())
1281    }
1282
1283    fn check_pma_expected_values(
1284        test_name: &str,
1285        kernel: Kernel,
1286    ) -> Result<(), Box<dyn std::error::Error>> {
1287        skip_if_unsupported!(kernel, test_name);
1288        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1289        let candles = read_candles_from_csv(file_path)?;
1290        let input = PmaInput::from_candles(&candles, "hl2", PmaParams {});
1291        let result = pma_with_kernel(&input, kernel)?;
1292
1293        assert_eq!(
1294            result.predict.len(),
1295            candles.close.len(),
1296            "Predict length mismatch"
1297        );
1298        assert_eq!(
1299            result.trigger.len(),
1300            candles.close.len(),
1301            "Trigger length mismatch"
1302        );
1303
1304        let expected_predict = [
1305            59208.18749999999,
1306            59233.83609693878,
1307            59213.19132653061,
1308            59199.002551020414,
1309            58993.318877551,
1310        ];
1311        let expected_trigger = [
1312            59157.70790816327,
1313            59208.60076530612,
1314            59218.6763392857,
1315            59211.1443877551,
1316            59123.05019132652,
1317        ];
1318
1319        assert!(
1320            result.predict.len() >= 5,
1321            "Output length too short for checking"
1322        );
1323        let start_idx = result.predict.len() - 5;
1324        for i in 0..5 {
1325            let calc_val = result.predict[start_idx + i];
1326            let exp_val = expected_predict[i];
1327            assert!(
1328                (calc_val - exp_val).abs() < 1e-1,
1329                "Mismatch in predict at index {}: expected {}, got {}",
1330                start_idx + i,
1331                exp_val,
1332                calc_val
1333            );
1334        }
1335        for i in 0..5 {
1336            let calc_val = result.trigger[start_idx + i];
1337            let exp_val = expected_trigger[i];
1338            assert!(
1339                (calc_val - exp_val).abs() < 1e-1,
1340                "Mismatch in trigger at index {}: expected {}, got {}",
1341                start_idx + i,
1342                exp_val,
1343                calc_val
1344            );
1345        }
1346        Ok(())
1347    }
1348
1349    macro_rules! generate_all_pma_tests {
1350        ($($test_fn:ident),*) => {
1351            paste::paste! {
1352                $(
1353                    #[test]
1354                    fn [<$test_fn _scalar_f64>]() {
1355                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1356                    }
1357                )*
1358                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1359                $(
1360                    #[test]
1361                    fn [<$test_fn _avx2_f64>]() {
1362                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1363                    }
1364                    #[test]
1365                    fn [<$test_fn _avx512_f64>]() {
1366                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1367                    }
1368                )*
1369            }
1370        }
1371    }
1372
1373    #[cfg(debug_assertions)]
1374    fn check_pma_no_poison(
1375        test_name: &str,
1376        kernel: Kernel,
1377    ) -> Result<(), Box<dyn std::error::Error>> {
1378        skip_if_unsupported!(kernel, test_name);
1379
1380        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1381        let candles = read_candles_from_csv(file_path)?;
1382
1383        let test_sources = vec![
1384            "close", "open", "high", "low", "hl2", "hlc3", "ohlc4", "volume",
1385        ];
1386
1387        for (source_idx, source) in test_sources.iter().enumerate() {
1388            let input = PmaInput::from_candles(&candles, source, PmaParams {});
1389            let output = pma_with_kernel(&input, kernel)?;
1390
1391            for (i, &val) in output.predict.iter().enumerate() {
1392                if val.is_nan() {
1393                    continue;
1394                }
1395
1396                let bits = val.to_bits();
1397
1398                if bits == 0x11111111_11111111 {
1399                    panic!(
1400                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1401						 in predict array with source: {} (source set {})",
1402                        test_name, val, bits, i, source, source_idx
1403                    );
1404                }
1405
1406                if bits == 0x22222222_22222222 {
1407                    panic!(
1408                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1409						 in predict array with source: {} (source set {})",
1410                        test_name, val, bits, i, source, source_idx
1411                    );
1412                }
1413
1414                if bits == 0x33333333_33333333 {
1415                    panic!(
1416                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1417						 in predict array with source: {} (source set {})",
1418                        test_name, val, bits, i, source, source_idx
1419                    );
1420                }
1421            }
1422
1423            for (i, &val) in output.trigger.iter().enumerate() {
1424                if val.is_nan() {
1425                    continue;
1426                }
1427
1428                let bits = val.to_bits();
1429
1430                if bits == 0x11111111_11111111 {
1431                    panic!(
1432                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1433						 in trigger array with source: {} (source set {})",
1434                        test_name, val, bits, i, source, source_idx
1435                    );
1436                }
1437
1438                if bits == 0x22222222_22222222 {
1439                    panic!(
1440                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1441						 in trigger array with source: {} (source set {})",
1442                        test_name, val, bits, i, source, source_idx
1443                    );
1444                }
1445
1446                if bits == 0x33333333_33333333 {
1447                    panic!(
1448                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1449						 in trigger array with source: {} (source set {})",
1450                        test_name, val, bits, i, source, source_idx
1451                    );
1452                }
1453            }
1454        }
1455
1456        Ok(())
1457    }
1458
1459    #[cfg(not(debug_assertions))]
1460    fn check_pma_no_poison(
1461        _test_name: &str,
1462        _kernel: Kernel,
1463    ) -> Result<(), Box<dyn std::error::Error>> {
1464        Ok(())
1465    }
1466
1467    #[cfg(feature = "proptest")]
1468    #[allow(clippy::float_cmp)]
1469    fn check_pma_property(
1470        test_name: &str,
1471        kernel: Kernel,
1472    ) -> Result<(), Box<dyn std::error::Error>> {
1473        use proptest::prelude::*;
1474        skip_if_unsupported!(kernel, test_name);
1475
1476        let strat = prop::collection::vec(
1477            (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1478            7..400,
1479        );
1480
1481        proptest::test_runner::TestRunner::default().run(&strat, |data| {
1482            let input = PmaInput::from_slice(&data, PmaParams {});
1483
1484            let result = pma_with_kernel(&input, kernel)?;
1485            let ref_result = pma_with_kernel(&input, Kernel::Scalar)?;
1486
1487            prop_assert_eq!(result.predict.len(), data.len());
1488            prop_assert_eq!(result.trigger.len(), data.len());
1489            prop_assert_eq!(ref_result.predict.len(), data.len());
1490            prop_assert_eq!(ref_result.trigger.len(), data.len());
1491
1492            let warmup_period = 7;
1493
1494            for i in 0..warmup_period {
1495                prop_assert!(
1496                    result.predict[i].is_nan(),
1497                    "Expected NaN in predict warmup at index {}",
1498                    i
1499                );
1500                prop_assert!(
1501                    result.trigger[i].is_nan(),
1502                    "Expected NaN in trigger warmup at index {}",
1503                    i
1504                );
1505            }
1506
1507            if data.windows(2).all(|w| (w[0] - w[1]).abs() < f64::EPSILON)
1508                && data.len() >= warmup_period
1509            {
1510                for i in warmup_period..data.len() {
1511                    if result.predict[i].is_finite() {
1512                        prop_assert!(
1513                            (result.predict[i] - data[0]).abs() < 1e-9,
1514                            "Constant data test failed: predict[{}] = {} should be close to {}",
1515                            i,
1516                            result.predict[i],
1517                            data[0]
1518                        );
1519                    }
1520                }
1521            }
1522
1523            for i in warmup_period..data.len() {
1524                if result.predict[i].is_finite() && ref_result.predict[i].is_finite() {
1525                    let diff_predict = (result.predict[i] - ref_result.predict[i]).abs();
1526                    prop_assert!(
1527                        diff_predict < 1e-10,
1528                        "Predict mismatch at index {}: kernel={}, scalar={}, diff={}",
1529                        i,
1530                        result.predict[i],
1531                        ref_result.predict[i],
1532                        diff_predict
1533                    );
1534                } else {
1535                    prop_assert_eq!(
1536                        result.predict[i].is_nan(),
1537                        ref_result.predict[i].is_nan(),
1538                        "NaN mismatch in predict at index {}",
1539                        i
1540                    );
1541                }
1542
1543                if result.trigger[i].is_finite() && ref_result.trigger[i].is_finite() {
1544                    let diff_trigger = (result.trigger[i] - ref_result.trigger[i]).abs();
1545                    prop_assert!(
1546                        diff_trigger < 1e-10,
1547                        "Trigger mismatch at index {}: kernel={}, scalar={}, diff={}",
1548                        i,
1549                        result.trigger[i],
1550                        ref_result.trigger[i],
1551                        diff_trigger
1552                    );
1553                } else {
1554                    prop_assert_eq!(
1555                        result.trigger[i].is_nan(),
1556                        ref_result.trigger[i].is_nan(),
1557                        "NaN mismatch in trigger at index {}",
1558                        i
1559                    );
1560                }
1561
1562                if i >= warmup_period && result.predict[i].is_finite() {
1563                    let window_start = i.saturating_sub(6);
1564                    let window_data = &data[window_start..=i];
1565                    let min_val = window_data.iter().fold(f64::INFINITY, |a, &b| a.min(b));
1566                    let max_val = window_data.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
1567
1568                    let tolerance = (max_val - min_val).abs() * 0.1 + 1e-9;
1569                    prop_assert!(
1570                        result.predict[i] >= min_val - tolerance
1571                            && result.predict[i] <= max_val + tolerance,
1572                        "Predict value {} at index {} outside bounds [{}, {}] with tolerance {}",
1573                        result.predict[i],
1574                        i,
1575                        min_val - tolerance,
1576                        max_val + tolerance,
1577                        tolerance
1578                    );
1579                }
1580
1581                if i == warmup_period && i >= 6 {
1582                    let wma1_expected = (7.0 * data[i]
1583                        + 6.0 * data[i - 1]
1584                        + 5.0 * data[i - 2]
1585                        + 4.0 * data[i - 3]
1586                        + 3.0 * data[i - 4]
1587                        + 2.0 * data[i - 5]
1588                        + data[i - 6])
1589                        / 28.0;
1590
1591                    if result.predict[i].is_finite() {
1592                        let window_start = i.saturating_sub(6);
1593                        let window = &data[window_start..=i];
1594                        let window_avg = window.iter().sum::<f64>() / window.len() as f64;
1595                        let min = window.iter().fold(f64::INFINITY, |a, &b| a.min(b));
1596                        let max = window.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
1597                        prop_assert!(
1598                            (result.predict[i] - window_avg).abs() < (max - min).abs() + 1e-9,
1599                            "Predict value {} at index {} seems unrelated to window average {}",
1600                            result.predict[i],
1601                            i,
1602                            window_avg
1603                        );
1604                    }
1605                }
1606
1607                if i >= warmup_period + 3
1608                    && result.trigger[i].is_finite()
1609                    && result.predict[i].is_finite()
1610                {
1611                    if result.predict[i - 1].is_finite()
1612                        && result.predict[i - 2].is_finite()
1613                        && result.predict[i - 3].is_finite()
1614                    {
1615                        let expected_trigger = (4.0 * result.predict[i]
1616                            + 3.0 * result.predict[i - 1]
1617                            + 2.0 * result.predict[i - 2]
1618                            + result.predict[i - 3])
1619                            / 10.0;
1620                        let trigger_diff = (result.trigger[i] - expected_trigger).abs();
1621                        prop_assert!(
1622                            trigger_diff < 1e-10,
1623                            "Trigger calculation error at index {}: expected {}, got {}, diff={}",
1624                            i,
1625                            expected_trigger,
1626                            result.trigger[i],
1627                            trigger_diff
1628                        );
1629                    }
1630                }
1631            }
1632
1633            if data.len() == 7 {
1634                prop_assert!(
1635                    result.predict[6].is_finite(),
1636                    "With exactly 7 points, predict[6] should be finite but got NaN"
1637                );
1638            }
1639
1640            Ok(())
1641        })?;
1642
1643        Ok(())
1644    }
1645
1646    generate_all_pma_tests!(
1647        check_pma_default_candles,
1648        check_pma_with_slice,
1649        check_pma_not_enough_data,
1650        check_pma_all_values_nan,
1651        check_pma_expected_values,
1652        check_pma_no_poison
1653    );
1654
1655    #[cfg(feature = "proptest")]
1656    generate_all_pma_tests!(check_pma_property);
1657
1658    fn check_batch_default_row(
1659        test: &str,
1660        kernel: Kernel,
1661    ) -> Result<(), Box<dyn std::error::Error>> {
1662        skip_if_unsupported!(kernel, test);
1663
1664        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1665        let c = read_candles_from_csv(file)?;
1666        let output = PmaBatchBuilder::new()
1667            .kernel(kernel)
1668            .apply_candles(&c, "close")?;
1669
1670        assert_eq!(output.rows, 1, "Expected exactly 1 row");
1671        assert_eq!(output.cols, c.close.len());
1672        assert_eq!(output.predict.len(), c.close.len());
1673        assert_eq!(output.trigger.len(), c.close.len());
1674
1675        let input = PmaInput::from_candles(&c, "close", PmaParams::default());
1676        let expected = pma_with_kernel(&input, kernel)?;
1677
1678        for (i, (&a, &b)) in output
1679            .predict
1680            .iter()
1681            .zip(expected.predict.iter())
1682            .enumerate()
1683        {
1684            if a.is_nan() && b.is_nan() {
1685                continue;
1686            }
1687            assert!(
1688                (a - b).abs() < 1e-12,
1689                "[{test}] predict mismatch at idx {i}: batch={}, direct={}",
1690                a,
1691                b
1692            );
1693        }
1694        for (i, (&a, &b)) in output
1695            .trigger
1696            .iter()
1697            .zip(expected.trigger.iter())
1698            .enumerate()
1699        {
1700            if a.is_nan() && b.is_nan() {
1701                continue;
1702            }
1703            assert!(
1704                (a - b).abs() < 1e-12,
1705                "[{test}] trigger mismatch at idx {i}: batch={}, direct={}",
1706                a,
1707                b
1708            );
1709        }
1710        Ok(())
1711    }
1712
1713    macro_rules! gen_batch_tests {
1714        ($fn_name:ident) => {
1715            paste::paste! {
1716                #[test] fn [<$fn_name _scalar>]()      {
1717                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1718                }
1719                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1720                #[test] fn [<$fn_name _avx2>]()        {
1721                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1722                }
1723                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1724                #[test] fn [<$fn_name _avx512>]()      {
1725                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1726                }
1727                #[test] fn [<$fn_name _auto_detect>]() {
1728                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1729                }
1730            }
1731        };
1732    }
1733    #[cfg(debug_assertions)]
1734    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
1735        skip_if_unsupported!(kernel, test);
1736
1737        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1738        let c = read_candles_from_csv(file)?;
1739
1740        let test_sources = vec!["close", "open", "high", "low", "hl2", "hlc3", "ohlc4"];
1741
1742        for (source_idx, source) in test_sources.iter().enumerate() {
1743            let output = PmaBatchBuilder::new()
1744                .kernel(kernel)
1745                .apply_candles(&c, source)?;
1746
1747            for (idx, &val) in output.predict.iter().enumerate() {
1748                if val.is_nan() {
1749                    continue;
1750                }
1751
1752                let bits = val.to_bits();
1753
1754                if bits == 0x11111111_11111111 {
1755                    panic!(
1756                        "[{}] Source {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1757						 at index {} in predict array with source: {}",
1758                        test, source_idx, val, bits, idx, source
1759                    );
1760                }
1761
1762                if bits == 0x22222222_22222222 {
1763                    panic!(
1764                        "[{}] Source {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1765						 at index {} in predict array with source: {}",
1766                        test, source_idx, val, bits, idx, source
1767                    );
1768                }
1769
1770                if bits == 0x33333333_33333333 {
1771                    panic!(
1772                        "[{}] Source {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1773						 at index {} in predict array with source: {}",
1774                        test, source_idx, val, bits, idx, source
1775                    );
1776                }
1777            }
1778
1779            for (idx, &val) in output.trigger.iter().enumerate() {
1780                if val.is_nan() {
1781                    continue;
1782                }
1783
1784                let bits = val.to_bits();
1785
1786                if bits == 0x11111111_11111111 {
1787                    panic!(
1788                        "[{}] Source {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1789						 at index {} in trigger array with source: {}",
1790                        test, source_idx, val, bits, idx, source
1791                    );
1792                }
1793
1794                if bits == 0x22222222_22222222 {
1795                    panic!(
1796                        "[{}] Source {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1797						 at index {} in trigger array with source: {}",
1798                        test, source_idx, val, bits, idx, source
1799                    );
1800                }
1801
1802                if bits == 0x33333333_33333333 {
1803                    panic!(
1804                        "[{}] Source {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1805						 at index {} in trigger array with source: {}",
1806                        test, source_idx, val, bits, idx, source
1807                    );
1808                }
1809            }
1810        }
1811
1812        Ok(())
1813    }
1814
1815    #[cfg(not(debug_assertions))]
1816    fn check_batch_no_poison(
1817        _test: &str,
1818        _kernel: Kernel,
1819    ) -> Result<(), Box<dyn std::error::Error>> {
1820        Ok(())
1821    }
1822
1823    gen_batch_tests!(check_batch_default_row);
1824    gen_batch_tests!(check_batch_no_poison);
1825
1826    #[test]
1827    fn test_pma_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
1828        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1829        let candles = read_candles_from_csv(file_path)?;
1830        let input = PmaInput::with_default_candles(&candles);
1831
1832        let base = pma_with_kernel(&input, Kernel::Auto)?;
1833
1834        let n = candles.close.len();
1835        let mut out_predict = vec![0.0; n];
1836        let mut out_trigger = vec![0.0; n];
1837
1838        #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1839        {
1840            pma_into(&input, &mut out_predict, &mut out_trigger)?;
1841        }
1842
1843        assert_eq!(base.predict.len(), out_predict.len());
1844        assert_eq!(base.trigger.len(), out_trigger.len());
1845
1846        fn eq_or_both_nan_eps(a: f64, b: f64) -> bool {
1847            (a.is_nan() && b.is_nan()) || (a - b).abs() <= 1e-12
1848        }
1849
1850        for i in 0..n {
1851            assert!(
1852                eq_or_both_nan_eps(base.predict[i], out_predict[i]),
1853                "predict mismatch at {i}: api={}, into={}",
1854                base.predict[i],
1855                out_predict[i]
1856            );
1857            assert!(
1858                eq_or_both_nan_eps(base.trigger[i], out_trigger[i]),
1859                "trigger mismatch at {i}: api={}, into={}",
1860                base.trigger[i],
1861                out_trigger[i]
1862            );
1863        }
1864
1865        Ok(())
1866    }
1867}