1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9
10use crate::utilities::data_loader::{source_type, Candles};
11use crate::utilities::enums::Kernel;
12use crate::utilities::helpers::{
13 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
14 make_uninit_matrix,
15};
16#[cfg(feature = "python")]
17use crate::utilities::kernel_validation::validate_kernel;
18use aligned_vec::{AVec, CACHELINE_ALIGN};
19#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
20use core::arch::x86_64::*;
21#[cfg(not(target_arch = "wasm32"))]
22use rayon::prelude::*;
23use std::convert::AsRef;
24use std::error::Error;
25use thiserror::Error;
26
27#[cfg(all(feature = "python", feature = "cuda"))]
28use crate::cuda::moving_averages::DeviceArrayF32;
29#[cfg(all(feature = "python", feature = "cuda"))]
30use crate::cuda::pfe_wrapper::CudaPfe;
31#[cfg(all(feature = "python", feature = "cuda"))]
32use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
33#[cfg(all(feature = "python", feature = "cuda"))]
34use cust::context::Context;
35#[cfg(all(feature = "python", feature = "cuda"))]
36use cust::memory::DeviceBuffer;
37#[cfg(all(feature = "python", feature = "cuda"))]
38use std::sync::Arc;
39
40impl<'a> AsRef<[f64]> for PfeInput<'a> {
41 #[inline(always)]
42 fn as_ref(&self) -> &[f64] {
43 match &self.data {
44 PfeData::Slice(slice) => slice,
45 PfeData::Candles { candles, source } => source_type(candles, source),
46 }
47 }
48}
49
50#[derive(Debug, Clone)]
51pub enum PfeData<'a> {
52 Candles {
53 candles: &'a Candles,
54 source: &'a str,
55 },
56 Slice(&'a [f64]),
57}
58
59#[derive(Debug, Clone)]
60pub struct PfeOutput {
61 pub values: Vec<f64>,
62}
63
64#[derive(Debug, Clone)]
65#[cfg_attr(
66 all(target_arch = "wasm32", feature = "wasm"),
67 derive(serde::Serialize, serde::Deserialize)
68)]
69pub struct PfeParams {
70 pub period: Option<usize>,
71 pub smoothing: Option<usize>,
72}
73
74impl Default for PfeParams {
75 fn default() -> Self {
76 Self {
77 period: Some(10),
78 smoothing: Some(5),
79 }
80 }
81}
82
83#[derive(Debug, Clone)]
84pub struct PfeInput<'a> {
85 pub data: PfeData<'a>,
86 pub params: PfeParams,
87}
88
89impl<'a> PfeInput<'a> {
90 #[inline]
91 pub fn from_candles(c: &'a Candles, s: &'a str, p: PfeParams) -> Self {
92 Self {
93 data: PfeData::Candles {
94 candles: c,
95 source: s,
96 },
97 params: p,
98 }
99 }
100 #[inline]
101 pub fn from_slice(sl: &'a [f64], p: PfeParams) -> Self {
102 Self {
103 data: PfeData::Slice(sl),
104 params: p,
105 }
106 }
107 #[inline]
108 pub fn with_default_candles(c: &'a Candles) -> Self {
109 Self::from_candles(c, "close", PfeParams::default())
110 }
111 #[inline]
112 pub fn get_period(&self) -> usize {
113 self.params.period.unwrap_or(10)
114 }
115 #[inline]
116 pub fn get_smoothing(&self) -> usize {
117 self.params.smoothing.unwrap_or(5)
118 }
119}
120
121#[derive(Copy, Clone, Debug)]
122pub struct PfeBuilder {
123 period: Option<usize>,
124 smoothing: Option<usize>,
125 kernel: Kernel,
126}
127
128impl Default for PfeBuilder {
129 fn default() -> Self {
130 Self {
131 period: None,
132 smoothing: None,
133 kernel: Kernel::Auto,
134 }
135 }
136}
137
138impl PfeBuilder {
139 #[inline(always)]
140 pub fn new() -> Self {
141 Self::default()
142 }
143 #[inline(always)]
144 pub fn period(mut self, n: usize) -> Self {
145 self.period = Some(n);
146 self
147 }
148 #[inline(always)]
149 pub fn smoothing(mut self, s: usize) -> Self {
150 self.smoothing = Some(s);
151 self
152 }
153 #[inline(always)]
154 pub fn kernel(mut self, k: Kernel) -> Self {
155 self.kernel = k;
156 self
157 }
158 #[inline(always)]
159 pub fn apply(self, c: &Candles) -> Result<PfeOutput, PfeError> {
160 let p = PfeParams {
161 period: self.period,
162 smoothing: self.smoothing,
163 };
164 let i = PfeInput::from_candles(c, "close", p);
165 pfe_with_kernel(&i, self.kernel)
166 }
167 #[inline(always)]
168 pub fn apply_slice(self, d: &[f64]) -> Result<PfeOutput, PfeError> {
169 let p = PfeParams {
170 period: self.period,
171 smoothing: self.smoothing,
172 };
173 let i = PfeInput::from_slice(d, p);
174 pfe_with_kernel(&i, self.kernel)
175 }
176 #[inline(always)]
177 pub fn into_stream(self) -> Result<PfeStream, PfeError> {
178 let p = PfeParams {
179 period: self.period,
180 smoothing: self.smoothing,
181 };
182 PfeStream::try_new(p)
183 }
184}
185
186#[derive(Debug, Error)]
187pub enum PfeError {
188 #[error("pfe: Input data slice is empty.")]
189 EmptyInputData,
190 #[error("pfe: All values are NaN.")]
191 AllValuesNaN,
192 #[error("pfe: Invalid period: period = {period}, data length = {data_len}")]
193 InvalidPeriod { period: usize, data_len: usize },
194 #[error("pfe: Not enough valid data: needed = {needed}, valid = {valid}")]
195 NotEnoughValidData { needed: usize, valid: usize },
196 #[error("pfe: Invalid smoothing: {smoothing}")]
197 InvalidSmoothing { smoothing: usize },
198 #[error("pfe: Output length mismatch: expected {expected}, got {got}")]
199 OutputLengthMismatch { expected: usize, got: usize },
200 #[error("pfe: invalid range: start={start} end={end} step={step}")]
201 InvalidRange {
202 start: usize,
203 end: usize,
204 step: usize,
205 },
206 #[error("pfe: invalid kernel for batch path: {0:?}")]
207 InvalidKernelForBatch(Kernel),
208 #[error("pfe: invalid input: {0}")]
209 InvalidInput(&'static str),
210}
211
212#[inline(always)]
213fn pfe_prepare<'a>(
214 input: &'a PfeInput,
215 k: Kernel,
216) -> Result<(&'a [f64], usize, usize, usize, Kernel), PfeError> {
217 let data: &[f64] = input.as_ref();
218 let len = data.len();
219 if len == 0 {
220 return Err(PfeError::EmptyInputData);
221 }
222 let first = data
223 .iter()
224 .position(|x| !x.is_nan())
225 .ok_or(PfeError::AllValuesNaN)?;
226 let period = input.get_period();
227 let smoothing = input.get_smoothing();
228
229 if period == 0 || period > len {
230 return Err(PfeError::InvalidPeriod {
231 period,
232 data_len: len,
233 });
234 }
235
236 if len - first < period + 1 {
237 return Err(PfeError::NotEnoughValidData {
238 needed: period + 1,
239 valid: len - first,
240 });
241 }
242 if smoothing == 0 {
243 return Err(PfeError::InvalidSmoothing { smoothing });
244 }
245
246 let chosen = match k {
247 Kernel::Auto => Kernel::Scalar,
248 other => other,
249 };
250 Ok((data, period, smoothing, first, chosen))
251}
252
253#[inline(always)]
254fn pfe_compute_into(
255 data: &[f64],
256 period: usize,
257 smoothing: usize,
258 first: usize,
259 _kernel: Kernel,
260 out: &mut [f64],
261) {
262 let len = data.len();
263 if len == 0 {
264 return;
265 }
266
267 let start = first + period;
268 if start >= len {
269 return;
270 }
271
272 let p = period as f64;
273 let p2 = p * p;
274 let alpha = 2.0 / (smoothing as f64 + 1.0);
275 let one_minus_alpha = 1.0 - alpha;
276
277 let mut seg = vec![0.0f64; period];
278 let mut head = 0usize;
279 let mut denom = 0.0f64;
280
281 let base = start - period + 1;
282 for j in 0..period {
283 let k = base + j;
284 let d = data[k] - data[k - 1];
285 let s = (d.mul_add(d, 1.0)).sqrt();
286 seg[j] = s;
287 denom += s;
288 }
289
290 let mut ema_started = false;
291 let mut ema_val = 0.0f64;
292
293 let last = len - 1;
294 for t in start..last {
295 let cur = data[t];
296 let past = data[t - period];
297 let diff = cur - past;
298
299 let long_leg = (diff.mul_add(diff, p2)).sqrt();
300
301 let raw = 100.0 * (long_leg / denom);
302 let signed = if diff > 0.0 { raw } else { -raw };
303
304 let val = if !ema_started {
305 ema_started = true;
306 ema_val = signed;
307 signed
308 } else {
309 ema_val = alpha.mul_add(signed, one_minus_alpha * ema_val);
310 ema_val
311 };
312
313 out[t] = val;
314
315 let old = seg[head];
316 let next = data[t + 1];
317 let new_d = next - cur;
318 let new_s = (new_d.mul_add(new_d, 1.0)).sqrt();
319 denom += new_s - old;
320 seg[head] = new_s;
321 head += 1;
322 if head == period {
323 head = 0;
324 }
325 }
326
327 if start <= last {
328 let cur = data[last];
329 let past = data[last - period];
330 let diff = cur - past;
331 let long_leg = (diff.mul_add(diff, p2)).sqrt();
332 let raw = 100.0 * (long_leg / denom);
333 let signed = if diff > 0.0 { raw } else { -raw };
334 out[last] = if !ema_started {
335 signed
336 } else {
337 alpha.mul_add(signed, one_minus_alpha * ema_val)
338 };
339 }
340}
341
342#[inline]
343pub fn pfe(input: &PfeInput) -> Result<PfeOutput, PfeError> {
344 pfe_with_kernel(input, Kernel::Auto)
345}
346
347pub fn pfe_with_kernel(input: &PfeInput, kernel: Kernel) -> Result<PfeOutput, PfeError> {
348 let (data, period, smoothing, first, chosen) = pfe_prepare(input, kernel)?;
349
350 let mut out = alloc_with_nan_prefix(data.len(), first + period);
351 pfe_compute_into(data, period, smoothing, first, chosen, &mut out);
352 Ok(PfeOutput { values: out })
353}
354
355#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
356#[inline]
357pub fn pfe_into(input: &PfeInput, out: &mut [f64]) -> Result<(), PfeError> {
358 let (data, period, smoothing, first, chosen) = pfe_prepare(input, Kernel::Auto)?;
359 if out.len() != data.len() {
360 return Err(PfeError::OutputLengthMismatch {
361 expected: data.len(),
362 got: out.len(),
363 });
364 }
365
366 let warm = (first + period).min(out.len());
367 for v in &mut out[..warm] {
368 *v = f64::from_bits(0x7ff8_0000_0000_0000);
369 }
370
371 pfe_compute_into(data, period, smoothing, first, chosen, out);
372 Ok(())
373}
374
375#[inline]
376pub fn pfe_into_slice(dst: &mut [f64], input: &PfeInput, k: Kernel) -> Result<(), PfeError> {
377 let (data, period, smoothing, first, chosen) = pfe_prepare(input, k)?;
378 if dst.len() != data.len() {
379 return Err(PfeError::OutputLengthMismatch {
380 expected: data.len(),
381 got: dst.len(),
382 });
383 }
384 pfe_compute_into(data, period, smoothing, first, chosen, dst);
385
386 for v in &mut dst[..(first + period)] {
387 *v = f64::NAN;
388 }
389 Ok(())
390}
391
392#[inline]
393pub fn pfe_batch_with_kernel(
394 data: &[f64],
395 sweep: &PfeBatchRange,
396 k: Kernel,
397) -> Result<PfeBatchOutput, PfeError> {
398 let kernel = match k {
399 Kernel::Auto => detect_best_batch_kernel(),
400 other if other.is_batch() => other,
401 _ => {
402 return Err(PfeError::InvalidKernelForBatch(k));
403 }
404 };
405 let simd = match kernel {
406 Kernel::Avx512Batch => Kernel::Avx512,
407 Kernel::Avx2Batch => Kernel::Avx2,
408 Kernel::ScalarBatch => Kernel::Scalar,
409 _ => unreachable!(),
410 };
411 pfe_batch_par_slice(data, sweep, simd)
412}
413
414#[derive(Clone, Debug)]
415pub struct PfeBatchRange {
416 pub period: (usize, usize, usize),
417 pub smoothing: (usize, usize, usize),
418}
419
420impl Default for PfeBatchRange {
421 fn default() -> Self {
422 Self {
423 period: (10, 259, 1),
424 smoothing: (5, 5, 0),
425 }
426 }
427}
428
429#[derive(Clone, Debug, Default)]
430pub struct PfeBatchBuilder {
431 range: PfeBatchRange,
432 kernel: Kernel,
433}
434
435impl PfeBatchBuilder {
436 pub fn new() -> Self {
437 Self::default()
438 }
439 pub fn kernel(mut self, k: Kernel) -> Self {
440 self.kernel = k;
441 self
442 }
443 #[inline]
444 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
445 self.range.period = (start, end, step);
446 self
447 }
448 #[inline]
449 pub fn period_static(mut self, p: usize) -> Self {
450 self.range.period = (p, p, 0);
451 self
452 }
453 #[inline]
454 pub fn smoothing_range(mut self, start: usize, end: usize, step: usize) -> Self {
455 self.range.smoothing = (start, end, step);
456 self
457 }
458 #[inline]
459 pub fn smoothing_static(mut self, s: usize) -> Self {
460 self.range.smoothing = (s, s, 0);
461 self
462 }
463 pub fn apply_slice(self, data: &[f64]) -> Result<PfeBatchOutput, PfeError> {
464 pfe_batch_with_kernel(data, &self.range, self.kernel)
465 }
466 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<PfeBatchOutput, PfeError> {
467 PfeBatchBuilder::new().kernel(k).apply_slice(data)
468 }
469 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<PfeBatchOutput, PfeError> {
470 let slice = source_type(c, src);
471 self.apply_slice(slice)
472 }
473 pub fn with_default_candles(c: &Candles) -> Result<PfeBatchOutput, PfeError> {
474 PfeBatchBuilder::new()
475 .kernel(Kernel::Auto)
476 .apply_candles(c, "close")
477 }
478}
479
480#[derive(Clone, Debug)]
481pub struct PfeBatchOutput {
482 pub values: Vec<f64>,
483 pub combos: Vec<PfeParams>,
484 pub rows: usize,
485 pub cols: usize,
486}
487impl PfeBatchOutput {
488 pub fn row_for_params(&self, p: &PfeParams) -> Option<usize> {
489 self.combos.iter().position(|c| {
490 c.period.unwrap_or(10) == p.period.unwrap_or(10)
491 && c.smoothing.unwrap_or(5) == p.smoothing.unwrap_or(5)
492 })
493 }
494 pub fn values_for(&self, p: &PfeParams) -> Option<&[f64]> {
495 self.row_for_params(p).map(|row| {
496 let start = row * self.cols;
497 &self.values[start..start + self.cols]
498 })
499 }
500}
501
502#[inline(always)]
503fn expand_grid(r: &PfeBatchRange) -> Result<Vec<PfeParams>, PfeError> {
504 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, PfeError> {
505 if step == 0 || start == end {
506 return Ok(vec![start]);
507 }
508 if start < end {
509 let mut v = Vec::new();
510 let mut x = start;
511 let st = step.max(1);
512 loop {
513 v.push(x);
514 let next = match x.checked_add(st) {
515 Some(nx) if nx > x && nx <= end => nx,
516 _ => break,
517 };
518 x = next;
519 }
520 if v.is_empty() {
521 return Err(PfeError::InvalidRange { start, end, step });
522 }
523 return Ok(v);
524 }
525 let mut v = Vec::new();
526 let mut x = start;
527 let st = step.max(1);
528 loop {
529 v.push(x);
530 if x <= end {
531 break;
532 }
533 let next = x.saturating_sub(st);
534 if next >= x {
535 break;
536 }
537 x = next;
538 }
539 if v.is_empty() {
540 return Err(PfeError::InvalidRange { start, end, step });
541 }
542 Ok(v)
543 }
544
545 let periods = axis_usize(r.period)?;
546 let smoothings = axis_usize(r.smoothing)?;
547
548 let cap = periods
549 .len()
550 .checked_mul(smoothings.len())
551 .ok_or(PfeError::InvalidRange {
552 start: periods.len(),
553 end: smoothings.len(),
554 step: 0,
555 })?;
556
557 let mut out = Vec::with_capacity(cap);
558 for &p in &periods {
559 for &s in &smoothings {
560 out.push(PfeParams {
561 period: Some(p),
562 smoothing: Some(s),
563 });
564 }
565 }
566 Ok(out)
567}
568
569#[inline(always)]
570pub fn pfe_batch_slice(
571 data: &[f64],
572 sweep: &PfeBatchRange,
573 kern: Kernel,
574) -> Result<PfeBatchOutput, PfeError> {
575 pfe_batch_inner(data, sweep, kern, false)
576}
577
578#[inline(always)]
579pub fn pfe_batch_par_slice(
580 data: &[f64],
581 sweep: &PfeBatchRange,
582 kern: Kernel,
583) -> Result<PfeBatchOutput, PfeError> {
584 pfe_batch_inner(data, sweep, kern, true)
585}
586
587#[inline(always)]
588pub fn pfe_batch_inner_into(
589 data: &[f64],
590 sweep: &PfeBatchRange,
591 kern: Kernel,
592 parallel: bool,
593 out: &mut [f64],
594) -> Result<Vec<PfeParams>, PfeError> {
595 let combos = expand_grid(sweep)?;
596 if combos.is_empty() {
597 return Err(PfeError::InvalidRange {
598 start: sweep.period.0,
599 end: sweep.period.1,
600 step: sweep.period.2,
601 });
602 }
603
604 let first = data
605 .iter()
606 .position(|x| !x.is_nan())
607 .ok_or(PfeError::AllValuesNaN)?;
608 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
609 if data.len() - first < max_p {
610 return Err(PfeError::NotEnoughValidData {
611 needed: max_p,
612 valid: data.len() - first,
613 });
614 }
615
616 let rows = combos.len();
617 let cols = data.len();
618
619 let expected = rows.checked_mul(cols).ok_or(PfeError::InvalidRange {
620 start: rows,
621 end: cols,
622 step: 0,
623 })?;
624 if out.len() != expected {
625 return Err(PfeError::OutputLengthMismatch {
626 expected,
627 got: out.len(),
628 });
629 }
630
631 for (row, combo) in combos.iter().enumerate() {
632 let warmup = first + combo.period.unwrap();
633 let row_start = row.checked_mul(cols).ok_or(PfeError::InvalidRange {
634 start: row,
635 end: cols,
636 step: 0,
637 })?;
638 let end = row_start
639 .checked_add(warmup.min(cols))
640 .ok_or(PfeError::InvalidRange {
641 start: row_start,
642 end: warmup.min(cols),
643 step: 0,
644 })?;
645 for i in row_start..end {
646 out[i] = f64::NAN;
647 }
648 }
649
650 let mut prefix = vec![0.0f64; cols];
651 for i in 1..cols {
652 let d = data[i] - data[i - 1];
653 let s = (d.mul_add(d, 1.0)).sqrt();
654 prefix[i] = prefix[i - 1] + s;
655 }
656
657 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
658 let period = combos[row].period.unwrap();
659 let smoothing = combos[row].smoothing.unwrap();
660 let _ = pfe_row_scalar_with_prefix(data, &prefix, first, period, smoothing, out_row);
661 };
662
663 if parallel {
664 #[cfg(not(target_arch = "wasm32"))]
665 {
666 out.par_chunks_mut(cols)
667 .enumerate()
668 .for_each(|(row, slice)| do_row(row, slice));
669 }
670
671 #[cfg(target_arch = "wasm32")]
672 {
673 for (row, slice) in out.chunks_mut(cols).enumerate() {
674 do_row(row, slice);
675 }
676 }
677 } else {
678 for (row, slice) in out.chunks_mut(cols).enumerate() {
679 do_row(row, slice);
680 }
681 }
682
683 Ok(combos)
684}
685
686#[inline(always)]
687fn pfe_batch_inner(
688 data: &[f64],
689 sweep: &PfeBatchRange,
690 kern: Kernel,
691 parallel: bool,
692) -> Result<PfeBatchOutput, PfeError> {
693 let combos = expand_grid(sweep)?;
694 if combos.is_empty() {
695 return Err(PfeError::InvalidRange {
696 start: sweep.period.0,
697 end: sweep.period.1,
698 step: sweep.period.2,
699 });
700 }
701
702 let first = data
703 .iter()
704 .position(|x| !x.is_nan())
705 .ok_or(PfeError::AllValuesNaN)?;
706 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
707 if data.len() - first < max_p {
708 return Err(PfeError::NotEnoughValidData {
709 needed: max_p,
710 valid: data.len() - first,
711 });
712 }
713
714 let rows = combos.len();
715 let cols = data.len();
716 let total = rows.checked_mul(cols).ok_or(PfeError::InvalidRange {
717 start: rows,
718 end: cols,
719 step: 0,
720 })?;
721
722 let mut buf_mu = make_uninit_matrix(rows, cols);
723 let warm: Vec<usize> = combos.iter().map(|c| first + c.period.unwrap()).collect();
724 init_matrix_prefixes(&mut buf_mu, cols, &warm);
725
726 let mut buf_guard = core::mem::ManuallyDrop::new(buf_mu);
727 let values_slice: &mut [f64] =
728 unsafe { core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, total) };
729
730 let mut prefix = vec![0.0f64; cols];
731 for i in 1..cols {
732 let d = data[i] - data[i - 1];
733 let s = (d.mul_add(d, 1.0)).sqrt();
734 prefix[i] = prefix[i - 1] + s;
735 }
736
737 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
738 let period = combos[row].period.unwrap();
739 let smoothing = combos[row].smoothing.unwrap();
740 let _ = pfe_row_scalar_with_prefix(data, &prefix, first, period, smoothing, out_row);
741 };
742
743 if parallel {
744 #[cfg(not(target_arch = "wasm32"))]
745 {
746 values_slice
747 .par_chunks_mut(cols)
748 .enumerate()
749 .for_each(|(row, slice)| do_row(row, slice));
750 }
751
752 #[cfg(target_arch = "wasm32")]
753 {
754 for (row, slice) in values_slice.chunks_mut(cols).enumerate() {
755 do_row(row, slice);
756 }
757 }
758 } else {
759 for (row, slice) in values_slice.chunks_mut(cols).enumerate() {
760 do_row(row, slice);
761 }
762 }
763
764 let values = unsafe {
765 Vec::from_raw_parts(
766 buf_guard.as_mut_ptr() as *mut f64,
767 buf_guard.len(),
768 buf_guard.capacity(),
769 )
770 };
771
772 Ok(PfeBatchOutput {
773 values,
774 combos,
775 rows,
776 cols,
777 })
778}
779
780#[inline(always)]
781unsafe fn pfe_row_scalar(
782 data: &[f64],
783 first: usize,
784 period: usize,
785 smoothing: usize,
786 out_row: &mut [f64],
787) {
788 let len = data.len();
789 let start = first + period;
790
791 let p = period as f64;
792 let p2 = p * p;
793 let alpha = 2.0 / (smoothing as f64 + 1.0);
794 let one_minus_alpha = 1.0 - alpha;
795
796 let mut seg = vec![0.0f64; period];
797 let mut head = 0usize;
798 let mut denom = 0.0f64;
799
800 let base = start - period + 1;
801
802 let mut j = 0usize;
803 let stop = period & !3;
804 while j < stop {
805 let k0 = base + j;
806 let d0 = *data.get_unchecked(k0) - *data.get_unchecked(k0 - 1);
807 let k1 = k0 + 1;
808 let d1 = *data.get_unchecked(k1) - *data.get_unchecked(k1 - 1);
809 let k2 = k1 + 1;
810 let d2 = *data.get_unchecked(k2) - *data.get_unchecked(k2 - 1);
811 let k3 = k2 + 1;
812 let d3 = *data.get_unchecked(k3) - *data.get_unchecked(k3 - 1);
813
814 let s0 = (d0.mul_add(d0, 1.0)).sqrt();
815 let s1 = (d1.mul_add(d1, 1.0)).sqrt();
816 let s2 = (d2.mul_add(d2, 1.0)).sqrt();
817 let s3 = (d3.mul_add(d3, 1.0)).sqrt();
818
819 *seg.get_unchecked_mut(j) = s0;
820 *seg.get_unchecked_mut(j + 1) = s1;
821 *seg.get_unchecked_mut(j + 2) = s2;
822 *seg.get_unchecked_mut(j + 3) = s3;
823
824 denom += s0 + s1 + s2 + s3;
825 j += 4;
826 }
827 while j < period {
828 let k = base + j;
829 let d = *data.get_unchecked(k) - *data.get_unchecked(k - 1);
830 let s = (d.mul_add(d, 1.0)).sqrt();
831 *seg.get_unchecked_mut(j) = s;
832 denom += s;
833 j += 1;
834 }
835
836 let mut ema_started = false;
837 let mut ema_val = 0.0f64;
838
839 for t in start..len {
840 let cur = *data.get_unchecked(t);
841 let past = *data.get_unchecked(t - period);
842 let diff = cur - past;
843
844 let long_leg = (diff.mul_add(diff, p2)).sqrt();
845 let raw = if denom <= f64::EPSILON {
846 0.0
847 } else {
848 100.0 * (long_leg / denom)
849 };
850 let signed = if diff > 0.0 { raw } else { -raw };
851
852 let val = if !ema_started {
853 ema_started = true;
854 ema_val = signed;
855 signed
856 } else {
857 ema_val = alpha.mul_add(signed, one_minus_alpha * ema_val);
858 ema_val
859 };
860
861 *out_row.get_unchecked_mut(t) = val;
862
863 if t + 1 < len {
864 let old = *seg.get_unchecked(head);
865 let next = *data.get_unchecked(t + 1);
866 let new_d = next - cur;
867 let new_s = (new_d.mul_add(new_d, 1.0)).sqrt();
868 denom += new_s - old;
869 *seg.get_unchecked_mut(head) = new_s;
870
871 head += 1;
872 if head == period {
873 head = 0;
874 }
875 }
876 }
877}
878
879#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
880#[inline(always)]
881unsafe fn pfe_row_avx2(
882 data: &[f64],
883 first: usize,
884 period: usize,
885 smoothing: usize,
886 out: &mut [f64],
887) {
888 use core::arch::x86_64::*;
889
890 let len = data.len();
891 let start = first + period;
892
893 let p = period as f64;
894 let p2 = p * p;
895 let alpha = 2.0 / (smoothing as f64 + 1.0);
896 let one_minus_alpha = 1.0 - alpha;
897
898 let mut seg = vec![0.0f64; period];
899 let mut head = 0usize;
900
901 let base = start - period + 1;
902 let mut denom = 0.0f64;
903
904 let ones = _mm256_set1_pd(1.0);
905
906 let mut j = 0usize;
907 let stop = period & !3;
908 while j < stop {
909 let k0 = base + j;
910
911 let prev = _mm256_loadu_pd(data.as_ptr().add(k0 - 1));
912 let curr = _mm256_loadu_pd(data.as_ptr().add(k0));
913 let diff = _mm256_sub_pd(curr, prev);
914 let sq = _mm256_sqrt_pd(_mm256_add_pd(_mm256_mul_pd(diff, diff), ones));
915
916 _mm256_storeu_pd(seg.as_mut_ptr().add(j), sq);
917
918 let hadd1 = _mm256_hadd_pd(sq, sq);
919 let lo = _mm256_extractf128_pd(hadd1, 0);
920 let hi = _mm256_extractf128_pd(hadd1, 1);
921 let sum2 = _mm_add_pd(lo, hi);
922 denom += _mm_cvtsd_f64(sum2);
923
924 j += 4;
925 }
926 while j < period {
927 let k = base + j;
928 let d = *data.get_unchecked(k) - *data.get_unchecked(k - 1);
929 let s = (d.mul_add(d, 1.0)).sqrt();
930 *seg.get_unchecked_mut(j) = s;
931 denom += s;
932 j += 1;
933 }
934
935 let mut ema_started = false;
936 let mut ema_val = 0.0f64;
937
938 for t in start..len {
939 let cur = *data.get_unchecked(t);
940 let past = *data.get_unchecked(t - period);
941 let diff = cur - past;
942
943 let long_leg = (diff.mul_add(diff, p2)).sqrt();
944 let raw = if denom <= f64::EPSILON {
945 0.0
946 } else {
947 100.0 * (long_leg / denom)
948 };
949 let signed = if diff > 0.0 { raw } else { -raw };
950
951 let val = if !ema_started {
952 ema_started = true;
953 ema_val = signed;
954 signed
955 } else {
956 ema_val = alpha.mul_add(signed, one_minus_alpha * ema_val);
957 ema_val
958 };
959
960 *out.get_unchecked_mut(t) = val;
961
962 if t + 1 < len {
963 let old = *seg.get_unchecked(head);
964 let next = *data.get_unchecked(t + 1);
965 let new_d = next - cur;
966 let new_s = (new_d.mul_add(new_d, 1.0)).sqrt();
967 denom += new_s - old;
968 *seg.get_unchecked_mut(head) = new_s;
969
970 head += 1;
971 if head == period {
972 head = 0;
973 }
974 }
975 }
976}
977
978#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
979#[inline(always)]
980unsafe fn pfe_row_avx512(
981 data: &[f64],
982 first: usize,
983 period: usize,
984 smoothing: usize,
985 out: &mut [f64],
986) {
987 use core::arch::x86_64::*;
988
989 let len = data.len();
990 let start = first + period;
991
992 let p = period as f64;
993 let p2 = p * p;
994 let alpha = 2.0 / (smoothing as f64 + 1.0);
995 let one_minus_alpha = 1.0 - alpha;
996
997 let mut seg = vec![0.0f64; period];
998 let mut head = 0usize;
999
1000 let base = start - period + 1;
1001 let mut denom = 0.0f64;
1002
1003 let ones = _mm512_set1_pd(1.0);
1004
1005 let mut j = 0usize;
1006 let stop = period & !7;
1007 while j < stop {
1008 let k0 = base + j;
1009
1010 let prev = _mm512_loadu_pd(data.as_ptr().add(k0 - 1));
1011 let curr = _mm512_loadu_pd(data.as_ptr().add(k0));
1012 let diff = _mm512_sub_pd(curr, prev);
1013 let sq = _mm512_sqrt_pd(_mm512_add_pd(_mm512_mul_pd(diff, diff), ones));
1014
1015 _mm512_storeu_pd(seg.as_mut_ptr().add(j), sq);
1016
1017 let lo256 = _mm512_extractf64x4_pd(sq, 0);
1018 let hi256 = _mm512_extractf64x4_pd(sq, 1);
1019 let hadd_lo = _mm256_hadd_pd(lo256, lo256);
1020 let hadd_hi = _mm256_hadd_pd(hi256, hi256);
1021 let lo_pair = _mm_add_pd(
1022 _mm256_extractf128_pd(hadd_lo, 0),
1023 _mm256_extractf128_pd(hadd_lo, 1),
1024 );
1025 let hi_pair = _mm_add_pd(
1026 _mm256_extractf128_pd(hadd_hi, 0),
1027 _mm256_extractf128_pd(hadd_hi, 1),
1028 );
1029 denom += _mm_cvtsd_f64(lo_pair) + _mm_cvtsd_f64(hi_pair);
1030
1031 j += 8;
1032 }
1033 while j < period {
1034 let k = base + j;
1035 let d = *data.get_unchecked(k) - *data.get_unchecked(k - 1);
1036 let s = (d.mul_add(d, 1.0)).sqrt();
1037 *seg.get_unchecked_mut(j) = s;
1038 denom += s;
1039 j += 1;
1040 }
1041
1042 let mut ema_started = false;
1043 let mut ema_val = 0.0f64;
1044
1045 for t in start..len {
1046 let cur = *data.get_unchecked(t);
1047 let past = *data.get_unchecked(t - period);
1048 let diff = cur - past;
1049
1050 let long_leg = (diff.mul_add(diff, p2)).sqrt();
1051 let raw = if denom <= f64::EPSILON {
1052 0.0
1053 } else {
1054 100.0 * (long_leg / denom)
1055 };
1056 let signed = if diff > 0.0 { raw } else { -raw };
1057
1058 let val = if !ema_started {
1059 ema_started = true;
1060 ema_val = signed;
1061 signed
1062 } else {
1063 ema_val = alpha.mul_add(signed, one_minus_alpha * ema_val);
1064 ema_val
1065 };
1066
1067 *out.get_unchecked_mut(t) = val;
1068
1069 if t + 1 < len {
1070 let old = *seg.get_unchecked(head);
1071 let next = *data.get_unchecked(t + 1);
1072 let new_d = next - cur;
1073 let new_s = (new_d.mul_add(new_d, 1.0)).sqrt();
1074 denom += new_s - old;
1075 *seg.get_unchecked_mut(head) = new_s;
1076
1077 head += 1;
1078 if head == period {
1079 head = 0;
1080 }
1081 }
1082 }
1083}
1084
1085#[inline(always)]
1086unsafe fn pfe_row_scalar_with_prefix(
1087 data: &[f64],
1088 prefix: &[f64],
1089 first: usize,
1090 period: usize,
1091 smoothing: usize,
1092 out_row: &mut [f64],
1093) {
1094 let len = data.len();
1095 let start = first + period;
1096
1097 let p = period as f64;
1098 let p2 = p * p;
1099 let alpha = 2.0 / (smoothing as f64 + 1.0);
1100 let one_minus_alpha = 1.0 - alpha;
1101
1102 let mut ema_started = false;
1103 let mut ema_val = 0.0f64;
1104
1105 for t in start..len {
1106 let cur = *data.get_unchecked(t);
1107 let past = *data.get_unchecked(t - period);
1108 let diff = cur - past;
1109
1110 let long_leg = (diff.mul_add(diff, p2)).sqrt();
1111 let denom = *prefix.get_unchecked(t) - *prefix.get_unchecked(t - period);
1112 let raw = if denom <= f64::EPSILON {
1113 0.0
1114 } else {
1115 100.0 * (long_leg / denom)
1116 };
1117 let signed = if diff > 0.0 { raw } else { -raw };
1118
1119 let val = if !ema_started {
1120 ema_started = true;
1121 ema_val = signed;
1122 signed
1123 } else {
1124 ema_val = alpha.mul_add(signed, one_minus_alpha * ema_val);
1125 ema_val
1126 };
1127
1128 *out_row.get_unchecked_mut(t) = val;
1129 }
1130}
1131
1132#[derive(Debug, Clone)]
1133pub struct PfeStream {
1134 period: usize,
1135 smoothing: usize,
1136
1137 alpha: f64,
1138 one_minus_alpha: f64,
1139 p2: f64,
1140
1141 prices: Vec<f64>,
1142 price_pos: usize,
1143 price_count: usize,
1144
1145 seg: Vec<f64>,
1146 seg_pos: usize,
1147 seg_count: usize,
1148 short_sum: f64,
1149
1150 last_price: f64,
1151 have_last: bool,
1152
1153 ema_val: f64,
1154 started: bool,
1155}
1156
1157impl PfeStream {
1158 pub fn try_new(params: PfeParams) -> Result<Self, PfeError> {
1159 let period = params.period.unwrap_or(10);
1160 if period == 0 {
1161 return Err(PfeError::InvalidPeriod {
1162 period,
1163 data_len: 0,
1164 });
1165 }
1166 let smoothing = params.smoothing.unwrap_or(5);
1167 if smoothing == 0 {
1168 return Err(PfeError::InvalidSmoothing { smoothing });
1169 }
1170
1171 let cap_prices = period + 1;
1172 let cap_seg = period;
1173 let alpha = 2.0 / (smoothing as f64 + 1.0);
1174
1175 Ok(Self {
1176 period,
1177 smoothing,
1178 alpha,
1179 one_minus_alpha: 1.0 - alpha,
1180 p2: (period as f64) * (period as f64),
1181
1182 prices: vec![0.0; cap_prices],
1183 price_pos: cap_prices - 1,
1184 price_count: 0,
1185
1186 seg: if cap_seg > 0 {
1187 vec![0.0; cap_seg]
1188 } else {
1189 Vec::new()
1190 },
1191 seg_pos: 0,
1192 seg_count: 0,
1193 short_sum: 0.0,
1194
1195 last_price: 0.0,
1196 have_last: false,
1197
1198 ema_val: 0.0,
1199 started: false,
1200 })
1201 }
1202
1203 #[inline(always)]
1204 pub fn update(&mut self, price: f64) -> Option<f64> {
1205 if self.have_last {
1206 let d = price - self.last_price;
1207
1208 let new_s = (d.mul_add(d, 1.0)).sqrt();
1209
1210 if self.seg_count < self.period {
1211 self.short_sum += new_s;
1212 if self.period > 0 {
1213 self.seg[self.seg_pos] = new_s;
1214 self.seg_pos += 1;
1215 if self.seg_pos == self.period {
1216 self.seg_pos = 0;
1217 }
1218 }
1219 self.seg_count += 1;
1220 } else if self.period > 0 {
1221 let old = self.seg[self.seg_pos];
1222 self.short_sum += new_s - old;
1223 self.seg[self.seg_pos] = new_s;
1224 self.seg_pos += 1;
1225 if self.seg_pos == self.period {
1226 self.seg_pos = 0;
1227 }
1228 }
1229 }
1230 self.last_price = price;
1231 self.have_last = true;
1232
1233 let cap_prices = self.period + 1;
1234 self.price_pos += 1;
1235 if self.price_pos == cap_prices {
1236 self.price_pos = 0;
1237 }
1238 self.prices[self.price_pos] = price;
1239 if self.price_count < cap_prices {
1240 self.price_count += 1;
1241 }
1242
1243 if self.price_count < cap_prices {
1244 return None;
1245 }
1246
1247 let past_idx = self.price_pos + 1 - ((self.price_pos + 1) / cap_prices) * cap_prices;
1248 let past = self.prices[past_idx];
1249 let diff = price - past;
1250
1251 let long_leg = (diff.mul_add(diff, self.p2)).sqrt();
1252 let denom = self.short_sum;
1253
1254 let inv = if denom <= f64::EPSILON {
1255 0.0
1256 } else {
1257 1.0 / denom
1258 };
1259 let raw = 100.0 * long_leg * inv;
1260 let signed = if diff > 0.0 { raw } else { -raw };
1261
1262 let out = if !self.started {
1263 self.started = true;
1264 self.ema_val = signed;
1265 signed
1266 } else {
1267 self.ema_val = self
1268 .alpha
1269 .mul_add(signed, self.one_minus_alpha * self.ema_val);
1270 self.ema_val
1271 };
1272
1273 Some(out)
1274 }
1275}
1276
1277#[cfg(all(feature = "python", feature = "cuda"))]
1278#[pyclass(module = "ta_indicators.cuda", name = "PfeDeviceArrayF32", unsendable)]
1279pub struct PfeDeviceArrayF32Py {
1280 pub(crate) inner: DeviceArrayF32,
1281 _ctx_guard: Arc<Context>,
1282 _device_id: u32,
1283}
1284
1285#[cfg(all(feature = "python", feature = "cuda"))]
1286#[pymethods]
1287impl PfeDeviceArrayF32Py {
1288 #[getter]
1289 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1290 let inner = &self.inner;
1291 let d = PyDict::new(py);
1292 let itemsize = std::mem::size_of::<f32>();
1293 d.set_item("shape", (inner.rows, inner.cols))?;
1294 d.set_item("typestr", "<f4")?;
1295 d.set_item("strides", (inner.cols * itemsize, itemsize))?;
1296 d.set_item("data", (inner.device_ptr() as usize, false))?;
1297 d.set_item("version", 3)?;
1298 Ok(d)
1299 }
1300
1301 fn __dlpack_device__(&self) -> (i32, i32) {
1302 (2, self._device_id as i32)
1303 }
1304
1305 #[pyo3(signature=(stream=None, max_version=None, dl_device=None, copy=None))]
1306 fn __dlpack__<'py>(
1307 &mut self,
1308 py: Python<'py>,
1309 stream: Option<PyObject>,
1310 max_version: Option<PyObject>,
1311 dl_device: Option<PyObject>,
1312 copy: Option<PyObject>,
1313 ) -> PyResult<PyObject> {
1314 let (kdl, alloc_dev) = self.__dlpack_device__();
1315 if let Some(dev_obj) = dl_device.as_ref() {
1316 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1317 if dev_ty != kdl || dev_id != alloc_dev {
1318 let wants_copy = copy
1319 .as_ref()
1320 .and_then(|c| c.extract::<bool>(py).ok())
1321 .unwrap_or(false);
1322 if wants_copy {
1323 return Err(PyValueError::new_err(
1324 "device copy not implemented for __dlpack__",
1325 ));
1326 } else {
1327 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
1328 }
1329 }
1330 }
1331 }
1332 let _ = stream;
1333
1334 let dummy =
1335 DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
1336 let inner = std::mem::replace(
1337 &mut self.inner,
1338 DeviceArrayF32 {
1339 buf: dummy,
1340 rows: 0,
1341 cols: 0,
1342 },
1343 );
1344
1345 let rows = inner.rows;
1346 let cols = inner.cols;
1347 let buf = inner.buf;
1348
1349 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1350 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1351 }
1352}
1353
1354#[cfg(all(feature = "python", feature = "cuda"))]
1355impl PfeDeviceArrayF32Py {
1356 pub fn new_from_rust(inner: DeviceArrayF32, ctx_guard: Arc<Context>, device_id: u32) -> Self {
1357 Self {
1358 inner,
1359 _ctx_guard: ctx_guard,
1360 _device_id: device_id,
1361 }
1362 }
1363}
1364
1365#[cfg(feature = "python")]
1366#[pyfunction(name = "pfe")]
1367#[pyo3(signature = (data, period, smoothing, kernel=None))]
1368pub fn pfe_py<'py>(
1369 py: Python<'py>,
1370 data: PyReadonlyArray1<'py, f64>,
1371 period: usize,
1372 smoothing: usize,
1373 kernel: Option<&str>,
1374) -> PyResult<Bound<'py, PyArray1<f64>>> {
1375 use numpy::{IntoPyArray, PyArrayMethods};
1376
1377 let slice_in = data.as_slice()?;
1378 let kern = validate_kernel(kernel, false)?;
1379 let params = PfeParams {
1380 period: Some(period),
1381 smoothing: Some(smoothing),
1382 };
1383 let pfe_in = PfeInput::from_slice(slice_in, params);
1384
1385 let result_vec: Vec<f64> = py
1386 .allow_threads(|| pfe_with_kernel(&pfe_in, kern).map(|o| o.values))
1387 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1388
1389 Ok(result_vec.into_pyarray(py))
1390}
1391
1392#[cfg(feature = "python")]
1393#[pyfunction(name = "pfe_batch")]
1394#[pyo3(signature = (data, period_range, smoothing_range, kernel=None))]
1395pub fn pfe_batch_py<'py>(
1396 py: Python<'py>,
1397 data: PyReadonlyArray1<'py, f64>,
1398 period_range: (usize, usize, usize),
1399 smoothing_range: (usize, usize, usize),
1400 kernel: Option<&str>,
1401) -> PyResult<Bound<'py, PyDict>> {
1402 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1403 use pyo3::types::PyDict;
1404
1405 let slice_in = data.as_slice()?;
1406
1407 let sweep = PfeBatchRange {
1408 period: period_range,
1409 smoothing: smoothing_range,
1410 };
1411
1412 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1413 let rows = combos.len();
1414 let cols = slice_in.len();
1415
1416 let total = rows
1417 .checked_mul(cols)
1418 .ok_or_else(|| PyValueError::new_err("pfe_batch: rows*cols overflow"))?;
1419 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1420 let slice_out = unsafe { out_arr.as_slice_mut()? };
1421
1422 let kern = validate_kernel(kernel, true)?;
1423
1424 let combos = py
1425 .allow_threads(|| {
1426 let kernel = match kern {
1427 Kernel::Auto => detect_best_batch_kernel(),
1428 k => k,
1429 };
1430 let simd = match kernel {
1431 Kernel::Avx512Batch => Kernel::Avx512,
1432 Kernel::Avx2Batch => Kernel::Avx2,
1433 Kernel::ScalarBatch => Kernel::Scalar,
1434 _ => unreachable!(),
1435 };
1436 pfe_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1437 })
1438 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1439
1440 let dict = PyDict::new(py);
1441 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1442 dict.set_item(
1443 "periods",
1444 combos
1445 .iter()
1446 .map(|p| p.period.unwrap() as u64)
1447 .collect::<Vec<_>>()
1448 .into_pyarray(py),
1449 )?;
1450 dict.set_item(
1451 "smoothings",
1452 combos
1453 .iter()
1454 .map(|p| p.smoothing.unwrap() as u64)
1455 .collect::<Vec<_>>()
1456 .into_pyarray(py),
1457 )?;
1458
1459 Ok(dict)
1460}
1461
1462#[cfg(all(feature = "python", feature = "cuda"))]
1463#[pyfunction(name = "pfe_cuda_batch_dev")]
1464#[pyo3(signature = (data, period_range, smoothing_range, device_id=0))]
1465pub fn pfe_cuda_batch_dev_py(
1466 py: Python<'_>,
1467 data: PyReadonlyArray1<'_, f32>,
1468 period_range: (usize, usize, usize),
1469 smoothing_range: (usize, usize, usize),
1470 device_id: usize,
1471) -> PyResult<PfeDeviceArrayF32Py> {
1472 use crate::cuda::cuda_available;
1473 if !cuda_available() {
1474 return Err(PyValueError::new_err("CUDA not available"));
1475 }
1476 let slice_in = data.as_slice()?;
1477 let sweep = PfeBatchRange {
1478 period: period_range,
1479 smoothing: smoothing_range,
1480 };
1481 let (inner, ctx, dev_id) = py.allow_threads(|| {
1482 let cuda = CudaPfe::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1483 let dev = cuda
1484 .pfe_batch_dev(slice_in, &sweep)
1485 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1486 Ok::<_, PyErr>((dev, cuda.context_arc(), cuda.device_id()))
1487 })?;
1488 Ok(PfeDeviceArrayF32Py::new_from_rust(inner, ctx, dev_id))
1489}
1490
1491#[cfg(all(feature = "python", feature = "cuda"))]
1492#[pyfunction(name = "pfe_cuda_many_series_one_param_dev")]
1493#[pyo3(signature = (data_tm, cols, rows, period, smoothing, device_id=0))]
1494pub fn pfe_cuda_many_series_one_param_dev_py(
1495 py: Python<'_>,
1496 data_tm: PyReadonlyArray1<'_, f32>,
1497 cols: usize,
1498 rows: usize,
1499 period: usize,
1500 smoothing: usize,
1501 device_id: usize,
1502) -> PyResult<PfeDeviceArrayF32Py> {
1503 use crate::cuda::cuda_available;
1504 if !cuda_available() {
1505 return Err(PyValueError::new_err("CUDA not available"));
1506 }
1507 let tm_slice = data_tm.as_slice()?;
1508 let (inner, ctx, dev_id) = py.allow_threads(|| {
1509 let cuda = CudaPfe::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1510 let dev = cuda
1511 .pfe_many_series_one_param_time_major_dev(tm_slice, cols, rows, period, smoothing)
1512 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1513 Ok::<_, PyErr>((dev, cuda.context_arc(), cuda.device_id()))
1514 })?;
1515 Ok(PfeDeviceArrayF32Py::new_from_rust(inner, ctx, dev_id))
1516}
1517
1518#[cfg(feature = "python")]
1519#[pyclass(name = "PfeStream")]
1520pub struct PfeStreamPy {
1521 stream: PfeStream,
1522}
1523
1524#[cfg(feature = "python")]
1525#[pymethods]
1526impl PfeStreamPy {
1527 #[new]
1528 fn new(period: usize, smoothing: usize) -> PyResult<Self> {
1529 let params = PfeParams {
1530 period: Some(period),
1531 smoothing: Some(smoothing),
1532 };
1533 let stream =
1534 PfeStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1535 Ok(PfeStreamPy { stream })
1536 }
1537
1538 fn update(&mut self, value: f64) -> Option<f64> {
1539 self.stream.update(value)
1540 }
1541}
1542
1543#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1544use serde::{Deserialize, Serialize};
1545#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1546use wasm_bindgen::prelude::*;
1547
1548#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1549#[wasm_bindgen]
1550pub fn pfe_js(data: &[f64], period: usize, smoothing: usize) -> Result<Vec<f64>, JsValue> {
1551 let params = PfeParams {
1552 period: Some(period),
1553 smoothing: Some(smoothing),
1554 };
1555 let input = PfeInput::from_slice(data, params);
1556 let mut out = vec![0.0; data.len()];
1557 pfe_into_slice(&mut out, &input, detect_best_kernel())
1558 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1559 Ok(out)
1560}
1561
1562#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1563#[wasm_bindgen]
1564pub fn pfe_alloc(len: usize) -> *mut f64 {
1565 let mut v = Vec::<f64>::with_capacity(len);
1566 let p = v.as_mut_ptr();
1567 std::mem::forget(v);
1568 p
1569}
1570
1571#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1572#[wasm_bindgen]
1573pub fn pfe_free(ptr: *mut f64, len: usize) {
1574 unsafe {
1575 let _ = Vec::from_raw_parts(ptr, len, len);
1576 }
1577}
1578
1579#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1580#[wasm_bindgen]
1581pub fn pfe_into(
1582 in_ptr: *const f64,
1583 out_ptr: *mut f64,
1584 len: usize,
1585 period: usize,
1586 smoothing: usize,
1587) -> Result<(), JsValue> {
1588 if in_ptr.is_null() || out_ptr.is_null() {
1589 return Err(JsValue::from_str("null pointer"));
1590 }
1591 unsafe {
1592 let data = std::slice::from_raw_parts(in_ptr, len);
1593 let params = PfeParams {
1594 period: Some(period),
1595 smoothing: Some(smoothing),
1596 };
1597 let input = PfeInput::from_slice(data, params);
1598 if in_ptr == out_ptr {
1599 let mut tmp = vec![0.0; len];
1600 pfe_into_slice(&mut tmp, &input, detect_best_kernel())
1601 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1602 std::slice::from_raw_parts_mut(out_ptr, len).copy_from_slice(&tmp);
1603 } else {
1604 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1605 pfe_into_slice(out, &input, detect_best_kernel())
1606 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1607 }
1608 Ok(())
1609 }
1610}
1611
1612#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1613#[derive(Serialize, Deserialize)]
1614pub struct PfeBatchConfig {
1615 pub period_range: (usize, usize, usize),
1616 pub smoothing_range: (usize, usize, usize),
1617}
1618
1619#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1620#[derive(Serialize, Deserialize)]
1621pub struct PfeBatchJsOutput {
1622 pub values: Vec<f64>,
1623 pub combos: Vec<PfeParams>,
1624 pub rows: usize,
1625 pub cols: usize,
1626}
1627
1628#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1629#[wasm_bindgen(js_name = pfe_batch)]
1630pub fn pfe_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1631 let cfg: PfeBatchConfig = serde_wasm_bindgen::from_value(config)
1632 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1633 let sweep = PfeBatchRange {
1634 period: cfg.period_range,
1635 smoothing: cfg.smoothing_range,
1636 };
1637 let out = pfe_batch_inner(data, &sweep, detect_best_kernel(), false)
1638 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1639 let js = PfeBatchJsOutput {
1640 values: out.values,
1641 combos: out.combos,
1642 rows: out.rows,
1643 cols: out.cols,
1644 };
1645 serde_wasm_bindgen::to_value(&js)
1646 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1647}
1648
1649#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1650#[wasm_bindgen]
1651pub fn pfe_batch_into(
1652 in_ptr: *const f64,
1653 out_ptr: *mut f64,
1654 len: usize,
1655 p_start: usize,
1656 p_end: usize,
1657 p_step: usize,
1658 s_start: usize,
1659 s_end: usize,
1660 s_step: usize,
1661) -> Result<usize, JsValue> {
1662 if in_ptr.is_null() || out_ptr.is_null() {
1663 return Err(JsValue::from_str("null pointer"));
1664 }
1665 unsafe {
1666 let data = std::slice::from_raw_parts(in_ptr, len);
1667 let sweep = PfeBatchRange {
1668 period: (p_start, p_end, p_step),
1669 smoothing: (s_start, s_end, s_step),
1670 };
1671 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1672 let rows = combos.len();
1673 let total = len
1674 .checked_mul(rows)
1675 .ok_or_else(|| JsValue::from_str("pfe_batch_into: rows*cols overflow"))?;
1676 let combos = pfe_batch_inner_into(
1677 data,
1678 &sweep,
1679 detect_best_kernel(),
1680 false,
1681 std::slice::from_raw_parts_mut(out_ptr, total),
1682 )
1683 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1684 Ok(combos.len())
1685 }
1686}
1687
1688#[cfg(test)]
1689mod tests {
1690 use super::*;
1691 use crate::skip_if_unsupported;
1692 use crate::utilities::data_loader::read_candles_from_csv;
1693
1694 fn check_pfe_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1695 skip_if_unsupported!(kernel, test_name);
1696 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1697 let candles = read_candles_from_csv(file_path)?;
1698
1699 let default_params = PfeParams {
1700 period: None,
1701 smoothing: None,
1702 };
1703 let input = PfeInput::from_candles(&candles, "close", default_params);
1704 let output = pfe_with_kernel(&input, kernel)?;
1705 assert_eq!(output.values.len(), candles.close.len());
1706
1707 Ok(())
1708 }
1709
1710 fn check_pfe_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1711 skip_if_unsupported!(kernel, test_name);
1712 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1713 let candles = read_candles_from_csv(file_path)?;
1714 let close_prices = &candles.close;
1715
1716 let params = PfeParams {
1717 period: Some(10),
1718 smoothing: Some(5),
1719 };
1720 let input = PfeInput::from_candles(&candles, "close", params);
1721 let pfe_result = pfe_with_kernel(&input, kernel)?;
1722
1723 assert_eq!(pfe_result.values.len(), close_prices.len());
1724
1725 let expected_last_five_pfe = [
1726 -13.03562252,
1727 -11.93979855,
1728 -9.94609862,
1729 -9.73372410,
1730 -14.88374798,
1731 ];
1732 let start_index = pfe_result.values.len() - 5;
1733 let result_last_five_pfe = &pfe_result.values[start_index..];
1734 for (i, &value) in result_last_five_pfe.iter().enumerate() {
1735 let expected_value = expected_last_five_pfe[i];
1736 assert!(
1737 (value - expected_value).abs() < 1e-8,
1738 "[{}] PFE mismatch at idx {}: got {}, expected {}",
1739 test_name,
1740 i,
1741 value,
1742 expected_value
1743 );
1744 }
1745
1746 for i in 0..(10 - 1) {
1747 assert!(pfe_result.values[i].is_nan());
1748 }
1749
1750 Ok(())
1751 }
1752
1753 fn check_pfe_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1754 skip_if_unsupported!(kernel, test_name);
1755 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1756 let candles = read_candles_from_csv(file_path)?;
1757
1758 let input = PfeInput::with_default_candles(&candles);
1759 match input.data {
1760 PfeData::Candles { source, .. } => assert_eq!(source, "close"),
1761 _ => panic!("Expected PfeData::Candles"),
1762 }
1763 let output = pfe_with_kernel(&input, kernel)?;
1764 assert_eq!(output.values.len(), candles.close.len());
1765
1766 Ok(())
1767 }
1768
1769 fn check_pfe_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1770 skip_if_unsupported!(kernel, test_name);
1771 let input_data = [10.0, 20.0, 30.0];
1772 let params = PfeParams {
1773 period: Some(0),
1774 smoothing: Some(5),
1775 };
1776 let input = PfeInput::from_slice(&input_data, params);
1777 let res = pfe_with_kernel(&input, kernel);
1778 assert!(
1779 res.is_err(),
1780 "[{}] PFE should fail with zero period",
1781 test_name
1782 );
1783 Ok(())
1784 }
1785
1786 fn check_pfe_period_exceeds_length(
1787 test_name: &str,
1788 kernel: Kernel,
1789 ) -> Result<(), Box<dyn Error>> {
1790 skip_if_unsupported!(kernel, test_name);
1791 let data_small = [10.0, 20.0, 30.0];
1792 let params = PfeParams {
1793 period: Some(10),
1794 smoothing: Some(2),
1795 };
1796 let input = PfeInput::from_slice(&data_small, params);
1797 let res = pfe_with_kernel(&input, kernel);
1798 assert!(
1799 res.is_err(),
1800 "[{}] PFE should fail with period exceeding length",
1801 test_name
1802 );
1803 Ok(())
1804 }
1805
1806 fn check_pfe_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1807 skip_if_unsupported!(kernel, test_name);
1808 let single_point = [42.0];
1809 let params = PfeParams {
1810 period: Some(10),
1811 smoothing: Some(2),
1812 };
1813 let input = PfeInput::from_slice(&single_point, params);
1814 let res = pfe_with_kernel(&input, kernel);
1815 assert!(
1816 res.is_err(),
1817 "[{}] PFE should fail with insufficient data",
1818 test_name
1819 );
1820 Ok(())
1821 }
1822
1823 fn check_pfe_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1824 skip_if_unsupported!(kernel, test_name);
1825 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1826 let candles = read_candles_from_csv(file_path)?;
1827
1828 let first_params = PfeParams {
1829 period: Some(10),
1830 smoothing: Some(5),
1831 };
1832 let first_input = PfeInput::from_candles(&candles, "close", first_params);
1833 let first_result = pfe_with_kernel(&first_input, kernel)?;
1834
1835 let second_params = PfeParams {
1836 period: Some(10),
1837 smoothing: Some(5),
1838 };
1839 let second_input = PfeInput::from_slice(&first_result.values, second_params);
1840 let second_result = pfe_with_kernel(&second_input, kernel)?;
1841
1842 assert_eq!(second_result.values.len(), first_result.values.len());
1843 for i in 20..second_result.values.len() {
1844 assert!(
1845 !second_result.values[i].is_nan(),
1846 "[{}] Expected no NaN after index 20, but found NaN at index {}",
1847 test_name,
1848 i
1849 );
1850 }
1851 Ok(())
1852 }
1853
1854 fn check_pfe_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1855 skip_if_unsupported!(kernel, test_name);
1856 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1857 let candles = read_candles_from_csv(file_path)?;
1858
1859 let input = PfeInput::from_candles(
1860 &candles,
1861 "close",
1862 PfeParams {
1863 period: Some(10),
1864 smoothing: Some(5),
1865 },
1866 );
1867 let res = pfe_with_kernel(&input, kernel)?;
1868 assert_eq!(res.values.len(), candles.close.len());
1869 if res.values.len() > 240 {
1870 for (i, &val) in res.values[240..].iter().enumerate() {
1871 assert!(
1872 !val.is_nan(),
1873 "[{}] Found unexpected NaN at out-index {}",
1874 test_name,
1875 240 + i
1876 );
1877 }
1878 }
1879 Ok(())
1880 }
1881
1882 fn check_pfe_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1883 skip_if_unsupported!(kernel, test_name);
1884
1885 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1886 let candles = read_candles_from_csv(file_path)?;
1887
1888 let period = 10;
1889 let smoothing = 5;
1890
1891 let input = PfeInput::from_candles(
1892 &candles,
1893 "close",
1894 PfeParams {
1895 period: Some(period),
1896 smoothing: Some(smoothing),
1897 },
1898 );
1899 let batch_output = pfe_with_kernel(&input, kernel)?.values;
1900
1901 let mut stream = PfeStream::try_new(PfeParams {
1902 period: Some(period),
1903 smoothing: Some(smoothing),
1904 })?;
1905
1906 let mut stream_values = Vec::with_capacity(candles.close.len());
1907 for &price in &candles.close {
1908 match stream.update(price) {
1909 Some(val) => stream_values.push(val),
1910 None => stream_values.push(f64::NAN),
1911 }
1912 }
1913
1914 assert_eq!(batch_output.len(), stream_values.len());
1915 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1916 if b.is_nan() && s.is_nan() {
1917 continue;
1918 }
1919 let diff = (b - s).abs();
1920 assert!(
1921 diff < 1e-9,
1922 "[{}] PFE streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1923 test_name,
1924 i,
1925 b,
1926 s,
1927 diff
1928 );
1929 }
1930 Ok(())
1931 }
1932
1933 #[test]
1934 fn test_pfe_into_matches_api() -> Result<(), Box<dyn Error>> {
1935 let mut data = vec![0.0f64; 256];
1936 for (i, v) in data.iter_mut().enumerate() {
1937 let x = i as f64;
1938 *v = 0.05 * x + (x * 0.1).sin() * 2.0 + (x * 0.03).cos();
1939 }
1940
1941 let input = PfeInput::from_slice(&data, PfeParams::default());
1942 let baseline = pfe(&input)?.values;
1943
1944 let mut into_out = vec![0.0f64; data.len()];
1945 pfe_into(&input, &mut into_out)?;
1946 assert_eq!(baseline.len(), into_out.len());
1947
1948 fn eq_or_both_nan(a: f64, b: f64) -> bool {
1949 (a.is_nan() && b.is_nan()) || (a == b)
1950 }
1951 for i in 0..baseline.len() {
1952 assert!(
1953 eq_or_both_nan(baseline[i], into_out[i]),
1954 "Mismatch at index {}: into={}, api={}",
1955 i,
1956 into_out[i],
1957 baseline[i]
1958 );
1959 }
1960 Ok(())
1961 }
1962
1963 #[cfg(debug_assertions)]
1964 fn check_pfe_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1965 skip_if_unsupported!(kernel, test_name);
1966
1967 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1968 let candles = read_candles_from_csv(file_path)?;
1969
1970 let test_params = vec![
1971 PfeParams::default(),
1972 PfeParams {
1973 period: Some(2),
1974 smoothing: Some(1),
1975 },
1976 PfeParams {
1977 period: Some(5),
1978 smoothing: Some(2),
1979 },
1980 PfeParams {
1981 period: Some(10),
1982 smoothing: Some(3),
1983 },
1984 PfeParams {
1985 period: Some(14),
1986 smoothing: Some(5),
1987 },
1988 PfeParams {
1989 period: Some(20),
1990 smoothing: Some(5),
1991 },
1992 PfeParams {
1993 period: Some(20),
1994 smoothing: Some(10),
1995 },
1996 PfeParams {
1997 period: Some(50),
1998 smoothing: Some(15),
1999 },
2000 PfeParams {
2001 period: Some(100),
2002 smoothing: Some(20),
2003 },
2004 PfeParams {
2005 period: Some(3),
2006 smoothing: Some(30),
2007 },
2008 ];
2009
2010 for (param_idx, params) in test_params.iter().enumerate() {
2011 let input = PfeInput::from_candles(&candles, "close", params.clone());
2012 let output = pfe_with_kernel(&input, kernel)?;
2013
2014 for (i, &val) in output.values.iter().enumerate() {
2015 if val.is_nan() {
2016 continue;
2017 }
2018
2019 let bits = val.to_bits();
2020
2021 if bits == 0x11111111_11111111 {
2022 panic!(
2023 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
2024 with params: period={}, smoothing={} (param set {})",
2025 test_name,
2026 val,
2027 bits,
2028 i,
2029 params.period.unwrap_or(10),
2030 params.smoothing.unwrap_or(5),
2031 param_idx
2032 );
2033 }
2034
2035 if bits == 0x22222222_22222222 {
2036 panic!(
2037 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
2038 with params: period={}, smoothing={} (param set {})",
2039 test_name,
2040 val,
2041 bits,
2042 i,
2043 params.period.unwrap_or(10),
2044 params.smoothing.unwrap_or(5),
2045 param_idx
2046 );
2047 }
2048
2049 if bits == 0x33333333_33333333 {
2050 panic!(
2051 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
2052 with params: period={}, smoothing={} (param set {})",
2053 test_name,
2054 val,
2055 bits,
2056 i,
2057 params.period.unwrap_or(10),
2058 params.smoothing.unwrap_or(5),
2059 param_idx
2060 );
2061 }
2062 }
2063 }
2064
2065 Ok(())
2066 }
2067
2068 #[cfg(not(debug_assertions))]
2069 fn check_pfe_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2070 Ok(())
2071 }
2072
2073 #[cfg(feature = "proptest")]
2074 #[allow(clippy::float_cmp)]
2075 fn check_pfe_property(
2076 test_name: &str,
2077 kernel: Kernel,
2078 ) -> Result<(), Box<dyn std::error::Error>> {
2079 use proptest::prelude::*;
2080 skip_if_unsupported!(kernel, test_name);
2081
2082 let strat = (2usize..=64).prop_flat_map(|period| {
2083 (
2084 prop::collection::vec(
2085 prop::strategy::Union::new(vec![
2086 (0.01f64..10.0f64).boxed(),
2087 (10.0f64..1000.0f64).boxed(),
2088 (1000.0f64..100000.0f64).boxed(),
2089 ])
2090 .prop_filter("finite", |x| x.is_finite()),
2091 period + 50..400,
2092 ),
2093 Just(period),
2094 1usize..=20,
2095 )
2096 });
2097
2098 proptest::test_runner::TestRunner::default()
2099 .run(&strat, |(data, period, smoothing)| {
2100 let params = PfeParams {
2101 period: Some(period),
2102 smoothing: Some(smoothing),
2103 };
2104 let input = PfeInput::from_slice(&data, params);
2105
2106 let PfeOutput { values: out } = pfe_with_kernel(&input, kernel).unwrap();
2107 let PfeOutput { values: ref_out } =
2108 pfe_with_kernel(&input, Kernel::Scalar).unwrap();
2109
2110 prop_assert_eq!(
2111 out.len(),
2112 data.len(),
2113 "[{}] Output length mismatch",
2114 test_name
2115 );
2116
2117 for i in 0..period {
2118 prop_assert!(
2119 out[i].is_nan(),
2120 "[{}] Expected NaN at index {} during warmup (period={})",
2121 test_name,
2122 i,
2123 period
2124 );
2125 }
2126
2127 for i in period..data.len() {
2128 let y = out[i];
2129 let r = ref_out[i];
2130
2131 if y.is_nan() != r.is_nan() {
2132 prop_assert!(
2133 false,
2134 "[{}] NaN mismatch at index {}: {} vs {}",
2135 test_name,
2136 i,
2137 y,
2138 r
2139 );
2140 }
2141
2142 if y.is_finite() {
2143 prop_assert!(
2144 y >= -100.0 && y <= 100.0,
2145 "[{}] PFE value {} at index {} out of bounds [-100, 100]",
2146 test_name,
2147 y,
2148 i
2149 );
2150
2151 let y_bits = y.to_bits();
2152 let r_bits = r.to_bits();
2153 let ulp_diff = y_bits.abs_diff(r_bits);
2154
2155 prop_assert!(
2156 (y - r).abs() <= 1e-9 || ulp_diff <= 4,
2157 "[{}] Kernel mismatch at index {}: {} vs {} (ULP diff: {})",
2158 test_name,
2159 i,
2160 y,
2161 r,
2162 ulp_diff
2163 );
2164 }
2165 }
2166
2167 let straight_up: Vec<f64> = (0..100).map(|i| 100.0 + i as f64).collect();
2168 let straight_params = PfeParams {
2169 period: Some(10),
2170 smoothing: Some(1),
2171 };
2172 let straight_input = PfeInput::from_slice(&straight_up, straight_params);
2173 if let Ok(straight_out) = pfe_with_kernel(&straight_input, kernel) {
2174 for i in 15..straight_out.values.len() {
2175 if straight_out.values[i].is_finite() {
2176 prop_assert!(
2177 straight_out.values[i] > 50.0,
2178 "[{}] Straight line up should have high positive efficiency, got {} at index {}",
2179 test_name,
2180 straight_out.values[i],
2181 i
2182 );
2183 }
2184 }
2185 }
2186
2187 let zigzag: Vec<f64> = (0..100)
2188 .map(|i| {
2189 if i % 2 == 0 {
2190 100.0 + (i as f64)
2191 } else {
2192 100.0 + (i as f64) - 5.0
2193 }
2194 })
2195 .collect();
2196 let zigzag_params = PfeParams {
2197 period: Some(10),
2198 smoothing: Some(1),
2199 };
2200 let zigzag_input = PfeInput::from_slice(&zigzag, zigzag_params.clone());
2201 let straight_input2 = PfeInput::from_slice(&straight_up, zigzag_params);
2202
2203 if let (Ok(zigzag_out), Ok(straight_out2)) = (
2204 pfe_with_kernel(&zigzag_input, kernel),
2205 pfe_with_kernel(&straight_input2, kernel),
2206 ) {
2207 let zigzag_avg: f64 = zigzag_out.values[20..50]
2208 .iter()
2209 .filter(|x| x.is_finite())
2210 .map(|x| x.abs())
2211 .sum::<f64>()
2212 / 30.0;
2213 let straight_avg: f64 = straight_out2.values[20..50]
2214 .iter()
2215 .filter(|x| x.is_finite())
2216 .map(|x| x.abs())
2217 .sum::<f64>()
2218 / 30.0;
2219
2220 prop_assert!(
2221 zigzag_avg < straight_avg,
2222 "[{}] Zigzag pattern should have lower efficiency ({}) than straight line ({})",
2223 test_name,
2224 zigzag_avg,
2225 straight_avg
2226 );
2227 }
2228
2229 if smoothing > 1 && data.len() > period + 30 {
2230 let unsmoothed_params = PfeParams {
2231 period: Some(period),
2232 smoothing: Some(1),
2233 };
2234 let unsmoothed_input = PfeInput::from_slice(&data, unsmoothed_params);
2235
2236 if let Ok(unsmoothed_out) = pfe_with_kernel(&unsmoothed_input, kernel) {
2237 let window_start = period + 10;
2238 let window_end = (period + 30).min(out.len());
2239
2240 if window_end > window_start {
2241 let smoothed_variance =
2242 calculate_variance(&out[window_start..window_end]);
2243 let unsmoothed_variance = calculate_variance(
2244 &unsmoothed_out.values[window_start..window_end],
2245 );
2246
2247 let mut has_extreme_jumps = false;
2248 for i in 1..data.len() {
2249 let ratio = if data[i - 1] != 0.0 {
2250 (data[i] / data[i - 1]).abs()
2251 } else {
2252 f64::INFINITY
2253 };
2254
2255 if ratio > 100.0 || ratio < 0.01 {
2256 has_extreme_jumps = true;
2257 break;
2258 }
2259 }
2260
2261 if smoothed_variance.is_finite()
2262 && unsmoothed_variance.is_finite()
2263 && unsmoothed_variance > 1e-6
2264 && !has_extreme_jumps
2265 && smoothing <= 10
2266 {
2267 prop_assert!(
2268 smoothed_variance <= unsmoothed_variance * 1.5,
2269 "[{}] Smoothed variance ({}) should be <= 1.5x unsmoothed variance ({})",
2270 test_name,
2271 smoothed_variance,
2272 unsmoothed_variance
2273 );
2274 }
2275 }
2276 }
2277 }
2278
2279 if period == 2 {
2280 for i in 2..out.len() {
2281 if out[i].is_finite() {
2282 prop_assert!(
2283 out[i] >= -100.0 && out[i] <= 100.0,
2284 "[{}] Period=2 should still produce valid bounded values, got {} at index {}",
2285 test_name,
2286 out[i],
2287 i
2288 );
2289 }
2290 }
2291 }
2292
2293 let constant: Vec<f64> = vec![500.0; 50];
2294 let const_params = PfeParams {
2295 period: Some(10),
2296 smoothing: Some(1),
2297 };
2298 let const_input = PfeInput::from_slice(&constant, const_params);
2299 if let Ok(const_out) = pfe_with_kernel(&const_input, kernel) {
2300 for i in 15..const_out.values.len() {
2301 if const_out.values[i].is_finite() {
2302 prop_assert!(
2303 (const_out.values[i] - (-100.0)).abs() < 1e-6,
2304 "[{}] Constant prices should produce exactly -100, got {} at index {}",
2305 test_name,
2306 const_out.values[i],
2307 i
2308 );
2309 }
2310 }
2311 }
2312
2313 Ok(())
2314 })
2315 .unwrap();
2316
2317 fn calculate_variance(values: &[f64]) -> f64 {
2318 let finite_values: Vec<f64> =
2319 values.iter().filter(|x| x.is_finite()).copied().collect();
2320
2321 if finite_values.is_empty() {
2322 return f64::NAN;
2323 }
2324
2325 let mean = finite_values.iter().sum::<f64>() / finite_values.len() as f64;
2326 let variance = finite_values
2327 .iter()
2328 .map(|x| (x - mean).powi(2))
2329 .sum::<f64>()
2330 / finite_values.len() as f64;
2331 variance
2332 }
2333
2334 Ok(())
2335 }
2336
2337 macro_rules! generate_all_pfe_tests {
2338 ($($test_fn:ident),*) => {
2339 paste::paste! {
2340 $(
2341 #[test]
2342 fn [<$test_fn _scalar_f64>]() {
2343 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2344 }
2345 )*
2346 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2347 $(
2348 #[test]
2349 fn [<$test_fn _avx2_f64>]() {
2350 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2351 }
2352 #[test]
2353 fn [<$test_fn _avx512_f64>]() {
2354 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2355 }
2356 )*
2357 }
2358 }
2359 }
2360
2361 generate_all_pfe_tests!(
2362 check_pfe_partial_params,
2363 check_pfe_accuracy,
2364 check_pfe_default_candles,
2365 check_pfe_zero_period,
2366 check_pfe_period_exceeds_length,
2367 check_pfe_very_small_dataset,
2368 check_pfe_reinput,
2369 check_pfe_nan_handling,
2370 check_pfe_streaming,
2371 check_pfe_no_poison
2372 );
2373
2374 #[cfg(feature = "proptest")]
2375 generate_all_pfe_tests!(check_pfe_property);
2376
2377 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2378 skip_if_unsupported!(kernel, test);
2379
2380 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2381 let c = read_candles_from_csv(file)?;
2382
2383 let output = PfeBatchBuilder::new()
2384 .kernel(kernel)
2385 .apply_candles(&c, "close")?;
2386
2387 let def = PfeParams::default();
2388 let row = output.values_for(&def).expect("default row missing");
2389
2390 assert_eq!(row.len(), c.close.len());
2391
2392 let expected = [
2393 -13.03562252,
2394 -11.93979855,
2395 -9.94609862,
2396 -9.73372410,
2397 -14.88374798,
2398 ];
2399 let start = row.len() - 5;
2400 for (i, &v) in row[start..].iter().enumerate() {
2401 assert!(
2402 (v - expected[i]).abs() < 1e-8,
2403 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
2404 );
2405 }
2406 Ok(())
2407 }
2408
2409 #[cfg(debug_assertions)]
2410 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2411 skip_if_unsupported!(kernel, test);
2412
2413 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2414 let c = read_candles_from_csv(file)?;
2415
2416 let test_configs = vec![
2417 (2, 10, 2, 1, 5, 1),
2418 (5, 25, 5, 2, 10, 2),
2419 (30, 60, 15, 5, 20, 5),
2420 (2, 5, 1, 1, 3, 1),
2421 (10, 10, 0, 1, 20, 1),
2422 (2, 100, 10, 5, 5, 0),
2423 (14, 21, 7, 3, 9, 3),
2424 (50, 100, 25, 10, 30, 10),
2425 ];
2426
2427 for (cfg_idx, &(p_start, p_end, p_step, s_start, s_end, s_step)) in
2428 test_configs.iter().enumerate()
2429 {
2430 let output = PfeBatchBuilder::new()
2431 .kernel(kernel)
2432 .period_range(p_start, p_end, p_step)
2433 .smoothing_range(s_start, s_end, s_step)
2434 .apply_candles(&c, "close")?;
2435
2436 for (idx, &val) in output.values.iter().enumerate() {
2437 if val.is_nan() {
2438 continue;
2439 }
2440
2441 let bits = val.to_bits();
2442 let row = idx / output.cols;
2443 let col = idx % output.cols;
2444 let combo = &output.combos[row];
2445
2446 if bits == 0x11111111_11111111 {
2447 panic!(
2448 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2449 at row {} col {} (flat index {}) with params: period={}, smoothing={}",
2450 test,
2451 cfg_idx,
2452 val,
2453 bits,
2454 row,
2455 col,
2456 idx,
2457 combo.period.unwrap_or(10),
2458 combo.smoothing.unwrap_or(5)
2459 );
2460 }
2461
2462 if bits == 0x22222222_22222222 {
2463 panic!(
2464 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2465 at row {} col {} (flat index {}) with params: period={}, smoothing={}",
2466 test,
2467 cfg_idx,
2468 val,
2469 bits,
2470 row,
2471 col,
2472 idx,
2473 combo.period.unwrap_or(10),
2474 combo.smoothing.unwrap_or(5)
2475 );
2476 }
2477
2478 if bits == 0x33333333_33333333 {
2479 panic!(
2480 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2481 at row {} col {} (flat index {}) with params: period={}, smoothing={}",
2482 test,
2483 cfg_idx,
2484 val,
2485 bits,
2486 row,
2487 col,
2488 idx,
2489 combo.period.unwrap_or(10),
2490 combo.smoothing.unwrap_or(5)
2491 );
2492 }
2493 }
2494 }
2495
2496 Ok(())
2497 }
2498
2499 #[cfg(not(debug_assertions))]
2500 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2501 Ok(())
2502 }
2503
2504 macro_rules! gen_batch_tests {
2505 ($fn_name:ident) => {
2506 paste::paste! {
2507 #[test] fn [<$fn_name _scalar>]() {
2508 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2509 }
2510 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2511 #[test] fn [<$fn_name _avx2>]() {
2512 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2513 }
2514 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2515 #[test] fn [<$fn_name _avx512>]() {
2516 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2517 }
2518 #[test] fn [<$fn_name _auto_detect>]() {
2519 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2520 }
2521 }
2522 };
2523 }
2524 gen_batch_tests!(check_batch_default_row);
2525 gen_batch_tests!(check_batch_no_poison);
2526}