Skip to main content

vector_ta/indicators/
aso.rs

1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::{PyDict, PyList};
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::utilities::data_loader::{source_type, Candles};
16use crate::utilities::enums::Kernel;
17use crate::utilities::helpers::{
18    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
19    make_uninit_matrix,
20};
21#[cfg(feature = "python")]
22use crate::utilities::kernel_validation::validate_kernel;
23use aligned_vec::{AVec, CACHELINE_ALIGN};
24
25#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
26use core::arch::x86_64::*;
27
28#[cfg(not(target_arch = "wasm32"))]
29use rayon::prelude::*;
30
31use std::convert::AsRef;
32use std::error::Error;
33use std::mem::MaybeUninit;
34use thiserror::Error;
35
36#[derive(Debug, Clone)]
37pub enum AsoData<'a> {
38    Candles {
39        candles: &'a Candles,
40        source: &'a str,
41    },
42    Slices {
43        open: &'a [f64],
44        high: &'a [f64],
45        low: &'a [f64],
46        close: &'a [f64],
47    },
48}
49
50#[derive(Debug, Clone)]
51pub struct AsoOutput {
52    pub bulls: Vec<f64>,
53    pub bears: Vec<f64>,
54}
55
56#[derive(Debug, Clone)]
57#[cfg_attr(
58    all(target_arch = "wasm32", feature = "wasm"),
59    derive(Serialize, Deserialize)
60)]
61pub struct AsoParams {
62    pub period: Option<usize>,
63    pub mode: Option<usize>,
64}
65
66impl Default for AsoParams {
67    fn default() -> Self {
68        Self {
69            period: Some(10),
70            mode: Some(0),
71        }
72    }
73}
74
75#[derive(Debug, Clone)]
76pub struct AsoInput<'a> {
77    pub data: AsoData<'a>,
78    pub params: AsoParams,
79}
80
81impl<'a> AsRef<[f64]> for AsoInput<'a> {
82    fn as_ref(&self) -> &[f64] {
83        match &self.data {
84            AsoData::Candles { candles, source } => source_type(candles, source),
85            AsoData::Slices { close, .. } => close,
86        }
87    }
88}
89
90impl<'a> AsoInput<'a> {
91    #[inline]
92    pub fn from_candles(c: &'a Candles, s: &'a str, p: AsoParams) -> Self {
93        Self {
94            data: AsoData::Candles {
95                candles: c,
96                source: s,
97            },
98            params: p,
99        }
100    }
101
102    #[inline]
103    pub fn from_slices(
104        open: &'a [f64],
105        high: &'a [f64],
106        low: &'a [f64],
107        close: &'a [f64],
108        p: AsoParams,
109    ) -> Self {
110        Self {
111            data: AsoData::Slices {
112                open,
113                high,
114                low,
115                close,
116            },
117            params: p,
118        }
119    }
120
121    #[inline]
122    pub fn with_default_candles(c: &'a Candles) -> Self {
123        Self::from_candles(c, "close", AsoParams::default())
124    }
125
126    #[inline]
127    pub fn get_period(&self) -> usize {
128        self.params.period.unwrap_or(10)
129    }
130
131    #[inline]
132    pub fn get_mode(&self) -> usize {
133        self.params.mode.unwrap_or(0)
134    }
135}
136
137#[derive(Copy, Clone, Debug)]
138pub struct AsoBuilder {
139    period: Option<usize>,
140    mode: Option<usize>,
141    kernel: Kernel,
142}
143
144impl Default for AsoBuilder {
145    fn default() -> Self {
146        Self {
147            period: None,
148            mode: None,
149            kernel: Kernel::Auto,
150        }
151    }
152}
153
154impl AsoBuilder {
155    #[inline(always)]
156    pub fn new() -> Self {
157        Self::default()
158    }
159
160    #[inline(always)]
161    pub fn period(mut self, val: usize) -> Self {
162        self.period = Some(val);
163        self
164    }
165
166    #[inline(always)]
167    pub fn mode(mut self, val: usize) -> Self {
168        self.mode = Some(val);
169        self
170    }
171
172    #[inline(always)]
173    pub fn kernel(mut self, k: Kernel) -> Self {
174        self.kernel = k;
175        self
176    }
177
178    #[inline(always)]
179    pub fn apply(self, c: &Candles) -> Result<AsoOutput, AsoError> {
180        self.apply_candles(c, "close")
181    }
182
183    #[inline(always)]
184    pub fn apply_candles(self, c: &Candles, s: &str) -> Result<AsoOutput, AsoError> {
185        let p = AsoParams {
186            period: self.period,
187            mode: self.mode,
188        };
189        let i = AsoInput::from_candles(c, s, p);
190        aso_with_kernel(&i, self.kernel)
191    }
192
193    #[inline(always)]
194    pub fn apply_slices(
195        self,
196        open: &[f64],
197        high: &[f64],
198        low: &[f64],
199        close: &[f64],
200    ) -> Result<AsoOutput, AsoError> {
201        let p = AsoParams {
202            period: self.period,
203            mode: self.mode,
204        };
205        let i = AsoInput::from_slices(open, high, low, close, p);
206        aso_with_kernel(&i, self.kernel)
207    }
208
209    #[inline(always)]
210    pub fn into_stream(self) -> Result<AsoStream, AsoError> {
211        let p = AsoParams {
212            period: self.period,
213            mode: self.mode,
214        };
215        AsoStream::try_new(p)
216    }
217}
218
219#[derive(Debug, Error)]
220pub enum AsoError {
221    #[error("aso: Input data slice is empty.")]
222    EmptyInputData,
223
224    #[error("aso: All values are NaN.")]
225    AllValuesNaN,
226
227    #[error("aso: Invalid period: period = {period}, data length = {data_len}")]
228    InvalidPeriod { period: usize, data_len: usize },
229
230    #[error("aso: Not enough valid data: needed = {needed}, valid = {valid}")]
231    NotEnoughValidData { needed: usize, valid: usize },
232
233    #[error("aso: Invalid mode: mode = {mode}, must be 0, 1, or 2")]
234    InvalidMode { mode: usize },
235
236    #[error("aso: Required OHLC data is missing or has mismatched lengths")]
237    MissingData,
238
239    #[error("aso: Output length mismatch: expected = {expected}, got = {got}")]
240    OutputLengthMismatch { expected: usize, got: usize },
241
242    #[error("aso: Invalid range: start={start} end={end} step={step}")]
243    InvalidRange {
244        start: usize,
245        end: usize,
246        step: usize,
247    },
248
249    #[error("aso: Invalid kernel for batch path: {0:?}")]
250    InvalidKernelForBatch(Kernel),
251}
252
253#[inline]
254pub fn aso(input: &AsoInput) -> Result<AsoOutput, AsoError> {
255    aso_with_kernel(input, Kernel::Auto)
256}
257
258pub fn aso_with_kernel(input: &AsoInput, kernel: Kernel) -> Result<AsoOutput, AsoError> {
259    let (open, high, low, close, period, mode, first, chosen) = aso_prepare(input, kernel)?;
260
261    let len = close.len();
262
263    let mut bulls = alloc_with_nan_prefix(len, first + period - 1);
264    let mut bears = alloc_with_nan_prefix(len, first + period - 1);
265
266    aso_compute_into(
267        open, high, low, close, period, mode, first, chosen, &mut bulls, &mut bears,
268    );
269
270    Ok(AsoOutput { bulls, bears })
271}
272
273#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
274#[inline]
275pub fn aso_into(
276    input: &AsoInput,
277    bulls_out: &mut [f64],
278    bears_out: &mut [f64],
279) -> Result<(), AsoError> {
280    let (open, high, low, close, period, mode, first, chosen) = aso_prepare(input, Kernel::Auto)?;
281
282    if bulls_out.len() != close.len() || bears_out.len() != close.len() {
283        return Err(AsoError::OutputLengthMismatch {
284            expected: close.len(),
285            got: bulls_out.len().min(bears_out.len()),
286        });
287    }
288
289    let warm = first + period - 1;
290    let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
291    let warm = warm.min(close.len());
292    for v in &mut bulls_out[..warm] {
293        *v = qnan;
294    }
295    for v in &mut bears_out[..warm] {
296        *v = qnan;
297    }
298
299    aso_compute_into(
300        open, high, low, close, period, mode, first, chosen, bulls_out, bears_out,
301    );
302
303    Ok(())
304}
305
306#[inline]
307pub fn aso_into_slices(
308    bulls_dst: &mut [f64],
309    bears_dst: &mut [f64],
310    input: &AsoInput,
311    kern: Kernel,
312) -> Result<(), AsoError> {
313    let (open, high, low, close, period, mode, first, chosen) = aso_prepare(input, kern)?;
314
315    if bulls_dst.len() != close.len() || bears_dst.len() != close.len() {
316        return Err(AsoError::OutputLengthMismatch {
317            expected: close.len(),
318            got: bulls_dst.len().min(bears_dst.len()),
319        });
320    }
321
322    aso_compute_into(
323        open, high, low, close, period, mode, first, chosen, bulls_dst, bears_dst,
324    );
325
326    let warm = first + period - 1;
327    for v in &mut bulls_dst[..warm] {
328        *v = f64::NAN;
329    }
330    for v in &mut bears_dst[..warm] {
331        *v = f64::NAN;
332    }
333
334    Ok(())
335}
336
337#[inline(always)]
338fn aso_prepare<'a>(
339    input: &'a AsoInput,
340    kernel: Kernel,
341) -> Result<
342    (
343        &'a [f64],
344        &'a [f64],
345        &'a [f64],
346        &'a [f64],
347        usize,
348        usize,
349        usize,
350        Kernel,
351    ),
352    AsoError,
353> {
354    let (open, high, low, close) = match &input.data {
355        AsoData::Candles { candles: c, .. } => (&c.open[..], &c.high[..], &c.low[..], &c.close[..]),
356        AsoData::Slices {
357            open,
358            high,
359            low,
360            close,
361        } => (*open, *high, *low, *close),
362    };
363
364    let len = close.len();
365    if len == 0 {
366        return Err(AsoError::EmptyInputData);
367    }
368
369    if open.len() != len || high.len() != len || low.len() != len {
370        return Err(AsoError::MissingData);
371    }
372
373    let first = close
374        .iter()
375        .position(|x| !x.is_nan())
376        .ok_or(AsoError::AllValuesNaN)?;
377
378    let period = input.get_period();
379    let mode = input.get_mode();
380
381    if period == 0 || period > len {
382        return Err(AsoError::InvalidPeriod {
383            period,
384            data_len: len,
385        });
386    }
387
388    if mode > 2 {
389        return Err(AsoError::InvalidMode { mode });
390    }
391
392    if len - first < period {
393        return Err(AsoError::NotEnoughValidData {
394            needed: period,
395            valid: len - first,
396        });
397    }
398
399    let mut chosen = match kernel {
400        Kernel::Auto => detect_best_kernel(),
401        k => k,
402    };
403
404    #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
405    if matches!(kernel, Kernel::Auto) && matches!(chosen, Kernel::Avx512 | Kernel::Avx512Batch) {
406        chosen = Kernel::Avx2;
407    }
408
409    Ok((open, high, low, close, period, mode, first, chosen))
410}
411
412#[inline(always)]
413fn aso_compute_into(
414    open: &[f64],
415    high: &[f64],
416    low: &[f64],
417    close: &[f64],
418    period: usize,
419    mode: usize,
420    first: usize,
421    kernel: Kernel,
422    out_bulls: &mut [f64],
423    out_bears: &mut [f64],
424) {
425    unsafe {
426        #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
427        {
428            if matches!(kernel, Kernel::Scalar | Kernel::ScalarBatch) {
429                aso_simd128(
430                    open, high, low, close, period, mode, first, out_bulls, out_bears,
431                );
432                return;
433            }
434        }
435
436        match kernel {
437            Kernel::Scalar | Kernel::ScalarBatch => aso_scalar(
438                open, high, low, close, period, mode, first, out_bulls, out_bears,
439            ),
440            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
441            Kernel::Avx2 | Kernel::Avx2Batch => aso_avx2(
442                open, high, low, close, period, mode, first, out_bulls, out_bears,
443            ),
444            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
445            Kernel::Avx512 | Kernel::Avx512Batch => aso_avx512(
446                open, high, low, close, period, mode, first, out_bulls, out_bears,
447            ),
448            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
449            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => aso_scalar(
450                open, high, low, close, period, mode, first, out_bulls, out_bears,
451            ),
452            _ => unreachable!(),
453        }
454    }
455}
456
457#[inline]
458pub fn aso_scalar(
459    open: &[f64],
460    high: &[f64],
461    low: &[f64],
462    close: &[f64],
463    period: usize,
464    mode: usize,
465    first_val: usize,
466    out_bulls: &mut [f64],
467    out_bears: &mut [f64],
468) {
469    let len = close.len();
470    if len == 0 {
471        return;
472    }
473
474    let warm = first_val + period - 1;
475
476    const DEQUE_THRESHOLD: usize = 64;
477    if period <= DEQUE_THRESHOLD {
478        let mut ring_b = vec![0.0; period];
479        let mut ring_e = vec![0.0; period];
480        let mut sum_b = 0.0;
481        let mut sum_e = 0.0;
482        let mut head = 0usize;
483        let mut filled = 0usize;
484
485        for i in first_val..len {
486            let intrarange = high[i] - low[i];
487            let k1 = if intrarange == 0.0 { 1.0 } else { intrarange };
488            let intrabarbulls = (((close[i] - low[i]) + (high[i] - open[i])) * 50.0) / k1;
489            let intrabarbears = (((high[i] - close[i]) + (open[i] - low[i])) * 50.0) / k1;
490
491            if i >= warm {
492                let start = i + 1 - period;
493
494                let mut gl = f64::MAX;
495                let mut gh = f64::MIN;
496                for j in start..=i {
497                    let lj = unsafe { *low.get_unchecked(j) };
498                    let hj = unsafe { *high.get_unchecked(j) };
499                    if lj < gl {
500                        gl = lj;
501                    }
502                    if hj > gh {
503                        gh = hj;
504                    }
505                }
506                let gopen = unsafe { *open.get_unchecked(start) };
507                let gr = gh - gl;
508                let k2 = if gr == 0.0 { 1.0 } else { gr };
509
510                let groupbulls = (((close[i] - gl) + (gh - gopen)) * 50.0) / k2;
511                let groupbears = (((gh - close[i]) + (gopen - gl)) * 50.0) / k2;
512
513                let b = match mode {
514                    0 => 0.5 * (intrabarbulls + groupbulls),
515                    1 => intrabarbulls,
516                    2 => groupbulls,
517                    _ => 0.5 * (intrabarbulls + groupbulls),
518                };
519                let e = match mode {
520                    0 => 0.5 * (intrabarbears + groupbears),
521                    1 => intrabarbears,
522                    2 => groupbears,
523                    _ => 0.5 * (intrabarbears + groupbears),
524                };
525
526                let old_b = if filled == period { ring_b[head] } else { 0.0 };
527                let old_e = if filled == period { ring_e[head] } else { 0.0 };
528                sum_b += b - old_b;
529                sum_e += e - old_e;
530                ring_b[head] = b;
531                ring_e[head] = e;
532                head = (head + 1) % period;
533                if filled < period {
534                    filled += 1;
535                }
536
537                let n = filled;
538                unsafe {
539                    *out_bulls.get_unchecked_mut(i) = sum_b / n as f64;
540                    *out_bears.get_unchecked_mut(i) = sum_e / n as f64;
541                }
542            }
543        }
544        return;
545    }
546
547    let mut ring_b = vec![0.0f64; period];
548    let mut ring_e = vec![0.0f64; period];
549    let mut sum_b = 0.0f64;
550    let mut sum_e = 0.0f64;
551    let mut rhead = 0usize;
552    let mut filled = 0usize;
553
554    let mut dq_min = vec![0usize; period];
555    let mut dq_max = vec![0usize; period];
556    let mut min_head = 0usize;
557    let mut min_tail = 0usize;
558    let mut min_len = 0usize;
559    let mut max_head = 0usize;
560    let mut max_tail = 0usize;
561    let mut max_len = 0usize;
562
563    for i in first_val..len {
564        let oi = unsafe { *open.get_unchecked(i) };
565        let hi = unsafe { *high.get_unchecked(i) };
566        let li = unsafe { *low.get_unchecked(i) };
567        let ci = unsafe { *close.get_unchecked(i) };
568
569        while min_len > 0 {
570            let back = if min_tail == 0 {
571                period - 1
572            } else {
573                min_tail - 1
574            };
575            let j = unsafe { *dq_min.get_unchecked(back) };
576            let lj = unsafe { *low.get_unchecked(j) };
577            if li <= lj {
578                min_tail = back;
579                min_len -= 1;
580            } else {
581                break;
582            }
583        }
584        if min_len == period {
585            min_head += 1;
586            if min_head == period {
587                min_head = 0;
588            }
589            min_len -= 1;
590        }
591        unsafe {
592            *dq_min.get_unchecked_mut(min_tail) = i;
593        }
594        min_tail += 1;
595        if min_tail == period {
596            min_tail = 0;
597        }
598        min_len += 1;
599
600        while max_len > 0 {
601            let back = if max_tail == 0 {
602                period - 1
603            } else {
604                max_tail - 1
605            };
606            let j = unsafe { *dq_max.get_unchecked(back) };
607            let hj = unsafe { *high.get_unchecked(j) };
608            if hi >= hj {
609                max_tail = back;
610                max_len -= 1;
611            } else {
612                break;
613            }
614        }
615        if max_len == period {
616            max_head += 1;
617            if max_head == period {
618                max_head = 0;
619            }
620            max_len -= 1;
621        }
622        unsafe {
623            *dq_max.get_unchecked_mut(max_tail) = i;
624        }
625        max_tail += 1;
626        if max_tail == period {
627            max_tail = 0;
628        }
629        max_len += 1;
630
631        if i >= warm {
632            let start = i + 1 - period;
633
634            while min_len > 0 && unsafe { *dq_min.get_unchecked(min_head) } < start {
635                min_head += 1;
636                if min_head == period {
637                    min_head = 0;
638                }
639                min_len -= 1;
640            }
641            while max_len > 0 && unsafe { *dq_max.get_unchecked(max_head) } < start {
642                max_head += 1;
643                if max_head == period {
644                    max_head = 0;
645                }
646                max_len -= 1;
647            }
648
649            debug_assert!(min_len > 0 && max_len > 0);
650            let gl = unsafe {
651                let idx = *dq_min.get_unchecked(min_head);
652                *low.get_unchecked(idx)
653            };
654            let gh = unsafe {
655                let idx = *dq_max.get_unchecked(max_head);
656                *high.get_unchecked(idx)
657            };
658            let gopen = unsafe { *open.get_unchecked(start) };
659
660            let intrarange = hi - li;
661            let inv_k1 = if intrarange != 0.0 {
662                1.0 / intrarange
663            } else {
664                1.0
665            };
666            let scale1 = 50.0 * inv_k1;
667            let intrabarbulls = ((ci - li) + (hi - oi)) * scale1;
668            let intrabarbears = ((hi - ci) + (oi - li)) * scale1;
669
670            let gr = gh - gl;
671            let inv_k2 = if gr != 0.0 { 1.0 / gr } else { 1.0 };
672            let scale2 = 50.0 * inv_k2;
673            let groupbulls = ((ci - gl) + (gh - gopen)) * scale2;
674            let groupbears = ((gh - ci) + (gopen - gl)) * scale2;
675
676            let b = if mode == 0 {
677                0.5 * (intrabarbulls + groupbulls)
678            } else if mode == 1 {
679                intrabarbulls
680            } else {
681                groupbulls
682            };
683            let e = if mode == 0 {
684                0.5 * (intrabarbears + groupbears)
685            } else if mode == 1 {
686                intrabarbears
687            } else {
688                groupbears
689            };
690
691            let old_b = if filled == period {
692                unsafe { *ring_b.get_unchecked(rhead) }
693            } else {
694                0.0
695            };
696            let old_e = if filled == period {
697                unsafe { *ring_e.get_unchecked(rhead) }
698            } else {
699                0.0
700            };
701            sum_b += b - old_b;
702            sum_e += e - old_e;
703            unsafe {
704                *ring_b.get_unchecked_mut(rhead) = b;
705                *ring_e.get_unchecked_mut(rhead) = e;
706            }
707            rhead += 1;
708            if rhead == period {
709                rhead = 0;
710            }
711            if filled < period {
712                filled += 1;
713            }
714
715            let n = filled as f64;
716            unsafe {
717                *out_bulls.get_unchecked_mut(i) = sum_b / n;
718                *out_bears.get_unchecked_mut(i) = sum_e / n;
719            }
720        }
721    }
722}
723
724#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
725#[inline]
726unsafe fn aso_simd128(
727    open: &[f64],
728    high: &[f64],
729    low: &[f64],
730    close: &[f64],
731    period: usize,
732    mode: usize,
733    first_val: usize,
734    out_bulls: &mut [f64],
735    out_bears: &mut [f64],
736) {
737    use core::arch::wasm32::*;
738
739    aso_scalar(
740        open, high, low, close, period, mode, first_val, out_bulls, out_bears,
741    );
742}
743
744#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
745#[target_feature(enable = "avx2,fma")]
746unsafe fn aso_avx2(
747    open: &[f64],
748    high: &[f64],
749    low: &[f64],
750    close: &[f64],
751    period: usize,
752    mode: usize,
753    first_val: usize,
754    out_bulls: &mut [f64],
755    out_bears: &mut [f64],
756) {
757    aso_scalar(
758        open, high, low, close, period, mode, first_val, out_bulls, out_bears,
759    );
760}
761
762#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
763#[target_feature(enable = "avx512f,fma")]
764unsafe fn aso_avx512(
765    open: &[f64],
766    high: &[f64],
767    low: &[f64],
768    close: &[f64],
769    period: usize,
770    mode: usize,
771    first_val: usize,
772    out_bulls: &mut [f64],
773    out_bears: &mut [f64],
774) {
775    aso_scalar(
776        open, high, low, close, period, mode, first_val, out_bulls, out_bears,
777    );
778}
779
780#[derive(Clone, Debug)]
781pub struct AsoBatchRange {
782    pub period: (usize, usize, usize),
783    pub mode: (usize, usize, usize),
784}
785
786impl Default for AsoBatchRange {
787    fn default() -> Self {
788        Self {
789            period: (10, 259, 1),
790            mode: (0, 0, 0),
791        }
792    }
793}
794
795#[derive(Clone, Debug, Default)]
796pub struct AsoBatchBuilder {
797    range: AsoBatchRange,
798    kernel: Kernel,
799}
800
801impl AsoBatchBuilder {
802    pub fn new() -> Self {
803        Self::default()
804    }
805
806    pub fn kernel(mut self, k: Kernel) -> Self {
807        self.kernel = k;
808        self
809    }
810
811    pub fn period_range(mut self, s: usize, e: usize, st: usize) -> Self {
812        self.range.period = (s, e, st);
813        self
814    }
815
816    pub fn period_static(mut self, p: usize) -> Self {
817        self.range.period = (p, p, 1);
818        self
819    }
820
821    pub fn mode_range(mut self, s: usize, e: usize, st: usize) -> Self {
822        self.range.mode = (s, e, st);
823        self
824    }
825
826    pub fn mode_static(mut self, m: usize) -> Self {
827        self.range.mode = (m, m, 1);
828        self
829    }
830
831    pub fn apply_candles(self, c: &Candles) -> Result<AsoBatchOutput, AsoError> {
832        let k = match self.kernel {
833            Kernel::Scalar => Kernel::ScalarBatch,
834            Kernel::Avx2 => Kernel::Avx2Batch,
835            Kernel::Avx512 => Kernel::Avx512Batch,
836            other => other,
837        };
838        aso_batch_with_kernel(&c.open, &c.high, &c.low, &c.close, &self.range, k)
839    }
840
841    pub fn apply_slices(
842        self,
843        o: &[f64],
844        h: &[f64],
845        l: &[f64],
846        c: &[f64],
847    ) -> Result<AsoBatchOutput, AsoError> {
848        let k = match self.kernel {
849            Kernel::Scalar => Kernel::ScalarBatch,
850            Kernel::Avx2 => Kernel::Avx2Batch,
851            Kernel::Avx512 => Kernel::Avx512Batch,
852            other => other,
853        };
854        aso_batch_with_kernel(o, h, l, c, &self.range, k)
855    }
856
857    pub fn with_default_candles(c: &Candles) -> Result<AsoBatchOutput, AsoError> {
858        Self::default().apply_candles(c)
859    }
860
861    pub fn with_default_slices(
862        o: &[f64],
863        h: &[f64],
864        l: &[f64],
865        c: &[f64],
866        k: Kernel,
867    ) -> Result<AsoBatchOutput, AsoError> {
868        Self::new().kernel(k).apply_slices(o, h, l, c)
869    }
870}
871
872#[derive(Clone, Debug)]
873pub struct AsoBatchOutput {
874    pub bulls: Vec<f64>,
875    pub bears: Vec<f64>,
876    pub combos: Vec<AsoParams>,
877    pub rows: usize,
878    pub cols: usize,
879}
880
881impl AsoBatchOutput {
882    #[inline]
883    pub fn bulls_row(&self, row: usize) -> &[f64] {
884        let s = row * self.cols;
885        &self.bulls[s..s + self.cols]
886    }
887
888    #[inline]
889    pub fn bears_row(&self, row: usize) -> &[f64] {
890        let s = row * self.cols;
891        &self.bears[s..s + self.cols]
892    }
893
894    #[inline]
895    pub fn row_for_params(&self, p: &AsoParams) -> Option<usize> {
896        self.combos
897            .iter()
898            .position(|c| c.period == p.period && c.mode == p.mode)
899    }
900
901    #[inline]
902    pub fn values_for(&self, p: &AsoParams) -> Option<(&[f64], &[f64])> {
903        self.row_for_params(p)
904            .map(|row| (self.bulls_row(row), self.bears_row(row)))
905    }
906}
907
908#[inline(always)]
909fn expand_grid_aso(r: &AsoBatchRange) -> Result<Vec<AsoParams>, AsoError> {
910    fn axis_usize((s, e, st): (usize, usize, usize)) -> Result<Vec<usize>, AsoError> {
911        if st == 0 || s == e {
912            return Ok(vec![s]);
913        }
914        let mut v = Vec::new();
915        if s < e {
916            let mut cur = s;
917            while cur <= e {
918                v.push(cur);
919                let next = cur.saturating_add(st);
920                if next == cur {
921                    break;
922                }
923                cur = next;
924            }
925        } else {
926            let mut cur = s;
927            while cur >= e {
928                v.push(cur);
929                let next = cur.saturating_sub(st);
930                if next == cur {
931                    break;
932                }
933                cur = next;
934                if cur == 0 && e > 0 {
935                    break;
936                }
937            }
938        }
939        if v.is_empty() {
940            return Err(AsoError::InvalidRange {
941                start: s,
942                end: e,
943                step: st,
944            });
945        }
946        Ok(v)
947    }
948
949    let ps = axis_usize(r.period)?;
950    let ms = axis_usize(r.mode)?;
951    let total = ps
952        .len()
953        .checked_mul(ms.len())
954        .ok_or(AsoError::InvalidRange {
955            start: ps.len(),
956            end: ms.len(),
957            step: 0,
958        })?;
959    let mut out = Vec::with_capacity(total);
960    for &p in &ps {
961        for &m in &ms {
962            out.push(AsoParams {
963                period: Some(p),
964                mode: Some(m),
965            });
966        }
967    }
968    Ok(out)
969}
970
971pub fn aso_batch_with_kernel(
972    open: &[f64],
973    high: &[f64],
974    low: &[f64],
975    close: &[f64],
976    sweep: &AsoBatchRange,
977    k: Kernel,
978) -> Result<AsoBatchOutput, AsoError> {
979    let combos = expand_grid_aso(sweep)?;
980    let rows = combos.len();
981    let cols = close.len();
982
983    if cols == 0 {
984        return Err(AsoError::EmptyInputData);
985    }
986    if open.len() != cols || high.len() != cols || low.len() != cols {
987        return Err(AsoError::MissingData);
988    }
989
990    let first = close
991        .iter()
992        .position(|x| !x.is_nan())
993        .ok_or(AsoError::AllValuesNaN)?;
994    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
995
996    if cols - first < max_p {
997        return Err(AsoError::NotEnoughValidData {
998            needed: max_p,
999            valid: cols - first,
1000        });
1001    }
1002
1003    let mut bulls_mu = make_uninit_matrix(rows, cols);
1004    let mut bears_mu = make_uninit_matrix(rows, cols);
1005
1006    let warm: Vec<usize> = combos
1007        .iter()
1008        .map(|c| first + c.period.unwrap() - 1)
1009        .collect();
1010    init_matrix_prefixes(&mut bulls_mu, cols, &warm);
1011    init_matrix_prefixes(&mut bears_mu, cols, &warm);
1012
1013    let mut guard_b = core::mem::ManuallyDrop::new(bulls_mu);
1014    let mut guard_e = core::mem::ManuallyDrop::new(bears_mu);
1015    let bulls_out: &mut [f64] =
1016        unsafe { core::slice::from_raw_parts_mut(guard_b.as_mut_ptr() as *mut f64, guard_b.len()) };
1017    let bears_out: &mut [f64] =
1018        unsafe { core::slice::from_raw_parts_mut(guard_e.as_mut_ptr() as *mut f64, guard_e.len()) };
1019
1020    let mut actual = match k {
1021        Kernel::Auto => detect_best_batch_kernel(),
1022        Kernel::ScalarBatch | Kernel::Avx2Batch | Kernel::Avx512Batch => k,
1023        other => return Err(AsoError::InvalidKernelForBatch(other)),
1024    };
1025
1026    #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1027    if matches!(k, Kernel::Auto) && matches!(actual, Kernel::Avx2Batch | Kernel::Avx512Batch) {
1028        actual = Kernel::ScalarBatch;
1029    }
1030
1031    let do_row = |row: usize, bulls_row: &mut [f64], bears_row: &mut [f64]| {
1032        let p = combos[row].period.unwrap();
1033        let m = combos[row].mode.unwrap();
1034        unsafe {
1035            match actual {
1036                Kernel::Scalar | Kernel::ScalarBatch => {
1037                    aso_scalar(open, high, low, close, p, m, first, bulls_row, bears_row)
1038                }
1039                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1040                Kernel::Avx2 | Kernel::Avx2Batch => {
1041                    aso_avx2(open, high, low, close, p, m, first, bulls_row, bears_row)
1042                }
1043                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1044                Kernel::Avx512 | Kernel::Avx512Batch => {
1045                    aso_avx512(open, high, low, close, p, m, first, bulls_row, bears_row)
1046                }
1047                #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1048                Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
1049                    aso_scalar(open, high, low, close, p, m, first, bulls_row, bears_row)
1050                }
1051                Kernel::Auto => unreachable!(),
1052            }
1053        }
1054    };
1055
1056    #[cfg(not(target_arch = "wasm32"))]
1057    {
1058        bulls_out
1059            .chunks_mut(cols)
1060            .zip(bears_out.chunks_mut(cols))
1061            .enumerate()
1062            .par_bridge()
1063            .for_each(|(row, (b, e))| do_row(row, b, e));
1064    }
1065    #[cfg(target_arch = "wasm32")]
1066    {
1067        for (row, (b, e)) in bulls_out
1068            .chunks_mut(cols)
1069            .zip(bears_out.chunks_mut(cols))
1070            .enumerate()
1071        {
1072            do_row(row, b, e);
1073        }
1074    }
1075
1076    let bulls = unsafe {
1077        Vec::from_raw_parts(
1078            guard_b.as_mut_ptr() as *mut f64,
1079            guard_b.len(),
1080            guard_b.capacity(),
1081        )
1082    };
1083    let bears = unsafe {
1084        Vec::from_raw_parts(
1085            guard_e.as_mut_ptr() as *mut f64,
1086            guard_e.len(),
1087            guard_e.capacity(),
1088        )
1089    };
1090
1091    Ok(AsoBatchOutput {
1092        bulls,
1093        bears,
1094        combos,
1095        rows,
1096        cols,
1097    })
1098}
1099
1100pub fn aso_batch_slice(
1101    open: &[f64],
1102    high: &[f64],
1103    low: &[f64],
1104    close: &[f64],
1105    sweep: &AsoBatchRange,
1106    kern: Kernel,
1107) -> Result<AsoBatchOutput, AsoError> {
1108    let k = match kern {
1109        Kernel::Scalar => Kernel::ScalarBatch,
1110        Kernel::Avx2 => Kernel::Avx2Batch,
1111        Kernel::Avx512 => Kernel::Avx512Batch,
1112        other => other,
1113    };
1114    aso_batch_with_kernel(open, high, low, close, sweep, k)
1115}
1116
1117#[cfg(not(target_arch = "wasm32"))]
1118pub fn aso_batch_par_slice(
1119    open: &[f64],
1120    high: &[f64],
1121    low: &[f64],
1122    close: &[f64],
1123    sweep: &AsoBatchRange,
1124    kern: Kernel,
1125) -> Result<AsoBatchOutput, AsoError> {
1126    let k = match kern {
1127        Kernel::Scalar => Kernel::ScalarBatch,
1128        Kernel::Avx2 => Kernel::Avx2Batch,
1129        Kernel::Avx512 => Kernel::Avx512Batch,
1130        other => other,
1131    };
1132    aso_batch_with_kernel(open, high, low, close, sweep, k)
1133}
1134
1135#[cfg(target_arch = "wasm32")]
1136pub fn aso_batch_par_slice(
1137    open: &[f64],
1138    high: &[f64],
1139    low: &[f64],
1140    close: &[f64],
1141    sweep: &AsoBatchRange,
1142    kern: Kernel,
1143) -> Result<AsoBatchOutput, AsoError> {
1144    aso_batch_with_kernel(open, high, low, close, sweep, kern)
1145}
1146
1147#[inline(always)]
1148fn aso_batch_inner_into(
1149    open: &[f64],
1150    high: &[f64],
1151    low: &[f64],
1152    close: &[f64],
1153    sweep: &AsoBatchRange,
1154    kern: Kernel,
1155    parallel: bool,
1156    out_bulls: &mut [f64],
1157    out_bears: &mut [f64],
1158) -> Result<Vec<AsoParams>, AsoError> {
1159    let combos = expand_grid_aso(sweep)?;
1160    if combos.is_empty() {
1161        return Err(AsoError::InvalidRange {
1162            start: 0,
1163            end: 0,
1164            step: 0,
1165        });
1166    }
1167
1168    let cols = close.len();
1169    if cols == 0 {
1170        return Err(AsoError::EmptyInputData);
1171    }
1172    if open.len() != cols || high.len() != cols || low.len() != cols {
1173        return Err(AsoError::MissingData);
1174    }
1175    let rows = combos.len();
1176    let total = rows.checked_mul(cols).ok_or(AsoError::InvalidRange {
1177        start: rows,
1178        end: cols,
1179        step: 0,
1180    })?;
1181    if out_bulls.len() != total || out_bears.len() != total {
1182        return Err(AsoError::OutputLengthMismatch {
1183            expected: total,
1184            got: out_bulls.len().min(out_bears.len()),
1185        });
1186    }
1187
1188    let first = close
1189        .iter()
1190        .position(|x| !x.is_nan())
1191        .ok_or(AsoError::AllValuesNaN)?;
1192    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1193    if cols - first < max_p {
1194        return Err(AsoError::NotEnoughValidData {
1195            needed: max_p,
1196            valid: cols - first,
1197        });
1198    }
1199
1200    let mut b_mu = unsafe {
1201        core::slice::from_raw_parts_mut(
1202            out_bulls.as_mut_ptr() as *mut MaybeUninit<f64>,
1203            out_bulls.len(),
1204        )
1205    };
1206    let mut e_mu = unsafe {
1207        core::slice::from_raw_parts_mut(
1208            out_bears.as_mut_ptr() as *mut MaybeUninit<f64>,
1209            out_bears.len(),
1210        )
1211    };
1212    let warm: Vec<usize> = combos
1213        .iter()
1214        .map(|c| first + c.period.unwrap() - 1)
1215        .collect();
1216    init_matrix_prefixes(&mut b_mu, cols, &warm);
1217    init_matrix_prefixes(&mut e_mu, cols, &warm);
1218
1219    let actual = match kern {
1220        Kernel::Auto => detect_best_batch_kernel(),
1221        Kernel::ScalarBatch | Kernel::Avx2Batch | Kernel::Avx512Batch => kern,
1222        other => return Err(AsoError::InvalidKernelForBatch(other)),
1223    };
1224
1225    let do_row = |row: usize, br: &mut [MaybeUninit<f64>], er: &mut [MaybeUninit<f64>]| unsafe {
1226        let p = combos[row].period.unwrap();
1227        let m = combos[row].mode.unwrap();
1228
1229        let b = core::slice::from_raw_parts_mut(br.as_mut_ptr() as *mut f64, br.len());
1230        let e = core::slice::from_raw_parts_mut(er.as_mut_ptr() as *mut f64, er.len());
1231
1232        match actual {
1233            Kernel::ScalarBatch => aso_scalar(open, high, low, close, p, m, first, b, e),
1234            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1235            Kernel::Avx2Batch => aso_avx2(open, high, low, close, p, m, first, b, e),
1236            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1237            Kernel::Avx512Batch => aso_avx512(open, high, low, close, p, m, first, b, e),
1238            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1239            Kernel::Avx2Batch | Kernel::Avx512Batch => {
1240                aso_scalar(open, high, low, close, p, m, first, b, e)
1241            }
1242            Kernel::Auto | Kernel::Scalar | Kernel::Avx2 | Kernel::Avx512 => unreachable!(),
1243        }
1244    };
1245
1246    if parallel {
1247        #[cfg(not(target_arch = "wasm32"))]
1248        {
1249            use rayon::prelude::*;
1250            b_mu.chunks_mut(cols)
1251                .zip(e_mu.chunks_mut(cols))
1252                .enumerate()
1253                .par_bridge()
1254                .for_each(|(row, (br, er))| do_row(row, br, er));
1255        }
1256        #[cfg(target_arch = "wasm32")]
1257        {
1258            for (row, (br, er)) in b_mu.chunks_mut(cols).zip(e_mu.chunks_mut(cols)).enumerate() {
1259                do_row(row, br, er);
1260            }
1261        }
1262    } else {
1263        for (row, (br, er)) in b_mu.chunks_mut(cols).zip(e_mu.chunks_mut(cols)).enumerate() {
1264            do_row(row, br, er);
1265        }
1266    }
1267
1268    Ok(combos)
1269}
1270
1271#[derive(Debug, Clone)]
1272pub struct AsoStream {
1273    o: Vec<f64>,
1274    h: Vec<f64>,
1275    l: Vec<f64>,
1276    c: Vec<f64>,
1277
1278    rb: Vec<f64>,
1279    re: Vec<f64>,
1280    sum_b: f64,
1281    sum_e: f64,
1282    head_be: usize,
1283    filled_be: usize,
1284
1285    dq_min_idx: Vec<usize>,
1286    dq_min_val: Vec<f64>,
1287    min_head: usize,
1288    min_tail: usize,
1289    min_len: usize,
1290
1291    dq_max_idx: Vec<usize>,
1292    dq_max_val: Vec<f64>,
1293    max_head: usize,
1294    max_tail: usize,
1295    max_len: usize,
1296
1297    period: usize,
1298    mode: usize,
1299    i: usize,
1300    ready: bool,
1301}
1302
1303impl AsoStream {
1304    #[inline]
1305    pub fn try_new(params: AsoParams) -> Result<Self, AsoError> {
1306        let period = params.period.unwrap_or(10);
1307        let mode = params.mode.unwrap_or(0);
1308
1309        if period == 0 {
1310            return Err(AsoError::InvalidPeriod {
1311                period,
1312                data_len: 0,
1313            });
1314        }
1315        if mode > 2 {
1316            return Err(AsoError::InvalidMode { mode });
1317        }
1318
1319        Ok(Self {
1320            o: vec![0.0; period],
1321            h: vec![0.0; period],
1322            l: vec![0.0; period],
1323            c: vec![0.0; period],
1324
1325            rb: vec![0.0; period],
1326            re: vec![0.0; period],
1327            sum_b: 0.0,
1328            sum_e: 0.0,
1329            head_be: 0,
1330            filled_be: 0,
1331
1332            dq_min_idx: vec![0usize; period],
1333            dq_min_val: vec![0.0; period],
1334            min_head: 0,
1335            min_tail: 0,
1336            min_len: 0,
1337
1338            dq_max_idx: vec![0usize; period],
1339            dq_max_val: vec![0.0; period],
1340            max_head: 0,
1341            max_tail: 0,
1342            max_len: 0,
1343
1344            period,
1345            mode,
1346            i: 0,
1347            ready: false,
1348        })
1349    }
1350
1351    #[inline(always)]
1352    fn inv_or_one(x: f64) -> f64 {
1353        if x != 0.0 {
1354            x.recip()
1355        } else {
1356            1.0
1357        }
1358    }
1359
1360    #[inline]
1361    pub fn update(&mut self, open: f64, high: f64, low: f64, close: f64) -> Option<(f64, f64)> {
1362        let p = self.period;
1363        let i = self.i;
1364        let idx = i % p;
1365
1366        self.o[idx] = open;
1367        self.h[idx] = high;
1368        self.l[idx] = low;
1369        self.c[idx] = close;
1370
1371        while self.min_len > 0 {
1372            let back = if self.min_tail == 0 {
1373                p - 1
1374            } else {
1375                self.min_tail - 1
1376            };
1377            if low <= self.dq_min_val[back] {
1378                self.min_tail = back;
1379                self.min_len -= 1;
1380            } else {
1381                break;
1382            }
1383        }
1384        if self.min_len == p {
1385            self.min_head += 1;
1386            if self.min_head == p {
1387                self.min_head = 0;
1388            }
1389            self.min_len -= 1;
1390        }
1391        self.dq_min_idx[self.min_tail] = i;
1392        self.dq_min_val[self.min_tail] = low;
1393        self.min_tail += 1;
1394        if self.min_tail == p {
1395            self.min_tail = 0;
1396        }
1397        self.min_len += 1;
1398
1399        while self.max_len > 0 {
1400            let back = if self.max_tail == 0 {
1401                p - 1
1402            } else {
1403                self.max_tail - 1
1404            };
1405            if high >= self.dq_max_val[back] {
1406                self.max_tail = back;
1407                self.max_len -= 1;
1408            } else {
1409                break;
1410            }
1411        }
1412        if self.max_len == p {
1413            self.max_head += 1;
1414            if self.max_head == p {
1415                self.max_head = 0;
1416            }
1417            self.max_len -= 1;
1418        }
1419        self.dq_max_idx[self.max_tail] = i;
1420        self.dq_max_val[self.max_tail] = high;
1421        self.max_tail += 1;
1422        if self.max_tail == p {
1423            self.max_tail = 0;
1424        }
1425        self.max_len += 1;
1426
1427        self.i = i + 1;
1428        if self.i >= p {
1429            self.ready = true;
1430        }
1431        if !self.ready {
1432            return None;
1433        }
1434
1435        let start_abs = self.i - p;
1436
1437        while self.min_len > 0 && self.dq_min_idx[self.min_head] < start_abs {
1438            self.min_head += 1;
1439            if self.min_head == p {
1440                self.min_head = 0;
1441            }
1442            self.min_len -= 1;
1443        }
1444        while self.max_len > 0 && self.dq_max_idx[self.max_head] < start_abs {
1445            self.max_head += 1;
1446            if self.max_head == p {
1447                self.max_head = 0;
1448            }
1449            self.max_len -= 1;
1450        }
1451
1452        debug_assert!(self.min_len > 0 && self.max_len > 0);
1453        let gl = self.dq_min_val[self.min_head];
1454        let gh = self.dq_max_val[self.max_head];
1455
1456        let oldest_ring = if idx + 1 == p { 0 } else { idx + 1 };
1457        let gopen = self.o[oldest_ring];
1458
1459        let intrarange = high - low;
1460        let scale1 = 50.0 * Self::inv_or_one(intrarange);
1461        let intrabarbulls = ((close - low) + (high - open)) * scale1;
1462        let intrabarbears = ((high - close) + (open - low)) * scale1;
1463
1464        let gr = gh - gl;
1465        let scale2 = 50.0 * Self::inv_or_one(gr);
1466        let groupbulls = ((close - gl) + (gh - gopen)) * scale2;
1467        let groupbears = ((gh - close) + (gopen - gl)) * scale2;
1468
1469        let b = match self.mode {
1470            0 => 0.5 * (intrabarbulls + groupbulls),
1471            1 => intrabarbulls,
1472            _ => groupbulls,
1473        };
1474        let e = match self.mode {
1475            0 => 0.5 * (intrabarbears + groupbears),
1476            1 => intrabarbears,
1477            _ => groupbears,
1478        };
1479
1480        let old_b = if self.filled_be == p {
1481            self.rb[self.head_be]
1482        } else {
1483            0.0
1484        };
1485        let old_e = if self.filled_be == p {
1486            self.re[self.head_be]
1487        } else {
1488            0.0
1489        };
1490
1491        self.sum_b += b - old_b;
1492        self.sum_e += e - old_e;
1493
1494        self.rb[self.head_be] = b;
1495        self.re[self.head_be] = e;
1496
1497        self.head_be += 1;
1498        if self.head_be == p {
1499            self.head_be = 0;
1500        }
1501        if self.filled_be < p {
1502            self.filled_be += 1;
1503        }
1504
1505        let n = self.filled_be as f64;
1506        Some((self.sum_b / n, self.sum_e / n))
1507    }
1508}
1509
1510#[cfg(feature = "python")]
1511#[pyfunction(name = "aso")]
1512#[pyo3(signature = (open, high, low, close, period=None, mode=None, kernel=None))]
1513pub fn aso_py<'py>(
1514    py: Python<'py>,
1515    open: PyReadonlyArray1<'py, f64>,
1516    high: PyReadonlyArray1<'py, f64>,
1517    low: PyReadonlyArray1<'py, f64>,
1518    close: PyReadonlyArray1<'py, f64>,
1519    period: Option<usize>,
1520    mode: Option<usize>,
1521    kernel: Option<&str>,
1522) -> PyResult<(Bound<'py, PyArray1<f64>>, Bound<'py, PyArray1<f64>>)> {
1523    let o = open.as_slice()?;
1524    let h = high.as_slice()?;
1525    let l = low.as_slice()?;
1526    let c = close.as_slice()?;
1527
1528    if h.len() != o.len() || l.len() != o.len() || c.len() != o.len() {
1529        return Err(PyValueError::new_err(
1530            "All OHLC arrays must have the same length",
1531        ));
1532    }
1533
1534    let kern = validate_kernel(kernel, false)?;
1535    let params = AsoParams { period, mode };
1536    let input = AsoInput::from_slices(o, h, l, c, params);
1537
1538    let (bulls, bears) = py
1539        .allow_threads(|| aso_with_kernel(&input, kern).map(|o| (o.bulls, o.bears)))
1540        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1541
1542    Ok((bulls.into_pyarray(py), bears.into_pyarray(py)))
1543}
1544
1545#[cfg(feature = "python")]
1546#[pyfunction(name = "aso_batch")]
1547#[pyo3(signature = (open, high, low, close, period_range, mode_range, kernel=None))]
1548pub fn aso_batch_py<'py>(
1549    py: Python<'py>,
1550    open: PyReadonlyArray1<'py, f64>,
1551    high: PyReadonlyArray1<'py, f64>,
1552    low: PyReadonlyArray1<'py, f64>,
1553    close: PyReadonlyArray1<'py, f64>,
1554    period_range: (usize, usize, usize),
1555    mode_range: (usize, usize, usize),
1556    kernel: Option<&str>,
1557) -> PyResult<Bound<'py, PyDict>> {
1558    use numpy::PyArray1;
1559    let (o, h, l, c) = (
1560        open.as_slice()?,
1561        high.as_slice()?,
1562        low.as_slice()?,
1563        close.as_slice()?,
1564    );
1565    let sweep = AsoBatchRange {
1566        period: period_range,
1567        mode: mode_range,
1568    };
1569    let combos = expand_grid_aso(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1570    let rows = combos.len();
1571    let cols = c.len();
1572    let total = rows
1573        .checked_mul(cols)
1574        .ok_or_else(|| PyValueError::new_err("size overflow"))?;
1575
1576    let bulls_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1577    let bears_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1578    let b = unsafe { bulls_arr.as_slice_mut()? };
1579    let e = unsafe { bears_arr.as_slice_mut()? };
1580
1581    let kern = validate_kernel(kernel, true)?;
1582    py.allow_threads(|| {
1583        let simd = match kern {
1584            Kernel::Auto => detect_best_batch_kernel(),
1585            k => k,
1586        };
1587        aso_batch_inner_into(o, h, l, c, &sweep, simd, true, b, e)
1588    })
1589    .map_err(|er| PyValueError::new_err(er.to_string()))?;
1590
1591    let d = PyDict::new(py);
1592    d.set_item("bulls", bulls_arr.reshape((rows, cols))?)?;
1593    d.set_item("bears", bears_arr.reshape((rows, cols))?)?;
1594    d.set_item(
1595        "periods",
1596        combos
1597            .iter()
1598            .map(|p| p.period.unwrap() as u64)
1599            .collect::<Vec<_>>()
1600            .into_pyarray(py),
1601    )?;
1602    d.set_item(
1603        "modes",
1604        combos
1605            .iter()
1606            .map(|p| p.mode.unwrap() as u64)
1607            .collect::<Vec<_>>()
1608            .into_pyarray(py),
1609    )?;
1610    Ok(d)
1611}
1612
1613#[cfg(all(feature = "python", feature = "cuda"))]
1614use crate::cuda::cuda_available;
1615#[cfg(all(feature = "python", feature = "cuda"))]
1616use crate::cuda::moving_averages::DeviceArrayF32;
1617#[cfg(all(feature = "python", feature = "cuda"))]
1618use crate::cuda::oscillators::CudaAso;
1619#[cfg(all(feature = "python", feature = "cuda"))]
1620use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
1621#[cfg(all(feature = "python", feature = "cuda"))]
1622use cust::context::Context as CudaContext;
1623#[cfg(all(feature = "python", feature = "cuda"))]
1624use cust::memory::DeviceBuffer;
1625#[cfg(all(feature = "python", feature = "cuda"))]
1626use std::sync::Arc;
1627
1628#[cfg(all(feature = "python", feature = "cuda"))]
1629#[pyclass(module = "ta_indicators.cuda", unsendable)]
1630pub struct AsoDeviceArrayF32Py {
1631    pub(crate) buf: Option<DeviceBuffer<f32>>,
1632    pub(crate) rows: usize,
1633    pub(crate) cols: usize,
1634    pub(crate) _ctx: Arc<CudaContext>,
1635    pub(crate) device_id: u32,
1636}
1637
1638#[cfg(all(feature = "python", feature = "cuda"))]
1639#[pymethods]
1640impl AsoDeviceArrayF32Py {
1641    #[getter]
1642    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1643        let d = PyDict::new(py);
1644        d.set_item("shape", (self.rows, self.cols))?;
1645        d.set_item("typestr", "<f4")?;
1646        d.set_item(
1647            "strides",
1648            (
1649                self.cols * std::mem::size_of::<f32>(),
1650                std::mem::size_of::<f32>(),
1651            ),
1652        )?;
1653        let ptr = if self.rows == 0 || self.cols == 0 {
1654            0usize
1655        } else {
1656            self.buf
1657                .as_ref()
1658                .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?
1659                .as_device_ptr()
1660                .as_raw() as usize
1661        };
1662        d.set_item("data", (ptr, false))?;
1663
1664        d.set_item("version", 3)?;
1665        Ok(d)
1666    }
1667
1668    fn __dlpack_device__(&self) -> (i32, i32) {
1669        (2, self.device_id as i32)
1670    }
1671
1672    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1673    fn __dlpack__<'py>(
1674        &mut self,
1675        py: Python<'py>,
1676        stream: Option<PyObject>,
1677        max_version: Option<PyObject>,
1678        dl_device: Option<PyObject>,
1679        copy: Option<PyObject>,
1680    ) -> PyResult<PyObject> {
1681        let (kdl, alloc_dev) = self.__dlpack_device__();
1682        if let Some(dev_obj) = dl_device.as_ref() {
1683            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1684                if dev_ty != kdl || dev_id != alloc_dev {
1685                    let wants_copy = copy
1686                        .as_ref()
1687                        .and_then(|c| c.extract::<bool>(py).ok())
1688                        .unwrap_or(false);
1689                    if wants_copy {
1690                        return Err(PyValueError::new_err(
1691                            "__dlpack__(copy=True) not implemented for ASO device handle",
1692                        ));
1693                    } else {
1694                        return Err(PyValueError::new_err("dl_device mismatch for ASO tensor"));
1695                    }
1696                }
1697            }
1698        }
1699        let _ = stream;
1700
1701        let buf = self
1702            .buf
1703            .take()
1704            .ok_or_else(|| PyValueError::new_err("__dlpack__ may only be called once"))?;
1705
1706        let rows = self.rows;
1707        let cols = self.cols;
1708
1709        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1710
1711        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1712    }
1713}
1714
1715#[cfg(all(feature = "python", feature = "cuda"))]
1716#[pyfunction(name = "aso_cuda_batch_dev")]
1717#[pyo3(signature = (open, high, low, close, period_range, mode_range, device_id=0))]
1718pub fn aso_cuda_batch_dev_py(
1719    py: Python<'_>,
1720    open: PyReadonlyArray1<'_, f32>,
1721    high: PyReadonlyArray1<'_, f32>,
1722    low: PyReadonlyArray1<'_, f32>,
1723    close: PyReadonlyArray1<'_, f32>,
1724    period_range: (usize, usize, usize),
1725    mode_range: (usize, usize, usize),
1726    device_id: usize,
1727) -> PyResult<(AsoDeviceArrayF32Py, AsoDeviceArrayF32Py)> {
1728    if !cuda_available() {
1729        return Err(PyValueError::new_err("CUDA not available"));
1730    }
1731    let o = open.as_slice()?;
1732    let h = high.as_slice()?;
1733    let l = low.as_slice()?;
1734    let c = close.as_slice()?;
1735    if o.len() == 0 || h.len() != o.len() || l.len() != o.len() || c.len() != o.len() {
1736        return Err(PyValueError::new_err("mismatched input lengths"));
1737    }
1738    let sweep = AsoBatchRange {
1739        period: period_range,
1740        mode: mode_range,
1741    };
1742    let (bulls, bears, ctx_guard, dev_id) = py.allow_threads(|| {
1743        let cuda = CudaAso::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1744        let out = cuda
1745            .aso_batch_dev(o, h, l, c, &sweep)
1746            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1747        Ok::<_, PyErr>((out.0, out.1, cuda.context_arc(), cuda.device_id()))
1748    })?;
1749    Ok((
1750        AsoDeviceArrayF32Py {
1751            buf: Some(bulls.buf),
1752            rows: bulls.rows,
1753            cols: bulls.cols,
1754            _ctx: ctx_guard.clone(),
1755            device_id: dev_id,
1756        },
1757        AsoDeviceArrayF32Py {
1758            buf: Some(bears.buf),
1759            rows: bears.rows,
1760            cols: bears.cols,
1761            _ctx: ctx_guard,
1762            device_id: dev_id,
1763        },
1764    ))
1765}
1766
1767#[cfg(all(feature = "python", feature = "cuda"))]
1768#[pyfunction(name = "aso_cuda_many_series_one_param_dev")]
1769#[pyo3(signature = (open_tm, high_tm, low_tm, close_tm, cols, rows, period, mode, device_id=0))]
1770pub fn aso_cuda_many_series_one_param_dev_py(
1771    py: Python<'_>,
1772    open_tm: PyReadonlyArray1<'_, f32>,
1773    high_tm: PyReadonlyArray1<'_, f32>,
1774    low_tm: PyReadonlyArray1<'_, f32>,
1775    close_tm: PyReadonlyArray1<'_, f32>,
1776    cols: usize,
1777    rows: usize,
1778    period: usize,
1779    mode: usize,
1780    device_id: usize,
1781) -> PyResult<(AsoDeviceArrayF32Py, AsoDeviceArrayF32Py)> {
1782    if !cuda_available() {
1783        return Err(PyValueError::new_err("CUDA not available"));
1784    }
1785    let o = open_tm.as_slice()?;
1786    let h = high_tm.as_slice()?;
1787    let l = low_tm.as_slice()?;
1788    let c = close_tm.as_slice()?;
1789    let expected = cols
1790        .checked_mul(rows)
1791        .ok_or_else(|| PyValueError::new_err("size overflow"))?;
1792    if expected != o.len() || h.len() != o.len() || l.len() != o.len() || c.len() != o.len() {
1793        return Err(PyValueError::new_err("mismatched input sizes"));
1794    }
1795    if mode > 2 {
1796        return Err(PyValueError::new_err("invalid mode"));
1797    }
1798    let (bulls, bears, ctx_guard, dev_id) = py.allow_threads(|| {
1799        let cuda = CudaAso::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1800        let out = cuda
1801            .aso_many_series_one_param_time_major_dev(o, h, l, c, cols, rows, period, mode)
1802            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1803        Ok::<_, PyErr>((out.0, out.1, cuda.context_arc(), cuda.device_id()))
1804    })?;
1805    Ok((
1806        AsoDeviceArrayF32Py {
1807            buf: Some(bulls.buf),
1808            rows: bulls.rows,
1809            cols: bulls.cols,
1810            _ctx: ctx_guard.clone(),
1811            device_id: dev_id,
1812        },
1813        AsoDeviceArrayF32Py {
1814            buf: Some(bears.buf),
1815            rows: bears.rows,
1816            cols: bears.cols,
1817            _ctx: ctx_guard,
1818            device_id: dev_id,
1819        },
1820    ))
1821}
1822
1823#[cfg(feature = "python")]
1824#[pyclass(name = "AsoStream")]
1825pub struct AsoStreamPy {
1826    stream: AsoStream,
1827}
1828
1829#[cfg(feature = "python")]
1830#[pymethods]
1831impl AsoStreamPy {
1832    #[new]
1833    fn new(period: Option<usize>, mode: Option<usize>) -> PyResult<Self> {
1834        let params = AsoParams { period, mode };
1835        let stream =
1836            AsoStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1837        Ok(AsoStreamPy { stream })
1838    }
1839
1840    fn update(&mut self, open: f64, high: f64, low: f64, close: f64) -> Option<(f64, f64)> {
1841        self.stream.update(open, high, low, close)
1842    }
1843}
1844
1845#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1846#[derive(Serialize, Deserialize)]
1847pub struct AsoResult {
1848    pub values: Vec<f64>,
1849    pub rows: usize,
1850    pub cols: usize,
1851}
1852
1853#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1854#[wasm_bindgen(js_name = "aso")]
1855pub fn aso_js(
1856    open: &[f64],
1857    high: &[f64],
1858    low: &[f64],
1859    close: &[f64],
1860    period: Option<usize>,
1861    mode: Option<usize>,
1862) -> Result<JsValue, JsValue> {
1863    let len = close.len();
1864    if open.len() != len || high.len() != len || low.len() != len {
1865        return Err(JsValue::from_str(
1866            "All OHLC arrays must have the same length",
1867        ));
1868    }
1869    let p = period.unwrap_or(10);
1870    let m = mode.unwrap_or(0);
1871    if m > 2 {
1872        return Err(JsValue::from_str("Invalid mode"));
1873    }
1874
1875    let first = close
1876        .iter()
1877        .position(|x| !x.is_nan())
1878        .ok_or_else(|| JsValue::from_str("All values NaN"))?;
1879    if p == 0 || p > len {
1880        return Err(JsValue::from_str("Invalid period"));
1881    }
1882    if len - first < p {
1883        return Err(JsValue::from_str("Not enough valid data"));
1884    }
1885
1886    let mut mu = make_uninit_matrix(2, len);
1887    let warm = first + p - 1;
1888    init_matrix_prefixes(&mut mu, len, &[warm, warm]);
1889
1890    let mut guard = core::mem::ManuallyDrop::new(mu);
1891    let dst: &mut [f64] =
1892        unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
1893    let (bulls_dst, bears_dst) = dst.split_at_mut(len);
1894
1895    let chosen = detect_best_kernel();
1896    unsafe {
1897        aso_compute_into(
1898            open, high, low, close, p, m, first, chosen, bulls_dst, bears_dst,
1899        );
1900    }
1901
1902    let values = unsafe {
1903        Vec::from_raw_parts(
1904            guard.as_mut_ptr() as *mut f64,
1905            guard.len(),
1906            guard.capacity(),
1907        )
1908    };
1909    let out = AsoResult {
1910        values,
1911        rows: 2,
1912        cols: len,
1913    };
1914    serde_wasm_bindgen::to_value(&out).map_err(|e| JsValue::from_str(&e.to_string()))
1915}
1916
1917#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1918#[wasm_bindgen]
1919pub fn aso_into(
1920    open_ptr: *const f64,
1921    high_ptr: *const f64,
1922    low_ptr: *const f64,
1923    close_ptr: *const f64,
1924    bulls_ptr: *mut f64,
1925    bears_ptr: *mut f64,
1926    len: usize,
1927    period: usize,
1928    mode: usize,
1929) -> Result<(), JsValue> {
1930    if [open_ptr, high_ptr, low_ptr, close_ptr]
1931        .iter()
1932        .any(|p| p.is_null())
1933        || [bulls_ptr, bears_ptr].iter().any(|p| p.is_null())
1934    {
1935        return Err(JsValue::from_str("null pointer"));
1936    }
1937
1938    unsafe {
1939        let o = std::slice::from_raw_parts(open_ptr, len);
1940        let h = std::slice::from_raw_parts(high_ptr, len);
1941        let l = std::slice::from_raw_parts(low_ptr, len);
1942        let c = std::slice::from_raw_parts(close_ptr, len);
1943        let bulls = std::slice::from_raw_parts_mut(bulls_ptr, len);
1944        let bears = std::slice::from_raw_parts_mut(bears_ptr, len);
1945
1946        let input = AsoInput::from_slices(
1947            o,
1948            h,
1949            l,
1950            c,
1951            AsoParams {
1952                period: Some(period),
1953                mode: Some(mode),
1954            },
1955        );
1956
1957        aso_into_slices(bulls, bears, &input, detect_best_kernel())
1958            .map_err(|e| JsValue::from_str(&e.to_string()))
1959    }
1960}
1961
1962#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1963#[wasm_bindgen]
1964pub fn aso_batch_into(
1965    open_ptr: *const f64,
1966    high_ptr: *const f64,
1967    low_ptr: *const f64,
1968    close_ptr: *const f64,
1969    len: usize,
1970    period_start: usize,
1971    period_end: usize,
1972    period_step: usize,
1973    mode_start: usize,
1974    mode_end: usize,
1975    mode_step: usize,
1976    bulls_out: *mut f64,
1977    bears_out: *mut f64,
1978) -> Result<usize, JsValue> {
1979    if [open_ptr, high_ptr, low_ptr, close_ptr, bulls_out, bears_out]
1980        .iter()
1981        .any(|p| p.is_null())
1982    {
1983        return Err(JsValue::from_str("null pointer"));
1984    }
1985    unsafe {
1986        let o = std::slice::from_raw_parts(open_ptr, len);
1987        let h = std::slice::from_raw_parts(high_ptr, len);
1988        let l = std::slice::from_raw_parts(low_ptr, len);
1989        let c = std::slice::from_raw_parts(close_ptr, len);
1990
1991        let sweep = AsoBatchRange {
1992            period: (period_start, period_end, period_step),
1993            mode: (mode_start, mode_end, mode_step),
1994        };
1995
1996        let combos = expand_grid_aso(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1997        let rows = combos.len();
1998        let cols = len;
1999        let total = rows
2000            .checked_mul(cols)
2001            .ok_or_else(|| JsValue::from_str("size overflow"))?;
2002
2003        let b = std::slice::from_raw_parts_mut(bulls_out, total);
2004        let e = std::slice::from_raw_parts_mut(bears_out, total);
2005
2006        aso_batch_inner_into(o, h, l, c, &sweep, detect_best_batch_kernel(), false, b, e)
2007            .map_err(|er| JsValue::from_str(&er.to_string()))?;
2008        Ok(rows)
2009    }
2010}
2011
2012#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2013#[derive(Serialize, Deserialize)]
2014pub struct AsoBatchConfig {
2015    pub period_range: (usize, usize, usize),
2016    pub mode_range: (usize, usize, usize),
2017}
2018
2019#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2020#[derive(Serialize, Deserialize)]
2021pub struct AsoBatchJsOutput {
2022    pub values: Vec<f64>,
2023    pub combos: Vec<AsoParams>,
2024    pub rows: usize,
2025    pub cols: usize,
2026}
2027
2028#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2029#[wasm_bindgen(js_name = "aso_batch")]
2030pub fn aso_batch_unified_js(
2031    open: &[f64],
2032    high: &[f64],
2033    low: &[f64],
2034    close: &[f64],
2035    config: JsValue,
2036) -> Result<JsValue, JsValue> {
2037    let cfg: AsoBatchConfig = serde_wasm_bindgen::from_value(config)
2038        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2039
2040    let sweep = AsoBatchRange {
2041        period: cfg.period_range,
2042        mode: cfg.mode_range,
2043    };
2044    let combos = expand_grid_aso(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2045    let rows = combos.len();
2046    let cols = close.len();
2047    if cols == 0 {
2048        return Err(JsValue::from_str("Empty input"));
2049    }
2050    if open.len() != cols || high.len() != cols || low.len() != cols {
2051        return Err(JsValue::from_str("OHLC length mismatch"));
2052    }
2053
2054    let mut mu = make_uninit_matrix(rows * 2, cols);
2055    let first = close
2056        .iter()
2057        .position(|x| !x.is_nan())
2058        .ok_or_else(|| JsValue::from_str("All values NaN"))?;
2059    let warms: Vec<usize> = combos
2060        .iter()
2061        .flat_map(|c| {
2062            let w = first + c.period.unwrap() - 1;
2063            [w, w]
2064        })
2065        .collect();
2066    init_matrix_prefixes(&mut mu, cols, &warms);
2067
2068    let mut guard = core::mem::ManuallyDrop::new(mu);
2069    let dst: &mut [f64] =
2070        unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
2071    let total = rows
2072        .checked_mul(cols)
2073        .ok_or_else(|| JsValue::from_str("size overflow"))?;
2074    let (bulls_dst, bears_dst) = dst.split_at_mut(total);
2075
2076    let kern = detect_best_batch_kernel();
2077    aso_batch_inner_into(
2078        open, high, low, close, &sweep, kern, false, bulls_dst, bears_dst,
2079    )
2080    .map_err(|e| JsValue::from_str(&e.to_string()))?;
2081
2082    let values = unsafe {
2083        Vec::from_raw_parts(
2084            guard.as_mut_ptr() as *mut f64,
2085            guard.len(),
2086            guard.capacity(),
2087        )
2088    };
2089
2090    let out = AsoBatchJsOutput {
2091        values,
2092        combos: combos.clone(),
2093        rows: rows * 2,
2094        cols,
2095    };
2096    serde_wasm_bindgen::to_value(&out)
2097        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2098}
2099
2100#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2101#[wasm_bindgen]
2102pub fn aso_alloc(len: usize) -> *mut f64 {
2103    let mut v = Vec::<f64>::with_capacity(len);
2104    let p = v.as_mut_ptr();
2105    std::mem::forget(v);
2106    p
2107}
2108
2109#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2110#[wasm_bindgen]
2111pub fn aso_free(ptr: *mut f64, len: usize) {
2112    unsafe {
2113        let _ = Vec::from_raw_parts(ptr, len, len);
2114    }
2115}
2116
2117#[cfg(test)]
2118mod tests {
2119    use super::*;
2120    use crate::skip_if_unsupported;
2121    use crate::utilities::data_loader::read_candles_from_csv;
2122    #[cfg(feature = "proptest")]
2123    use proptest::prelude::*;
2124    use std::error::Error;
2125
2126    fn check_aso_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2127        skip_if_unsupported!(kernel, test_name);
2128
2129        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2130        let candles = read_candles_from_csv(file_path)?;
2131
2132        let input = AsoInput::from_candles(&candles, "close", AsoParams::default());
2133        let result = aso_with_kernel(&input, kernel)?;
2134
2135        let expected_bulls = [
2136            48.48594883,
2137            46.37206396,
2138            47.20522805,
2139            46.83750720,
2140            43.28268188,
2141        ];
2142
2143        let expected_bears = [
2144            51.51405117,
2145            53.62793604,
2146            52.79477195,
2147            53.16249280,
2148            56.71731812,
2149        ];
2150
2151        let start = result.bulls.len().saturating_sub(5);
2152        for (i, (&bull_val, &bear_val)) in result.bulls[start..]
2153            .iter()
2154            .zip(result.bears[start..].iter())
2155            .enumerate()
2156        {
2157            let bull_diff = (bull_val - expected_bulls[i]).abs();
2158            let bear_diff = (bear_val - expected_bears[i]).abs();
2159
2160            assert!(
2161                bull_diff < 1e-6,
2162                "[{}] ASO Bulls {:?} mismatch at idx {}: got {}, expected {}",
2163                test_name,
2164                kernel,
2165                i,
2166                bull_val,
2167                expected_bulls[i]
2168            );
2169
2170            assert!(
2171                bear_diff < 1e-6,
2172                "[{}] ASO Bears {:?} mismatch at idx {}: got {}, expected {}",
2173                test_name,
2174                kernel,
2175                i,
2176                bear_val,
2177                expected_bears[i]
2178            );
2179        }
2180        Ok(())
2181    }
2182
2183    fn check_aso_slice_input(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2184        skip_if_unsupported!(kernel, test_name);
2185
2186        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2187        let candles = read_candles_from_csv(file_path)?;
2188
2189        let input = AsoInput::from_slices(
2190            &candles.open,
2191            &candles.high,
2192            &candles.low,
2193            &candles.close,
2194            AsoParams::default(),
2195        );
2196        let result = aso_with_kernel(&input, kernel)?;
2197
2198        assert_eq!(result.bulls.len(), candles.close.len());
2199        assert_eq!(result.bears.len(), candles.close.len());
2200
2201        Ok(())
2202    }
2203
2204    fn check_aso_into_slices(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2205        skip_if_unsupported!(kernel, test_name);
2206
2207        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2208        let candles = read_candles_from_csv(file_path)?;
2209
2210        let mut bulls = vec![0.0; candles.close.len()];
2211        let mut bears = vec![0.0; candles.close.len()];
2212
2213        let input = AsoInput::from_candles(&candles, "close", AsoParams::default());
2214        aso_into_slices(&mut bulls, &mut bears, &input, kernel)?;
2215
2216        for i in 0..9 {
2217            assert!(bulls[i].is_nan());
2218            assert!(bears[i].is_nan());
2219        }
2220
2221        assert!(!bulls[20].is_nan());
2222        assert!(!bears[20].is_nan());
2223
2224        Ok(())
2225    }
2226
2227    fn check_aso_batch(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2228        skip_if_unsupported!(kernel, test_name);
2229
2230        let open = vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0];
2231        let high = vec![15.0, 25.0, 35.0, 45.0, 55.0, 65.0, 75.0, 85.0, 95.0, 105.0];
2232        let low = vec![5.0, 15.0, 25.0, 35.0, 45.0, 55.0, 65.0, 75.0, 85.0, 95.0];
2233        let close = vec![12.0, 22.0, 32.0, 42.0, 52.0, 62.0, 72.0, 82.0, 92.0, 102.0];
2234
2235        let sweep = AsoBatchRange {
2236            period: (3, 5, 1),
2237            mode: (0, 2, 1),
2238        };
2239
2240        let result = aso_batch_with_kernel(&open, &high, &low, &close, &sweep, kernel)?;
2241
2242        assert_eq!(result.rows, 9);
2243        assert_eq!(result.cols, 10);
2244        assert_eq!(result.bulls.len(), 90);
2245        assert_eq!(result.bears.len(), 90);
2246        assert_eq!(result.combos.len(), 9);
2247
2248        Ok(())
2249    }
2250
2251    fn check_aso_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2252        skip_if_unsupported!(kernel, test_name);
2253
2254        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2255        let candles = read_candles_from_csv(file_path)?;
2256
2257        let default_params = AsoParams {
2258            period: None,
2259            mode: None,
2260        };
2261        let input = AsoInput::from_candles(&candles, "close", default_params);
2262        let output = aso_with_kernel(&input, kernel)?;
2263        assert_eq!(output.bulls.len(), candles.close.len());
2264        assert_eq!(output.bears.len(), candles.close.len());
2265
2266        Ok(())
2267    }
2268
2269    fn check_aso_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2270        skip_if_unsupported!(kernel, test_name);
2271
2272        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2273        let candles = read_candles_from_csv(file_path)?;
2274
2275        let input = AsoInput::with_default_candles(&candles);
2276        let output = aso_with_kernel(&input, kernel)?;
2277        assert_eq!(output.bulls.len(), candles.close.len());
2278        assert_eq!(output.bears.len(), candles.close.len());
2279
2280        Ok(())
2281    }
2282
2283    fn check_aso_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2284        skip_if_unsupported!(kernel, test_name);
2285
2286        let open = vec![10.0, 20.0, 30.0];
2287        let high = vec![15.0, 25.0, 35.0];
2288        let low = vec![8.0, 18.0, 28.0];
2289        let close = vec![12.0, 22.0, 32.0];
2290
2291        let params = AsoParams {
2292            period: Some(0),
2293            mode: None,
2294        };
2295        let input = AsoInput::from_slices(&open, &high, &low, &close, params);
2296        let res = aso_with_kernel(&input, kernel);
2297        assert!(
2298            res.is_err(),
2299            "[{}] ASO should fail with zero period",
2300            test_name
2301        );
2302        Ok(())
2303    }
2304
2305    fn check_aso_period_exceeds_length(
2306        test_name: &str,
2307        kernel: Kernel,
2308    ) -> Result<(), Box<dyn Error>> {
2309        skip_if_unsupported!(kernel, test_name);
2310
2311        let open = vec![10.0, 20.0, 30.0];
2312        let high = vec![15.0, 25.0, 35.0];
2313        let low = vec![8.0, 18.0, 28.0];
2314        let close = vec![12.0, 22.0, 32.0];
2315
2316        let params = AsoParams {
2317            period: Some(10),
2318            mode: None,
2319        };
2320        let input = AsoInput::from_slices(&open, &high, &low, &close, params);
2321        let res = aso_with_kernel(&input, kernel);
2322        assert!(
2323            res.is_err(),
2324            "[{}] ASO should fail with period exceeding length",
2325            test_name
2326        );
2327        Ok(())
2328    }
2329
2330    fn check_aso_invalid_mode(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2331        skip_if_unsupported!(kernel, test_name);
2332
2333        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2334        let candles = read_candles_from_csv(file_path)?;
2335
2336        let params = AsoParams {
2337            period: Some(10),
2338            mode: Some(3),
2339        };
2340        let input = AsoInput::from_candles(&candles, "close", params);
2341        let res = aso_with_kernel(&input, kernel);
2342        assert!(
2343            res.is_err(),
2344            "[{}] ASO should fail with invalid mode",
2345            test_name
2346        );
2347        Ok(())
2348    }
2349
2350    fn check_aso_empty_input(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2351        skip_if_unsupported!(kernel, test_name);
2352
2353        let empty: Vec<f64> = vec![];
2354        let params = AsoParams::default();
2355        let input = AsoInput::from_slices(&empty, &empty, &empty, &empty, params);
2356        let res = aso_with_kernel(&input, kernel);
2357        assert!(
2358            res.is_err(),
2359            "[{}] ASO should fail with empty input",
2360            test_name
2361        );
2362        Ok(())
2363    }
2364
2365    fn check_aso_all_nan(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2366        skip_if_unsupported!(kernel, test_name);
2367
2368        let nan_data = vec![f64::NAN, f64::NAN, f64::NAN];
2369        let params = AsoParams::default();
2370        let input = AsoInput::from_slices(&nan_data, &nan_data, &nan_data, &nan_data, params);
2371        let res = aso_with_kernel(&input, kernel);
2372        assert!(
2373            res.is_err(),
2374            "[{}] ASO should fail with all NaN values",
2375            test_name
2376        );
2377        Ok(())
2378    }
2379
2380    fn check_aso_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2381        skip_if_unsupported!(kernel, test_name);
2382        let single_point = [42.0];
2383        let params = AsoParams {
2384            period: Some(10),
2385            mode: None,
2386        };
2387        let input = AsoInput::from_slices(
2388            &single_point,
2389            &single_point,
2390            &single_point,
2391            &single_point,
2392            params,
2393        );
2394        let res = aso_with_kernel(&input, kernel);
2395        assert!(
2396            res.is_err(),
2397            "[{}] ASO should fail with insufficient data",
2398            test_name
2399        );
2400        Ok(())
2401    }
2402
2403    fn check_aso_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2404        skip_if_unsupported!(kernel, test_name);
2405        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2406        let candles = read_candles_from_csv(file_path)?;
2407
2408        let first_params = AsoParams {
2409            period: Some(10),
2410            mode: Some(0),
2411        };
2412        let first_input = AsoInput::from_candles(&candles, "close", first_params);
2413        let first_result = aso_with_kernel(&first_input, kernel)?;
2414
2415        let second_params = AsoParams {
2416            period: Some(10),
2417            mode: Some(0),
2418        };
2419        let second_input = AsoInput::from_slices(
2420            &first_result.bulls,
2421            &first_result.bulls,
2422            &first_result.bulls,
2423            &first_result.bulls,
2424            second_params,
2425        );
2426        let second_result = aso_with_kernel(&second_input, kernel)?;
2427
2428        assert_eq!(second_result.bulls.len(), first_result.bulls.len());
2429        assert_eq!(second_result.bears.len(), first_result.bears.len());
2430
2431        if second_result.bulls.len() > 30 {
2432            assert!(!second_result.bulls[30].is_nan());
2433            assert!(!second_result.bears[30].is_nan());
2434        }
2435
2436        Ok(())
2437    }
2438
2439    fn check_aso_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2440        skip_if_unsupported!(kernel, test_name);
2441        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2442        let candles = read_candles_from_csv(file_path)?;
2443
2444        let input = AsoInput::from_candles(
2445            &candles,
2446            "close",
2447            AsoParams {
2448                period: Some(10),
2449                mode: Some(0),
2450            },
2451        );
2452        let res = aso_with_kernel(&input, kernel)?;
2453        assert_eq!(res.bulls.len(), candles.close.len());
2454        assert_eq!(res.bears.len(), candles.close.len());
2455
2456        if res.bulls.len() > 240 {
2457            for (i, (&bull_val, &bear_val)) in res.bulls[240..]
2458                .iter()
2459                .zip(res.bears[240..].iter())
2460                .enumerate()
2461            {
2462                assert!(
2463                    !bull_val.is_nan(),
2464                    "[{}] Found unexpected NaN in bulls at out-index {}",
2465                    test_name,
2466                    240 + i
2467                );
2468                assert!(
2469                    !bear_val.is_nan(),
2470                    "[{}] Found unexpected NaN in bears at out-index {}",
2471                    test_name,
2472                    240 + i
2473                );
2474            }
2475        }
2476        Ok(())
2477    }
2478
2479    fn check_aso_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2480        skip_if_unsupported!(kernel, test_name);
2481
2482        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2483        let candles = read_candles_from_csv(file_path)?;
2484
2485        let period = 10;
2486        let mode = 0;
2487
2488        let input = AsoInput::from_candles(
2489            &candles,
2490            "close",
2491            AsoParams {
2492                period: Some(period),
2493                mode: Some(mode),
2494            },
2495        );
2496        let batch_output = aso_with_kernel(&input, kernel)?;
2497
2498        let mut stream = AsoStream::try_new(AsoParams {
2499            period: Some(period),
2500            mode: Some(mode),
2501        })?;
2502
2503        let mut stream_bulls = Vec::with_capacity(candles.close.len());
2504        let mut stream_bears = Vec::with_capacity(candles.close.len());
2505
2506        for i in 0..candles.close.len() {
2507            match stream.update(
2508                candles.open[i],
2509                candles.high[i],
2510                candles.low[i],
2511                candles.close[i],
2512            ) {
2513                Some((bull, bear)) => {
2514                    stream_bulls.push(bull);
2515                    stream_bears.push(bear);
2516                }
2517                None => {
2518                    stream_bulls.push(f64::NAN);
2519                    stream_bears.push(f64::NAN);
2520                }
2521            }
2522        }
2523
2524        assert_eq!(batch_output.bulls.len(), stream_bulls.len());
2525        assert_eq!(batch_output.bears.len(), stream_bears.len());
2526
2527        for (i, ((&batch_bull, &stream_bull), (&batch_bear, &stream_bear))) in batch_output
2528            .bulls
2529            .iter()
2530            .zip(stream_bulls.iter())
2531            .zip(batch_output.bears.iter().zip(stream_bears.iter()))
2532            .enumerate()
2533        {
2534            if batch_bull.is_nan() && stream_bull.is_nan() {
2535                continue;
2536            }
2537            if batch_bear.is_nan() && stream_bear.is_nan() {
2538                continue;
2539            }
2540
2541            if i >= period {
2542                if !batch_bull.is_nan() && !stream_bull.is_nan() {
2543                    assert!(
2544                        stream_bull >= -1e-9 && stream_bull <= 100.0 + 1e-9,
2545                        "[{}] ASO streaming bulls out of range at idx {}: {}",
2546                        test_name,
2547                        i,
2548                        stream_bull
2549                    );
2550                }
2551                if !batch_bear.is_nan() && !stream_bear.is_nan() {
2552                    assert!(
2553                        stream_bear >= -1e-9 && stream_bear <= 100.0 + 1e-9,
2554                        "[{}] ASO streaming bears out of range at idx {}: {}",
2555                        test_name,
2556                        i,
2557                        stream_bear
2558                    );
2559                }
2560
2561                if mode != 0 && !stream_bull.is_nan() && !stream_bear.is_nan() {
2562                    let sum = stream_bull + stream_bear;
2563                    assert!(
2564                        (sum - 100.0).abs() < 1e-9,
2565                        "[{}] ASO streaming bulls + bears != 100 at idx {} (mode {}): {} + {} = {}",
2566                        test_name,
2567                        i,
2568                        mode,
2569                        stream_bull,
2570                        stream_bear,
2571                        sum
2572                    );
2573                }
2574            }
2575        }
2576        Ok(())
2577    }
2578
2579    #[cfg(debug_assertions)]
2580    fn check_aso_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2581        skip_if_unsupported!(kernel, test_name);
2582
2583        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2584        let candles = read_candles_from_csv(file_path)?;
2585
2586        let test_params = vec![
2587            AsoParams::default(),
2588            AsoParams {
2589                period: Some(5),
2590                mode: Some(0),
2591            },
2592            AsoParams {
2593                period: Some(5),
2594                mode: Some(1),
2595            },
2596            AsoParams {
2597                period: Some(5),
2598                mode: Some(2),
2599            },
2600            AsoParams {
2601                period: Some(10),
2602                mode: Some(0),
2603            },
2604            AsoParams {
2605                period: Some(10),
2606                mode: Some(1),
2607            },
2608            AsoParams {
2609                period: Some(10),
2610                mode: Some(2),
2611            },
2612            AsoParams {
2613                period: Some(20),
2614                mode: Some(0),
2615            },
2616            AsoParams {
2617                period: Some(20),
2618                mode: Some(1),
2619            },
2620            AsoParams {
2621                period: Some(20),
2622                mode: Some(2),
2623            },
2624            AsoParams {
2625                period: Some(2),
2626                mode: Some(0),
2627            },
2628            AsoParams {
2629                period: Some(50),
2630                mode: Some(1),
2631            },
2632            AsoParams {
2633                period: Some(100),
2634                mode: Some(2),
2635            },
2636        ];
2637
2638        for (param_idx, params) in test_params.iter().enumerate() {
2639            let input = AsoInput::from_candles(&candles, "close", params.clone());
2640            let output = aso_with_kernel(&input, kernel)?;
2641
2642            for (i, (&bull_val, &bear_val)) in
2643                output.bulls.iter().zip(output.bears.iter()).enumerate()
2644            {
2645                if bull_val.is_nan() || bear_val.is_nan() {
2646                    continue;
2647                }
2648
2649                let bull_bits = bull_val.to_bits();
2650                let bear_bits = bear_val.to_bits();
2651
2652                for (val, bits, name) in [
2653                    (bull_val, bull_bits, "bulls"),
2654                    (bear_val, bear_bits, "bears"),
2655                ] {
2656                    if bits == 0x11111111_11111111 {
2657                        panic!(
2658                            "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) in {} at index {} \
2659                            with params: period={}, mode={}",
2660                            test_name,
2661                            val,
2662                            bits,
2663                            name,
2664                            i,
2665                            params.period.unwrap_or(10),
2666                            params.mode.unwrap_or(0)
2667                        );
2668                    }
2669
2670                    if bits == 0x22222222_22222222 {
2671                        panic!(
2672                            "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) in {} at index {} \
2673                            with params: period={}, mode={}",
2674                            test_name,
2675                            val,
2676                            bits,
2677                            name,
2678                            i,
2679                            params.period.unwrap_or(10),
2680                            params.mode.unwrap_or(0)
2681                        );
2682                    }
2683
2684                    if bits == 0x33333333_33333333 {
2685                        panic!(
2686                            "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) in {} at index {} \
2687                            with params: period={}, mode={}",
2688                            test_name,
2689                            val,
2690                            bits,
2691                            name,
2692                            i,
2693                            params.period.unwrap_or(10),
2694                            params.mode.unwrap_or(0)
2695                        );
2696                    }
2697                }
2698            }
2699        }
2700
2701        Ok(())
2702    }
2703
2704    #[cfg(not(debug_assertions))]
2705    fn check_aso_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2706        Ok(())
2707    }
2708
2709    #[cfg(feature = "proptest")]
2710    #[allow(clippy::float_cmp)]
2711    fn check_aso_property(
2712        test_name: &str,
2713        kernel: Kernel,
2714    ) -> Result<(), Box<dyn std::error::Error>> {
2715        use proptest::prelude::*;
2716        skip_if_unsupported!(kernel, test_name);
2717
2718        let strat = (2usize..=50).prop_flat_map(|period| {
2719            (
2720                prop::collection::vec(
2721                    (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
2722                    period..400,
2723                ),
2724                Just(period),
2725                0usize..=2,
2726            )
2727        });
2728
2729        proptest::test_runner::TestRunner::default()
2730            .run(&strat, |(data, period, mode)| {
2731                let params = AsoParams {
2732                    period: Some(period),
2733                    mode: Some(mode),
2734                };
2735
2736                let mut open = Vec::with_capacity(data.len());
2737                let mut high = Vec::with_capacity(data.len());
2738                let mut low = Vec::with_capacity(data.len());
2739                let mut close = Vec::with_capacity(data.len());
2740
2741                for &val in &data {
2742                    let spread = val.abs() * 0.1 + 1.0;
2743                    open.push(val);
2744                    high.push(val + spread);
2745                    low.push(val - spread);
2746                    close.push(val + spread * 0.5);
2747                }
2748
2749                let input = AsoInput::from_slices(&open, &high, &low, &close, params);
2750
2751                let AsoOutput {
2752                    bulls: out_bulls,
2753                    bears: out_bears,
2754                } = aso_with_kernel(&input, kernel).unwrap();
2755                let AsoOutput {
2756                    bulls: ref_bulls,
2757                    bears: ref_bears,
2758                } = aso_with_kernel(&input, Kernel::Scalar).unwrap();
2759
2760                for i in (period - 1)..data.len() {
2761                    let bull = out_bulls[i];
2762                    let bear = out_bears[i];
2763                    let ref_bull = ref_bulls[i];
2764                    let ref_bear = ref_bears[i];
2765
2766                    if !bull.is_nan() && !bear.is_nan() {
2767                        let sum = bull + bear;
2768                        prop_assert!(
2769                            (sum - 100.0).abs() < 1e-9,
2770                            "idx {}: bulls + bears = {} + {} = {}, expected 100",
2771                            i,
2772                            bull,
2773                            bear,
2774                            sum
2775                        );
2776                    }
2777
2778                    if !bull.is_nan() {
2779                        prop_assert!(
2780                            bull >= -1e-9 && bull <= 100.0 + 1e-9,
2781                            "idx {}: bull {} out of range [0, 100]",
2782                            i,
2783                            bull
2784                        );
2785                    }
2786                    if !bear.is_nan() {
2787                        prop_assert!(
2788                            bear >= -1e-9 && bear <= 100.0 + 1e-9,
2789                            "idx {}: bear {} out of range [0, 100]",
2790                            i,
2791                            bear
2792                        );
2793                    }
2794
2795                    let bull_bits = bull.to_bits();
2796                    let bear_bits = bear.to_bits();
2797                    let ref_bull_bits = ref_bull.to_bits();
2798                    let ref_bear_bits = ref_bear.to_bits();
2799
2800                    if !bull.is_finite() || !ref_bull.is_finite() {
2801                        prop_assert!(
2802                            bull_bits == ref_bull_bits,
2803                            "bull finite/NaN mismatch idx {}: {} vs {}",
2804                            i,
2805                            bull,
2806                            ref_bull
2807                        );
2808                    } else {
2809                        let ulp_diff: u64 = bull_bits.abs_diff(ref_bull_bits);
2810                        prop_assert!(
2811                            (bull - ref_bull).abs() <= 1e-9 || ulp_diff <= 4,
2812                            "bull mismatch idx {}: {} vs {} (ULP={})",
2813                            i,
2814                            bull,
2815                            ref_bull,
2816                            ulp_diff
2817                        );
2818                    }
2819
2820                    if !bear.is_finite() || !ref_bear.is_finite() {
2821                        prop_assert!(
2822                            bear_bits == ref_bear_bits,
2823                            "bear finite/NaN mismatch idx {}: {} vs {}",
2824                            i,
2825                            bear,
2826                            ref_bear
2827                        );
2828                    } else {
2829                        let ulp_diff: u64 = bear_bits.abs_diff(ref_bear_bits);
2830                        prop_assert!(
2831                            (bear - ref_bear).abs() <= 1e-9 || ulp_diff <= 4,
2832                            "bear mismatch idx {}: {} vs {} (ULP={})",
2833                            i,
2834                            bear,
2835                            ref_bear,
2836                            ulp_diff
2837                        );
2838                    }
2839                }
2840                Ok(())
2841            })
2842            .unwrap();
2843
2844        Ok(())
2845    }
2846
2847    fn check_batch_default_row(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2848        skip_if_unsupported!(kernel, test_name);
2849
2850        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2851        let candles = read_candles_from_csv(file_path)?;
2852
2853        let output = AsoBatchBuilder::new()
2854            .kernel(kernel)
2855            .apply_candles(&candles)?;
2856
2857        let default_params = AsoParams::default();
2858        let default_row_idx = output
2859            .combos
2860            .iter()
2861            .position(|p| p.period == default_params.period && p.mode == default_params.mode)
2862            .expect("default row missing");
2863
2864        let bulls_row = output.bulls_row(default_row_idx);
2865        let bears_row = output.bears_row(default_row_idx);
2866
2867        assert_eq!(bulls_row.len(), candles.close.len());
2868        assert_eq!(bears_row.len(), candles.close.len());
2869
2870        let expected_bulls = [
2871            48.48594883,
2872            46.37206396,
2873            47.20522805,
2874            46.83750720,
2875            43.28268188,
2876        ];
2877        let expected_bears = [
2878            51.51405117,
2879            53.62793604,
2880            52.79477195,
2881            53.16249280,
2882            56.71731812,
2883        ];
2884
2885        let start = bulls_row.len() - 5;
2886        for (i, (&bull, &bear)) in bulls_row[start..]
2887            .iter()
2888            .zip(bears_row[start..].iter())
2889            .enumerate()
2890        {
2891            assert!(
2892                (bull - expected_bulls[i]).abs() < 1e-6,
2893                "[{}] default-row bulls mismatch at idx {}: {} vs {}",
2894                test_name,
2895                i,
2896                bull,
2897                expected_bulls[i]
2898            );
2899            assert!(
2900                (bear - expected_bears[i]).abs() < 1e-6,
2901                "[{}] default-row bears mismatch at idx {}: {} vs {}",
2902                test_name,
2903                i,
2904                bear,
2905                expected_bears[i]
2906            );
2907        }
2908        Ok(())
2909    }
2910
2911    fn check_batch_sweep(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2912        skip_if_unsupported!(kernel, test_name);
2913
2914        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2915        let candles = read_candles_from_csv(file_path)?;
2916
2917        let output = AsoBatchBuilder::new()
2918            .kernel(kernel)
2919            .period_range(10, 20, 2)
2920            .mode_range(0, 2, 1)
2921            .apply_candles(&candles)?;
2922
2923        let expected_combos = 6 * 3;
2924        assert_eq!(output.combos.len(), expected_combos);
2925        assert_eq!(output.rows, expected_combos);
2926        assert_eq!(output.cols, candles.close.len());
2927
2928        let mut found_combos = 0;
2929        for period in (10..=20).step_by(2) {
2930            for mode in 0..=2 {
2931                let found = output
2932                    .combos
2933                    .iter()
2934                    .any(|c| c.period == Some(period) && c.mode == Some(mode));
2935                assert!(
2936                    found,
2937                    "[{}] Missing combo: period={}, mode={}",
2938                    test_name, period, mode
2939                );
2940                if found {
2941                    found_combos += 1;
2942                }
2943            }
2944        }
2945        assert_eq!(found_combos, expected_combos);
2946
2947        Ok(())
2948    }
2949
2950    #[cfg(debug_assertions)]
2951    fn check_batch_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2952        skip_if_unsupported!(kernel, test_name);
2953
2954        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2955        let candles = read_candles_from_csv(file_path)?;
2956
2957        let test_configs = vec![
2958            (2, 10, 2, 0, 2, 1),
2959            (5, 25, 5, 0, 0, 1),
2960            (10, 10, 1, 0, 2, 1),
2961            (2, 5, 1, 1, 1, 1),
2962            (30, 60, 15, 2, 2, 1),
2963            (9, 15, 3, 0, 2, 2),
2964            (8, 12, 1, 0, 1, 1),
2965        ];
2966
2967        for (cfg_idx, &(p_start, p_end, p_step, m_start, m_end, m_step)) in
2968            test_configs.iter().enumerate()
2969        {
2970            let output = AsoBatchBuilder::new()
2971                .kernel(kernel)
2972                .period_range(p_start, p_end, p_step)
2973                .mode_range(m_start, m_end, m_step)
2974                .apply_candles(&candles)?;
2975
2976            for (row_idx, combo) in output.combos.iter().enumerate() {
2977                let bulls_row = output.bulls_row(row_idx);
2978                let bears_row = output.bears_row(row_idx);
2979
2980                for (col_idx, (&bull_val, &bear_val)) in
2981                    bulls_row.iter().zip(bears_row.iter()).enumerate()
2982                {
2983                    if bull_val.is_nan() || bear_val.is_nan() {
2984                        continue;
2985                    }
2986
2987                    let bull_bits = bull_val.to_bits();
2988                    let bear_bits = bear_val.to_bits();
2989
2990                    for (val, bits, name) in [
2991                        (bull_val, bull_bits, "bulls"),
2992                        (bear_val, bear_bits, "bears"),
2993                    ] {
2994                        if bits == 0x11111111_11111111 {
2995                            panic!(
2996                                "[{}] Config {}: Found alloc_with_nan_prefix poison {} (0x{:016X}) in {} \
2997                                at row {} col {} (period={}, mode={})",
2998                                test_name, cfg_idx, val, bits, name, row_idx, col_idx,
2999                                combo.period.unwrap_or(10), combo.mode.unwrap_or(0)
3000                            );
3001                        }
3002
3003                        if bits == 0x22222222_22222222 {
3004                            panic!(
3005                                "[{}] Config {}: Found init_matrix_prefixes poison {} (0x{:016X}) in {} \
3006                                at row {} col {} (period={}, mode={})",
3007                                test_name, cfg_idx, val, bits, name, row_idx, col_idx,
3008                                combo.period.unwrap_or(10), combo.mode.unwrap_or(0)
3009                            );
3010                        }
3011
3012                        if bits == 0x33333333_33333333 {
3013                            panic!(
3014                                "[{}] Config {}: Found make_uninit_matrix poison {} (0x{:016X}) in {} \
3015                                at row {} col {} (period={}, mode={})",
3016                                test_name, cfg_idx, val, bits, name, row_idx, col_idx,
3017                                combo.period.unwrap_or(10), combo.mode.unwrap_or(0)
3018                            );
3019                        }
3020                    }
3021                }
3022            }
3023        }
3024
3025        Ok(())
3026    }
3027
3028    #[cfg(not(debug_assertions))]
3029    fn check_batch_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
3030        Ok(())
3031    }
3032
3033    macro_rules! generate_all_aso_tests {
3034        ($($test_fn:ident),*) => {
3035            paste::paste! {
3036                $(
3037                    #[test]
3038                    fn [<$test_fn _scalar_f64>]() {
3039                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
3040                    }
3041                )*
3042                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3043                $(
3044                    #[test]
3045                    fn [<$test_fn _avx2_f64>]() {
3046                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
3047                    }
3048                    #[test]
3049                    fn [<$test_fn _avx512_f64>]() {
3050                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
3051                    }
3052                )*
3053                #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
3054                $(
3055                    #[test]
3056                    fn [<$test_fn _simd128_f64>]() {
3057                        let _ = $test_fn(stringify!([<$test_fn _simd128_f64>]), Kernel::Scalar);
3058                    }
3059                )*
3060            }
3061        }
3062    }
3063
3064    generate_all_aso_tests!(
3065        check_aso_accuracy,
3066        check_aso_slice_input,
3067        check_aso_into_slices,
3068        check_aso_batch,
3069        check_aso_partial_params,
3070        check_aso_default_candles,
3071        check_aso_zero_period,
3072        check_aso_period_exceeds_length,
3073        check_aso_invalid_mode,
3074        check_aso_empty_input,
3075        check_aso_all_nan,
3076        check_aso_very_small_dataset,
3077        check_aso_reinput,
3078        check_aso_nan_handling,
3079        check_aso_streaming,
3080        check_aso_no_poison
3081    );
3082
3083    #[cfg(feature = "proptest")]
3084    generate_all_aso_tests!(check_aso_property);
3085
3086    macro_rules! gen_batch_tests {
3087        ($fn_name:ident) => {
3088            paste::paste! {
3089                #[test] fn [<$fn_name _scalar>]() {
3090                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
3091                }
3092                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3093                #[test] fn [<$fn_name _avx2>]() {
3094                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
3095                }
3096                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3097                #[test] fn [<$fn_name _avx512>]() {
3098                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
3099                }
3100                #[test] fn [<$fn_name _auto_detect>]() {
3101                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
3102                }
3103            }
3104        };
3105    }
3106
3107    gen_batch_tests!(check_batch_default_row);
3108    gen_batch_tests!(check_batch_sweep);
3109    gen_batch_tests!(check_batch_no_poison);
3110
3111    #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
3112    #[test]
3113    fn test_aso_into_matches_api() -> Result<(), Box<dyn Error>> {
3114        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3115        let candles = read_candles_from_csv(file_path)?;
3116
3117        let input = AsoInput::from_candles(&candles, "close", AsoParams::default());
3118
3119        let base = aso(&input)?;
3120
3121        let mut bulls = vec![0.0; candles.close.len()];
3122        let mut bears = vec![0.0; candles.close.len()];
3123        aso_into(&input, &mut bulls, &mut bears)?;
3124
3125        assert_eq!(bulls.len(), base.bulls.len());
3126        assert_eq!(bears.len(), base.bears.len());
3127
3128        fn eq_or_both_nan(a: f64, b: f64) -> bool {
3129            (a.is_nan() && b.is_nan()) || (a == b)
3130        }
3131
3132        for i in 0..bulls.len() {
3133            assert!(
3134                eq_or_both_nan(bulls[i], base.bulls[i]),
3135                "bulls mismatch at {}: got {}, expected {}",
3136                i,
3137                bulls[i],
3138                base.bulls[i]
3139            );
3140            assert!(
3141                eq_or_both_nan(bears[i], base.bears[i]),
3142                "bears mismatch at {}: got {}, expected {}",
3143                i,
3144                bears[i],
3145                base.bears[i]
3146            );
3147        }
3148
3149        Ok(())
3150    }
3151
3152    #[test]
3153    fn test_new_api_features() -> Result<(), Box<dyn Error>> {
3154        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3155        let candles = read_candles_from_csv(file_path)?;
3156
3157        let input = AsoInput::from_candles(&candles, "close", AsoParams::default());
3158        let _data_ref: &[f64] = input.as_ref();
3159
3160        let builder = AsoBatchBuilder::new().period_static(10).mode_static(0);
3161        let output = builder.apply_candles(&candles)?;
3162        assert_eq!(output.combos.len(), 1);
3163
3164        let output2 = AsoBatchBuilder::with_default_candles(&candles)?;
3165        assert!(output2.combos.len() > 0);
3166
3167        let output3 = AsoBatchBuilder::with_default_slices(
3168            &candles.open,
3169            &candles.high,
3170            &candles.low,
3171            &candles.close,
3172            Kernel::Scalar,
3173        )?;
3174        assert!(output3.combos.len() > 0);
3175
3176        let params = AsoParams::default();
3177        if let Some(row) = output2.row_for_params(&params) {
3178            let bulls_row = output2.bulls_row(row);
3179            let bears_row = output2.bears_row(row);
3180            assert_eq!(bulls_row.len(), candles.close.len());
3181            assert_eq!(bears_row.len(), candles.close.len());
3182        }
3183
3184        if let Some((bulls, bears)) = output2.values_for(&params) {
3185            assert_eq!(bulls.len(), candles.close.len());
3186            assert_eq!(bears.len(), candles.close.len());
3187        }
3188
3189        let sweep = AsoBatchRange::default();
3190        let output4 = aso_batch_slice(
3191            &candles.open,
3192            &candles.high,
3193            &candles.low,
3194            &candles.close,
3195            &sweep,
3196            Kernel::Scalar,
3197        )?;
3198        assert!(output4.combos.len() > 0);
3199
3200        let output5 = aso_batch_par_slice(
3201            &candles.open,
3202            &candles.high,
3203            &candles.low,
3204            &candles.close,
3205            &sweep,
3206            Kernel::Scalar,
3207        )?;
3208        assert_eq!(output4.combos.len(), output5.combos.len());
3209
3210        let input_high = AsoInput::from_candles(&candles, "high", AsoParams::default());
3211        let _output_high = aso(&input_high)?;
3212
3213        Ok(())
3214    }
3215}