Skip to main content

vector_ta/indicators/
apo.rs

1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::utilities::data_loader::{source_type, Candles};
16use crate::utilities::enums::Kernel;
17use crate::utilities::helpers::{
18    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
19    make_uninit_matrix,
20};
21#[cfg(feature = "python")]
22use crate::utilities::kernel_validation::validate_kernel;
23#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
24use core::arch::x86_64::*;
25#[cfg(not(target_arch = "wasm32"))]
26use rayon::prelude::*;
27use std::convert::AsRef;
28use std::mem::{ManuallyDrop, MaybeUninit};
29use thiserror::Error;
30
31#[derive(Debug, Clone)]
32pub enum ApoData<'a> {
33    Candles {
34        candles: &'a Candles,
35        source: &'a str,
36    },
37    Slice(&'a [f64]),
38}
39
40impl<'a> AsRef<[f64]> for ApoInput<'a> {
41    #[inline(always)]
42    fn as_ref(&self) -> &[f64] {
43        match &self.data {
44            ApoData::Slice(slice) => slice,
45            ApoData::Candles { candles, source } => source_type(candles, source),
46        }
47    }
48}
49
50#[derive(Debug, Clone)]
51pub struct ApoOutput {
52    pub values: Vec<f64>,
53}
54
55#[derive(Debug, Clone)]
56#[cfg_attr(
57    all(target_arch = "wasm32", feature = "wasm"),
58    derive(Serialize, Deserialize)
59)]
60pub struct ApoParams {
61    pub short_period: Option<usize>,
62    pub long_period: Option<usize>,
63}
64impl Default for ApoParams {
65    fn default() -> Self {
66        Self {
67            short_period: Some(10),
68            long_period: Some(20),
69        }
70    }
71}
72
73#[derive(Debug, Clone)]
74pub struct ApoInput<'a> {
75    pub data: ApoData<'a>,
76    pub params: ApoParams,
77}
78impl<'a> ApoInput<'a> {
79    #[inline]
80    pub fn from_candles(c: &'a Candles, s: &'a str, p: ApoParams) -> Self {
81        Self {
82            data: ApoData::Candles {
83                candles: c,
84                source: s,
85            },
86            params: p,
87        }
88    }
89    #[inline]
90    pub fn from_slice(sl: &'a [f64], p: ApoParams) -> Self {
91        Self {
92            data: ApoData::Slice(sl),
93            params: p,
94        }
95    }
96    #[inline]
97    pub fn with_default_candles(c: &'a Candles) -> Self {
98        Self::from_candles(c, "close", ApoParams::default())
99    }
100    #[inline]
101    pub fn get_short_period(&self) -> usize {
102        self.params.short_period.unwrap_or(10)
103    }
104    #[inline]
105    pub fn get_long_period(&self) -> usize {
106        self.params.long_period.unwrap_or(20)
107    }
108}
109
110#[derive(Debug, Error)]
111pub enum ApoError {
112    #[error("apo: Input data slice is empty.")]
113    EmptyInputData,
114    #[error("apo: All values are NaN.")]
115    AllValuesNaN,
116    #[error("apo: Invalid period: short={short}, long={long}")]
117    InvalidPeriod { short: usize, long: usize },
118    #[error("apo: short_period not less than long_period: short={short}, long={long}")]
119    ShortPeriodNotLessThanLong { short: usize, long: usize },
120    #[error("apo: Not enough valid data: needed={needed}, valid={valid}")]
121    NotEnoughValidData { needed: usize, valid: usize },
122    #[error("apo: output length mismatch: expected={expected}, got={got}")]
123    OutputLengthMismatch { expected: usize, got: usize },
124    #[error("apo: invalid range: start={start}, end={end}, step={step}")]
125    InvalidRange {
126        start: usize,
127        end: usize,
128        step: usize,
129    },
130    #[error("apo: invalid kernel for batch: {0:?}")]
131    InvalidKernelForBatch(Kernel),
132}
133
134#[derive(Copy, Clone, Debug)]
135pub struct ApoBuilder {
136    short_period: Option<usize>,
137    long_period: Option<usize>,
138    kernel: Kernel,
139}
140impl Default for ApoBuilder {
141    fn default() -> Self {
142        Self {
143            short_period: None,
144            long_period: None,
145            kernel: Kernel::Auto,
146        }
147    }
148}
149impl ApoBuilder {
150    #[inline(always)]
151    pub fn new() -> Self {
152        Self::default()
153    }
154    #[inline(always)]
155    pub fn short_period(mut self, n: usize) -> Self {
156        self.short_period = Some(n);
157        self
158    }
159    #[inline(always)]
160    pub fn long_period(mut self, n: usize) -> Self {
161        self.long_period = Some(n);
162        self
163    }
164    #[inline(always)]
165    pub fn kernel(mut self, k: Kernel) -> Self {
166        self.kernel = k;
167        self
168    }
169    #[inline(always)]
170    pub fn apply(self, c: &Candles) -> Result<ApoOutput, ApoError> {
171        let p = ApoParams {
172            short_period: self.short_period,
173            long_period: self.long_period,
174        };
175        let i = ApoInput::from_candles(c, "close", p);
176        apo_with_kernel(&i, self.kernel)
177    }
178    #[inline(always)]
179    pub fn apply_slice(self, d: &[f64]) -> Result<ApoOutput, ApoError> {
180        let p = ApoParams {
181            short_period: self.short_period,
182            long_period: self.long_period,
183        };
184        let i = ApoInput::from_slice(d, p);
185        apo_with_kernel(&i, self.kernel)
186    }
187    #[inline(always)]
188    pub fn into_stream(self) -> Result<ApoStream, ApoError> {
189        let p = ApoParams {
190            short_period: self.short_period,
191            long_period: self.long_period,
192        };
193        ApoStream::try_new(p)
194    }
195}
196
197#[inline]
198pub fn apo(input: &ApoInput) -> Result<ApoOutput, ApoError> {
199    apo_with_kernel(input, Kernel::Auto)
200}
201
202#[inline(always)]
203fn apo_prepare<'a>(
204    input: &'a ApoInput,
205    kernel: Kernel,
206) -> Result<(&'a [f64], usize, usize, usize, usize, Kernel), ApoError> {
207    let data: &[f64] = input.as_ref();
208    let len = data.len();
209    if len == 0 {
210        return Err(ApoError::EmptyInputData);
211    }
212    let first = data
213        .iter()
214        .position(|x| !x.is_nan())
215        .ok_or(ApoError::AllValuesNaN)?;
216    let short = input.get_short_period();
217    let long = input.get_long_period();
218
219    if short == 0 || long == 0 {
220        return Err(ApoError::InvalidPeriod { short, long });
221    }
222    if short >= long {
223        return Err(ApoError::ShortPeriodNotLessThanLong { short, long });
224    }
225    if (len - first) < long {
226        return Err(ApoError::NotEnoughValidData {
227            needed: long,
228            valid: len - first,
229        });
230    }
231
232    let mut chosen = match kernel {
233        Kernel::Auto => Kernel::Scalar,
234        k => k,
235    };
236
237    #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
238    if matches!(kernel, Kernel::Auto) && matches!(chosen, Kernel::Avx2 | Kernel::Avx512) {
239        chosen = Kernel::Scalar;
240    }
241    Ok((data, first, short, long, len, chosen))
242}
243
244#[inline(always)]
245fn apo_compute_into(
246    data: &[f64],
247    first: usize,
248    short: usize,
249    long: usize,
250    kernel: Kernel,
251    out: &mut [f64],
252) {
253    unsafe {
254        #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
255        {}
256
257        match kernel {
258            Kernel::Scalar | Kernel::ScalarBatch => apo_scalar(data, short, long, first, out),
259            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
260            Kernel::Avx2 | Kernel::Avx2Batch => apo_avx2(data, short, long, first, out),
261            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
262            Kernel::Avx512 | Kernel::Avx512Batch => apo_avx512(data, short, long, first, out),
263            _ => unreachable!(),
264        }
265    }
266}
267
268pub fn apo_with_kernel(input: &ApoInput, kernel: Kernel) -> Result<ApoOutput, ApoError> {
269    let (data, first, short, long, len, chosen) = apo_prepare(input, kernel)?;
270
271    let warmup_period = first;
272
273    let mut out = alloc_with_nan_prefix(len, warmup_period);
274
275    apo_compute_into(data, first, short, long, chosen, &mut out);
276
277    Ok(ApoOutput { values: out })
278}
279
280#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
281pub fn apo_into(input: &ApoInput, out: &mut [f64]) -> Result<(), ApoError> {
282    let (data, first, short, long, len, chosen) = apo_prepare(input, Kernel::Auto)?;
283    if out.len() != len {
284        return Err(ApoError::OutputLengthMismatch {
285            expected: len,
286            got: out.len(),
287        });
288    }
289
290    if first > 0 {
291        for v in &mut out[..first] {
292            *v = f64::from_bits(0x7ff8_0000_0000_0000);
293        }
294    }
295
296    apo_compute_into(data, first, short, long, chosen, out);
297    Ok(())
298}
299
300#[inline(always)]
301pub fn apo_scalar(data: &[f64], short: usize, long: usize, first: usize, out: &mut [f64]) {
302    let alpha_s = 2.0 / (short as f64 + 1.0);
303    let alpha_l = 2.0 / (long as f64 + 1.0);
304    let oma_s = 1.0 - alpha_s;
305    let oma_l = 1.0 - alpha_l;
306
307    let n = data.len();
308    debug_assert_eq!(out.len(), n);
309
310    let mut se = data[first];
311    let mut le = se;
312    out[first] = 0.0;
313
314    let mut i = first + 1;
315    while i + 1 < n {
316        let p0 = data[i];
317        se = alpha_s * p0 + oma_s * se;
318        le = alpha_l * p0 + oma_l * le;
319        out[i] = se - le;
320
321        let p1 = data[i + 1];
322        se = alpha_s * p1 + oma_s * se;
323        le = alpha_l * p1 + oma_l * le;
324        out[i + 1] = se - le;
325
326        i += 2;
327    }
328
329    if i < n {
330        let p = data[i];
331        se = alpha_s * p + oma_s * se;
332        le = alpha_l * p + oma_l * le;
333        out[i] = se - le;
334    }
335}
336
337#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
338#[inline]
339#[target_feature(enable = "avx2")]
340pub unsafe fn apo_avx2(data: &[f64], short: usize, long: usize, first: usize, out: &mut [f64]) {
341    use core::arch::x86_64::*;
342
343    let alpha_s = 2.0 / (short as f64 + 1.0);
344    let alpha_l = 2.0 / (long as f64 + 1.0);
345    let oma_s = 1.0 - alpha_s;
346    let oma_l = 1.0 - alpha_l;
347
348    let n = data.len();
349    debug_assert_eq!(out.len(), n);
350
351    let mut i = first;
352    let x0 = *data.get_unchecked(i);
353
354    let mut ema = _mm256_set_pd(x0, x0, x0, x0);
355
356    let a = _mm256_set_pd(alpha_l, alpha_s, alpha_l, alpha_s);
357    let oma = _mm256_set_pd(oma_l, oma_s, oma_l, oma_s);
358
359    *out.get_unchecked_mut(i) = 0.0;
360    i += 1;
361
362    while i < n {
363        let p = _mm256_set1_pd(*data.get_unchecked(i));
364
365        let t1 = _mm256_mul_pd(a, p);
366        let t2 = _mm256_mul_pd(oma, ema);
367        ema = _mm256_add_pd(t1, t2);
368
369        let swapped = _mm256_permute_pd(ema, 0x5);
370        let diff = _mm256_sub_pd(ema, swapped);
371
372        let apo_val = _mm256_cvtsd_f64(diff);
373        *out.get_unchecked_mut(i) = apo_val;
374
375        i += 1;
376    }
377}
378
379#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
380#[inline]
381#[target_feature(enable = "avx512f")]
382pub unsafe fn apo_avx512(data: &[f64], short: usize, long: usize, first: usize, out: &mut [f64]) {
383    apo_avx512_short(data, short, long, first, out);
384}
385
386#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
387#[inline]
388#[target_feature(enable = "avx512f")]
389pub unsafe fn apo_avx512_short(
390    data: &[f64],
391    short: usize,
392    long: usize,
393    first: usize,
394    out: &mut [f64],
395) {
396    use core::arch::x86_64::*;
397
398    let alpha_s = 2.0 / (short as f64 + 1.0);
399    let alpha_l = 2.0 / (long as f64 + 1.0);
400    let oma_s = 1.0 - alpha_s;
401    let oma_l = 1.0 - alpha_l;
402
403    let n = data.len();
404    debug_assert_eq!(out.len(), n);
405
406    let mut i = first;
407    let x0 = *data.get_unchecked(i);
408
409    let mut ema = _mm512_set_pd(x0, x0, x0, x0, x0, x0, x0, x0);
410
411    let a = _mm512_set_pd(
412        alpha_l, alpha_s, alpha_l, alpha_s, alpha_l, alpha_s, alpha_l, alpha_s,
413    );
414    let oma = _mm512_set_pd(oma_l, oma_s, oma_l, oma_s, oma_l, oma_s, oma_l, oma_s);
415
416    *out.get_unchecked_mut(i) = 0.0;
417    i += 1;
418
419    while i < n {
420        let p = _mm512_set1_pd(*data.get_unchecked(i));
421
422        let t1 = _mm512_mul_pd(a, p);
423        let t2 = _mm512_mul_pd(oma, ema);
424        ema = _mm512_add_pd(t1, t2);
425
426        let swapped = _mm512_permute_pd(ema, 0b01010101);
427        let diff = _mm512_sub_pd(ema, swapped);
428
429        let low128 = _mm512_castpd512_pd128(diff);
430        let apo_val = _mm_cvtsd_f64(low128);
431        *out.get_unchecked_mut(i) = apo_val;
432
433        i += 1;
434    }
435}
436
437#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
438#[inline]
439#[target_feature(enable = "avx512f")]
440pub unsafe fn apo_avx512_long(
441    data: &[f64],
442    short: usize,
443    long: usize,
444    first: usize,
445    out: &mut [f64],
446) {
447    apo_avx512_short(data, short, long, first, out);
448}
449
450#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
451#[inline(always)]
452#[allow(dead_code)]
453unsafe fn apo_simd128(data: &[f64], short: usize, long: usize, first: usize, out: &mut [f64]) {
454    use core::arch::wasm32::*;
455
456    let alpha_short = 2.0 / (short as f64 + 1.0);
457    let alpha_long = 2.0 / (long as f64 + 1.0);
458
459    let one_minus_alpha_short = 1.0 - alpha_short;
460    let one_minus_alpha_long = 1.0 - alpha_long;
461
462    let mut short_ema = data[first];
463    let mut long_ema = data[first];
464
465    out[first] = 0.0;
466
467    let alpha_short_vec = f64x2_splat(alpha_short);
468    let alpha_long_vec = f64x2_splat(alpha_long);
469    let one_minus_alpha_short_vec = f64x2_splat(one_minus_alpha_short);
470    let one_minus_alpha_long_vec = f64x2_splat(one_minus_alpha_long);
471
472    let mut i = first + 1;
473
474    while i + 1 < data.len() {
475        let price_vec = v128_load(&data[i] as *const f64 as *const v128);
476
477        let short_ema_vec = f64x2_splat(short_ema);
478        let long_ema_vec = f64x2_splat(long_ema);
479
480        let new_short_ema_vec = f64x2_add(
481            f64x2_mul(alpha_short_vec, price_vec),
482            f64x2_mul(one_minus_alpha_short_vec, short_ema_vec),
483        );
484
485        let new_long_ema_vec = f64x2_add(
486            f64x2_mul(alpha_long_vec, price_vec),
487            f64x2_mul(one_minus_alpha_long_vec, long_ema_vec),
488        );
489
490        let apo_vec = f64x2_sub(new_short_ema_vec, new_long_ema_vec);
491
492        v128_store(&mut out[i] as *mut f64 as *mut v128, apo_vec);
493
494        short_ema = f64x2_extract_lane::<1>(new_short_ema_vec);
495        long_ema = f64x2_extract_lane::<1>(new_long_ema_vec);
496
497        i += 2;
498    }
499
500    if i < data.len() {
501        let price = data[i];
502        short_ema = alpha_short * price + one_minus_alpha_short * short_ema;
503        long_ema = alpha_long * price + one_minus_alpha_long * long_ema;
504        out[i] = short_ema - long_ema;
505    }
506}
507
508#[derive(Clone, Debug)]
509pub struct ApoStream {
510    short: usize,
511    long: usize,
512    alpha_short: f64,
513    alpha_long: f64,
514
515    oma_short: f64,
516    oma_long: f64,
517    short_ema: f64,
518    long_ema: f64,
519    filled: bool,
520    nan_leading: usize,
521    seen: usize,
522}
523
524impl ApoStream {
525    #[inline(always)]
526    pub fn try_new(params: ApoParams) -> Result<Self, ApoError> {
527        let short = params.short_period.unwrap_or(10);
528        let long = params.long_period.unwrap_or(20);
529        if short == 0 || long == 0 {
530            return Err(ApoError::InvalidPeriod { short, long });
531        }
532        if short >= long {
533            return Err(ApoError::ShortPeriodNotLessThanLong { short, long });
534        }
535
536        let alpha_short = 2.0 / (short as f64 + 1.0);
537        let alpha_long = 2.0 / (long as f64 + 1.0);
538        Ok(Self {
539            short,
540            long,
541            alpha_short,
542            alpha_long,
543            oma_short: 1.0 - alpha_short,
544            oma_long: 1.0 - alpha_long,
545            short_ema: f64::NAN,
546            long_ema: f64::NAN,
547            filled: false,
548            nan_leading: 0,
549            seen: 0,
550        })
551    }
552
553    #[inline(always)]
554    pub fn update(&mut self, price: f64) -> Option<f64> {
555        if !self.filled {
556            if price.is_nan() {
557                self.nan_leading += 1;
558                return None;
559            }
560            self.short_ema = price;
561            self.long_ema = price;
562            self.filled = true;
563            self.seen = 1;
564            return Some(0.0);
565        }
566
567        self.seen += 1;
568
569        if price.is_nan() {
570            self.short_ema = f64::NAN;
571            self.long_ema = f64::NAN;
572            return Some(f64::NAN);
573        }
574
575        let se_prev = self.short_ema;
576        let le_prev = self.long_ema;
577        self.short_ema = self.alpha_short * price + self.oma_short * se_prev;
578        self.long_ema = self.alpha_long * price + self.oma_long * le_prev;
579        Some(self.short_ema - self.long_ema)
580    }
581
582    #[inline(always)]
583    pub fn update_fastmath(&mut self, price: f64) -> Option<f64> {
584        if !self.filled {
585            if price.is_nan() {
586                self.nan_leading += 1;
587                return None;
588            }
589            self.short_ema = price;
590            self.long_ema = price;
591            self.filled = true;
592            self.seen = 1;
593            return Some(0.0);
594        }
595
596        self.seen += 1;
597
598        if price.is_nan() {
599            self.short_ema = f64::NAN;
600            self.long_ema = f64::NAN;
601            return Some(f64::NAN);
602        }
603
604        let ds = price - self.short_ema;
605        let dl = price - self.long_ema;
606        self.short_ema = ds.mul_add(self.alpha_short, self.short_ema);
607        self.long_ema = dl.mul_add(self.alpha_long, self.long_ema);
608        Some(self.short_ema - self.long_ema)
609    }
610}
611
612pub fn apo_into_slice(dst: &mut [f64], input: &ApoInput, kern: Kernel) -> Result<(), ApoError> {
613    let (data, first, short, long, len, chosen) = apo_prepare(input, kern)?;
614    if dst.len() != len {
615        return Err(ApoError::OutputLengthMismatch {
616            expected: len,
617            got: dst.len(),
618        });
619    }
620    apo_compute_into(data, first, short, long, chosen, dst);
621    for v in &mut dst[..first] {
622        *v = f64::NAN;
623    }
624    Ok(())
625}
626
627#[derive(Clone, Debug)]
628pub struct ApoBatchRange {
629    pub short: (usize, usize, usize),
630    pub long: (usize, usize, usize),
631}
632impl Default for ApoBatchRange {
633    fn default() -> Self {
634        Self {
635            short: (10, 10, 0),
636            long: (20, 269, 1),
637        }
638    }
639}
640
641#[derive(Clone, Debug, Default)]
642pub struct ApoBatchBuilder {
643    range: ApoBatchRange,
644    kernel: Kernel,
645}
646impl ApoBatchBuilder {
647    pub fn new() -> Self {
648        Self::default()
649    }
650    pub fn kernel(mut self, k: Kernel) -> Self {
651        self.kernel = k;
652        self
653    }
654    pub fn short_range(mut self, start: usize, end: usize, step: usize) -> Self {
655        self.range.short = (start, end, step);
656        self
657    }
658    pub fn short_static(mut self, s: usize) -> Self {
659        self.range.short = (s, s, 0);
660        self
661    }
662    pub fn long_range(mut self, start: usize, end: usize, step: usize) -> Self {
663        self.range.long = (start, end, step);
664        self
665    }
666    pub fn long_static(mut self, s: usize) -> Self {
667        self.range.long = (s, s, 0);
668        self
669    }
670    pub fn apply_slice(self, data: &[f64]) -> Result<ApoBatchOutput, ApoError> {
671        apo_batch_with_kernel(data, &self.range, self.kernel)
672    }
673    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<ApoBatchOutput, ApoError> {
674        ApoBatchBuilder::new().kernel(k).apply_slice(data)
675    }
676    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<ApoBatchOutput, ApoError> {
677        let slice = source_type(c, src);
678        self.apply_slice(slice)
679    }
680    pub fn with_default_candles(c: &Candles) -> Result<ApoBatchOutput, ApoError> {
681        ApoBatchBuilder::new()
682            .kernel(Kernel::Auto)
683            .apply_candles(c, "close")
684    }
685}
686
687#[derive(Clone, Debug)]
688pub struct ApoBatchOutput {
689    pub values: Vec<f64>,
690    pub combos: Vec<ApoParams>,
691    pub rows: usize,
692    pub cols: usize,
693}
694impl ApoBatchOutput {
695    pub fn row_for_params(&self, p: &ApoParams) -> Option<usize> {
696        self.combos.iter().position(|c| {
697            c.short_period.unwrap_or(10) == p.short_period.unwrap_or(10)
698                && c.long_period.unwrap_or(20) == p.long_period.unwrap_or(20)
699        })
700    }
701    pub fn values_for(&self, p: &ApoParams) -> Option<&[f64]> {
702        self.row_for_params(p).map(|row| {
703            let start = row * self.cols;
704            &self.values[start..start + self.cols]
705        })
706    }
707}
708
709#[inline(always)]
710fn expand_grid(r: &ApoBatchRange) -> Result<Vec<ApoParams>, ApoError> {
711    fn axis((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, ApoError> {
712        if step == 0 || start == end {
713            return Ok(vec![start]);
714        }
715        let mut v = Vec::new();
716        if start < end {
717            let mut cur = start;
718            while cur <= end {
719                v.push(cur);
720                match cur.checked_add(step) {
721                    Some(n) => cur = n,
722                    None => break,
723                }
724            }
725        } else {
726            let mut cur = start;
727            while cur >= end {
728                v.push(cur);
729                if let Some(n) = cur.checked_sub(step) {
730                    cur = n;
731                } else {
732                    break;
733                }
734                if cur == usize::MAX {
735                    break;
736                }
737            }
738        }
739        if v.is_empty() {
740            return Err(ApoError::InvalidRange { start, end, step });
741        }
742        Ok(v)
743    }
744    let shorts = axis(r.short)?;
745    let longs = axis(r.long)?;
746    let mut out = Vec::with_capacity(shorts.len().saturating_mul(longs.len()));
747    for &s in &shorts {
748        for &l in &longs {
749            if s < l && s > 0 && l > 0 {
750                out.push(ApoParams {
751                    short_period: Some(s),
752                    long_period: Some(l),
753                });
754            }
755        }
756    }
757    Ok(out)
758}
759
760#[inline(always)]
761pub fn apo_batch_with_kernel(
762    data: &[f64],
763    sweep: &ApoBatchRange,
764    k: Kernel,
765) -> Result<ApoBatchOutput, ApoError> {
766    let kernel = match k {
767        Kernel::Auto => detect_best_batch_kernel(),
768        other if other.is_batch() => other,
769        other => return Err(ApoError::InvalidKernelForBatch(other)),
770    };
771    apo_batch_par_slice(data, sweep, kernel)
772}
773
774#[inline(always)]
775pub fn apo_batch_slice(
776    data: &[f64],
777    sweep: &ApoBatchRange,
778    kern: Kernel,
779) -> Result<ApoBatchOutput, ApoError> {
780    apo_batch_inner(data, sweep, kern, false)
781}
782
783#[inline(always)]
784pub fn apo_batch_par_slice(
785    data: &[f64],
786    sweep: &ApoBatchRange,
787    kern: Kernel,
788) -> Result<ApoBatchOutput, ApoError> {
789    apo_batch_inner(data, sweep, kern, true)
790}
791
792#[inline(always)]
793fn apo_batch_inner(
794    data: &[f64],
795    sweep: &ApoBatchRange,
796    kern: Kernel,
797    parallel: bool,
798) -> Result<ApoBatchOutput, ApoError> {
799    let combos = expand_grid(sweep)?;
800    if combos.is_empty() {
801        return Err(ApoError::InvalidRange {
802            start: sweep.short.0,
803            end: sweep.short.1,
804            step: sweep.short.2,
805        });
806    }
807    let first = data
808        .iter()
809        .position(|x| !x.is_nan())
810        .ok_or(ApoError::AllValuesNaN)?;
811    let max_long = combos.iter().map(|c| c.long_period.unwrap()).max().unwrap();
812    if data.len() - first < max_long {
813        return Err(ApoError::NotEnoughValidData {
814            needed: max_long,
815            valid: data.len() - first,
816        });
817    }
818
819    let rows = combos.len();
820    let cols = data.len();
821
822    let _ = rows.checked_mul(cols).ok_or(ApoError::InvalidRange {
823        start: rows,
824        end: cols,
825        step: 0,
826    })?;
827
828    let mut buf_mu = make_uninit_matrix(rows, cols);
829
830    let warm: Vec<usize> = combos.iter().map(|_c| first).collect();
831
832    init_matrix_prefixes(&mut buf_mu, cols, &warm);
833
834    let mut buf_guard = ManuallyDrop::new(buf_mu);
835    let values: &mut [f64] = unsafe {
836        core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
837    };
838
839    match kern {
840        Kernel::Scalar | Kernel::ScalarBatch => {
841            let do_row = |row: usize, out_row: &mut [f64]| unsafe {
842                let s = combos[row].short_period.unwrap();
843                let l = combos[row].long_period.unwrap();
844                apo_row_scalar(data, first, s, l, out_row)
845            };
846            if parallel {
847                #[cfg(not(target_arch = "wasm32"))]
848                {
849                    values
850                        .par_chunks_mut(cols)
851                        .enumerate()
852                        .for_each(|(row, slice)| do_row(row, slice));
853                }
854                #[cfg(target_arch = "wasm32")]
855                {
856                    for (row, slice) in values.chunks_mut(cols).enumerate() {
857                        do_row(row, slice);
858                    }
859                }
860            } else {
861                for (row, slice) in values.chunks_mut(cols).enumerate() {
862                    do_row(row, slice);
863                }
864            }
865        }
866        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
867        Kernel::Avx2 => {
868            let do_row = |row: usize, out_row: &mut [f64]| unsafe {
869                let s = combos[row].short_period.unwrap();
870                let l = combos[row].long_period.unwrap();
871                apo_row_avx2(data, first, s, l, out_row)
872            };
873            if parallel {
874                #[cfg(not(target_arch = "wasm32"))]
875                values
876                    .par_chunks_mut(cols)
877                    .enumerate()
878                    .for_each(|(row, slice)| do_row(row, slice));
879                #[cfg(target_arch = "wasm32")]
880                for (row, slice) in values.chunks_mut(cols).enumerate() {
881                    do_row(row, slice);
882                }
883            } else {
884                for (row, slice) in values.chunks_mut(cols).enumerate() {
885                    do_row(row, slice);
886                }
887            }
888        }
889        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
890        Kernel::Avx512 => {
891            let do_row = |row: usize, out_row: &mut [f64]| unsafe {
892                let s = combos[row].short_period.unwrap();
893                let l = combos[row].long_period.unwrap();
894                apo_row_avx512(data, first, s, l, out_row)
895            };
896            if parallel {
897                #[cfg(not(target_arch = "wasm32"))]
898                values
899                    .par_chunks_mut(cols)
900                    .enumerate()
901                    .for_each(|(row, slice)| do_row(row, slice));
902                #[cfg(target_arch = "wasm32")]
903                for (row, slice) in values.chunks_mut(cols).enumerate() {
904                    do_row(row, slice);
905                }
906            } else {
907                for (row, slice) in values.chunks_mut(cols).enumerate() {
908                    do_row(row, slice);
909                }
910            }
911        }
912        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
913        Kernel::Avx2Batch => {
914            const LANES: usize = 4;
915            let blocks = (rows + LANES - 1) / LANES;
916            let do_block = |b: usize, blk: &mut [f64]| unsafe {
917                let start_row = b * LANES;
918                let end_row = usize::min(start_row + LANES, rows);
919                apo_batch_rows_avx2(data, first, cols, &combos[start_row..end_row], blk);
920            };
921            if parallel {
922                #[cfg(not(target_arch = "wasm32"))]
923                values
924                    .par_chunks_mut(cols * LANES)
925                    .enumerate()
926                    .for_each(|(b, blk)| do_block(b, blk));
927                #[cfg(target_arch = "wasm32")]
928                for (b, blk) in values.chunks_mut(cols * LANES).enumerate() {
929                    do_block(b, blk);
930                }
931            } else {
932                for (b, blk) in values.chunks_mut(cols * LANES).enumerate() {
933                    do_block(b, blk);
934                }
935            }
936        }
937        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
938        Kernel::Avx512Batch => {
939            const LANES: usize = 8;
940            let blocks = (rows + LANES - 1) / LANES;
941            let do_block = |b: usize, blk: &mut [f64]| unsafe {
942                let start_row = b * LANES;
943                let end_row = usize::min(start_row + LANES, rows);
944                apo_batch_rows_avx512(data, first, cols, &combos[start_row..end_row], blk);
945            };
946            if parallel {
947                #[cfg(not(target_arch = "wasm32"))]
948                values
949                    .par_chunks_mut(cols * LANES)
950                    .enumerate()
951                    .for_each(|(b, blk)| do_block(b, blk));
952                #[cfg(target_arch = "wasm32")]
953                for (b, blk) in values.chunks_mut(cols * LANES).enumerate() {
954                    do_block(b, blk);
955                }
956            } else {
957                for (b, blk) in values.chunks_mut(cols * LANES).enumerate() {
958                    do_block(b, blk);
959                }
960            }
961        }
962        _ => unreachable!(),
963    }
964
965    let values = unsafe {
966        Vec::from_raw_parts(
967            buf_guard.as_mut_ptr() as *mut f64,
968            buf_guard.len(),
969            buf_guard.capacity(),
970        )
971    };
972
973    Ok(ApoBatchOutput {
974        values,
975        combos,
976        rows,
977        cols,
978    })
979}
980
981#[inline(always)]
982fn apo_batch_inner_into(
983    data: &[f64],
984    sweep: &ApoBatchRange,
985    kern: Kernel,
986    parallel: bool,
987    out: &mut [f64],
988) -> Result<Vec<ApoParams>, ApoError> {
989    let combos = expand_grid(sweep)?;
990    if combos.is_empty() {
991        return Err(ApoError::InvalidRange {
992            start: sweep.short.0,
993            end: sweep.short.1,
994            step: sweep.short.2,
995        });
996    }
997
998    let first = data
999        .iter()
1000        .position(|x| !x.is_nan())
1001        .ok_or(ApoError::AllValuesNaN)?;
1002    let max_long = combos.iter().map(|c| c.long_period.unwrap()).max().unwrap();
1003    if data.len() - first < max_long {
1004        return Err(ApoError::NotEnoughValidData {
1005            needed: max_long,
1006            valid: data.len() - first,
1007        });
1008    }
1009
1010    let rows = combos.len();
1011    let cols = data.len();
1012    let expected = rows.checked_mul(cols).ok_or(ApoError::InvalidRange {
1013        start: rows,
1014        end: cols,
1015        step: 0,
1016    })?;
1017    if out.len() != expected {
1018        return Err(ApoError::OutputLengthMismatch {
1019            expected,
1020            got: out.len(),
1021        });
1022    }
1023
1024    let out_mu: &mut [MaybeUninit<f64>] = unsafe {
1025        core::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
1026    };
1027    let warm = vec![first; rows];
1028    init_matrix_prefixes(out_mu, cols, &warm);
1029
1030    match kern {
1031        Kernel::Scalar | Kernel::ScalarBatch => {
1032            let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
1033                let s = combos[row].short_period.unwrap();
1034                let l = combos[row].long_period.unwrap();
1035                let dst: &mut [f64] =
1036                    core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
1037                apo_row_scalar(data, first, s, l, dst)
1038            };
1039            if parallel {
1040                #[cfg(not(target_arch = "wasm32"))]
1041                out_mu
1042                    .par_chunks_mut(cols)
1043                    .enumerate()
1044                    .for_each(|(r, s)| do_row(r, s));
1045                #[cfg(target_arch = "wasm32")]
1046                for (r, s) in out_mu.chunks_mut(cols).enumerate() {
1047                    do_row(r, s);
1048                }
1049            } else {
1050                for (r, s) in out_mu.chunks_mut(cols).enumerate() {
1051                    do_row(r, s);
1052                }
1053            }
1054        }
1055        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1056        Kernel::Avx2 => {
1057            let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
1058                let s = combos[row].short_period.unwrap();
1059                let l = combos[row].long_period.unwrap();
1060                let dst: &mut [f64] =
1061                    core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
1062                apo_row_avx2(data, first, s, l, dst)
1063            };
1064            if parallel {
1065                #[cfg(not(target_arch = "wasm32"))]
1066                out_mu
1067                    .par_chunks_mut(cols)
1068                    .enumerate()
1069                    .for_each(|(r, s)| do_row(r, s));
1070                #[cfg(target_arch = "wasm32")]
1071                for (r, s) in out_mu.chunks_mut(cols).enumerate() {
1072                    do_row(r, s);
1073                }
1074            } else {
1075                for (r, s) in out_mu.chunks_mut(cols).enumerate() {
1076                    do_row(r, s);
1077                }
1078            }
1079        }
1080        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1081        Kernel::Avx512 => {
1082            let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
1083                let s = combos[row].short_period.unwrap();
1084                let l = combos[row].long_period.unwrap();
1085                let dst: &mut [f64] =
1086                    core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
1087                apo_row_avx512(data, first, s, l, dst)
1088            };
1089            if parallel {
1090                #[cfg(not(target_arch = "wasm32"))]
1091                out_mu
1092                    .par_chunks_mut(cols)
1093                    .enumerate()
1094                    .for_each(|(r, s)| do_row(r, s));
1095                #[cfg(target_arch = "wasm32")]
1096                for (r, s) in out_mu.chunks_mut(cols).enumerate() {
1097                    do_row(r, s);
1098                }
1099            } else {
1100                for (r, s) in out_mu.chunks_mut(cols).enumerate() {
1101                    do_row(r, s);
1102                }
1103            }
1104        }
1105        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1106        Kernel::Avx2Batch => {
1107            const LANES: usize = 4;
1108            let do_block = |b: usize, blk_mu: &mut [MaybeUninit<f64>]| unsafe {
1109                let start_row = b * LANES;
1110                let end_row = usize::min(start_row + LANES, rows);
1111                let blk: &mut [f64] =
1112                    core::slice::from_raw_parts_mut(blk_mu.as_mut_ptr() as *mut f64, blk_mu.len());
1113                apo_batch_rows_avx2(data, first, cols, &combos[start_row..end_row], blk);
1114            };
1115            if parallel {
1116                #[cfg(not(target_arch = "wasm32"))]
1117                out_mu
1118                    .par_chunks_mut(cols * LANES)
1119                    .enumerate()
1120                    .for_each(|(b, blk)| do_block(b, blk));
1121                #[cfg(target_arch = "wasm32")]
1122                for (b, blk) in out_mu.chunks_mut(cols * LANES).enumerate() {
1123                    do_block(b, blk);
1124                }
1125            } else {
1126                for (b, blk) in out_mu.chunks_mut(cols * LANES).enumerate() {
1127                    do_block(b, blk);
1128                }
1129            }
1130        }
1131        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1132        Kernel::Avx512Batch => {
1133            const LANES: usize = 8;
1134            let do_block = |b: usize, blk_mu: &mut [MaybeUninit<f64>]| unsafe {
1135                let start_row = b * LANES;
1136                let end_row = usize::min(start_row + LANES, rows);
1137                let blk: &mut [f64] =
1138                    core::slice::from_raw_parts_mut(blk_mu.as_mut_ptr() as *mut f64, blk_mu.len());
1139                apo_batch_rows_avx512(data, first, cols, &combos[start_row..end_row], blk);
1140            };
1141            if parallel {
1142                #[cfg(not(target_arch = "wasm32"))]
1143                out_mu
1144                    .par_chunks_mut(cols * LANES)
1145                    .enumerate()
1146                    .for_each(|(b, blk)| do_block(b, blk));
1147                #[cfg(target_arch = "wasm32")]
1148                for (b, blk) in out_mu.chunks_mut(cols * LANES).enumerate() {
1149                    do_block(b, blk);
1150                }
1151            } else {
1152                for (b, blk) in out_mu.chunks_mut(cols * LANES).enumerate() {
1153                    do_block(b, blk);
1154                }
1155            }
1156        }
1157        _ => unreachable!(),
1158    }
1159
1160    Ok(combos)
1161}
1162
1163#[inline(always)]
1164pub unsafe fn apo_row_scalar(
1165    data: &[f64],
1166    first: usize,
1167    short: usize,
1168    long: usize,
1169    out: &mut [f64],
1170) {
1171    apo_scalar(data, short, long, first, out)
1172}
1173#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1174#[inline]
1175#[target_feature(enable = "avx2")]
1176pub unsafe fn apo_row_avx2(data: &[f64], first: usize, short: usize, long: usize, out: &mut [f64]) {
1177    apo_avx2(data, short, long, first, out)
1178}
1179#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1180#[inline]
1181#[target_feature(enable = "avx512f")]
1182pub unsafe fn apo_row_avx512(
1183    data: &[f64],
1184    first: usize,
1185    short: usize,
1186    long: usize,
1187    out: &mut [f64],
1188) {
1189    if long <= 32 {
1190        apo_row_avx512_short(data, first, short, long, out)
1191    } else {
1192        apo_row_avx512_long(data, first, short, long, out)
1193    }
1194}
1195#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1196#[inline]
1197#[target_feature(enable = "avx512f")]
1198pub unsafe fn apo_row_avx512_short(
1199    data: &[f64],
1200    first: usize,
1201    short: usize,
1202    long: usize,
1203    out: &mut [f64],
1204) {
1205    apo_avx512_short(data, short, long, first, out)
1206}
1207#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1208#[inline]
1209#[target_feature(enable = "avx512f")]
1210pub unsafe fn apo_row_avx512_long(
1211    data: &[f64],
1212    first: usize,
1213    short: usize,
1214    long: usize,
1215    out: &mut [f64],
1216) {
1217    apo_avx512_long(data, short, long, first, out)
1218}
1219
1220#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1221#[inline]
1222#[target_feature(enable = "avx2")]
1223unsafe fn apo_batch_rows_avx2(
1224    data: &[f64],
1225    first: usize,
1226    cols: usize,
1227    combos_block: &[ApoParams],
1228    out_block: &mut [f64],
1229) {
1230    use core::arch::x86_64::*;
1231    let lanes = 4usize;
1232    let l = combos_block.len();
1233
1234    let mut as_arr = [0.0f64; 4];
1235    let mut al_arr = [0.0f64; 4];
1236    let mut os_arr = [1.0f64; 4];
1237    let mut ol_arr = [1.0f64; 4];
1238    for (j, p) in combos_block.iter().enumerate() {
1239        let s = p.short_period.unwrap_or(10);
1240        let g = p.long_period.unwrap_or(20);
1241        let a_s = 2.0 / (s as f64 + 1.0);
1242        let a_l = 2.0 / (g as f64 + 1.0);
1243        as_arr[j] = a_s;
1244        al_arr[j] = a_l;
1245        os_arr[j] = 1.0 - a_s;
1246        ol_arr[j] = 1.0 - a_l;
1247    }
1248    let a_s = _mm256_setr_pd(as_arr[0], as_arr[1], as_arr[2], as_arr[3]);
1249    let a_l = _mm256_setr_pd(al_arr[0], al_arr[1], al_arr[2], al_arr[3]);
1250    let o_s = _mm256_setr_pd(os_arr[0], os_arr[1], os_arr[2], os_arr[3]);
1251    let o_l = _mm256_setr_pd(ol_arr[0], ol_arr[1], ol_arr[2], ol_arr[3]);
1252
1253    let x0 = *data.get_unchecked(first);
1254    let mut se = _mm256_set1_pd(x0);
1255    let mut le = _mm256_set1_pd(x0);
1256
1257    for j in 0..l {
1258        *out_block.get_unchecked_mut(j * cols + first) = 0.0;
1259    }
1260    let mut i = first + 1;
1261    while i < cols {
1262        let p = _mm256_set1_pd(*data.get_unchecked(i));
1263
1264        let se1 = _mm256_add_pd(_mm256_mul_pd(a_s, p), _mm256_mul_pd(o_s, se));
1265        let le1 = _mm256_add_pd(_mm256_mul_pd(a_l, p), _mm256_mul_pd(o_l, le));
1266        se = se1;
1267        le = le1;
1268
1269        let diff = _mm256_sub_pd(se, le);
1270        let mut tmp: [f64; 4] = [0.0; 4];
1271        _mm256_storeu_pd(tmp.as_mut_ptr(), diff);
1272        for j in 0..l {
1273            *out_block.get_unchecked_mut(j * cols + i) = tmp[j];
1274        }
1275        i += 1;
1276    }
1277}
1278
1279#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1280#[inline]
1281#[target_feature(enable = "avx512f")]
1282unsafe fn apo_batch_rows_avx512(
1283    data: &[f64],
1284    first: usize,
1285    cols: usize,
1286    combos_block: &[ApoParams],
1287    out_block: &mut [f64],
1288) {
1289    use core::arch::x86_64::*;
1290    let lanes = 8usize;
1291    let l = combos_block.len();
1292
1293    let mut as_arr = [0.0f64; 8];
1294    let mut al_arr = [0.0f64; 8];
1295    let mut os_arr = [1.0f64; 8];
1296    let mut ol_arr = [1.0f64; 8];
1297    for (j, p) in combos_block.iter().enumerate() {
1298        let s = p.short_period.unwrap_or(10);
1299        let g = p.long_period.unwrap_or(20);
1300        let a_s = 2.0 / (s as f64 + 1.0);
1301        let a_l = 2.0 / (g as f64 + 1.0);
1302        as_arr[j] = a_s;
1303        al_arr[j] = a_l;
1304        os_arr[j] = 1.0 - a_s;
1305        ol_arr[j] = 1.0 - a_l;
1306    }
1307    let a_s = _mm512_setr_pd(
1308        as_arr[0], as_arr[1], as_arr[2], as_arr[3], as_arr[4], as_arr[5], as_arr[6], as_arr[7],
1309    );
1310    let a_l = _mm512_setr_pd(
1311        al_arr[0], al_arr[1], al_arr[2], al_arr[3], al_arr[4], al_arr[5], al_arr[6], al_arr[7],
1312    );
1313    let o_s = _mm512_setr_pd(
1314        os_arr[0], os_arr[1], os_arr[2], os_arr[3], os_arr[4], os_arr[5], os_arr[6], os_arr[7],
1315    );
1316    let o_l = _mm512_setr_pd(
1317        ol_arr[0], ol_arr[1], ol_arr[2], ol_arr[3], ol_arr[4], ol_arr[5], ol_arr[6], ol_arr[7],
1318    );
1319
1320    let x0 = *data.get_unchecked(first);
1321    let mut se = _mm512_set1_pd(x0);
1322    let mut le = _mm512_set1_pd(x0);
1323
1324    for j in 0..l {
1325        *out_block.get_unchecked_mut(j * cols + first) = 0.0;
1326    }
1327    let mut i = first + 1;
1328    while i < cols {
1329        let p = _mm512_set1_pd(*data.get_unchecked(i));
1330        let se1 = _mm512_add_pd(_mm512_mul_pd(a_s, p), _mm512_mul_pd(o_s, se));
1331        let le1 = _mm512_add_pd(_mm512_mul_pd(a_l, p), _mm512_mul_pd(o_l, le));
1332        se = se1;
1333        le = le1;
1334
1335        let diff = _mm512_sub_pd(se, le);
1336        let mut tmp: [f64; 8] = [0.0; 8];
1337        _mm512_storeu_pd(tmp.as_mut_ptr(), diff);
1338        for j in 0..l {
1339            *out_block.get_unchecked_mut(j * cols + i) = tmp[j];
1340        }
1341        i += 1;
1342    }
1343}
1344
1345#[cfg(test)]
1346mod tests {
1347    use super::*;
1348    use crate::skip_if_unsupported;
1349    use crate::utilities::data_loader::read_candles_from_csv;
1350
1351    #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1352    #[test]
1353    fn test_apo_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
1354        let mut data: Vec<f64> = Vec::with_capacity(256);
1355        for _ in 0..5 {
1356            data.push(f64::NAN);
1357        }
1358        for i in 0..251 {
1359            let x = i as f64;
1360            data.push(100.0 + 0.1 * x + (x * 0.05).sin());
1361        }
1362
1363        let input = ApoInput::from_slice(&data, ApoParams::default());
1364
1365        let baseline = apo(&input)?.values;
1366
1367        let mut out = vec![0.0; data.len()];
1368        apo_into(&input, &mut out)?;
1369
1370        assert_eq!(baseline.len(), out.len());
1371
1372        fn eq_or_both_nan(a: f64, b: f64) -> bool {
1373            (a.is_nan() && b.is_nan()) || (a == b) || ((a - b).abs() <= 1e-12)
1374        }
1375
1376        for (i, (a, b)) in baseline
1377            .iter()
1378            .copied()
1379            .zip(out.iter().copied())
1380            .enumerate()
1381        {
1382            assert!(
1383                eq_or_both_nan(a, b),
1384                "mismatch at index {}: api={} into={}",
1385                i,
1386                a,
1387                b
1388            );
1389        }
1390        Ok(())
1391    }
1392
1393    fn check_apo_partial_params(
1394        test_name: &str,
1395        kernel: Kernel,
1396    ) -> Result<(), Box<dyn std::error::Error>> {
1397        skip_if_unsupported!(kernel, test_name);
1398        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1399        let candles = read_candles_from_csv(file_path)?;
1400        let default_params = ApoParams {
1401            short_period: None,
1402            long_period: None,
1403        };
1404        let input = ApoInput::from_candles(&candles, "close", default_params);
1405        let output = apo_with_kernel(&input, kernel)?;
1406        assert_eq!(output.values.len(), candles.close.len());
1407        Ok(())
1408    }
1409
1410    fn check_apo_accuracy(
1411        test_name: &str,
1412        kernel: Kernel,
1413    ) -> Result<(), Box<dyn std::error::Error>> {
1414        skip_if_unsupported!(kernel, test_name);
1415        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1416        let candles = read_candles_from_csv(file_path)?;
1417        let input = ApoInput::with_default_candles(&candles);
1418        let result = apo_with_kernel(&input, kernel)?;
1419        let expected_last_five = [
1420            -429.80100015922653,
1421            -401.64149983850075,
1422            -386.13569657357584,
1423            -357.92775222467753,
1424            -374.13870680232503,
1425        ];
1426        let start_index = result.values.len().saturating_sub(5);
1427        let result_last_five = &result.values[start_index..];
1428        for (i, &value) in result_last_five.iter().enumerate() {
1429            assert!(
1430                (value - expected_last_five[i]).abs() < 1e-1,
1431                "[{}] APO value mismatch at index {}: expected {}, got {}",
1432                test_name,
1433                i,
1434                expected_last_five[i],
1435                value
1436            );
1437        }
1438        Ok(())
1439    }
1440
1441    fn check_apo_default_candles(
1442        test_name: &str,
1443        kernel: Kernel,
1444    ) -> Result<(), Box<dyn std::error::Error>> {
1445        skip_if_unsupported!(kernel, test_name);
1446        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1447        let candles = read_candles_from_csv(file_path)?;
1448        let input = ApoInput::with_default_candles(&candles);
1449        match input.data {
1450            ApoData::Candles { source, .. } => assert_eq!(source, "close"),
1451            _ => panic!("Expected ApoData::Candles"),
1452        }
1453        let output = apo_with_kernel(&input, kernel)?;
1454        assert_eq!(output.values.len(), candles.close.len());
1455        Ok(())
1456    }
1457
1458    fn check_apo_zero_period(
1459        test_name: &str,
1460        kernel: Kernel,
1461    ) -> Result<(), Box<dyn std::error::Error>> {
1462        skip_if_unsupported!(kernel, test_name);
1463        let input_data = [10.0, 20.0, 30.0];
1464        let params = ApoParams {
1465            short_period: Some(0),
1466            long_period: Some(20),
1467        };
1468        let input = ApoInput::from_slice(&input_data, params);
1469        let res = apo_with_kernel(&input, kernel);
1470        assert!(
1471            res.is_err(),
1472            "[{}] APO should fail with zero period",
1473            test_name
1474        );
1475        Ok(())
1476    }
1477
1478    fn check_apo_empty_input(
1479        test_name: &str,
1480        kernel: Kernel,
1481    ) -> Result<(), Box<dyn std::error::Error>> {
1482        skip_if_unsupported!(kernel, test_name);
1483        let empty_data: Vec<f64> = vec![];
1484        let params = ApoParams::default();
1485        let input = ApoInput::from_slice(&empty_data, params);
1486        let result = apo_with_kernel(&input, kernel);
1487        assert!(result.is_err());
1488        Ok(())
1489    }
1490
1491    fn check_apo_streaming(
1492        test_name: &str,
1493        kernel: Kernel,
1494    ) -> Result<(), Box<dyn std::error::Error>> {
1495        skip_if_unsupported!(kernel, test_name);
1496        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1497        let candles = read_candles_from_csv(file_path)?;
1498        let params = ApoParams::default();
1499
1500        let input = ApoInput::from_candles(&candles, "close", params.clone());
1501        let batch_result = apo_with_kernel(&input, kernel)?;
1502
1503        let mut stream = ApoStream::try_new(params)?;
1504        let mut streaming_results = vec![];
1505
1506        for &close in &candles.close {
1507            if let Some(val) = stream.update(close) {
1508                streaming_results.push(val);
1509            } else {
1510                streaming_results.push(f64::NAN);
1511            }
1512        }
1513
1514        assert_eq!(batch_result.values.len(), streaming_results.len());
1515        let first_valid = candles.close.iter().position(|x| !x.is_nan()).unwrap_or(0);
1516
1517        for i in first_valid..batch_result.values.len() {
1518            if !batch_result.values[i].is_nan() && !streaming_results[i].is_nan() {
1519                let diff = (batch_result.values[i] - streaming_results[i]).abs();
1520                assert!(
1521                    diff < 1e-10,
1522                    "Streaming mismatch at index {}: batch={}, stream={}",
1523                    i,
1524                    batch_result.values[i],
1525                    streaming_results[i]
1526                );
1527            }
1528        }
1529        Ok(())
1530    }
1531
1532    fn check_apo_period_invalid(
1533        test_name: &str,
1534        kernel: Kernel,
1535    ) -> Result<(), Box<dyn std::error::Error>> {
1536        skip_if_unsupported!(kernel, test_name);
1537        let data_small = [10.0, 20.0, 30.0];
1538        let params = ApoParams {
1539            short_period: Some(20),
1540            long_period: Some(10),
1541        };
1542        let input = ApoInput::from_slice(&data_small, params);
1543        let res = apo_with_kernel(&input, kernel);
1544        assert!(
1545            res.is_err(),
1546            "[{}] APO should fail with invalid period",
1547            test_name
1548        );
1549        Ok(())
1550    }
1551
1552    fn check_apo_very_small_dataset(
1553        test_name: &str,
1554        kernel: Kernel,
1555    ) -> Result<(), Box<dyn std::error::Error>> {
1556        skip_if_unsupported!(kernel, test_name);
1557        let single_point = [42.0];
1558        let params = ApoParams {
1559            short_period: Some(9),
1560            long_period: Some(10),
1561        };
1562        let input = ApoInput::from_slice(&single_point, params);
1563        let res = apo_with_kernel(&input, kernel);
1564        assert!(
1565            res.is_err(),
1566            "[{}] APO should fail with insufficient data",
1567            test_name
1568        );
1569        Ok(())
1570    }
1571
1572    fn check_apo_reinput(
1573        test_name: &str,
1574        kernel: Kernel,
1575    ) -> Result<(), Box<dyn std::error::Error>> {
1576        skip_if_unsupported!(kernel, test_name);
1577        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1578        let candles = read_candles_from_csv(file_path)?;
1579        let first_params = ApoParams {
1580            short_period: Some(10),
1581            long_period: Some(20),
1582        };
1583        let first_input = ApoInput::from_candles(&candles, "close", first_params);
1584        let first_result = apo_with_kernel(&first_input, kernel)?;
1585        let second_params = ApoParams {
1586            short_period: Some(5),
1587            long_period: Some(15),
1588        };
1589        let second_input = ApoInput::from_slice(&first_result.values, second_params);
1590        let second_result = apo_with_kernel(&second_input, kernel)?;
1591        assert_eq!(second_result.values.len(), first_result.values.len());
1592        Ok(())
1593    }
1594
1595    fn check_apo_nan_handling(
1596        test_name: &str,
1597        kernel: Kernel,
1598    ) -> Result<(), Box<dyn std::error::Error>> {
1599        skip_if_unsupported!(kernel, test_name);
1600        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1601        let candles = read_candles_from_csv(file_path)?;
1602        let input = ApoInput::from_candles(
1603            &candles,
1604            "close",
1605            ApoParams {
1606                short_period: Some(10),
1607                long_period: Some(20),
1608            },
1609        );
1610        let res = apo_with_kernel(&input, kernel)?;
1611        assert_eq!(res.values.len(), candles.close.len());
1612        if res.values.len() > 30 {
1613            for (i, &val) in res.values[30..].iter().enumerate() {
1614                assert!(
1615                    !val.is_nan(),
1616                    "[{}] Found unexpected NaN at out-index {}",
1617                    test_name,
1618                    30 + i
1619                );
1620            }
1621        }
1622        Ok(())
1623    }
1624
1625    #[cfg(debug_assertions)]
1626    fn check_apo_no_poison(
1627        test_name: &str,
1628        kernel: Kernel,
1629    ) -> Result<(), Box<dyn std::error::Error>> {
1630        skip_if_unsupported!(kernel, test_name);
1631        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1632        let candles = read_candles_from_csv(file_path)?;
1633
1634        let input = ApoInput::from_candles(&candles, "close", ApoParams::default());
1635        let output = apo_with_kernel(&input, kernel)?;
1636
1637        for (i, &val) in output.values.iter().enumerate() {
1638            if val.is_nan() {
1639                continue;
1640            }
1641
1642            let bits = val.to_bits();
1643
1644            if bits == 0x11111111_11111111 {
1645                panic!(
1646                    "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {}",
1647                    test_name, val, bits, i
1648                );
1649            }
1650
1651            if bits == 0x22222222_22222222 {
1652                panic!(
1653                    "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {}",
1654                    test_name, val, bits, i
1655                );
1656            }
1657
1658            if bits == 0x33333333_33333333 {
1659                panic!(
1660                    "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {}",
1661                    test_name, val, bits, i
1662                );
1663            }
1664        }
1665
1666        Ok(())
1667    }
1668
1669    #[cfg(not(debug_assertions))]
1670    fn check_apo_no_poison(
1671        _test_name: &str,
1672        _kernel: Kernel,
1673    ) -> Result<(), Box<dyn std::error::Error>> {
1674        Ok(())
1675    }
1676
1677    #[cfg(feature = "proptest")]
1678    #[allow(clippy::float_cmp)]
1679    fn check_apo_property(
1680        test_name: &str,
1681        kernel: Kernel,
1682    ) -> Result<(), Box<dyn std::error::Error>> {
1683        use proptest::prelude::*;
1684        skip_if_unsupported!(kernel, test_name);
1685
1686        let random_data_strat = (3usize..=20, 10usize..=50)
1687            .prop_filter("short < long", |(s, l)| s < l)
1688            .prop_flat_map(|(short_period, long_period)| {
1689                let len = long_period * 2..400;
1690                (
1691                    prop::collection::vec(
1692                        (10f64..10000f64).prop_filter("finite", |x| x.is_finite()),
1693                        len,
1694                    ),
1695                    Just(short_period),
1696                    Just(long_period),
1697                    Just("random"),
1698                )
1699            });
1700
1701        let constant_data_strat = (3usize..=20, 10usize..=50)
1702            .prop_filter("short < long", |(s, l)| s < l)
1703            .prop_flat_map(|(short_period, long_period)| {
1704                let len = long_period * 2..200;
1705                (
1706                    prop::collection::vec(Just(100.0f64), len),
1707                    Just(short_period),
1708                    Just(long_period),
1709                    Just("constant"),
1710                )
1711            });
1712
1713        let trending_data_strat = (3usize..=20, 10usize..=50)
1714            .prop_filter("short < long", |(s, l)| s < l)
1715            .prop_flat_map(|(short_period, long_period)| {
1716                let len = long_period * 2..200;
1717                (
1718                    (50..150usize).prop_flat_map(move |size| {
1719                        (0.1f64..5.0).prop_map(move |slope| {
1720                            (0..size)
1721                                .map(|i| 100.0 + slope * i as f64)
1722                                .collect::<Vec<f64>>()
1723                        })
1724                    }),
1725                    Just(short_period),
1726                    Just(long_period),
1727                    Just("trending"),
1728                )
1729            });
1730
1731        let strat = prop_oneof![random_data_strat, constant_data_strat, trending_data_strat,];
1732
1733        proptest::test_runner::TestRunner::default()
1734            .run(&strat, |(data, short_period, long_period, data_type)| {
1735                let params = ApoParams {
1736                    short_period: Some(short_period),
1737                    long_period: Some(long_period),
1738                };
1739                let input = ApoInput::from_slice(&data, params.clone());
1740
1741                let result = apo_with_kernel(&input, kernel);
1742                prop_assert!(result.is_ok(), "APO computation failed: {:?}", result);
1743
1744                let ApoOutput { values: out } = result.unwrap();
1745
1746                prop_assert_eq!(out.len(), data.len(), "Output length mismatch");
1747
1748                let first_valid = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1749                if first_valid < data.len() {
1750                    prop_assert!(
1751                        out[first_valid].abs() < 1e-10,
1752                        "First APO value should be 0, got {} at index {}",
1753                        out[first_valid],
1754                        first_valid
1755                    );
1756                }
1757
1758                for i in first_valid..out.len() {
1759                    prop_assert!(
1760                        out[i].is_finite(),
1761                        "APO output at index {} should be finite, got {}",
1762                        i,
1763                        out[i]
1764                    );
1765                }
1766
1767                let data_min = data
1768                    .iter()
1769                    .filter(|x| x.is_finite())
1770                    .fold(f64::INFINITY, |a, &b| a.min(b));
1771                let data_max = data
1772                    .iter()
1773                    .filter(|x| x.is_finite())
1774                    .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
1775                let data_range = data_max - data_min;
1776
1777                let apo_bound = data_range * 0.3;
1778
1779                for i in first_valid..out.len() {
1780                    prop_assert!(
1781                        out[i].abs() <= apo_bound,
1782                        "APO value at index {} exceeds expected bound: {} > {}",
1783                        i,
1784                        out[i].abs(),
1785                        apo_bound
1786                    );
1787                }
1788
1789                match data_type {
1790                    "constant" => {
1791                        for i in first_valid..out.len() {
1792                            prop_assert!(
1793                                out[i].abs() < 1e-9,
1794                                "APO should be ~0 for constant data, got {} at index {}",
1795                                out[i],
1796                                i
1797                            );
1798                        }
1799                    }
1800                    "trending" => {
1801                        if data.len() > long_period * 2 {
1802                            let check_start = first_valid + long_period;
1803                            let check_end = out.len();
1804                            if check_start < check_end {
1805                                let is_increasing = data[first_valid] < data[data.len() - 1];
1806
1807                                let positive_count = out[check_start..check_end]
1808                                    .iter()
1809                                    .filter(|&&v| v > 0.0)
1810                                    .count();
1811                                let total_count = check_end - check_start;
1812
1813                                if is_increasing {
1814                                    prop_assert!(
1815										positive_count > total_count / 2,
1816										"APO should be mostly positive for uptrend, got {} positive out of {}",
1817										positive_count,
1818										total_count
1819									);
1820                                } else {
1821                                    prop_assert!(
1822										positive_count < total_count / 2,
1823										"APO should be mostly negative for downtrend, got {} positive out of {}",
1824										positive_count,
1825										total_count
1826									);
1827                                }
1828                            }
1829                        }
1830                    }
1831                    _ => {}
1832                }
1833
1834                if data.len() >= 3 && first_valid + 2 < data.len() {
1835                    let alpha_short = 2.0 / (short_period as f64 + 1.0);
1836                    let alpha_long = 2.0 / (long_period as f64 + 1.0);
1837
1838                    let mut short_ema = data[first_valid];
1839                    let mut long_ema = data[first_valid];
1840                    let expected_first = 0.0;
1841                    prop_assert!(
1842                        (out[first_valid] - expected_first).abs() < 1e-9,
1843                        "First value mismatch: expected {}, got {}",
1844                        expected_first,
1845                        out[first_valid]
1846                    );
1847
1848                    if first_valid + 1 < data.len() {
1849                        let price = data[first_valid + 1];
1850                        short_ema = alpha_short * price + (1.0 - alpha_short) * short_ema;
1851                        long_ema = alpha_long * price + (1.0 - alpha_long) * long_ema;
1852                        let expected_second = short_ema - long_ema;
1853                        prop_assert!(
1854                            (out[first_valid + 1] - expected_second).abs() < 1e-9,
1855                            "Second value mismatch: expected {}, got {}",
1856                            expected_second,
1857                            out[first_valid + 1]
1858                        );
1859                    }
1860
1861                    if first_valid + 2 < data.len() {
1862                        let price = data[first_valid + 2];
1863                        short_ema = alpha_short * price + (1.0 - alpha_short) * short_ema;
1864                        long_ema = alpha_long * price + (1.0 - alpha_long) * long_ema;
1865                        let expected_third = short_ema - long_ema;
1866                        prop_assert!(
1867                            (out[first_valid + 2] - expected_third).abs() < 1e-9,
1868                            "Third value mismatch: expected {}, got {}",
1869                            expected_third,
1870                            out[first_valid + 2]
1871                        );
1872                    }
1873                }
1874
1875                let ref_output = apo_with_kernel(&input, Kernel::Scalar);
1876                prop_assert!(ref_output.is_ok(), "Reference scalar computation failed");
1877                let ApoOutput { values: ref_out } = ref_output.unwrap();
1878
1879                for (i, (&val, &ref_val)) in out.iter().zip(ref_out.iter()).enumerate() {
1880                    if !val.is_finite() || !ref_val.is_finite() {
1881                        prop_assert_eq!(
1882                            val.is_nan(),
1883                            ref_val.is_nan(),
1884                            "NaN mismatch at index {}: kernel={}, scalar={}",
1885                            i,
1886                            val,
1887                            ref_val
1888                        );
1889                    } else {
1890                        let diff = (val - ref_val).abs();
1891                        let ulp_diff = val.to_bits().abs_diff(ref_val.to_bits());
1892                        prop_assert!(
1893                            diff <= 1e-9 || ulp_diff <= 4,
1894                            "Kernel mismatch at index {}: {} vs {} (diff: {}, ULP: {})",
1895                            i,
1896                            val,
1897                            ref_val,
1898                            diff,
1899                            ulp_diff
1900                        );
1901                    }
1902                }
1903
1904                prop_assert!(
1905                    short_period < long_period,
1906                    "Short period must be less than long period"
1907                );
1908
1909                let mut stream = ApoStream::try_new(params).unwrap();
1910                let mut stream_values = Vec::new();
1911                for &price in &data {
1912                    if let Some(val) = stream.update(price) {
1913                        stream_values.push(val);
1914                    } else {
1915                        stream_values.push(f64::NAN);
1916                    }
1917                }
1918
1919                for i in first_valid..out.len() {
1920                    if out[i].is_finite() && stream_values[i].is_finite() {
1921                        let diff = (out[i] - stream_values[i]).abs();
1922                        prop_assert!(
1923                            diff < 1e-10,
1924                            "Streaming mismatch at index {}: batch={}, stream={}, diff={}",
1925                            i,
1926                            out[i],
1927                            stream_values[i],
1928                            diff
1929                        );
1930                    }
1931                }
1932
1933                Ok(())
1934            })
1935            .map_err(|e| e.into())
1936    }
1937
1938    #[cfg(not(feature = "proptest"))]
1939    fn check_apo_property(
1940        test_name: &str,
1941        kernel: Kernel,
1942    ) -> Result<(), Box<dyn std::error::Error>> {
1943        skip_if_unsupported!(kernel, test_name);
1944        Ok(())
1945    }
1946
1947    macro_rules! generate_all_apo_tests {
1948        ($($test_fn:ident),*) => {
1949            paste::paste! {
1950                $(
1951                    #[test]
1952                    fn [<$test_fn _scalar_f64>]() {
1953                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1954                    }
1955                )*
1956                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1957                $(
1958                    #[test]
1959                    fn [<$test_fn _avx2_f64>]() {
1960                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1961                    }
1962                    #[test]
1963                    fn [<$test_fn _avx512_f64>]() {
1964                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1965                    }
1966                )*
1967                #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
1968                $(
1969                    #[test]
1970                    fn [<$test_fn _simd128_f64>]() {
1971                        let _ = $test_fn(stringify!([<$test_fn _simd128_f64>]), Kernel::Scalar);
1972                    }
1973                )*
1974            }
1975        }
1976    }
1977
1978    generate_all_apo_tests!(
1979        check_apo_partial_params,
1980        check_apo_accuracy,
1981        check_apo_default_candles,
1982        check_apo_zero_period,
1983        check_apo_empty_input,
1984        check_apo_streaming,
1985        check_apo_period_invalid,
1986        check_apo_very_small_dataset,
1987        check_apo_reinput,
1988        check_apo_nan_handling,
1989        check_apo_no_poison,
1990        check_apo_property
1991    );
1992
1993    fn check_batch_default_row(
1994        test: &str,
1995        kernel: Kernel,
1996    ) -> Result<(), Box<dyn std::error::Error>> {
1997        skip_if_unsupported!(kernel, test);
1998        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1999        let c = read_candles_from_csv(file)?;
2000        let output = ApoBatchBuilder::new()
2001            .kernel(kernel)
2002            .apply_candles(&c, "close")?;
2003        let def = ApoParams::default();
2004        let row = output.values_for(&def).expect("default row missing");
2005        assert_eq!(row.len(), c.close.len());
2006        Ok(())
2007    }
2008
2009    #[cfg(debug_assertions)]
2010    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
2011        skip_if_unsupported!(kernel, test);
2012
2013        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2014        let c = read_candles_from_csv(file)?;
2015
2016        let test_configs = vec![
2017            (2, 10, 2, 15, 30, 5),
2018            (5, 25, 5, 30, 50, 10),
2019            (10, 20, 5, 25, 45, 10),
2020            (12, 12, 0, 26, 26, 0),
2021            (3, 9, 3, 10, 20, 5),
2022        ];
2023
2024        for (cfg_idx, &(s_start, s_end, s_step, l_start, l_end, l_step)) in
2025            test_configs.iter().enumerate()
2026        {
2027            let output = ApoBatchBuilder::new()
2028                .kernel(kernel)
2029                .short_range(s_start, s_end, s_step)
2030                .long_range(l_start, l_end, l_step)
2031                .apply_candles(&c, "close")?;
2032
2033            for (idx, &val) in output.values.iter().enumerate() {
2034                if val.is_nan() {
2035                    continue;
2036                }
2037
2038                let bits = val.to_bits();
2039                let row = idx / output.cols;
2040                let col = idx % output.cols;
2041                let combo = &output.combos[row];
2042
2043                if bits == 0x11111111_11111111 {
2044                    panic!(
2045                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2046						 at row {} col {} (flat index {}) with params: short={}, long={}",
2047                        test,
2048                        cfg_idx,
2049                        val,
2050                        bits,
2051                        row,
2052                        col,
2053                        idx,
2054                        combo.short_period.unwrap_or(12),
2055                        combo.long_period.unwrap_or(26)
2056                    );
2057                }
2058
2059                if bits == 0x22222222_22222222 {
2060                    panic!(
2061                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2062						 at row {} col {} (flat index {}) with params: short={}, long={}",
2063                        test,
2064                        cfg_idx,
2065                        val,
2066                        bits,
2067                        row,
2068                        col,
2069                        idx,
2070                        combo.short_period.unwrap_or(12),
2071                        combo.long_period.unwrap_or(26)
2072                    );
2073                }
2074
2075                if bits == 0x33333333_33333333 {
2076                    panic!(
2077                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2078						 at row {} col {} (flat index {}) with params: short={}, long={}",
2079                        test,
2080                        cfg_idx,
2081                        val,
2082                        bits,
2083                        row,
2084                        col,
2085                        idx,
2086                        combo.short_period.unwrap_or(12),
2087                        combo.long_period.unwrap_or(26)
2088                    );
2089                }
2090            }
2091        }
2092
2093        Ok(())
2094    }
2095
2096    #[cfg(not(debug_assertions))]
2097    fn check_batch_no_poison(
2098        _test: &str,
2099        _kernel: Kernel,
2100    ) -> Result<(), Box<dyn std::error::Error>> {
2101        Ok(())
2102    }
2103
2104    macro_rules! gen_batch_tests {
2105        ($fn_name:ident) => {
2106            paste::paste! {
2107                #[test] fn [<$fn_name _scalar>]()      {
2108                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2109                }
2110                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2111                #[test] fn [<$fn_name _avx2>]()        {
2112                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2113                }
2114                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2115                #[test] fn [<$fn_name _avx512>]()      {
2116                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2117                }
2118                #[test] fn [<$fn_name _auto_detect>]() {
2119                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2120                }
2121            }
2122        };
2123    }
2124    gen_batch_tests!(check_batch_default_row);
2125    gen_batch_tests!(check_batch_no_poison);
2126
2127    #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
2128    #[test]
2129    fn test_apo_simd128_correctness() {
2130        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
2131        let short_period = 3;
2132        let long_period = 5;
2133        let params = ApoParams {
2134            short_period: Some(short_period),
2135            long_period: Some(long_period),
2136        };
2137        let input = ApoInput::from_slice(&data, params);
2138
2139        let scalar_output = apo_with_kernel(&input, Kernel::Scalar).unwrap();
2140
2141        let mut pure_scalar_output = vec![f64::NAN; data.len()];
2142        let first = 0;
2143        unsafe {
2144            apo_scalar(
2145                &data,
2146                short_period,
2147                long_period,
2148                first,
2149                &mut pure_scalar_output,
2150            );
2151        }
2152
2153        assert_eq!(scalar_output.values.len(), pure_scalar_output.len());
2154        for (i, (simd_val, scalar_val)) in scalar_output
2155            .values
2156            .iter()
2157            .zip(pure_scalar_output.iter())
2158            .enumerate()
2159        {
2160            if scalar_val.is_nan() {
2161                assert!(simd_val.is_nan(), "SIMD128 NaN mismatch at index {}", i);
2162            } else {
2163                assert!(
2164                    (scalar_val - simd_val).abs() < 1e-10,
2165                    "SIMD128 mismatch at index {}: scalar={}, simd128={}",
2166                    i,
2167                    scalar_val,
2168                    simd_val
2169                );
2170            }
2171        }
2172    }
2173}
2174
2175#[cfg(feature = "python")]
2176#[pyfunction(name = "apo")]
2177#[pyo3(signature = (data, short_period=10, long_period=20, kernel=None))]
2178pub fn apo_py<'py>(
2179    py: Python<'py>,
2180    data: numpy::PyReadonlyArray1<'py, f64>,
2181    short_period: usize,
2182    long_period: usize,
2183    kernel: Option<&str>,
2184) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
2185    use numpy::{IntoPyArray, PyArrayMethods};
2186
2187    let slice_in = data.as_slice()?;
2188    let kern = validate_kernel(kernel, false)?;
2189
2190    let params = ApoParams {
2191        short_period: Some(short_period),
2192        long_period: Some(long_period),
2193    };
2194    let apo_in = ApoInput::from_slice(slice_in, params);
2195
2196    let result_vec: Vec<f64> = py
2197        .allow_threads(|| apo_with_kernel(&apo_in, kern).map(|o| o.values))
2198        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2199
2200    Ok(result_vec.into_pyarray(py))
2201}
2202
2203#[cfg(feature = "python")]
2204#[pyclass(name = "ApoStream")]
2205pub struct ApoStreamPy {
2206    stream: ApoStream,
2207}
2208
2209#[cfg(feature = "python")]
2210#[pymethods]
2211impl ApoStreamPy {
2212    #[new]
2213    fn new(short_period: usize, long_period: usize) -> PyResult<Self> {
2214        let params = ApoParams {
2215            short_period: Some(short_period),
2216            long_period: Some(long_period),
2217        };
2218        let stream =
2219            ApoStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2220        Ok(ApoStreamPy { stream })
2221    }
2222
2223    fn update(&mut self, value: f64) -> Option<f64> {
2224        self.stream.update(value)
2225    }
2226}
2227
2228#[cfg(feature = "python")]
2229#[pyfunction(name = "apo_batch")]
2230#[pyo3(signature = (data, short_period_range, long_period_range, kernel=None))]
2231pub fn apo_batch_py<'py>(
2232    py: Python<'py>,
2233    data: numpy::PyReadonlyArray1<'py, f64>,
2234    short_period_range: (usize, usize, usize),
2235    long_period_range: (usize, usize, usize),
2236    kernel: Option<&str>,
2237) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
2238    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
2239    use pyo3::types::PyDict;
2240
2241    let slice_in = data.as_slice()?;
2242    let kern = validate_kernel(kernel, true)?;
2243
2244    let sweep = ApoBatchRange {
2245        short: short_period_range,
2246        long: long_period_range,
2247    };
2248    let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
2249    if combos.is_empty() {
2250        return Err(PyValueError::new_err("No valid parameter combinations"));
2251    }
2252    let rows = combos.len();
2253    let cols = slice_in.len();
2254
2255    let total = rows
2256        .checked_mul(cols)
2257        .ok_or_else(|| PyValueError::new_err("rows * cols overflow"))?;
2258
2259    let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2260    let slice_out = unsafe { out_arr.as_slice_mut()? };
2261
2262    let first = slice_in.iter().position(|x| !x.is_nan()).unwrap_or(0);
2263    let out_mu: &mut [MaybeUninit<f64>] = unsafe {
2264        core::slice::from_raw_parts_mut(
2265            slice_out.as_mut_ptr() as *mut MaybeUninit<f64>,
2266            slice_out.len(),
2267        )
2268    };
2269    let warm: Vec<usize> = std::iter::repeat(first).take(rows).collect();
2270    init_matrix_prefixes(out_mu, cols, &warm);
2271
2272    let combos = py
2273        .allow_threads(|| {
2274            let k = match kern {
2275                Kernel::Auto => detect_best_batch_kernel(),
2276                k => k,
2277            };
2278            let simd = match k {
2279                Kernel::Avx512Batch => Kernel::Avx512,
2280                Kernel::Avx2Batch => Kernel::Avx2,
2281                _ => Kernel::Scalar,
2282            };
2283            apo_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
2284        })
2285        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2286
2287    let dict = PyDict::new(py);
2288    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2289    dict.set_item(
2290        "short_periods",
2291        combos
2292            .iter()
2293            .map(|p| p.short_period.unwrap() as u64)
2294            .collect::<Vec<_>>()
2295            .into_pyarray(py),
2296    )?;
2297    dict.set_item(
2298        "long_periods",
2299        combos
2300            .iter()
2301            .map(|p| p.long_period.unwrap() as u64)
2302            .collect::<Vec<_>>()
2303            .into_pyarray(py),
2304    )?;
2305    Ok(dict)
2306}
2307
2308#[cfg(all(feature = "python", feature = "cuda"))]
2309use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
2310#[cfg(all(feature = "python", feature = "cuda"))]
2311use cust::context::Context as CudaContext;
2312#[cfg(all(feature = "python", feature = "cuda"))]
2313use std::sync::Arc;
2314
2315#[cfg(all(feature = "python", feature = "cuda"))]
2316#[pyclass(module = "ta_indicators.cuda", name = "DeviceArrayF32Apo", unsendable)]
2317pub struct DeviceArrayF32ApoPy {
2318    pub(crate) inner: Option<crate::cuda::moving_averages::apo_wrapper::DeviceArrayF32>,
2319    stream_handle: usize,
2320    _ctx_guard: Arc<CudaContext>,
2321    _device_id: u32,
2322}
2323
2324#[cfg(all(feature = "python", feature = "cuda"))]
2325#[pymethods]
2326impl DeviceArrayF32ApoPy {
2327    #[new]
2328    fn py_new() -> PyResult<Self> {
2329        Err(pyo3::exceptions::PyTypeError::new_err(
2330            "use factory methods from CUDA functions",
2331        ))
2332    }
2333
2334    #[getter]
2335    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2336        let inner = self
2337            .inner
2338            .as_ref()
2339            .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
2340        let d = PyDict::new(py);
2341        let itemsize = std::mem::size_of::<f32>();
2342        d.set_item("shape", (inner.rows, inner.cols))?;
2343        d.set_item("typestr", "<f4")?;
2344        d.set_item("strides", (inner.cols * itemsize, itemsize))?;
2345        let size = inner.rows.saturating_mul(inner.cols);
2346        let ptr_val: usize = if size == 0 {
2347            0
2348        } else {
2349            inner.buf.as_device_ptr().as_raw() as usize
2350        };
2351        d.set_item("data", (ptr_val, false))?;
2352        d.set_item("version", 3)?;
2353        Ok(d)
2354    }
2355
2356    fn __dlpack_device__(&self) -> PyResult<(i32, i32)> {
2357        Ok((2, self._device_id as i32))
2358    }
2359
2360    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2361    fn __dlpack__<'py>(
2362        &mut self,
2363        py: Python<'py>,
2364        stream: Option<PyObject>,
2365        max_version: Option<PyObject>,
2366        dl_device: Option<PyObject>,
2367        copy: Option<PyObject>,
2368    ) -> PyResult<PyObject> {
2369        let (kdl, alloc_dev) = self.__dlpack_device__()?;
2370        if let Some(dev_obj) = dl_device.as_ref() {
2371            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2372                if dev_ty != kdl || dev_id != alloc_dev {
2373                    let wants_copy = copy
2374                        .as_ref()
2375                        .and_then(|c| c.extract::<bool>(py).ok())
2376                        .unwrap_or(false);
2377                    if wants_copy {
2378                        return Err(PyValueError::new_err(
2379                            "device copy not implemented for __dlpack__",
2380                        ));
2381                    } else {
2382                        return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
2383                    }
2384                }
2385            }
2386        }
2387        let _ = stream;
2388
2389        let inner = self
2390            .inner
2391            .take()
2392            .ok_or_else(|| PyValueError::new_err("__dlpack__ may only be called once"))?;
2393        let crate::cuda::moving_averages::apo_wrapper::DeviceArrayF32 {
2394            buf, rows, cols, ..
2395        } = inner;
2396
2397        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2398
2399        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2400    }
2401}
2402
2403#[cfg(all(feature = "python", feature = "cuda"))]
2404#[pyfunction(name = "apo_cuda_batch_dev")]
2405#[pyo3(signature = (data_f32, short_range=(10,10,0), long_range=(20,20,0), device_id=0))]
2406pub fn apo_cuda_batch_dev_py(
2407    py: Python<'_>,
2408    data_f32: numpy::PyReadonlyArray1<'_, f32>,
2409    short_range: (usize, usize, usize),
2410    long_range: (usize, usize, usize),
2411    device_id: usize,
2412) -> PyResult<DeviceArrayF32ApoPy> {
2413    use crate::cuda::cuda_available;
2414    use crate::cuda::moving_averages::CudaApo;
2415    if !cuda_available() {
2416        return Err(PyValueError::new_err("CUDA not available"));
2417    }
2418    let slice = data_f32.as_slice()?;
2419    let sweep = ApoBatchRange {
2420        short: short_range,
2421        long: long_range,
2422    };
2423    let inner = py.allow_threads(|| {
2424        let cuda = CudaApo::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2425        cuda.apo_batch_dev(slice, &sweep)
2426            .map_err(|e| PyValueError::new_err(e.to_string()))
2427    })?;
2428    let ctx = inner.ctx();
2429    let dev_id = inner.device_id();
2430    Ok(DeviceArrayF32ApoPy {
2431        inner: Some(inner),
2432        stream_handle: 0,
2433        _ctx_guard: ctx,
2434        _device_id: dev_id,
2435    })
2436}
2437
2438#[cfg(all(feature = "python", feature = "cuda"))]
2439#[pyfunction(name = "apo_cuda_many_series_one_param_dev")]
2440#[pyo3(signature = (data_tm_f32, short_period, long_period, device_id=0))]
2441pub fn apo_cuda_many_series_one_param_dev_py(
2442    py: Python<'_>,
2443    data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
2444    short_period: usize,
2445    long_period: usize,
2446    device_id: usize,
2447) -> PyResult<DeviceArrayF32ApoPy> {
2448    use crate::cuda::cuda_available;
2449    use crate::cuda::moving_averages::CudaApo;
2450    use numpy::PyUntypedArrayMethods;
2451    if !cuda_available() {
2452        return Err(PyValueError::new_err("CUDA not available"));
2453    }
2454    if short_period == 0 || long_period == 0 || short_period >= long_period {
2455        return Err(PyValueError::new_err("invalid short/long period"));
2456    }
2457    let flat = data_tm_f32.as_slice()?;
2458    let rows = data_tm_f32.shape()[0];
2459    let cols = data_tm_f32.shape()[1];
2460    let params = ApoParams {
2461        short_period: Some(short_period),
2462        long_period: Some(long_period),
2463    };
2464    let inner = py.allow_threads(|| {
2465        let cuda = CudaApo::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2466        cuda.apo_many_series_one_param_time_major_dev(flat, cols, rows, &params)
2467            .map_err(|e| PyValueError::new_err(e.to_string()))
2468    })?;
2469    let ctx = inner.ctx();
2470    let dev_id = inner.device_id();
2471    Ok(DeviceArrayF32ApoPy {
2472        inner: Some(inner),
2473        stream_handle: 0,
2474        _ctx_guard: ctx,
2475        _device_id: dev_id,
2476    })
2477}
2478
2479#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2480#[wasm_bindgen]
2481pub fn apo_js(data: &[f64], short_period: usize, long_period: usize) -> Result<Vec<f64>, JsValue> {
2482    let params = ApoParams {
2483        short_period: Some(short_period),
2484        long_period: Some(long_period),
2485    };
2486    let input = ApoInput::from_slice(data, params);
2487
2488    let mut output = vec![0.0; data.len()];
2489
2490    apo_into_slice(&mut output, &input, Kernel::Auto)
2491        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2492
2493    Ok(output)
2494}
2495
2496#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2497#[wasm_bindgen]
2498pub fn apo_alloc(len: usize) -> *mut f64 {
2499    let mut v = Vec::<f64>::with_capacity(len);
2500    let ptr = v.as_mut_ptr();
2501    std::mem::forget(v);
2502    ptr
2503}
2504
2505#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2506#[wasm_bindgen]
2507pub fn apo_free(ptr: *mut f64, len: usize) {
2508    unsafe {
2509        let _ = Vec::from_raw_parts(ptr, len, len);
2510    }
2511}
2512
2513#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2514#[wasm_bindgen]
2515pub fn apo_into(
2516    in_ptr: *const f64,
2517    out_ptr: *mut f64,
2518    len: usize,
2519    short_period: usize,
2520    long_period: usize,
2521) -> Result<(), JsValue> {
2522    if in_ptr.is_null() || out_ptr.is_null() {
2523        return Err(JsValue::from_str("Null pointer passed to apo_into"));
2524    }
2525    unsafe {
2526        let data = std::slice::from_raw_parts(in_ptr, len);
2527        let params = ApoParams {
2528            short_period: Some(short_period),
2529            long_period: Some(long_period),
2530        };
2531        let input = ApoInput::from_slice(data, params);
2532
2533        if in_ptr == out_ptr {
2534            let mut tmp = vec![0.0; len];
2535            apo_into_slice(&mut tmp, &input, Kernel::Auto)
2536                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2537            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2538            out.copy_from_slice(&tmp);
2539        } else {
2540            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2541            apo_into_slice(out, &input, Kernel::Auto)
2542                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2543        }
2544    }
2545    Ok(())
2546}
2547
2548#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2549#[wasm_bindgen]
2550pub fn apo_batch_js(
2551    data: &[f64],
2552    short_period_start: usize,
2553    short_period_end: usize,
2554    short_period_step: usize,
2555    long_period_start: usize,
2556    long_period_end: usize,
2557    long_period_step: usize,
2558) -> Result<Vec<f64>, JsValue> {
2559    let sweep = ApoBatchRange {
2560        short: (short_period_start, short_period_end, short_period_step),
2561        long: (long_period_start, long_period_end, long_period_step),
2562    };
2563
2564    apo_batch_inner(data, &sweep, Kernel::Scalar, false)
2565        .map(|output| output.values)
2566        .map_err(|e| JsValue::from_str(&e.to_string()))
2567}
2568
2569#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2570#[wasm_bindgen]
2571pub fn apo_batch_metadata_js(
2572    short_period_start: usize,
2573    short_period_end: usize,
2574    short_period_step: usize,
2575    long_period_start: usize,
2576    long_period_end: usize,
2577    long_period_step: usize,
2578) -> Result<Vec<f64>, JsValue> {
2579    let sweep = ApoBatchRange {
2580        short: (short_period_start, short_period_end, short_period_step),
2581        long: (long_period_start, long_period_end, long_period_step),
2582    };
2583
2584    let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2585    let mut metadata = Vec::with_capacity(combos.len() * 2);
2586
2587    for combo in combos {
2588        metadata.push(combo.short_period.unwrap() as f64);
2589        metadata.push(combo.long_period.unwrap() as f64);
2590    }
2591
2592    Ok(metadata)
2593}
2594
2595#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2596#[wasm_bindgen]
2597pub fn apo_batch_into(
2598    in_ptr: *const f64,
2599    out_ptr: *mut f64,
2600    len: usize,
2601    short_start: usize,
2602    short_end: usize,
2603    short_step: usize,
2604    long_start: usize,
2605    long_end: usize,
2606    long_step: usize,
2607) -> Result<usize, JsValue> {
2608    if in_ptr.is_null() || out_ptr.is_null() {
2609        return Err(JsValue::from_str("Null pointer passed to apo_batch_into"));
2610    }
2611    unsafe {
2612        let data = std::slice::from_raw_parts(in_ptr, len);
2613        let sweep = ApoBatchRange {
2614            short: (short_start, short_end, short_step),
2615            long: (long_start, long_end, long_step),
2616        };
2617        let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2618        let rows = combos.len();
2619        let cols = len;
2620
2621        let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
2622        apo_batch_inner_into(data, &sweep, detect_best_kernel(), false, out)
2623            .map_err(|e| JsValue::from_str(&e.to_string()))?;
2624        Ok(rows)
2625    }
2626}
2627
2628#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2629#[derive(Serialize, Deserialize)]
2630pub struct ApoBatchConfig {
2631    pub short_period_range: (usize, usize, usize),
2632    pub long_period_range: (usize, usize, usize),
2633}
2634
2635#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2636#[derive(Serialize, Deserialize)]
2637pub struct ApoBatchJsOutput {
2638    pub values: Vec<f64>,
2639    pub combos: Vec<ApoParams>,
2640    pub rows: usize,
2641    pub cols: usize,
2642}
2643
2644#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2645#[wasm_bindgen(js_name = apo_batch)]
2646pub fn apo_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2647    let cfg: ApoBatchConfig = serde_wasm_bindgen::from_value(config)
2648        .map_err(|e| JsValue::from_str(&format!("Invalid config: {e}")))?;
2649    let sweep = ApoBatchRange {
2650        short: cfg.short_period_range,
2651        long: cfg.long_period_range,
2652    };
2653    let out = apo_batch_inner(data, &sweep, detect_best_kernel(), false)
2654        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2655    let js = ApoBatchJsOutput {
2656        values: out.values,
2657        combos: out.combos,
2658        rows: out.rows,
2659        cols: out.cols,
2660    };
2661    serde_wasm_bindgen::to_value(&js)
2662        .map_err(|e| JsValue::from_str(&format!("Serialization error: {e}")))
2663}