1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::utilities::data_loader::{source_type, Candles};
16use crate::utilities::enums::Kernel;
17use crate::utilities::helpers::{
18 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
19 make_uninit_matrix,
20};
21#[cfg(feature = "python")]
22use crate::utilities::kernel_validation::validate_kernel;
23#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
24use core::arch::x86_64::*;
25#[cfg(not(target_arch = "wasm32"))]
26use rayon::prelude::*;
27use std::convert::AsRef;
28use std::mem::{ManuallyDrop, MaybeUninit};
29use thiserror::Error;
30
31#[derive(Debug, Clone)]
32pub enum ApoData<'a> {
33 Candles {
34 candles: &'a Candles,
35 source: &'a str,
36 },
37 Slice(&'a [f64]),
38}
39
40impl<'a> AsRef<[f64]> for ApoInput<'a> {
41 #[inline(always)]
42 fn as_ref(&self) -> &[f64] {
43 match &self.data {
44 ApoData::Slice(slice) => slice,
45 ApoData::Candles { candles, source } => source_type(candles, source),
46 }
47 }
48}
49
50#[derive(Debug, Clone)]
51pub struct ApoOutput {
52 pub values: Vec<f64>,
53}
54
55#[derive(Debug, Clone)]
56#[cfg_attr(
57 all(target_arch = "wasm32", feature = "wasm"),
58 derive(Serialize, Deserialize)
59)]
60pub struct ApoParams {
61 pub short_period: Option<usize>,
62 pub long_period: Option<usize>,
63}
64impl Default for ApoParams {
65 fn default() -> Self {
66 Self {
67 short_period: Some(10),
68 long_period: Some(20),
69 }
70 }
71}
72
73#[derive(Debug, Clone)]
74pub struct ApoInput<'a> {
75 pub data: ApoData<'a>,
76 pub params: ApoParams,
77}
78impl<'a> ApoInput<'a> {
79 #[inline]
80 pub fn from_candles(c: &'a Candles, s: &'a str, p: ApoParams) -> Self {
81 Self {
82 data: ApoData::Candles {
83 candles: c,
84 source: s,
85 },
86 params: p,
87 }
88 }
89 #[inline]
90 pub fn from_slice(sl: &'a [f64], p: ApoParams) -> Self {
91 Self {
92 data: ApoData::Slice(sl),
93 params: p,
94 }
95 }
96 #[inline]
97 pub fn with_default_candles(c: &'a Candles) -> Self {
98 Self::from_candles(c, "close", ApoParams::default())
99 }
100 #[inline]
101 pub fn get_short_period(&self) -> usize {
102 self.params.short_period.unwrap_or(10)
103 }
104 #[inline]
105 pub fn get_long_period(&self) -> usize {
106 self.params.long_period.unwrap_or(20)
107 }
108}
109
110#[derive(Debug, Error)]
111pub enum ApoError {
112 #[error("apo: Input data slice is empty.")]
113 EmptyInputData,
114 #[error("apo: All values are NaN.")]
115 AllValuesNaN,
116 #[error("apo: Invalid period: short={short}, long={long}")]
117 InvalidPeriod { short: usize, long: usize },
118 #[error("apo: short_period not less than long_period: short={short}, long={long}")]
119 ShortPeriodNotLessThanLong { short: usize, long: usize },
120 #[error("apo: Not enough valid data: needed={needed}, valid={valid}")]
121 NotEnoughValidData { needed: usize, valid: usize },
122 #[error("apo: output length mismatch: expected={expected}, got={got}")]
123 OutputLengthMismatch { expected: usize, got: usize },
124 #[error("apo: invalid range: start={start}, end={end}, step={step}")]
125 InvalidRange {
126 start: usize,
127 end: usize,
128 step: usize,
129 },
130 #[error("apo: invalid kernel for batch: {0:?}")]
131 InvalidKernelForBatch(Kernel),
132}
133
134#[derive(Copy, Clone, Debug)]
135pub struct ApoBuilder {
136 short_period: Option<usize>,
137 long_period: Option<usize>,
138 kernel: Kernel,
139}
140impl Default for ApoBuilder {
141 fn default() -> Self {
142 Self {
143 short_period: None,
144 long_period: None,
145 kernel: Kernel::Auto,
146 }
147 }
148}
149impl ApoBuilder {
150 #[inline(always)]
151 pub fn new() -> Self {
152 Self::default()
153 }
154 #[inline(always)]
155 pub fn short_period(mut self, n: usize) -> Self {
156 self.short_period = Some(n);
157 self
158 }
159 #[inline(always)]
160 pub fn long_period(mut self, n: usize) -> Self {
161 self.long_period = Some(n);
162 self
163 }
164 #[inline(always)]
165 pub fn kernel(mut self, k: Kernel) -> Self {
166 self.kernel = k;
167 self
168 }
169 #[inline(always)]
170 pub fn apply(self, c: &Candles) -> Result<ApoOutput, ApoError> {
171 let p = ApoParams {
172 short_period: self.short_period,
173 long_period: self.long_period,
174 };
175 let i = ApoInput::from_candles(c, "close", p);
176 apo_with_kernel(&i, self.kernel)
177 }
178 #[inline(always)]
179 pub fn apply_slice(self, d: &[f64]) -> Result<ApoOutput, ApoError> {
180 let p = ApoParams {
181 short_period: self.short_period,
182 long_period: self.long_period,
183 };
184 let i = ApoInput::from_slice(d, p);
185 apo_with_kernel(&i, self.kernel)
186 }
187 #[inline(always)]
188 pub fn into_stream(self) -> Result<ApoStream, ApoError> {
189 let p = ApoParams {
190 short_period: self.short_period,
191 long_period: self.long_period,
192 };
193 ApoStream::try_new(p)
194 }
195}
196
197#[inline]
198pub fn apo(input: &ApoInput) -> Result<ApoOutput, ApoError> {
199 apo_with_kernel(input, Kernel::Auto)
200}
201
202#[inline(always)]
203fn apo_prepare<'a>(
204 input: &'a ApoInput,
205 kernel: Kernel,
206) -> Result<(&'a [f64], usize, usize, usize, usize, Kernel), ApoError> {
207 let data: &[f64] = input.as_ref();
208 let len = data.len();
209 if len == 0 {
210 return Err(ApoError::EmptyInputData);
211 }
212 let first = data
213 .iter()
214 .position(|x| !x.is_nan())
215 .ok_or(ApoError::AllValuesNaN)?;
216 let short = input.get_short_period();
217 let long = input.get_long_period();
218
219 if short == 0 || long == 0 {
220 return Err(ApoError::InvalidPeriod { short, long });
221 }
222 if short >= long {
223 return Err(ApoError::ShortPeriodNotLessThanLong { short, long });
224 }
225 if (len - first) < long {
226 return Err(ApoError::NotEnoughValidData {
227 needed: long,
228 valid: len - first,
229 });
230 }
231
232 let mut chosen = match kernel {
233 Kernel::Auto => Kernel::Scalar,
234 k => k,
235 };
236
237 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
238 if matches!(kernel, Kernel::Auto) && matches!(chosen, Kernel::Avx2 | Kernel::Avx512) {
239 chosen = Kernel::Scalar;
240 }
241 Ok((data, first, short, long, len, chosen))
242}
243
244#[inline(always)]
245fn apo_compute_into(
246 data: &[f64],
247 first: usize,
248 short: usize,
249 long: usize,
250 kernel: Kernel,
251 out: &mut [f64],
252) {
253 unsafe {
254 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
255 {}
256
257 match kernel {
258 Kernel::Scalar | Kernel::ScalarBatch => apo_scalar(data, short, long, first, out),
259 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
260 Kernel::Avx2 | Kernel::Avx2Batch => apo_avx2(data, short, long, first, out),
261 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
262 Kernel::Avx512 | Kernel::Avx512Batch => apo_avx512(data, short, long, first, out),
263 _ => unreachable!(),
264 }
265 }
266}
267
268pub fn apo_with_kernel(input: &ApoInput, kernel: Kernel) -> Result<ApoOutput, ApoError> {
269 let (data, first, short, long, len, chosen) = apo_prepare(input, kernel)?;
270
271 let warmup_period = first;
272
273 let mut out = alloc_with_nan_prefix(len, warmup_period);
274
275 apo_compute_into(data, first, short, long, chosen, &mut out);
276
277 Ok(ApoOutput { values: out })
278}
279
280#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
281pub fn apo_into(input: &ApoInput, out: &mut [f64]) -> Result<(), ApoError> {
282 let (data, first, short, long, len, chosen) = apo_prepare(input, Kernel::Auto)?;
283 if out.len() != len {
284 return Err(ApoError::OutputLengthMismatch {
285 expected: len,
286 got: out.len(),
287 });
288 }
289
290 if first > 0 {
291 for v in &mut out[..first] {
292 *v = f64::from_bits(0x7ff8_0000_0000_0000);
293 }
294 }
295
296 apo_compute_into(data, first, short, long, chosen, out);
297 Ok(())
298}
299
300#[inline(always)]
301pub fn apo_scalar(data: &[f64], short: usize, long: usize, first: usize, out: &mut [f64]) {
302 let alpha_s = 2.0 / (short as f64 + 1.0);
303 let alpha_l = 2.0 / (long as f64 + 1.0);
304 let oma_s = 1.0 - alpha_s;
305 let oma_l = 1.0 - alpha_l;
306
307 let n = data.len();
308 debug_assert_eq!(out.len(), n);
309
310 let mut se = data[first];
311 let mut le = se;
312 out[first] = 0.0;
313
314 let mut i = first + 1;
315 while i + 1 < n {
316 let p0 = data[i];
317 se = alpha_s * p0 + oma_s * se;
318 le = alpha_l * p0 + oma_l * le;
319 out[i] = se - le;
320
321 let p1 = data[i + 1];
322 se = alpha_s * p1 + oma_s * se;
323 le = alpha_l * p1 + oma_l * le;
324 out[i + 1] = se - le;
325
326 i += 2;
327 }
328
329 if i < n {
330 let p = data[i];
331 se = alpha_s * p + oma_s * se;
332 le = alpha_l * p + oma_l * le;
333 out[i] = se - le;
334 }
335}
336
337#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
338#[inline]
339#[target_feature(enable = "avx2")]
340pub unsafe fn apo_avx2(data: &[f64], short: usize, long: usize, first: usize, out: &mut [f64]) {
341 use core::arch::x86_64::*;
342
343 let alpha_s = 2.0 / (short as f64 + 1.0);
344 let alpha_l = 2.0 / (long as f64 + 1.0);
345 let oma_s = 1.0 - alpha_s;
346 let oma_l = 1.0 - alpha_l;
347
348 let n = data.len();
349 debug_assert_eq!(out.len(), n);
350
351 let mut i = first;
352 let x0 = *data.get_unchecked(i);
353
354 let mut ema = _mm256_set_pd(x0, x0, x0, x0);
355
356 let a = _mm256_set_pd(alpha_l, alpha_s, alpha_l, alpha_s);
357 let oma = _mm256_set_pd(oma_l, oma_s, oma_l, oma_s);
358
359 *out.get_unchecked_mut(i) = 0.0;
360 i += 1;
361
362 while i < n {
363 let p = _mm256_set1_pd(*data.get_unchecked(i));
364
365 let t1 = _mm256_mul_pd(a, p);
366 let t2 = _mm256_mul_pd(oma, ema);
367 ema = _mm256_add_pd(t1, t2);
368
369 let swapped = _mm256_permute_pd(ema, 0x5);
370 let diff = _mm256_sub_pd(ema, swapped);
371
372 let apo_val = _mm256_cvtsd_f64(diff);
373 *out.get_unchecked_mut(i) = apo_val;
374
375 i += 1;
376 }
377}
378
379#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
380#[inline]
381#[target_feature(enable = "avx512f")]
382pub unsafe fn apo_avx512(data: &[f64], short: usize, long: usize, first: usize, out: &mut [f64]) {
383 apo_avx512_short(data, short, long, first, out);
384}
385
386#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
387#[inline]
388#[target_feature(enable = "avx512f")]
389pub unsafe fn apo_avx512_short(
390 data: &[f64],
391 short: usize,
392 long: usize,
393 first: usize,
394 out: &mut [f64],
395) {
396 use core::arch::x86_64::*;
397
398 let alpha_s = 2.0 / (short as f64 + 1.0);
399 let alpha_l = 2.0 / (long as f64 + 1.0);
400 let oma_s = 1.0 - alpha_s;
401 let oma_l = 1.0 - alpha_l;
402
403 let n = data.len();
404 debug_assert_eq!(out.len(), n);
405
406 let mut i = first;
407 let x0 = *data.get_unchecked(i);
408
409 let mut ema = _mm512_set_pd(x0, x0, x0, x0, x0, x0, x0, x0);
410
411 let a = _mm512_set_pd(
412 alpha_l, alpha_s, alpha_l, alpha_s, alpha_l, alpha_s, alpha_l, alpha_s,
413 );
414 let oma = _mm512_set_pd(oma_l, oma_s, oma_l, oma_s, oma_l, oma_s, oma_l, oma_s);
415
416 *out.get_unchecked_mut(i) = 0.0;
417 i += 1;
418
419 while i < n {
420 let p = _mm512_set1_pd(*data.get_unchecked(i));
421
422 let t1 = _mm512_mul_pd(a, p);
423 let t2 = _mm512_mul_pd(oma, ema);
424 ema = _mm512_add_pd(t1, t2);
425
426 let swapped = _mm512_permute_pd(ema, 0b01010101);
427 let diff = _mm512_sub_pd(ema, swapped);
428
429 let low128 = _mm512_castpd512_pd128(diff);
430 let apo_val = _mm_cvtsd_f64(low128);
431 *out.get_unchecked_mut(i) = apo_val;
432
433 i += 1;
434 }
435}
436
437#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
438#[inline]
439#[target_feature(enable = "avx512f")]
440pub unsafe fn apo_avx512_long(
441 data: &[f64],
442 short: usize,
443 long: usize,
444 first: usize,
445 out: &mut [f64],
446) {
447 apo_avx512_short(data, short, long, first, out);
448}
449
450#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
451#[inline(always)]
452#[allow(dead_code)]
453unsafe fn apo_simd128(data: &[f64], short: usize, long: usize, first: usize, out: &mut [f64]) {
454 use core::arch::wasm32::*;
455
456 let alpha_short = 2.0 / (short as f64 + 1.0);
457 let alpha_long = 2.0 / (long as f64 + 1.0);
458
459 let one_minus_alpha_short = 1.0 - alpha_short;
460 let one_minus_alpha_long = 1.0 - alpha_long;
461
462 let mut short_ema = data[first];
463 let mut long_ema = data[first];
464
465 out[first] = 0.0;
466
467 let alpha_short_vec = f64x2_splat(alpha_short);
468 let alpha_long_vec = f64x2_splat(alpha_long);
469 let one_minus_alpha_short_vec = f64x2_splat(one_minus_alpha_short);
470 let one_minus_alpha_long_vec = f64x2_splat(one_minus_alpha_long);
471
472 let mut i = first + 1;
473
474 while i + 1 < data.len() {
475 let price_vec = v128_load(&data[i] as *const f64 as *const v128);
476
477 let short_ema_vec = f64x2_splat(short_ema);
478 let long_ema_vec = f64x2_splat(long_ema);
479
480 let new_short_ema_vec = f64x2_add(
481 f64x2_mul(alpha_short_vec, price_vec),
482 f64x2_mul(one_minus_alpha_short_vec, short_ema_vec),
483 );
484
485 let new_long_ema_vec = f64x2_add(
486 f64x2_mul(alpha_long_vec, price_vec),
487 f64x2_mul(one_minus_alpha_long_vec, long_ema_vec),
488 );
489
490 let apo_vec = f64x2_sub(new_short_ema_vec, new_long_ema_vec);
491
492 v128_store(&mut out[i] as *mut f64 as *mut v128, apo_vec);
493
494 short_ema = f64x2_extract_lane::<1>(new_short_ema_vec);
495 long_ema = f64x2_extract_lane::<1>(new_long_ema_vec);
496
497 i += 2;
498 }
499
500 if i < data.len() {
501 let price = data[i];
502 short_ema = alpha_short * price + one_minus_alpha_short * short_ema;
503 long_ema = alpha_long * price + one_minus_alpha_long * long_ema;
504 out[i] = short_ema - long_ema;
505 }
506}
507
508#[derive(Clone, Debug)]
509pub struct ApoStream {
510 short: usize,
511 long: usize,
512 alpha_short: f64,
513 alpha_long: f64,
514
515 oma_short: f64,
516 oma_long: f64,
517 short_ema: f64,
518 long_ema: f64,
519 filled: bool,
520 nan_leading: usize,
521 seen: usize,
522}
523
524impl ApoStream {
525 #[inline(always)]
526 pub fn try_new(params: ApoParams) -> Result<Self, ApoError> {
527 let short = params.short_period.unwrap_or(10);
528 let long = params.long_period.unwrap_or(20);
529 if short == 0 || long == 0 {
530 return Err(ApoError::InvalidPeriod { short, long });
531 }
532 if short >= long {
533 return Err(ApoError::ShortPeriodNotLessThanLong { short, long });
534 }
535
536 let alpha_short = 2.0 / (short as f64 + 1.0);
537 let alpha_long = 2.0 / (long as f64 + 1.0);
538 Ok(Self {
539 short,
540 long,
541 alpha_short,
542 alpha_long,
543 oma_short: 1.0 - alpha_short,
544 oma_long: 1.0 - alpha_long,
545 short_ema: f64::NAN,
546 long_ema: f64::NAN,
547 filled: false,
548 nan_leading: 0,
549 seen: 0,
550 })
551 }
552
553 #[inline(always)]
554 pub fn update(&mut self, price: f64) -> Option<f64> {
555 if !self.filled {
556 if price.is_nan() {
557 self.nan_leading += 1;
558 return None;
559 }
560 self.short_ema = price;
561 self.long_ema = price;
562 self.filled = true;
563 self.seen = 1;
564 return Some(0.0);
565 }
566
567 self.seen += 1;
568
569 if price.is_nan() {
570 self.short_ema = f64::NAN;
571 self.long_ema = f64::NAN;
572 return Some(f64::NAN);
573 }
574
575 let se_prev = self.short_ema;
576 let le_prev = self.long_ema;
577 self.short_ema = self.alpha_short * price + self.oma_short * se_prev;
578 self.long_ema = self.alpha_long * price + self.oma_long * le_prev;
579 Some(self.short_ema - self.long_ema)
580 }
581
582 #[inline(always)]
583 pub fn update_fastmath(&mut self, price: f64) -> Option<f64> {
584 if !self.filled {
585 if price.is_nan() {
586 self.nan_leading += 1;
587 return None;
588 }
589 self.short_ema = price;
590 self.long_ema = price;
591 self.filled = true;
592 self.seen = 1;
593 return Some(0.0);
594 }
595
596 self.seen += 1;
597
598 if price.is_nan() {
599 self.short_ema = f64::NAN;
600 self.long_ema = f64::NAN;
601 return Some(f64::NAN);
602 }
603
604 let ds = price - self.short_ema;
605 let dl = price - self.long_ema;
606 self.short_ema = ds.mul_add(self.alpha_short, self.short_ema);
607 self.long_ema = dl.mul_add(self.alpha_long, self.long_ema);
608 Some(self.short_ema - self.long_ema)
609 }
610}
611
612pub fn apo_into_slice(dst: &mut [f64], input: &ApoInput, kern: Kernel) -> Result<(), ApoError> {
613 let (data, first, short, long, len, chosen) = apo_prepare(input, kern)?;
614 if dst.len() != len {
615 return Err(ApoError::OutputLengthMismatch {
616 expected: len,
617 got: dst.len(),
618 });
619 }
620 apo_compute_into(data, first, short, long, chosen, dst);
621 for v in &mut dst[..first] {
622 *v = f64::NAN;
623 }
624 Ok(())
625}
626
627#[derive(Clone, Debug)]
628pub struct ApoBatchRange {
629 pub short: (usize, usize, usize),
630 pub long: (usize, usize, usize),
631}
632impl Default for ApoBatchRange {
633 fn default() -> Self {
634 Self {
635 short: (10, 10, 0),
636 long: (20, 269, 1),
637 }
638 }
639}
640
641#[derive(Clone, Debug, Default)]
642pub struct ApoBatchBuilder {
643 range: ApoBatchRange,
644 kernel: Kernel,
645}
646impl ApoBatchBuilder {
647 pub fn new() -> Self {
648 Self::default()
649 }
650 pub fn kernel(mut self, k: Kernel) -> Self {
651 self.kernel = k;
652 self
653 }
654 pub fn short_range(mut self, start: usize, end: usize, step: usize) -> Self {
655 self.range.short = (start, end, step);
656 self
657 }
658 pub fn short_static(mut self, s: usize) -> Self {
659 self.range.short = (s, s, 0);
660 self
661 }
662 pub fn long_range(mut self, start: usize, end: usize, step: usize) -> Self {
663 self.range.long = (start, end, step);
664 self
665 }
666 pub fn long_static(mut self, s: usize) -> Self {
667 self.range.long = (s, s, 0);
668 self
669 }
670 pub fn apply_slice(self, data: &[f64]) -> Result<ApoBatchOutput, ApoError> {
671 apo_batch_with_kernel(data, &self.range, self.kernel)
672 }
673 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<ApoBatchOutput, ApoError> {
674 ApoBatchBuilder::new().kernel(k).apply_slice(data)
675 }
676 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<ApoBatchOutput, ApoError> {
677 let slice = source_type(c, src);
678 self.apply_slice(slice)
679 }
680 pub fn with_default_candles(c: &Candles) -> Result<ApoBatchOutput, ApoError> {
681 ApoBatchBuilder::new()
682 .kernel(Kernel::Auto)
683 .apply_candles(c, "close")
684 }
685}
686
687#[derive(Clone, Debug)]
688pub struct ApoBatchOutput {
689 pub values: Vec<f64>,
690 pub combos: Vec<ApoParams>,
691 pub rows: usize,
692 pub cols: usize,
693}
694impl ApoBatchOutput {
695 pub fn row_for_params(&self, p: &ApoParams) -> Option<usize> {
696 self.combos.iter().position(|c| {
697 c.short_period.unwrap_or(10) == p.short_period.unwrap_or(10)
698 && c.long_period.unwrap_or(20) == p.long_period.unwrap_or(20)
699 })
700 }
701 pub fn values_for(&self, p: &ApoParams) -> Option<&[f64]> {
702 self.row_for_params(p).map(|row| {
703 let start = row * self.cols;
704 &self.values[start..start + self.cols]
705 })
706 }
707}
708
709#[inline(always)]
710fn expand_grid(r: &ApoBatchRange) -> Result<Vec<ApoParams>, ApoError> {
711 fn axis((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, ApoError> {
712 if step == 0 || start == end {
713 return Ok(vec![start]);
714 }
715 let mut v = Vec::new();
716 if start < end {
717 let mut cur = start;
718 while cur <= end {
719 v.push(cur);
720 match cur.checked_add(step) {
721 Some(n) => cur = n,
722 None => break,
723 }
724 }
725 } else {
726 let mut cur = start;
727 while cur >= end {
728 v.push(cur);
729 if let Some(n) = cur.checked_sub(step) {
730 cur = n;
731 } else {
732 break;
733 }
734 if cur == usize::MAX {
735 break;
736 }
737 }
738 }
739 if v.is_empty() {
740 return Err(ApoError::InvalidRange { start, end, step });
741 }
742 Ok(v)
743 }
744 let shorts = axis(r.short)?;
745 let longs = axis(r.long)?;
746 let mut out = Vec::with_capacity(shorts.len().saturating_mul(longs.len()));
747 for &s in &shorts {
748 for &l in &longs {
749 if s < l && s > 0 && l > 0 {
750 out.push(ApoParams {
751 short_period: Some(s),
752 long_period: Some(l),
753 });
754 }
755 }
756 }
757 Ok(out)
758}
759
760#[inline(always)]
761pub fn apo_batch_with_kernel(
762 data: &[f64],
763 sweep: &ApoBatchRange,
764 k: Kernel,
765) -> Result<ApoBatchOutput, ApoError> {
766 let kernel = match k {
767 Kernel::Auto => detect_best_batch_kernel(),
768 other if other.is_batch() => other,
769 other => return Err(ApoError::InvalidKernelForBatch(other)),
770 };
771 apo_batch_par_slice(data, sweep, kernel)
772}
773
774#[inline(always)]
775pub fn apo_batch_slice(
776 data: &[f64],
777 sweep: &ApoBatchRange,
778 kern: Kernel,
779) -> Result<ApoBatchOutput, ApoError> {
780 apo_batch_inner(data, sweep, kern, false)
781}
782
783#[inline(always)]
784pub fn apo_batch_par_slice(
785 data: &[f64],
786 sweep: &ApoBatchRange,
787 kern: Kernel,
788) -> Result<ApoBatchOutput, ApoError> {
789 apo_batch_inner(data, sweep, kern, true)
790}
791
792#[inline(always)]
793fn apo_batch_inner(
794 data: &[f64],
795 sweep: &ApoBatchRange,
796 kern: Kernel,
797 parallel: bool,
798) -> Result<ApoBatchOutput, ApoError> {
799 let combos = expand_grid(sweep)?;
800 if combos.is_empty() {
801 return Err(ApoError::InvalidRange {
802 start: sweep.short.0,
803 end: sweep.short.1,
804 step: sweep.short.2,
805 });
806 }
807 let first = data
808 .iter()
809 .position(|x| !x.is_nan())
810 .ok_or(ApoError::AllValuesNaN)?;
811 let max_long = combos.iter().map(|c| c.long_period.unwrap()).max().unwrap();
812 if data.len() - first < max_long {
813 return Err(ApoError::NotEnoughValidData {
814 needed: max_long,
815 valid: data.len() - first,
816 });
817 }
818
819 let rows = combos.len();
820 let cols = data.len();
821
822 let _ = rows.checked_mul(cols).ok_or(ApoError::InvalidRange {
823 start: rows,
824 end: cols,
825 step: 0,
826 })?;
827
828 let mut buf_mu = make_uninit_matrix(rows, cols);
829
830 let warm: Vec<usize> = combos.iter().map(|_c| first).collect();
831
832 init_matrix_prefixes(&mut buf_mu, cols, &warm);
833
834 let mut buf_guard = ManuallyDrop::new(buf_mu);
835 let values: &mut [f64] = unsafe {
836 core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
837 };
838
839 match kern {
840 Kernel::Scalar | Kernel::ScalarBatch => {
841 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
842 let s = combos[row].short_period.unwrap();
843 let l = combos[row].long_period.unwrap();
844 apo_row_scalar(data, first, s, l, out_row)
845 };
846 if parallel {
847 #[cfg(not(target_arch = "wasm32"))]
848 {
849 values
850 .par_chunks_mut(cols)
851 .enumerate()
852 .for_each(|(row, slice)| do_row(row, slice));
853 }
854 #[cfg(target_arch = "wasm32")]
855 {
856 for (row, slice) in values.chunks_mut(cols).enumerate() {
857 do_row(row, slice);
858 }
859 }
860 } else {
861 for (row, slice) in values.chunks_mut(cols).enumerate() {
862 do_row(row, slice);
863 }
864 }
865 }
866 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
867 Kernel::Avx2 => {
868 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
869 let s = combos[row].short_period.unwrap();
870 let l = combos[row].long_period.unwrap();
871 apo_row_avx2(data, first, s, l, out_row)
872 };
873 if parallel {
874 #[cfg(not(target_arch = "wasm32"))]
875 values
876 .par_chunks_mut(cols)
877 .enumerate()
878 .for_each(|(row, slice)| do_row(row, slice));
879 #[cfg(target_arch = "wasm32")]
880 for (row, slice) in values.chunks_mut(cols).enumerate() {
881 do_row(row, slice);
882 }
883 } else {
884 for (row, slice) in values.chunks_mut(cols).enumerate() {
885 do_row(row, slice);
886 }
887 }
888 }
889 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
890 Kernel::Avx512 => {
891 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
892 let s = combos[row].short_period.unwrap();
893 let l = combos[row].long_period.unwrap();
894 apo_row_avx512(data, first, s, l, out_row)
895 };
896 if parallel {
897 #[cfg(not(target_arch = "wasm32"))]
898 values
899 .par_chunks_mut(cols)
900 .enumerate()
901 .for_each(|(row, slice)| do_row(row, slice));
902 #[cfg(target_arch = "wasm32")]
903 for (row, slice) in values.chunks_mut(cols).enumerate() {
904 do_row(row, slice);
905 }
906 } else {
907 for (row, slice) in values.chunks_mut(cols).enumerate() {
908 do_row(row, slice);
909 }
910 }
911 }
912 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
913 Kernel::Avx2Batch => {
914 const LANES: usize = 4;
915 let blocks = (rows + LANES - 1) / LANES;
916 let do_block = |b: usize, blk: &mut [f64]| unsafe {
917 let start_row = b * LANES;
918 let end_row = usize::min(start_row + LANES, rows);
919 apo_batch_rows_avx2(data, first, cols, &combos[start_row..end_row], blk);
920 };
921 if parallel {
922 #[cfg(not(target_arch = "wasm32"))]
923 values
924 .par_chunks_mut(cols * LANES)
925 .enumerate()
926 .for_each(|(b, blk)| do_block(b, blk));
927 #[cfg(target_arch = "wasm32")]
928 for (b, blk) in values.chunks_mut(cols * LANES).enumerate() {
929 do_block(b, blk);
930 }
931 } else {
932 for (b, blk) in values.chunks_mut(cols * LANES).enumerate() {
933 do_block(b, blk);
934 }
935 }
936 }
937 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
938 Kernel::Avx512Batch => {
939 const LANES: usize = 8;
940 let blocks = (rows + LANES - 1) / LANES;
941 let do_block = |b: usize, blk: &mut [f64]| unsafe {
942 let start_row = b * LANES;
943 let end_row = usize::min(start_row + LANES, rows);
944 apo_batch_rows_avx512(data, first, cols, &combos[start_row..end_row], blk);
945 };
946 if parallel {
947 #[cfg(not(target_arch = "wasm32"))]
948 values
949 .par_chunks_mut(cols * LANES)
950 .enumerate()
951 .for_each(|(b, blk)| do_block(b, blk));
952 #[cfg(target_arch = "wasm32")]
953 for (b, blk) in values.chunks_mut(cols * LANES).enumerate() {
954 do_block(b, blk);
955 }
956 } else {
957 for (b, blk) in values.chunks_mut(cols * LANES).enumerate() {
958 do_block(b, blk);
959 }
960 }
961 }
962 _ => unreachable!(),
963 }
964
965 let values = unsafe {
966 Vec::from_raw_parts(
967 buf_guard.as_mut_ptr() as *mut f64,
968 buf_guard.len(),
969 buf_guard.capacity(),
970 )
971 };
972
973 Ok(ApoBatchOutput {
974 values,
975 combos,
976 rows,
977 cols,
978 })
979}
980
981#[inline(always)]
982fn apo_batch_inner_into(
983 data: &[f64],
984 sweep: &ApoBatchRange,
985 kern: Kernel,
986 parallel: bool,
987 out: &mut [f64],
988) -> Result<Vec<ApoParams>, ApoError> {
989 let combos = expand_grid(sweep)?;
990 if combos.is_empty() {
991 return Err(ApoError::InvalidRange {
992 start: sweep.short.0,
993 end: sweep.short.1,
994 step: sweep.short.2,
995 });
996 }
997
998 let first = data
999 .iter()
1000 .position(|x| !x.is_nan())
1001 .ok_or(ApoError::AllValuesNaN)?;
1002 let max_long = combos.iter().map(|c| c.long_period.unwrap()).max().unwrap();
1003 if data.len() - first < max_long {
1004 return Err(ApoError::NotEnoughValidData {
1005 needed: max_long,
1006 valid: data.len() - first,
1007 });
1008 }
1009
1010 let rows = combos.len();
1011 let cols = data.len();
1012 let expected = rows.checked_mul(cols).ok_or(ApoError::InvalidRange {
1013 start: rows,
1014 end: cols,
1015 step: 0,
1016 })?;
1017 if out.len() != expected {
1018 return Err(ApoError::OutputLengthMismatch {
1019 expected,
1020 got: out.len(),
1021 });
1022 }
1023
1024 let out_mu: &mut [MaybeUninit<f64>] = unsafe {
1025 core::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
1026 };
1027 let warm = vec![first; rows];
1028 init_matrix_prefixes(out_mu, cols, &warm);
1029
1030 match kern {
1031 Kernel::Scalar | Kernel::ScalarBatch => {
1032 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
1033 let s = combos[row].short_period.unwrap();
1034 let l = combos[row].long_period.unwrap();
1035 let dst: &mut [f64] =
1036 core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
1037 apo_row_scalar(data, first, s, l, dst)
1038 };
1039 if parallel {
1040 #[cfg(not(target_arch = "wasm32"))]
1041 out_mu
1042 .par_chunks_mut(cols)
1043 .enumerate()
1044 .for_each(|(r, s)| do_row(r, s));
1045 #[cfg(target_arch = "wasm32")]
1046 for (r, s) in out_mu.chunks_mut(cols).enumerate() {
1047 do_row(r, s);
1048 }
1049 } else {
1050 for (r, s) in out_mu.chunks_mut(cols).enumerate() {
1051 do_row(r, s);
1052 }
1053 }
1054 }
1055 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1056 Kernel::Avx2 => {
1057 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
1058 let s = combos[row].short_period.unwrap();
1059 let l = combos[row].long_period.unwrap();
1060 let dst: &mut [f64] =
1061 core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
1062 apo_row_avx2(data, first, s, l, dst)
1063 };
1064 if parallel {
1065 #[cfg(not(target_arch = "wasm32"))]
1066 out_mu
1067 .par_chunks_mut(cols)
1068 .enumerate()
1069 .for_each(|(r, s)| do_row(r, s));
1070 #[cfg(target_arch = "wasm32")]
1071 for (r, s) in out_mu.chunks_mut(cols).enumerate() {
1072 do_row(r, s);
1073 }
1074 } else {
1075 for (r, s) in out_mu.chunks_mut(cols).enumerate() {
1076 do_row(r, s);
1077 }
1078 }
1079 }
1080 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1081 Kernel::Avx512 => {
1082 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
1083 let s = combos[row].short_period.unwrap();
1084 let l = combos[row].long_period.unwrap();
1085 let dst: &mut [f64] =
1086 core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
1087 apo_row_avx512(data, first, s, l, dst)
1088 };
1089 if parallel {
1090 #[cfg(not(target_arch = "wasm32"))]
1091 out_mu
1092 .par_chunks_mut(cols)
1093 .enumerate()
1094 .for_each(|(r, s)| do_row(r, s));
1095 #[cfg(target_arch = "wasm32")]
1096 for (r, s) in out_mu.chunks_mut(cols).enumerate() {
1097 do_row(r, s);
1098 }
1099 } else {
1100 for (r, s) in out_mu.chunks_mut(cols).enumerate() {
1101 do_row(r, s);
1102 }
1103 }
1104 }
1105 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1106 Kernel::Avx2Batch => {
1107 const LANES: usize = 4;
1108 let do_block = |b: usize, blk_mu: &mut [MaybeUninit<f64>]| unsafe {
1109 let start_row = b * LANES;
1110 let end_row = usize::min(start_row + LANES, rows);
1111 let blk: &mut [f64] =
1112 core::slice::from_raw_parts_mut(blk_mu.as_mut_ptr() as *mut f64, blk_mu.len());
1113 apo_batch_rows_avx2(data, first, cols, &combos[start_row..end_row], blk);
1114 };
1115 if parallel {
1116 #[cfg(not(target_arch = "wasm32"))]
1117 out_mu
1118 .par_chunks_mut(cols * LANES)
1119 .enumerate()
1120 .for_each(|(b, blk)| do_block(b, blk));
1121 #[cfg(target_arch = "wasm32")]
1122 for (b, blk) in out_mu.chunks_mut(cols * LANES).enumerate() {
1123 do_block(b, blk);
1124 }
1125 } else {
1126 for (b, blk) in out_mu.chunks_mut(cols * LANES).enumerate() {
1127 do_block(b, blk);
1128 }
1129 }
1130 }
1131 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1132 Kernel::Avx512Batch => {
1133 const LANES: usize = 8;
1134 let do_block = |b: usize, blk_mu: &mut [MaybeUninit<f64>]| unsafe {
1135 let start_row = b * LANES;
1136 let end_row = usize::min(start_row + LANES, rows);
1137 let blk: &mut [f64] =
1138 core::slice::from_raw_parts_mut(blk_mu.as_mut_ptr() as *mut f64, blk_mu.len());
1139 apo_batch_rows_avx512(data, first, cols, &combos[start_row..end_row], blk);
1140 };
1141 if parallel {
1142 #[cfg(not(target_arch = "wasm32"))]
1143 out_mu
1144 .par_chunks_mut(cols * LANES)
1145 .enumerate()
1146 .for_each(|(b, blk)| do_block(b, blk));
1147 #[cfg(target_arch = "wasm32")]
1148 for (b, blk) in out_mu.chunks_mut(cols * LANES).enumerate() {
1149 do_block(b, blk);
1150 }
1151 } else {
1152 for (b, blk) in out_mu.chunks_mut(cols * LANES).enumerate() {
1153 do_block(b, blk);
1154 }
1155 }
1156 }
1157 _ => unreachable!(),
1158 }
1159
1160 Ok(combos)
1161}
1162
1163#[inline(always)]
1164pub unsafe fn apo_row_scalar(
1165 data: &[f64],
1166 first: usize,
1167 short: usize,
1168 long: usize,
1169 out: &mut [f64],
1170) {
1171 apo_scalar(data, short, long, first, out)
1172}
1173#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1174#[inline]
1175#[target_feature(enable = "avx2")]
1176pub unsafe fn apo_row_avx2(data: &[f64], first: usize, short: usize, long: usize, out: &mut [f64]) {
1177 apo_avx2(data, short, long, first, out)
1178}
1179#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1180#[inline]
1181#[target_feature(enable = "avx512f")]
1182pub unsafe fn apo_row_avx512(
1183 data: &[f64],
1184 first: usize,
1185 short: usize,
1186 long: usize,
1187 out: &mut [f64],
1188) {
1189 if long <= 32 {
1190 apo_row_avx512_short(data, first, short, long, out)
1191 } else {
1192 apo_row_avx512_long(data, first, short, long, out)
1193 }
1194}
1195#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1196#[inline]
1197#[target_feature(enable = "avx512f")]
1198pub unsafe fn apo_row_avx512_short(
1199 data: &[f64],
1200 first: usize,
1201 short: usize,
1202 long: usize,
1203 out: &mut [f64],
1204) {
1205 apo_avx512_short(data, short, long, first, out)
1206}
1207#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1208#[inline]
1209#[target_feature(enable = "avx512f")]
1210pub unsafe fn apo_row_avx512_long(
1211 data: &[f64],
1212 first: usize,
1213 short: usize,
1214 long: usize,
1215 out: &mut [f64],
1216) {
1217 apo_avx512_long(data, short, long, first, out)
1218}
1219
1220#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1221#[inline]
1222#[target_feature(enable = "avx2")]
1223unsafe fn apo_batch_rows_avx2(
1224 data: &[f64],
1225 first: usize,
1226 cols: usize,
1227 combos_block: &[ApoParams],
1228 out_block: &mut [f64],
1229) {
1230 use core::arch::x86_64::*;
1231 let lanes = 4usize;
1232 let l = combos_block.len();
1233
1234 let mut as_arr = [0.0f64; 4];
1235 let mut al_arr = [0.0f64; 4];
1236 let mut os_arr = [1.0f64; 4];
1237 let mut ol_arr = [1.0f64; 4];
1238 for (j, p) in combos_block.iter().enumerate() {
1239 let s = p.short_period.unwrap_or(10);
1240 let g = p.long_period.unwrap_or(20);
1241 let a_s = 2.0 / (s as f64 + 1.0);
1242 let a_l = 2.0 / (g as f64 + 1.0);
1243 as_arr[j] = a_s;
1244 al_arr[j] = a_l;
1245 os_arr[j] = 1.0 - a_s;
1246 ol_arr[j] = 1.0 - a_l;
1247 }
1248 let a_s = _mm256_setr_pd(as_arr[0], as_arr[1], as_arr[2], as_arr[3]);
1249 let a_l = _mm256_setr_pd(al_arr[0], al_arr[1], al_arr[2], al_arr[3]);
1250 let o_s = _mm256_setr_pd(os_arr[0], os_arr[1], os_arr[2], os_arr[3]);
1251 let o_l = _mm256_setr_pd(ol_arr[0], ol_arr[1], ol_arr[2], ol_arr[3]);
1252
1253 let x0 = *data.get_unchecked(first);
1254 let mut se = _mm256_set1_pd(x0);
1255 let mut le = _mm256_set1_pd(x0);
1256
1257 for j in 0..l {
1258 *out_block.get_unchecked_mut(j * cols + first) = 0.0;
1259 }
1260 let mut i = first + 1;
1261 while i < cols {
1262 let p = _mm256_set1_pd(*data.get_unchecked(i));
1263
1264 let se1 = _mm256_add_pd(_mm256_mul_pd(a_s, p), _mm256_mul_pd(o_s, se));
1265 let le1 = _mm256_add_pd(_mm256_mul_pd(a_l, p), _mm256_mul_pd(o_l, le));
1266 se = se1;
1267 le = le1;
1268
1269 let diff = _mm256_sub_pd(se, le);
1270 let mut tmp: [f64; 4] = [0.0; 4];
1271 _mm256_storeu_pd(tmp.as_mut_ptr(), diff);
1272 for j in 0..l {
1273 *out_block.get_unchecked_mut(j * cols + i) = tmp[j];
1274 }
1275 i += 1;
1276 }
1277}
1278
1279#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1280#[inline]
1281#[target_feature(enable = "avx512f")]
1282unsafe fn apo_batch_rows_avx512(
1283 data: &[f64],
1284 first: usize,
1285 cols: usize,
1286 combos_block: &[ApoParams],
1287 out_block: &mut [f64],
1288) {
1289 use core::arch::x86_64::*;
1290 let lanes = 8usize;
1291 let l = combos_block.len();
1292
1293 let mut as_arr = [0.0f64; 8];
1294 let mut al_arr = [0.0f64; 8];
1295 let mut os_arr = [1.0f64; 8];
1296 let mut ol_arr = [1.0f64; 8];
1297 for (j, p) in combos_block.iter().enumerate() {
1298 let s = p.short_period.unwrap_or(10);
1299 let g = p.long_period.unwrap_or(20);
1300 let a_s = 2.0 / (s as f64 + 1.0);
1301 let a_l = 2.0 / (g as f64 + 1.0);
1302 as_arr[j] = a_s;
1303 al_arr[j] = a_l;
1304 os_arr[j] = 1.0 - a_s;
1305 ol_arr[j] = 1.0 - a_l;
1306 }
1307 let a_s = _mm512_setr_pd(
1308 as_arr[0], as_arr[1], as_arr[2], as_arr[3], as_arr[4], as_arr[5], as_arr[6], as_arr[7],
1309 );
1310 let a_l = _mm512_setr_pd(
1311 al_arr[0], al_arr[1], al_arr[2], al_arr[3], al_arr[4], al_arr[5], al_arr[6], al_arr[7],
1312 );
1313 let o_s = _mm512_setr_pd(
1314 os_arr[0], os_arr[1], os_arr[2], os_arr[3], os_arr[4], os_arr[5], os_arr[6], os_arr[7],
1315 );
1316 let o_l = _mm512_setr_pd(
1317 ol_arr[0], ol_arr[1], ol_arr[2], ol_arr[3], ol_arr[4], ol_arr[5], ol_arr[6], ol_arr[7],
1318 );
1319
1320 let x0 = *data.get_unchecked(first);
1321 let mut se = _mm512_set1_pd(x0);
1322 let mut le = _mm512_set1_pd(x0);
1323
1324 for j in 0..l {
1325 *out_block.get_unchecked_mut(j * cols + first) = 0.0;
1326 }
1327 let mut i = first + 1;
1328 while i < cols {
1329 let p = _mm512_set1_pd(*data.get_unchecked(i));
1330 let se1 = _mm512_add_pd(_mm512_mul_pd(a_s, p), _mm512_mul_pd(o_s, se));
1331 let le1 = _mm512_add_pd(_mm512_mul_pd(a_l, p), _mm512_mul_pd(o_l, le));
1332 se = se1;
1333 le = le1;
1334
1335 let diff = _mm512_sub_pd(se, le);
1336 let mut tmp: [f64; 8] = [0.0; 8];
1337 _mm512_storeu_pd(tmp.as_mut_ptr(), diff);
1338 for j in 0..l {
1339 *out_block.get_unchecked_mut(j * cols + i) = tmp[j];
1340 }
1341 i += 1;
1342 }
1343}
1344
1345#[cfg(test)]
1346mod tests {
1347 use super::*;
1348 use crate::skip_if_unsupported;
1349 use crate::utilities::data_loader::read_candles_from_csv;
1350
1351 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1352 #[test]
1353 fn test_apo_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
1354 let mut data: Vec<f64> = Vec::with_capacity(256);
1355 for _ in 0..5 {
1356 data.push(f64::NAN);
1357 }
1358 for i in 0..251 {
1359 let x = i as f64;
1360 data.push(100.0 + 0.1 * x + (x * 0.05).sin());
1361 }
1362
1363 let input = ApoInput::from_slice(&data, ApoParams::default());
1364
1365 let baseline = apo(&input)?.values;
1366
1367 let mut out = vec![0.0; data.len()];
1368 apo_into(&input, &mut out)?;
1369
1370 assert_eq!(baseline.len(), out.len());
1371
1372 fn eq_or_both_nan(a: f64, b: f64) -> bool {
1373 (a.is_nan() && b.is_nan()) || (a == b) || ((a - b).abs() <= 1e-12)
1374 }
1375
1376 for (i, (a, b)) in baseline
1377 .iter()
1378 .copied()
1379 .zip(out.iter().copied())
1380 .enumerate()
1381 {
1382 assert!(
1383 eq_or_both_nan(a, b),
1384 "mismatch at index {}: api={} into={}",
1385 i,
1386 a,
1387 b
1388 );
1389 }
1390 Ok(())
1391 }
1392
1393 fn check_apo_partial_params(
1394 test_name: &str,
1395 kernel: Kernel,
1396 ) -> Result<(), Box<dyn std::error::Error>> {
1397 skip_if_unsupported!(kernel, test_name);
1398 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1399 let candles = read_candles_from_csv(file_path)?;
1400 let default_params = ApoParams {
1401 short_period: None,
1402 long_period: None,
1403 };
1404 let input = ApoInput::from_candles(&candles, "close", default_params);
1405 let output = apo_with_kernel(&input, kernel)?;
1406 assert_eq!(output.values.len(), candles.close.len());
1407 Ok(())
1408 }
1409
1410 fn check_apo_accuracy(
1411 test_name: &str,
1412 kernel: Kernel,
1413 ) -> Result<(), Box<dyn std::error::Error>> {
1414 skip_if_unsupported!(kernel, test_name);
1415 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1416 let candles = read_candles_from_csv(file_path)?;
1417 let input = ApoInput::with_default_candles(&candles);
1418 let result = apo_with_kernel(&input, kernel)?;
1419 let expected_last_five = [
1420 -429.80100015922653,
1421 -401.64149983850075,
1422 -386.13569657357584,
1423 -357.92775222467753,
1424 -374.13870680232503,
1425 ];
1426 let start_index = result.values.len().saturating_sub(5);
1427 let result_last_five = &result.values[start_index..];
1428 for (i, &value) in result_last_five.iter().enumerate() {
1429 assert!(
1430 (value - expected_last_five[i]).abs() < 1e-1,
1431 "[{}] APO value mismatch at index {}: expected {}, got {}",
1432 test_name,
1433 i,
1434 expected_last_five[i],
1435 value
1436 );
1437 }
1438 Ok(())
1439 }
1440
1441 fn check_apo_default_candles(
1442 test_name: &str,
1443 kernel: Kernel,
1444 ) -> Result<(), Box<dyn std::error::Error>> {
1445 skip_if_unsupported!(kernel, test_name);
1446 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1447 let candles = read_candles_from_csv(file_path)?;
1448 let input = ApoInput::with_default_candles(&candles);
1449 match input.data {
1450 ApoData::Candles { source, .. } => assert_eq!(source, "close"),
1451 _ => panic!("Expected ApoData::Candles"),
1452 }
1453 let output = apo_with_kernel(&input, kernel)?;
1454 assert_eq!(output.values.len(), candles.close.len());
1455 Ok(())
1456 }
1457
1458 fn check_apo_zero_period(
1459 test_name: &str,
1460 kernel: Kernel,
1461 ) -> Result<(), Box<dyn std::error::Error>> {
1462 skip_if_unsupported!(kernel, test_name);
1463 let input_data = [10.0, 20.0, 30.0];
1464 let params = ApoParams {
1465 short_period: Some(0),
1466 long_period: Some(20),
1467 };
1468 let input = ApoInput::from_slice(&input_data, params);
1469 let res = apo_with_kernel(&input, kernel);
1470 assert!(
1471 res.is_err(),
1472 "[{}] APO should fail with zero period",
1473 test_name
1474 );
1475 Ok(())
1476 }
1477
1478 fn check_apo_empty_input(
1479 test_name: &str,
1480 kernel: Kernel,
1481 ) -> Result<(), Box<dyn std::error::Error>> {
1482 skip_if_unsupported!(kernel, test_name);
1483 let empty_data: Vec<f64> = vec![];
1484 let params = ApoParams::default();
1485 let input = ApoInput::from_slice(&empty_data, params);
1486 let result = apo_with_kernel(&input, kernel);
1487 assert!(result.is_err());
1488 Ok(())
1489 }
1490
1491 fn check_apo_streaming(
1492 test_name: &str,
1493 kernel: Kernel,
1494 ) -> Result<(), Box<dyn std::error::Error>> {
1495 skip_if_unsupported!(kernel, test_name);
1496 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1497 let candles = read_candles_from_csv(file_path)?;
1498 let params = ApoParams::default();
1499
1500 let input = ApoInput::from_candles(&candles, "close", params.clone());
1501 let batch_result = apo_with_kernel(&input, kernel)?;
1502
1503 let mut stream = ApoStream::try_new(params)?;
1504 let mut streaming_results = vec![];
1505
1506 for &close in &candles.close {
1507 if let Some(val) = stream.update(close) {
1508 streaming_results.push(val);
1509 } else {
1510 streaming_results.push(f64::NAN);
1511 }
1512 }
1513
1514 assert_eq!(batch_result.values.len(), streaming_results.len());
1515 let first_valid = candles.close.iter().position(|x| !x.is_nan()).unwrap_or(0);
1516
1517 for i in first_valid..batch_result.values.len() {
1518 if !batch_result.values[i].is_nan() && !streaming_results[i].is_nan() {
1519 let diff = (batch_result.values[i] - streaming_results[i]).abs();
1520 assert!(
1521 diff < 1e-10,
1522 "Streaming mismatch at index {}: batch={}, stream={}",
1523 i,
1524 batch_result.values[i],
1525 streaming_results[i]
1526 );
1527 }
1528 }
1529 Ok(())
1530 }
1531
1532 fn check_apo_period_invalid(
1533 test_name: &str,
1534 kernel: Kernel,
1535 ) -> Result<(), Box<dyn std::error::Error>> {
1536 skip_if_unsupported!(kernel, test_name);
1537 let data_small = [10.0, 20.0, 30.0];
1538 let params = ApoParams {
1539 short_period: Some(20),
1540 long_period: Some(10),
1541 };
1542 let input = ApoInput::from_slice(&data_small, params);
1543 let res = apo_with_kernel(&input, kernel);
1544 assert!(
1545 res.is_err(),
1546 "[{}] APO should fail with invalid period",
1547 test_name
1548 );
1549 Ok(())
1550 }
1551
1552 fn check_apo_very_small_dataset(
1553 test_name: &str,
1554 kernel: Kernel,
1555 ) -> Result<(), Box<dyn std::error::Error>> {
1556 skip_if_unsupported!(kernel, test_name);
1557 let single_point = [42.0];
1558 let params = ApoParams {
1559 short_period: Some(9),
1560 long_period: Some(10),
1561 };
1562 let input = ApoInput::from_slice(&single_point, params);
1563 let res = apo_with_kernel(&input, kernel);
1564 assert!(
1565 res.is_err(),
1566 "[{}] APO should fail with insufficient data",
1567 test_name
1568 );
1569 Ok(())
1570 }
1571
1572 fn check_apo_reinput(
1573 test_name: &str,
1574 kernel: Kernel,
1575 ) -> Result<(), Box<dyn std::error::Error>> {
1576 skip_if_unsupported!(kernel, test_name);
1577 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1578 let candles = read_candles_from_csv(file_path)?;
1579 let first_params = ApoParams {
1580 short_period: Some(10),
1581 long_period: Some(20),
1582 };
1583 let first_input = ApoInput::from_candles(&candles, "close", first_params);
1584 let first_result = apo_with_kernel(&first_input, kernel)?;
1585 let second_params = ApoParams {
1586 short_period: Some(5),
1587 long_period: Some(15),
1588 };
1589 let second_input = ApoInput::from_slice(&first_result.values, second_params);
1590 let second_result = apo_with_kernel(&second_input, kernel)?;
1591 assert_eq!(second_result.values.len(), first_result.values.len());
1592 Ok(())
1593 }
1594
1595 fn check_apo_nan_handling(
1596 test_name: &str,
1597 kernel: Kernel,
1598 ) -> Result<(), Box<dyn std::error::Error>> {
1599 skip_if_unsupported!(kernel, test_name);
1600 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1601 let candles = read_candles_from_csv(file_path)?;
1602 let input = ApoInput::from_candles(
1603 &candles,
1604 "close",
1605 ApoParams {
1606 short_period: Some(10),
1607 long_period: Some(20),
1608 },
1609 );
1610 let res = apo_with_kernel(&input, kernel)?;
1611 assert_eq!(res.values.len(), candles.close.len());
1612 if res.values.len() > 30 {
1613 for (i, &val) in res.values[30..].iter().enumerate() {
1614 assert!(
1615 !val.is_nan(),
1616 "[{}] Found unexpected NaN at out-index {}",
1617 test_name,
1618 30 + i
1619 );
1620 }
1621 }
1622 Ok(())
1623 }
1624
1625 #[cfg(debug_assertions)]
1626 fn check_apo_no_poison(
1627 test_name: &str,
1628 kernel: Kernel,
1629 ) -> Result<(), Box<dyn std::error::Error>> {
1630 skip_if_unsupported!(kernel, test_name);
1631 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1632 let candles = read_candles_from_csv(file_path)?;
1633
1634 let input = ApoInput::from_candles(&candles, "close", ApoParams::default());
1635 let output = apo_with_kernel(&input, kernel)?;
1636
1637 for (i, &val) in output.values.iter().enumerate() {
1638 if val.is_nan() {
1639 continue;
1640 }
1641
1642 let bits = val.to_bits();
1643
1644 if bits == 0x11111111_11111111 {
1645 panic!(
1646 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {}",
1647 test_name, val, bits, i
1648 );
1649 }
1650
1651 if bits == 0x22222222_22222222 {
1652 panic!(
1653 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {}",
1654 test_name, val, bits, i
1655 );
1656 }
1657
1658 if bits == 0x33333333_33333333 {
1659 panic!(
1660 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {}",
1661 test_name, val, bits, i
1662 );
1663 }
1664 }
1665
1666 Ok(())
1667 }
1668
1669 #[cfg(not(debug_assertions))]
1670 fn check_apo_no_poison(
1671 _test_name: &str,
1672 _kernel: Kernel,
1673 ) -> Result<(), Box<dyn std::error::Error>> {
1674 Ok(())
1675 }
1676
1677 #[cfg(feature = "proptest")]
1678 #[allow(clippy::float_cmp)]
1679 fn check_apo_property(
1680 test_name: &str,
1681 kernel: Kernel,
1682 ) -> Result<(), Box<dyn std::error::Error>> {
1683 use proptest::prelude::*;
1684 skip_if_unsupported!(kernel, test_name);
1685
1686 let random_data_strat = (3usize..=20, 10usize..=50)
1687 .prop_filter("short < long", |(s, l)| s < l)
1688 .prop_flat_map(|(short_period, long_period)| {
1689 let len = long_period * 2..400;
1690 (
1691 prop::collection::vec(
1692 (10f64..10000f64).prop_filter("finite", |x| x.is_finite()),
1693 len,
1694 ),
1695 Just(short_period),
1696 Just(long_period),
1697 Just("random"),
1698 )
1699 });
1700
1701 let constant_data_strat = (3usize..=20, 10usize..=50)
1702 .prop_filter("short < long", |(s, l)| s < l)
1703 .prop_flat_map(|(short_period, long_period)| {
1704 let len = long_period * 2..200;
1705 (
1706 prop::collection::vec(Just(100.0f64), len),
1707 Just(short_period),
1708 Just(long_period),
1709 Just("constant"),
1710 )
1711 });
1712
1713 let trending_data_strat = (3usize..=20, 10usize..=50)
1714 .prop_filter("short < long", |(s, l)| s < l)
1715 .prop_flat_map(|(short_period, long_period)| {
1716 let len = long_period * 2..200;
1717 (
1718 (50..150usize).prop_flat_map(move |size| {
1719 (0.1f64..5.0).prop_map(move |slope| {
1720 (0..size)
1721 .map(|i| 100.0 + slope * i as f64)
1722 .collect::<Vec<f64>>()
1723 })
1724 }),
1725 Just(short_period),
1726 Just(long_period),
1727 Just("trending"),
1728 )
1729 });
1730
1731 let strat = prop_oneof![random_data_strat, constant_data_strat, trending_data_strat,];
1732
1733 proptest::test_runner::TestRunner::default()
1734 .run(&strat, |(data, short_period, long_period, data_type)| {
1735 let params = ApoParams {
1736 short_period: Some(short_period),
1737 long_period: Some(long_period),
1738 };
1739 let input = ApoInput::from_slice(&data, params.clone());
1740
1741 let result = apo_with_kernel(&input, kernel);
1742 prop_assert!(result.is_ok(), "APO computation failed: {:?}", result);
1743
1744 let ApoOutput { values: out } = result.unwrap();
1745
1746 prop_assert_eq!(out.len(), data.len(), "Output length mismatch");
1747
1748 let first_valid = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1749 if first_valid < data.len() {
1750 prop_assert!(
1751 out[first_valid].abs() < 1e-10,
1752 "First APO value should be 0, got {} at index {}",
1753 out[first_valid],
1754 first_valid
1755 );
1756 }
1757
1758 for i in first_valid..out.len() {
1759 prop_assert!(
1760 out[i].is_finite(),
1761 "APO output at index {} should be finite, got {}",
1762 i,
1763 out[i]
1764 );
1765 }
1766
1767 let data_min = data
1768 .iter()
1769 .filter(|x| x.is_finite())
1770 .fold(f64::INFINITY, |a, &b| a.min(b));
1771 let data_max = data
1772 .iter()
1773 .filter(|x| x.is_finite())
1774 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
1775 let data_range = data_max - data_min;
1776
1777 let apo_bound = data_range * 0.3;
1778
1779 for i in first_valid..out.len() {
1780 prop_assert!(
1781 out[i].abs() <= apo_bound,
1782 "APO value at index {} exceeds expected bound: {} > {}",
1783 i,
1784 out[i].abs(),
1785 apo_bound
1786 );
1787 }
1788
1789 match data_type {
1790 "constant" => {
1791 for i in first_valid..out.len() {
1792 prop_assert!(
1793 out[i].abs() < 1e-9,
1794 "APO should be ~0 for constant data, got {} at index {}",
1795 out[i],
1796 i
1797 );
1798 }
1799 }
1800 "trending" => {
1801 if data.len() > long_period * 2 {
1802 let check_start = first_valid + long_period;
1803 let check_end = out.len();
1804 if check_start < check_end {
1805 let is_increasing = data[first_valid] < data[data.len() - 1];
1806
1807 let positive_count = out[check_start..check_end]
1808 .iter()
1809 .filter(|&&v| v > 0.0)
1810 .count();
1811 let total_count = check_end - check_start;
1812
1813 if is_increasing {
1814 prop_assert!(
1815 positive_count > total_count / 2,
1816 "APO should be mostly positive for uptrend, got {} positive out of {}",
1817 positive_count,
1818 total_count
1819 );
1820 } else {
1821 prop_assert!(
1822 positive_count < total_count / 2,
1823 "APO should be mostly negative for downtrend, got {} positive out of {}",
1824 positive_count,
1825 total_count
1826 );
1827 }
1828 }
1829 }
1830 }
1831 _ => {}
1832 }
1833
1834 if data.len() >= 3 && first_valid + 2 < data.len() {
1835 let alpha_short = 2.0 / (short_period as f64 + 1.0);
1836 let alpha_long = 2.0 / (long_period as f64 + 1.0);
1837
1838 let mut short_ema = data[first_valid];
1839 let mut long_ema = data[first_valid];
1840 let expected_first = 0.0;
1841 prop_assert!(
1842 (out[first_valid] - expected_first).abs() < 1e-9,
1843 "First value mismatch: expected {}, got {}",
1844 expected_first,
1845 out[first_valid]
1846 );
1847
1848 if first_valid + 1 < data.len() {
1849 let price = data[first_valid + 1];
1850 short_ema = alpha_short * price + (1.0 - alpha_short) * short_ema;
1851 long_ema = alpha_long * price + (1.0 - alpha_long) * long_ema;
1852 let expected_second = short_ema - long_ema;
1853 prop_assert!(
1854 (out[first_valid + 1] - expected_second).abs() < 1e-9,
1855 "Second value mismatch: expected {}, got {}",
1856 expected_second,
1857 out[first_valid + 1]
1858 );
1859 }
1860
1861 if first_valid + 2 < data.len() {
1862 let price = data[first_valid + 2];
1863 short_ema = alpha_short * price + (1.0 - alpha_short) * short_ema;
1864 long_ema = alpha_long * price + (1.0 - alpha_long) * long_ema;
1865 let expected_third = short_ema - long_ema;
1866 prop_assert!(
1867 (out[first_valid + 2] - expected_third).abs() < 1e-9,
1868 "Third value mismatch: expected {}, got {}",
1869 expected_third,
1870 out[first_valid + 2]
1871 );
1872 }
1873 }
1874
1875 let ref_output = apo_with_kernel(&input, Kernel::Scalar);
1876 prop_assert!(ref_output.is_ok(), "Reference scalar computation failed");
1877 let ApoOutput { values: ref_out } = ref_output.unwrap();
1878
1879 for (i, (&val, &ref_val)) in out.iter().zip(ref_out.iter()).enumerate() {
1880 if !val.is_finite() || !ref_val.is_finite() {
1881 prop_assert_eq!(
1882 val.is_nan(),
1883 ref_val.is_nan(),
1884 "NaN mismatch at index {}: kernel={}, scalar={}",
1885 i,
1886 val,
1887 ref_val
1888 );
1889 } else {
1890 let diff = (val - ref_val).abs();
1891 let ulp_diff = val.to_bits().abs_diff(ref_val.to_bits());
1892 prop_assert!(
1893 diff <= 1e-9 || ulp_diff <= 4,
1894 "Kernel mismatch at index {}: {} vs {} (diff: {}, ULP: {})",
1895 i,
1896 val,
1897 ref_val,
1898 diff,
1899 ulp_diff
1900 );
1901 }
1902 }
1903
1904 prop_assert!(
1905 short_period < long_period,
1906 "Short period must be less than long period"
1907 );
1908
1909 let mut stream = ApoStream::try_new(params).unwrap();
1910 let mut stream_values = Vec::new();
1911 for &price in &data {
1912 if let Some(val) = stream.update(price) {
1913 stream_values.push(val);
1914 } else {
1915 stream_values.push(f64::NAN);
1916 }
1917 }
1918
1919 for i in first_valid..out.len() {
1920 if out[i].is_finite() && stream_values[i].is_finite() {
1921 let diff = (out[i] - stream_values[i]).abs();
1922 prop_assert!(
1923 diff < 1e-10,
1924 "Streaming mismatch at index {}: batch={}, stream={}, diff={}",
1925 i,
1926 out[i],
1927 stream_values[i],
1928 diff
1929 );
1930 }
1931 }
1932
1933 Ok(())
1934 })
1935 .map_err(|e| e.into())
1936 }
1937
1938 #[cfg(not(feature = "proptest"))]
1939 fn check_apo_property(
1940 test_name: &str,
1941 kernel: Kernel,
1942 ) -> Result<(), Box<dyn std::error::Error>> {
1943 skip_if_unsupported!(kernel, test_name);
1944 Ok(())
1945 }
1946
1947 macro_rules! generate_all_apo_tests {
1948 ($($test_fn:ident),*) => {
1949 paste::paste! {
1950 $(
1951 #[test]
1952 fn [<$test_fn _scalar_f64>]() {
1953 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1954 }
1955 )*
1956 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1957 $(
1958 #[test]
1959 fn [<$test_fn _avx2_f64>]() {
1960 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1961 }
1962 #[test]
1963 fn [<$test_fn _avx512_f64>]() {
1964 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1965 }
1966 )*
1967 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
1968 $(
1969 #[test]
1970 fn [<$test_fn _simd128_f64>]() {
1971 let _ = $test_fn(stringify!([<$test_fn _simd128_f64>]), Kernel::Scalar);
1972 }
1973 )*
1974 }
1975 }
1976 }
1977
1978 generate_all_apo_tests!(
1979 check_apo_partial_params,
1980 check_apo_accuracy,
1981 check_apo_default_candles,
1982 check_apo_zero_period,
1983 check_apo_empty_input,
1984 check_apo_streaming,
1985 check_apo_period_invalid,
1986 check_apo_very_small_dataset,
1987 check_apo_reinput,
1988 check_apo_nan_handling,
1989 check_apo_no_poison,
1990 check_apo_property
1991 );
1992
1993 fn check_batch_default_row(
1994 test: &str,
1995 kernel: Kernel,
1996 ) -> Result<(), Box<dyn std::error::Error>> {
1997 skip_if_unsupported!(kernel, test);
1998 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1999 let c = read_candles_from_csv(file)?;
2000 let output = ApoBatchBuilder::new()
2001 .kernel(kernel)
2002 .apply_candles(&c, "close")?;
2003 let def = ApoParams::default();
2004 let row = output.values_for(&def).expect("default row missing");
2005 assert_eq!(row.len(), c.close.len());
2006 Ok(())
2007 }
2008
2009 #[cfg(debug_assertions)]
2010 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
2011 skip_if_unsupported!(kernel, test);
2012
2013 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2014 let c = read_candles_from_csv(file)?;
2015
2016 let test_configs = vec![
2017 (2, 10, 2, 15, 30, 5),
2018 (5, 25, 5, 30, 50, 10),
2019 (10, 20, 5, 25, 45, 10),
2020 (12, 12, 0, 26, 26, 0),
2021 (3, 9, 3, 10, 20, 5),
2022 ];
2023
2024 for (cfg_idx, &(s_start, s_end, s_step, l_start, l_end, l_step)) in
2025 test_configs.iter().enumerate()
2026 {
2027 let output = ApoBatchBuilder::new()
2028 .kernel(kernel)
2029 .short_range(s_start, s_end, s_step)
2030 .long_range(l_start, l_end, l_step)
2031 .apply_candles(&c, "close")?;
2032
2033 for (idx, &val) in output.values.iter().enumerate() {
2034 if val.is_nan() {
2035 continue;
2036 }
2037
2038 let bits = val.to_bits();
2039 let row = idx / output.cols;
2040 let col = idx % output.cols;
2041 let combo = &output.combos[row];
2042
2043 if bits == 0x11111111_11111111 {
2044 panic!(
2045 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2046 at row {} col {} (flat index {}) with params: short={}, long={}",
2047 test,
2048 cfg_idx,
2049 val,
2050 bits,
2051 row,
2052 col,
2053 idx,
2054 combo.short_period.unwrap_or(12),
2055 combo.long_period.unwrap_or(26)
2056 );
2057 }
2058
2059 if bits == 0x22222222_22222222 {
2060 panic!(
2061 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2062 at row {} col {} (flat index {}) with params: short={}, long={}",
2063 test,
2064 cfg_idx,
2065 val,
2066 bits,
2067 row,
2068 col,
2069 idx,
2070 combo.short_period.unwrap_or(12),
2071 combo.long_period.unwrap_or(26)
2072 );
2073 }
2074
2075 if bits == 0x33333333_33333333 {
2076 panic!(
2077 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2078 at row {} col {} (flat index {}) with params: short={}, long={}",
2079 test,
2080 cfg_idx,
2081 val,
2082 bits,
2083 row,
2084 col,
2085 idx,
2086 combo.short_period.unwrap_or(12),
2087 combo.long_period.unwrap_or(26)
2088 );
2089 }
2090 }
2091 }
2092
2093 Ok(())
2094 }
2095
2096 #[cfg(not(debug_assertions))]
2097 fn check_batch_no_poison(
2098 _test: &str,
2099 _kernel: Kernel,
2100 ) -> Result<(), Box<dyn std::error::Error>> {
2101 Ok(())
2102 }
2103
2104 macro_rules! gen_batch_tests {
2105 ($fn_name:ident) => {
2106 paste::paste! {
2107 #[test] fn [<$fn_name _scalar>]() {
2108 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2109 }
2110 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2111 #[test] fn [<$fn_name _avx2>]() {
2112 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2113 }
2114 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2115 #[test] fn [<$fn_name _avx512>]() {
2116 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2117 }
2118 #[test] fn [<$fn_name _auto_detect>]() {
2119 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2120 }
2121 }
2122 };
2123 }
2124 gen_batch_tests!(check_batch_default_row);
2125 gen_batch_tests!(check_batch_no_poison);
2126
2127 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
2128 #[test]
2129 fn test_apo_simd128_correctness() {
2130 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
2131 let short_period = 3;
2132 let long_period = 5;
2133 let params = ApoParams {
2134 short_period: Some(short_period),
2135 long_period: Some(long_period),
2136 };
2137 let input = ApoInput::from_slice(&data, params);
2138
2139 let scalar_output = apo_with_kernel(&input, Kernel::Scalar).unwrap();
2140
2141 let mut pure_scalar_output = vec![f64::NAN; data.len()];
2142 let first = 0;
2143 unsafe {
2144 apo_scalar(
2145 &data,
2146 short_period,
2147 long_period,
2148 first,
2149 &mut pure_scalar_output,
2150 );
2151 }
2152
2153 assert_eq!(scalar_output.values.len(), pure_scalar_output.len());
2154 for (i, (simd_val, scalar_val)) in scalar_output
2155 .values
2156 .iter()
2157 .zip(pure_scalar_output.iter())
2158 .enumerate()
2159 {
2160 if scalar_val.is_nan() {
2161 assert!(simd_val.is_nan(), "SIMD128 NaN mismatch at index {}", i);
2162 } else {
2163 assert!(
2164 (scalar_val - simd_val).abs() < 1e-10,
2165 "SIMD128 mismatch at index {}: scalar={}, simd128={}",
2166 i,
2167 scalar_val,
2168 simd_val
2169 );
2170 }
2171 }
2172 }
2173}
2174
2175#[cfg(feature = "python")]
2176#[pyfunction(name = "apo")]
2177#[pyo3(signature = (data, short_period=10, long_period=20, kernel=None))]
2178pub fn apo_py<'py>(
2179 py: Python<'py>,
2180 data: numpy::PyReadonlyArray1<'py, f64>,
2181 short_period: usize,
2182 long_period: usize,
2183 kernel: Option<&str>,
2184) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
2185 use numpy::{IntoPyArray, PyArrayMethods};
2186
2187 let slice_in = data.as_slice()?;
2188 let kern = validate_kernel(kernel, false)?;
2189
2190 let params = ApoParams {
2191 short_period: Some(short_period),
2192 long_period: Some(long_period),
2193 };
2194 let apo_in = ApoInput::from_slice(slice_in, params);
2195
2196 let result_vec: Vec<f64> = py
2197 .allow_threads(|| apo_with_kernel(&apo_in, kern).map(|o| o.values))
2198 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2199
2200 Ok(result_vec.into_pyarray(py))
2201}
2202
2203#[cfg(feature = "python")]
2204#[pyclass(name = "ApoStream")]
2205pub struct ApoStreamPy {
2206 stream: ApoStream,
2207}
2208
2209#[cfg(feature = "python")]
2210#[pymethods]
2211impl ApoStreamPy {
2212 #[new]
2213 fn new(short_period: usize, long_period: usize) -> PyResult<Self> {
2214 let params = ApoParams {
2215 short_period: Some(short_period),
2216 long_period: Some(long_period),
2217 };
2218 let stream =
2219 ApoStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2220 Ok(ApoStreamPy { stream })
2221 }
2222
2223 fn update(&mut self, value: f64) -> Option<f64> {
2224 self.stream.update(value)
2225 }
2226}
2227
2228#[cfg(feature = "python")]
2229#[pyfunction(name = "apo_batch")]
2230#[pyo3(signature = (data, short_period_range, long_period_range, kernel=None))]
2231pub fn apo_batch_py<'py>(
2232 py: Python<'py>,
2233 data: numpy::PyReadonlyArray1<'py, f64>,
2234 short_period_range: (usize, usize, usize),
2235 long_period_range: (usize, usize, usize),
2236 kernel: Option<&str>,
2237) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
2238 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
2239 use pyo3::types::PyDict;
2240
2241 let slice_in = data.as_slice()?;
2242 let kern = validate_kernel(kernel, true)?;
2243
2244 let sweep = ApoBatchRange {
2245 short: short_period_range,
2246 long: long_period_range,
2247 };
2248 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
2249 if combos.is_empty() {
2250 return Err(PyValueError::new_err("No valid parameter combinations"));
2251 }
2252 let rows = combos.len();
2253 let cols = slice_in.len();
2254
2255 let total = rows
2256 .checked_mul(cols)
2257 .ok_or_else(|| PyValueError::new_err("rows * cols overflow"))?;
2258
2259 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2260 let slice_out = unsafe { out_arr.as_slice_mut()? };
2261
2262 let first = slice_in.iter().position(|x| !x.is_nan()).unwrap_or(0);
2263 let out_mu: &mut [MaybeUninit<f64>] = unsafe {
2264 core::slice::from_raw_parts_mut(
2265 slice_out.as_mut_ptr() as *mut MaybeUninit<f64>,
2266 slice_out.len(),
2267 )
2268 };
2269 let warm: Vec<usize> = std::iter::repeat(first).take(rows).collect();
2270 init_matrix_prefixes(out_mu, cols, &warm);
2271
2272 let combos = py
2273 .allow_threads(|| {
2274 let k = match kern {
2275 Kernel::Auto => detect_best_batch_kernel(),
2276 k => k,
2277 };
2278 let simd = match k {
2279 Kernel::Avx512Batch => Kernel::Avx512,
2280 Kernel::Avx2Batch => Kernel::Avx2,
2281 _ => Kernel::Scalar,
2282 };
2283 apo_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
2284 })
2285 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2286
2287 let dict = PyDict::new(py);
2288 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2289 dict.set_item(
2290 "short_periods",
2291 combos
2292 .iter()
2293 .map(|p| p.short_period.unwrap() as u64)
2294 .collect::<Vec<_>>()
2295 .into_pyarray(py),
2296 )?;
2297 dict.set_item(
2298 "long_periods",
2299 combos
2300 .iter()
2301 .map(|p| p.long_period.unwrap() as u64)
2302 .collect::<Vec<_>>()
2303 .into_pyarray(py),
2304 )?;
2305 Ok(dict)
2306}
2307
2308#[cfg(all(feature = "python", feature = "cuda"))]
2309use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
2310#[cfg(all(feature = "python", feature = "cuda"))]
2311use cust::context::Context as CudaContext;
2312#[cfg(all(feature = "python", feature = "cuda"))]
2313use std::sync::Arc;
2314
2315#[cfg(all(feature = "python", feature = "cuda"))]
2316#[pyclass(module = "ta_indicators.cuda", name = "DeviceArrayF32Apo", unsendable)]
2317pub struct DeviceArrayF32ApoPy {
2318 pub(crate) inner: Option<crate::cuda::moving_averages::apo_wrapper::DeviceArrayF32>,
2319 stream_handle: usize,
2320 _ctx_guard: Arc<CudaContext>,
2321 _device_id: u32,
2322}
2323
2324#[cfg(all(feature = "python", feature = "cuda"))]
2325#[pymethods]
2326impl DeviceArrayF32ApoPy {
2327 #[new]
2328 fn py_new() -> PyResult<Self> {
2329 Err(pyo3::exceptions::PyTypeError::new_err(
2330 "use factory methods from CUDA functions",
2331 ))
2332 }
2333
2334 #[getter]
2335 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2336 let inner = self
2337 .inner
2338 .as_ref()
2339 .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
2340 let d = PyDict::new(py);
2341 let itemsize = std::mem::size_of::<f32>();
2342 d.set_item("shape", (inner.rows, inner.cols))?;
2343 d.set_item("typestr", "<f4")?;
2344 d.set_item("strides", (inner.cols * itemsize, itemsize))?;
2345 let size = inner.rows.saturating_mul(inner.cols);
2346 let ptr_val: usize = if size == 0 {
2347 0
2348 } else {
2349 inner.buf.as_device_ptr().as_raw() as usize
2350 };
2351 d.set_item("data", (ptr_val, false))?;
2352 d.set_item("version", 3)?;
2353 Ok(d)
2354 }
2355
2356 fn __dlpack_device__(&self) -> PyResult<(i32, i32)> {
2357 Ok((2, self._device_id as i32))
2358 }
2359
2360 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2361 fn __dlpack__<'py>(
2362 &mut self,
2363 py: Python<'py>,
2364 stream: Option<PyObject>,
2365 max_version: Option<PyObject>,
2366 dl_device: Option<PyObject>,
2367 copy: Option<PyObject>,
2368 ) -> PyResult<PyObject> {
2369 let (kdl, alloc_dev) = self.__dlpack_device__()?;
2370 if let Some(dev_obj) = dl_device.as_ref() {
2371 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2372 if dev_ty != kdl || dev_id != alloc_dev {
2373 let wants_copy = copy
2374 .as_ref()
2375 .and_then(|c| c.extract::<bool>(py).ok())
2376 .unwrap_or(false);
2377 if wants_copy {
2378 return Err(PyValueError::new_err(
2379 "device copy not implemented for __dlpack__",
2380 ));
2381 } else {
2382 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
2383 }
2384 }
2385 }
2386 }
2387 let _ = stream;
2388
2389 let inner = self
2390 .inner
2391 .take()
2392 .ok_or_else(|| PyValueError::new_err("__dlpack__ may only be called once"))?;
2393 let crate::cuda::moving_averages::apo_wrapper::DeviceArrayF32 {
2394 buf, rows, cols, ..
2395 } = inner;
2396
2397 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2398
2399 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2400 }
2401}
2402
2403#[cfg(all(feature = "python", feature = "cuda"))]
2404#[pyfunction(name = "apo_cuda_batch_dev")]
2405#[pyo3(signature = (data_f32, short_range=(10,10,0), long_range=(20,20,0), device_id=0))]
2406pub fn apo_cuda_batch_dev_py(
2407 py: Python<'_>,
2408 data_f32: numpy::PyReadonlyArray1<'_, f32>,
2409 short_range: (usize, usize, usize),
2410 long_range: (usize, usize, usize),
2411 device_id: usize,
2412) -> PyResult<DeviceArrayF32ApoPy> {
2413 use crate::cuda::cuda_available;
2414 use crate::cuda::moving_averages::CudaApo;
2415 if !cuda_available() {
2416 return Err(PyValueError::new_err("CUDA not available"));
2417 }
2418 let slice = data_f32.as_slice()?;
2419 let sweep = ApoBatchRange {
2420 short: short_range,
2421 long: long_range,
2422 };
2423 let inner = py.allow_threads(|| {
2424 let cuda = CudaApo::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2425 cuda.apo_batch_dev(slice, &sweep)
2426 .map_err(|e| PyValueError::new_err(e.to_string()))
2427 })?;
2428 let ctx = inner.ctx();
2429 let dev_id = inner.device_id();
2430 Ok(DeviceArrayF32ApoPy {
2431 inner: Some(inner),
2432 stream_handle: 0,
2433 _ctx_guard: ctx,
2434 _device_id: dev_id,
2435 })
2436}
2437
2438#[cfg(all(feature = "python", feature = "cuda"))]
2439#[pyfunction(name = "apo_cuda_many_series_one_param_dev")]
2440#[pyo3(signature = (data_tm_f32, short_period, long_period, device_id=0))]
2441pub fn apo_cuda_many_series_one_param_dev_py(
2442 py: Python<'_>,
2443 data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
2444 short_period: usize,
2445 long_period: usize,
2446 device_id: usize,
2447) -> PyResult<DeviceArrayF32ApoPy> {
2448 use crate::cuda::cuda_available;
2449 use crate::cuda::moving_averages::CudaApo;
2450 use numpy::PyUntypedArrayMethods;
2451 if !cuda_available() {
2452 return Err(PyValueError::new_err("CUDA not available"));
2453 }
2454 if short_period == 0 || long_period == 0 || short_period >= long_period {
2455 return Err(PyValueError::new_err("invalid short/long period"));
2456 }
2457 let flat = data_tm_f32.as_slice()?;
2458 let rows = data_tm_f32.shape()[0];
2459 let cols = data_tm_f32.shape()[1];
2460 let params = ApoParams {
2461 short_period: Some(short_period),
2462 long_period: Some(long_period),
2463 };
2464 let inner = py.allow_threads(|| {
2465 let cuda = CudaApo::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2466 cuda.apo_many_series_one_param_time_major_dev(flat, cols, rows, ¶ms)
2467 .map_err(|e| PyValueError::new_err(e.to_string()))
2468 })?;
2469 let ctx = inner.ctx();
2470 let dev_id = inner.device_id();
2471 Ok(DeviceArrayF32ApoPy {
2472 inner: Some(inner),
2473 stream_handle: 0,
2474 _ctx_guard: ctx,
2475 _device_id: dev_id,
2476 })
2477}
2478
2479#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2480#[wasm_bindgen]
2481pub fn apo_js(data: &[f64], short_period: usize, long_period: usize) -> Result<Vec<f64>, JsValue> {
2482 let params = ApoParams {
2483 short_period: Some(short_period),
2484 long_period: Some(long_period),
2485 };
2486 let input = ApoInput::from_slice(data, params);
2487
2488 let mut output = vec![0.0; data.len()];
2489
2490 apo_into_slice(&mut output, &input, Kernel::Auto)
2491 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2492
2493 Ok(output)
2494}
2495
2496#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2497#[wasm_bindgen]
2498pub fn apo_alloc(len: usize) -> *mut f64 {
2499 let mut v = Vec::<f64>::with_capacity(len);
2500 let ptr = v.as_mut_ptr();
2501 std::mem::forget(v);
2502 ptr
2503}
2504
2505#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2506#[wasm_bindgen]
2507pub fn apo_free(ptr: *mut f64, len: usize) {
2508 unsafe {
2509 let _ = Vec::from_raw_parts(ptr, len, len);
2510 }
2511}
2512
2513#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2514#[wasm_bindgen]
2515pub fn apo_into(
2516 in_ptr: *const f64,
2517 out_ptr: *mut f64,
2518 len: usize,
2519 short_period: usize,
2520 long_period: usize,
2521) -> Result<(), JsValue> {
2522 if in_ptr.is_null() || out_ptr.is_null() {
2523 return Err(JsValue::from_str("Null pointer passed to apo_into"));
2524 }
2525 unsafe {
2526 let data = std::slice::from_raw_parts(in_ptr, len);
2527 let params = ApoParams {
2528 short_period: Some(short_period),
2529 long_period: Some(long_period),
2530 };
2531 let input = ApoInput::from_slice(data, params);
2532
2533 if in_ptr == out_ptr {
2534 let mut tmp = vec![0.0; len];
2535 apo_into_slice(&mut tmp, &input, Kernel::Auto)
2536 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2537 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2538 out.copy_from_slice(&tmp);
2539 } else {
2540 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2541 apo_into_slice(out, &input, Kernel::Auto)
2542 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2543 }
2544 }
2545 Ok(())
2546}
2547
2548#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2549#[wasm_bindgen]
2550pub fn apo_batch_js(
2551 data: &[f64],
2552 short_period_start: usize,
2553 short_period_end: usize,
2554 short_period_step: usize,
2555 long_period_start: usize,
2556 long_period_end: usize,
2557 long_period_step: usize,
2558) -> Result<Vec<f64>, JsValue> {
2559 let sweep = ApoBatchRange {
2560 short: (short_period_start, short_period_end, short_period_step),
2561 long: (long_period_start, long_period_end, long_period_step),
2562 };
2563
2564 apo_batch_inner(data, &sweep, Kernel::Scalar, false)
2565 .map(|output| output.values)
2566 .map_err(|e| JsValue::from_str(&e.to_string()))
2567}
2568
2569#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2570#[wasm_bindgen]
2571pub fn apo_batch_metadata_js(
2572 short_period_start: usize,
2573 short_period_end: usize,
2574 short_period_step: usize,
2575 long_period_start: usize,
2576 long_period_end: usize,
2577 long_period_step: usize,
2578) -> Result<Vec<f64>, JsValue> {
2579 let sweep = ApoBatchRange {
2580 short: (short_period_start, short_period_end, short_period_step),
2581 long: (long_period_start, long_period_end, long_period_step),
2582 };
2583
2584 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2585 let mut metadata = Vec::with_capacity(combos.len() * 2);
2586
2587 for combo in combos {
2588 metadata.push(combo.short_period.unwrap() as f64);
2589 metadata.push(combo.long_period.unwrap() as f64);
2590 }
2591
2592 Ok(metadata)
2593}
2594
2595#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2596#[wasm_bindgen]
2597pub fn apo_batch_into(
2598 in_ptr: *const f64,
2599 out_ptr: *mut f64,
2600 len: usize,
2601 short_start: usize,
2602 short_end: usize,
2603 short_step: usize,
2604 long_start: usize,
2605 long_end: usize,
2606 long_step: usize,
2607) -> Result<usize, JsValue> {
2608 if in_ptr.is_null() || out_ptr.is_null() {
2609 return Err(JsValue::from_str("Null pointer passed to apo_batch_into"));
2610 }
2611 unsafe {
2612 let data = std::slice::from_raw_parts(in_ptr, len);
2613 let sweep = ApoBatchRange {
2614 short: (short_start, short_end, short_step),
2615 long: (long_start, long_end, long_step),
2616 };
2617 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2618 let rows = combos.len();
2619 let cols = len;
2620
2621 let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
2622 apo_batch_inner_into(data, &sweep, detect_best_kernel(), false, out)
2623 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2624 Ok(rows)
2625 }
2626}
2627
2628#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2629#[derive(Serialize, Deserialize)]
2630pub struct ApoBatchConfig {
2631 pub short_period_range: (usize, usize, usize),
2632 pub long_period_range: (usize, usize, usize),
2633}
2634
2635#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2636#[derive(Serialize, Deserialize)]
2637pub struct ApoBatchJsOutput {
2638 pub values: Vec<f64>,
2639 pub combos: Vec<ApoParams>,
2640 pub rows: usize,
2641 pub cols: usize,
2642}
2643
2644#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2645#[wasm_bindgen(js_name = apo_batch)]
2646pub fn apo_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2647 let cfg: ApoBatchConfig = serde_wasm_bindgen::from_value(config)
2648 .map_err(|e| JsValue::from_str(&format!("Invalid config: {e}")))?;
2649 let sweep = ApoBatchRange {
2650 short: cfg.short_period_range,
2651 long: cfg.long_period_range,
2652 };
2653 let out = apo_batch_inner(data, &sweep, detect_best_kernel(), false)
2654 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2655 let js = ApoBatchJsOutput {
2656 values: out.values,
2657 combos: out.combos,
2658 rows: out.rows,
2659 cols: out.cols,
2660 };
2661 serde_wasm_bindgen::to_value(&js)
2662 .map_err(|e| JsValue::from_str(&format!("Serialization error: {e}")))
2663}