1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::cuda_available;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::cuda::oscillators::ppo_wrapper::{CudaPpo, DeviceArrayF32Ppo};
5#[cfg(all(feature = "python", feature = "cuda"))]
6use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
7#[cfg(all(feature = "python", feature = "cuda"))]
8use cust::memory::DeviceBuffer;
9#[cfg(all(feature = "python", feature = "cuda"))]
10use numpy::PyUntypedArrayMethods;
11#[cfg(feature = "python")]
12use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
13#[cfg(feature = "python")]
14use pyo3::exceptions::PyValueError;
15#[cfg(feature = "python")]
16use pyo3::prelude::*;
17#[cfg(feature = "python")]
18use pyo3::types::{PyAny, PyDict};
19#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
20use serde::{Deserialize, Serialize};
21#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
22use wasm_bindgen::prelude::*;
23
24use crate::indicators::moving_averages::ma::{ma, MaData};
25use crate::utilities::data_loader::{source_type, Candles};
26use crate::utilities::enums::Kernel;
27use crate::utilities::helpers::{
28 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
29 make_uninit_matrix,
30};
31#[cfg(feature = "python")]
32use crate::utilities::kernel_validation::validate_kernel;
33#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
34use core::arch::x86_64::*;
35#[cfg(not(target_arch = "wasm32"))]
36use rayon::prelude::*;
37use std::collections::HashMap;
38use std::convert::AsRef;
39use std::error::Error;
40use std::mem::MaybeUninit;
41use thiserror::Error;
42
43impl<'a> AsRef<[f64]> for PpoInput<'a> {
44 #[inline(always)]
45 fn as_ref(&self) -> &[f64] {
46 match &self.data {
47 PpoData::Slice(slice) => slice,
48 PpoData::Candles { candles, source } => source_type(candles, source),
49 }
50 }
51}
52
53#[derive(Debug, Clone)]
54pub enum PpoData<'a> {
55 Candles {
56 candles: &'a Candles,
57 source: &'a str,
58 },
59 Slice(&'a [f64]),
60}
61
62#[derive(Debug, Clone)]
63pub struct PpoOutput {
64 pub values: Vec<f64>,
65}
66
67#[derive(Debug, Clone)]
68#[cfg_attr(
69 all(target_arch = "wasm32", feature = "wasm"),
70 derive(Serialize, Deserialize)
71)]
72pub struct PpoParams {
73 pub fast_period: Option<usize>,
74 pub slow_period: Option<usize>,
75 pub ma_type: Option<String>,
76}
77
78impl Default for PpoParams {
79 fn default() -> Self {
80 Self {
81 fast_period: Some(12),
82 slow_period: Some(26),
83 ma_type: Some("sma".to_string()),
84 }
85 }
86}
87
88#[derive(Debug, Clone)]
89pub struct PpoInput<'a> {
90 pub data: PpoData<'a>,
91 pub params: PpoParams,
92}
93
94impl<'a> PpoInput<'a> {
95 #[inline]
96 pub fn from_candles(c: &'a Candles, s: &'a str, p: PpoParams) -> Self {
97 Self {
98 data: PpoData::Candles {
99 candles: c,
100 source: s,
101 },
102 params: p,
103 }
104 }
105 #[inline]
106 pub fn from_slice(sl: &'a [f64], p: PpoParams) -> Self {
107 Self {
108 data: PpoData::Slice(sl),
109 params: p,
110 }
111 }
112 #[inline]
113 pub fn with_default_candles(c: &'a Candles) -> Self {
114 Self::from_candles(c, "close", PpoParams::default())
115 }
116 #[inline]
117 pub fn get_fast_period(&self) -> usize {
118 self.params.fast_period.unwrap_or(12)
119 }
120 #[inline]
121 pub fn get_slow_period(&self) -> usize {
122 self.params.slow_period.unwrap_or(26)
123 }
124 #[inline]
125 pub fn get_ma_type(&self) -> String {
126 self.params
127 .ma_type
128 .clone()
129 .unwrap_or_else(|| "sma".to_string())
130 }
131}
132
133#[derive(Clone, Debug)]
134pub struct PpoBuilder {
135 fast_period: Option<usize>,
136 slow_period: Option<usize>,
137 ma_type: Option<String>,
138 kernel: Kernel,
139}
140
141impl Default for PpoBuilder {
142 fn default() -> Self {
143 Self {
144 fast_period: None,
145 slow_period: None,
146 ma_type: None,
147 kernel: Kernel::Auto,
148 }
149 }
150}
151
152impl PpoBuilder {
153 #[inline(always)]
154 pub fn new() -> Self {
155 Self::default()
156 }
157 #[inline(always)]
158 pub fn fast_period(mut self, n: usize) -> Self {
159 self.fast_period = Some(n);
160 self
161 }
162 #[inline(always)]
163 pub fn slow_period(mut self, n: usize) -> Self {
164 self.slow_period = Some(n);
165 self
166 }
167 #[inline(always)]
168 pub fn ma_type<S: Into<String>>(mut self, s: S) -> Self {
169 self.ma_type = Some(s.into());
170 self
171 }
172 #[inline(always)]
173 pub fn kernel(mut self, k: Kernel) -> Self {
174 self.kernel = k;
175 self
176 }
177 #[inline(always)]
178 pub fn apply(self, c: &Candles) -> Result<PpoOutput, PpoError> {
179 let p = PpoParams {
180 fast_period: self.fast_period,
181 slow_period: self.slow_period,
182 ma_type: self.ma_type,
183 };
184 let i = PpoInput::from_candles(c, "close", p);
185 ppo_with_kernel(&i, self.kernel)
186 }
187 #[inline(always)]
188 pub fn apply_slice(self, d: &[f64]) -> Result<PpoOutput, PpoError> {
189 let p = PpoParams {
190 fast_period: self.fast_period,
191 slow_period: self.slow_period,
192 ma_type: self.ma_type,
193 };
194 let i = PpoInput::from_slice(d, p);
195 ppo_with_kernel(&i, self.kernel)
196 }
197 #[inline(always)]
198 pub fn into_stream(self) -> Result<PpoStream, PpoError> {
199 let p = PpoParams {
200 fast_period: self.fast_period,
201 slow_period: self.slow_period,
202 ma_type: self.ma_type,
203 };
204 PpoStream::try_new(p)
205 }
206}
207
208#[derive(Debug, Error)]
209pub enum PpoError {
210 #[error("ppo: Empty data provided.")]
211 EmptyInputData,
212 #[error("ppo: All values are NaN.")]
213 AllValuesNaN,
214 #[error("ppo: Invalid period: fast = {fast}, slow = {slow}, data length = {data_len}")]
215 InvalidPeriod {
216 fast: usize,
217 slow: usize,
218 data_len: usize,
219 },
220 #[error("ppo: Not enough valid data: needed = {needed}, valid = {valid}")]
221 NotEnoughValidData { needed: usize, valid: usize },
222 #[error("ppo: output length mismatch: expected = {expected}, got = {got}")]
223 OutputLengthMismatch { expected: usize, got: usize },
224 #[error("ppo: Invalid range: start = {start}, end = {end}, step = {step}")]
225 InvalidRange {
226 start: usize,
227 end: usize,
228 step: usize,
229 },
230 #[error("ppo: invalid kernel for batch path: {0:?}")]
231 InvalidKernelForBatch(Kernel),
232 #[error("ppo: invalid input: {0}")]
233 InvalidInput(String),
234 #[error("ppo: MA error: {0}")]
235 MaError(String),
236}
237
238#[inline]
239pub fn ppo(input: &PpoInput) -> Result<PpoOutput, PpoError> {
240 ppo_with_kernel(input, Kernel::Auto)
241}
242
243pub fn ppo_with_kernel(input: &PpoInput, kernel: Kernel) -> Result<PpoOutput, PpoError> {
244 let data: &[f64] = input.as_ref();
245 let len = data.len();
246 if len == 0 {
247 return Err(PpoError::EmptyInputData);
248 }
249 let first = data
250 .iter()
251 .position(|x| !x.is_nan())
252 .ok_or(PpoError::AllValuesNaN)?;
253
254 let fast = input.get_fast_period();
255 let slow = input.get_slow_period();
256 let ma_type = input.params.ma_type.as_deref().unwrap_or("sma");
257
258 if fast == 0 || slow == 0 || fast > len || slow > len {
259 return Err(PpoError::InvalidPeriod {
260 fast,
261 slow,
262 data_len: len,
263 });
264 }
265 if (len - first) < slow {
266 return Err(PpoError::NotEnoughValidData {
267 needed: slow,
268 valid: len - first,
269 });
270 }
271
272 let chosen = match kernel {
273 Kernel::Auto => Kernel::Scalar,
274 k => k,
275 };
276
277 let mut out = alloc_with_nan_prefix(len, first + slow - 1);
278
279 unsafe {
280 match chosen {
281 Kernel::Scalar | Kernel::ScalarBatch => {
282 ppo_scalar(data, fast, slow, ma_type, first, &mut out)
283 }
284 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
285 Kernel::Avx2 | Kernel::Avx2Batch => {
286 ppo_avx2(data, fast, slow, ma_type, first, &mut out)
287 }
288 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
289 Kernel::Avx512 | Kernel::Avx512Batch => {
290 ppo_avx512(data, fast, slow, ma_type, first, &mut out)
291 }
292 _ => unreachable!(),
293 }
294 }
295
296 Ok(PpoOutput { values: out })
297}
298
299pub fn ppo_into_slice(dst: &mut [f64], input: &PpoInput, kern: Kernel) -> Result<(), PpoError> {
300 let data = input.as_ref();
301 if data.is_empty() {
302 return Err(PpoError::EmptyInputData);
303 }
304
305 let fast = input.get_fast_period();
306 let slow = input.get_slow_period();
307 let ma_type = input.params.ma_type.as_deref().unwrap_or("sma");
308
309 if fast == 0 || slow == 0 || fast > data.len() || slow > data.len() {
310 return Err(PpoError::InvalidPeriod {
311 fast,
312 slow,
313 data_len: data.len(),
314 });
315 }
316 if dst.len() != data.len() {
317 return Err(PpoError::OutputLengthMismatch {
318 expected: data.len(),
319 got: dst.len(),
320 });
321 }
322
323 let first = data
324 .iter()
325 .position(|x| !x.is_nan())
326 .ok_or(PpoError::AllValuesNaN)?;
327 if data.len() - first < slow {
328 return Err(PpoError::NotEnoughValidData {
329 needed: slow,
330 valid: data.len() - first,
331 });
332 }
333
334 let chosen = match kern {
335 Kernel::Auto => Kernel::Scalar,
336 k => k,
337 };
338
339 unsafe {
340 match chosen {
341 Kernel::Scalar | Kernel::ScalarBatch => {
342 ppo_scalar(data, fast, slow, ma_type, first, dst)
343 }
344 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
345 Kernel::Avx2 | Kernel::Avx2Batch => ppo_avx2(data, fast, slow, ma_type, first, dst),
346 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
347 Kernel::Avx512 | Kernel::Avx512Batch => {
348 ppo_avx512(data, fast, slow, ma_type, first, dst)
349 }
350 _ => unreachable!(),
351 }
352 }
353
354 let warmup_end = first + slow - 1;
355 for v in &mut dst[..warmup_end] {
356 *v = f64::from_bits(0x7ff8_0000_0000_0000);
357 }
358 Ok(())
359}
360
361#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
362#[inline]
363pub fn ppo_into(input: &PpoInput, out: &mut [f64]) -> Result<(), PpoError> {
364 ppo_into_slice(out, input, Kernel::Auto)
365}
366
367#[inline]
368pub unsafe fn ppo_scalar(
369 data: &[f64],
370 fast: usize,
371 slow: usize,
372 ma_type: &str,
373 first: usize,
374 out: &mut [f64],
375) {
376 if ma_type == "ema" {
377 ppo_scalar_classic_ema(data, fast, slow, first, out);
378 return;
379 } else if ma_type == "sma" {
380 ppo_scalar_classic_sma(data, fast, slow, first, out);
381 return;
382 }
383
384 let start = first + slow - 1;
385 let fast_ma = match ma(ma_type, MaData::Slice(data), fast) {
386 Ok(v) => v,
387 Err(_) => {
388 let mut i = start;
389 while i < data.len() {
390 *out.get_unchecked_mut(i) = f64::NAN;
391 i += 1;
392 }
393 return;
394 }
395 };
396 let slow_ma = match ma(ma_type, MaData::Slice(data), slow) {
397 Ok(v) => v,
398 Err(_) => {
399 let mut i = start;
400 while i < data.len() {
401 *out.get_unchecked_mut(i) = f64::NAN;
402 i += 1;
403 }
404 return;
405 }
406 };
407
408 let n = data.len();
409 let mut i = start;
410 while i < n {
411 let sf = *slow_ma.get_unchecked(i);
412 let ff = *fast_ma.get_unchecked(i);
413 let y = if sf == 0.0 || sf.is_nan() || ff.is_nan() {
414 f64::NAN
415 } else {
416 let ratio = ff / sf;
417 f64::mul_add(ratio, 100.0, -100.0)
418 };
419 *out.get_unchecked_mut(i) = y;
420 i += 1;
421 }
422}
423
424#[inline]
425pub unsafe fn ppo_scalar_classic_ema(
426 data: &[f64],
427 fast: usize,
428 slow: usize,
429 first: usize,
430 out: &mut [f64],
431) {
432 let n = data.len();
433 let start_idx = first + slow - 1;
434
435 let fa = 2.0 / (fast as f64 + 1.0);
436 let sa = 2.0 / (slow as f64 + 1.0);
437 let fb = 1.0 - fa;
438 let sb = 1.0 - sa;
439
440 let mut slow_sum = 0.0f64;
441 let mut fast_sum = 0.0f64;
442 let overlap = slow - fast;
443 let mut k = 0usize;
444 while k < slow {
445 let v = *data.get_unchecked(first + k);
446 slow_sum += v;
447 if k >= overlap {
448 fast_sum += v;
449 }
450 k += 1;
451 }
452
453 let mut fast_ema = fast_sum / (fast as f64);
454 let mut slow_ema = slow_sum / (slow as f64);
455
456 let mut i = first + fast;
457 while i <= start_idx {
458 let x = *data.get_unchecked(i);
459 fast_ema = f64::mul_add(fa, x, fb * fast_ema);
460 i += 1;
461 }
462
463 *out.get_unchecked_mut(start_idx) = if slow_ema == 0.0 || slow_ema.is_nan() || fast_ema.is_nan()
464 {
465 f64::NAN
466 } else {
467 let ratio = fast_ema / slow_ema;
468 f64::mul_add(ratio, 100.0, -100.0)
469 };
470
471 let mut j = start_idx + 1;
472 while j < n {
473 let x = *data.get_unchecked(j);
474 fast_ema = f64::mul_add(fa, x, fb * fast_ema);
475 slow_ema = f64::mul_add(sa, x, sb * slow_ema);
476
477 let y = if slow_ema == 0.0 || slow_ema.is_nan() || fast_ema.is_nan() {
478 f64::NAN
479 } else {
480 let ratio = fast_ema / slow_ema;
481 f64::mul_add(ratio, 100.0, -100.0)
482 };
483 *out.get_unchecked_mut(j) = y;
484 j += 1;
485 }
486}
487
488#[inline]
489pub unsafe fn ppo_scalar_classic_sma(
490 data: &[f64],
491 fast: usize,
492 slow: usize,
493 first: usize,
494 out: &mut [f64],
495) {
496 let n = data.len();
497 let start_idx = first + slow - 1;
498
499 let k = (slow as f64) / (fast as f64);
500
501 let mut slow_sum = 0.0f64;
502 let mut fast_sum = 0.0f64;
503 let overlap = slow - fast;
504 let mut t = 0usize;
505 while t < slow {
506 let v = *data.get_unchecked(first + t);
507 slow_sum += v;
508 if t >= overlap {
509 fast_sum += v;
510 }
511 t += 1;
512 }
513
514 *out.get_unchecked_mut(start_idx) = if slow_sum == 0.0 || slow_sum.is_nan() || fast_sum.is_nan()
515 {
516 f64::NAN
517 } else {
518 let ratio = (fast_sum * k) / slow_sum;
519 f64::mul_add(ratio, 100.0, -100.0)
520 };
521
522 let mut i = start_idx + 1;
523 while i < n {
524 let add = *data.get_unchecked(i);
525
526 let sub_fast = *data.get_unchecked(i - fast);
527 let sub_slow = *data.get_unchecked(i - slow);
528
529 fast_sum += add - sub_fast;
530 slow_sum += add - sub_slow;
531
532 let y = if slow_sum == 0.0 || slow_sum.is_nan() || fast_sum.is_nan() {
533 f64::NAN
534 } else {
535 let ratio = (fast_sum * k) / slow_sum;
536 f64::mul_add(ratio, 100.0, -100.0)
537 };
538 *out.get_unchecked_mut(i) = y;
539
540 i += 1;
541 }
542}
543
544#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
545#[inline]
546pub unsafe fn ppo_avx2(
547 data: &[f64],
548 fast: usize,
549 slow: usize,
550 ma_type: &str,
551 first: usize,
552 out: &mut [f64],
553) {
554 ppo_scalar(data, fast, slow, ma_type, first, out)
555}
556
557#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
558#[inline]
559pub unsafe fn ppo_avx512(
560 data: &[f64],
561 fast: usize,
562 slow: usize,
563 ma_type: &str,
564 first: usize,
565 out: &mut [f64],
566) {
567 if slow <= 32 {
568 ppo_avx512_short(data, fast, slow, ma_type, first, out)
569 } else {
570 ppo_avx512_long(data, fast, slow, ma_type, first, out)
571 }
572}
573
574#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
575#[inline]
576pub unsafe fn ppo_avx512_short(
577 data: &[f64],
578 fast: usize,
579 slow: usize,
580 ma_type: &str,
581 first: usize,
582 out: &mut [f64],
583) {
584 ppo_scalar(data, fast, slow, ma_type, first, out)
585}
586
587#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
588#[inline]
589pub unsafe fn ppo_avx512_long(
590 data: &[f64],
591 fast: usize,
592 slow: usize,
593 ma_type: &str,
594 first: usize,
595 out: &mut [f64],
596) {
597 ppo_scalar(data, fast, slow, ma_type, first, out)
598}
599
600#[derive(Clone, Debug)]
601pub struct PpoBatchRange {
602 pub fast_period: (usize, usize, usize),
603 pub slow_period: (usize, usize, usize),
604 pub ma_type: String,
605}
606
607impl Default for PpoBatchRange {
608 fn default() -> Self {
609 Self {
610 fast_period: (12, 12, 0),
611 slow_period: (26, 275, 1),
612 ma_type: "sma".to_string(),
613 }
614 }
615}
616
617#[derive(Clone, Debug, Default)]
618pub struct PpoBatchBuilder {
619 range: PpoBatchRange,
620 kernel: Kernel,
621}
622
623impl PpoBatchBuilder {
624 pub fn new() -> Self {
625 Self::default()
626 }
627 pub fn kernel(mut self, k: Kernel) -> Self {
628 self.kernel = k;
629 self
630 }
631 pub fn fast_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
632 self.range.fast_period = (start, end, step);
633 self
634 }
635 pub fn slow_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
636 self.range.slow_period = (start, end, step);
637 self
638 }
639 pub fn ma_type<S: Into<String>>(mut self, t: S) -> Self {
640 self.range.ma_type = t.into();
641 self
642 }
643 pub fn apply_slice(self, data: &[f64]) -> Result<PpoBatchOutput, PpoError> {
644 ppo_batch_with_kernel(data, &self.range, self.kernel)
645 }
646 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<PpoBatchOutput, PpoError> {
647 PpoBatchBuilder::new().kernel(k).apply_slice(data)
648 }
649 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<PpoBatchOutput, PpoError> {
650 let slice = source_type(c, src);
651 self.apply_slice(slice)
652 }
653 pub fn with_default_candles(c: &Candles) -> Result<PpoBatchOutput, PpoError> {
654 PpoBatchBuilder::new()
655 .kernel(Kernel::Auto)
656 .apply_candles(c, "close")
657 }
658}
659
660#[derive(Clone, Debug)]
661pub struct PpoBatchOutput {
662 pub values: Vec<f64>,
663 pub combos: Vec<PpoParams>,
664 pub rows: usize,
665 pub cols: usize,
666}
667
668impl PpoBatchOutput {
669 pub fn row_for_params(&self, p: &PpoParams) -> Option<usize> {
670 self.combos.iter().position(|c| {
671 c.fast_period.unwrap_or(12) == p.fast_period.unwrap_or(12)
672 && c.slow_period.unwrap_or(26) == p.slow_period.unwrap_or(26)
673 && c.ma_type.as_ref().unwrap_or(&"sma".to_string())
674 == p.ma_type.as_ref().unwrap_or(&"sma".to_string())
675 })
676 }
677 pub fn values_for(&self, p: &PpoParams) -> Option<&[f64]> {
678 self.row_for_params(p).and_then(|row| {
679 let start = row.checked_mul(self.cols)?;
680 let end = start.checked_add(self.cols)?;
681 self.values.get(start..end)
682 })
683 }
684}
685
686#[inline(always)]
687fn expand_grid(r: &PpoBatchRange) -> Result<Vec<PpoParams>, PpoError> {
688 fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
689 if step == 0 || start == end {
690 return vec![start];
691 }
692 if start <= end {
693 (start..=end).step_by(step).collect()
694 } else {
695 let mut v = Vec::new();
696 let mut cur = start;
697 loop {
698 v.push(cur);
699 if cur <= end {
700 break;
701 }
702 let next = match cur.checked_sub(step) {
703 Some(n) => n,
704 None => break,
705 };
706 if next < end {
707 break;
708 }
709 cur = next;
710 }
711 v
712 }
713 }
714
715 let fasts = axis_usize(r.fast_period);
716 let slows = axis_usize(r.slow_period);
717 if fasts.is_empty() {
718 let (start, end, step) = r.fast_period;
719 return Err(PpoError::InvalidRange { start, end, step });
720 }
721 if slows.is_empty() {
722 let (start, end, step) = r.slow_period;
723 return Err(PpoError::InvalidRange { start, end, step });
724 }
725
726 let ma_type = r.ma_type.clone();
727
728 let total = fasts
729 .len()
730 .checked_mul(slows.len())
731 .ok_or_else(|| PpoError::InvalidInput("fast*slow grid overflow".into()))?;
732 let mut out = Vec::with_capacity(total);
733 for &f in &fasts {
734 for &s in &slows {
735 out.push(PpoParams {
736 fast_period: Some(f),
737 slow_period: Some(s),
738 ma_type: Some(ma_type.clone()),
739 });
740 }
741 }
742 Ok(out)
743}
744
745#[inline(always)]
746pub fn ppo_batch_with_kernel(
747 data: &[f64],
748 sweep: &PpoBatchRange,
749 k: Kernel,
750) -> Result<PpoBatchOutput, PpoError> {
751 let kernel = match k {
752 Kernel::Auto => detect_best_batch_kernel(),
753 other if other.is_batch() => other,
754 other => return Err(PpoError::InvalidKernelForBatch(other)),
755 };
756 let simd = match kernel {
757 Kernel::Avx512Batch => Kernel::Avx512,
758 Kernel::Avx2Batch => Kernel::Avx2,
759 Kernel::ScalarBatch => Kernel::Scalar,
760 _ => unreachable!(),
761 };
762 ppo_batch_par_slice(data, sweep, simd)
763}
764
765#[inline(always)]
766pub fn ppo_batch_slice(
767 data: &[f64],
768 sweep: &PpoBatchRange,
769 kern: Kernel,
770) -> Result<PpoBatchOutput, PpoError> {
771 ppo_batch_inner(data, sweep, kern, false)
772}
773
774#[inline(always)]
775pub fn ppo_batch_par_slice(
776 data: &[f64],
777 sweep: &PpoBatchRange,
778 kern: Kernel,
779) -> Result<PpoBatchOutput, PpoError> {
780 ppo_batch_inner(data, sweep, kern, true)
781}
782
783#[inline(always)]
784fn ppo_batch_inner_into(
785 data: &[f64],
786 sweep: &PpoBatchRange,
787 kern: Kernel,
788 parallel: bool,
789 out: &mut [f64],
790) -> Result<Vec<PpoParams>, PpoError> {
791 let combos = expand_grid(sweep)?;
792 if combos.is_empty() {
793 let (start, end, step) = sweep.fast_period;
794 return Err(PpoError::InvalidRange { start, end, step });
795 }
796
797 let first = data
798 .iter()
799 .position(|x| !x.is_nan())
800 .ok_or(PpoError::AllValuesNaN)?;
801 let max_slow = combos.iter().map(|c| c.slow_period.unwrap()).max().unwrap();
802 if data.len() - first < max_slow {
803 return Err(PpoError::NotEnoughValidData {
804 needed: max_slow,
805 valid: data.len() - first,
806 });
807 }
808
809 let rows = combos.len();
810 let cols = data.len();
811 let expected = rows.checked_mul(cols).ok_or_else(|| {
812 PpoError::InvalidInput("rows*cols overflow in ppo_batch_inner_into".into())
813 })?;
814 if out.len() != expected {
815 return Err(PpoError::OutputLengthMismatch {
816 expected,
817 got: out.len(),
818 });
819 }
820
821 let out_uninit: &mut [MaybeUninit<f64>] = unsafe {
822 std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
823 };
824
825 let ma_type = sweep.ma_type.as_str();
826 let use_cached = ma_type == "sma" || ma_type == "ema";
827 let mut ma_cache: HashMap<usize, Vec<f64>> = HashMap::new();
828 if use_cached {
829 let mut uniq: Vec<usize> = Vec::new();
830 for c in &combos {
831 let f = c.fast_period.unwrap();
832 let s = c.slow_period.unwrap();
833 if !uniq.contains(&f) {
834 uniq.push(f);
835 }
836 if !uniq.contains(&s) {
837 uniq.push(s);
838 }
839 }
840 for &p in &uniq {
841 if let Ok(v) = ma(ma_type, MaData::Slice(data), p) {
842 ma_cache.insert(p, v);
843 } else {
844 let mut v = vec![f64::NAN; data.len()];
845 ma_cache.insert(p, v);
846 }
847 }
848 }
849
850 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
851 let p = &combos[row];
852 let out_row: &mut [f64] =
853 std::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
854 if use_cached {
855 let fast = p.fast_period.unwrap();
856 let slow = p.slow_period.unwrap();
857 let fast_ma = ma_cache.get(&fast).unwrap();
858 let slow_ma = ma_cache.get(&slow).unwrap();
859 let mut i = first + slow - 1;
860 let n = data.len();
861 while i < n {
862 let sf = *slow_ma.get_unchecked(i);
863 let ff = *fast_ma.get_unchecked(i);
864 let y = if sf == 0.0 || sf.is_nan() || ff.is_nan() {
865 f64::NAN
866 } else {
867 let ratio = ff / sf;
868 f64::mul_add(ratio, 100.0, -100.0)
869 };
870 *out_row.get_unchecked_mut(i) = y;
871 i += 1;
872 }
873 } else {
874 match kern {
875 Kernel::Scalar => ppo_row_scalar(
876 data,
877 first,
878 p.fast_period.unwrap(),
879 p.slow_period.unwrap(),
880 p.ma_type.as_ref().unwrap(),
881 out_row,
882 ),
883 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
884 Kernel::Avx2 => ppo_row_avx2(
885 data,
886 first,
887 p.fast_period.unwrap(),
888 p.slow_period.unwrap(),
889 p.ma_type.as_ref().unwrap(),
890 out_row,
891 ),
892 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
893 Kernel::Avx512 => ppo_row_avx512(
894 data,
895 first,
896 p.fast_period.unwrap(),
897 p.slow_period.unwrap(),
898 p.ma_type.as_ref().unwrap(),
899 out_row,
900 ),
901 _ => unreachable!(),
902 }
903 }
904 };
905
906 if parallel {
907 #[cfg(not(target_arch = "wasm32"))]
908 {
909 out_uninit
910 .par_chunks_mut(cols)
911 .enumerate()
912 .for_each(|(row, slice)| do_row(row, slice));
913 }
914 #[cfg(target_arch = "wasm32")]
915 {
916 for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
917 do_row(row, slice);
918 }
919 }
920 } else {
921 for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
922 do_row(row, slice);
923 }
924 }
925
926 Ok(combos)
927}
928
929#[inline(always)]
930fn ppo_batch_inner(
931 data: &[f64],
932 sweep: &PpoBatchRange,
933 kern: Kernel,
934 parallel: bool,
935) -> Result<PpoBatchOutput, PpoError> {
936 let combos = expand_grid(sweep)?;
937 if combos.is_empty() {
938 let (start, end, step) = sweep.fast_period;
939 return Err(PpoError::InvalidRange { start, end, step });
940 }
941 let first = data
942 .iter()
943 .position(|x| !x.is_nan())
944 .ok_or(PpoError::AllValuesNaN)?;
945 let max_slow = combos.iter().map(|c| c.slow_period.unwrap()).max().unwrap();
946 if data.len() - first < max_slow {
947 return Err(PpoError::NotEnoughValidData {
948 needed: max_slow,
949 valid: data.len() - first,
950 });
951 }
952 let rows = combos.len();
953 let cols = data.len();
954 let _total = rows
955 .checked_mul(cols)
956 .ok_or_else(|| PpoError::InvalidInput("rows*cols overflow in ppo_batch_inner".into()))?;
957
958 let mut buf_mu = make_uninit_matrix(rows, cols);
959
960 let warmup_periods: Vec<usize> = combos
961 .iter()
962 .map(|c| first + c.slow_period.unwrap() - 1)
963 .collect();
964
965 init_matrix_prefixes(&mut buf_mu, cols, &warmup_periods);
966
967 let mut buf_guard = core::mem::ManuallyDrop::new(buf_mu);
968 let values: &mut [f64] = unsafe {
969 core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
970 };
971
972 let ma_type = sweep.ma_type.as_str();
973 let use_cached = ma_type == "sma" || ma_type == "ema";
974 let mut ma_cache: HashMap<usize, Vec<f64>> = HashMap::new();
975 if use_cached {
976 let mut uniq: Vec<usize> = Vec::new();
977 for c in &combos {
978 let f = c.fast_period.unwrap();
979 let s = c.slow_period.unwrap();
980 if !uniq.contains(&f) {
981 uniq.push(f);
982 }
983 if !uniq.contains(&s) {
984 uniq.push(s);
985 }
986 }
987 for &p in &uniq {
988 if let Ok(v) = ma(ma_type, MaData::Slice(data), p) {
989 ma_cache.insert(p, v);
990 } else {
991 let v = vec![f64::NAN; data.len()];
992 ma_cache.insert(p, v);
993 }
994 }
995 }
996
997 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
998 let p = &combos[row];
999 if use_cached {
1000 let fast = p.fast_period.unwrap();
1001 let slow = p.slow_period.unwrap();
1002 let fast_ma = ma_cache.get(&fast).unwrap();
1003 let slow_ma = ma_cache.get(&slow).unwrap();
1004 let mut i = first + slow - 1;
1005 let n = data.len();
1006 while i < n {
1007 let sf = *slow_ma.get_unchecked(i);
1008 let ff = *fast_ma.get_unchecked(i);
1009 let y = if sf == 0.0 || sf.is_nan() || ff.is_nan() {
1010 f64::NAN
1011 } else {
1012 let ratio = ff / sf;
1013 f64::mul_add(ratio, 100.0, -100.0)
1014 };
1015 *out_row.get_unchecked_mut(i) = y;
1016 i += 1;
1017 }
1018 } else {
1019 match kern {
1020 Kernel::Scalar => ppo_row_scalar(
1021 data,
1022 first,
1023 p.fast_period.unwrap(),
1024 p.slow_period.unwrap(),
1025 p.ma_type.as_ref().unwrap(),
1026 out_row,
1027 ),
1028 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1029 Kernel::Avx2 => ppo_row_avx2(
1030 data,
1031 first,
1032 p.fast_period.unwrap(),
1033 p.slow_period.unwrap(),
1034 p.ma_type.as_ref().unwrap(),
1035 out_row,
1036 ),
1037 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1038 Kernel::Avx512 => ppo_row_avx512(
1039 data,
1040 first,
1041 p.fast_period.unwrap(),
1042 p.slow_period.unwrap(),
1043 p.ma_type.as_ref().unwrap(),
1044 out_row,
1045 ),
1046 _ => unreachable!(),
1047 }
1048 }
1049 };
1050
1051 if parallel {
1052 #[cfg(not(target_arch = "wasm32"))]
1053 {
1054 values
1055 .par_chunks_mut(cols)
1056 .enumerate()
1057 .for_each(|(row, slice)| do_row(row, slice));
1058 }
1059
1060 #[cfg(target_arch = "wasm32")]
1061 {
1062 for (row, slice) in values.chunks_mut(cols).enumerate() {
1063 do_row(row, slice);
1064 }
1065 }
1066 } else {
1067 for (row, slice) in values.chunks_mut(cols).enumerate() {
1068 do_row(row, slice);
1069 }
1070 }
1071
1072 let values = unsafe {
1073 Vec::from_raw_parts(
1074 buf_guard.as_mut_ptr() as *mut f64,
1075 buf_guard.len(),
1076 buf_guard.capacity(),
1077 )
1078 };
1079
1080 Ok(PpoBatchOutput {
1081 values,
1082 combos,
1083 rows,
1084 cols,
1085 })
1086}
1087
1088#[inline(always)]
1089pub unsafe fn ppo_row_scalar(
1090 data: &[f64],
1091 first: usize,
1092 fast: usize,
1093 slow: usize,
1094 ma_type: &str,
1095 out: &mut [f64],
1096) {
1097 if ma_type == "ema" {
1098 ppo_row_scalar_classic_ema(data, first, fast, slow, out);
1099 } else if ma_type == "sma" {
1100 ppo_row_scalar_classic_sma(data, first, fast, slow, out);
1101 } else {
1102 ppo_scalar(data, fast, slow, ma_type, first, out);
1103 }
1104}
1105
1106#[inline(always)]
1107pub unsafe fn ppo_row_scalar_classic_ema(
1108 data: &[f64],
1109 first: usize,
1110 fast: usize,
1111 slow: usize,
1112 out: &mut [f64],
1113) {
1114 ppo_scalar_classic_ema(data, fast, slow, first, out);
1115}
1116
1117#[inline(always)]
1118pub unsafe fn ppo_row_scalar_classic_sma(
1119 data: &[f64],
1120 first: usize,
1121 fast: usize,
1122 slow: usize,
1123 out: &mut [f64],
1124) {
1125 ppo_scalar_classic_sma(data, fast, slow, first, out);
1126}
1127
1128#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1129#[inline(always)]
1130pub unsafe fn ppo_row_avx2(
1131 data: &[f64],
1132 first: usize,
1133 fast: usize,
1134 slow: usize,
1135 ma_type: &str,
1136 out: &mut [f64],
1137) {
1138 ppo_scalar(data, fast, slow, ma_type, first, out)
1139}
1140
1141#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1142#[inline(always)]
1143pub unsafe fn ppo_row_avx512(
1144 data: &[f64],
1145 first: usize,
1146 fast: usize,
1147 slow: usize,
1148 ma_type: &str,
1149 out: &mut [f64],
1150) {
1151 if slow <= 32 {
1152 ppo_row_avx512_short(data, first, fast, slow, ma_type, out)
1153 } else {
1154 ppo_row_avx512_long(data, first, fast, slow, ma_type, out)
1155 }
1156}
1157
1158#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1159#[inline(always)]
1160pub unsafe fn ppo_row_avx512_short(
1161 data: &[f64],
1162 first: usize,
1163 fast: usize,
1164 slow: usize,
1165 ma_type: &str,
1166 out: &mut [f64],
1167) {
1168 ppo_scalar(data, fast, slow, ma_type, first, out)
1169}
1170
1171#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1172#[inline(always)]
1173pub unsafe fn ppo_row_avx512_long(
1174 data: &[f64],
1175 first: usize,
1176 fast: usize,
1177 slow: usize,
1178 ma_type: &str,
1179 out: &mut [f64],
1180) {
1181 ppo_scalar(data, fast, slow, ma_type, first, out)
1182}
1183
1184pub struct PpoStream {
1185 fast_period: usize,
1186 slow_period: usize,
1187 ma_type: String,
1188 mode: StreamMode,
1189}
1190
1191#[derive(Debug)]
1192enum StreamMode {
1193 Sma(SmaState),
1194
1195 Ema(EmaState),
1196
1197 Generic { data: Vec<f64> },
1198}
1199
1200#[derive(Debug)]
1201struct SmaState {
1202 started: bool,
1203 fast: usize,
1204 slow: usize,
1205 k: f64,
1206
1207 fast_sum: f64,
1208 slow_sum: f64,
1209 fast_buf: Vec<f64>,
1210 slow_buf: Vec<f64>,
1211 i_fast: usize,
1212 i_slow: usize,
1213 filled_fast: usize,
1214 filled_slow: usize,
1215}
1216
1217#[derive(Debug)]
1218struct EmaState {
1219 started: bool,
1220 fast: usize,
1221 slow: usize,
1222
1223 fa: f64,
1224 fb: f64,
1225 sa: f64,
1226 sb: f64,
1227
1228 fast_ema: f64,
1229 slow_ema: f64,
1230 seeded: bool,
1231
1232 warm: Vec<f64>,
1233}
1234
1235impl PpoStream {
1236 pub fn try_new(params: PpoParams) -> Result<Self, PpoError> {
1237 let fast = params.fast_period.unwrap_or(12);
1238 let slow = params.slow_period.unwrap_or(26);
1239 let ma_type = params.ma_type.clone().unwrap_or_else(|| "sma".to_string());
1240
1241 let mode = match ma_type.as_str() {
1242 "sma" => StreamMode::Sma(SmaState {
1243 started: false,
1244 fast,
1245 slow,
1246 k: slow as f64 / fast as f64,
1247 fast_sum: 0.0,
1248 slow_sum: 0.0,
1249 fast_buf: Vec::with_capacity(fast),
1250 slow_buf: Vec::with_capacity(slow),
1251 i_fast: 0,
1252 i_slow: 0,
1253 filled_fast: 0,
1254 filled_slow: 0,
1255 }),
1256 "ema" => {
1257 let fa = 2.0 / (fast as f64 + 1.0);
1258 let sa = 2.0 / (slow as f64 + 1.0);
1259 StreamMode::Ema(EmaState {
1260 started: false,
1261 fast,
1262 slow,
1263 fa,
1264 fb: 1.0 - fa,
1265 sa,
1266 sb: 1.0 - sa,
1267 fast_ema: f64::NAN,
1268 slow_ema: f64::NAN,
1269 seeded: false,
1270 warm: Vec::with_capacity(slow),
1271 })
1272 }
1273 _ => StreamMode::Generic { data: Vec::new() },
1274 };
1275
1276 Ok(Self {
1277 fast_period: fast,
1278 slow_period: slow,
1279 ma_type,
1280 mode,
1281 })
1282 }
1283
1284 #[inline(always)]
1285 pub fn update(&mut self, value: f64) -> Option<f64> {
1286 match &mut self.mode {
1287 StreamMode::Sma(state) => update_sma(state, value),
1288 StreamMode::Ema(state) => update_ema(state, value),
1289
1290 StreamMode::Generic { data } => {
1291 data.push(value);
1292 if data.len() < self.slow_period {
1293 return None;
1294 }
1295
1296 let fast_ma = ma(&self.ma_type, MaData::Slice(&data), self.fast_period).ok()?;
1297 let slow_ma = ma(&self.ma_type, MaData::Slice(&data), self.slow_period).ok()?;
1298 let ff = *fast_ma.last()?;
1299 let sf = *slow_ma.last()?;
1300 if ff.is_nan() || sf.is_nan() || sf == 0.0 {
1301 Some(f64::NAN)
1302 } else {
1303 let ratio = ff / sf;
1304 Some(f64::mul_add(ratio, 100.0, -100.0))
1305 }
1306 }
1307 }
1308 }
1309}
1310
1311#[inline(always)]
1312fn update_sma(s: &mut SmaState, x: f64) -> Option<f64> {
1313 if !s.started {
1314 if x.is_nan() {
1315 return None;
1316 }
1317 s.started = true;
1318 }
1319
1320 if s.filled_fast < s.fast {
1321 s.fast_buf.push(x);
1322 s.fast_sum += x;
1323 s.filled_fast += 1;
1324 } else {
1325 let old = s.fast_buf[s.i_fast];
1326 s.fast_sum += x - old;
1327 s.fast_buf[s.i_fast] = x;
1328 s.i_fast = (s.i_fast + 1) % s.fast;
1329 }
1330
1331 if s.filled_slow < s.slow {
1332 s.slow_buf.push(x);
1333 s.slow_sum += x;
1334 s.filled_slow += 1;
1335 if s.filled_slow < s.slow {
1336 return None;
1337 }
1338 } else {
1339 let old = s.slow_buf[s.i_slow];
1340 s.slow_sum += x - old;
1341 s.slow_buf[s.i_slow] = x;
1342 s.i_slow = (s.i_slow + 1) % s.slow;
1343 }
1344
1345 let slow_sum = s.slow_sum;
1346 let fast_sum = s.fast_sum;
1347 if slow_sum == 0.0 || slow_sum.is_nan() || fast_sum.is_nan() {
1348 Some(f64::NAN)
1349 } else {
1350 let ratio = (fast_sum * s.k) / slow_sum;
1351 Some(f64::mul_add(ratio, 100.0, -100.0))
1352 }
1353}
1354
1355#[inline(always)]
1356fn update_ema(e: &mut EmaState, x: f64) -> Option<f64> {
1357 if !e.started {
1358 if x.is_nan() {
1359 return None;
1360 }
1361 e.started = true;
1362 }
1363
1364 if !e.seeded {
1365 e.warm.push(x);
1366 if e.warm.len() < e.slow {
1367 return None;
1368 }
1369
1370 let mut slow_sum = 0.0f64;
1371 for &v in &e.warm {
1372 slow_sum += v;
1373 }
1374 e.slow_ema = slow_sum / e.slow as f64;
1375
1376 let mut fast_sum = 0.0f64;
1377 let start = e.slow - e.fast;
1378 for &v in &e.warm[start..] {
1379 fast_sum += v;
1380 }
1381 e.fast_ema = fast_sum / e.fast as f64;
1382
1383 for &v in &e.warm[e.fast..] {
1384 e.fast_ema = f64::mul_add(e.fa, v, e.fb * e.fast_ema);
1385 }
1386
1387 e.seeded = true;
1388
1389 let sf = e.slow_ema;
1390 let ff = e.fast_ema;
1391 return if sf == 0.0 || sf.is_nan() || ff.is_nan() {
1392 Some(f64::NAN)
1393 } else {
1394 let ratio = ff / sf;
1395 Some(f64::mul_add(ratio, 100.0, -100.0))
1396 };
1397 }
1398
1399 e.fast_ema = f64::mul_add(e.fa, x, e.fb * e.fast_ema);
1400 e.slow_ema = f64::mul_add(e.sa, x, e.sb * e.slow_ema);
1401
1402 let sf = e.slow_ema;
1403 let ff = e.fast_ema;
1404 if sf == 0.0 || sf.is_nan() || ff.is_nan() {
1405 Some(f64::NAN)
1406 } else {
1407 let ratio = ff / sf;
1408 Some(f64::mul_add(ratio, 100.0, -100.0))
1409 }
1410}
1411
1412#[cfg(test)]
1413mod tests {
1414 use super::*;
1415 use crate::skip_if_unsupported;
1416 use crate::utilities::data_loader::read_candles_from_csv;
1417
1418 fn check_ppo_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1419 skip_if_unsupported!(kernel, test_name);
1420 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1421 let candles = read_candles_from_csv(file_path)?;
1422 let default_params = PpoParams {
1423 fast_period: None,
1424 slow_period: None,
1425 ma_type: None,
1426 };
1427 let input = PpoInput::from_candles(&candles, "close", default_params);
1428 let output = ppo_with_kernel(&input, kernel)?;
1429 assert_eq!(output.values.len(), candles.close.len());
1430 Ok(())
1431 }
1432
1433 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1434 #[test]
1435 fn test_ppo_into_matches_api() -> Result<(), Box<dyn Error>> {
1436 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1437 let candles = read_candles_from_csv(file_path)?;
1438 let input = PpoInput::from_candles(&candles, "close", PpoParams::default());
1439
1440 let baseline = ppo(&input)?.values;
1441
1442 let mut out = vec![0.0; candles.close.len()];
1443 ppo_into(&input, &mut out)?;
1444
1445 assert_eq!(baseline.len(), out.len());
1446
1447 fn eq_or_both_nan(a: f64, b: f64) -> bool {
1448 (a.is_nan() && b.is_nan()) || (a == b)
1449 }
1450
1451 for i in 0..out.len() {
1452 assert!(
1453 eq_or_both_nan(baseline[i], out[i]),
1454 "Mismatch at index {}: baseline={}, into={}",
1455 i,
1456 baseline[i],
1457 out[i]
1458 );
1459 }
1460
1461 Ok(())
1462 }
1463
1464 fn check_ppo_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1465 skip_if_unsupported!(kernel, test_name);
1466 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1467 let candles = read_candles_from_csv(file_path)?;
1468 let input = PpoInput::from_candles(&candles, "close", PpoParams::default());
1469 let result = ppo_with_kernel(&input, kernel)?;
1470 assert_eq!(result.values.len(), candles.close.len());
1471 let expected_last_five = [
1472 -0.8532313608928664,
1473 -0.8537562894550523,
1474 -0.6821291938174874,
1475 -0.5620008722078592,
1476 -0.4101724140910927,
1477 ];
1478 let start = result.values.len().saturating_sub(5);
1479 for (i, &val) in result.values[start..].iter().enumerate() {
1480 let diff = (val - expected_last_five[i]).abs();
1481 assert!(
1482 diff < 1e-7,
1483 "[{}] PPO {:?} mismatch at idx {}: got {}, expected {}",
1484 test_name,
1485 kernel,
1486 i,
1487 val,
1488 expected_last_five[i]
1489 );
1490 }
1491 Ok(())
1492 }
1493
1494 fn check_ppo_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1495 skip_if_unsupported!(kernel, test_name);
1496 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1497 let candles = read_candles_from_csv(file_path)?;
1498 let input = PpoInput::with_default_candles(&candles);
1499 match input.data {
1500 PpoData::Candles { source, .. } => assert_eq!(source, "close"),
1501 _ => panic!("Expected PpoData::Candles"),
1502 }
1503 let output = ppo_with_kernel(&input, kernel)?;
1504 assert_eq!(output.values.len(), candles.close.len());
1505 Ok(())
1506 }
1507
1508 fn check_ppo_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1509 skip_if_unsupported!(kernel, test_name);
1510 let input_data = [10.0, 20.0, 30.0];
1511 let params = PpoParams {
1512 fast_period: Some(0),
1513 slow_period: None,
1514 ma_type: None,
1515 };
1516 let input = PpoInput::from_slice(&input_data, params);
1517 let res = ppo_with_kernel(&input, kernel);
1518 assert!(
1519 res.is_err(),
1520 "[{}] PPO should fail with zero fast period",
1521 test_name
1522 );
1523 Ok(())
1524 }
1525
1526 fn check_ppo_period_exceeds_length(
1527 test_name: &str,
1528 kernel: Kernel,
1529 ) -> Result<(), Box<dyn Error>> {
1530 skip_if_unsupported!(kernel, test_name);
1531 let data_small = [10.0, 20.0, 30.0];
1532 let params = PpoParams {
1533 fast_period: Some(12),
1534 slow_period: Some(26),
1535 ma_type: None,
1536 };
1537 let input = PpoInput::from_slice(&data_small, params);
1538 let res = ppo_with_kernel(&input, kernel);
1539 assert!(
1540 res.is_err(),
1541 "[{}] PPO should fail with period exceeding length",
1542 test_name
1543 );
1544 Ok(())
1545 }
1546
1547 fn check_ppo_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1548 skip_if_unsupported!(kernel, test_name);
1549 let single_point = [42.0];
1550 let params = PpoParams {
1551 fast_period: Some(12),
1552 slow_period: Some(26),
1553 ma_type: None,
1554 };
1555 let input = PpoInput::from_slice(&single_point, params);
1556 let res = ppo_with_kernel(&input, kernel);
1557 assert!(
1558 res.is_err(),
1559 "[{}] PPO should fail with insufficient data",
1560 test_name
1561 );
1562 Ok(())
1563 }
1564
1565 fn check_ppo_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1566 skip_if_unsupported!(kernel, test_name);
1567 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1568 let candles = read_candles_from_csv(file_path)?;
1569 let input = PpoInput::from_candles(
1570 &candles,
1571 "close",
1572 PpoParams {
1573 fast_period: Some(12),
1574 slow_period: Some(26),
1575 ma_type: Some("sma".to_string()),
1576 },
1577 );
1578 let res = ppo_with_kernel(&input, kernel)?;
1579 assert_eq!(res.values.len(), candles.close.len());
1580 if res.values.len() > 30 {
1581 for (i, &val) in res.values[30..].iter().enumerate() {
1582 assert!(
1583 !val.is_nan(),
1584 "[{}] Found unexpected NaN at out-index {}",
1585 test_name,
1586 30 + i
1587 );
1588 }
1589 }
1590 Ok(())
1591 }
1592
1593 fn check_ppo_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1594 skip_if_unsupported!(kernel, test_name);
1595 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1596 let candles = read_candles_from_csv(file_path)?;
1597 let fast = 12;
1598 let slow = 26;
1599 let ma_type = "sma".to_string();
1600 let input = PpoInput::from_candles(
1601 &candles,
1602 "close",
1603 PpoParams {
1604 fast_period: Some(fast),
1605 slow_period: Some(slow),
1606 ma_type: Some(ma_type.clone()),
1607 },
1608 );
1609 let batch_output = ppo_with_kernel(&input, kernel)?.values;
1610 let mut stream = PpoStream::try_new(PpoParams {
1611 fast_period: Some(fast),
1612 slow_period: Some(slow),
1613 ma_type: Some(ma_type),
1614 })?;
1615 let mut stream_values = Vec::with_capacity(candles.close.len());
1616 for &price in &candles.close {
1617 match stream.update(price) {
1618 Some(ppo_val) => stream_values.push(ppo_val),
1619 None => stream_values.push(f64::NAN),
1620 }
1621 }
1622 assert_eq!(batch_output.len(), stream_values.len());
1623 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1624 if b.is_nan() && s.is_nan() {
1625 continue;
1626 }
1627 let diff = (b - s).abs();
1628 assert!(
1629 diff < 1e-9,
1630 "[{}] PPO streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1631 test_name,
1632 i,
1633 b,
1634 s,
1635 diff
1636 );
1637 }
1638 Ok(())
1639 }
1640
1641 #[cfg(debug_assertions)]
1642 fn check_ppo_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1643 skip_if_unsupported!(kernel, test_name);
1644
1645 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1646 let candles = read_candles_from_csv(file_path)?;
1647
1648 let test_params = vec![
1649 PpoParams::default(),
1650 PpoParams {
1651 fast_period: Some(2),
1652 slow_period: Some(3),
1653 ma_type: Some("sma".to_string()),
1654 },
1655 PpoParams {
1656 fast_period: Some(5),
1657 slow_period: Some(10),
1658 ma_type: Some("sma".to_string()),
1659 },
1660 PpoParams {
1661 fast_period: Some(12),
1662 slow_period: Some(26),
1663 ma_type: Some("ema".to_string()),
1664 },
1665 PpoParams {
1666 fast_period: Some(12),
1667 slow_period: Some(26),
1668 ma_type: Some("wma".to_string()),
1669 },
1670 PpoParams {
1671 fast_period: Some(20),
1672 slow_period: Some(40),
1673 ma_type: Some("sma".to_string()),
1674 },
1675 PpoParams {
1676 fast_period: Some(50),
1677 slow_period: Some(100),
1678 ma_type: Some("sma".to_string()),
1679 },
1680 PpoParams {
1681 fast_period: Some(10),
1682 slow_period: Some(11),
1683 ma_type: Some("sma".to_string()),
1684 },
1685 PpoParams {
1686 fast_period: Some(3),
1687 slow_period: Some(21),
1688 ma_type: Some("ema".to_string()),
1689 },
1690 PpoParams {
1691 fast_period: Some(7),
1692 slow_period: Some(14),
1693 ma_type: Some("wma".to_string()),
1694 },
1695 PpoParams {
1696 fast_period: Some(9),
1697 slow_period: Some(21),
1698 ma_type: Some("sma".to_string()),
1699 },
1700 PpoParams {
1701 fast_period: Some(100),
1702 slow_period: Some(200),
1703 ma_type: Some("ema".to_string()),
1704 },
1705 ];
1706
1707 for (param_idx, params) in test_params.iter().enumerate() {
1708 let input = PpoInput::from_candles(&candles, "close", params.clone());
1709 let output = ppo_with_kernel(&input, kernel)?;
1710
1711 for (i, &val) in output.values.iter().enumerate() {
1712 if val.is_nan() {
1713 continue;
1714 }
1715
1716 let bits = val.to_bits();
1717
1718 if bits == 0x11111111_11111111 {
1719 panic!(
1720 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1721 with params: fast_period={}, slow_period={}, ma_type={:?} (param set {})",
1722 test_name,
1723 val,
1724 bits,
1725 i,
1726 params.fast_period.unwrap_or(12),
1727 params.slow_period.unwrap_or(26),
1728 params.ma_type.as_ref().unwrap_or(&"sma".to_string()),
1729 param_idx
1730 );
1731 }
1732
1733 if bits == 0x22222222_22222222 {
1734 panic!(
1735 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1736 with params: fast_period={}, slow_period={}, ma_type={:?} (param set {})",
1737 test_name,
1738 val,
1739 bits,
1740 i,
1741 params.fast_period.unwrap_or(12),
1742 params.slow_period.unwrap_or(26),
1743 params.ma_type.as_ref().unwrap_or(&"sma".to_string()),
1744 param_idx
1745 );
1746 }
1747
1748 if bits == 0x33333333_33333333 {
1749 panic!(
1750 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1751 with params: fast_period={}, slow_period={}, ma_type={:?} (param set {})",
1752 test_name,
1753 val,
1754 bits,
1755 i,
1756 params.fast_period.unwrap_or(12),
1757 params.slow_period.unwrap_or(26),
1758 params.ma_type.as_ref().unwrap_or(&"sma".to_string()),
1759 param_idx
1760 );
1761 }
1762 }
1763 }
1764
1765 Ok(())
1766 }
1767
1768 #[cfg(not(debug_assertions))]
1769 fn check_ppo_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1770 Ok(())
1771 }
1772
1773 #[cfg(feature = "proptest")]
1774 #[allow(clippy::float_cmp)]
1775 fn check_ppo_property(
1776 test_name: &str,
1777 kernel: Kernel,
1778 ) -> Result<(), Box<dyn std::error::Error>> {
1779 use crate::indicators::moving_averages::ma::{ma, MaData};
1780 use proptest::prelude::*;
1781 skip_if_unsupported!(kernel, test_name);
1782
1783 let strat = (2usize..=64).prop_flat_map(|slow_period| {
1784 (
1785 prop::collection::vec(
1786 (10f64..100000f64)
1787 .prop_filter("positive finite", |x| x.is_finite() && *x > 0.0),
1788 slow_period..400,
1789 ),
1790 2usize..=slow_period,
1791 Just(slow_period),
1792 Just("sma"),
1793 )
1794 });
1795
1796 proptest::test_runner::TestRunner::default()
1797 .run(&strat, |(data, fast_period, slow_period, ma_type)| {
1798 let params = PpoParams {
1799 fast_period: Some(fast_period),
1800 slow_period: Some(slow_period),
1801 ma_type: Some(ma_type.to_string()),
1802 };
1803 let input = PpoInput::from_slice(&data, params);
1804
1805 let PpoOutput { values: out } = ppo_with_kernel(&input, kernel).unwrap();
1806
1807 let PpoOutput { values: ref_out } =
1808 ppo_with_kernel(&input, Kernel::Scalar).unwrap();
1809
1810 let fast_ma = ma(&ma_type, MaData::Slice(&data), fast_period).unwrap();
1811 let slow_ma = ma(&ma_type, MaData::Slice(&data), slow_period).unwrap();
1812
1813 for i in 0..(slow_period - 1).min(data.len()) {
1814 prop_assert!(
1815 out[i].is_nan(),
1816 "Expected NaN during warmup at index {}, got {}",
1817 i,
1818 out[i]
1819 );
1820 }
1821
1822 for i in (slow_period - 1)..data.len() {
1823 let y = out[i];
1824 let r = ref_out[i];
1825 let fast_val = fast_ma[i];
1826 let slow_val = slow_ma[i];
1827
1828 if !fast_val.is_nan() && !slow_val.is_nan() && slow_val != 0.0 {
1829 let expected_ppo = 100.0 * (fast_val - slow_val) / slow_val;
1830
1831 if y.is_finite() && expected_ppo.is_finite() {
1832 prop_assert!(
1833 (y - expected_ppo).abs() < 1e-9,
1834 "PPO formula mismatch at index {}: got {}, expected {} (fast={}, slow={})",
1835 i, y, expected_ppo, fast_val, slow_val
1836 );
1837 }
1838 }
1839
1840 let y_bits = y.to_bits();
1841 let r_bits = r.to_bits();
1842
1843 if !y.is_finite() || !r.is_finite() {
1844 prop_assert!(
1845 y.to_bits() == r.to_bits(),
1846 "finite/NaN mismatch idx {}: {} vs {}",
1847 i,
1848 y,
1849 r
1850 );
1851 continue;
1852 }
1853
1854 let ulp_diff: u64 = y_bits.abs_diff(r_bits);
1855 prop_assert!(
1856 (y - r).abs() <= 1e-9 || ulp_diff <= 4,
1857 "kernel mismatch idx {}: {} vs {} (ULP={})",
1858 i,
1859 y,
1860 r,
1861 ulp_diff
1862 );
1863
1864 if y.is_finite()
1865 && fast_val.is_finite()
1866 && slow_val.is_finite()
1867 && slow_val > 0.0
1868 {
1869 if fast_val > slow_val {
1870 prop_assert!(
1871 y > 0.0,
1872 "PPO should be positive when fast > slow at index {}",
1873 i
1874 );
1875 } else if fast_val < slow_val {
1876 prop_assert!(
1877 y < 0.0,
1878 "PPO should be negative when fast < slow at index {}",
1879 i
1880 );
1881 } else {
1882 prop_assert!(
1883 y.abs() < 1e-9,
1884 "PPO should be ~0 when fast == slow at index {}",
1885 i
1886 );
1887 }
1888 }
1889
1890 if fast_period == slow_period && y.is_finite() {
1891 prop_assert!(
1892 y.abs() < 1e-9,
1893 "PPO should be ~0 when fast_period == slow_period at index {}: got {}",
1894 i,
1895 y
1896 );
1897 }
1898
1899 if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-12) && y.is_finite() {
1900 prop_assert!(
1901 y.abs() < 1e-6,
1902 "PPO should be ~0 for constant data at index {}: got {}",
1903 i,
1904 y
1905 );
1906 }
1907
1908 if y.is_finite() {
1909 let window_start = i.saturating_sub(slow_period - 1);
1910 let window = &data[window_start..=i];
1911 let min_val = window.iter().cloned().fold(f64::INFINITY, f64::min);
1912 let max_val = window.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1913 let volatility_ratio = if min_val > 0.0 {
1914 max_val / min_val
1915 } else {
1916 1.0
1917 };
1918
1919 let max_expected_ppo = 100.0 * (volatility_ratio - 1.0);
1920
1921 prop_assert!(
1922 y.abs() <= max_expected_ppo * 1.5,
1923 "PPO exceeds expected bounds at index {}: got {}%, max expected ~{}% (volatility ratio {})",
1924 i, y, max_expected_ppo, volatility_ratio
1925 );
1926 }
1927
1928 if slow_val.abs() < 1e-10 && slow_val != 0.0 {
1929 prop_assert!(
1930 y.is_nan() || y.abs() > 1000.0,
1931 "PPO should be NaN or very large when slow_ma ~0 at index {}: slow_ma={}, ppo={}",
1932 i, slow_val, y
1933 );
1934 }
1935 }
1936
1937 let is_monotonic_increasing = data.windows(2).all(|w| w[1] >= w[0]);
1938 let is_monotonic_decreasing = data.windows(2).all(|w| w[1] <= w[0]);
1939
1940 if (is_monotonic_increasing || is_monotonic_decreasing)
1941 && data.len() > slow_period * 2
1942 {
1943 let last_values = &out[out.len() - slow_period / 2..];
1944 let valid_last: Vec<f64> = last_values
1945 .iter()
1946 .filter(|x| x.is_finite())
1947 .cloned()
1948 .collect();
1949
1950 if valid_last.len() > 2 {
1951 if is_monotonic_increasing && fast_period < slow_period {
1952 let avg = valid_last.iter().sum::<f64>() / valid_last.len() as f64;
1953 prop_assert!(
1954 avg > -1e-6,
1955 "PPO should be positive for monotonic increasing data: avg={}",
1956 avg
1957 );
1958 } else if is_monotonic_decreasing && fast_period < slow_period {
1959 let avg = valid_last.iter().sum::<f64>() / valid_last.len() as f64;
1960 prop_assert!(
1961 avg < 1e-6,
1962 "PPO should be negative for monotonic decreasing data: avg={}",
1963 avg
1964 );
1965 }
1966 }
1967 }
1968
1969 Ok(())
1970 })
1971 .unwrap();
1972
1973 Ok(())
1974 }
1975
1976 macro_rules! generate_all_ppo_tests {
1977 ($($test_fn:ident),*) => {
1978 paste::paste! {
1979 $(
1980 #[test]
1981 fn [<$test_fn _scalar_f64>]() {
1982 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1983 }
1984 )*
1985 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1986 $(
1987 #[test]
1988 fn [<$test_fn _avx2_f64>]() {
1989 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1990 }
1991 #[test]
1992 fn [<$test_fn _avx512_f64>]() {
1993 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1994 }
1995 )*
1996 }
1997 }
1998 }
1999
2000 generate_all_ppo_tests!(
2001 check_ppo_partial_params,
2002 check_ppo_accuracy,
2003 check_ppo_default_candles,
2004 check_ppo_zero_period,
2005 check_ppo_period_exceeds_length,
2006 check_ppo_very_small_dataset,
2007 check_ppo_nan_handling,
2008 check_ppo_streaming,
2009 check_ppo_no_poison
2010 );
2011
2012 #[cfg(feature = "proptest")]
2013 generate_all_ppo_tests!(check_ppo_property);
2014
2015 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2016 skip_if_unsupported!(kernel, test);
2017 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2018 let c = read_candles_from_csv(file)?;
2019 let output = PpoBatchBuilder::new()
2020 .kernel(kernel)
2021 .apply_candles(&c, "close")?;
2022 let def = PpoParams::default();
2023 let row = output.values_for(&def).expect("default row missing");
2024 assert_eq!(row.len(), c.close.len());
2025 let expected = [
2026 -0.8532313608928664,
2027 -0.8537562894550523,
2028 -0.6821291938174874,
2029 -0.5620008722078592,
2030 -0.4101724140910927,
2031 ];
2032 let start = row.len() - 5;
2033 for (i, &v) in row[start..].iter().enumerate() {
2034 assert!(
2035 (v - expected[i]).abs() < 1e-7,
2036 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
2037 );
2038 }
2039 Ok(())
2040 }
2041
2042 macro_rules! gen_batch_tests {
2043 ($fn_name:ident) => {
2044 paste::paste! {
2045 #[test] fn [<$fn_name _scalar>]() {
2046 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2047 }
2048 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2049 #[test] fn [<$fn_name _avx2>]() {
2050 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2051 }
2052 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2053 #[test] fn [<$fn_name _avx512>]() {
2054 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2055 }
2056 #[test] fn [<$fn_name _auto_detect>]() {
2057 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2058 }
2059 }
2060 };
2061 }
2062
2063 #[cfg(debug_assertions)]
2064 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2065 skip_if_unsupported!(kernel, test);
2066
2067 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2068 let c = read_candles_from_csv(file)?;
2069
2070 let test_configs = vec![
2071 (2, 10, 2, 12, 30, 3, "sma"),
2072 (5, 25, 5, 26, 50, 5, "sma"),
2073 (30, 60, 15, 65, 100, 10, "sma"),
2074 (2, 5, 1, 6, 10, 1, "ema"),
2075 (10, 20, 2, 21, 40, 4, "wma"),
2076 (3, 9, 3, 12, 21, 3, "sma"),
2077 (7, 14, 7, 21, 28, 7, "ema"),
2078 ];
2079
2080 for (cfg_idx, &(f_start, f_end, f_step, s_start, s_end, s_step, ma_type)) in
2081 test_configs.iter().enumerate()
2082 {
2083 let output = PpoBatchBuilder::new()
2084 .kernel(kernel)
2085 .fast_period_range(f_start, f_end, f_step)
2086 .slow_period_range(s_start, s_end, s_step)
2087 .ma_type(ma_type)
2088 .apply_candles(&c, "close")?;
2089
2090 for (idx, &val) in output.values.iter().enumerate() {
2091 if val.is_nan() {
2092 continue;
2093 }
2094
2095 let bits = val.to_bits();
2096 let row = idx / output.cols;
2097 let col = idx % output.cols;
2098 let combo = &output.combos[row];
2099
2100 if bits == 0x11111111_11111111 {
2101 panic!(
2102 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2103 at row {} col {} (flat index {}) with params: fast_period={}, slow_period={}, ma_type={:?}",
2104 test,
2105 cfg_idx,
2106 val,
2107 bits,
2108 row,
2109 col,
2110 idx,
2111 combo.fast_period.unwrap_or(12),
2112 combo.slow_period.unwrap_or(26),
2113 combo.ma_type.as_ref().unwrap_or(&"sma".to_string())
2114 );
2115 }
2116
2117 if bits == 0x22222222_22222222 {
2118 panic!(
2119 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2120 at row {} col {} (flat index {}) with params: fast_period={}, slow_period={}, ma_type={:?}",
2121 test,
2122 cfg_idx,
2123 val,
2124 bits,
2125 row,
2126 col,
2127 idx,
2128 combo.fast_period.unwrap_or(12),
2129 combo.slow_period.unwrap_or(26),
2130 combo.ma_type.as_ref().unwrap_or(&"sma".to_string())
2131 );
2132 }
2133
2134 if bits == 0x33333333_33333333 {
2135 panic!(
2136 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2137 at row {} col {} (flat index {}) with params: fast_period={}, slow_period={}, ma_type={:?}",
2138 test,
2139 cfg_idx,
2140 val,
2141 bits,
2142 row,
2143 col,
2144 idx,
2145 combo.fast_period.unwrap_or(12),
2146 combo.slow_period.unwrap_or(26),
2147 combo.ma_type.as_ref().unwrap_or(&"sma".to_string())
2148 );
2149 }
2150 }
2151 }
2152
2153 Ok(())
2154 }
2155
2156 #[cfg(not(debug_assertions))]
2157 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2158 Ok(())
2159 }
2160
2161 gen_batch_tests!(check_batch_default_row);
2162 gen_batch_tests!(check_batch_no_poison);
2163}
2164
2165#[cfg(feature = "python")]
2166#[pyfunction(name = "ppo")]
2167#[pyo3(signature = (data, fast_period=None, slow_period=None, ma_type=None, kernel=None))]
2168pub fn ppo_py<'py>(
2169 py: Python<'py>,
2170 data: PyReadonlyArray1<'py, f64>,
2171 fast_period: Option<usize>,
2172 slow_period: Option<usize>,
2173 ma_type: Option<&str>,
2174 kernel: Option<&str>,
2175) -> PyResult<Bound<'py, PyArray1<f64>>> {
2176 let slice_in = data.as_slice()?;
2177 let kern = validate_kernel(kernel, false)?;
2178
2179 let params = PpoParams {
2180 fast_period,
2181 slow_period,
2182 ma_type: ma_type.map(|s| s.to_string()),
2183 };
2184 let input = PpoInput::from_slice(slice_in, params);
2185
2186 let result_vec: Vec<f64> = py
2187 .allow_threads(|| ppo_with_kernel(&input, kern).map(|o| o.values))
2188 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2189
2190 Ok(result_vec.into_pyarray(py))
2191}
2192
2193#[cfg(feature = "python")]
2194#[pyclass(name = "PpoStream")]
2195pub struct PpoStreamPy {
2196 stream: PpoStream,
2197}
2198
2199#[cfg(feature = "python")]
2200#[pymethods]
2201impl PpoStreamPy {
2202 #[new]
2203 fn new(
2204 fast_period: Option<usize>,
2205 slow_period: Option<usize>,
2206 ma_type: Option<&str>,
2207 ) -> PyResult<Self> {
2208 let params = PpoParams {
2209 fast_period,
2210 slow_period,
2211 ma_type: ma_type.map(|s| s.to_string()),
2212 };
2213 let stream =
2214 PpoStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2215 Ok(PpoStreamPy { stream })
2216 }
2217
2218 fn update(&mut self, value: f64) -> Option<f64> {
2219 self.stream.update(value)
2220 }
2221}
2222
2223#[cfg(all(feature = "python", feature = "cuda"))]
2224#[pyclass(module = "ta_indicators.cuda", unsendable)]
2225pub struct PpoDeviceArrayF32Py {
2226 pub(crate) inner: DeviceArrayF32Ppo,
2227}
2228
2229#[cfg(all(feature = "python", feature = "cuda"))]
2230#[pymethods]
2231impl PpoDeviceArrayF32Py {
2232 #[getter]
2233 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2234 let d = PyDict::new(py);
2235 d.set_item("shape", (self.inner.rows, self.inner.cols))?;
2236 d.set_item("typestr", "<f4")?;
2237 d.set_item(
2238 "strides",
2239 (
2240 self.inner.cols * std::mem::size_of::<f32>(),
2241 std::mem::size_of::<f32>(),
2242 ),
2243 )?;
2244 d.set_item("data", (self.inner.device_ptr() as usize, false))?;
2245
2246 d.set_item("version", 3)?;
2247 Ok(d)
2248 }
2249
2250 fn __dlpack_device__(&self) -> (i32, i32) {
2251 (2, self.inner.device_id as i32)
2252 }
2253
2254 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2255 fn __dlpack__<'py>(
2256 &mut self,
2257 py: Python<'py>,
2258 stream: Option<pyo3::PyObject>,
2259 max_version: Option<pyo3::PyObject>,
2260 dl_device: Option<pyo3::PyObject>,
2261 copy: Option<pyo3::PyObject>,
2262 ) -> PyResult<PyObject> {
2263 let (kdl, alloc_dev) = self.__dlpack_device__();
2264 if let Some(dev_obj) = dl_device.as_ref() {
2265 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2266 if dev_ty != kdl || dev_id != alloc_dev {
2267 let wants_copy = copy
2268 .as_ref()
2269 .and_then(|c| c.extract::<bool>(py).ok())
2270 .unwrap_or(false);
2271 if wants_copy {
2272 return Err(PyValueError::new_err(
2273 "device copy not implemented for __dlpack__",
2274 ));
2275 } else {
2276 return Err(pyo3::exceptions::PyBufferError::new_err(
2277 "__dlpack__: requested device does not match producer buffer",
2278 ));
2279 }
2280 }
2281 }
2282 }
2283
2284 if let Some(copy_obj) = copy.as_ref() {
2285 let do_copy: bool = copy_obj.extract::<bool>(py)?;
2286 if do_copy {
2287 return Err(pyo3::exceptions::PyBufferError::new_err(
2288 "__dlpack__(copy=True) not supported for ppo CUDA buffers",
2289 ));
2290 }
2291 }
2292
2293 if let Some(s) = stream.as_ref() {
2294 if let Ok(i) = s.extract::<i64>(py) {
2295 if i == 0 {
2296 return Err(PyValueError::new_err(
2297 "__dlpack__: stream 0 is disallowed for CUDA",
2298 ));
2299 }
2300 }
2301 }
2302
2303 let dev_id_u32 = self.inner.device_id;
2304 let ctx_clone = self.inner.ctx.clone();
2305 let dummy =
2306 DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
2307 let inner = std::mem::replace(
2308 &mut self.inner,
2309 DeviceArrayF32Ppo {
2310 buf: dummy,
2311 rows: 0,
2312 cols: 0,
2313 ctx: ctx_clone,
2314 device_id: dev_id_u32,
2315 },
2316 );
2317
2318 let rows = inner.rows;
2319 let cols = inner.cols;
2320 let buf = inner.buf;
2321
2322 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2323
2324 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2325 }
2326}
2327
2328#[cfg(feature = "python")]
2329#[pyfunction(name = "ppo_batch")]
2330#[pyo3(signature = (data, fast_period_range, slow_period_range, ma_type=None, kernel=None))]
2331pub fn ppo_batch_py<'py>(
2332 py: Python<'py>,
2333 data: PyReadonlyArray1<'py, f64>,
2334 fast_period_range: (usize, usize, usize),
2335 slow_period_range: (usize, usize, usize),
2336 ma_type: Option<&str>,
2337 kernel: Option<&str>,
2338) -> PyResult<Bound<'py, PyDict>> {
2339 let slice_in = data.as_slice()?;
2340
2341 let sweep = PpoBatchRange {
2342 fast_period: fast_period_range,
2343 slow_period: slow_period_range,
2344 ma_type: ma_type.unwrap_or("sma").to_string(),
2345 };
2346
2347 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
2348 let rows = combos.len();
2349 let cols = slice_in.len();
2350
2351 let total = rows
2352 .checked_mul(cols)
2353 .ok_or_else(|| PyValueError::new_err("rows*cols overflow in ppo_batch_py"))?;
2354 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2355 let slice_out = unsafe { out_arr.as_slice_mut()? };
2356
2357 let warm: Vec<usize> = combos
2358 .iter()
2359 .map(|c| {
2360 let first = slice_in.iter().position(|x| !x.is_nan()).unwrap_or(0);
2361 first + c.slow_period.unwrap() - 1
2362 })
2363 .collect();
2364
2365 unsafe {
2366 let mu: &mut [MaybeUninit<f64>] = std::slice::from_raw_parts_mut(
2367 slice_out.as_mut_ptr() as *mut MaybeUninit<f64>,
2368 slice_out.len(),
2369 );
2370 init_matrix_prefixes(mu, cols, &warm);
2371 }
2372
2373 let kern = validate_kernel(kernel, true)?;
2374 let combos = py
2375 .allow_threads(|| {
2376 let kernel = match kern {
2377 Kernel::Auto => detect_best_batch_kernel(),
2378 k => k,
2379 };
2380 let simd = match kernel {
2381 Kernel::Avx512Batch => Kernel::Avx512,
2382 Kernel::Avx2Batch => Kernel::Avx2,
2383 Kernel::ScalarBatch => Kernel::Scalar,
2384 k if !k.is_batch() => k,
2385 _ => Kernel::Scalar,
2386 };
2387 ppo_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
2388 })
2389 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2390
2391 let dict = PyDict::new(py);
2392 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2393 dict.set_item(
2394 "fast_periods",
2395 combos
2396 .iter()
2397 .map(|p| p.fast_period.unwrap() as u64)
2398 .collect::<Vec<_>>()
2399 .into_pyarray(py),
2400 )?;
2401 dict.set_item(
2402 "slow_periods",
2403 combos
2404 .iter()
2405 .map(|p| p.slow_period.unwrap() as u64)
2406 .collect::<Vec<_>>()
2407 .into_pyarray(py),
2408 )?;
2409 dict.set_item(
2410 "ma_types",
2411 combos
2412 .iter()
2413 .map(|p| p.ma_type.as_ref().unwrap().clone())
2414 .collect::<Vec<_>>(),
2415 )?;
2416 Ok(dict)
2417}
2418
2419#[cfg(all(feature = "python", feature = "cuda"))]
2420#[pyfunction(name = "ppo_cuda_batch_dev")]
2421#[pyo3(signature = (data_f32, fast_period_range, slow_period_range, ma_type="sma", device_id=0))]
2422pub fn ppo_cuda_batch_dev_py<'py>(
2423 py: Python<'py>,
2424 data_f32: numpy::PyReadonlyArray1<'py, f32>,
2425 fast_period_range: (usize, usize, usize),
2426 slow_period_range: (usize, usize, usize),
2427 ma_type: &str,
2428 device_id: usize,
2429) -> PyResult<(PpoDeviceArrayF32Py, Bound<'py, PyDict>)> {
2430 if !cuda_available() {
2431 return Err(PyValueError::new_err("CUDA not available"));
2432 }
2433
2434 let slice_in = data_f32.as_slice()?;
2435 let sweep = PpoBatchRange {
2436 fast_period: fast_period_range,
2437 slow_period: slow_period_range,
2438 ma_type: ma_type.to_string(),
2439 };
2440 let (inner, combos) = py
2441 .allow_threads(|| CudaPpo::new(device_id).and_then(|c| c.ppo_batch_dev(slice_in, &sweep)))
2442 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2443
2444 let dict = PyDict::new(py);
2445 dict.set_item(
2446 "fast_periods",
2447 combos
2448 .iter()
2449 .map(|p| p.fast_period.unwrap() as u64)
2450 .collect::<Vec<_>>()
2451 .into_pyarray(py),
2452 )?;
2453 dict.set_item(
2454 "slow_periods",
2455 combos
2456 .iter()
2457 .map(|p| p.slow_period.unwrap() as u64)
2458 .collect::<Vec<_>>()
2459 .into_pyarray(py),
2460 )?;
2461 dict.set_item(
2462 "ma_types",
2463 combos
2464 .iter()
2465 .map(|p| p.ma_type.as_ref().unwrap().clone())
2466 .collect::<Vec<_>>(),
2467 )?;
2468 Ok((PpoDeviceArrayF32Py { inner }, dict))
2469}
2470
2471#[cfg(all(feature = "python", feature = "cuda"))]
2472#[pyfunction(name = "ppo_cuda_many_series_one_param_dev")]
2473#[pyo3(signature = (data_tm_f32, fast_period, slow_period, ma_type="sma", device_id=0))]
2474pub fn ppo_cuda_many_series_one_param_dev_py<'py>(
2475 py: Python<'py>,
2476 data_tm_f32: numpy::PyReadonlyArray2<'py, f32>,
2477 fast_period: usize,
2478 slow_period: usize,
2479 ma_type: &str,
2480 device_id: usize,
2481) -> PyResult<PpoDeviceArrayF32Py> {
2482 if !cuda_available() {
2483 return Err(PyValueError::new_err("CUDA not available"));
2484 }
2485
2486 let shape = data_tm_f32.shape();
2487 if shape.len() != 2 {
2488 return Err(PyValueError::new_err("expected 2D array"));
2489 }
2490 let rows = shape[0];
2491 let cols = shape[1];
2492 let flat = data_tm_f32.as_slice()?;
2493 let params = PpoParams {
2494 fast_period: Some(fast_period),
2495 slow_period: Some(slow_period),
2496 ma_type: Some(ma_type.to_string()),
2497 };
2498 let inner = py
2499 .allow_threads(|| {
2500 CudaPpo::new(device_id)
2501 .and_then(|c| c.ppo_many_series_one_param_time_major_dev(flat, cols, rows, ¶ms))
2502 })
2503 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2504 Ok(PpoDeviceArrayF32Py { inner })
2505}
2506
2507#[cfg(feature = "python")]
2508pub fn register_ppo_module(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
2509 m.add_function(wrap_pyfunction!(ppo_py, m)?)?;
2510 m.add_function(wrap_pyfunction!(ppo_batch_py, m)?)?;
2511 m.add_class::<PpoStreamPy>()?;
2512 #[cfg(feature = "cuda")]
2513 {
2514 m.add_class::<PpoDeviceArrayF32Py>()?;
2515 m.add_function(wrap_pyfunction!(ppo_cuda_batch_dev_py, m)?)?;
2516 m.add_function(wrap_pyfunction!(ppo_cuda_many_series_one_param_dev_py, m)?)?;
2517 }
2518 Ok(())
2519}
2520
2521#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2522#[wasm_bindgen]
2523pub fn ppo_js(
2524 data: &[f64],
2525 fast_period: usize,
2526 slow_period: usize,
2527 ma_type: &str,
2528) -> Result<Vec<f64>, JsValue> {
2529 let params = PpoParams {
2530 fast_period: Some(fast_period),
2531 slow_period: Some(slow_period),
2532 ma_type: Some(ma_type.to_string()),
2533 };
2534 let input = PpoInput::from_slice(data, params);
2535 let mut out = vec![0.0; data.len()];
2536 ppo_into_slice(&mut out, &input, detect_best_kernel())
2537 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2538 Ok(out)
2539}
2540
2541#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2542#[wasm_bindgen]
2543pub fn ppo_alloc(len: usize) -> *mut f64 {
2544 let mut vec = Vec::<f64>::with_capacity(len);
2545 let ptr = vec.as_mut_ptr();
2546 std::mem::forget(vec);
2547 ptr
2548}
2549
2550#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2551#[wasm_bindgen]
2552pub fn ppo_free(ptr: *mut f64, len: usize) {
2553 if !ptr.is_null() {
2554 unsafe {
2555 let _ = Vec::from_raw_parts(ptr, len, len);
2556 }
2557 }
2558}
2559
2560#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2561#[wasm_bindgen]
2562pub fn ppo_into(
2563 in_ptr: *const f64,
2564 out_ptr: *mut f64,
2565 len: usize,
2566 fast_period: usize,
2567 slow_period: usize,
2568 ma_type: &str,
2569) -> Result<(), JsValue> {
2570 if in_ptr.is_null() || out_ptr.is_null() {
2571 return Err(JsValue::from_str("null pointer passed to ppo_into"));
2572 }
2573 unsafe {
2574 let data = std::slice::from_raw_parts(in_ptr, len);
2575 let params = PpoParams {
2576 fast_period: Some(fast_period),
2577 slow_period: Some(slow_period),
2578 ma_type: Some(ma_type.to_string()),
2579 };
2580 let input = PpoInput::from_slice(data, params);
2581 if in_ptr == out_ptr {
2582 let mut tmp = vec![0.0; len];
2583 ppo_into_slice(&mut tmp, &input, detect_best_kernel())
2584 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2585 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2586 out.copy_from_slice(&tmp);
2587 } else {
2588 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2589 ppo_into_slice(out, &input, detect_best_kernel())
2590 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2591 }
2592 }
2593 Ok(())
2594}
2595
2596#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2597#[derive(Serialize, Deserialize)]
2598pub struct PpoBatchConfig {
2599 pub fast_period_range: (usize, usize, usize),
2600 pub slow_period_range: (usize, usize, usize),
2601 pub ma_type: String,
2602}
2603
2604#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2605#[derive(Serialize, Deserialize)]
2606pub struct PpoBatchJsOutput {
2607 pub values: Vec<f64>,
2608 pub combos: Vec<PpoParams>,
2609 pub rows: usize,
2610 pub cols: usize,
2611}
2612
2613#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2614#[wasm_bindgen(js_name = "ppo_batch")]
2615pub fn ppo_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2616 let cfg: PpoBatchConfig = serde_wasm_bindgen::from_value(config)
2617 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2618 let sweep = PpoBatchRange {
2619 fast_period: cfg.fast_period_range,
2620 slow_period: cfg.slow_period_range,
2621 ma_type: cfg.ma_type,
2622 };
2623 let out = ppo_batch_inner(data, &sweep, detect_best_kernel(), false)
2624 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2625 let js = PpoBatchJsOutput {
2626 values: out.values,
2627 combos: out.combos,
2628 rows: out.rows,
2629 cols: out.cols,
2630 };
2631 serde_wasm_bindgen::to_value(&js)
2632 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2633}
2634
2635#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2636#[wasm_bindgen]
2637pub fn ppo_batch_into(
2638 in_ptr: *const f64,
2639 out_ptr: *mut f64,
2640 len: usize,
2641 fast_start: usize,
2642 fast_end: usize,
2643 fast_step: usize,
2644 slow_start: usize,
2645 slow_end: usize,
2646 slow_step: usize,
2647 ma_type: &str,
2648) -> Result<usize, JsValue> {
2649 if in_ptr.is_null() || out_ptr.is_null() {
2650 return Err(JsValue::from_str("null pointer passed to ppo_batch_into"));
2651 }
2652 unsafe {
2653 let data = std::slice::from_raw_parts(in_ptr, len);
2654 let sweep = PpoBatchRange {
2655 fast_period: (fast_start, fast_end, fast_step),
2656 slow_period: (slow_start, slow_end, slow_step),
2657 ma_type: ma_type.to_string(),
2658 };
2659 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2660 let rows = combos.len();
2661 let cols = len;
2662 let total = rows
2663 .checked_mul(cols)
2664 .ok_or_else(|| JsValue::from_str("rows*cols overflow in ppo_batch_into"))?;
2665 let out = std::slice::from_raw_parts_mut(out_ptr, total);
2666 ppo_batch_inner_into(data, &sweep, detect_best_kernel(), false, out)
2667 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2668 Ok(rows)
2669 }
2670}