1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::{PyDict, PyList};
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::utilities::data_loader::{source_type, Candles};
16#[cfg(all(feature = "python", feature = "cuda"))]
17use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
18use crate::utilities::enums::Kernel;
19use crate::utilities::helpers::{
20 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
21 make_uninit_matrix,
22};
23#[cfg(feature = "python")]
24use crate::utilities::kernel_validation::validate_kernel;
25#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
26use core::arch::x86_64::*;
27#[cfg(not(target_arch = "wasm32"))]
28use rayon::prelude::*;
29use std::error::Error;
30use std::mem::{ManuallyDrop, MaybeUninit};
31use thiserror::Error;
32
33#[cfg(all(feature = "python", feature = "cuda"))]
34use crate::cuda::CudaAo;
35#[cfg(all(feature = "python", feature = "cuda"))]
36use cust::context::Context;
37#[cfg(all(feature = "python", feature = "cuda"))]
38use cust::memory::DeviceBuffer;
39#[cfg(all(feature = "python", feature = "cuda"))]
40use std::sync::Arc;
41
42impl<'a> AsRef<[f64]> for AoInput<'a> {
43 #[inline(always)]
44 fn as_ref(&self) -> &[f64] {
45 match &self.data {
46 AoData::Slice(slice) => slice,
47 AoData::Candles { candles, source } => source_type(candles, source),
48 }
49 }
50}
51
52#[derive(Debug, Clone)]
53pub enum AoData<'a> {
54 Candles {
55 candles: &'a Candles,
56 source: &'a str,
57 },
58 Slice(&'a [f64]),
59}
60
61#[derive(Debug, Clone)]
62pub struct AoOutput {
63 pub values: Vec<f64>,
64}
65
66#[derive(Debug, Clone)]
67#[cfg_attr(
68 all(target_arch = "wasm32", feature = "wasm"),
69 derive(Serialize, Deserialize)
70)]
71pub struct AoParams {
72 pub short_period: Option<usize>,
73 pub long_period: Option<usize>,
74}
75
76impl Default for AoParams {
77 fn default() -> Self {
78 Self {
79 short_period: Some(5),
80 long_period: Some(34),
81 }
82 }
83}
84
85#[derive(Debug, Clone)]
86pub struct AoInput<'a> {
87 pub data: AoData<'a>,
88 pub params: AoParams,
89}
90
91impl<'a> AoInput<'a> {
92 #[inline]
93 pub fn from_candles(c: &'a Candles, s: &'a str, p: AoParams) -> Self {
94 Self {
95 data: AoData::Candles {
96 candles: c,
97 source: s,
98 },
99 params: p,
100 }
101 }
102 #[inline]
103 pub fn from_slice(sl: &'a [f64], p: AoParams) -> Self {
104 Self {
105 data: AoData::Slice(sl),
106 params: p,
107 }
108 }
109 #[inline]
110 pub fn with_default_candles(c: &'a Candles) -> Self {
111 Self::from_candles(c, "hl2", AoParams::default())
112 }
113 #[inline]
114 pub fn get_short(&self) -> usize {
115 self.params.short_period.unwrap_or(5)
116 }
117 #[inline]
118 pub fn get_long(&self) -> usize {
119 self.params.long_period.unwrap_or(34)
120 }
121}
122
123#[derive(Copy, Clone, Debug)]
124pub struct AoBuilder {
125 short_period: Option<usize>,
126 long_period: Option<usize>,
127 kernel: Kernel,
128}
129
130impl Default for AoBuilder {
131 fn default() -> Self {
132 Self {
133 short_period: None,
134 long_period: None,
135 kernel: Kernel::Auto,
136 }
137 }
138}
139
140impl AoBuilder {
141 #[inline(always)]
142 pub fn new() -> Self {
143 Self::default()
144 }
145 #[inline(always)]
146 pub fn short_period(mut self, n: usize) -> Self {
147 self.short_period = Some(n);
148 self
149 }
150 #[inline(always)]
151 pub fn long_period(mut self, n: usize) -> Self {
152 self.long_period = Some(n);
153 self
154 }
155 #[inline(always)]
156 pub fn kernel(mut self, k: Kernel) -> Self {
157 self.kernel = k;
158 self
159 }
160
161 #[inline(always)]
162 pub fn apply(self, c: &Candles) -> Result<AoOutput, AoError> {
163 let p = AoParams {
164 short_period: self.short_period,
165 long_period: self.long_period,
166 };
167 let i = AoInput::from_candles(c, "hl2", p);
168 ao_with_kernel(&i, self.kernel)
169 }
170 #[inline(always)]
171 pub fn apply_slice(self, d: &[f64]) -> Result<AoOutput, AoError> {
172 let p = AoParams {
173 short_period: self.short_period,
174 long_period: self.long_period,
175 };
176 let i = AoInput::from_slice(d, p);
177 ao_with_kernel(&i, self.kernel)
178 }
179 #[inline(always)]
180 pub fn into_stream(self) -> Result<AoStream, AoError> {
181 let p = AoParams {
182 short_period: self.short_period,
183 long_period: self.long_period,
184 };
185 AoStream::try_new(p)
186 }
187}
188
189#[derive(Debug, Error)]
190pub enum AoError {
191 #[error("ao: Input data slice is empty.")]
192 EmptyInputData,
193 #[error("ao: All values are NaN.")]
194 AllValuesNaN,
195 #[error("ao: Invalid periods: short={short}, long={long}")]
196 InvalidPeriods { short: usize, long: usize },
197 #[error("ao: Short period must be less than long period: short={short}, long={long}")]
198 ShortPeriodNotLess { short: usize, long: usize },
199 #[error("ao: Not enough valid data: needed = {needed}, valid = {valid}")]
200 NotEnoughValidData { needed: usize, valid: usize },
201 #[error("ao: High and low arrays must have same length: high={high_len}, low={low_len}")]
202 MismatchedArrayLengths { high_len: usize, low_len: usize },
203 #[error("ao: output length mismatch: expected = {expected}, got = {got}")]
204 OutputLengthMismatch { expected: usize, got: usize },
205 #[error("ao: invalid kernel for batch: {0:?}")]
206 InvalidKernelForBatch(Kernel),
207 #[error("ao: invalid range: start={start}, end={end}, step={step}")]
208 InvalidRange {
209 start: usize,
210 end: usize,
211 step: usize,
212 },
213 #[error("ao: invalid input: {0}")]
214 InvalidInput(String),
215}
216
217#[inline]
218pub fn ao(input: &AoInput) -> Result<AoOutput, AoError> {
219 ao_with_kernel(input, Kernel::Auto)
220}
221
222#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
223pub fn ao_into(input: &AoInput, out: &mut [f64]) -> Result<(), AoError> {
224 let (data, short, long, first, len) = ao_prepare(input)?;
225 if out.len() != len {
226 return Err(AoError::OutputLengthMismatch {
227 expected: len,
228 got: out.len(),
229 });
230 }
231
232 let warmup_end = first + long - 1;
233 let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
234 for v in &mut out[..warmup_end.min(len)] {
235 *v = qnan;
236 }
237
238 let chosen = detect_best_kernel();
239 unsafe {
240 match chosen {
241 Kernel::Scalar | Kernel::ScalarBatch => ao_scalar(data, short, long, first, out),
242 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
243 Kernel::Avx2 | Kernel::Avx2Batch => ao_avx2(data, short, long, first, out),
244 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
245 Kernel::Avx512 | Kernel::Avx512Batch => ao_avx512(data, short, long, first, out),
246 _ => unreachable!(),
247 }
248 }
249
250 Ok(())
251}
252
253#[inline]
254pub fn ao_into_slice(dst: &mut [f64], input: &AoInput, kern: Kernel) -> Result<(), AoError> {
255 let (data, short, long, first, len) = ao_prepare(input)?;
256 if dst.len() != len {
257 return Err(AoError::OutputLengthMismatch {
258 expected: len,
259 got: dst.len(),
260 });
261 }
262
263 let chosen = match kern {
264 Kernel::Auto => detect_best_kernel(),
265 k => k,
266 };
267
268 unsafe {
269 match chosen {
270 Kernel::Scalar | Kernel::ScalarBatch => ao_scalar(data, short, long, first, dst),
271 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
272 Kernel::Avx2 | Kernel::Avx2Batch => ao_avx2(data, short, long, first, dst),
273 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
274 Kernel::Avx512 | Kernel::Avx512Batch => ao_avx512(data, short, long, first, dst),
275 _ => unreachable!(),
276 }
277 }
278
279 let warmup_end = first + long - 1;
280 for v in &mut dst[..warmup_end] {
281 *v = f64::NAN;
282 }
283 Ok(())
284}
285
286#[inline(always)]
287fn ao_prepare<'a>(input: &'a AoInput) -> Result<(&'a [f64], usize, usize, usize, usize), AoError> {
288 let data = input.as_ref();
289 if data.is_empty() {
290 return Err(AoError::EmptyInputData);
291 }
292
293 let first = data
294 .iter()
295 .position(|x| !x.is_nan())
296 .ok_or(AoError::AllValuesNaN)?;
297 let len = data.len();
298 let short = input.get_short();
299 let long = input.get_long();
300
301 if short == 0 || long == 0 {
302 return Err(AoError::InvalidPeriods { short, long });
303 }
304 if short >= long {
305 return Err(AoError::ShortPeriodNotLess { short, long });
306 }
307 if len - first < long {
308 return Err(AoError::NotEnoughValidData {
309 needed: long,
310 valid: len - first,
311 });
312 }
313 Ok((data, short, long, first, len))
314}
315
316pub fn ao_with_kernel(input: &AoInput, kernel: Kernel) -> Result<AoOutput, AoError> {
317 let (data, short, long, first, len) = ao_prepare(input)?;
318
319 let chosen = match kernel {
320 Kernel::Auto => detect_best_kernel(),
321 other => other,
322 };
323
324 let warmup_period = first + long - 1;
325
326 let mut out = alloc_with_nan_prefix(len, warmup_period);
327
328 unsafe {
329 match chosen {
330 Kernel::Scalar | Kernel::ScalarBatch => ao_scalar(data, short, long, first, &mut out),
331 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
332 Kernel::Avx2 | Kernel::Avx2Batch => ao_avx2(data, short, long, first, &mut out),
333 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
334 Kernel::Avx512 | Kernel::Avx512Batch => ao_avx512(data, short, long, first, &mut out),
335 _ => unreachable!(),
336 }
337 }
338
339 Ok(AoOutput { values: out })
340}
341
342#[inline]
343pub fn compute_hl2(high: &[f64], low: &[f64]) -> Result<Vec<f64>, AoError> {
344 if high.len() != low.len() {
345 return Err(AoError::MismatchedArrayLengths {
346 high_len: high.len(),
347 low_len: low.len(),
348 });
349 }
350 if high.is_empty() {
351 return Err(AoError::EmptyInputData);
352 }
353
354 let mut out = alloc_with_nan_prefix(high.len(), 0);
355
356 for i in 0..high.len() {
357 unsafe {
358 *out.get_unchecked_mut(i) = (*high.get_unchecked(i) + *low.get_unchecked(i)) * 0.5;
359 }
360 }
361 Ok(out)
362}
363
364#[inline(always)]
365pub fn ao_scalar(data: &[f64], short: usize, long: usize, first: usize, out: &mut [f64]) {
366 let len = data.len();
367 if len == 0 {
368 return;
369 }
370 let warm = first + long - 1;
371 if warm >= len {
372 return;
373 }
374
375 let inv_s = 1.0 / (short as f64);
376 let inv_l = 1.0 / (long as f64);
377
378 unsafe {
379 let base = data.as_ptr();
380
381 let mut long_sum = 0.0f64;
382 let mut p = base.add(first);
383 for _ in 0..(long - 1) {
384 long_sum += *p;
385 p = p.add(1);
386 }
387
388 let mut short_sum = 0.0f64;
389 let mut ps = base.add(first + long - short);
390 for _ in 0..(short - 1) {
391 short_sum += *ps;
392 ps = ps.add(1);
393 }
394
395 let mut head = base.add(warm);
396 let mut tail_long = base.add(first);
397 let mut tail_short = base.add(first + long - short);
398 let mut outp = out.as_mut_ptr().add(warm);
399
400 let mut i = warm;
401
402 while i + 1 < len {
403 let v0 = *head;
404 long_sum += v0;
405 short_sum += v0;
406 *outp = short_sum.mul_add(inv_s, -long_sum * inv_l);
407 long_sum -= *tail_long;
408 short_sum -= *tail_short;
409 head = head.add(1);
410 tail_long = tail_long.add(1);
411 tail_short = tail_short.add(1);
412 outp = outp.add(1);
413 i += 1;
414
415 let v1 = *head;
416 long_sum += v1;
417 short_sum += v1;
418 *outp = short_sum.mul_add(inv_s, -long_sum * inv_l);
419 long_sum -= *tail_long;
420 short_sum -= *tail_short;
421 head = head.add(1);
422 tail_long = tail_long.add(1);
423 tail_short = tail_short.add(1);
424 outp = outp.add(1);
425 i += 1;
426 }
427
428 while i < len {
429 let v = *head;
430 long_sum += v;
431 short_sum += v;
432 *outp = short_sum.mul_add(inv_s, -long_sum * inv_l);
433 long_sum -= *tail_long;
434 short_sum -= *tail_short;
435 head = head.add(1);
436 tail_long = tail_long.add(1);
437 tail_short = tail_short.add(1);
438 outp = outp.add(1);
439 i += 1;
440 }
441 }
442}
443
444#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
445#[inline]
446pub fn ao_avx512(data: &[f64], short: usize, long: usize, first: usize, out: &mut [f64]) {
447 if long <= 32 {
448 unsafe { ao_avx512_short(data, short, long, first, out) }
449 } else {
450 unsafe { ao_avx512_long(data, short, long, first, out) }
451 }
452}
453
454#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
455#[inline]
456pub fn ao_avx2(data: &[f64], short: usize, long: usize, first: usize, out: &mut [f64]) {
457 unsafe { ao_scalar(data, short, long, first, out) }
458}
459
460#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
461#[inline]
462pub unsafe fn ao_avx512_short(
463 data: &[f64],
464 short: usize,
465 long: usize,
466 first: usize,
467 out: &mut [f64],
468) {
469 ao_scalar(data, short, long, first, out)
470}
471
472#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
473#[inline]
474pub unsafe fn ao_avx512_long(
475 data: &[f64],
476 short: usize,
477 long: usize,
478 first: usize,
479 out: &mut [f64],
480) {
481 ao_scalar(data, short, long, first, out)
482}
483
484#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
485#[target_feature(enable = "avx2,fma")]
486unsafe fn ao_avx2_prefixsum(
487 data: &[f64],
488 short: usize,
489 long: usize,
490 first: usize,
491 out: &mut [f64],
492) {
493 use core::arch::x86_64::*;
494
495 let len = data.len();
496 if len == 0 {
497 return;
498 }
499 let warm = first + long - 1;
500 if warm >= len {
501 return;
502 }
503
504 let suffix_len = len - first;
505 let mut pref = Vec::<f64>::with_capacity(suffix_len + 1);
506 pref.push(0.0);
507 let mut acc = 0.0f64;
508 let p = data.as_ptr().add(first);
509 for k in 0..suffix_len {
510 acc += *p.add(k);
511 pref.push(acc);
512 }
513 let pref_ptr = pref.as_ptr();
514
515 let n_out = len - warm;
516 let out_base = out.as_mut_ptr().add(warm);
517
518 let inv_s = _mm256_set1_pd(1.0 / (short as f64));
519 let inv_l = _mm256_set1_pd(1.0 / (long as f64));
520
521 let cur0 = pref_ptr.add(long);
522 let mut cur = cur0;
523 let mut prev_s = cur0.sub(short);
524 let mut prev_l = cur0.sub(long);
525 let mut dst = out_base;
526
527 let vec_chunks = n_out / 4;
528 for _ in 0..vec_chunks {
529 let pc = _mm256_loadu_pd(cur);
530 let ps = _mm256_loadu_pd(prev_s);
531 let pl = _mm256_loadu_pd(prev_l);
532
533 let short_sum = _mm256_sub_pd(pc, ps);
534 let long_sum = _mm256_sub_pd(pc, pl);
535
536 let z = _mm256_mul_pd(long_sum, inv_l);
537 let ao = _mm256_fmsub_pd(short_sum, inv_s, z);
538
539 _mm256_storeu_pd(dst, ao);
540
541 cur = cur.add(4);
542 prev_s = prev_s.add(4);
543 prev_l = prev_l.add(4);
544 dst = dst.add(4);
545 }
546
547 let tail = n_out & 3;
548 for t in 0..tail {
549 let i_cur = long + (vec_chunks * 4 + t);
550 let pc = *pref_ptr.add(i_cur);
551 let ps = *pref_ptr.add(i_cur - short);
552 let pl = *pref_ptr.add(i_cur - long);
553 let y = pc - ps;
554 let z = pc - pl;
555 *dst.add(t) = y.mul_add(1.0 / (short as f64), -(z * (1.0 / (long as f64))));
556 }
557}
558
559#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
560#[target_feature(enable = "avx512f,fma")]
561unsafe fn ao_avx512_prefixsum(
562 data: &[f64],
563 short: usize,
564 long: usize,
565 first: usize,
566 out: &mut [f64],
567) {
568 use core::arch::x86_64::*;
569
570 let len = data.len();
571 if len == 0 {
572 return;
573 }
574 let warm = first + long - 1;
575 if warm >= len {
576 return;
577 }
578
579 let suffix_len = len - first;
580 let mut pref = Vec::<f64>::with_capacity(suffix_len + 1);
581 pref.push(0.0);
582 let mut acc = 0.0f64;
583 let p = data.as_ptr().add(first);
584 for k in 0..suffix_len {
585 acc += *p.add(k);
586 pref.push(acc);
587 }
588 let pref_ptr = pref.as_ptr();
589
590 let n_out = len - warm;
591 let out_base = out.as_mut_ptr().add(warm);
592
593 let inv_s = _mm512_set1_pd(1.0 / (short as f64));
594 let inv_l = _mm512_set1_pd(1.0 / (long as f64));
595
596 let cur0 = pref_ptr.add(long);
597 let mut cur = cur0;
598 let mut prev_s = cur0.sub(short);
599 let mut prev_l = cur0.sub(long);
600 let mut dst = out_base;
601
602 let vec_chunks = n_out / 8;
603 for _ in 0..vec_chunks {
604 let pc = _mm512_loadu_pd(cur);
605 let ps = _mm512_loadu_pd(prev_s);
606 let pl = _mm512_loadu_pd(prev_l);
607
608 let short_sum = _mm512_sub_pd(pc, ps);
609 let long_sum = _mm512_sub_pd(pc, pl);
610
611 let z = _mm512_mul_pd(long_sum, inv_l);
612 let ao = _mm512_fmsub_pd(short_sum, inv_s, z);
613
614 _mm512_storeu_pd(dst, ao);
615
616 cur = cur.add(8);
617 prev_s = prev_s.add(8);
618 prev_l = prev_l.add(8);
619 dst = dst.add(8);
620 }
621
622 let tail = n_out & 7;
623 for t in 0..tail {
624 let i_cur = long + (vec_chunks * 8 + t);
625 let pc = *pref_ptr.add(i_cur);
626 let ps = *pref_ptr.add(i_cur - short);
627 let pl = *pref_ptr.add(i_cur - long);
628 let y = pc - ps;
629 let z = pc - pl;
630 *dst.add(t) = y.mul_add(1.0 / (short as f64), -(z * (1.0 / (long as f64))));
631 }
632}
633
634#[derive(Debug, Clone)]
635pub struct AoStream {
636 short: usize,
637 long: usize,
638 inv_short: f64,
639 inv_long: f64,
640
641 buf: Vec<f64>,
642
643 head: usize,
644
645 filled: usize,
646
647 short_sum: f64,
648 long_sum: f64,
649}
650
651impl AoStream {
652 pub fn try_new(params: AoParams) -> Result<Self, AoError> {
653 let short = params.short_period.unwrap_or(5);
654 let long = params.long_period.unwrap_or(34);
655 if short == 0 || long == 0 {
656 return Err(AoError::InvalidPeriods { short, long });
657 }
658 if short >= long {
659 return Err(AoError::ShortPeriodNotLess { short, long });
660 }
661 Ok(Self {
662 short,
663 long,
664 inv_short: 1.0 / (short as f64),
665 inv_long: 1.0 / (long as f64),
666 buf: vec![0.0; long],
667 head: 0,
668 filled: 0,
669 short_sum: 0.0,
670 long_sum: 0.0,
671 })
672 }
673
674 #[inline(always)]
675 pub fn update(&mut self, value: f64) -> Option<f64> {
676 let old_long = if self.filled == self.long {
677 self.buf[self.head]
678 } else {
679 0.0
680 };
681
682 let old_short = if self.filled >= self.short {
683 let idx = if self.head >= self.short {
684 self.head - self.short
685 } else {
686 self.head + self.long - self.short
687 };
688 self.buf[idx]
689 } else {
690 0.0
691 };
692
693 self.buf[self.head] = value;
694
695 self.head += 1;
696 if self.head == self.long {
697 self.head = 0;
698 }
699
700 if self.filled < self.long {
701 self.filled += 1;
702 }
703
704 self.long_sum += value - old_long;
705 self.short_sum += value - old_short;
706
707 if self.filled == self.long {
708 Some(
709 self.short_sum
710 .mul_add(self.inv_short, -(self.long_sum * self.inv_long)),
711 )
712 } else {
713 None
714 }
715 }
716
717 #[inline(always)]
718 pub fn reset(&mut self) {
719 self.head = 0;
720 self.filled = 0;
721 self.short_sum = 0.0;
722 self.long_sum = 0.0;
723 for x in &mut self.buf {
724 *x = 0.0;
725 }
726 }
727}
728
729#[derive(Clone, Debug)]
730pub struct AoBatchRange {
731 pub short_period: (usize, usize, usize),
732 pub long_period: (usize, usize, usize),
733}
734
735impl Default for AoBatchRange {
736 fn default() -> Self {
737 Self {
738 short_period: (5, 5, 0),
739 long_period: (34, 283, 1),
740 }
741 }
742}
743
744#[derive(Clone, Debug, Default)]
745pub struct AoBatchBuilder {
746 range: AoBatchRange,
747 kernel: Kernel,
748}
749
750impl AoBatchBuilder {
751 pub fn new() -> Self {
752 Self::default()
753 }
754 pub fn kernel(mut self, k: Kernel) -> Self {
755 self.kernel = k;
756 self
757 }
758 pub fn short_range(mut self, start: usize, end: usize, step: usize) -> Self {
759 self.range.short_period = (start, end, step);
760 self
761 }
762 pub fn short_static(mut self, v: usize) -> Self {
763 self.range.short_period = (v, v, 0);
764 self
765 }
766 pub fn long_range(mut self, start: usize, end: usize, step: usize) -> Self {
767 self.range.long_period = (start, end, step);
768 self
769 }
770 pub fn long_static(mut self, v: usize) -> Self {
771 self.range.long_period = (v, v, 0);
772 self
773 }
774 pub fn apply_slice(self, data: &[f64]) -> Result<AoBatchOutput, AoError> {
775 ao_batch_with_kernel(data, &self.range, self.kernel)
776 }
777 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<AoBatchOutput, AoError> {
778 AoBatchBuilder::new().kernel(k).apply_slice(data)
779 }
780 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<AoBatchOutput, AoError> {
781 let slice = source_type(c, src);
782 self.apply_slice(slice)
783 }
784 pub fn with_default_candles(c: &Candles) -> Result<AoBatchOutput, AoError> {
785 AoBatchBuilder::new()
786 .kernel(Kernel::Auto)
787 .apply_candles(c, "hl2")
788 }
789}
790
791pub fn ao_batch_with_kernel(
792 data: &[f64],
793 sweep: &AoBatchRange,
794 k: Kernel,
795) -> Result<AoBatchOutput, AoError> {
796 let kernel = match k {
797 Kernel::Auto => detect_best_batch_kernel(),
798 other if other.is_batch() => other,
799 other => return Err(AoError::InvalidKernelForBatch(other)),
800 };
801 let simd = match kernel {
802 Kernel::Avx512Batch => Kernel::Avx512,
803 Kernel::Avx2Batch => Kernel::Avx2,
804 Kernel::ScalarBatch => Kernel::Scalar,
805 _ => unreachable!(),
806 };
807 ao_batch_par_slice(data, sweep, simd)
808}
809
810#[derive(Clone, Debug)]
811pub struct AoBatchOutput {
812 pub values: Vec<f64>,
813 pub combos: Vec<AoParams>,
814 pub rows: usize,
815 pub cols: usize,
816}
817impl AoBatchOutput {
818 pub fn row_for_params(&self, p: &AoParams) -> Option<usize> {
819 self.combos.iter().position(|c| {
820 c.short_period.unwrap_or(5) == p.short_period.unwrap_or(5)
821 && c.long_period.unwrap_or(34) == p.long_period.unwrap_or(34)
822 })
823 }
824 pub fn values_for(&self, p: &AoParams) -> Option<&[f64]> {
825 self.row_for_params(p).map(|row| {
826 let start = row * self.cols;
827 &self.values[start..start + self.cols]
828 })
829 }
830}
831
832#[inline(always)]
833fn expand_grid_checked(r: &AoBatchRange) -> Result<Vec<AoParams>, AoError> {
834 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, AoError> {
835 if step == 0 {
836 return Ok(vec![start]);
837 }
838 if start == end {
839 return Ok(vec![start]);
840 }
841 let mut out = Vec::new();
842 if start < end {
843 let mut v = start;
844 while v <= end {
845 out.push(v);
846 match v.checked_add(step) {
847 Some(next) => {
848 if next == v {
849 break;
850 }
851 v = next;
852 }
853 None => break,
854 }
855 }
856 } else {
857 let mut v = start;
858 while v >= end {
859 out.push(v);
860 if v < end + step {
861 break;
862 }
863 v -= step;
864 if v == 0 {
865 break;
866 }
867 }
868 }
869 if out.is_empty() {
870 return Err(AoError::InvalidRange { start, end, step });
871 }
872 Ok(out)
873 }
874 let shorts = axis_usize(r.short_period)?;
875 let longs = axis_usize(r.long_period)?;
876
877 let cap = shorts
878 .len()
879 .checked_mul(longs.len())
880 .ok_or_else(|| AoError::InvalidInput("rows*cols overflow".into()))?;
881 let mut out = Vec::with_capacity(cap);
882 for &s in &shorts {
883 for &l in &longs {
884 if s < l && s > 0 && l > 0 {
885 out.push(AoParams {
886 short_period: Some(s),
887 long_period: Some(l),
888 });
889 }
890 }
891 }
892 if out.is_empty() {
893 return Err(AoError::InvalidInput(
894 "no valid parameter combinations".into(),
895 ));
896 }
897 Ok(out)
898}
899
900#[inline(always)]
901pub fn ao_batch_slice(
902 data: &[f64],
903 sweep: &AoBatchRange,
904 kern: Kernel,
905) -> Result<AoBatchOutput, AoError> {
906 ao_batch_inner(data, sweep, kern, false)
907}
908#[inline(always)]
909pub fn ao_batch_par_slice(
910 data: &[f64],
911 sweep: &AoBatchRange,
912 kern: Kernel,
913) -> Result<AoBatchOutput, AoError> {
914 ao_batch_inner(data, sweep, kern, true)
915}
916#[inline(always)]
917fn ao_batch_inner_into(
918 data: &[f64],
919 sweep: &AoBatchRange,
920 kern: Kernel,
921 parallel: bool,
922 out: &mut [f64],
923) -> Result<Vec<AoParams>, AoError> {
924 let combos = expand_grid_checked(sweep)?;
925 let first = data
926 .iter()
927 .position(|x| !x.is_nan())
928 .ok_or(AoError::AllValuesNaN)?;
929 let max_long = combos.iter().map(|c| c.long_period.unwrap()).max().unwrap();
930 if data.len() - first < max_long {
931 return Err(AoError::NotEnoughValidData {
932 needed: max_long,
933 valid: data.len() - first,
934 });
935 }
936 let rows = combos.len();
937 let cols = data.len();
938 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
939 let short = combos[row].short_period.unwrap();
940 let long = combos[row].long_period.unwrap();
941 match kern {
942 Kernel::Scalar => ao_row_scalar(data, first, short, long, out_row),
943 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
944 Kernel::Avx2 => ao_row_avx2(data, first, short, long, out_row),
945 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
946 Kernel::Avx512 => ao_row_avx512(data, first, short, long, out_row),
947 _ => unreachable!(),
948 }
949 };
950 if parallel {
951 #[cfg(not(target_arch = "wasm32"))]
952 {
953 out.par_chunks_mut(cols)
954 .enumerate()
955 .for_each(|(row, slice)| do_row(row, slice));
956 }
957
958 #[cfg(target_arch = "wasm32")]
959 {
960 for (row, slice) in out.chunks_mut(cols).enumerate() {
961 do_row(row, slice);
962 }
963 }
964 } else {
965 for (row, slice) in out.chunks_mut(cols).enumerate() {
966 do_row(row, slice);
967 }
968 }
969 Ok(combos)
970}
971
972#[inline(always)]
973fn ao_batch_inner(
974 data: &[f64],
975 sweep: &AoBatchRange,
976 kern: Kernel,
977 parallel: bool,
978) -> Result<AoBatchOutput, AoError> {
979 let combos = expand_grid_checked(sweep)?;
980 let rows = combos.len();
981 let cols = data.len();
982
983 if cols == 0 {
984 return Err(AoError::EmptyInputData);
985 }
986
987 let first = data
988 .iter()
989 .position(|x| !x.is_nan())
990 .ok_or(AoError::AllValuesNaN)?;
991 let max_long = combos.iter().map(|c| c.long_period.unwrap()).max().unwrap();
992 if data.len() - first < max_long {
993 return Err(AoError::NotEnoughValidData {
994 needed: max_long,
995 valid: data.len() - first,
996 });
997 }
998
999 let _total = rows
1000 .checked_mul(cols)
1001 .ok_or_else(|| AoError::InvalidInput("rows*cols overflow".into()))?;
1002
1003 let mut buf_mu = make_uninit_matrix(rows, cols);
1004
1005 let warm: Vec<usize> = combos
1006 .iter()
1007 .map(|c| first + c.long_period.unwrap() - 1)
1008 .collect();
1009
1010 init_matrix_prefixes(&mut buf_mu, cols, &warm);
1011
1012 let mut buf_guard = std::mem::ManuallyDrop::new(buf_mu);
1013 let values_ptr = buf_guard.as_mut_ptr() as *mut f64;
1014 let values_len = buf_guard.len();
1015 let values_cap = buf_guard.capacity();
1016
1017 let values = unsafe {
1018 let slice = std::slice::from_raw_parts_mut(values_ptr, values_len);
1019
1020 ao_batch_inner_into(data, sweep, kern, parallel, slice)?;
1021
1022 Vec::from_raw_parts(values_ptr, values_len, values_cap)
1023 };
1024
1025 Ok(AoBatchOutput {
1026 values,
1027 combos,
1028 rows,
1029 cols,
1030 })
1031}
1032#[inline(always)]
1033unsafe fn ao_row_scalar(data: &[f64], first: usize, short: usize, long: usize, out: &mut [f64]) {
1034 ao_scalar(data, short, long, first, out)
1035}
1036#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1037#[inline(always)]
1038pub unsafe fn ao_row_avx2(data: &[f64], first: usize, short: usize, long: usize, out: &mut [f64]) {
1039 ao_scalar(data, short, long, first, out)
1040}
1041#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1042#[inline(always)]
1043pub unsafe fn ao_row_avx512(
1044 data: &[f64],
1045 first: usize,
1046 short: usize,
1047 long: usize,
1048 out: &mut [f64],
1049) {
1050 if long <= 32 {
1051 ao_row_avx512_short(data, first, short, long, out);
1052 } else {
1053 ao_row_avx512_long(data, first, short, long, out);
1054 }
1055}
1056#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1057#[inline(always)]
1058pub unsafe fn ao_row_avx512_short(
1059 data: &[f64],
1060 first: usize,
1061 short: usize,
1062 long: usize,
1063 out: &mut [f64],
1064) {
1065 ao_scalar(data, short, long, first, out)
1066}
1067#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1068#[inline(always)]
1069pub unsafe fn ao_row_avx512_long(
1070 data: &[f64],
1071 first: usize,
1072 short: usize,
1073 long: usize,
1074 out: &mut [f64],
1075) {
1076 ao_scalar(data, short, long, first, out)
1077}
1078
1079#[cfg(test)]
1080mod tests {
1081 use super::*;
1082 use crate::skip_if_unsupported;
1083 use crate::utilities::data_loader::read_candles_from_csv;
1084 use paste::paste;
1085
1086 #[test]
1087 fn test_ao_into_matches_api() {
1088 let len = 256;
1089 let mut data = Vec::with_capacity(len);
1090 for i in 0..len {
1091 let x = i as f64;
1092 data.push((x * 0.01).mul_add(1.0, (x * 0.0314159).sin()));
1093 }
1094
1095 let input = AoInput::from_slice(&data, AoParams::default());
1096
1097 let base = ao(&input).expect("ao() should succeed");
1098
1099 let mut out = vec![0.0f64; len];
1100 ao_into(&input, &mut out).expect("ao_into() should succeed");
1101
1102 assert_eq!(base.values.len(), out.len());
1103
1104 fn eq_or_both_nan(a: f64, b: f64) -> bool {
1105 (a.is_nan() && b.is_nan()) || (a == b)
1106 }
1107
1108 for i in 0..len {
1109 assert!(
1110 eq_or_both_nan(base.values[i], out[i]),
1111 "Mismatch at index {}: got {}, expected {}",
1112 i,
1113 out[i],
1114 base.values[i]
1115 );
1116 }
1117 }
1118
1119 fn check_ao_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1120 skip_if_unsupported!(kernel, test_name);
1121 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1122 let candles = read_candles_from_csv(file_path)?;
1123 let partial_params = AoParams {
1124 short_period: Some(3),
1125 long_period: None,
1126 };
1127 let input = AoInput::from_candles(&candles, "hl2", partial_params);
1128 let result = ao_with_kernel(&input, kernel)?;
1129 assert_eq!(result.values.len(), candles.close.len());
1130 Ok(())
1131 }
1132 fn check_ao_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1133 skip_if_unsupported!(kernel, test_name);
1134 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1135 let candles = read_candles_from_csv(file_path)?;
1136 let input = AoInput::with_default_candles(&candles);
1137 let result = ao_with_kernel(&input, kernel)?;
1138 let expected_last_five = [-1671.3, -1401.6706, -1262.3559, -1178.4941, -1157.4118];
1139 let start = result.values.len().saturating_sub(5);
1140 for (i, &val) in result.values[start..].iter().enumerate() {
1141 let diff = (val - expected_last_five[i]).abs();
1142 assert!(
1143 diff < 1e-1,
1144 "[{}] AO {:?} mismatch at idx {}: got {}, expected {}",
1145 test_name,
1146 kernel,
1147 i,
1148 val,
1149 expected_last_five[i]
1150 );
1151 }
1152 Ok(())
1153 }
1154 fn check_ao_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1155 skip_if_unsupported!(kernel, test_name);
1156 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1157 let candles = read_candles_from_csv(file_path)?;
1158 let input = AoInput::with_default_candles(&candles);
1159 match input.data {
1160 AoData::Candles { source, .. } => assert_eq!(source, "hl2"),
1161 _ => panic!("Expected AoData::Candles"),
1162 }
1163 let output = ao_with_kernel(&input, kernel)?;
1164 assert_eq!(output.values.len(), candles.close.len());
1165 Ok(())
1166 }
1167 fn check_ao_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1168 skip_if_unsupported!(kernel, test_name);
1169 let input_data = [10.0, 20.0, 30.0];
1170 let params = AoParams {
1171 short_period: Some(0),
1172 long_period: Some(34),
1173 };
1174 let input = AoInput::from_slice(&input_data, params);
1175 let res = ao_with_kernel(&input, kernel);
1176 assert!(
1177 res.is_err(),
1178 "[{}] AO should fail with zero period",
1179 test_name
1180 );
1181 Ok(())
1182 }
1183 fn check_ao_period_exceeds_length(
1184 test_name: &str,
1185 kernel: Kernel,
1186 ) -> Result<(), Box<dyn Error>> {
1187 skip_if_unsupported!(kernel, test_name);
1188 let data_small = [10.0, 20.0, 30.0];
1189 let params = AoParams {
1190 short_period: Some(5),
1191 long_period: Some(10),
1192 };
1193 let input = AoInput::from_slice(&data_small, params);
1194 let res = ao_with_kernel(&input, kernel);
1195 assert!(
1196 res.is_err(),
1197 "[{}] AO should fail with period exceeding length",
1198 test_name
1199 );
1200 Ok(())
1201 }
1202 fn check_ao_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1203 skip_if_unsupported!(kernel, test_name);
1204 let single_point = [42.0];
1205 let params = AoParams {
1206 short_period: Some(5),
1207 long_period: Some(34),
1208 };
1209 let input = AoInput::from_slice(&single_point, params);
1210 let res = ao_with_kernel(&input, kernel);
1211 assert!(
1212 res.is_err(),
1213 "[{}] AO should fail with insufficient data",
1214 test_name
1215 );
1216 Ok(())
1217 }
1218 fn check_ao_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1219 skip_if_unsupported!(kernel, test_name);
1220 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1221 let candles = read_candles_from_csv(file_path)?;
1222 let first_params = AoParams {
1223 short_period: Some(5),
1224 long_period: Some(34),
1225 };
1226 let first_input = AoInput::from_candles(&candles, "hl2", first_params);
1227 let first_result = ao_with_kernel(&first_input, kernel)?;
1228 let second_params = AoParams {
1229 short_period: Some(3),
1230 long_period: Some(10),
1231 };
1232 let second_input = AoInput::from_slice(&first_result.values, second_params);
1233 let second_result = ao_with_kernel(&second_input, kernel)?;
1234 assert_eq!(second_result.values.len(), first_result.values.len());
1235 Ok(())
1236 }
1237 fn check_ao_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1238 skip_if_unsupported!(kernel, test_name);
1239 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1240 let candles = read_candles_from_csv(file_path)?;
1241 let input = AoInput::from_candles(
1242 &candles,
1243 "hl2",
1244 AoParams {
1245 short_period: Some(5),
1246 long_period: Some(34),
1247 },
1248 );
1249 let res = ao_with_kernel(&input, kernel)?;
1250 assert_eq!(res.values.len(), candles.close.len());
1251 if res.values.len() > 240 {
1252 for (i, &val) in res.values[240..].iter().enumerate() {
1253 assert!(
1254 !val.is_nan(),
1255 "[{}] Found unexpected NaN at out-index {}",
1256 test_name,
1257 240 + i
1258 );
1259 }
1260 }
1261 Ok(())
1262 }
1263
1264 #[cfg(debug_assertions)]
1265 fn check_ao_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1266 skip_if_unsupported!(kernel, test_name);
1267
1268 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1269 let candles = read_candles_from_csv(file_path)?;
1270
1271 let test_params = vec![
1272 AoParams::default(),
1273 AoParams {
1274 short_period: Some(2),
1275 long_period: Some(10),
1276 },
1277 AoParams {
1278 short_period: Some(3),
1279 long_period: Some(20),
1280 },
1281 AoParams {
1282 short_period: Some(10),
1283 long_period: Some(50),
1284 },
1285 AoParams {
1286 short_period: Some(20),
1287 long_period: Some(100),
1288 },
1289 AoParams {
1290 short_period: Some(5),
1291 long_period: Some(200),
1292 },
1293 ];
1294
1295 for (param_idx, params) in test_params.iter().enumerate() {
1296 let input = AoInput::from_candles(&candles, "hl2", params.clone());
1297 let result = ao_with_kernel(&input, kernel)?;
1298
1299 for (i, &val) in result.values.iter().enumerate() {
1300 if val.is_nan() {
1301 continue;
1302 }
1303
1304 let bits = val.to_bits();
1305
1306 if bits == 0x11111111_11111111 {
1307 panic!(
1308 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1309 with params: short={}, long={} (param set {})",
1310 test_name,
1311 val,
1312 bits,
1313 i,
1314 params.short_period.unwrap_or(5),
1315 params.long_period.unwrap_or(34),
1316 param_idx
1317 );
1318 }
1319
1320 if bits == 0x22222222_22222222 {
1321 panic!(
1322 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1323 with params: short={}, long={} (param set {})",
1324 test_name,
1325 val,
1326 bits,
1327 i,
1328 params.short_period.unwrap_or(5),
1329 params.long_period.unwrap_or(34),
1330 param_idx
1331 );
1332 }
1333
1334 if bits == 0x33333333_33333333 {
1335 panic!(
1336 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1337 with params: short={}, long={} (param set {})",
1338 test_name,
1339 val,
1340 bits,
1341 i,
1342 params.short_period.unwrap_or(5),
1343 params.long_period.unwrap_or(34),
1344 param_idx
1345 );
1346 }
1347 }
1348 }
1349
1350 Ok(())
1351 }
1352
1353 #[cfg(not(debug_assertions))]
1354 fn check_ao_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1355 Ok(())
1356 }
1357
1358 #[cfg(feature = "proptest")]
1359 #[allow(clippy::float_cmp)]
1360 fn check_ao_property(
1361 test_name: &str,
1362 kernel: Kernel,
1363 ) -> Result<(), Box<dyn std::error::Error>> {
1364 use proptest::prelude::*;
1365 skip_if_unsupported!(kernel, test_name);
1366
1367 let strat = (1usize..=50).prop_flat_map(|short_period| {
1368 ((short_period + 1)..=100).prop_flat_map(move |long_period| {
1369 (
1370 prop::collection::vec(
1371 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1372 long_period..400,
1373 ),
1374 Just(short_period),
1375 Just(long_period),
1376 )
1377 })
1378 });
1379
1380 proptest::test_runner::TestRunner::default()
1381 .run(&strat, |(data, short_period, long_period)| {
1382 let params = AoParams {
1383 short_period: Some(short_period),
1384 long_period: Some(long_period),
1385 };
1386 let input = AoInput::from_slice(&data, params);
1387
1388 let AoOutput { values: out } = ao_with_kernel(&input, kernel).unwrap();
1389
1390 let AoOutput { values: ref_out } = ao_with_kernel(&input, Kernel::Scalar).unwrap();
1391
1392 for i in 0..(long_period - 1) {
1393 prop_assert!(
1394 out[i].is_nan(),
1395 "Expected NaN during warmup at index {}, got {}",
1396 i,
1397 out[i]
1398 );
1399 }
1400
1401 for i in (long_period - 1)..data.len() {
1402 prop_assert!(
1403 out[i].is_finite(),
1404 "Expected finite value after warmup at index {}, got {}",
1405 i,
1406 out[i]
1407 );
1408 }
1409
1410 if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10) && data.len() >= long_period
1411 {
1412 for i in (long_period - 1)..data.len() {
1413 prop_assert!(
1414 out[i].abs() < 1e-9,
1415 "For constant data, AO should be 0 at index {}, got {}",
1416 i,
1417 out[i]
1418 );
1419 }
1420 }
1421
1422 if data.len() >= long_period {
1423 for i in (long_period - 1)..data.len() {
1424 let short_start = i + 1 - short_period;
1425 let long_start = i + 1 - long_period;
1426
1427 let long_window = &data[long_start..=i];
1428 let window_min = long_window.iter().cloned().fold(f64::INFINITY, f64::min);
1429 let window_max = long_window
1430 .iter()
1431 .cloned()
1432 .fold(f64::NEG_INFINITY, f64::max);
1433
1434 let theoretical_max = window_max - window_min;
1435
1436 prop_assert!(
1437 out[i].abs() <= theoretical_max + 1e-9,
1438 "AO value {} at index {} exceeds theoretical max {}",
1439 out[i],
1440 i,
1441 theoretical_max
1442 );
1443 }
1444 }
1445
1446 if short_period == 1 && long_period == 2 && data.len() >= 2 {
1447 for i in 1..data.len() {
1448 let expected = data[i] - (data[i] + data[i - 1]) / 2.0;
1449 let actual = out[i];
1450 prop_assert!(
1451 (actual - expected).abs() < 1e-9,
1452 "Special case (short=1, long=2) mismatch at index {}: expected {}, got {}",
1453 i,
1454 expected,
1455 actual
1456 );
1457 }
1458 }
1459
1460 let is_increasing = data.windows(2).all(|w| w[1] > w[0] + 1e-10);
1461 let is_decreasing = data.windows(2).all(|w| w[1] < w[0] - 1e-10);
1462
1463 if is_increasing && data.len() >= long_period {
1464 for i in (long_period - 1)..data.len() {
1465 prop_assert!(
1466 out[i] > -1e-9,
1467 "For strictly increasing data, AO should be positive at index {}, got {}",
1468 i,
1469 out[i]
1470 );
1471 }
1472 }
1473
1474 if is_decreasing && data.len() >= long_period {
1475 for i in (long_period - 1)..data.len() {
1476 prop_assert!(
1477 out[i] < 1e-9,
1478 "For strictly decreasing data, AO should be negative at index {}, got {}",
1479 i,
1480 out[i]
1481 );
1482 }
1483 }
1484
1485 if data.len() >= 3 {
1486 let diffs: Vec<f64> = data.windows(2).map(|w| w[1] - w[0]).collect();
1487 let is_linear = diffs.windows(2).all(|w| (w[1] - w[0]).abs() < 1e-10);
1488
1489 if is_linear && data.len() >= long_period + 10 {
1490 let stable_start = long_period + 5;
1491 if stable_start < data.len() - 1 {
1492 let stable_values = &out[stable_start..];
1493 if stable_values.len() >= 2 {
1494 let first_stable = stable_values[0];
1495 for (idx, &val) in stable_values.iter().enumerate() {
1496 prop_assert!(
1497 (val - first_stable).abs() < 1e-8,
1498 "For linear data, AO should stabilize. Value at {} differs from stable value: {} vs {}",
1499 stable_start + idx,
1500 val,
1501 first_stable
1502 );
1503 }
1504 }
1505 }
1506 }
1507 }
1508
1509 for i in 0..data.len() {
1510 let y = out[i];
1511 let r = ref_out[i];
1512
1513 if y.is_nan() || r.is_nan() {
1514 prop_assert!(
1515 y.is_nan() && r.is_nan(),
1516 "NaN mismatch at index {}: kernel={:?} is_nan={}, scalar is_nan={}",
1517 i,
1518 kernel,
1519 y.is_nan(),
1520 r.is_nan()
1521 );
1522 continue;
1523 }
1524
1525 let y_bits = y.to_bits();
1526 let r_bits = r.to_bits();
1527 let ulp_diff: u64 = y_bits.abs_diff(r_bits);
1528
1529 prop_assert!(
1530 (y - r).abs() <= 1e-9 || ulp_diff <= 4,
1531 "Kernel mismatch at index {}: kernel={:?} value={}, scalar value={}, \
1532 diff={}, ULP diff={}",
1533 i,
1534 kernel,
1535 y,
1536 r,
1537 (y - r).abs(),
1538 ulp_diff
1539 );
1540 }
1541
1542 Ok(())
1543 })
1544 .unwrap();
1545
1546 Ok(())
1547 }
1548
1549 macro_rules! generate_all_ao_tests {
1550 ($($test_fn:ident),*) => {
1551 paste! {
1552 $(
1553 #[test]
1554 fn [<$test_fn _scalar_f64>]() {
1555 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1556 }
1557 )*
1558 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1559 $(
1560 #[test]
1561 fn [<$test_fn _avx2_f64>]() {
1562 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1563 }
1564 #[test]
1565 fn [<$test_fn _avx512_f64>]() {
1566 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1567 }
1568 )*
1569 }
1570 }
1571 }
1572 generate_all_ao_tests!(
1573 check_ao_partial_params,
1574 check_ao_accuracy,
1575 check_ao_default_candles,
1576 check_ao_zero_period,
1577 check_ao_period_exceeds_length,
1578 check_ao_very_small_dataset,
1579 check_ao_reinput,
1580 check_ao_nan_handling,
1581 check_ao_no_poison
1582 );
1583
1584 #[cfg(feature = "proptest")]
1585 generate_all_ao_tests!(check_ao_property);
1586
1587 #[test]
1588 fn test_output_len_mismatch_error() {
1589 let data = vec![10.0, 20.0, 30.0, 40.0, 50.0];
1590 let params = AoParams {
1591 short_period: Some(2),
1592 long_period: Some(3),
1593 };
1594 let input = AoInput::from_slice(&data, params);
1595
1596 let mut wrong_sized_buf = vec![0.0; 10];
1597
1598 let result = ao_into_slice(&mut wrong_sized_buf, &input, Kernel::Auto);
1599 assert!(result.is_err());
1600
1601 if let Err(AoError::OutputLengthMismatch { expected, got }) = result {
1602 assert_eq!(expected, 5);
1603 assert_eq!(got, 10);
1604 } else {
1605 panic!("Expected OutputLenMismatch error");
1606 }
1607 }
1608
1609 #[test]
1610 fn test_invalid_kernel_error() {
1611 let data = vec![10.0, 20.0, 30.0, 40.0, 50.0];
1612 let sweep = AoBatchRange::default();
1613
1614 let result = ao_batch_with_kernel(&data, &sweep, Kernel::Scalar);
1615 assert!(result.is_err());
1616
1617 if let Err(AoError::InvalidKernelForBatch(kernel)) = result {
1618 assert!(matches!(kernel, Kernel::Scalar));
1619 } else {
1620 panic!("Expected InvalidKernel error");
1621 }
1622 }
1623
1624 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1625 skip_if_unsupported!(kernel, test);
1626 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1627 let c = read_candles_from_csv(file)?;
1628 let output = AoBatchBuilder::new()
1629 .kernel(kernel)
1630 .apply_candles(&c, "hl2")?;
1631 let def = AoParams::default();
1632 let row = output.values_for(&def).expect("default row missing");
1633 assert_eq!(row.len(), c.close.len());
1634 let expected = [-1671.3, -1401.6706, -1262.3559, -1178.4941, -1157.4118];
1635 let start = row.len() - 5;
1636 for (i, &v) in row[start..].iter().enumerate() {
1637 assert!(
1638 (v - expected[i]).abs() < 1e-1,
1639 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1640 );
1641 }
1642 Ok(())
1643 }
1644
1645 #[cfg(debug_assertions)]
1646 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1647 skip_if_unsupported!(kernel, test);
1648
1649 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1650 let c = read_candles_from_csv(file)?;
1651
1652 let test_configs = vec![
1653 (2, 10, 2, 15, 40, 5),
1654 (5, 20, 5, 30, 60, 10),
1655 (10, 30, 10, 40, 100, 20),
1656 (3, 3, 0, 10, 50, 10),
1657 (5, 15, 5, 34, 34, 0),
1658 ];
1659
1660 for (cfg_idx, &(short_start, short_end, short_step, long_start, long_end, long_step)) in
1661 test_configs.iter().enumerate()
1662 {
1663 let sweep = AoBatchRange {
1664 short_period: (short_start, short_end, short_step),
1665 long_period: (long_start, long_end, long_step),
1666 };
1667
1668 let output = ao_batch_with_kernel(source_type(&c, "hl2"), &sweep, kernel)?;
1669
1670 for (row, combo) in output.combos.iter().enumerate() {
1671 let row_start = row * output.cols;
1672 let row_end = row_start + output.cols;
1673 let row_values = &output.values[row_start..row_end];
1674
1675 for (col, &val) in row_values.iter().enumerate() {
1676 if val.is_nan() {
1677 continue;
1678 }
1679
1680 let bits = val.to_bits();
1681 let idx = row * output.cols + col;
1682
1683 if bits == 0x11111111_11111111 {
1684 panic!(
1685 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1686 at row {} col {} (flat index {}) with params: short={}, long={}",
1687 test,
1688 cfg_idx,
1689 val,
1690 bits,
1691 row,
1692 col,
1693 idx,
1694 combo.short_period.unwrap_or(5),
1695 combo.long_period.unwrap_or(34)
1696 );
1697 }
1698
1699 if bits == 0x22222222_22222222 {
1700 panic!(
1701 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1702 at row {} col {} (flat index {}) with params: short={}, long={}",
1703 test,
1704 cfg_idx,
1705 val,
1706 bits,
1707 row,
1708 col,
1709 idx,
1710 combo.short_period.unwrap_or(5),
1711 combo.long_period.unwrap_or(34)
1712 );
1713 }
1714
1715 if bits == 0x33333333_33333333 {
1716 panic!(
1717 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1718 at row {} col {} (flat index {}) with params: short={}, long={}",
1719 test,
1720 cfg_idx,
1721 val,
1722 bits,
1723 row,
1724 col,
1725 idx,
1726 combo.short_period.unwrap_or(5),
1727 combo.long_period.unwrap_or(34)
1728 );
1729 }
1730 }
1731 }
1732 }
1733
1734 Ok(())
1735 }
1736
1737 #[cfg(not(debug_assertions))]
1738 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1739 Ok(())
1740 }
1741 macro_rules! gen_batch_tests {
1742 ($fn_name:ident) => {
1743 paste! {
1744 #[test] fn [<$fn_name _scalar>]() {
1745 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1746 }
1747 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1748 #[test] fn [<$fn_name _avx2>]() {
1749 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1750 }
1751 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1752 #[test] fn [<$fn_name _avx512>]() {
1753 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1754 }
1755 #[test] fn [<$fn_name _auto_detect>]() {
1756 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1757 }
1758 }
1759 };
1760 }
1761 gen_batch_tests!(check_batch_default_row);
1762 gen_batch_tests!(check_batch_no_poison);
1763}
1764
1765#[cfg(feature = "python")]
1766#[pyfunction(name = "ao")]
1767#[pyo3(signature = (high, low, short_period, long_period, kernel=None))]
1768pub fn ao_py<'py>(
1769 py: Python<'py>,
1770 high: numpy::PyReadonlyArray1<'py, f64>,
1771 low: numpy::PyReadonlyArray1<'py, f64>,
1772 short_period: usize,
1773 long_period: usize,
1774 kernel: Option<&str>,
1775) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
1776 use numpy::{IntoPyArray, PyArrayMethods};
1777
1778 let high_slice = high.as_slice()?;
1779 let low_slice = low.as_slice()?;
1780
1781 let kern = validate_kernel(kernel, false)?;
1782
1783 let result_vec: Vec<f64> = py
1784 .allow_threads(|| -> Result<Vec<f64>, AoError> {
1785 let hl2 = compute_hl2(high_slice, low_slice)?;
1786
1787 let params = AoParams {
1788 short_period: Some(short_period),
1789 long_period: Some(long_period),
1790 };
1791 let ao_in = AoInput::from_slice(&hl2, params);
1792
1793 ao_with_kernel(&ao_in, kern).map(|o| o.values)
1794 })
1795 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1796
1797 Ok(result_vec.into_pyarray(py))
1798}
1799
1800#[cfg(feature = "python")]
1801#[pyclass(name = "AoStream")]
1802pub struct AoStreamPy {
1803 stream: AoStream,
1804}
1805
1806#[cfg(feature = "python")]
1807#[pymethods]
1808impl AoStreamPy {
1809 #[new]
1810 fn new(short_period: usize, long_period: usize) -> PyResult<Self> {
1811 let params = AoParams {
1812 short_period: Some(short_period),
1813 long_period: Some(long_period),
1814 };
1815 let stream = AoStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1816 Ok(AoStreamPy { stream })
1817 }
1818
1819 fn update(&mut self, high: f64, low: f64) -> Option<f64> {
1820 let hl2 = (high + low) / 2.0;
1821 self.stream.update(hl2)
1822 }
1823}
1824
1825#[cfg(feature = "python")]
1826#[pyfunction(name = "ao_batch")]
1827#[pyo3(signature = (high, low, short_period_range, long_period_range, kernel=None))]
1828pub fn ao_batch_py<'py>(
1829 py: Python<'py>,
1830 high: numpy::PyReadonlyArray1<'py, f64>,
1831 low: numpy::PyReadonlyArray1<'py, f64>,
1832 short_period_range: (usize, usize, usize),
1833 long_period_range: (usize, usize, usize),
1834 kernel: Option<&str>,
1835) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1836 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1837 use pyo3::types::PyDict;
1838
1839 let high_slice = high.as_slice()?;
1840 let low_slice = low.as_slice()?;
1841
1842 let kern = validate_kernel(kernel, true)?;
1843
1844 let sweep = AoBatchRange {
1845 short_period: short_period_range,
1846 long_period: long_period_range,
1847 };
1848
1849 let combos = expand_grid_checked(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1850 let rows = combos.len();
1851 let cols = high_slice.len();
1852
1853 let expected = rows
1854 .checked_mul(cols)
1855 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
1856 let out_arr = unsafe { PyArray1::<f64>::new(py, [expected], false) };
1857 let slice_out = unsafe { out_arr.as_slice_mut()? };
1858
1859 let combos = py
1860 .allow_threads(|| -> Result<Vec<AoParams>, AoError> {
1861 let hl2 = compute_hl2(high_slice, low_slice)?;
1862
1863 let first = hl2.iter().position(|x| !x.is_nan()).unwrap_or(0);
1864 let warm: Vec<usize> = combos
1865 .iter()
1866 .map(|c| first + c.long_period.unwrap() - 1)
1867 .collect();
1868
1869 let slice_mu = unsafe {
1870 std::slice::from_raw_parts_mut(
1871 slice_out.as_mut_ptr() as *mut MaybeUninit<f64>,
1872 slice_out.len(),
1873 )
1874 };
1875
1876 init_matrix_prefixes(slice_mu, cols, &warm);
1877
1878 let kernel = match kern {
1879 Kernel::Auto => detect_best_batch_kernel(),
1880 k => k,
1881 };
1882 let simd = match kernel {
1883 Kernel::Avx512Batch => Kernel::Avx512,
1884 Kernel::Avx2Batch => Kernel::Avx2,
1885 Kernel::ScalarBatch => Kernel::Scalar,
1886 _ => unreachable!(),
1887 };
1888
1889 ao_batch_inner_into(&hl2, &sweep, simd, true, slice_out)
1890 })
1891 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1892
1893 let dict = PyDict::new(py);
1894 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1895 dict.set_item(
1896 "short_periods",
1897 combos
1898 .iter()
1899 .map(|p| p.short_period.unwrap() as u64)
1900 .collect::<Vec<_>>()
1901 .into_pyarray(py),
1902 )?;
1903 dict.set_item(
1904 "long_periods",
1905 combos
1906 .iter()
1907 .map(|p| p.long_period.unwrap() as u64)
1908 .collect::<Vec<_>>()
1909 .into_pyarray(py),
1910 )?;
1911
1912 Ok(dict)
1913}
1914
1915#[cfg(all(feature = "python", feature = "cuda"))]
1916#[pyclass(module = "ta_indicators.cuda", unsendable)]
1917pub struct DeviceArrayF32AoPy {
1918 pub(crate) buf: Option<DeviceBuffer<f32>>,
1919 pub(crate) rows: usize,
1920 pub(crate) cols: usize,
1921 pub(crate) _ctx: Arc<Context>,
1922 pub(crate) device_id: u32,
1923}
1924
1925#[cfg(all(feature = "python", feature = "cuda"))]
1926#[pymethods]
1927impl DeviceArrayF32AoPy {
1928 #[getter]
1929 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1930 let d = PyDict::new(py);
1931 d.set_item("shape", (self.rows, self.cols))?;
1932 d.set_item("typestr", "<f4")?;
1933 d.set_item(
1934 "strides",
1935 (
1936 self.cols * std::mem::size_of::<f32>(),
1937 std::mem::size_of::<f32>(),
1938 ),
1939 )?;
1940 let ptr = self
1941 .buf
1942 .as_ref()
1943 .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?
1944 .as_device_ptr()
1945 .as_raw() as usize;
1946 d.set_item("data", (ptr, false))?;
1947
1948 d.set_item("version", 3)?;
1949 Ok(d)
1950 }
1951
1952 fn __dlpack_device__(&self) -> (i32, i32) {
1953 (2, self.device_id as i32)
1954 }
1955
1956 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1957 fn __dlpack__<'py>(
1958 &mut self,
1959 py: Python<'py>,
1960 stream: Option<pyo3::PyObject>,
1961 max_version: Option<pyo3::PyObject>,
1962 dl_device: Option<pyo3::PyObject>,
1963 copy: Option<pyo3::PyObject>,
1964 ) -> PyResult<PyObject> {
1965 let (kdl, alloc_dev) = self.__dlpack_device__();
1966 if let Some(dev_obj) = dl_device.as_ref() {
1967 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1968 if dev_ty != kdl || dev_id != alloc_dev {
1969 let wants_copy = copy
1970 .as_ref()
1971 .and_then(|c| c.extract::<bool>(py).ok())
1972 .unwrap_or(false);
1973 if wants_copy {
1974 return Err(PyValueError::new_err(
1975 "device copy not implemented for __dlpack__",
1976 ));
1977 } else {
1978 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
1979 }
1980 }
1981 }
1982 }
1983
1984 if let Some(obj) = stream.as_ref() {
1985 if let Ok(i) = obj.extract::<i64>(py) {
1986 if i == 0 {
1987 return Err(PyValueError::new_err(
1988 "__dlpack__: stream 0 is disallowed for CUDA",
1989 ));
1990 }
1991 }
1992 }
1993
1994 let buf = self
1995 .buf
1996 .take()
1997 .ok_or_else(|| PyValueError::new_err("__dlpack__ may only be called once"))?;
1998
1999 let rows = self.rows;
2000 let cols = self.cols;
2001
2002 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2003
2004 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2005 }
2006}
2007
2008#[cfg(all(feature = "python", feature = "cuda"))]
2009#[pyfunction(name = "ao_cuda_batch_dev")]
2010#[pyo3(signature = (high, low, short_period_range, long_period_range, device_id=0))]
2011pub fn ao_cuda_batch_dev_py(
2012 py: Python<'_>,
2013 high: numpy::PyReadonlyArray1<'_, f32>,
2014 low: numpy::PyReadonlyArray1<'_, f32>,
2015 short_period_range: (usize, usize, usize),
2016 long_period_range: (usize, usize, usize),
2017 device_id: usize,
2018) -> PyResult<DeviceArrayF32AoPy> {
2019 use crate::cuda::cuda_available;
2020 if !cuda_available() {
2021 return Err(PyValueError::new_err("CUDA not available"));
2022 }
2023 let high_slice = high.as_slice()?;
2024 let low_slice = low.as_slice()?;
2025 if high_slice.len() != low_slice.len() {
2026 return Err(PyValueError::new_err("high/low length mismatch"));
2027 }
2028 let mut hl2_f32 = vec![0f32; high_slice.len()];
2029 for i in 0..high_slice.len() {
2030 let h = high_slice[i];
2031 let l = low_slice[i];
2032 hl2_f32[i] = (h + l) * 0.5;
2033 }
2034 let sweep = AoBatchRange {
2035 short_period: short_period_range,
2036 long_period: long_period_range,
2037 };
2038 let inner = py.allow_threads(|| {
2039 let cuda = CudaAo::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2040 cuda.ao_batch_dev(&hl2_f32, &sweep)
2041 .map_err(|e| PyValueError::new_err(e.to_string()))
2042 })?;
2043 let crate::cuda::oscillators::ao_wrapper::DeviceArrayF32Ao {
2044 buf,
2045 rows,
2046 cols,
2047 ctx,
2048 device_id: dev_id,
2049 } = inner;
2050 Ok(DeviceArrayF32AoPy {
2051 buf: Some(buf),
2052 rows,
2053 cols,
2054 _ctx: ctx,
2055 device_id: dev_id,
2056 })
2057}
2058
2059#[cfg(all(feature = "python", feature = "cuda"))]
2060#[pyfunction(name = "ao_cuda_many_series_one_param_dev")]
2061#[pyo3(signature = (high_tm, low_tm, cols, rows, short_period, long_period, device_id=0))]
2062pub fn ao_cuda_many_series_one_param_dev_py(
2063 py: Python<'_>,
2064 high_tm: numpy::PyReadonlyArray1<'_, f32>,
2065 low_tm: numpy::PyReadonlyArray1<'_, f32>,
2066 cols: usize,
2067 rows: usize,
2068 short_period: usize,
2069 long_period: usize,
2070 device_id: usize,
2071) -> PyResult<DeviceArrayF32AoPy> {
2072 use crate::cuda::cuda_available;
2073 if !cuda_available() {
2074 return Err(PyValueError::new_err("CUDA not available"));
2075 }
2076 let high_slice = high_tm.as_slice()?;
2077 let low_slice = low_tm.as_slice()?;
2078 let expected = cols
2079 .checked_mul(rows)
2080 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
2081 if high_slice.len() != expected || low_slice.len() != expected {
2082 return Err(PyValueError::new_err("time-major input length mismatch"));
2083 }
2084 let mut hl2_f32 = vec![0f32; expected];
2085 for i in 0..expected {
2086 hl2_f32[i] = (high_slice[i] + low_slice[i]) * 0.5;
2087 }
2088 let inner = py.allow_threads(|| {
2089 let cuda = CudaAo::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2090 cuda.ao_many_series_one_param_time_major_dev(
2091 &hl2_f32,
2092 cols,
2093 rows,
2094 short_period,
2095 long_period,
2096 )
2097 .map_err(|e| PyValueError::new_err(e.to_string()))
2098 })?;
2099 let crate::cuda::oscillators::ao_wrapper::DeviceArrayF32Ao {
2100 buf,
2101 rows,
2102 cols,
2103 ctx,
2104 device_id: dev_id,
2105 } = inner;
2106 Ok(DeviceArrayF32AoPy {
2107 buf: Some(buf),
2108 rows,
2109 cols,
2110 _ctx: ctx,
2111 device_id: dev_id,
2112 })
2113}
2114
2115#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2116#[wasm_bindgen]
2117pub fn ao_js(
2118 high: &[f64],
2119 low: &[f64],
2120 short_period: usize,
2121 long_period: usize,
2122) -> Result<Vec<f64>, JsValue> {
2123 let hl2 = compute_hl2(high, low).map_err(|e| JsValue::from_str(&e.to_string()))?;
2124
2125 let params = AoParams {
2126 short_period: Some(short_period),
2127 long_period: Some(long_period),
2128 };
2129 let input = AoInput::from_slice(&hl2, params);
2130
2131 ao_with_kernel(&input, Kernel::Auto)
2132 .map(|o| o.values)
2133 .map_err(|e| JsValue::from_str(&e.to_string()))
2134}
2135
2136#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2137#[wasm_bindgen]
2138pub fn ao_batch_js(
2139 high: &[f64],
2140 low: &[f64],
2141 short_start: usize,
2142 short_end: usize,
2143 short_step: usize,
2144 long_start: usize,
2145 long_end: usize,
2146 long_step: usize,
2147) -> Result<Vec<f64>, JsValue> {
2148 let hl2 = compute_hl2(high, low).map_err(|e| JsValue::from_str(&e.to_string()))?;
2149
2150 let sweep = AoBatchRange {
2151 short_period: (short_start, short_end, short_step),
2152 long_period: (long_start, long_end, long_step),
2153 };
2154
2155 ao_batch_inner(&hl2, &sweep, Kernel::Scalar, false)
2156 .map(|output| output.values)
2157 .map_err(|e| JsValue::from_str(&e.to_string()))
2158}
2159
2160#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2161#[wasm_bindgen]
2162pub fn ao_batch_metadata_js(
2163 short_start: usize,
2164 short_end: usize,
2165 short_step: usize,
2166 long_start: usize,
2167 long_end: usize,
2168 long_step: usize,
2169) -> Result<Vec<f64>, JsValue> {
2170 let sweep = AoBatchRange {
2171 short_period: (short_start, short_end, short_step),
2172 long_period: (long_start, long_end, long_step),
2173 };
2174
2175 let combos = expand_grid_checked(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2176 let mut metadata = Vec::with_capacity(combos.len() * 2);
2177
2178 for combo in combos {
2179 metadata.push(combo.short_period.unwrap() as f64);
2180 metadata.push(combo.long_period.unwrap() as f64);
2181 }
2182
2183 Ok(metadata)
2184}
2185
2186#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2187#[derive(Serialize, Deserialize)]
2188pub struct AoBatchConfig {
2189 pub short_period_range: (usize, usize, usize),
2190 pub long_period_range: (usize, usize, usize),
2191}
2192
2193#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2194#[derive(Serialize, Deserialize)]
2195pub struct AoBatchJsOutput {
2196 pub values: Vec<f64>,
2197 pub combos: Vec<AoParams>,
2198 pub rows: usize,
2199 pub cols: usize,
2200}
2201
2202#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2203#[wasm_bindgen(js_name = ao_batch)]
2204pub fn ao_batch_unified_js(high: &[f64], low: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2205 let config: AoBatchConfig = serde_wasm_bindgen::from_value(config)
2206 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2207
2208 let hl2 = compute_hl2(high, low).map_err(|e| JsValue::from_str(&e.to_string()))?;
2209
2210 let sweep = AoBatchRange {
2211 short_period: config.short_period_range,
2212 long_period: config.long_period_range,
2213 };
2214
2215 let output = ao_batch_inner(&hl2, &sweep, Kernel::Scalar, false)
2216 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2217
2218 let js_output = AoBatchJsOutput {
2219 values: output.values,
2220 combos: output.combos,
2221 rows: output.rows,
2222 cols: output.cols,
2223 };
2224
2225 serde_wasm_bindgen::to_value(&js_output)
2226 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2227}
2228
2229#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2230#[wasm_bindgen]
2231pub fn ao_alloc(len: usize) -> *mut f64 {
2232 let mut vec = Vec::<f64>::with_capacity(len);
2233 let ptr = vec.as_mut_ptr();
2234 std::mem::forget(vec);
2235 ptr
2236}
2237
2238#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2239#[wasm_bindgen]
2240pub fn ao_free(ptr: *mut f64, len: usize) {
2241 if !ptr.is_null() {
2242 unsafe {
2243 let _ = Vec::from_raw_parts(ptr, len, len);
2244 }
2245 }
2246}
2247
2248#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2249#[wasm_bindgen]
2250pub fn ao_into(
2251 in_high_ptr: *const f64,
2252 in_low_ptr: *const f64,
2253 out_ptr: *mut f64,
2254 len: usize,
2255 short_period: usize,
2256 long_period: usize,
2257) -> Result<(), JsValue> {
2258 if in_high_ptr.is_null() || in_low_ptr.is_null() || out_ptr.is_null() {
2259 return Err(JsValue::from_str("Null pointer provided"));
2260 }
2261
2262 #[inline(always)]
2263 unsafe fn overlaps(a: *const f64, b: *const f64, n: usize) -> bool {
2264 let as_ = a as usize;
2265 let ae = as_.wrapping_add(n * core::mem::size_of::<f64>());
2266 let bs_ = b as usize;
2267 let be = bs_.wrapping_add(n * core::mem::size_of::<f64>());
2268 !(ae <= bs_ || be <= as_)
2269 }
2270
2271 unsafe {
2272 let high = std::slice::from_raw_parts(in_high_ptr, len);
2273 let low = std::slice::from_raw_parts(in_low_ptr, len);
2274 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2275
2276 let alias = overlaps(in_high_ptr, out_ptr as *const f64, len)
2277 || overlaps(in_low_ptr, out_ptr as *const f64, len);
2278
2279 let hl2 = compute_hl2(high, low).map_err(|e| JsValue::from_str(&e.to_string()))?;
2280 let params = AoParams {
2281 short_period: Some(short_period),
2282 long_period: Some(long_period),
2283 };
2284 let input = AoInput::from_slice(&hl2, params);
2285
2286 if alias {
2287 let mut tmp = vec![0.0; len];
2288 ao_into_slice(&mut tmp, &input, Kernel::Auto)
2289 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2290 out.copy_from_slice(&tmp);
2291 } else {
2292 ao_into_slice(out, &input, Kernel::Auto)
2293 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2294 }
2295 Ok(())
2296 }
2297}
2298
2299#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2300#[wasm_bindgen]
2301pub fn ao_batch_into(
2302 in_high_ptr: *const f64,
2303 in_low_ptr: *const f64,
2304 out_ptr: *mut f64,
2305 len: usize,
2306 short_period_start: usize,
2307 short_period_end: usize,
2308 short_period_step: usize,
2309 long_period_start: usize,
2310 long_period_end: usize,
2311 long_period_step: usize,
2312) -> Result<usize, JsValue> {
2313 if in_high_ptr.is_null() || in_low_ptr.is_null() || out_ptr.is_null() {
2314 return Err(JsValue::from_str("Null pointer provided"));
2315 }
2316
2317 unsafe {
2318 let high = std::slice::from_raw_parts(in_high_ptr, len);
2319 let low = std::slice::from_raw_parts(in_low_ptr, len);
2320
2321 let sweep = AoBatchRange {
2322 short_period: (short_period_start, short_period_end, short_period_step),
2323 long_period: (long_period_start, long_period_end, long_period_step),
2324 };
2325
2326 let combos = expand_grid_checked(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2327 let rows = combos.len();
2328 let total = rows
2329 .checked_mul(len)
2330 .ok_or_else(|| JsValue::from_str("rows*len overflow"))?;
2331 let out = std::slice::from_raw_parts_mut(out_ptr, total);
2332
2333 let hl2 = compute_hl2(high, low).map_err(|e| JsValue::from_str(&e.to_string()))?;
2334
2335 ao_batch_inner_into(&hl2, &sweep, Kernel::Scalar, false, out)
2336 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2337
2338 Ok(rows)
2339 }
2340}