Skip to main content

vector_ta/indicators/
ao.rs

1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::{PyDict, PyList};
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::utilities::data_loader::{source_type, Candles};
16#[cfg(all(feature = "python", feature = "cuda"))]
17use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
18use crate::utilities::enums::Kernel;
19use crate::utilities::helpers::{
20    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
21    make_uninit_matrix,
22};
23#[cfg(feature = "python")]
24use crate::utilities::kernel_validation::validate_kernel;
25#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
26use core::arch::x86_64::*;
27#[cfg(not(target_arch = "wasm32"))]
28use rayon::prelude::*;
29use std::error::Error;
30use std::mem::{ManuallyDrop, MaybeUninit};
31use thiserror::Error;
32
33#[cfg(all(feature = "python", feature = "cuda"))]
34use crate::cuda::CudaAo;
35#[cfg(all(feature = "python", feature = "cuda"))]
36use cust::context::Context;
37#[cfg(all(feature = "python", feature = "cuda"))]
38use cust::memory::DeviceBuffer;
39#[cfg(all(feature = "python", feature = "cuda"))]
40use std::sync::Arc;
41
42impl<'a> AsRef<[f64]> for AoInput<'a> {
43    #[inline(always)]
44    fn as_ref(&self) -> &[f64] {
45        match &self.data {
46            AoData::Slice(slice) => slice,
47            AoData::Candles { candles, source } => source_type(candles, source),
48        }
49    }
50}
51
52#[derive(Debug, Clone)]
53pub enum AoData<'a> {
54    Candles {
55        candles: &'a Candles,
56        source: &'a str,
57    },
58    Slice(&'a [f64]),
59}
60
61#[derive(Debug, Clone)]
62pub struct AoOutput {
63    pub values: Vec<f64>,
64}
65
66#[derive(Debug, Clone)]
67#[cfg_attr(
68    all(target_arch = "wasm32", feature = "wasm"),
69    derive(Serialize, Deserialize)
70)]
71pub struct AoParams {
72    pub short_period: Option<usize>,
73    pub long_period: Option<usize>,
74}
75
76impl Default for AoParams {
77    fn default() -> Self {
78        Self {
79            short_period: Some(5),
80            long_period: Some(34),
81        }
82    }
83}
84
85#[derive(Debug, Clone)]
86pub struct AoInput<'a> {
87    pub data: AoData<'a>,
88    pub params: AoParams,
89}
90
91impl<'a> AoInput<'a> {
92    #[inline]
93    pub fn from_candles(c: &'a Candles, s: &'a str, p: AoParams) -> Self {
94        Self {
95            data: AoData::Candles {
96                candles: c,
97                source: s,
98            },
99            params: p,
100        }
101    }
102    #[inline]
103    pub fn from_slice(sl: &'a [f64], p: AoParams) -> Self {
104        Self {
105            data: AoData::Slice(sl),
106            params: p,
107        }
108    }
109    #[inline]
110    pub fn with_default_candles(c: &'a Candles) -> Self {
111        Self::from_candles(c, "hl2", AoParams::default())
112    }
113    #[inline]
114    pub fn get_short(&self) -> usize {
115        self.params.short_period.unwrap_or(5)
116    }
117    #[inline]
118    pub fn get_long(&self) -> usize {
119        self.params.long_period.unwrap_or(34)
120    }
121}
122
123#[derive(Copy, Clone, Debug)]
124pub struct AoBuilder {
125    short_period: Option<usize>,
126    long_period: Option<usize>,
127    kernel: Kernel,
128}
129
130impl Default for AoBuilder {
131    fn default() -> Self {
132        Self {
133            short_period: None,
134            long_period: None,
135            kernel: Kernel::Auto,
136        }
137    }
138}
139
140impl AoBuilder {
141    #[inline(always)]
142    pub fn new() -> Self {
143        Self::default()
144    }
145    #[inline(always)]
146    pub fn short_period(mut self, n: usize) -> Self {
147        self.short_period = Some(n);
148        self
149    }
150    #[inline(always)]
151    pub fn long_period(mut self, n: usize) -> Self {
152        self.long_period = Some(n);
153        self
154    }
155    #[inline(always)]
156    pub fn kernel(mut self, k: Kernel) -> Self {
157        self.kernel = k;
158        self
159    }
160
161    #[inline(always)]
162    pub fn apply(self, c: &Candles) -> Result<AoOutput, AoError> {
163        let p = AoParams {
164            short_period: self.short_period,
165            long_period: self.long_period,
166        };
167        let i = AoInput::from_candles(c, "hl2", p);
168        ao_with_kernel(&i, self.kernel)
169    }
170    #[inline(always)]
171    pub fn apply_slice(self, d: &[f64]) -> Result<AoOutput, AoError> {
172        let p = AoParams {
173            short_period: self.short_period,
174            long_period: self.long_period,
175        };
176        let i = AoInput::from_slice(d, p);
177        ao_with_kernel(&i, self.kernel)
178    }
179    #[inline(always)]
180    pub fn into_stream(self) -> Result<AoStream, AoError> {
181        let p = AoParams {
182            short_period: self.short_period,
183            long_period: self.long_period,
184        };
185        AoStream::try_new(p)
186    }
187}
188
189#[derive(Debug, Error)]
190pub enum AoError {
191    #[error("ao: Input data slice is empty.")]
192    EmptyInputData,
193    #[error("ao: All values are NaN.")]
194    AllValuesNaN,
195    #[error("ao: Invalid periods: short={short}, long={long}")]
196    InvalidPeriods { short: usize, long: usize },
197    #[error("ao: Short period must be less than long period: short={short}, long={long}")]
198    ShortPeriodNotLess { short: usize, long: usize },
199    #[error("ao: Not enough valid data: needed = {needed}, valid = {valid}")]
200    NotEnoughValidData { needed: usize, valid: usize },
201    #[error("ao: High and low arrays must have same length: high={high_len}, low={low_len}")]
202    MismatchedArrayLengths { high_len: usize, low_len: usize },
203    #[error("ao: output length mismatch: expected = {expected}, got = {got}")]
204    OutputLengthMismatch { expected: usize, got: usize },
205    #[error("ao: invalid kernel for batch: {0:?}")]
206    InvalidKernelForBatch(Kernel),
207    #[error("ao: invalid range: start={start}, end={end}, step={step}")]
208    InvalidRange {
209        start: usize,
210        end: usize,
211        step: usize,
212    },
213    #[error("ao: invalid input: {0}")]
214    InvalidInput(String),
215}
216
217#[inline]
218pub fn ao(input: &AoInput) -> Result<AoOutput, AoError> {
219    ao_with_kernel(input, Kernel::Auto)
220}
221
222#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
223pub fn ao_into(input: &AoInput, out: &mut [f64]) -> Result<(), AoError> {
224    let (data, short, long, first, len) = ao_prepare(input)?;
225    if out.len() != len {
226        return Err(AoError::OutputLengthMismatch {
227            expected: len,
228            got: out.len(),
229        });
230    }
231
232    let warmup_end = first + long - 1;
233    let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
234    for v in &mut out[..warmup_end.min(len)] {
235        *v = qnan;
236    }
237
238    let chosen = detect_best_kernel();
239    unsafe {
240        match chosen {
241            Kernel::Scalar | Kernel::ScalarBatch => ao_scalar(data, short, long, first, out),
242            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
243            Kernel::Avx2 | Kernel::Avx2Batch => ao_avx2(data, short, long, first, out),
244            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
245            Kernel::Avx512 | Kernel::Avx512Batch => ao_avx512(data, short, long, first, out),
246            _ => unreachable!(),
247        }
248    }
249
250    Ok(())
251}
252
253#[inline]
254pub fn ao_into_slice(dst: &mut [f64], input: &AoInput, kern: Kernel) -> Result<(), AoError> {
255    let (data, short, long, first, len) = ao_prepare(input)?;
256    if dst.len() != len {
257        return Err(AoError::OutputLengthMismatch {
258            expected: len,
259            got: dst.len(),
260        });
261    }
262
263    let chosen = match kern {
264        Kernel::Auto => detect_best_kernel(),
265        k => k,
266    };
267
268    unsafe {
269        match chosen {
270            Kernel::Scalar | Kernel::ScalarBatch => ao_scalar(data, short, long, first, dst),
271            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
272            Kernel::Avx2 | Kernel::Avx2Batch => ao_avx2(data, short, long, first, dst),
273            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
274            Kernel::Avx512 | Kernel::Avx512Batch => ao_avx512(data, short, long, first, dst),
275            _ => unreachable!(),
276        }
277    }
278
279    let warmup_end = first + long - 1;
280    for v in &mut dst[..warmup_end] {
281        *v = f64::NAN;
282    }
283    Ok(())
284}
285
286#[inline(always)]
287fn ao_prepare<'a>(input: &'a AoInput) -> Result<(&'a [f64], usize, usize, usize, usize), AoError> {
288    let data = input.as_ref();
289    if data.is_empty() {
290        return Err(AoError::EmptyInputData);
291    }
292
293    let first = data
294        .iter()
295        .position(|x| !x.is_nan())
296        .ok_or(AoError::AllValuesNaN)?;
297    let len = data.len();
298    let short = input.get_short();
299    let long = input.get_long();
300
301    if short == 0 || long == 0 {
302        return Err(AoError::InvalidPeriods { short, long });
303    }
304    if short >= long {
305        return Err(AoError::ShortPeriodNotLess { short, long });
306    }
307    if len - first < long {
308        return Err(AoError::NotEnoughValidData {
309            needed: long,
310            valid: len - first,
311        });
312    }
313    Ok((data, short, long, first, len))
314}
315
316pub fn ao_with_kernel(input: &AoInput, kernel: Kernel) -> Result<AoOutput, AoError> {
317    let (data, short, long, first, len) = ao_prepare(input)?;
318
319    let chosen = match kernel {
320        Kernel::Auto => detect_best_kernel(),
321        other => other,
322    };
323
324    let warmup_period = first + long - 1;
325
326    let mut out = alloc_with_nan_prefix(len, warmup_period);
327
328    unsafe {
329        match chosen {
330            Kernel::Scalar | Kernel::ScalarBatch => ao_scalar(data, short, long, first, &mut out),
331            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
332            Kernel::Avx2 | Kernel::Avx2Batch => ao_avx2(data, short, long, first, &mut out),
333            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
334            Kernel::Avx512 | Kernel::Avx512Batch => ao_avx512(data, short, long, first, &mut out),
335            _ => unreachable!(),
336        }
337    }
338
339    Ok(AoOutput { values: out })
340}
341
342#[inline]
343pub fn compute_hl2(high: &[f64], low: &[f64]) -> Result<Vec<f64>, AoError> {
344    if high.len() != low.len() {
345        return Err(AoError::MismatchedArrayLengths {
346            high_len: high.len(),
347            low_len: low.len(),
348        });
349    }
350    if high.is_empty() {
351        return Err(AoError::EmptyInputData);
352    }
353
354    let mut out = alloc_with_nan_prefix(high.len(), 0);
355
356    for i in 0..high.len() {
357        unsafe {
358            *out.get_unchecked_mut(i) = (*high.get_unchecked(i) + *low.get_unchecked(i)) * 0.5;
359        }
360    }
361    Ok(out)
362}
363
364#[inline(always)]
365pub fn ao_scalar(data: &[f64], short: usize, long: usize, first: usize, out: &mut [f64]) {
366    let len = data.len();
367    if len == 0 {
368        return;
369    }
370    let warm = first + long - 1;
371    if warm >= len {
372        return;
373    }
374
375    let inv_s = 1.0 / (short as f64);
376    let inv_l = 1.0 / (long as f64);
377
378    unsafe {
379        let base = data.as_ptr();
380
381        let mut long_sum = 0.0f64;
382        let mut p = base.add(first);
383        for _ in 0..(long - 1) {
384            long_sum += *p;
385            p = p.add(1);
386        }
387
388        let mut short_sum = 0.0f64;
389        let mut ps = base.add(first + long - short);
390        for _ in 0..(short - 1) {
391            short_sum += *ps;
392            ps = ps.add(1);
393        }
394
395        let mut head = base.add(warm);
396        let mut tail_long = base.add(first);
397        let mut tail_short = base.add(first + long - short);
398        let mut outp = out.as_mut_ptr().add(warm);
399
400        let mut i = warm;
401
402        while i + 1 < len {
403            let v0 = *head;
404            long_sum += v0;
405            short_sum += v0;
406            *outp = short_sum.mul_add(inv_s, -long_sum * inv_l);
407            long_sum -= *tail_long;
408            short_sum -= *tail_short;
409            head = head.add(1);
410            tail_long = tail_long.add(1);
411            tail_short = tail_short.add(1);
412            outp = outp.add(1);
413            i += 1;
414
415            let v1 = *head;
416            long_sum += v1;
417            short_sum += v1;
418            *outp = short_sum.mul_add(inv_s, -long_sum * inv_l);
419            long_sum -= *tail_long;
420            short_sum -= *tail_short;
421            head = head.add(1);
422            tail_long = tail_long.add(1);
423            tail_short = tail_short.add(1);
424            outp = outp.add(1);
425            i += 1;
426        }
427
428        while i < len {
429            let v = *head;
430            long_sum += v;
431            short_sum += v;
432            *outp = short_sum.mul_add(inv_s, -long_sum * inv_l);
433            long_sum -= *tail_long;
434            short_sum -= *tail_short;
435            head = head.add(1);
436            tail_long = tail_long.add(1);
437            tail_short = tail_short.add(1);
438            outp = outp.add(1);
439            i += 1;
440        }
441    }
442}
443
444#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
445#[inline]
446pub fn ao_avx512(data: &[f64], short: usize, long: usize, first: usize, out: &mut [f64]) {
447    if long <= 32 {
448        unsafe { ao_avx512_short(data, short, long, first, out) }
449    } else {
450        unsafe { ao_avx512_long(data, short, long, first, out) }
451    }
452}
453
454#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
455#[inline]
456pub fn ao_avx2(data: &[f64], short: usize, long: usize, first: usize, out: &mut [f64]) {
457    unsafe { ao_scalar(data, short, long, first, out) }
458}
459
460#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
461#[inline]
462pub unsafe fn ao_avx512_short(
463    data: &[f64],
464    short: usize,
465    long: usize,
466    first: usize,
467    out: &mut [f64],
468) {
469    ao_scalar(data, short, long, first, out)
470}
471
472#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
473#[inline]
474pub unsafe fn ao_avx512_long(
475    data: &[f64],
476    short: usize,
477    long: usize,
478    first: usize,
479    out: &mut [f64],
480) {
481    ao_scalar(data, short, long, first, out)
482}
483
484#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
485#[target_feature(enable = "avx2,fma")]
486unsafe fn ao_avx2_prefixsum(
487    data: &[f64],
488    short: usize,
489    long: usize,
490    first: usize,
491    out: &mut [f64],
492) {
493    use core::arch::x86_64::*;
494
495    let len = data.len();
496    if len == 0 {
497        return;
498    }
499    let warm = first + long - 1;
500    if warm >= len {
501        return;
502    }
503
504    let suffix_len = len - first;
505    let mut pref = Vec::<f64>::with_capacity(suffix_len + 1);
506    pref.push(0.0);
507    let mut acc = 0.0f64;
508    let p = data.as_ptr().add(first);
509    for k in 0..suffix_len {
510        acc += *p.add(k);
511        pref.push(acc);
512    }
513    let pref_ptr = pref.as_ptr();
514
515    let n_out = len - warm;
516    let out_base = out.as_mut_ptr().add(warm);
517
518    let inv_s = _mm256_set1_pd(1.0 / (short as f64));
519    let inv_l = _mm256_set1_pd(1.0 / (long as f64));
520
521    let cur0 = pref_ptr.add(long);
522    let mut cur = cur0;
523    let mut prev_s = cur0.sub(short);
524    let mut prev_l = cur0.sub(long);
525    let mut dst = out_base;
526
527    let vec_chunks = n_out / 4;
528    for _ in 0..vec_chunks {
529        let pc = _mm256_loadu_pd(cur);
530        let ps = _mm256_loadu_pd(prev_s);
531        let pl = _mm256_loadu_pd(prev_l);
532
533        let short_sum = _mm256_sub_pd(pc, ps);
534        let long_sum = _mm256_sub_pd(pc, pl);
535
536        let z = _mm256_mul_pd(long_sum, inv_l);
537        let ao = _mm256_fmsub_pd(short_sum, inv_s, z);
538
539        _mm256_storeu_pd(dst, ao);
540
541        cur = cur.add(4);
542        prev_s = prev_s.add(4);
543        prev_l = prev_l.add(4);
544        dst = dst.add(4);
545    }
546
547    let tail = n_out & 3;
548    for t in 0..tail {
549        let i_cur = long + (vec_chunks * 4 + t);
550        let pc = *pref_ptr.add(i_cur);
551        let ps = *pref_ptr.add(i_cur - short);
552        let pl = *pref_ptr.add(i_cur - long);
553        let y = pc - ps;
554        let z = pc - pl;
555        *dst.add(t) = y.mul_add(1.0 / (short as f64), -(z * (1.0 / (long as f64))));
556    }
557}
558
559#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
560#[target_feature(enable = "avx512f,fma")]
561unsafe fn ao_avx512_prefixsum(
562    data: &[f64],
563    short: usize,
564    long: usize,
565    first: usize,
566    out: &mut [f64],
567) {
568    use core::arch::x86_64::*;
569
570    let len = data.len();
571    if len == 0 {
572        return;
573    }
574    let warm = first + long - 1;
575    if warm >= len {
576        return;
577    }
578
579    let suffix_len = len - first;
580    let mut pref = Vec::<f64>::with_capacity(suffix_len + 1);
581    pref.push(0.0);
582    let mut acc = 0.0f64;
583    let p = data.as_ptr().add(first);
584    for k in 0..suffix_len {
585        acc += *p.add(k);
586        pref.push(acc);
587    }
588    let pref_ptr = pref.as_ptr();
589
590    let n_out = len - warm;
591    let out_base = out.as_mut_ptr().add(warm);
592
593    let inv_s = _mm512_set1_pd(1.0 / (short as f64));
594    let inv_l = _mm512_set1_pd(1.0 / (long as f64));
595
596    let cur0 = pref_ptr.add(long);
597    let mut cur = cur0;
598    let mut prev_s = cur0.sub(short);
599    let mut prev_l = cur0.sub(long);
600    let mut dst = out_base;
601
602    let vec_chunks = n_out / 8;
603    for _ in 0..vec_chunks {
604        let pc = _mm512_loadu_pd(cur);
605        let ps = _mm512_loadu_pd(prev_s);
606        let pl = _mm512_loadu_pd(prev_l);
607
608        let short_sum = _mm512_sub_pd(pc, ps);
609        let long_sum = _mm512_sub_pd(pc, pl);
610
611        let z = _mm512_mul_pd(long_sum, inv_l);
612        let ao = _mm512_fmsub_pd(short_sum, inv_s, z);
613
614        _mm512_storeu_pd(dst, ao);
615
616        cur = cur.add(8);
617        prev_s = prev_s.add(8);
618        prev_l = prev_l.add(8);
619        dst = dst.add(8);
620    }
621
622    let tail = n_out & 7;
623    for t in 0..tail {
624        let i_cur = long + (vec_chunks * 8 + t);
625        let pc = *pref_ptr.add(i_cur);
626        let ps = *pref_ptr.add(i_cur - short);
627        let pl = *pref_ptr.add(i_cur - long);
628        let y = pc - ps;
629        let z = pc - pl;
630        *dst.add(t) = y.mul_add(1.0 / (short as f64), -(z * (1.0 / (long as f64))));
631    }
632}
633
634#[derive(Debug, Clone)]
635pub struct AoStream {
636    short: usize,
637    long: usize,
638    inv_short: f64,
639    inv_long: f64,
640
641    buf: Vec<f64>,
642
643    head: usize,
644
645    filled: usize,
646
647    short_sum: f64,
648    long_sum: f64,
649}
650
651impl AoStream {
652    pub fn try_new(params: AoParams) -> Result<Self, AoError> {
653        let short = params.short_period.unwrap_or(5);
654        let long = params.long_period.unwrap_or(34);
655        if short == 0 || long == 0 {
656            return Err(AoError::InvalidPeriods { short, long });
657        }
658        if short >= long {
659            return Err(AoError::ShortPeriodNotLess { short, long });
660        }
661        Ok(Self {
662            short,
663            long,
664            inv_short: 1.0 / (short as f64),
665            inv_long: 1.0 / (long as f64),
666            buf: vec![0.0; long],
667            head: 0,
668            filled: 0,
669            short_sum: 0.0,
670            long_sum: 0.0,
671        })
672    }
673
674    #[inline(always)]
675    pub fn update(&mut self, value: f64) -> Option<f64> {
676        let old_long = if self.filled == self.long {
677            self.buf[self.head]
678        } else {
679            0.0
680        };
681
682        let old_short = if self.filled >= self.short {
683            let idx = if self.head >= self.short {
684                self.head - self.short
685            } else {
686                self.head + self.long - self.short
687            };
688            self.buf[idx]
689        } else {
690            0.0
691        };
692
693        self.buf[self.head] = value;
694
695        self.head += 1;
696        if self.head == self.long {
697            self.head = 0;
698        }
699
700        if self.filled < self.long {
701            self.filled += 1;
702        }
703
704        self.long_sum += value - old_long;
705        self.short_sum += value - old_short;
706
707        if self.filled == self.long {
708            Some(
709                self.short_sum
710                    .mul_add(self.inv_short, -(self.long_sum * self.inv_long)),
711            )
712        } else {
713            None
714        }
715    }
716
717    #[inline(always)]
718    pub fn reset(&mut self) {
719        self.head = 0;
720        self.filled = 0;
721        self.short_sum = 0.0;
722        self.long_sum = 0.0;
723        for x in &mut self.buf {
724            *x = 0.0;
725        }
726    }
727}
728
729#[derive(Clone, Debug)]
730pub struct AoBatchRange {
731    pub short_period: (usize, usize, usize),
732    pub long_period: (usize, usize, usize),
733}
734
735impl Default for AoBatchRange {
736    fn default() -> Self {
737        Self {
738            short_period: (5, 5, 0),
739            long_period: (34, 283, 1),
740        }
741    }
742}
743
744#[derive(Clone, Debug, Default)]
745pub struct AoBatchBuilder {
746    range: AoBatchRange,
747    kernel: Kernel,
748}
749
750impl AoBatchBuilder {
751    pub fn new() -> Self {
752        Self::default()
753    }
754    pub fn kernel(mut self, k: Kernel) -> Self {
755        self.kernel = k;
756        self
757    }
758    pub fn short_range(mut self, start: usize, end: usize, step: usize) -> Self {
759        self.range.short_period = (start, end, step);
760        self
761    }
762    pub fn short_static(mut self, v: usize) -> Self {
763        self.range.short_period = (v, v, 0);
764        self
765    }
766    pub fn long_range(mut self, start: usize, end: usize, step: usize) -> Self {
767        self.range.long_period = (start, end, step);
768        self
769    }
770    pub fn long_static(mut self, v: usize) -> Self {
771        self.range.long_period = (v, v, 0);
772        self
773    }
774    pub fn apply_slice(self, data: &[f64]) -> Result<AoBatchOutput, AoError> {
775        ao_batch_with_kernel(data, &self.range, self.kernel)
776    }
777    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<AoBatchOutput, AoError> {
778        AoBatchBuilder::new().kernel(k).apply_slice(data)
779    }
780    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<AoBatchOutput, AoError> {
781        let slice = source_type(c, src);
782        self.apply_slice(slice)
783    }
784    pub fn with_default_candles(c: &Candles) -> Result<AoBatchOutput, AoError> {
785        AoBatchBuilder::new()
786            .kernel(Kernel::Auto)
787            .apply_candles(c, "hl2")
788    }
789}
790
791pub fn ao_batch_with_kernel(
792    data: &[f64],
793    sweep: &AoBatchRange,
794    k: Kernel,
795) -> Result<AoBatchOutput, AoError> {
796    let kernel = match k {
797        Kernel::Auto => detect_best_batch_kernel(),
798        other if other.is_batch() => other,
799        other => return Err(AoError::InvalidKernelForBatch(other)),
800    };
801    let simd = match kernel {
802        Kernel::Avx512Batch => Kernel::Avx512,
803        Kernel::Avx2Batch => Kernel::Avx2,
804        Kernel::ScalarBatch => Kernel::Scalar,
805        _ => unreachable!(),
806    };
807    ao_batch_par_slice(data, sweep, simd)
808}
809
810#[derive(Clone, Debug)]
811pub struct AoBatchOutput {
812    pub values: Vec<f64>,
813    pub combos: Vec<AoParams>,
814    pub rows: usize,
815    pub cols: usize,
816}
817impl AoBatchOutput {
818    pub fn row_for_params(&self, p: &AoParams) -> Option<usize> {
819        self.combos.iter().position(|c| {
820            c.short_period.unwrap_or(5) == p.short_period.unwrap_or(5)
821                && c.long_period.unwrap_or(34) == p.long_period.unwrap_or(34)
822        })
823    }
824    pub fn values_for(&self, p: &AoParams) -> Option<&[f64]> {
825        self.row_for_params(p).map(|row| {
826            let start = row * self.cols;
827            &self.values[start..start + self.cols]
828        })
829    }
830}
831
832#[inline(always)]
833fn expand_grid_checked(r: &AoBatchRange) -> Result<Vec<AoParams>, AoError> {
834    fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, AoError> {
835        if step == 0 {
836            return Ok(vec![start]);
837        }
838        if start == end {
839            return Ok(vec![start]);
840        }
841        let mut out = Vec::new();
842        if start < end {
843            let mut v = start;
844            while v <= end {
845                out.push(v);
846                match v.checked_add(step) {
847                    Some(next) => {
848                        if next == v {
849                            break;
850                        }
851                        v = next;
852                    }
853                    None => break,
854                }
855            }
856        } else {
857            let mut v = start;
858            while v >= end {
859                out.push(v);
860                if v < end + step {
861                    break;
862                }
863                v -= step;
864                if v == 0 {
865                    break;
866                }
867            }
868        }
869        if out.is_empty() {
870            return Err(AoError::InvalidRange { start, end, step });
871        }
872        Ok(out)
873    }
874    let shorts = axis_usize(r.short_period)?;
875    let longs = axis_usize(r.long_period)?;
876
877    let cap = shorts
878        .len()
879        .checked_mul(longs.len())
880        .ok_or_else(|| AoError::InvalidInput("rows*cols overflow".into()))?;
881    let mut out = Vec::with_capacity(cap);
882    for &s in &shorts {
883        for &l in &longs {
884            if s < l && s > 0 && l > 0 {
885                out.push(AoParams {
886                    short_period: Some(s),
887                    long_period: Some(l),
888                });
889            }
890        }
891    }
892    if out.is_empty() {
893        return Err(AoError::InvalidInput(
894            "no valid parameter combinations".into(),
895        ));
896    }
897    Ok(out)
898}
899
900#[inline(always)]
901pub fn ao_batch_slice(
902    data: &[f64],
903    sweep: &AoBatchRange,
904    kern: Kernel,
905) -> Result<AoBatchOutput, AoError> {
906    ao_batch_inner(data, sweep, kern, false)
907}
908#[inline(always)]
909pub fn ao_batch_par_slice(
910    data: &[f64],
911    sweep: &AoBatchRange,
912    kern: Kernel,
913) -> Result<AoBatchOutput, AoError> {
914    ao_batch_inner(data, sweep, kern, true)
915}
916#[inline(always)]
917fn ao_batch_inner_into(
918    data: &[f64],
919    sweep: &AoBatchRange,
920    kern: Kernel,
921    parallel: bool,
922    out: &mut [f64],
923) -> Result<Vec<AoParams>, AoError> {
924    let combos = expand_grid_checked(sweep)?;
925    let first = data
926        .iter()
927        .position(|x| !x.is_nan())
928        .ok_or(AoError::AllValuesNaN)?;
929    let max_long = combos.iter().map(|c| c.long_period.unwrap()).max().unwrap();
930    if data.len() - first < max_long {
931        return Err(AoError::NotEnoughValidData {
932            needed: max_long,
933            valid: data.len() - first,
934        });
935    }
936    let rows = combos.len();
937    let cols = data.len();
938    let do_row = |row: usize, out_row: &mut [f64]| unsafe {
939        let short = combos[row].short_period.unwrap();
940        let long = combos[row].long_period.unwrap();
941        match kern {
942            Kernel::Scalar => ao_row_scalar(data, first, short, long, out_row),
943            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
944            Kernel::Avx2 => ao_row_avx2(data, first, short, long, out_row),
945            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
946            Kernel::Avx512 => ao_row_avx512(data, first, short, long, out_row),
947            _ => unreachable!(),
948        }
949    };
950    if parallel {
951        #[cfg(not(target_arch = "wasm32"))]
952        {
953            out.par_chunks_mut(cols)
954                .enumerate()
955                .for_each(|(row, slice)| do_row(row, slice));
956        }
957
958        #[cfg(target_arch = "wasm32")]
959        {
960            for (row, slice) in out.chunks_mut(cols).enumerate() {
961                do_row(row, slice);
962            }
963        }
964    } else {
965        for (row, slice) in out.chunks_mut(cols).enumerate() {
966            do_row(row, slice);
967        }
968    }
969    Ok(combos)
970}
971
972#[inline(always)]
973fn ao_batch_inner(
974    data: &[f64],
975    sweep: &AoBatchRange,
976    kern: Kernel,
977    parallel: bool,
978) -> Result<AoBatchOutput, AoError> {
979    let combos = expand_grid_checked(sweep)?;
980    let rows = combos.len();
981    let cols = data.len();
982
983    if cols == 0 {
984        return Err(AoError::EmptyInputData);
985    }
986
987    let first = data
988        .iter()
989        .position(|x| !x.is_nan())
990        .ok_or(AoError::AllValuesNaN)?;
991    let max_long = combos.iter().map(|c| c.long_period.unwrap()).max().unwrap();
992    if data.len() - first < max_long {
993        return Err(AoError::NotEnoughValidData {
994            needed: max_long,
995            valid: data.len() - first,
996        });
997    }
998
999    let _total = rows
1000        .checked_mul(cols)
1001        .ok_or_else(|| AoError::InvalidInput("rows*cols overflow".into()))?;
1002
1003    let mut buf_mu = make_uninit_matrix(rows, cols);
1004
1005    let warm: Vec<usize> = combos
1006        .iter()
1007        .map(|c| first + c.long_period.unwrap() - 1)
1008        .collect();
1009
1010    init_matrix_prefixes(&mut buf_mu, cols, &warm);
1011
1012    let mut buf_guard = std::mem::ManuallyDrop::new(buf_mu);
1013    let values_ptr = buf_guard.as_mut_ptr() as *mut f64;
1014    let values_len = buf_guard.len();
1015    let values_cap = buf_guard.capacity();
1016
1017    let values = unsafe {
1018        let slice = std::slice::from_raw_parts_mut(values_ptr, values_len);
1019
1020        ao_batch_inner_into(data, sweep, kern, parallel, slice)?;
1021
1022        Vec::from_raw_parts(values_ptr, values_len, values_cap)
1023    };
1024
1025    Ok(AoBatchOutput {
1026        values,
1027        combos,
1028        rows,
1029        cols,
1030    })
1031}
1032#[inline(always)]
1033unsafe fn ao_row_scalar(data: &[f64], first: usize, short: usize, long: usize, out: &mut [f64]) {
1034    ao_scalar(data, short, long, first, out)
1035}
1036#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1037#[inline(always)]
1038pub unsafe fn ao_row_avx2(data: &[f64], first: usize, short: usize, long: usize, out: &mut [f64]) {
1039    ao_scalar(data, short, long, first, out)
1040}
1041#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1042#[inline(always)]
1043pub unsafe fn ao_row_avx512(
1044    data: &[f64],
1045    first: usize,
1046    short: usize,
1047    long: usize,
1048    out: &mut [f64],
1049) {
1050    if long <= 32 {
1051        ao_row_avx512_short(data, first, short, long, out);
1052    } else {
1053        ao_row_avx512_long(data, first, short, long, out);
1054    }
1055}
1056#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1057#[inline(always)]
1058pub unsafe fn ao_row_avx512_short(
1059    data: &[f64],
1060    first: usize,
1061    short: usize,
1062    long: usize,
1063    out: &mut [f64],
1064) {
1065    ao_scalar(data, short, long, first, out)
1066}
1067#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1068#[inline(always)]
1069pub unsafe fn ao_row_avx512_long(
1070    data: &[f64],
1071    first: usize,
1072    short: usize,
1073    long: usize,
1074    out: &mut [f64],
1075) {
1076    ao_scalar(data, short, long, first, out)
1077}
1078
1079#[cfg(test)]
1080mod tests {
1081    use super::*;
1082    use crate::skip_if_unsupported;
1083    use crate::utilities::data_loader::read_candles_from_csv;
1084    use paste::paste;
1085
1086    #[test]
1087    fn test_ao_into_matches_api() {
1088        let len = 256;
1089        let mut data = Vec::with_capacity(len);
1090        for i in 0..len {
1091            let x = i as f64;
1092            data.push((x * 0.01).mul_add(1.0, (x * 0.0314159).sin()));
1093        }
1094
1095        let input = AoInput::from_slice(&data, AoParams::default());
1096
1097        let base = ao(&input).expect("ao() should succeed");
1098
1099        let mut out = vec![0.0f64; len];
1100        ao_into(&input, &mut out).expect("ao_into() should succeed");
1101
1102        assert_eq!(base.values.len(), out.len());
1103
1104        fn eq_or_both_nan(a: f64, b: f64) -> bool {
1105            (a.is_nan() && b.is_nan()) || (a == b)
1106        }
1107
1108        for i in 0..len {
1109            assert!(
1110                eq_or_both_nan(base.values[i], out[i]),
1111                "Mismatch at index {}: got {}, expected {}",
1112                i,
1113                out[i],
1114                base.values[i]
1115            );
1116        }
1117    }
1118
1119    fn check_ao_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1120        skip_if_unsupported!(kernel, test_name);
1121        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1122        let candles = read_candles_from_csv(file_path)?;
1123        let partial_params = AoParams {
1124            short_period: Some(3),
1125            long_period: None,
1126        };
1127        let input = AoInput::from_candles(&candles, "hl2", partial_params);
1128        let result = ao_with_kernel(&input, kernel)?;
1129        assert_eq!(result.values.len(), candles.close.len());
1130        Ok(())
1131    }
1132    fn check_ao_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1133        skip_if_unsupported!(kernel, test_name);
1134        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1135        let candles = read_candles_from_csv(file_path)?;
1136        let input = AoInput::with_default_candles(&candles);
1137        let result = ao_with_kernel(&input, kernel)?;
1138        let expected_last_five = [-1671.3, -1401.6706, -1262.3559, -1178.4941, -1157.4118];
1139        let start = result.values.len().saturating_sub(5);
1140        for (i, &val) in result.values[start..].iter().enumerate() {
1141            let diff = (val - expected_last_five[i]).abs();
1142            assert!(
1143                diff < 1e-1,
1144                "[{}] AO {:?} mismatch at idx {}: got {}, expected {}",
1145                test_name,
1146                kernel,
1147                i,
1148                val,
1149                expected_last_five[i]
1150            );
1151        }
1152        Ok(())
1153    }
1154    fn check_ao_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1155        skip_if_unsupported!(kernel, test_name);
1156        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1157        let candles = read_candles_from_csv(file_path)?;
1158        let input = AoInput::with_default_candles(&candles);
1159        match input.data {
1160            AoData::Candles { source, .. } => assert_eq!(source, "hl2"),
1161            _ => panic!("Expected AoData::Candles"),
1162        }
1163        let output = ao_with_kernel(&input, kernel)?;
1164        assert_eq!(output.values.len(), candles.close.len());
1165        Ok(())
1166    }
1167    fn check_ao_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1168        skip_if_unsupported!(kernel, test_name);
1169        let input_data = [10.0, 20.0, 30.0];
1170        let params = AoParams {
1171            short_period: Some(0),
1172            long_period: Some(34),
1173        };
1174        let input = AoInput::from_slice(&input_data, params);
1175        let res = ao_with_kernel(&input, kernel);
1176        assert!(
1177            res.is_err(),
1178            "[{}] AO should fail with zero period",
1179            test_name
1180        );
1181        Ok(())
1182    }
1183    fn check_ao_period_exceeds_length(
1184        test_name: &str,
1185        kernel: Kernel,
1186    ) -> Result<(), Box<dyn Error>> {
1187        skip_if_unsupported!(kernel, test_name);
1188        let data_small = [10.0, 20.0, 30.0];
1189        let params = AoParams {
1190            short_period: Some(5),
1191            long_period: Some(10),
1192        };
1193        let input = AoInput::from_slice(&data_small, params);
1194        let res = ao_with_kernel(&input, kernel);
1195        assert!(
1196            res.is_err(),
1197            "[{}] AO should fail with period exceeding length",
1198            test_name
1199        );
1200        Ok(())
1201    }
1202    fn check_ao_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1203        skip_if_unsupported!(kernel, test_name);
1204        let single_point = [42.0];
1205        let params = AoParams {
1206            short_period: Some(5),
1207            long_period: Some(34),
1208        };
1209        let input = AoInput::from_slice(&single_point, params);
1210        let res = ao_with_kernel(&input, kernel);
1211        assert!(
1212            res.is_err(),
1213            "[{}] AO should fail with insufficient data",
1214            test_name
1215        );
1216        Ok(())
1217    }
1218    fn check_ao_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1219        skip_if_unsupported!(kernel, test_name);
1220        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1221        let candles = read_candles_from_csv(file_path)?;
1222        let first_params = AoParams {
1223            short_period: Some(5),
1224            long_period: Some(34),
1225        };
1226        let first_input = AoInput::from_candles(&candles, "hl2", first_params);
1227        let first_result = ao_with_kernel(&first_input, kernel)?;
1228        let second_params = AoParams {
1229            short_period: Some(3),
1230            long_period: Some(10),
1231        };
1232        let second_input = AoInput::from_slice(&first_result.values, second_params);
1233        let second_result = ao_with_kernel(&second_input, kernel)?;
1234        assert_eq!(second_result.values.len(), first_result.values.len());
1235        Ok(())
1236    }
1237    fn check_ao_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1238        skip_if_unsupported!(kernel, test_name);
1239        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1240        let candles = read_candles_from_csv(file_path)?;
1241        let input = AoInput::from_candles(
1242            &candles,
1243            "hl2",
1244            AoParams {
1245                short_period: Some(5),
1246                long_period: Some(34),
1247            },
1248        );
1249        let res = ao_with_kernel(&input, kernel)?;
1250        assert_eq!(res.values.len(), candles.close.len());
1251        if res.values.len() > 240 {
1252            for (i, &val) in res.values[240..].iter().enumerate() {
1253                assert!(
1254                    !val.is_nan(),
1255                    "[{}] Found unexpected NaN at out-index {}",
1256                    test_name,
1257                    240 + i
1258                );
1259            }
1260        }
1261        Ok(())
1262    }
1263
1264    #[cfg(debug_assertions)]
1265    fn check_ao_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1266        skip_if_unsupported!(kernel, test_name);
1267
1268        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1269        let candles = read_candles_from_csv(file_path)?;
1270
1271        let test_params = vec![
1272            AoParams::default(),
1273            AoParams {
1274                short_period: Some(2),
1275                long_period: Some(10),
1276            },
1277            AoParams {
1278                short_period: Some(3),
1279                long_period: Some(20),
1280            },
1281            AoParams {
1282                short_period: Some(10),
1283                long_period: Some(50),
1284            },
1285            AoParams {
1286                short_period: Some(20),
1287                long_period: Some(100),
1288            },
1289            AoParams {
1290                short_period: Some(5),
1291                long_period: Some(200),
1292            },
1293        ];
1294
1295        for (param_idx, params) in test_params.iter().enumerate() {
1296            let input = AoInput::from_candles(&candles, "hl2", params.clone());
1297            let result = ao_with_kernel(&input, kernel)?;
1298
1299            for (i, &val) in result.values.iter().enumerate() {
1300                if val.is_nan() {
1301                    continue;
1302                }
1303
1304                let bits = val.to_bits();
1305
1306                if bits == 0x11111111_11111111 {
1307                    panic!(
1308                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1309						 with params: short={}, long={} (param set {})",
1310                        test_name,
1311                        val,
1312                        bits,
1313                        i,
1314                        params.short_period.unwrap_or(5),
1315                        params.long_period.unwrap_or(34),
1316                        param_idx
1317                    );
1318                }
1319
1320                if bits == 0x22222222_22222222 {
1321                    panic!(
1322                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1323						 with params: short={}, long={} (param set {})",
1324                        test_name,
1325                        val,
1326                        bits,
1327                        i,
1328                        params.short_period.unwrap_or(5),
1329                        params.long_period.unwrap_or(34),
1330                        param_idx
1331                    );
1332                }
1333
1334                if bits == 0x33333333_33333333 {
1335                    panic!(
1336                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1337						 with params: short={}, long={} (param set {})",
1338                        test_name,
1339                        val,
1340                        bits,
1341                        i,
1342                        params.short_period.unwrap_or(5),
1343                        params.long_period.unwrap_or(34),
1344                        param_idx
1345                    );
1346                }
1347            }
1348        }
1349
1350        Ok(())
1351    }
1352
1353    #[cfg(not(debug_assertions))]
1354    fn check_ao_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1355        Ok(())
1356    }
1357
1358    #[cfg(feature = "proptest")]
1359    #[allow(clippy::float_cmp)]
1360    fn check_ao_property(
1361        test_name: &str,
1362        kernel: Kernel,
1363    ) -> Result<(), Box<dyn std::error::Error>> {
1364        use proptest::prelude::*;
1365        skip_if_unsupported!(kernel, test_name);
1366
1367        let strat = (1usize..=50).prop_flat_map(|short_period| {
1368            ((short_period + 1)..=100).prop_flat_map(move |long_period| {
1369                (
1370                    prop::collection::vec(
1371                        (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1372                        long_period..400,
1373                    ),
1374                    Just(short_period),
1375                    Just(long_period),
1376                )
1377            })
1378        });
1379
1380        proptest::test_runner::TestRunner::default()
1381            .run(&strat, |(data, short_period, long_period)| {
1382                let params = AoParams {
1383                    short_period: Some(short_period),
1384                    long_period: Some(long_period),
1385                };
1386                let input = AoInput::from_slice(&data, params);
1387
1388                let AoOutput { values: out } = ao_with_kernel(&input, kernel).unwrap();
1389
1390                let AoOutput { values: ref_out } = ao_with_kernel(&input, Kernel::Scalar).unwrap();
1391
1392                for i in 0..(long_period - 1) {
1393                    prop_assert!(
1394                        out[i].is_nan(),
1395                        "Expected NaN during warmup at index {}, got {}",
1396                        i,
1397                        out[i]
1398                    );
1399                }
1400
1401                for i in (long_period - 1)..data.len() {
1402                    prop_assert!(
1403                        out[i].is_finite(),
1404                        "Expected finite value after warmup at index {}, got {}",
1405                        i,
1406                        out[i]
1407                    );
1408                }
1409
1410                if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10) && data.len() >= long_period
1411                {
1412                    for i in (long_period - 1)..data.len() {
1413                        prop_assert!(
1414                            out[i].abs() < 1e-9,
1415                            "For constant data, AO should be 0 at index {}, got {}",
1416                            i,
1417                            out[i]
1418                        );
1419                    }
1420                }
1421
1422                if data.len() >= long_period {
1423                    for i in (long_period - 1)..data.len() {
1424                        let short_start = i + 1 - short_period;
1425                        let long_start = i + 1 - long_period;
1426
1427                        let long_window = &data[long_start..=i];
1428                        let window_min = long_window.iter().cloned().fold(f64::INFINITY, f64::min);
1429                        let window_max = long_window
1430                            .iter()
1431                            .cloned()
1432                            .fold(f64::NEG_INFINITY, f64::max);
1433
1434                        let theoretical_max = window_max - window_min;
1435
1436                        prop_assert!(
1437                            out[i].abs() <= theoretical_max + 1e-9,
1438                            "AO value {} at index {} exceeds theoretical max {}",
1439                            out[i],
1440                            i,
1441                            theoretical_max
1442                        );
1443                    }
1444                }
1445
1446                if short_period == 1 && long_period == 2 && data.len() >= 2 {
1447                    for i in 1..data.len() {
1448                        let expected = data[i] - (data[i] + data[i - 1]) / 2.0;
1449                        let actual = out[i];
1450                        prop_assert!(
1451							(actual - expected).abs() < 1e-9,
1452							"Special case (short=1, long=2) mismatch at index {}: expected {}, got {}",
1453							i,
1454							expected,
1455							actual
1456						);
1457                    }
1458                }
1459
1460                let is_increasing = data.windows(2).all(|w| w[1] > w[0] + 1e-10);
1461                let is_decreasing = data.windows(2).all(|w| w[1] < w[0] - 1e-10);
1462
1463                if is_increasing && data.len() >= long_period {
1464                    for i in (long_period - 1)..data.len() {
1465                        prop_assert!(
1466							out[i] > -1e-9,
1467							"For strictly increasing data, AO should be positive at index {}, got {}",
1468							i,
1469							out[i]
1470						);
1471                    }
1472                }
1473
1474                if is_decreasing && data.len() >= long_period {
1475                    for i in (long_period - 1)..data.len() {
1476                        prop_assert!(
1477							out[i] < 1e-9,
1478							"For strictly decreasing data, AO should be negative at index {}, got {}",
1479							i,
1480							out[i]
1481						);
1482                    }
1483                }
1484
1485                if data.len() >= 3 {
1486                    let diffs: Vec<f64> = data.windows(2).map(|w| w[1] - w[0]).collect();
1487                    let is_linear = diffs.windows(2).all(|w| (w[1] - w[0]).abs() < 1e-10);
1488
1489                    if is_linear && data.len() >= long_period + 10 {
1490                        let stable_start = long_period + 5;
1491                        if stable_start < data.len() - 1 {
1492                            let stable_values = &out[stable_start..];
1493                            if stable_values.len() >= 2 {
1494                                let first_stable = stable_values[0];
1495                                for (idx, &val) in stable_values.iter().enumerate() {
1496                                    prop_assert!(
1497										(val - first_stable).abs() < 1e-8,
1498										"For linear data, AO should stabilize. Value at {} differs from stable value: {} vs {}",
1499										stable_start + idx,
1500										val,
1501										first_stable
1502									);
1503                                }
1504                            }
1505                        }
1506                    }
1507                }
1508
1509                for i in 0..data.len() {
1510                    let y = out[i];
1511                    let r = ref_out[i];
1512
1513                    if y.is_nan() || r.is_nan() {
1514                        prop_assert!(
1515                            y.is_nan() && r.is_nan(),
1516                            "NaN mismatch at index {}: kernel={:?} is_nan={}, scalar is_nan={}",
1517                            i,
1518                            kernel,
1519                            y.is_nan(),
1520                            r.is_nan()
1521                        );
1522                        continue;
1523                    }
1524
1525                    let y_bits = y.to_bits();
1526                    let r_bits = r.to_bits();
1527                    let ulp_diff: u64 = y_bits.abs_diff(r_bits);
1528
1529                    prop_assert!(
1530                        (y - r).abs() <= 1e-9 || ulp_diff <= 4,
1531                        "Kernel mismatch at index {}: kernel={:?} value={}, scalar value={}, \
1532						 diff={}, ULP diff={}",
1533                        i,
1534                        kernel,
1535                        y,
1536                        r,
1537                        (y - r).abs(),
1538                        ulp_diff
1539                    );
1540                }
1541
1542                Ok(())
1543            })
1544            .unwrap();
1545
1546        Ok(())
1547    }
1548
1549    macro_rules! generate_all_ao_tests {
1550        ($($test_fn:ident),*) => {
1551            paste! {
1552                $(
1553                    #[test]
1554                    fn [<$test_fn _scalar_f64>]() {
1555                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1556                    }
1557                )*
1558                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1559                $(
1560                    #[test]
1561                    fn [<$test_fn _avx2_f64>]() {
1562                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1563                    }
1564                    #[test]
1565                    fn [<$test_fn _avx512_f64>]() {
1566                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1567                    }
1568                )*
1569            }
1570        }
1571    }
1572    generate_all_ao_tests!(
1573        check_ao_partial_params,
1574        check_ao_accuracy,
1575        check_ao_default_candles,
1576        check_ao_zero_period,
1577        check_ao_period_exceeds_length,
1578        check_ao_very_small_dataset,
1579        check_ao_reinput,
1580        check_ao_nan_handling,
1581        check_ao_no_poison
1582    );
1583
1584    #[cfg(feature = "proptest")]
1585    generate_all_ao_tests!(check_ao_property);
1586
1587    #[test]
1588    fn test_output_len_mismatch_error() {
1589        let data = vec![10.0, 20.0, 30.0, 40.0, 50.0];
1590        let params = AoParams {
1591            short_period: Some(2),
1592            long_period: Some(3),
1593        };
1594        let input = AoInput::from_slice(&data, params);
1595
1596        let mut wrong_sized_buf = vec![0.0; 10];
1597
1598        let result = ao_into_slice(&mut wrong_sized_buf, &input, Kernel::Auto);
1599        assert!(result.is_err());
1600
1601        if let Err(AoError::OutputLengthMismatch { expected, got }) = result {
1602            assert_eq!(expected, 5);
1603            assert_eq!(got, 10);
1604        } else {
1605            panic!("Expected OutputLenMismatch error");
1606        }
1607    }
1608
1609    #[test]
1610    fn test_invalid_kernel_error() {
1611        let data = vec![10.0, 20.0, 30.0, 40.0, 50.0];
1612        let sweep = AoBatchRange::default();
1613
1614        let result = ao_batch_with_kernel(&data, &sweep, Kernel::Scalar);
1615        assert!(result.is_err());
1616
1617        if let Err(AoError::InvalidKernelForBatch(kernel)) = result {
1618            assert!(matches!(kernel, Kernel::Scalar));
1619        } else {
1620            panic!("Expected InvalidKernel error");
1621        }
1622    }
1623
1624    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1625        skip_if_unsupported!(kernel, test);
1626        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1627        let c = read_candles_from_csv(file)?;
1628        let output = AoBatchBuilder::new()
1629            .kernel(kernel)
1630            .apply_candles(&c, "hl2")?;
1631        let def = AoParams::default();
1632        let row = output.values_for(&def).expect("default row missing");
1633        assert_eq!(row.len(), c.close.len());
1634        let expected = [-1671.3, -1401.6706, -1262.3559, -1178.4941, -1157.4118];
1635        let start = row.len() - 5;
1636        for (i, &v) in row[start..].iter().enumerate() {
1637            assert!(
1638                (v - expected[i]).abs() < 1e-1,
1639                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1640            );
1641        }
1642        Ok(())
1643    }
1644
1645    #[cfg(debug_assertions)]
1646    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1647        skip_if_unsupported!(kernel, test);
1648
1649        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1650        let c = read_candles_from_csv(file)?;
1651
1652        let test_configs = vec![
1653            (2, 10, 2, 15, 40, 5),
1654            (5, 20, 5, 30, 60, 10),
1655            (10, 30, 10, 40, 100, 20),
1656            (3, 3, 0, 10, 50, 10),
1657            (5, 15, 5, 34, 34, 0),
1658        ];
1659
1660        for (cfg_idx, &(short_start, short_end, short_step, long_start, long_end, long_step)) in
1661            test_configs.iter().enumerate()
1662        {
1663            let sweep = AoBatchRange {
1664                short_period: (short_start, short_end, short_step),
1665                long_period: (long_start, long_end, long_step),
1666            };
1667
1668            let output = ao_batch_with_kernel(source_type(&c, "hl2"), &sweep, kernel)?;
1669
1670            for (row, combo) in output.combos.iter().enumerate() {
1671                let row_start = row * output.cols;
1672                let row_end = row_start + output.cols;
1673                let row_values = &output.values[row_start..row_end];
1674
1675                for (col, &val) in row_values.iter().enumerate() {
1676                    if val.is_nan() {
1677                        continue;
1678                    }
1679
1680                    let bits = val.to_bits();
1681                    let idx = row * output.cols + col;
1682
1683                    if bits == 0x11111111_11111111 {
1684                        panic!(
1685							"[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1686							 at row {} col {} (flat index {}) with params: short={}, long={}",
1687							test,
1688							cfg_idx,
1689							val,
1690							bits,
1691							row,
1692							col,
1693							idx,
1694							combo.short_period.unwrap_or(5),
1695							combo.long_period.unwrap_or(34)
1696						);
1697                    }
1698
1699                    if bits == 0x22222222_22222222 {
1700                        panic!(
1701							"[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1702							 at row {} col {} (flat index {}) with params: short={}, long={}",
1703							test,
1704							cfg_idx,
1705							val,
1706							bits,
1707							row,
1708							col,
1709							idx,
1710							combo.short_period.unwrap_or(5),
1711							combo.long_period.unwrap_or(34)
1712						);
1713                    }
1714
1715                    if bits == 0x33333333_33333333 {
1716                        panic!(
1717                            "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1718							 at row {} col {} (flat index {}) with params: short={}, long={}",
1719                            test,
1720                            cfg_idx,
1721                            val,
1722                            bits,
1723                            row,
1724                            col,
1725                            idx,
1726                            combo.short_period.unwrap_or(5),
1727                            combo.long_period.unwrap_or(34)
1728                        );
1729                    }
1730                }
1731            }
1732        }
1733
1734        Ok(())
1735    }
1736
1737    #[cfg(not(debug_assertions))]
1738    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1739        Ok(())
1740    }
1741    macro_rules! gen_batch_tests {
1742        ($fn_name:ident) => {
1743            paste! {
1744                #[test] fn [<$fn_name _scalar>]()      {
1745                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1746                }
1747                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1748                #[test] fn [<$fn_name _avx2>]()        {
1749                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1750                }
1751                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1752                #[test] fn [<$fn_name _avx512>]()      {
1753                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1754                }
1755                #[test] fn [<$fn_name _auto_detect>]() {
1756                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1757                }
1758            }
1759        };
1760    }
1761    gen_batch_tests!(check_batch_default_row);
1762    gen_batch_tests!(check_batch_no_poison);
1763}
1764
1765#[cfg(feature = "python")]
1766#[pyfunction(name = "ao")]
1767#[pyo3(signature = (high, low, short_period, long_period, kernel=None))]
1768pub fn ao_py<'py>(
1769    py: Python<'py>,
1770    high: numpy::PyReadonlyArray1<'py, f64>,
1771    low: numpy::PyReadonlyArray1<'py, f64>,
1772    short_period: usize,
1773    long_period: usize,
1774    kernel: Option<&str>,
1775) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
1776    use numpy::{IntoPyArray, PyArrayMethods};
1777
1778    let high_slice = high.as_slice()?;
1779    let low_slice = low.as_slice()?;
1780
1781    let kern = validate_kernel(kernel, false)?;
1782
1783    let result_vec: Vec<f64> = py
1784        .allow_threads(|| -> Result<Vec<f64>, AoError> {
1785            let hl2 = compute_hl2(high_slice, low_slice)?;
1786
1787            let params = AoParams {
1788                short_period: Some(short_period),
1789                long_period: Some(long_period),
1790            };
1791            let ao_in = AoInput::from_slice(&hl2, params);
1792
1793            ao_with_kernel(&ao_in, kern).map(|o| o.values)
1794        })
1795        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1796
1797    Ok(result_vec.into_pyarray(py))
1798}
1799
1800#[cfg(feature = "python")]
1801#[pyclass(name = "AoStream")]
1802pub struct AoStreamPy {
1803    stream: AoStream,
1804}
1805
1806#[cfg(feature = "python")]
1807#[pymethods]
1808impl AoStreamPy {
1809    #[new]
1810    fn new(short_period: usize, long_period: usize) -> PyResult<Self> {
1811        let params = AoParams {
1812            short_period: Some(short_period),
1813            long_period: Some(long_period),
1814        };
1815        let stream = AoStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1816        Ok(AoStreamPy { stream })
1817    }
1818
1819    fn update(&mut self, high: f64, low: f64) -> Option<f64> {
1820        let hl2 = (high + low) / 2.0;
1821        self.stream.update(hl2)
1822    }
1823}
1824
1825#[cfg(feature = "python")]
1826#[pyfunction(name = "ao_batch")]
1827#[pyo3(signature = (high, low, short_period_range, long_period_range, kernel=None))]
1828pub fn ao_batch_py<'py>(
1829    py: Python<'py>,
1830    high: numpy::PyReadonlyArray1<'py, f64>,
1831    low: numpy::PyReadonlyArray1<'py, f64>,
1832    short_period_range: (usize, usize, usize),
1833    long_period_range: (usize, usize, usize),
1834    kernel: Option<&str>,
1835) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1836    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1837    use pyo3::types::PyDict;
1838
1839    let high_slice = high.as_slice()?;
1840    let low_slice = low.as_slice()?;
1841
1842    let kern = validate_kernel(kernel, true)?;
1843
1844    let sweep = AoBatchRange {
1845        short_period: short_period_range,
1846        long_period: long_period_range,
1847    };
1848
1849    let combos = expand_grid_checked(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1850    let rows = combos.len();
1851    let cols = high_slice.len();
1852
1853    let expected = rows
1854        .checked_mul(cols)
1855        .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
1856    let out_arr = unsafe { PyArray1::<f64>::new(py, [expected], false) };
1857    let slice_out = unsafe { out_arr.as_slice_mut()? };
1858
1859    let combos = py
1860        .allow_threads(|| -> Result<Vec<AoParams>, AoError> {
1861            let hl2 = compute_hl2(high_slice, low_slice)?;
1862
1863            let first = hl2.iter().position(|x| !x.is_nan()).unwrap_or(0);
1864            let warm: Vec<usize> = combos
1865                .iter()
1866                .map(|c| first + c.long_period.unwrap() - 1)
1867                .collect();
1868
1869            let slice_mu = unsafe {
1870                std::slice::from_raw_parts_mut(
1871                    slice_out.as_mut_ptr() as *mut MaybeUninit<f64>,
1872                    slice_out.len(),
1873                )
1874            };
1875
1876            init_matrix_prefixes(slice_mu, cols, &warm);
1877
1878            let kernel = match kern {
1879                Kernel::Auto => detect_best_batch_kernel(),
1880                k => k,
1881            };
1882            let simd = match kernel {
1883                Kernel::Avx512Batch => Kernel::Avx512,
1884                Kernel::Avx2Batch => Kernel::Avx2,
1885                Kernel::ScalarBatch => Kernel::Scalar,
1886                _ => unreachable!(),
1887            };
1888
1889            ao_batch_inner_into(&hl2, &sweep, simd, true, slice_out)
1890        })
1891        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1892
1893    let dict = PyDict::new(py);
1894    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1895    dict.set_item(
1896        "short_periods",
1897        combos
1898            .iter()
1899            .map(|p| p.short_period.unwrap() as u64)
1900            .collect::<Vec<_>>()
1901            .into_pyarray(py),
1902    )?;
1903    dict.set_item(
1904        "long_periods",
1905        combos
1906            .iter()
1907            .map(|p| p.long_period.unwrap() as u64)
1908            .collect::<Vec<_>>()
1909            .into_pyarray(py),
1910    )?;
1911
1912    Ok(dict)
1913}
1914
1915#[cfg(all(feature = "python", feature = "cuda"))]
1916#[pyclass(module = "ta_indicators.cuda", unsendable)]
1917pub struct DeviceArrayF32AoPy {
1918    pub(crate) buf: Option<DeviceBuffer<f32>>,
1919    pub(crate) rows: usize,
1920    pub(crate) cols: usize,
1921    pub(crate) _ctx: Arc<Context>,
1922    pub(crate) device_id: u32,
1923}
1924
1925#[cfg(all(feature = "python", feature = "cuda"))]
1926#[pymethods]
1927impl DeviceArrayF32AoPy {
1928    #[getter]
1929    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1930        let d = PyDict::new(py);
1931        d.set_item("shape", (self.rows, self.cols))?;
1932        d.set_item("typestr", "<f4")?;
1933        d.set_item(
1934            "strides",
1935            (
1936                self.cols * std::mem::size_of::<f32>(),
1937                std::mem::size_of::<f32>(),
1938            ),
1939        )?;
1940        let ptr = self
1941            .buf
1942            .as_ref()
1943            .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?
1944            .as_device_ptr()
1945            .as_raw() as usize;
1946        d.set_item("data", (ptr, false))?;
1947
1948        d.set_item("version", 3)?;
1949        Ok(d)
1950    }
1951
1952    fn __dlpack_device__(&self) -> (i32, i32) {
1953        (2, self.device_id as i32)
1954    }
1955
1956    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1957    fn __dlpack__<'py>(
1958        &mut self,
1959        py: Python<'py>,
1960        stream: Option<pyo3::PyObject>,
1961        max_version: Option<pyo3::PyObject>,
1962        dl_device: Option<pyo3::PyObject>,
1963        copy: Option<pyo3::PyObject>,
1964    ) -> PyResult<PyObject> {
1965        let (kdl, alloc_dev) = self.__dlpack_device__();
1966        if let Some(dev_obj) = dl_device.as_ref() {
1967            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1968                if dev_ty != kdl || dev_id != alloc_dev {
1969                    let wants_copy = copy
1970                        .as_ref()
1971                        .and_then(|c| c.extract::<bool>(py).ok())
1972                        .unwrap_or(false);
1973                    if wants_copy {
1974                        return Err(PyValueError::new_err(
1975                            "device copy not implemented for __dlpack__",
1976                        ));
1977                    } else {
1978                        return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
1979                    }
1980                }
1981            }
1982        }
1983
1984        if let Some(obj) = stream.as_ref() {
1985            if let Ok(i) = obj.extract::<i64>(py) {
1986                if i == 0 {
1987                    return Err(PyValueError::new_err(
1988                        "__dlpack__: stream 0 is disallowed for CUDA",
1989                    ));
1990                }
1991            }
1992        }
1993
1994        let buf = self
1995            .buf
1996            .take()
1997            .ok_or_else(|| PyValueError::new_err("__dlpack__ may only be called once"))?;
1998
1999        let rows = self.rows;
2000        let cols = self.cols;
2001
2002        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2003
2004        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2005    }
2006}
2007
2008#[cfg(all(feature = "python", feature = "cuda"))]
2009#[pyfunction(name = "ao_cuda_batch_dev")]
2010#[pyo3(signature = (high, low, short_period_range, long_period_range, device_id=0))]
2011pub fn ao_cuda_batch_dev_py(
2012    py: Python<'_>,
2013    high: numpy::PyReadonlyArray1<'_, f32>,
2014    low: numpy::PyReadonlyArray1<'_, f32>,
2015    short_period_range: (usize, usize, usize),
2016    long_period_range: (usize, usize, usize),
2017    device_id: usize,
2018) -> PyResult<DeviceArrayF32AoPy> {
2019    use crate::cuda::cuda_available;
2020    if !cuda_available() {
2021        return Err(PyValueError::new_err("CUDA not available"));
2022    }
2023    let high_slice = high.as_slice()?;
2024    let low_slice = low.as_slice()?;
2025    if high_slice.len() != low_slice.len() {
2026        return Err(PyValueError::new_err("high/low length mismatch"));
2027    }
2028    let mut hl2_f32 = vec![0f32; high_slice.len()];
2029    for i in 0..high_slice.len() {
2030        let h = high_slice[i];
2031        let l = low_slice[i];
2032        hl2_f32[i] = (h + l) * 0.5;
2033    }
2034    let sweep = AoBatchRange {
2035        short_period: short_period_range,
2036        long_period: long_period_range,
2037    };
2038    let inner = py.allow_threads(|| {
2039        let cuda = CudaAo::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2040        cuda.ao_batch_dev(&hl2_f32, &sweep)
2041            .map_err(|e| PyValueError::new_err(e.to_string()))
2042    })?;
2043    let crate::cuda::oscillators::ao_wrapper::DeviceArrayF32Ao {
2044        buf,
2045        rows,
2046        cols,
2047        ctx,
2048        device_id: dev_id,
2049    } = inner;
2050    Ok(DeviceArrayF32AoPy {
2051        buf: Some(buf),
2052        rows,
2053        cols,
2054        _ctx: ctx,
2055        device_id: dev_id,
2056    })
2057}
2058
2059#[cfg(all(feature = "python", feature = "cuda"))]
2060#[pyfunction(name = "ao_cuda_many_series_one_param_dev")]
2061#[pyo3(signature = (high_tm, low_tm, cols, rows, short_period, long_period, device_id=0))]
2062pub fn ao_cuda_many_series_one_param_dev_py(
2063    py: Python<'_>,
2064    high_tm: numpy::PyReadonlyArray1<'_, f32>,
2065    low_tm: numpy::PyReadonlyArray1<'_, f32>,
2066    cols: usize,
2067    rows: usize,
2068    short_period: usize,
2069    long_period: usize,
2070    device_id: usize,
2071) -> PyResult<DeviceArrayF32AoPy> {
2072    use crate::cuda::cuda_available;
2073    if !cuda_available() {
2074        return Err(PyValueError::new_err("CUDA not available"));
2075    }
2076    let high_slice = high_tm.as_slice()?;
2077    let low_slice = low_tm.as_slice()?;
2078    let expected = cols
2079        .checked_mul(rows)
2080        .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
2081    if high_slice.len() != expected || low_slice.len() != expected {
2082        return Err(PyValueError::new_err("time-major input length mismatch"));
2083    }
2084    let mut hl2_f32 = vec![0f32; expected];
2085    for i in 0..expected {
2086        hl2_f32[i] = (high_slice[i] + low_slice[i]) * 0.5;
2087    }
2088    let inner = py.allow_threads(|| {
2089        let cuda = CudaAo::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2090        cuda.ao_many_series_one_param_time_major_dev(
2091            &hl2_f32,
2092            cols,
2093            rows,
2094            short_period,
2095            long_period,
2096        )
2097        .map_err(|e| PyValueError::new_err(e.to_string()))
2098    })?;
2099    let crate::cuda::oscillators::ao_wrapper::DeviceArrayF32Ao {
2100        buf,
2101        rows,
2102        cols,
2103        ctx,
2104        device_id: dev_id,
2105    } = inner;
2106    Ok(DeviceArrayF32AoPy {
2107        buf: Some(buf),
2108        rows,
2109        cols,
2110        _ctx: ctx,
2111        device_id: dev_id,
2112    })
2113}
2114
2115#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2116#[wasm_bindgen]
2117pub fn ao_js(
2118    high: &[f64],
2119    low: &[f64],
2120    short_period: usize,
2121    long_period: usize,
2122) -> Result<Vec<f64>, JsValue> {
2123    let hl2 = compute_hl2(high, low).map_err(|e| JsValue::from_str(&e.to_string()))?;
2124
2125    let params = AoParams {
2126        short_period: Some(short_period),
2127        long_period: Some(long_period),
2128    };
2129    let input = AoInput::from_slice(&hl2, params);
2130
2131    ao_with_kernel(&input, Kernel::Auto)
2132        .map(|o| o.values)
2133        .map_err(|e| JsValue::from_str(&e.to_string()))
2134}
2135
2136#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2137#[wasm_bindgen]
2138pub fn ao_batch_js(
2139    high: &[f64],
2140    low: &[f64],
2141    short_start: usize,
2142    short_end: usize,
2143    short_step: usize,
2144    long_start: usize,
2145    long_end: usize,
2146    long_step: usize,
2147) -> Result<Vec<f64>, JsValue> {
2148    let hl2 = compute_hl2(high, low).map_err(|e| JsValue::from_str(&e.to_string()))?;
2149
2150    let sweep = AoBatchRange {
2151        short_period: (short_start, short_end, short_step),
2152        long_period: (long_start, long_end, long_step),
2153    };
2154
2155    ao_batch_inner(&hl2, &sweep, Kernel::Scalar, false)
2156        .map(|output| output.values)
2157        .map_err(|e| JsValue::from_str(&e.to_string()))
2158}
2159
2160#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2161#[wasm_bindgen]
2162pub fn ao_batch_metadata_js(
2163    short_start: usize,
2164    short_end: usize,
2165    short_step: usize,
2166    long_start: usize,
2167    long_end: usize,
2168    long_step: usize,
2169) -> Result<Vec<f64>, JsValue> {
2170    let sweep = AoBatchRange {
2171        short_period: (short_start, short_end, short_step),
2172        long_period: (long_start, long_end, long_step),
2173    };
2174
2175    let combos = expand_grid_checked(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2176    let mut metadata = Vec::with_capacity(combos.len() * 2);
2177
2178    for combo in combos {
2179        metadata.push(combo.short_period.unwrap() as f64);
2180        metadata.push(combo.long_period.unwrap() as f64);
2181    }
2182
2183    Ok(metadata)
2184}
2185
2186#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2187#[derive(Serialize, Deserialize)]
2188pub struct AoBatchConfig {
2189    pub short_period_range: (usize, usize, usize),
2190    pub long_period_range: (usize, usize, usize),
2191}
2192
2193#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2194#[derive(Serialize, Deserialize)]
2195pub struct AoBatchJsOutput {
2196    pub values: Vec<f64>,
2197    pub combos: Vec<AoParams>,
2198    pub rows: usize,
2199    pub cols: usize,
2200}
2201
2202#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2203#[wasm_bindgen(js_name = ao_batch)]
2204pub fn ao_batch_unified_js(high: &[f64], low: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2205    let config: AoBatchConfig = serde_wasm_bindgen::from_value(config)
2206        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2207
2208    let hl2 = compute_hl2(high, low).map_err(|e| JsValue::from_str(&e.to_string()))?;
2209
2210    let sweep = AoBatchRange {
2211        short_period: config.short_period_range,
2212        long_period: config.long_period_range,
2213    };
2214
2215    let output = ao_batch_inner(&hl2, &sweep, Kernel::Scalar, false)
2216        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2217
2218    let js_output = AoBatchJsOutput {
2219        values: output.values,
2220        combos: output.combos,
2221        rows: output.rows,
2222        cols: output.cols,
2223    };
2224
2225    serde_wasm_bindgen::to_value(&js_output)
2226        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2227}
2228
2229#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2230#[wasm_bindgen]
2231pub fn ao_alloc(len: usize) -> *mut f64 {
2232    let mut vec = Vec::<f64>::with_capacity(len);
2233    let ptr = vec.as_mut_ptr();
2234    std::mem::forget(vec);
2235    ptr
2236}
2237
2238#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2239#[wasm_bindgen]
2240pub fn ao_free(ptr: *mut f64, len: usize) {
2241    if !ptr.is_null() {
2242        unsafe {
2243            let _ = Vec::from_raw_parts(ptr, len, len);
2244        }
2245    }
2246}
2247
2248#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2249#[wasm_bindgen]
2250pub fn ao_into(
2251    in_high_ptr: *const f64,
2252    in_low_ptr: *const f64,
2253    out_ptr: *mut f64,
2254    len: usize,
2255    short_period: usize,
2256    long_period: usize,
2257) -> Result<(), JsValue> {
2258    if in_high_ptr.is_null() || in_low_ptr.is_null() || out_ptr.is_null() {
2259        return Err(JsValue::from_str("Null pointer provided"));
2260    }
2261
2262    #[inline(always)]
2263    unsafe fn overlaps(a: *const f64, b: *const f64, n: usize) -> bool {
2264        let as_ = a as usize;
2265        let ae = as_.wrapping_add(n * core::mem::size_of::<f64>());
2266        let bs_ = b as usize;
2267        let be = bs_.wrapping_add(n * core::mem::size_of::<f64>());
2268        !(ae <= bs_ || be <= as_)
2269    }
2270
2271    unsafe {
2272        let high = std::slice::from_raw_parts(in_high_ptr, len);
2273        let low = std::slice::from_raw_parts(in_low_ptr, len);
2274        let out = std::slice::from_raw_parts_mut(out_ptr, len);
2275
2276        let alias = overlaps(in_high_ptr, out_ptr as *const f64, len)
2277            || overlaps(in_low_ptr, out_ptr as *const f64, len);
2278
2279        let hl2 = compute_hl2(high, low).map_err(|e| JsValue::from_str(&e.to_string()))?;
2280        let params = AoParams {
2281            short_period: Some(short_period),
2282            long_period: Some(long_period),
2283        };
2284        let input = AoInput::from_slice(&hl2, params);
2285
2286        if alias {
2287            let mut tmp = vec![0.0; len];
2288            ao_into_slice(&mut tmp, &input, Kernel::Auto)
2289                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2290            out.copy_from_slice(&tmp);
2291        } else {
2292            ao_into_slice(out, &input, Kernel::Auto)
2293                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2294        }
2295        Ok(())
2296    }
2297}
2298
2299#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2300#[wasm_bindgen]
2301pub fn ao_batch_into(
2302    in_high_ptr: *const f64,
2303    in_low_ptr: *const f64,
2304    out_ptr: *mut f64,
2305    len: usize,
2306    short_period_start: usize,
2307    short_period_end: usize,
2308    short_period_step: usize,
2309    long_period_start: usize,
2310    long_period_end: usize,
2311    long_period_step: usize,
2312) -> Result<usize, JsValue> {
2313    if in_high_ptr.is_null() || in_low_ptr.is_null() || out_ptr.is_null() {
2314        return Err(JsValue::from_str("Null pointer provided"));
2315    }
2316
2317    unsafe {
2318        let high = std::slice::from_raw_parts(in_high_ptr, len);
2319        let low = std::slice::from_raw_parts(in_low_ptr, len);
2320
2321        let sweep = AoBatchRange {
2322            short_period: (short_period_start, short_period_end, short_period_step),
2323            long_period: (long_period_start, long_period_end, long_period_step),
2324        };
2325
2326        let combos = expand_grid_checked(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2327        let rows = combos.len();
2328        let total = rows
2329            .checked_mul(len)
2330            .ok_or_else(|| JsValue::from_str("rows*len overflow"))?;
2331        let out = std::slice::from_raw_parts_mut(out_ptr, total);
2332
2333        let hl2 = compute_hl2(high, low).map_err(|e| JsValue::from_str(&e.to_string()))?;
2334
2335        ao_batch_inner_into(&hl2, &sweep, Kernel::Scalar, false, out)
2336            .map_err(|e| JsValue::from_str(&e.to_string()))?;
2337
2338        Ok(rows)
2339    }
2340}