Skip to main content

vector_ta/indicators/
ppo.rs

1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::cuda_available;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::cuda::oscillators::ppo_wrapper::{CudaPpo, DeviceArrayF32Ppo};
5#[cfg(all(feature = "python", feature = "cuda"))]
6use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
7#[cfg(all(feature = "python", feature = "cuda"))]
8use cust::memory::DeviceBuffer;
9#[cfg(all(feature = "python", feature = "cuda"))]
10use numpy::PyUntypedArrayMethods;
11#[cfg(feature = "python")]
12use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
13#[cfg(feature = "python")]
14use pyo3::exceptions::PyValueError;
15#[cfg(feature = "python")]
16use pyo3::prelude::*;
17#[cfg(feature = "python")]
18use pyo3::types::{PyAny, PyDict};
19#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
20use serde::{Deserialize, Serialize};
21#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
22use wasm_bindgen::prelude::*;
23
24use crate::indicators::moving_averages::ma::{ma, MaData};
25use crate::utilities::data_loader::{source_type, Candles};
26use crate::utilities::enums::Kernel;
27use crate::utilities::helpers::{
28    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
29    make_uninit_matrix,
30};
31#[cfg(feature = "python")]
32use crate::utilities::kernel_validation::validate_kernel;
33#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
34use core::arch::x86_64::*;
35#[cfg(not(target_arch = "wasm32"))]
36use rayon::prelude::*;
37use std::collections::HashMap;
38use std::convert::AsRef;
39use std::error::Error;
40use std::mem::MaybeUninit;
41use thiserror::Error;
42
43impl<'a> AsRef<[f64]> for PpoInput<'a> {
44    #[inline(always)]
45    fn as_ref(&self) -> &[f64] {
46        match &self.data {
47            PpoData::Slice(slice) => slice,
48            PpoData::Candles { candles, source } => source_type(candles, source),
49        }
50    }
51}
52
53#[derive(Debug, Clone)]
54pub enum PpoData<'a> {
55    Candles {
56        candles: &'a Candles,
57        source: &'a str,
58    },
59    Slice(&'a [f64]),
60}
61
62#[derive(Debug, Clone)]
63pub struct PpoOutput {
64    pub values: Vec<f64>,
65}
66
67#[derive(Debug, Clone)]
68#[cfg_attr(
69    all(target_arch = "wasm32", feature = "wasm"),
70    derive(Serialize, Deserialize)
71)]
72pub struct PpoParams {
73    pub fast_period: Option<usize>,
74    pub slow_period: Option<usize>,
75    pub ma_type: Option<String>,
76}
77
78impl Default for PpoParams {
79    fn default() -> Self {
80        Self {
81            fast_period: Some(12),
82            slow_period: Some(26),
83            ma_type: Some("sma".to_string()),
84        }
85    }
86}
87
88#[derive(Debug, Clone)]
89pub struct PpoInput<'a> {
90    pub data: PpoData<'a>,
91    pub params: PpoParams,
92}
93
94impl<'a> PpoInput<'a> {
95    #[inline]
96    pub fn from_candles(c: &'a Candles, s: &'a str, p: PpoParams) -> Self {
97        Self {
98            data: PpoData::Candles {
99                candles: c,
100                source: s,
101            },
102            params: p,
103        }
104    }
105    #[inline]
106    pub fn from_slice(sl: &'a [f64], p: PpoParams) -> Self {
107        Self {
108            data: PpoData::Slice(sl),
109            params: p,
110        }
111    }
112    #[inline]
113    pub fn with_default_candles(c: &'a Candles) -> Self {
114        Self::from_candles(c, "close", PpoParams::default())
115    }
116    #[inline]
117    pub fn get_fast_period(&self) -> usize {
118        self.params.fast_period.unwrap_or(12)
119    }
120    #[inline]
121    pub fn get_slow_period(&self) -> usize {
122        self.params.slow_period.unwrap_or(26)
123    }
124    #[inline]
125    pub fn get_ma_type(&self) -> String {
126        self.params
127            .ma_type
128            .clone()
129            .unwrap_or_else(|| "sma".to_string())
130    }
131}
132
133#[derive(Clone, Debug)]
134pub struct PpoBuilder {
135    fast_period: Option<usize>,
136    slow_period: Option<usize>,
137    ma_type: Option<String>,
138    kernel: Kernel,
139}
140
141impl Default for PpoBuilder {
142    fn default() -> Self {
143        Self {
144            fast_period: None,
145            slow_period: None,
146            ma_type: None,
147            kernel: Kernel::Auto,
148        }
149    }
150}
151
152impl PpoBuilder {
153    #[inline(always)]
154    pub fn new() -> Self {
155        Self::default()
156    }
157    #[inline(always)]
158    pub fn fast_period(mut self, n: usize) -> Self {
159        self.fast_period = Some(n);
160        self
161    }
162    #[inline(always)]
163    pub fn slow_period(mut self, n: usize) -> Self {
164        self.slow_period = Some(n);
165        self
166    }
167    #[inline(always)]
168    pub fn ma_type<S: Into<String>>(mut self, s: S) -> Self {
169        self.ma_type = Some(s.into());
170        self
171    }
172    #[inline(always)]
173    pub fn kernel(mut self, k: Kernel) -> Self {
174        self.kernel = k;
175        self
176    }
177    #[inline(always)]
178    pub fn apply(self, c: &Candles) -> Result<PpoOutput, PpoError> {
179        let p = PpoParams {
180            fast_period: self.fast_period,
181            slow_period: self.slow_period,
182            ma_type: self.ma_type,
183        };
184        let i = PpoInput::from_candles(c, "close", p);
185        ppo_with_kernel(&i, self.kernel)
186    }
187    #[inline(always)]
188    pub fn apply_slice(self, d: &[f64]) -> Result<PpoOutput, PpoError> {
189        let p = PpoParams {
190            fast_period: self.fast_period,
191            slow_period: self.slow_period,
192            ma_type: self.ma_type,
193        };
194        let i = PpoInput::from_slice(d, p);
195        ppo_with_kernel(&i, self.kernel)
196    }
197    #[inline(always)]
198    pub fn into_stream(self) -> Result<PpoStream, PpoError> {
199        let p = PpoParams {
200            fast_period: self.fast_period,
201            slow_period: self.slow_period,
202            ma_type: self.ma_type,
203        };
204        PpoStream::try_new(p)
205    }
206}
207
208#[derive(Debug, Error)]
209pub enum PpoError {
210    #[error("ppo: Empty data provided.")]
211    EmptyInputData,
212    #[error("ppo: All values are NaN.")]
213    AllValuesNaN,
214    #[error("ppo: Invalid period: fast = {fast}, slow = {slow}, data length = {data_len}")]
215    InvalidPeriod {
216        fast: usize,
217        slow: usize,
218        data_len: usize,
219    },
220    #[error("ppo: Not enough valid data: needed = {needed}, valid = {valid}")]
221    NotEnoughValidData { needed: usize, valid: usize },
222    #[error("ppo: output length mismatch: expected = {expected}, got = {got}")]
223    OutputLengthMismatch { expected: usize, got: usize },
224    #[error("ppo: Invalid range: start = {start}, end = {end}, step = {step}")]
225    InvalidRange {
226        start: usize,
227        end: usize,
228        step: usize,
229    },
230    #[error("ppo: invalid kernel for batch path: {0:?}")]
231    InvalidKernelForBatch(Kernel),
232    #[error("ppo: invalid input: {0}")]
233    InvalidInput(String),
234    #[error("ppo: MA error: {0}")]
235    MaError(String),
236}
237
238#[inline]
239pub fn ppo(input: &PpoInput) -> Result<PpoOutput, PpoError> {
240    ppo_with_kernel(input, Kernel::Auto)
241}
242
243pub fn ppo_with_kernel(input: &PpoInput, kernel: Kernel) -> Result<PpoOutput, PpoError> {
244    let data: &[f64] = input.as_ref();
245    let len = data.len();
246    if len == 0 {
247        return Err(PpoError::EmptyInputData);
248    }
249    let first = data
250        .iter()
251        .position(|x| !x.is_nan())
252        .ok_or(PpoError::AllValuesNaN)?;
253
254    let fast = input.get_fast_period();
255    let slow = input.get_slow_period();
256    let ma_type = input.params.ma_type.as_deref().unwrap_or("sma");
257
258    if fast == 0 || slow == 0 || fast > len || slow > len {
259        return Err(PpoError::InvalidPeriod {
260            fast,
261            slow,
262            data_len: len,
263        });
264    }
265    if (len - first) < slow {
266        return Err(PpoError::NotEnoughValidData {
267            needed: slow,
268            valid: len - first,
269        });
270    }
271
272    let chosen = match kernel {
273        Kernel::Auto => Kernel::Scalar,
274        k => k,
275    };
276
277    let mut out = alloc_with_nan_prefix(len, first + slow - 1);
278
279    unsafe {
280        match chosen {
281            Kernel::Scalar | Kernel::ScalarBatch => {
282                ppo_scalar(data, fast, slow, ma_type, first, &mut out)
283            }
284            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
285            Kernel::Avx2 | Kernel::Avx2Batch => {
286                ppo_avx2(data, fast, slow, ma_type, first, &mut out)
287            }
288            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
289            Kernel::Avx512 | Kernel::Avx512Batch => {
290                ppo_avx512(data, fast, slow, ma_type, first, &mut out)
291            }
292            _ => unreachable!(),
293        }
294    }
295
296    Ok(PpoOutput { values: out })
297}
298
299pub fn ppo_into_slice(dst: &mut [f64], input: &PpoInput, kern: Kernel) -> Result<(), PpoError> {
300    let data = input.as_ref();
301    if data.is_empty() {
302        return Err(PpoError::EmptyInputData);
303    }
304
305    let fast = input.get_fast_period();
306    let slow = input.get_slow_period();
307    let ma_type = input.params.ma_type.as_deref().unwrap_or("sma");
308
309    if fast == 0 || slow == 0 || fast > data.len() || slow > data.len() {
310        return Err(PpoError::InvalidPeriod {
311            fast,
312            slow,
313            data_len: data.len(),
314        });
315    }
316    if dst.len() != data.len() {
317        return Err(PpoError::OutputLengthMismatch {
318            expected: data.len(),
319            got: dst.len(),
320        });
321    }
322
323    let first = data
324        .iter()
325        .position(|x| !x.is_nan())
326        .ok_or(PpoError::AllValuesNaN)?;
327    if data.len() - first < slow {
328        return Err(PpoError::NotEnoughValidData {
329            needed: slow,
330            valid: data.len() - first,
331        });
332    }
333
334    let chosen = match kern {
335        Kernel::Auto => Kernel::Scalar,
336        k => k,
337    };
338
339    unsafe {
340        match chosen {
341            Kernel::Scalar | Kernel::ScalarBatch => {
342                ppo_scalar(data, fast, slow, ma_type, first, dst)
343            }
344            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
345            Kernel::Avx2 | Kernel::Avx2Batch => ppo_avx2(data, fast, slow, ma_type, first, dst),
346            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
347            Kernel::Avx512 | Kernel::Avx512Batch => {
348                ppo_avx512(data, fast, slow, ma_type, first, dst)
349            }
350            _ => unreachable!(),
351        }
352    }
353
354    let warmup_end = first + slow - 1;
355    for v in &mut dst[..warmup_end] {
356        *v = f64::from_bits(0x7ff8_0000_0000_0000);
357    }
358    Ok(())
359}
360
361#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
362#[inline]
363pub fn ppo_into(input: &PpoInput, out: &mut [f64]) -> Result<(), PpoError> {
364    ppo_into_slice(out, input, Kernel::Auto)
365}
366
367#[inline]
368pub unsafe fn ppo_scalar(
369    data: &[f64],
370    fast: usize,
371    slow: usize,
372    ma_type: &str,
373    first: usize,
374    out: &mut [f64],
375) {
376    if ma_type == "ema" {
377        ppo_scalar_classic_ema(data, fast, slow, first, out);
378        return;
379    } else if ma_type == "sma" {
380        ppo_scalar_classic_sma(data, fast, slow, first, out);
381        return;
382    }
383
384    let start = first + slow - 1;
385    let fast_ma = match ma(ma_type, MaData::Slice(data), fast) {
386        Ok(v) => v,
387        Err(_) => {
388            let mut i = start;
389            while i < data.len() {
390                *out.get_unchecked_mut(i) = f64::NAN;
391                i += 1;
392            }
393            return;
394        }
395    };
396    let slow_ma = match ma(ma_type, MaData::Slice(data), slow) {
397        Ok(v) => v,
398        Err(_) => {
399            let mut i = start;
400            while i < data.len() {
401                *out.get_unchecked_mut(i) = f64::NAN;
402                i += 1;
403            }
404            return;
405        }
406    };
407
408    let n = data.len();
409    let mut i = start;
410    while i < n {
411        let sf = *slow_ma.get_unchecked(i);
412        let ff = *fast_ma.get_unchecked(i);
413        let y = if sf == 0.0 || sf.is_nan() || ff.is_nan() {
414            f64::NAN
415        } else {
416            let ratio = ff / sf;
417            f64::mul_add(ratio, 100.0, -100.0)
418        };
419        *out.get_unchecked_mut(i) = y;
420        i += 1;
421    }
422}
423
424#[inline]
425pub unsafe fn ppo_scalar_classic_ema(
426    data: &[f64],
427    fast: usize,
428    slow: usize,
429    first: usize,
430    out: &mut [f64],
431) {
432    let n = data.len();
433    let start_idx = first + slow - 1;
434
435    let fa = 2.0 / (fast as f64 + 1.0);
436    let sa = 2.0 / (slow as f64 + 1.0);
437    let fb = 1.0 - fa;
438    let sb = 1.0 - sa;
439
440    let mut slow_sum = 0.0f64;
441    let mut fast_sum = 0.0f64;
442    let overlap = slow - fast;
443    let mut k = 0usize;
444    while k < slow {
445        let v = *data.get_unchecked(first + k);
446        slow_sum += v;
447        if k >= overlap {
448            fast_sum += v;
449        }
450        k += 1;
451    }
452
453    let mut fast_ema = fast_sum / (fast as f64);
454    let mut slow_ema = slow_sum / (slow as f64);
455
456    let mut i = first + fast;
457    while i <= start_idx {
458        let x = *data.get_unchecked(i);
459        fast_ema = f64::mul_add(fa, x, fb * fast_ema);
460        i += 1;
461    }
462
463    *out.get_unchecked_mut(start_idx) = if slow_ema == 0.0 || slow_ema.is_nan() || fast_ema.is_nan()
464    {
465        f64::NAN
466    } else {
467        let ratio = fast_ema / slow_ema;
468        f64::mul_add(ratio, 100.0, -100.0)
469    };
470
471    let mut j = start_idx + 1;
472    while j < n {
473        let x = *data.get_unchecked(j);
474        fast_ema = f64::mul_add(fa, x, fb * fast_ema);
475        slow_ema = f64::mul_add(sa, x, sb * slow_ema);
476
477        let y = if slow_ema == 0.0 || slow_ema.is_nan() || fast_ema.is_nan() {
478            f64::NAN
479        } else {
480            let ratio = fast_ema / slow_ema;
481            f64::mul_add(ratio, 100.0, -100.0)
482        };
483        *out.get_unchecked_mut(j) = y;
484        j += 1;
485    }
486}
487
488#[inline]
489pub unsafe fn ppo_scalar_classic_sma(
490    data: &[f64],
491    fast: usize,
492    slow: usize,
493    first: usize,
494    out: &mut [f64],
495) {
496    let n = data.len();
497    let start_idx = first + slow - 1;
498
499    let k = (slow as f64) / (fast as f64);
500
501    let mut slow_sum = 0.0f64;
502    let mut fast_sum = 0.0f64;
503    let overlap = slow - fast;
504    let mut t = 0usize;
505    while t < slow {
506        let v = *data.get_unchecked(first + t);
507        slow_sum += v;
508        if t >= overlap {
509            fast_sum += v;
510        }
511        t += 1;
512    }
513
514    *out.get_unchecked_mut(start_idx) = if slow_sum == 0.0 || slow_sum.is_nan() || fast_sum.is_nan()
515    {
516        f64::NAN
517    } else {
518        let ratio = (fast_sum * k) / slow_sum;
519        f64::mul_add(ratio, 100.0, -100.0)
520    };
521
522    let mut i = start_idx + 1;
523    while i < n {
524        let add = *data.get_unchecked(i);
525
526        let sub_fast = *data.get_unchecked(i - fast);
527        let sub_slow = *data.get_unchecked(i - slow);
528
529        fast_sum += add - sub_fast;
530        slow_sum += add - sub_slow;
531
532        let y = if slow_sum == 0.0 || slow_sum.is_nan() || fast_sum.is_nan() {
533            f64::NAN
534        } else {
535            let ratio = (fast_sum * k) / slow_sum;
536            f64::mul_add(ratio, 100.0, -100.0)
537        };
538        *out.get_unchecked_mut(i) = y;
539
540        i += 1;
541    }
542}
543
544#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
545#[inline]
546pub unsafe fn ppo_avx2(
547    data: &[f64],
548    fast: usize,
549    slow: usize,
550    ma_type: &str,
551    first: usize,
552    out: &mut [f64],
553) {
554    ppo_scalar(data, fast, slow, ma_type, first, out)
555}
556
557#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
558#[inline]
559pub unsafe fn ppo_avx512(
560    data: &[f64],
561    fast: usize,
562    slow: usize,
563    ma_type: &str,
564    first: usize,
565    out: &mut [f64],
566) {
567    if slow <= 32 {
568        ppo_avx512_short(data, fast, slow, ma_type, first, out)
569    } else {
570        ppo_avx512_long(data, fast, slow, ma_type, first, out)
571    }
572}
573
574#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
575#[inline]
576pub unsafe fn ppo_avx512_short(
577    data: &[f64],
578    fast: usize,
579    slow: usize,
580    ma_type: &str,
581    first: usize,
582    out: &mut [f64],
583) {
584    ppo_scalar(data, fast, slow, ma_type, first, out)
585}
586
587#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
588#[inline]
589pub unsafe fn ppo_avx512_long(
590    data: &[f64],
591    fast: usize,
592    slow: usize,
593    ma_type: &str,
594    first: usize,
595    out: &mut [f64],
596) {
597    ppo_scalar(data, fast, slow, ma_type, first, out)
598}
599
600#[derive(Clone, Debug)]
601pub struct PpoBatchRange {
602    pub fast_period: (usize, usize, usize),
603    pub slow_period: (usize, usize, usize),
604    pub ma_type: String,
605}
606
607impl Default for PpoBatchRange {
608    fn default() -> Self {
609        Self {
610            fast_period: (12, 12, 0),
611            slow_period: (26, 275, 1),
612            ma_type: "sma".to_string(),
613        }
614    }
615}
616
617#[derive(Clone, Debug, Default)]
618pub struct PpoBatchBuilder {
619    range: PpoBatchRange,
620    kernel: Kernel,
621}
622
623impl PpoBatchBuilder {
624    pub fn new() -> Self {
625        Self::default()
626    }
627    pub fn kernel(mut self, k: Kernel) -> Self {
628        self.kernel = k;
629        self
630    }
631    pub fn fast_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
632        self.range.fast_period = (start, end, step);
633        self
634    }
635    pub fn slow_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
636        self.range.slow_period = (start, end, step);
637        self
638    }
639    pub fn ma_type<S: Into<String>>(mut self, t: S) -> Self {
640        self.range.ma_type = t.into();
641        self
642    }
643    pub fn apply_slice(self, data: &[f64]) -> Result<PpoBatchOutput, PpoError> {
644        ppo_batch_with_kernel(data, &self.range, self.kernel)
645    }
646    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<PpoBatchOutput, PpoError> {
647        PpoBatchBuilder::new().kernel(k).apply_slice(data)
648    }
649    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<PpoBatchOutput, PpoError> {
650        let slice = source_type(c, src);
651        self.apply_slice(slice)
652    }
653    pub fn with_default_candles(c: &Candles) -> Result<PpoBatchOutput, PpoError> {
654        PpoBatchBuilder::new()
655            .kernel(Kernel::Auto)
656            .apply_candles(c, "close")
657    }
658}
659
660#[derive(Clone, Debug)]
661pub struct PpoBatchOutput {
662    pub values: Vec<f64>,
663    pub combos: Vec<PpoParams>,
664    pub rows: usize,
665    pub cols: usize,
666}
667
668impl PpoBatchOutput {
669    pub fn row_for_params(&self, p: &PpoParams) -> Option<usize> {
670        self.combos.iter().position(|c| {
671            c.fast_period.unwrap_or(12) == p.fast_period.unwrap_or(12)
672                && c.slow_period.unwrap_or(26) == p.slow_period.unwrap_or(26)
673                && c.ma_type.as_ref().unwrap_or(&"sma".to_string())
674                    == p.ma_type.as_ref().unwrap_or(&"sma".to_string())
675        })
676    }
677    pub fn values_for(&self, p: &PpoParams) -> Option<&[f64]> {
678        self.row_for_params(p).and_then(|row| {
679            let start = row.checked_mul(self.cols)?;
680            let end = start.checked_add(self.cols)?;
681            self.values.get(start..end)
682        })
683    }
684}
685
686#[inline(always)]
687fn expand_grid(r: &PpoBatchRange) -> Result<Vec<PpoParams>, PpoError> {
688    fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
689        if step == 0 || start == end {
690            return vec![start];
691        }
692        if start <= end {
693            (start..=end).step_by(step).collect()
694        } else {
695            let mut v = Vec::new();
696            let mut cur = start;
697            loop {
698                v.push(cur);
699                if cur <= end {
700                    break;
701                }
702                let next = match cur.checked_sub(step) {
703                    Some(n) => n,
704                    None => break,
705                };
706                if next < end {
707                    break;
708                }
709                cur = next;
710            }
711            v
712        }
713    }
714
715    let fasts = axis_usize(r.fast_period);
716    let slows = axis_usize(r.slow_period);
717    if fasts.is_empty() {
718        let (start, end, step) = r.fast_period;
719        return Err(PpoError::InvalidRange { start, end, step });
720    }
721    if slows.is_empty() {
722        let (start, end, step) = r.slow_period;
723        return Err(PpoError::InvalidRange { start, end, step });
724    }
725
726    let ma_type = r.ma_type.clone();
727
728    let total = fasts
729        .len()
730        .checked_mul(slows.len())
731        .ok_or_else(|| PpoError::InvalidInput("fast*slow grid overflow".into()))?;
732    let mut out = Vec::with_capacity(total);
733    for &f in &fasts {
734        for &s in &slows {
735            out.push(PpoParams {
736                fast_period: Some(f),
737                slow_period: Some(s),
738                ma_type: Some(ma_type.clone()),
739            });
740        }
741    }
742    Ok(out)
743}
744
745#[inline(always)]
746pub fn ppo_batch_with_kernel(
747    data: &[f64],
748    sweep: &PpoBatchRange,
749    k: Kernel,
750) -> Result<PpoBatchOutput, PpoError> {
751    let kernel = match k {
752        Kernel::Auto => detect_best_batch_kernel(),
753        other if other.is_batch() => other,
754        other => return Err(PpoError::InvalidKernelForBatch(other)),
755    };
756    let simd = match kernel {
757        Kernel::Avx512Batch => Kernel::Avx512,
758        Kernel::Avx2Batch => Kernel::Avx2,
759        Kernel::ScalarBatch => Kernel::Scalar,
760        _ => unreachable!(),
761    };
762    ppo_batch_par_slice(data, sweep, simd)
763}
764
765#[inline(always)]
766pub fn ppo_batch_slice(
767    data: &[f64],
768    sweep: &PpoBatchRange,
769    kern: Kernel,
770) -> Result<PpoBatchOutput, PpoError> {
771    ppo_batch_inner(data, sweep, kern, false)
772}
773
774#[inline(always)]
775pub fn ppo_batch_par_slice(
776    data: &[f64],
777    sweep: &PpoBatchRange,
778    kern: Kernel,
779) -> Result<PpoBatchOutput, PpoError> {
780    ppo_batch_inner(data, sweep, kern, true)
781}
782
783#[inline(always)]
784fn ppo_batch_inner_into(
785    data: &[f64],
786    sweep: &PpoBatchRange,
787    kern: Kernel,
788    parallel: bool,
789    out: &mut [f64],
790) -> Result<Vec<PpoParams>, PpoError> {
791    let combos = expand_grid(sweep)?;
792    if combos.is_empty() {
793        let (start, end, step) = sweep.fast_period;
794        return Err(PpoError::InvalidRange { start, end, step });
795    }
796
797    let first = data
798        .iter()
799        .position(|x| !x.is_nan())
800        .ok_or(PpoError::AllValuesNaN)?;
801    let max_slow = combos.iter().map(|c| c.slow_period.unwrap()).max().unwrap();
802    if data.len() - first < max_slow {
803        return Err(PpoError::NotEnoughValidData {
804            needed: max_slow,
805            valid: data.len() - first,
806        });
807    }
808
809    let rows = combos.len();
810    let cols = data.len();
811    let expected = rows.checked_mul(cols).ok_or_else(|| {
812        PpoError::InvalidInput("rows*cols overflow in ppo_batch_inner_into".into())
813    })?;
814    if out.len() != expected {
815        return Err(PpoError::OutputLengthMismatch {
816            expected,
817            got: out.len(),
818        });
819    }
820
821    let out_uninit: &mut [MaybeUninit<f64>] = unsafe {
822        std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
823    };
824
825    let ma_type = sweep.ma_type.as_str();
826    let use_cached = ma_type == "sma" || ma_type == "ema";
827    let mut ma_cache: HashMap<usize, Vec<f64>> = HashMap::new();
828    if use_cached {
829        let mut uniq: Vec<usize> = Vec::new();
830        for c in &combos {
831            let f = c.fast_period.unwrap();
832            let s = c.slow_period.unwrap();
833            if !uniq.contains(&f) {
834                uniq.push(f);
835            }
836            if !uniq.contains(&s) {
837                uniq.push(s);
838            }
839        }
840        for &p in &uniq {
841            if let Ok(v) = ma(ma_type, MaData::Slice(data), p) {
842                ma_cache.insert(p, v);
843            } else {
844                let mut v = vec![f64::NAN; data.len()];
845                ma_cache.insert(p, v);
846            }
847        }
848    }
849
850    let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
851        let p = &combos[row];
852        let out_row: &mut [f64] =
853            std::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
854        if use_cached {
855            let fast = p.fast_period.unwrap();
856            let slow = p.slow_period.unwrap();
857            let fast_ma = ma_cache.get(&fast).unwrap();
858            let slow_ma = ma_cache.get(&slow).unwrap();
859            let mut i = first + slow - 1;
860            let n = data.len();
861            while i < n {
862                let sf = *slow_ma.get_unchecked(i);
863                let ff = *fast_ma.get_unchecked(i);
864                let y = if sf == 0.0 || sf.is_nan() || ff.is_nan() {
865                    f64::NAN
866                } else {
867                    let ratio = ff / sf;
868                    f64::mul_add(ratio, 100.0, -100.0)
869                };
870                *out_row.get_unchecked_mut(i) = y;
871                i += 1;
872            }
873        } else {
874            match kern {
875                Kernel::Scalar => ppo_row_scalar(
876                    data,
877                    first,
878                    p.fast_period.unwrap(),
879                    p.slow_period.unwrap(),
880                    p.ma_type.as_ref().unwrap(),
881                    out_row,
882                ),
883                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
884                Kernel::Avx2 => ppo_row_avx2(
885                    data,
886                    first,
887                    p.fast_period.unwrap(),
888                    p.slow_period.unwrap(),
889                    p.ma_type.as_ref().unwrap(),
890                    out_row,
891                ),
892                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
893                Kernel::Avx512 => ppo_row_avx512(
894                    data,
895                    first,
896                    p.fast_period.unwrap(),
897                    p.slow_period.unwrap(),
898                    p.ma_type.as_ref().unwrap(),
899                    out_row,
900                ),
901                _ => unreachable!(),
902            }
903        }
904    };
905
906    if parallel {
907        #[cfg(not(target_arch = "wasm32"))]
908        {
909            out_uninit
910                .par_chunks_mut(cols)
911                .enumerate()
912                .for_each(|(row, slice)| do_row(row, slice));
913        }
914        #[cfg(target_arch = "wasm32")]
915        {
916            for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
917                do_row(row, slice);
918            }
919        }
920    } else {
921        for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
922            do_row(row, slice);
923        }
924    }
925
926    Ok(combos)
927}
928
929#[inline(always)]
930fn ppo_batch_inner(
931    data: &[f64],
932    sweep: &PpoBatchRange,
933    kern: Kernel,
934    parallel: bool,
935) -> Result<PpoBatchOutput, PpoError> {
936    let combos = expand_grid(sweep)?;
937    if combos.is_empty() {
938        let (start, end, step) = sweep.fast_period;
939        return Err(PpoError::InvalidRange { start, end, step });
940    }
941    let first = data
942        .iter()
943        .position(|x| !x.is_nan())
944        .ok_or(PpoError::AllValuesNaN)?;
945    let max_slow = combos.iter().map(|c| c.slow_period.unwrap()).max().unwrap();
946    if data.len() - first < max_slow {
947        return Err(PpoError::NotEnoughValidData {
948            needed: max_slow,
949            valid: data.len() - first,
950        });
951    }
952    let rows = combos.len();
953    let cols = data.len();
954    let _total = rows
955        .checked_mul(cols)
956        .ok_or_else(|| PpoError::InvalidInput("rows*cols overflow in ppo_batch_inner".into()))?;
957
958    let mut buf_mu = make_uninit_matrix(rows, cols);
959
960    let warmup_periods: Vec<usize> = combos
961        .iter()
962        .map(|c| first + c.slow_period.unwrap() - 1)
963        .collect();
964
965    init_matrix_prefixes(&mut buf_mu, cols, &warmup_periods);
966
967    let mut buf_guard = core::mem::ManuallyDrop::new(buf_mu);
968    let values: &mut [f64] = unsafe {
969        core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
970    };
971
972    let ma_type = sweep.ma_type.as_str();
973    let use_cached = ma_type == "sma" || ma_type == "ema";
974    let mut ma_cache: HashMap<usize, Vec<f64>> = HashMap::new();
975    if use_cached {
976        let mut uniq: Vec<usize> = Vec::new();
977        for c in &combos {
978            let f = c.fast_period.unwrap();
979            let s = c.slow_period.unwrap();
980            if !uniq.contains(&f) {
981                uniq.push(f);
982            }
983            if !uniq.contains(&s) {
984                uniq.push(s);
985            }
986        }
987        for &p in &uniq {
988            if let Ok(v) = ma(ma_type, MaData::Slice(data), p) {
989                ma_cache.insert(p, v);
990            } else {
991                let v = vec![f64::NAN; data.len()];
992                ma_cache.insert(p, v);
993            }
994        }
995    }
996
997    let do_row = |row: usize, out_row: &mut [f64]| unsafe {
998        let p = &combos[row];
999        if use_cached {
1000            let fast = p.fast_period.unwrap();
1001            let slow = p.slow_period.unwrap();
1002            let fast_ma = ma_cache.get(&fast).unwrap();
1003            let slow_ma = ma_cache.get(&slow).unwrap();
1004            let mut i = first + slow - 1;
1005            let n = data.len();
1006            while i < n {
1007                let sf = *slow_ma.get_unchecked(i);
1008                let ff = *fast_ma.get_unchecked(i);
1009                let y = if sf == 0.0 || sf.is_nan() || ff.is_nan() {
1010                    f64::NAN
1011                } else {
1012                    let ratio = ff / sf;
1013                    f64::mul_add(ratio, 100.0, -100.0)
1014                };
1015                *out_row.get_unchecked_mut(i) = y;
1016                i += 1;
1017            }
1018        } else {
1019            match kern {
1020                Kernel::Scalar => ppo_row_scalar(
1021                    data,
1022                    first,
1023                    p.fast_period.unwrap(),
1024                    p.slow_period.unwrap(),
1025                    p.ma_type.as_ref().unwrap(),
1026                    out_row,
1027                ),
1028                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1029                Kernel::Avx2 => ppo_row_avx2(
1030                    data,
1031                    first,
1032                    p.fast_period.unwrap(),
1033                    p.slow_period.unwrap(),
1034                    p.ma_type.as_ref().unwrap(),
1035                    out_row,
1036                ),
1037                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1038                Kernel::Avx512 => ppo_row_avx512(
1039                    data,
1040                    first,
1041                    p.fast_period.unwrap(),
1042                    p.slow_period.unwrap(),
1043                    p.ma_type.as_ref().unwrap(),
1044                    out_row,
1045                ),
1046                _ => unreachable!(),
1047            }
1048        }
1049    };
1050
1051    if parallel {
1052        #[cfg(not(target_arch = "wasm32"))]
1053        {
1054            values
1055                .par_chunks_mut(cols)
1056                .enumerate()
1057                .for_each(|(row, slice)| do_row(row, slice));
1058        }
1059
1060        #[cfg(target_arch = "wasm32")]
1061        {
1062            for (row, slice) in values.chunks_mut(cols).enumerate() {
1063                do_row(row, slice);
1064            }
1065        }
1066    } else {
1067        for (row, slice) in values.chunks_mut(cols).enumerate() {
1068            do_row(row, slice);
1069        }
1070    }
1071
1072    let values = unsafe {
1073        Vec::from_raw_parts(
1074            buf_guard.as_mut_ptr() as *mut f64,
1075            buf_guard.len(),
1076            buf_guard.capacity(),
1077        )
1078    };
1079
1080    Ok(PpoBatchOutput {
1081        values,
1082        combos,
1083        rows,
1084        cols,
1085    })
1086}
1087
1088#[inline(always)]
1089pub unsafe fn ppo_row_scalar(
1090    data: &[f64],
1091    first: usize,
1092    fast: usize,
1093    slow: usize,
1094    ma_type: &str,
1095    out: &mut [f64],
1096) {
1097    if ma_type == "ema" {
1098        ppo_row_scalar_classic_ema(data, first, fast, slow, out);
1099    } else if ma_type == "sma" {
1100        ppo_row_scalar_classic_sma(data, first, fast, slow, out);
1101    } else {
1102        ppo_scalar(data, fast, slow, ma_type, first, out);
1103    }
1104}
1105
1106#[inline(always)]
1107pub unsafe fn ppo_row_scalar_classic_ema(
1108    data: &[f64],
1109    first: usize,
1110    fast: usize,
1111    slow: usize,
1112    out: &mut [f64],
1113) {
1114    ppo_scalar_classic_ema(data, fast, slow, first, out);
1115}
1116
1117#[inline(always)]
1118pub unsafe fn ppo_row_scalar_classic_sma(
1119    data: &[f64],
1120    first: usize,
1121    fast: usize,
1122    slow: usize,
1123    out: &mut [f64],
1124) {
1125    ppo_scalar_classic_sma(data, fast, slow, first, out);
1126}
1127
1128#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1129#[inline(always)]
1130pub unsafe fn ppo_row_avx2(
1131    data: &[f64],
1132    first: usize,
1133    fast: usize,
1134    slow: usize,
1135    ma_type: &str,
1136    out: &mut [f64],
1137) {
1138    ppo_scalar(data, fast, slow, ma_type, first, out)
1139}
1140
1141#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1142#[inline(always)]
1143pub unsafe fn ppo_row_avx512(
1144    data: &[f64],
1145    first: usize,
1146    fast: usize,
1147    slow: usize,
1148    ma_type: &str,
1149    out: &mut [f64],
1150) {
1151    if slow <= 32 {
1152        ppo_row_avx512_short(data, first, fast, slow, ma_type, out)
1153    } else {
1154        ppo_row_avx512_long(data, first, fast, slow, ma_type, out)
1155    }
1156}
1157
1158#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1159#[inline(always)]
1160pub unsafe fn ppo_row_avx512_short(
1161    data: &[f64],
1162    first: usize,
1163    fast: usize,
1164    slow: usize,
1165    ma_type: &str,
1166    out: &mut [f64],
1167) {
1168    ppo_scalar(data, fast, slow, ma_type, first, out)
1169}
1170
1171#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1172#[inline(always)]
1173pub unsafe fn ppo_row_avx512_long(
1174    data: &[f64],
1175    first: usize,
1176    fast: usize,
1177    slow: usize,
1178    ma_type: &str,
1179    out: &mut [f64],
1180) {
1181    ppo_scalar(data, fast, slow, ma_type, first, out)
1182}
1183
1184pub struct PpoStream {
1185    fast_period: usize,
1186    slow_period: usize,
1187    ma_type: String,
1188    mode: StreamMode,
1189}
1190
1191#[derive(Debug)]
1192enum StreamMode {
1193    Sma(SmaState),
1194
1195    Ema(EmaState),
1196
1197    Generic { data: Vec<f64> },
1198}
1199
1200#[derive(Debug)]
1201struct SmaState {
1202    started: bool,
1203    fast: usize,
1204    slow: usize,
1205    k: f64,
1206
1207    fast_sum: f64,
1208    slow_sum: f64,
1209    fast_buf: Vec<f64>,
1210    slow_buf: Vec<f64>,
1211    i_fast: usize,
1212    i_slow: usize,
1213    filled_fast: usize,
1214    filled_slow: usize,
1215}
1216
1217#[derive(Debug)]
1218struct EmaState {
1219    started: bool,
1220    fast: usize,
1221    slow: usize,
1222
1223    fa: f64,
1224    fb: f64,
1225    sa: f64,
1226    sb: f64,
1227
1228    fast_ema: f64,
1229    slow_ema: f64,
1230    seeded: bool,
1231
1232    warm: Vec<f64>,
1233}
1234
1235impl PpoStream {
1236    pub fn try_new(params: PpoParams) -> Result<Self, PpoError> {
1237        let fast = params.fast_period.unwrap_or(12);
1238        let slow = params.slow_period.unwrap_or(26);
1239        let ma_type = params.ma_type.clone().unwrap_or_else(|| "sma".to_string());
1240
1241        let mode = match ma_type.as_str() {
1242            "sma" => StreamMode::Sma(SmaState {
1243                started: false,
1244                fast,
1245                slow,
1246                k: slow as f64 / fast as f64,
1247                fast_sum: 0.0,
1248                slow_sum: 0.0,
1249                fast_buf: Vec::with_capacity(fast),
1250                slow_buf: Vec::with_capacity(slow),
1251                i_fast: 0,
1252                i_slow: 0,
1253                filled_fast: 0,
1254                filled_slow: 0,
1255            }),
1256            "ema" => {
1257                let fa = 2.0 / (fast as f64 + 1.0);
1258                let sa = 2.0 / (slow as f64 + 1.0);
1259                StreamMode::Ema(EmaState {
1260                    started: false,
1261                    fast,
1262                    slow,
1263                    fa,
1264                    fb: 1.0 - fa,
1265                    sa,
1266                    sb: 1.0 - sa,
1267                    fast_ema: f64::NAN,
1268                    slow_ema: f64::NAN,
1269                    seeded: false,
1270                    warm: Vec::with_capacity(slow),
1271                })
1272            }
1273            _ => StreamMode::Generic { data: Vec::new() },
1274        };
1275
1276        Ok(Self {
1277            fast_period: fast,
1278            slow_period: slow,
1279            ma_type,
1280            mode,
1281        })
1282    }
1283
1284    #[inline(always)]
1285    pub fn update(&mut self, value: f64) -> Option<f64> {
1286        match &mut self.mode {
1287            StreamMode::Sma(state) => update_sma(state, value),
1288            StreamMode::Ema(state) => update_ema(state, value),
1289
1290            StreamMode::Generic { data } => {
1291                data.push(value);
1292                if data.len() < self.slow_period {
1293                    return None;
1294                }
1295
1296                let fast_ma = ma(&self.ma_type, MaData::Slice(&data), self.fast_period).ok()?;
1297                let slow_ma = ma(&self.ma_type, MaData::Slice(&data), self.slow_period).ok()?;
1298                let ff = *fast_ma.last()?;
1299                let sf = *slow_ma.last()?;
1300                if ff.is_nan() || sf.is_nan() || sf == 0.0 {
1301                    Some(f64::NAN)
1302                } else {
1303                    let ratio = ff / sf;
1304                    Some(f64::mul_add(ratio, 100.0, -100.0))
1305                }
1306            }
1307        }
1308    }
1309}
1310
1311#[inline(always)]
1312fn update_sma(s: &mut SmaState, x: f64) -> Option<f64> {
1313    if !s.started {
1314        if x.is_nan() {
1315            return None;
1316        }
1317        s.started = true;
1318    }
1319
1320    if s.filled_fast < s.fast {
1321        s.fast_buf.push(x);
1322        s.fast_sum += x;
1323        s.filled_fast += 1;
1324    } else {
1325        let old = s.fast_buf[s.i_fast];
1326        s.fast_sum += x - old;
1327        s.fast_buf[s.i_fast] = x;
1328        s.i_fast = (s.i_fast + 1) % s.fast;
1329    }
1330
1331    if s.filled_slow < s.slow {
1332        s.slow_buf.push(x);
1333        s.slow_sum += x;
1334        s.filled_slow += 1;
1335        if s.filled_slow < s.slow {
1336            return None;
1337        }
1338    } else {
1339        let old = s.slow_buf[s.i_slow];
1340        s.slow_sum += x - old;
1341        s.slow_buf[s.i_slow] = x;
1342        s.i_slow = (s.i_slow + 1) % s.slow;
1343    }
1344
1345    let slow_sum = s.slow_sum;
1346    let fast_sum = s.fast_sum;
1347    if slow_sum == 0.0 || slow_sum.is_nan() || fast_sum.is_nan() {
1348        Some(f64::NAN)
1349    } else {
1350        let ratio = (fast_sum * s.k) / slow_sum;
1351        Some(f64::mul_add(ratio, 100.0, -100.0))
1352    }
1353}
1354
1355#[inline(always)]
1356fn update_ema(e: &mut EmaState, x: f64) -> Option<f64> {
1357    if !e.started {
1358        if x.is_nan() {
1359            return None;
1360        }
1361        e.started = true;
1362    }
1363
1364    if !e.seeded {
1365        e.warm.push(x);
1366        if e.warm.len() < e.slow {
1367            return None;
1368        }
1369
1370        let mut slow_sum = 0.0f64;
1371        for &v in &e.warm {
1372            slow_sum += v;
1373        }
1374        e.slow_ema = slow_sum / e.slow as f64;
1375
1376        let mut fast_sum = 0.0f64;
1377        let start = e.slow - e.fast;
1378        for &v in &e.warm[start..] {
1379            fast_sum += v;
1380        }
1381        e.fast_ema = fast_sum / e.fast as f64;
1382
1383        for &v in &e.warm[e.fast..] {
1384            e.fast_ema = f64::mul_add(e.fa, v, e.fb * e.fast_ema);
1385        }
1386
1387        e.seeded = true;
1388
1389        let sf = e.slow_ema;
1390        let ff = e.fast_ema;
1391        return if sf == 0.0 || sf.is_nan() || ff.is_nan() {
1392            Some(f64::NAN)
1393        } else {
1394            let ratio = ff / sf;
1395            Some(f64::mul_add(ratio, 100.0, -100.0))
1396        };
1397    }
1398
1399    e.fast_ema = f64::mul_add(e.fa, x, e.fb * e.fast_ema);
1400    e.slow_ema = f64::mul_add(e.sa, x, e.sb * e.slow_ema);
1401
1402    let sf = e.slow_ema;
1403    let ff = e.fast_ema;
1404    if sf == 0.0 || sf.is_nan() || ff.is_nan() {
1405        Some(f64::NAN)
1406    } else {
1407        let ratio = ff / sf;
1408        Some(f64::mul_add(ratio, 100.0, -100.0))
1409    }
1410}
1411
1412#[cfg(test)]
1413mod tests {
1414    use super::*;
1415    use crate::skip_if_unsupported;
1416    use crate::utilities::data_loader::read_candles_from_csv;
1417
1418    fn check_ppo_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1419        skip_if_unsupported!(kernel, test_name);
1420        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1421        let candles = read_candles_from_csv(file_path)?;
1422        let default_params = PpoParams {
1423            fast_period: None,
1424            slow_period: None,
1425            ma_type: None,
1426        };
1427        let input = PpoInput::from_candles(&candles, "close", default_params);
1428        let output = ppo_with_kernel(&input, kernel)?;
1429        assert_eq!(output.values.len(), candles.close.len());
1430        Ok(())
1431    }
1432
1433    #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1434    #[test]
1435    fn test_ppo_into_matches_api() -> Result<(), Box<dyn Error>> {
1436        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1437        let candles = read_candles_from_csv(file_path)?;
1438        let input = PpoInput::from_candles(&candles, "close", PpoParams::default());
1439
1440        let baseline = ppo(&input)?.values;
1441
1442        let mut out = vec![0.0; candles.close.len()];
1443        ppo_into(&input, &mut out)?;
1444
1445        assert_eq!(baseline.len(), out.len());
1446
1447        fn eq_or_both_nan(a: f64, b: f64) -> bool {
1448            (a.is_nan() && b.is_nan()) || (a == b)
1449        }
1450
1451        for i in 0..out.len() {
1452            assert!(
1453                eq_or_both_nan(baseline[i], out[i]),
1454                "Mismatch at index {}: baseline={}, into={}",
1455                i,
1456                baseline[i],
1457                out[i]
1458            );
1459        }
1460
1461        Ok(())
1462    }
1463
1464    fn check_ppo_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1465        skip_if_unsupported!(kernel, test_name);
1466        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1467        let candles = read_candles_from_csv(file_path)?;
1468        let input = PpoInput::from_candles(&candles, "close", PpoParams::default());
1469        let result = ppo_with_kernel(&input, kernel)?;
1470        assert_eq!(result.values.len(), candles.close.len());
1471        let expected_last_five = [
1472            -0.8532313608928664,
1473            -0.8537562894550523,
1474            -0.6821291938174874,
1475            -0.5620008722078592,
1476            -0.4101724140910927,
1477        ];
1478        let start = result.values.len().saturating_sub(5);
1479        for (i, &val) in result.values[start..].iter().enumerate() {
1480            let diff = (val - expected_last_five[i]).abs();
1481            assert!(
1482                diff < 1e-7,
1483                "[{}] PPO {:?} mismatch at idx {}: got {}, expected {}",
1484                test_name,
1485                kernel,
1486                i,
1487                val,
1488                expected_last_five[i]
1489            );
1490        }
1491        Ok(())
1492    }
1493
1494    fn check_ppo_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1495        skip_if_unsupported!(kernel, test_name);
1496        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1497        let candles = read_candles_from_csv(file_path)?;
1498        let input = PpoInput::with_default_candles(&candles);
1499        match input.data {
1500            PpoData::Candles { source, .. } => assert_eq!(source, "close"),
1501            _ => panic!("Expected PpoData::Candles"),
1502        }
1503        let output = ppo_with_kernel(&input, kernel)?;
1504        assert_eq!(output.values.len(), candles.close.len());
1505        Ok(())
1506    }
1507
1508    fn check_ppo_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1509        skip_if_unsupported!(kernel, test_name);
1510        let input_data = [10.0, 20.0, 30.0];
1511        let params = PpoParams {
1512            fast_period: Some(0),
1513            slow_period: None,
1514            ma_type: None,
1515        };
1516        let input = PpoInput::from_slice(&input_data, params);
1517        let res = ppo_with_kernel(&input, kernel);
1518        assert!(
1519            res.is_err(),
1520            "[{}] PPO should fail with zero fast period",
1521            test_name
1522        );
1523        Ok(())
1524    }
1525
1526    fn check_ppo_period_exceeds_length(
1527        test_name: &str,
1528        kernel: Kernel,
1529    ) -> Result<(), Box<dyn Error>> {
1530        skip_if_unsupported!(kernel, test_name);
1531        let data_small = [10.0, 20.0, 30.0];
1532        let params = PpoParams {
1533            fast_period: Some(12),
1534            slow_period: Some(26),
1535            ma_type: None,
1536        };
1537        let input = PpoInput::from_slice(&data_small, params);
1538        let res = ppo_with_kernel(&input, kernel);
1539        assert!(
1540            res.is_err(),
1541            "[{}] PPO should fail with period exceeding length",
1542            test_name
1543        );
1544        Ok(())
1545    }
1546
1547    fn check_ppo_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1548        skip_if_unsupported!(kernel, test_name);
1549        let single_point = [42.0];
1550        let params = PpoParams {
1551            fast_period: Some(12),
1552            slow_period: Some(26),
1553            ma_type: None,
1554        };
1555        let input = PpoInput::from_slice(&single_point, params);
1556        let res = ppo_with_kernel(&input, kernel);
1557        assert!(
1558            res.is_err(),
1559            "[{}] PPO should fail with insufficient data",
1560            test_name
1561        );
1562        Ok(())
1563    }
1564
1565    fn check_ppo_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1566        skip_if_unsupported!(kernel, test_name);
1567        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1568        let candles = read_candles_from_csv(file_path)?;
1569        let input = PpoInput::from_candles(
1570            &candles,
1571            "close",
1572            PpoParams {
1573                fast_period: Some(12),
1574                slow_period: Some(26),
1575                ma_type: Some("sma".to_string()),
1576            },
1577        );
1578        let res = ppo_with_kernel(&input, kernel)?;
1579        assert_eq!(res.values.len(), candles.close.len());
1580        if res.values.len() > 30 {
1581            for (i, &val) in res.values[30..].iter().enumerate() {
1582                assert!(
1583                    !val.is_nan(),
1584                    "[{}] Found unexpected NaN at out-index {}",
1585                    test_name,
1586                    30 + i
1587                );
1588            }
1589        }
1590        Ok(())
1591    }
1592
1593    fn check_ppo_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1594        skip_if_unsupported!(kernel, test_name);
1595        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1596        let candles = read_candles_from_csv(file_path)?;
1597        let fast = 12;
1598        let slow = 26;
1599        let ma_type = "sma".to_string();
1600        let input = PpoInput::from_candles(
1601            &candles,
1602            "close",
1603            PpoParams {
1604                fast_period: Some(fast),
1605                slow_period: Some(slow),
1606                ma_type: Some(ma_type.clone()),
1607            },
1608        );
1609        let batch_output = ppo_with_kernel(&input, kernel)?.values;
1610        let mut stream = PpoStream::try_new(PpoParams {
1611            fast_period: Some(fast),
1612            slow_period: Some(slow),
1613            ma_type: Some(ma_type),
1614        })?;
1615        let mut stream_values = Vec::with_capacity(candles.close.len());
1616        for &price in &candles.close {
1617            match stream.update(price) {
1618                Some(ppo_val) => stream_values.push(ppo_val),
1619                None => stream_values.push(f64::NAN),
1620            }
1621        }
1622        assert_eq!(batch_output.len(), stream_values.len());
1623        for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1624            if b.is_nan() && s.is_nan() {
1625                continue;
1626            }
1627            let diff = (b - s).abs();
1628            assert!(
1629                diff < 1e-9,
1630                "[{}] PPO streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1631                test_name,
1632                i,
1633                b,
1634                s,
1635                diff
1636            );
1637        }
1638        Ok(())
1639    }
1640
1641    #[cfg(debug_assertions)]
1642    fn check_ppo_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1643        skip_if_unsupported!(kernel, test_name);
1644
1645        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1646        let candles = read_candles_from_csv(file_path)?;
1647
1648        let test_params = vec![
1649            PpoParams::default(),
1650            PpoParams {
1651                fast_period: Some(2),
1652                slow_period: Some(3),
1653                ma_type: Some("sma".to_string()),
1654            },
1655            PpoParams {
1656                fast_period: Some(5),
1657                slow_period: Some(10),
1658                ma_type: Some("sma".to_string()),
1659            },
1660            PpoParams {
1661                fast_period: Some(12),
1662                slow_period: Some(26),
1663                ma_type: Some("ema".to_string()),
1664            },
1665            PpoParams {
1666                fast_period: Some(12),
1667                slow_period: Some(26),
1668                ma_type: Some("wma".to_string()),
1669            },
1670            PpoParams {
1671                fast_period: Some(20),
1672                slow_period: Some(40),
1673                ma_type: Some("sma".to_string()),
1674            },
1675            PpoParams {
1676                fast_period: Some(50),
1677                slow_period: Some(100),
1678                ma_type: Some("sma".to_string()),
1679            },
1680            PpoParams {
1681                fast_period: Some(10),
1682                slow_period: Some(11),
1683                ma_type: Some("sma".to_string()),
1684            },
1685            PpoParams {
1686                fast_period: Some(3),
1687                slow_period: Some(21),
1688                ma_type: Some("ema".to_string()),
1689            },
1690            PpoParams {
1691                fast_period: Some(7),
1692                slow_period: Some(14),
1693                ma_type: Some("wma".to_string()),
1694            },
1695            PpoParams {
1696                fast_period: Some(9),
1697                slow_period: Some(21),
1698                ma_type: Some("sma".to_string()),
1699            },
1700            PpoParams {
1701                fast_period: Some(100),
1702                slow_period: Some(200),
1703                ma_type: Some("ema".to_string()),
1704            },
1705        ];
1706
1707        for (param_idx, params) in test_params.iter().enumerate() {
1708            let input = PpoInput::from_candles(&candles, "close", params.clone());
1709            let output = ppo_with_kernel(&input, kernel)?;
1710
1711            for (i, &val) in output.values.iter().enumerate() {
1712                if val.is_nan() {
1713                    continue;
1714                }
1715
1716                let bits = val.to_bits();
1717
1718                if bits == 0x11111111_11111111 {
1719                    panic!(
1720                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1721						 with params: fast_period={}, slow_period={}, ma_type={:?} (param set {})",
1722                        test_name,
1723                        val,
1724                        bits,
1725                        i,
1726                        params.fast_period.unwrap_or(12),
1727                        params.slow_period.unwrap_or(26),
1728                        params.ma_type.as_ref().unwrap_or(&"sma".to_string()),
1729                        param_idx
1730                    );
1731                }
1732
1733                if bits == 0x22222222_22222222 {
1734                    panic!(
1735                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1736						 with params: fast_period={}, slow_period={}, ma_type={:?} (param set {})",
1737                        test_name,
1738                        val,
1739                        bits,
1740                        i,
1741                        params.fast_period.unwrap_or(12),
1742                        params.slow_period.unwrap_or(26),
1743                        params.ma_type.as_ref().unwrap_or(&"sma".to_string()),
1744                        param_idx
1745                    );
1746                }
1747
1748                if bits == 0x33333333_33333333 {
1749                    panic!(
1750                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1751						 with params: fast_period={}, slow_period={}, ma_type={:?} (param set {})",
1752                        test_name,
1753                        val,
1754                        bits,
1755                        i,
1756                        params.fast_period.unwrap_or(12),
1757                        params.slow_period.unwrap_or(26),
1758                        params.ma_type.as_ref().unwrap_or(&"sma".to_string()),
1759                        param_idx
1760                    );
1761                }
1762            }
1763        }
1764
1765        Ok(())
1766    }
1767
1768    #[cfg(not(debug_assertions))]
1769    fn check_ppo_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1770        Ok(())
1771    }
1772
1773    #[cfg(feature = "proptest")]
1774    #[allow(clippy::float_cmp)]
1775    fn check_ppo_property(
1776        test_name: &str,
1777        kernel: Kernel,
1778    ) -> Result<(), Box<dyn std::error::Error>> {
1779        use crate::indicators::moving_averages::ma::{ma, MaData};
1780        use proptest::prelude::*;
1781        skip_if_unsupported!(kernel, test_name);
1782
1783        let strat = (2usize..=64).prop_flat_map(|slow_period| {
1784            (
1785                prop::collection::vec(
1786                    (10f64..100000f64)
1787                        .prop_filter("positive finite", |x| x.is_finite() && *x > 0.0),
1788                    slow_period..400,
1789                ),
1790                2usize..=slow_period,
1791                Just(slow_period),
1792                Just("sma"),
1793            )
1794        });
1795
1796        proptest::test_runner::TestRunner::default()
1797            .run(&strat, |(data, fast_period, slow_period, ma_type)| {
1798                let params = PpoParams {
1799                    fast_period: Some(fast_period),
1800                    slow_period: Some(slow_period),
1801                    ma_type: Some(ma_type.to_string()),
1802                };
1803                let input = PpoInput::from_slice(&data, params);
1804
1805                let PpoOutput { values: out } = ppo_with_kernel(&input, kernel).unwrap();
1806
1807                let PpoOutput { values: ref_out } =
1808                    ppo_with_kernel(&input, Kernel::Scalar).unwrap();
1809
1810                let fast_ma = ma(&ma_type, MaData::Slice(&data), fast_period).unwrap();
1811                let slow_ma = ma(&ma_type, MaData::Slice(&data), slow_period).unwrap();
1812
1813                for i in 0..(slow_period - 1).min(data.len()) {
1814                    prop_assert!(
1815                        out[i].is_nan(),
1816                        "Expected NaN during warmup at index {}, got {}",
1817                        i,
1818                        out[i]
1819                    );
1820                }
1821
1822                for i in (slow_period - 1)..data.len() {
1823                    let y = out[i];
1824                    let r = ref_out[i];
1825                    let fast_val = fast_ma[i];
1826                    let slow_val = slow_ma[i];
1827
1828                    if !fast_val.is_nan() && !slow_val.is_nan() && slow_val != 0.0 {
1829                        let expected_ppo = 100.0 * (fast_val - slow_val) / slow_val;
1830
1831                        if y.is_finite() && expected_ppo.is_finite() {
1832                            prop_assert!(
1833								(y - expected_ppo).abs() < 1e-9,
1834								"PPO formula mismatch at index {}: got {}, expected {} (fast={}, slow={})",
1835								i, y, expected_ppo, fast_val, slow_val
1836							);
1837                        }
1838                    }
1839
1840                    let y_bits = y.to_bits();
1841                    let r_bits = r.to_bits();
1842
1843                    if !y.is_finite() || !r.is_finite() {
1844                        prop_assert!(
1845                            y.to_bits() == r.to_bits(),
1846                            "finite/NaN mismatch idx {}: {} vs {}",
1847                            i,
1848                            y,
1849                            r
1850                        );
1851                        continue;
1852                    }
1853
1854                    let ulp_diff: u64 = y_bits.abs_diff(r_bits);
1855                    prop_assert!(
1856                        (y - r).abs() <= 1e-9 || ulp_diff <= 4,
1857                        "kernel mismatch idx {}: {} vs {} (ULP={})",
1858                        i,
1859                        y,
1860                        r,
1861                        ulp_diff
1862                    );
1863
1864                    if y.is_finite()
1865                        && fast_val.is_finite()
1866                        && slow_val.is_finite()
1867                        && slow_val > 0.0
1868                    {
1869                        if fast_val > slow_val {
1870                            prop_assert!(
1871                                y > 0.0,
1872                                "PPO should be positive when fast > slow at index {}",
1873                                i
1874                            );
1875                        } else if fast_val < slow_val {
1876                            prop_assert!(
1877                                y < 0.0,
1878                                "PPO should be negative when fast < slow at index {}",
1879                                i
1880                            );
1881                        } else {
1882                            prop_assert!(
1883                                y.abs() < 1e-9,
1884                                "PPO should be ~0 when fast == slow at index {}",
1885                                i
1886                            );
1887                        }
1888                    }
1889
1890                    if fast_period == slow_period && y.is_finite() {
1891                        prop_assert!(
1892                            y.abs() < 1e-9,
1893                            "PPO should be ~0 when fast_period == slow_period at index {}: got {}",
1894                            i,
1895                            y
1896                        );
1897                    }
1898
1899                    if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-12) && y.is_finite() {
1900                        prop_assert!(
1901                            y.abs() < 1e-6,
1902                            "PPO should be ~0 for constant data at index {}: got {}",
1903                            i,
1904                            y
1905                        );
1906                    }
1907
1908                    if y.is_finite() {
1909                        let window_start = i.saturating_sub(slow_period - 1);
1910                        let window = &data[window_start..=i];
1911                        let min_val = window.iter().cloned().fold(f64::INFINITY, f64::min);
1912                        let max_val = window.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1913                        let volatility_ratio = if min_val > 0.0 {
1914                            max_val / min_val
1915                        } else {
1916                            1.0
1917                        };
1918
1919                        let max_expected_ppo = 100.0 * (volatility_ratio - 1.0);
1920
1921                        prop_assert!(
1922							y.abs() <= max_expected_ppo * 1.5,
1923							"PPO exceeds expected bounds at index {}: got {}%, max expected ~{}% (volatility ratio {})",
1924							i, y, max_expected_ppo, volatility_ratio
1925						);
1926                    }
1927
1928                    if slow_val.abs() < 1e-10 && slow_val != 0.0 {
1929                        prop_assert!(
1930							y.is_nan() || y.abs() > 1000.0,
1931							"PPO should be NaN or very large when slow_ma ~0 at index {}: slow_ma={}, ppo={}",
1932							i, slow_val, y
1933						);
1934                    }
1935                }
1936
1937                let is_monotonic_increasing = data.windows(2).all(|w| w[1] >= w[0]);
1938                let is_monotonic_decreasing = data.windows(2).all(|w| w[1] <= w[0]);
1939
1940                if (is_monotonic_increasing || is_monotonic_decreasing)
1941                    && data.len() > slow_period * 2
1942                {
1943                    let last_values = &out[out.len() - slow_period / 2..];
1944                    let valid_last: Vec<f64> = last_values
1945                        .iter()
1946                        .filter(|x| x.is_finite())
1947                        .cloned()
1948                        .collect();
1949
1950                    if valid_last.len() > 2 {
1951                        if is_monotonic_increasing && fast_period < slow_period {
1952                            let avg = valid_last.iter().sum::<f64>() / valid_last.len() as f64;
1953                            prop_assert!(
1954                                avg > -1e-6,
1955                                "PPO should be positive for monotonic increasing data: avg={}",
1956                                avg
1957                            );
1958                        } else if is_monotonic_decreasing && fast_period < slow_period {
1959                            let avg = valid_last.iter().sum::<f64>() / valid_last.len() as f64;
1960                            prop_assert!(
1961                                avg < 1e-6,
1962                                "PPO should be negative for monotonic decreasing data: avg={}",
1963                                avg
1964                            );
1965                        }
1966                    }
1967                }
1968
1969                Ok(())
1970            })
1971            .unwrap();
1972
1973        Ok(())
1974    }
1975
1976    macro_rules! generate_all_ppo_tests {
1977        ($($test_fn:ident),*) => {
1978            paste::paste! {
1979                $(
1980                    #[test]
1981                    fn [<$test_fn _scalar_f64>]() {
1982                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1983                    }
1984                )*
1985                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1986                $(
1987                    #[test]
1988                    fn [<$test_fn _avx2_f64>]() {
1989                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1990                    }
1991                    #[test]
1992                    fn [<$test_fn _avx512_f64>]() {
1993                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1994                    }
1995                )*
1996            }
1997        }
1998    }
1999
2000    generate_all_ppo_tests!(
2001        check_ppo_partial_params,
2002        check_ppo_accuracy,
2003        check_ppo_default_candles,
2004        check_ppo_zero_period,
2005        check_ppo_period_exceeds_length,
2006        check_ppo_very_small_dataset,
2007        check_ppo_nan_handling,
2008        check_ppo_streaming,
2009        check_ppo_no_poison
2010    );
2011
2012    #[cfg(feature = "proptest")]
2013    generate_all_ppo_tests!(check_ppo_property);
2014
2015    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2016        skip_if_unsupported!(kernel, test);
2017        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2018        let c = read_candles_from_csv(file)?;
2019        let output = PpoBatchBuilder::new()
2020            .kernel(kernel)
2021            .apply_candles(&c, "close")?;
2022        let def = PpoParams::default();
2023        let row = output.values_for(&def).expect("default row missing");
2024        assert_eq!(row.len(), c.close.len());
2025        let expected = [
2026            -0.8532313608928664,
2027            -0.8537562894550523,
2028            -0.6821291938174874,
2029            -0.5620008722078592,
2030            -0.4101724140910927,
2031        ];
2032        let start = row.len() - 5;
2033        for (i, &v) in row[start..].iter().enumerate() {
2034            assert!(
2035                (v - expected[i]).abs() < 1e-7,
2036                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
2037            );
2038        }
2039        Ok(())
2040    }
2041
2042    macro_rules! gen_batch_tests {
2043        ($fn_name:ident) => {
2044            paste::paste! {
2045                #[test] fn [<$fn_name _scalar>]()      {
2046                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2047                }
2048                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2049                #[test] fn [<$fn_name _avx2>]()        {
2050                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2051                }
2052                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2053                #[test] fn [<$fn_name _avx512>]()      {
2054                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2055                }
2056                #[test] fn [<$fn_name _auto_detect>]() {
2057                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2058                }
2059            }
2060        };
2061    }
2062
2063    #[cfg(debug_assertions)]
2064    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2065        skip_if_unsupported!(kernel, test);
2066
2067        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2068        let c = read_candles_from_csv(file)?;
2069
2070        let test_configs = vec![
2071            (2, 10, 2, 12, 30, 3, "sma"),
2072            (5, 25, 5, 26, 50, 5, "sma"),
2073            (30, 60, 15, 65, 100, 10, "sma"),
2074            (2, 5, 1, 6, 10, 1, "ema"),
2075            (10, 20, 2, 21, 40, 4, "wma"),
2076            (3, 9, 3, 12, 21, 3, "sma"),
2077            (7, 14, 7, 21, 28, 7, "ema"),
2078        ];
2079
2080        for (cfg_idx, &(f_start, f_end, f_step, s_start, s_end, s_step, ma_type)) in
2081            test_configs.iter().enumerate()
2082        {
2083            let output = PpoBatchBuilder::new()
2084                .kernel(kernel)
2085                .fast_period_range(f_start, f_end, f_step)
2086                .slow_period_range(s_start, s_end, s_step)
2087                .ma_type(ma_type)
2088                .apply_candles(&c, "close")?;
2089
2090            for (idx, &val) in output.values.iter().enumerate() {
2091                if val.is_nan() {
2092                    continue;
2093                }
2094
2095                let bits = val.to_bits();
2096                let row = idx / output.cols;
2097                let col = idx % output.cols;
2098                let combo = &output.combos[row];
2099
2100                if bits == 0x11111111_11111111 {
2101                    panic!(
2102                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2103						 at row {} col {} (flat index {}) with params: fast_period={}, slow_period={}, ma_type={:?}",
2104                        test,
2105                        cfg_idx,
2106                        val,
2107                        bits,
2108                        row,
2109                        col,
2110                        idx,
2111                        combo.fast_period.unwrap_or(12),
2112                        combo.slow_period.unwrap_or(26),
2113                        combo.ma_type.as_ref().unwrap_or(&"sma".to_string())
2114                    );
2115                }
2116
2117                if bits == 0x22222222_22222222 {
2118                    panic!(
2119                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2120						 at row {} col {} (flat index {}) with params: fast_period={}, slow_period={}, ma_type={:?}",
2121                        test,
2122                        cfg_idx,
2123                        val,
2124                        bits,
2125                        row,
2126                        col,
2127                        idx,
2128                        combo.fast_period.unwrap_or(12),
2129                        combo.slow_period.unwrap_or(26),
2130                        combo.ma_type.as_ref().unwrap_or(&"sma".to_string())
2131                    );
2132                }
2133
2134                if bits == 0x33333333_33333333 {
2135                    panic!(
2136                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2137						 at row {} col {} (flat index {}) with params: fast_period={}, slow_period={}, ma_type={:?}",
2138                        test,
2139                        cfg_idx,
2140                        val,
2141                        bits,
2142                        row,
2143                        col,
2144                        idx,
2145                        combo.fast_period.unwrap_or(12),
2146                        combo.slow_period.unwrap_or(26),
2147                        combo.ma_type.as_ref().unwrap_or(&"sma".to_string())
2148                    );
2149                }
2150            }
2151        }
2152
2153        Ok(())
2154    }
2155
2156    #[cfg(not(debug_assertions))]
2157    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2158        Ok(())
2159    }
2160
2161    gen_batch_tests!(check_batch_default_row);
2162    gen_batch_tests!(check_batch_no_poison);
2163}
2164
2165#[cfg(feature = "python")]
2166#[pyfunction(name = "ppo")]
2167#[pyo3(signature = (data, fast_period=None, slow_period=None, ma_type=None, kernel=None))]
2168pub fn ppo_py<'py>(
2169    py: Python<'py>,
2170    data: PyReadonlyArray1<'py, f64>,
2171    fast_period: Option<usize>,
2172    slow_period: Option<usize>,
2173    ma_type: Option<&str>,
2174    kernel: Option<&str>,
2175) -> PyResult<Bound<'py, PyArray1<f64>>> {
2176    let slice_in = data.as_slice()?;
2177    let kern = validate_kernel(kernel, false)?;
2178
2179    let params = PpoParams {
2180        fast_period,
2181        slow_period,
2182        ma_type: ma_type.map(|s| s.to_string()),
2183    };
2184    let input = PpoInput::from_slice(slice_in, params);
2185
2186    let result_vec: Vec<f64> = py
2187        .allow_threads(|| ppo_with_kernel(&input, kern).map(|o| o.values))
2188        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2189
2190    Ok(result_vec.into_pyarray(py))
2191}
2192
2193#[cfg(feature = "python")]
2194#[pyclass(name = "PpoStream")]
2195pub struct PpoStreamPy {
2196    stream: PpoStream,
2197}
2198
2199#[cfg(feature = "python")]
2200#[pymethods]
2201impl PpoStreamPy {
2202    #[new]
2203    fn new(
2204        fast_period: Option<usize>,
2205        slow_period: Option<usize>,
2206        ma_type: Option<&str>,
2207    ) -> PyResult<Self> {
2208        let params = PpoParams {
2209            fast_period,
2210            slow_period,
2211            ma_type: ma_type.map(|s| s.to_string()),
2212        };
2213        let stream =
2214            PpoStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2215        Ok(PpoStreamPy { stream })
2216    }
2217
2218    fn update(&mut self, value: f64) -> Option<f64> {
2219        self.stream.update(value)
2220    }
2221}
2222
2223#[cfg(all(feature = "python", feature = "cuda"))]
2224#[pyclass(module = "ta_indicators.cuda", unsendable)]
2225pub struct PpoDeviceArrayF32Py {
2226    pub(crate) inner: DeviceArrayF32Ppo,
2227}
2228
2229#[cfg(all(feature = "python", feature = "cuda"))]
2230#[pymethods]
2231impl PpoDeviceArrayF32Py {
2232    #[getter]
2233    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2234        let d = PyDict::new(py);
2235        d.set_item("shape", (self.inner.rows, self.inner.cols))?;
2236        d.set_item("typestr", "<f4")?;
2237        d.set_item(
2238            "strides",
2239            (
2240                self.inner.cols * std::mem::size_of::<f32>(),
2241                std::mem::size_of::<f32>(),
2242            ),
2243        )?;
2244        d.set_item("data", (self.inner.device_ptr() as usize, false))?;
2245
2246        d.set_item("version", 3)?;
2247        Ok(d)
2248    }
2249
2250    fn __dlpack_device__(&self) -> (i32, i32) {
2251        (2, self.inner.device_id as i32)
2252    }
2253
2254    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2255    fn __dlpack__<'py>(
2256        &mut self,
2257        py: Python<'py>,
2258        stream: Option<pyo3::PyObject>,
2259        max_version: Option<pyo3::PyObject>,
2260        dl_device: Option<pyo3::PyObject>,
2261        copy: Option<pyo3::PyObject>,
2262    ) -> PyResult<PyObject> {
2263        let (kdl, alloc_dev) = self.__dlpack_device__();
2264        if let Some(dev_obj) = dl_device.as_ref() {
2265            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2266                if dev_ty != kdl || dev_id != alloc_dev {
2267                    let wants_copy = copy
2268                        .as_ref()
2269                        .and_then(|c| c.extract::<bool>(py).ok())
2270                        .unwrap_or(false);
2271                    if wants_copy {
2272                        return Err(PyValueError::new_err(
2273                            "device copy not implemented for __dlpack__",
2274                        ));
2275                    } else {
2276                        return Err(pyo3::exceptions::PyBufferError::new_err(
2277                            "__dlpack__: requested device does not match producer buffer",
2278                        ));
2279                    }
2280                }
2281            }
2282        }
2283
2284        if let Some(copy_obj) = copy.as_ref() {
2285            let do_copy: bool = copy_obj.extract::<bool>(py)?;
2286            if do_copy {
2287                return Err(pyo3::exceptions::PyBufferError::new_err(
2288                    "__dlpack__(copy=True) not supported for ppo CUDA buffers",
2289                ));
2290            }
2291        }
2292
2293        if let Some(s) = stream.as_ref() {
2294            if let Ok(i) = s.extract::<i64>(py) {
2295                if i == 0 {
2296                    return Err(PyValueError::new_err(
2297                        "__dlpack__: stream 0 is disallowed for CUDA",
2298                    ));
2299                }
2300            }
2301        }
2302
2303        let dev_id_u32 = self.inner.device_id;
2304        let ctx_clone = self.inner.ctx.clone();
2305        let dummy =
2306            DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
2307        let inner = std::mem::replace(
2308            &mut self.inner,
2309            DeviceArrayF32Ppo {
2310                buf: dummy,
2311                rows: 0,
2312                cols: 0,
2313                ctx: ctx_clone,
2314                device_id: dev_id_u32,
2315            },
2316        );
2317
2318        let rows = inner.rows;
2319        let cols = inner.cols;
2320        let buf = inner.buf;
2321
2322        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2323
2324        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2325    }
2326}
2327
2328#[cfg(feature = "python")]
2329#[pyfunction(name = "ppo_batch")]
2330#[pyo3(signature = (data, fast_period_range, slow_period_range, ma_type=None, kernel=None))]
2331pub fn ppo_batch_py<'py>(
2332    py: Python<'py>,
2333    data: PyReadonlyArray1<'py, f64>,
2334    fast_period_range: (usize, usize, usize),
2335    slow_period_range: (usize, usize, usize),
2336    ma_type: Option<&str>,
2337    kernel: Option<&str>,
2338) -> PyResult<Bound<'py, PyDict>> {
2339    let slice_in = data.as_slice()?;
2340
2341    let sweep = PpoBatchRange {
2342        fast_period: fast_period_range,
2343        slow_period: slow_period_range,
2344        ma_type: ma_type.unwrap_or("sma").to_string(),
2345    };
2346
2347    let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
2348    let rows = combos.len();
2349    let cols = slice_in.len();
2350
2351    let total = rows
2352        .checked_mul(cols)
2353        .ok_or_else(|| PyValueError::new_err("rows*cols overflow in ppo_batch_py"))?;
2354    let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2355    let slice_out = unsafe { out_arr.as_slice_mut()? };
2356
2357    let warm: Vec<usize> = combos
2358        .iter()
2359        .map(|c| {
2360            let first = slice_in.iter().position(|x| !x.is_nan()).unwrap_or(0);
2361            first + c.slow_period.unwrap() - 1
2362        })
2363        .collect();
2364
2365    unsafe {
2366        let mu: &mut [MaybeUninit<f64>] = std::slice::from_raw_parts_mut(
2367            slice_out.as_mut_ptr() as *mut MaybeUninit<f64>,
2368            slice_out.len(),
2369        );
2370        init_matrix_prefixes(mu, cols, &warm);
2371    }
2372
2373    let kern = validate_kernel(kernel, true)?;
2374    let combos = py
2375        .allow_threads(|| {
2376            let kernel = match kern {
2377                Kernel::Auto => detect_best_batch_kernel(),
2378                k => k,
2379            };
2380            let simd = match kernel {
2381                Kernel::Avx512Batch => Kernel::Avx512,
2382                Kernel::Avx2Batch => Kernel::Avx2,
2383                Kernel::ScalarBatch => Kernel::Scalar,
2384                k if !k.is_batch() => k,
2385                _ => Kernel::Scalar,
2386            };
2387            ppo_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
2388        })
2389        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2390
2391    let dict = PyDict::new(py);
2392    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2393    dict.set_item(
2394        "fast_periods",
2395        combos
2396            .iter()
2397            .map(|p| p.fast_period.unwrap() as u64)
2398            .collect::<Vec<_>>()
2399            .into_pyarray(py),
2400    )?;
2401    dict.set_item(
2402        "slow_periods",
2403        combos
2404            .iter()
2405            .map(|p| p.slow_period.unwrap() as u64)
2406            .collect::<Vec<_>>()
2407            .into_pyarray(py),
2408    )?;
2409    dict.set_item(
2410        "ma_types",
2411        combos
2412            .iter()
2413            .map(|p| p.ma_type.as_ref().unwrap().clone())
2414            .collect::<Vec<_>>(),
2415    )?;
2416    Ok(dict)
2417}
2418
2419#[cfg(all(feature = "python", feature = "cuda"))]
2420#[pyfunction(name = "ppo_cuda_batch_dev")]
2421#[pyo3(signature = (data_f32, fast_period_range, slow_period_range, ma_type="sma", device_id=0))]
2422pub fn ppo_cuda_batch_dev_py<'py>(
2423    py: Python<'py>,
2424    data_f32: numpy::PyReadonlyArray1<'py, f32>,
2425    fast_period_range: (usize, usize, usize),
2426    slow_period_range: (usize, usize, usize),
2427    ma_type: &str,
2428    device_id: usize,
2429) -> PyResult<(PpoDeviceArrayF32Py, Bound<'py, PyDict>)> {
2430    if !cuda_available() {
2431        return Err(PyValueError::new_err("CUDA not available"));
2432    }
2433
2434    let slice_in = data_f32.as_slice()?;
2435    let sweep = PpoBatchRange {
2436        fast_period: fast_period_range,
2437        slow_period: slow_period_range,
2438        ma_type: ma_type.to_string(),
2439    };
2440    let (inner, combos) = py
2441        .allow_threads(|| CudaPpo::new(device_id).and_then(|c| c.ppo_batch_dev(slice_in, &sweep)))
2442        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2443
2444    let dict = PyDict::new(py);
2445    dict.set_item(
2446        "fast_periods",
2447        combos
2448            .iter()
2449            .map(|p| p.fast_period.unwrap() as u64)
2450            .collect::<Vec<_>>()
2451            .into_pyarray(py),
2452    )?;
2453    dict.set_item(
2454        "slow_periods",
2455        combos
2456            .iter()
2457            .map(|p| p.slow_period.unwrap() as u64)
2458            .collect::<Vec<_>>()
2459            .into_pyarray(py),
2460    )?;
2461    dict.set_item(
2462        "ma_types",
2463        combos
2464            .iter()
2465            .map(|p| p.ma_type.as_ref().unwrap().clone())
2466            .collect::<Vec<_>>(),
2467    )?;
2468    Ok((PpoDeviceArrayF32Py { inner }, dict))
2469}
2470
2471#[cfg(all(feature = "python", feature = "cuda"))]
2472#[pyfunction(name = "ppo_cuda_many_series_one_param_dev")]
2473#[pyo3(signature = (data_tm_f32, fast_period, slow_period, ma_type="sma", device_id=0))]
2474pub fn ppo_cuda_many_series_one_param_dev_py<'py>(
2475    py: Python<'py>,
2476    data_tm_f32: numpy::PyReadonlyArray2<'py, f32>,
2477    fast_period: usize,
2478    slow_period: usize,
2479    ma_type: &str,
2480    device_id: usize,
2481) -> PyResult<PpoDeviceArrayF32Py> {
2482    if !cuda_available() {
2483        return Err(PyValueError::new_err("CUDA not available"));
2484    }
2485
2486    let shape = data_tm_f32.shape();
2487    if shape.len() != 2 {
2488        return Err(PyValueError::new_err("expected 2D array"));
2489    }
2490    let rows = shape[0];
2491    let cols = shape[1];
2492    let flat = data_tm_f32.as_slice()?;
2493    let params = PpoParams {
2494        fast_period: Some(fast_period),
2495        slow_period: Some(slow_period),
2496        ma_type: Some(ma_type.to_string()),
2497    };
2498    let inner = py
2499        .allow_threads(|| {
2500            CudaPpo::new(device_id)
2501                .and_then(|c| c.ppo_many_series_one_param_time_major_dev(flat, cols, rows, &params))
2502        })
2503        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2504    Ok(PpoDeviceArrayF32Py { inner })
2505}
2506
2507#[cfg(feature = "python")]
2508pub fn register_ppo_module(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
2509    m.add_function(wrap_pyfunction!(ppo_py, m)?)?;
2510    m.add_function(wrap_pyfunction!(ppo_batch_py, m)?)?;
2511    m.add_class::<PpoStreamPy>()?;
2512    #[cfg(feature = "cuda")]
2513    {
2514        m.add_class::<PpoDeviceArrayF32Py>()?;
2515        m.add_function(wrap_pyfunction!(ppo_cuda_batch_dev_py, m)?)?;
2516        m.add_function(wrap_pyfunction!(ppo_cuda_many_series_one_param_dev_py, m)?)?;
2517    }
2518    Ok(())
2519}
2520
2521#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2522#[wasm_bindgen]
2523pub fn ppo_js(
2524    data: &[f64],
2525    fast_period: usize,
2526    slow_period: usize,
2527    ma_type: &str,
2528) -> Result<Vec<f64>, JsValue> {
2529    let params = PpoParams {
2530        fast_period: Some(fast_period),
2531        slow_period: Some(slow_period),
2532        ma_type: Some(ma_type.to_string()),
2533    };
2534    let input = PpoInput::from_slice(data, params);
2535    let mut out = vec![0.0; data.len()];
2536    ppo_into_slice(&mut out, &input, detect_best_kernel())
2537        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2538    Ok(out)
2539}
2540
2541#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2542#[wasm_bindgen]
2543pub fn ppo_alloc(len: usize) -> *mut f64 {
2544    let mut vec = Vec::<f64>::with_capacity(len);
2545    let ptr = vec.as_mut_ptr();
2546    std::mem::forget(vec);
2547    ptr
2548}
2549
2550#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2551#[wasm_bindgen]
2552pub fn ppo_free(ptr: *mut f64, len: usize) {
2553    if !ptr.is_null() {
2554        unsafe {
2555            let _ = Vec::from_raw_parts(ptr, len, len);
2556        }
2557    }
2558}
2559
2560#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2561#[wasm_bindgen]
2562pub fn ppo_into(
2563    in_ptr: *const f64,
2564    out_ptr: *mut f64,
2565    len: usize,
2566    fast_period: usize,
2567    slow_period: usize,
2568    ma_type: &str,
2569) -> Result<(), JsValue> {
2570    if in_ptr.is_null() || out_ptr.is_null() {
2571        return Err(JsValue::from_str("null pointer passed to ppo_into"));
2572    }
2573    unsafe {
2574        let data = std::slice::from_raw_parts(in_ptr, len);
2575        let params = PpoParams {
2576            fast_period: Some(fast_period),
2577            slow_period: Some(slow_period),
2578            ma_type: Some(ma_type.to_string()),
2579        };
2580        let input = PpoInput::from_slice(data, params);
2581        if in_ptr == out_ptr {
2582            let mut tmp = vec![0.0; len];
2583            ppo_into_slice(&mut tmp, &input, detect_best_kernel())
2584                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2585            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2586            out.copy_from_slice(&tmp);
2587        } else {
2588            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2589            ppo_into_slice(out, &input, detect_best_kernel())
2590                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2591        }
2592    }
2593    Ok(())
2594}
2595
2596#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2597#[derive(Serialize, Deserialize)]
2598pub struct PpoBatchConfig {
2599    pub fast_period_range: (usize, usize, usize),
2600    pub slow_period_range: (usize, usize, usize),
2601    pub ma_type: String,
2602}
2603
2604#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2605#[derive(Serialize, Deserialize)]
2606pub struct PpoBatchJsOutput {
2607    pub values: Vec<f64>,
2608    pub combos: Vec<PpoParams>,
2609    pub rows: usize,
2610    pub cols: usize,
2611}
2612
2613#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2614#[wasm_bindgen(js_name = "ppo_batch")]
2615pub fn ppo_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2616    let cfg: PpoBatchConfig = serde_wasm_bindgen::from_value(config)
2617        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2618    let sweep = PpoBatchRange {
2619        fast_period: cfg.fast_period_range,
2620        slow_period: cfg.slow_period_range,
2621        ma_type: cfg.ma_type,
2622    };
2623    let out = ppo_batch_inner(data, &sweep, detect_best_kernel(), false)
2624        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2625    let js = PpoBatchJsOutput {
2626        values: out.values,
2627        combos: out.combos,
2628        rows: out.rows,
2629        cols: out.cols,
2630    };
2631    serde_wasm_bindgen::to_value(&js)
2632        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2633}
2634
2635#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2636#[wasm_bindgen]
2637pub fn ppo_batch_into(
2638    in_ptr: *const f64,
2639    out_ptr: *mut f64,
2640    len: usize,
2641    fast_start: usize,
2642    fast_end: usize,
2643    fast_step: usize,
2644    slow_start: usize,
2645    slow_end: usize,
2646    slow_step: usize,
2647    ma_type: &str,
2648) -> Result<usize, JsValue> {
2649    if in_ptr.is_null() || out_ptr.is_null() {
2650        return Err(JsValue::from_str("null pointer passed to ppo_batch_into"));
2651    }
2652    unsafe {
2653        let data = std::slice::from_raw_parts(in_ptr, len);
2654        let sweep = PpoBatchRange {
2655            fast_period: (fast_start, fast_end, fast_step),
2656            slow_period: (slow_start, slow_end, slow_step),
2657            ma_type: ma_type.to_string(),
2658        };
2659        let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2660        let rows = combos.len();
2661        let cols = len;
2662        let total = rows
2663            .checked_mul(cols)
2664            .ok_or_else(|| JsValue::from_str("rows*cols overflow in ppo_batch_into"))?;
2665        let out = std::slice::from_raw_parts_mut(out_ptr, total);
2666        ppo_batch_inner_into(data, &sweep, detect_best_kernel(), false, out)
2667            .map_err(|e| JsValue::from_str(&e.to_string()))?;
2668        Ok(rows)
2669    }
2670}