1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::{PyDict, PyList};
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::utilities::data_loader::{source_type, Candles};
16use crate::utilities::enums::Kernel;
17use crate::utilities::helpers::{
18 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
19 make_uninit_matrix,
20};
21#[cfg(feature = "python")]
22use crate::utilities::kernel_validation::validate_kernel;
23use aligned_vec::{AVec, CACHELINE_ALIGN};
24
25#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
26use core::arch::x86_64::*;
27
28#[cfg(not(target_arch = "wasm32"))]
29use rayon::prelude::*;
30
31use std::convert::AsRef;
32use std::error::Error;
33use std::mem::MaybeUninit;
34use thiserror::Error;
35
36#[derive(Debug, Clone)]
37pub enum AsoData<'a> {
38 Candles {
39 candles: &'a Candles,
40 source: &'a str,
41 },
42 Slices {
43 open: &'a [f64],
44 high: &'a [f64],
45 low: &'a [f64],
46 close: &'a [f64],
47 },
48}
49
50#[derive(Debug, Clone)]
51pub struct AsoOutput {
52 pub bulls: Vec<f64>,
53 pub bears: Vec<f64>,
54}
55
56#[derive(Debug, Clone)]
57#[cfg_attr(
58 all(target_arch = "wasm32", feature = "wasm"),
59 derive(Serialize, Deserialize)
60)]
61pub struct AsoParams {
62 pub period: Option<usize>,
63 pub mode: Option<usize>,
64}
65
66impl Default for AsoParams {
67 fn default() -> Self {
68 Self {
69 period: Some(10),
70 mode: Some(0),
71 }
72 }
73}
74
75#[derive(Debug, Clone)]
76pub struct AsoInput<'a> {
77 pub data: AsoData<'a>,
78 pub params: AsoParams,
79}
80
81impl<'a> AsRef<[f64]> for AsoInput<'a> {
82 fn as_ref(&self) -> &[f64] {
83 match &self.data {
84 AsoData::Candles { candles, source } => source_type(candles, source),
85 AsoData::Slices { close, .. } => close,
86 }
87 }
88}
89
90impl<'a> AsoInput<'a> {
91 #[inline]
92 pub fn from_candles(c: &'a Candles, s: &'a str, p: AsoParams) -> Self {
93 Self {
94 data: AsoData::Candles {
95 candles: c,
96 source: s,
97 },
98 params: p,
99 }
100 }
101
102 #[inline]
103 pub fn from_slices(
104 open: &'a [f64],
105 high: &'a [f64],
106 low: &'a [f64],
107 close: &'a [f64],
108 p: AsoParams,
109 ) -> Self {
110 Self {
111 data: AsoData::Slices {
112 open,
113 high,
114 low,
115 close,
116 },
117 params: p,
118 }
119 }
120
121 #[inline]
122 pub fn with_default_candles(c: &'a Candles) -> Self {
123 Self::from_candles(c, "close", AsoParams::default())
124 }
125
126 #[inline]
127 pub fn get_period(&self) -> usize {
128 self.params.period.unwrap_or(10)
129 }
130
131 #[inline]
132 pub fn get_mode(&self) -> usize {
133 self.params.mode.unwrap_or(0)
134 }
135}
136
137#[derive(Copy, Clone, Debug)]
138pub struct AsoBuilder {
139 period: Option<usize>,
140 mode: Option<usize>,
141 kernel: Kernel,
142}
143
144impl Default for AsoBuilder {
145 fn default() -> Self {
146 Self {
147 period: None,
148 mode: None,
149 kernel: Kernel::Auto,
150 }
151 }
152}
153
154impl AsoBuilder {
155 #[inline(always)]
156 pub fn new() -> Self {
157 Self::default()
158 }
159
160 #[inline(always)]
161 pub fn period(mut self, val: usize) -> Self {
162 self.period = Some(val);
163 self
164 }
165
166 #[inline(always)]
167 pub fn mode(mut self, val: usize) -> Self {
168 self.mode = Some(val);
169 self
170 }
171
172 #[inline(always)]
173 pub fn kernel(mut self, k: Kernel) -> Self {
174 self.kernel = k;
175 self
176 }
177
178 #[inline(always)]
179 pub fn apply(self, c: &Candles) -> Result<AsoOutput, AsoError> {
180 self.apply_candles(c, "close")
181 }
182
183 #[inline(always)]
184 pub fn apply_candles(self, c: &Candles, s: &str) -> Result<AsoOutput, AsoError> {
185 let p = AsoParams {
186 period: self.period,
187 mode: self.mode,
188 };
189 let i = AsoInput::from_candles(c, s, p);
190 aso_with_kernel(&i, self.kernel)
191 }
192
193 #[inline(always)]
194 pub fn apply_slices(
195 self,
196 open: &[f64],
197 high: &[f64],
198 low: &[f64],
199 close: &[f64],
200 ) -> Result<AsoOutput, AsoError> {
201 let p = AsoParams {
202 period: self.period,
203 mode: self.mode,
204 };
205 let i = AsoInput::from_slices(open, high, low, close, p);
206 aso_with_kernel(&i, self.kernel)
207 }
208
209 #[inline(always)]
210 pub fn into_stream(self) -> Result<AsoStream, AsoError> {
211 let p = AsoParams {
212 period: self.period,
213 mode: self.mode,
214 };
215 AsoStream::try_new(p)
216 }
217}
218
219#[derive(Debug, Error)]
220pub enum AsoError {
221 #[error("aso: Input data slice is empty.")]
222 EmptyInputData,
223
224 #[error("aso: All values are NaN.")]
225 AllValuesNaN,
226
227 #[error("aso: Invalid period: period = {period}, data length = {data_len}")]
228 InvalidPeriod { period: usize, data_len: usize },
229
230 #[error("aso: Not enough valid data: needed = {needed}, valid = {valid}")]
231 NotEnoughValidData { needed: usize, valid: usize },
232
233 #[error("aso: Invalid mode: mode = {mode}, must be 0, 1, or 2")]
234 InvalidMode { mode: usize },
235
236 #[error("aso: Required OHLC data is missing or has mismatched lengths")]
237 MissingData,
238
239 #[error("aso: Output length mismatch: expected = {expected}, got = {got}")]
240 OutputLengthMismatch { expected: usize, got: usize },
241
242 #[error("aso: Invalid range: start={start} end={end} step={step}")]
243 InvalidRange {
244 start: usize,
245 end: usize,
246 step: usize,
247 },
248
249 #[error("aso: Invalid kernel for batch path: {0:?}")]
250 InvalidKernelForBatch(Kernel),
251}
252
253#[inline]
254pub fn aso(input: &AsoInput) -> Result<AsoOutput, AsoError> {
255 aso_with_kernel(input, Kernel::Auto)
256}
257
258pub fn aso_with_kernel(input: &AsoInput, kernel: Kernel) -> Result<AsoOutput, AsoError> {
259 let (open, high, low, close, period, mode, first, chosen) = aso_prepare(input, kernel)?;
260
261 let len = close.len();
262
263 let mut bulls = alloc_with_nan_prefix(len, first + period - 1);
264 let mut bears = alloc_with_nan_prefix(len, first + period - 1);
265
266 aso_compute_into(
267 open, high, low, close, period, mode, first, chosen, &mut bulls, &mut bears,
268 );
269
270 Ok(AsoOutput { bulls, bears })
271}
272
273#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
274#[inline]
275pub fn aso_into(
276 input: &AsoInput,
277 bulls_out: &mut [f64],
278 bears_out: &mut [f64],
279) -> Result<(), AsoError> {
280 let (open, high, low, close, period, mode, first, chosen) = aso_prepare(input, Kernel::Auto)?;
281
282 if bulls_out.len() != close.len() || bears_out.len() != close.len() {
283 return Err(AsoError::OutputLengthMismatch {
284 expected: close.len(),
285 got: bulls_out.len().min(bears_out.len()),
286 });
287 }
288
289 let warm = first + period - 1;
290 let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
291 let warm = warm.min(close.len());
292 for v in &mut bulls_out[..warm] {
293 *v = qnan;
294 }
295 for v in &mut bears_out[..warm] {
296 *v = qnan;
297 }
298
299 aso_compute_into(
300 open, high, low, close, period, mode, first, chosen, bulls_out, bears_out,
301 );
302
303 Ok(())
304}
305
306#[inline]
307pub fn aso_into_slices(
308 bulls_dst: &mut [f64],
309 bears_dst: &mut [f64],
310 input: &AsoInput,
311 kern: Kernel,
312) -> Result<(), AsoError> {
313 let (open, high, low, close, period, mode, first, chosen) = aso_prepare(input, kern)?;
314
315 if bulls_dst.len() != close.len() || bears_dst.len() != close.len() {
316 return Err(AsoError::OutputLengthMismatch {
317 expected: close.len(),
318 got: bulls_dst.len().min(bears_dst.len()),
319 });
320 }
321
322 aso_compute_into(
323 open, high, low, close, period, mode, first, chosen, bulls_dst, bears_dst,
324 );
325
326 let warm = first + period - 1;
327 for v in &mut bulls_dst[..warm] {
328 *v = f64::NAN;
329 }
330 for v in &mut bears_dst[..warm] {
331 *v = f64::NAN;
332 }
333
334 Ok(())
335}
336
337#[inline(always)]
338fn aso_prepare<'a>(
339 input: &'a AsoInput,
340 kernel: Kernel,
341) -> Result<
342 (
343 &'a [f64],
344 &'a [f64],
345 &'a [f64],
346 &'a [f64],
347 usize,
348 usize,
349 usize,
350 Kernel,
351 ),
352 AsoError,
353> {
354 let (open, high, low, close) = match &input.data {
355 AsoData::Candles { candles: c, .. } => (&c.open[..], &c.high[..], &c.low[..], &c.close[..]),
356 AsoData::Slices {
357 open,
358 high,
359 low,
360 close,
361 } => (*open, *high, *low, *close),
362 };
363
364 let len = close.len();
365 if len == 0 {
366 return Err(AsoError::EmptyInputData);
367 }
368
369 if open.len() != len || high.len() != len || low.len() != len {
370 return Err(AsoError::MissingData);
371 }
372
373 let first = close
374 .iter()
375 .position(|x| !x.is_nan())
376 .ok_or(AsoError::AllValuesNaN)?;
377
378 let period = input.get_period();
379 let mode = input.get_mode();
380
381 if period == 0 || period > len {
382 return Err(AsoError::InvalidPeriod {
383 period,
384 data_len: len,
385 });
386 }
387
388 if mode > 2 {
389 return Err(AsoError::InvalidMode { mode });
390 }
391
392 if len - first < period {
393 return Err(AsoError::NotEnoughValidData {
394 needed: period,
395 valid: len - first,
396 });
397 }
398
399 let mut chosen = match kernel {
400 Kernel::Auto => detect_best_kernel(),
401 k => k,
402 };
403
404 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
405 if matches!(kernel, Kernel::Auto) && matches!(chosen, Kernel::Avx512 | Kernel::Avx512Batch) {
406 chosen = Kernel::Avx2;
407 }
408
409 Ok((open, high, low, close, period, mode, first, chosen))
410}
411
412#[inline(always)]
413fn aso_compute_into(
414 open: &[f64],
415 high: &[f64],
416 low: &[f64],
417 close: &[f64],
418 period: usize,
419 mode: usize,
420 first: usize,
421 kernel: Kernel,
422 out_bulls: &mut [f64],
423 out_bears: &mut [f64],
424) {
425 unsafe {
426 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
427 {
428 if matches!(kernel, Kernel::Scalar | Kernel::ScalarBatch) {
429 aso_simd128(
430 open, high, low, close, period, mode, first, out_bulls, out_bears,
431 );
432 return;
433 }
434 }
435
436 match kernel {
437 Kernel::Scalar | Kernel::ScalarBatch => aso_scalar(
438 open, high, low, close, period, mode, first, out_bulls, out_bears,
439 ),
440 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
441 Kernel::Avx2 | Kernel::Avx2Batch => aso_avx2(
442 open, high, low, close, period, mode, first, out_bulls, out_bears,
443 ),
444 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
445 Kernel::Avx512 | Kernel::Avx512Batch => aso_avx512(
446 open, high, low, close, period, mode, first, out_bulls, out_bears,
447 ),
448 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
449 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => aso_scalar(
450 open, high, low, close, period, mode, first, out_bulls, out_bears,
451 ),
452 _ => unreachable!(),
453 }
454 }
455}
456
457#[inline]
458pub fn aso_scalar(
459 open: &[f64],
460 high: &[f64],
461 low: &[f64],
462 close: &[f64],
463 period: usize,
464 mode: usize,
465 first_val: usize,
466 out_bulls: &mut [f64],
467 out_bears: &mut [f64],
468) {
469 let len = close.len();
470 if len == 0 {
471 return;
472 }
473
474 let warm = first_val + period - 1;
475
476 const DEQUE_THRESHOLD: usize = 64;
477 if period <= DEQUE_THRESHOLD {
478 let mut ring_b = vec![0.0; period];
479 let mut ring_e = vec![0.0; period];
480 let mut sum_b = 0.0;
481 let mut sum_e = 0.0;
482 let mut head = 0usize;
483 let mut filled = 0usize;
484
485 for i in first_val..len {
486 let intrarange = high[i] - low[i];
487 let k1 = if intrarange == 0.0 { 1.0 } else { intrarange };
488 let intrabarbulls = (((close[i] - low[i]) + (high[i] - open[i])) * 50.0) / k1;
489 let intrabarbears = (((high[i] - close[i]) + (open[i] - low[i])) * 50.0) / k1;
490
491 if i >= warm {
492 let start = i + 1 - period;
493
494 let mut gl = f64::MAX;
495 let mut gh = f64::MIN;
496 for j in start..=i {
497 let lj = unsafe { *low.get_unchecked(j) };
498 let hj = unsafe { *high.get_unchecked(j) };
499 if lj < gl {
500 gl = lj;
501 }
502 if hj > gh {
503 gh = hj;
504 }
505 }
506 let gopen = unsafe { *open.get_unchecked(start) };
507 let gr = gh - gl;
508 let k2 = if gr == 0.0 { 1.0 } else { gr };
509
510 let groupbulls = (((close[i] - gl) + (gh - gopen)) * 50.0) / k2;
511 let groupbears = (((gh - close[i]) + (gopen - gl)) * 50.0) / k2;
512
513 let b = match mode {
514 0 => 0.5 * (intrabarbulls + groupbulls),
515 1 => intrabarbulls,
516 2 => groupbulls,
517 _ => 0.5 * (intrabarbulls + groupbulls),
518 };
519 let e = match mode {
520 0 => 0.5 * (intrabarbears + groupbears),
521 1 => intrabarbears,
522 2 => groupbears,
523 _ => 0.5 * (intrabarbears + groupbears),
524 };
525
526 let old_b = if filled == period { ring_b[head] } else { 0.0 };
527 let old_e = if filled == period { ring_e[head] } else { 0.0 };
528 sum_b += b - old_b;
529 sum_e += e - old_e;
530 ring_b[head] = b;
531 ring_e[head] = e;
532 head = (head + 1) % period;
533 if filled < period {
534 filled += 1;
535 }
536
537 let n = filled;
538 unsafe {
539 *out_bulls.get_unchecked_mut(i) = sum_b / n as f64;
540 *out_bears.get_unchecked_mut(i) = sum_e / n as f64;
541 }
542 }
543 }
544 return;
545 }
546
547 let mut ring_b = vec![0.0f64; period];
548 let mut ring_e = vec![0.0f64; period];
549 let mut sum_b = 0.0f64;
550 let mut sum_e = 0.0f64;
551 let mut rhead = 0usize;
552 let mut filled = 0usize;
553
554 let mut dq_min = vec![0usize; period];
555 let mut dq_max = vec![0usize; period];
556 let mut min_head = 0usize;
557 let mut min_tail = 0usize;
558 let mut min_len = 0usize;
559 let mut max_head = 0usize;
560 let mut max_tail = 0usize;
561 let mut max_len = 0usize;
562
563 for i in first_val..len {
564 let oi = unsafe { *open.get_unchecked(i) };
565 let hi = unsafe { *high.get_unchecked(i) };
566 let li = unsafe { *low.get_unchecked(i) };
567 let ci = unsafe { *close.get_unchecked(i) };
568
569 while min_len > 0 {
570 let back = if min_tail == 0 {
571 period - 1
572 } else {
573 min_tail - 1
574 };
575 let j = unsafe { *dq_min.get_unchecked(back) };
576 let lj = unsafe { *low.get_unchecked(j) };
577 if li <= lj {
578 min_tail = back;
579 min_len -= 1;
580 } else {
581 break;
582 }
583 }
584 if min_len == period {
585 min_head += 1;
586 if min_head == period {
587 min_head = 0;
588 }
589 min_len -= 1;
590 }
591 unsafe {
592 *dq_min.get_unchecked_mut(min_tail) = i;
593 }
594 min_tail += 1;
595 if min_tail == period {
596 min_tail = 0;
597 }
598 min_len += 1;
599
600 while max_len > 0 {
601 let back = if max_tail == 0 {
602 period - 1
603 } else {
604 max_tail - 1
605 };
606 let j = unsafe { *dq_max.get_unchecked(back) };
607 let hj = unsafe { *high.get_unchecked(j) };
608 if hi >= hj {
609 max_tail = back;
610 max_len -= 1;
611 } else {
612 break;
613 }
614 }
615 if max_len == period {
616 max_head += 1;
617 if max_head == period {
618 max_head = 0;
619 }
620 max_len -= 1;
621 }
622 unsafe {
623 *dq_max.get_unchecked_mut(max_tail) = i;
624 }
625 max_tail += 1;
626 if max_tail == period {
627 max_tail = 0;
628 }
629 max_len += 1;
630
631 if i >= warm {
632 let start = i + 1 - period;
633
634 while min_len > 0 && unsafe { *dq_min.get_unchecked(min_head) } < start {
635 min_head += 1;
636 if min_head == period {
637 min_head = 0;
638 }
639 min_len -= 1;
640 }
641 while max_len > 0 && unsafe { *dq_max.get_unchecked(max_head) } < start {
642 max_head += 1;
643 if max_head == period {
644 max_head = 0;
645 }
646 max_len -= 1;
647 }
648
649 debug_assert!(min_len > 0 && max_len > 0);
650 let gl = unsafe {
651 let idx = *dq_min.get_unchecked(min_head);
652 *low.get_unchecked(idx)
653 };
654 let gh = unsafe {
655 let idx = *dq_max.get_unchecked(max_head);
656 *high.get_unchecked(idx)
657 };
658 let gopen = unsafe { *open.get_unchecked(start) };
659
660 let intrarange = hi - li;
661 let inv_k1 = if intrarange != 0.0 {
662 1.0 / intrarange
663 } else {
664 1.0
665 };
666 let scale1 = 50.0 * inv_k1;
667 let intrabarbulls = ((ci - li) + (hi - oi)) * scale1;
668 let intrabarbears = ((hi - ci) + (oi - li)) * scale1;
669
670 let gr = gh - gl;
671 let inv_k2 = if gr != 0.0 { 1.0 / gr } else { 1.0 };
672 let scale2 = 50.0 * inv_k2;
673 let groupbulls = ((ci - gl) + (gh - gopen)) * scale2;
674 let groupbears = ((gh - ci) + (gopen - gl)) * scale2;
675
676 let b = if mode == 0 {
677 0.5 * (intrabarbulls + groupbulls)
678 } else if mode == 1 {
679 intrabarbulls
680 } else {
681 groupbulls
682 };
683 let e = if mode == 0 {
684 0.5 * (intrabarbears + groupbears)
685 } else if mode == 1 {
686 intrabarbears
687 } else {
688 groupbears
689 };
690
691 let old_b = if filled == period {
692 unsafe { *ring_b.get_unchecked(rhead) }
693 } else {
694 0.0
695 };
696 let old_e = if filled == period {
697 unsafe { *ring_e.get_unchecked(rhead) }
698 } else {
699 0.0
700 };
701 sum_b += b - old_b;
702 sum_e += e - old_e;
703 unsafe {
704 *ring_b.get_unchecked_mut(rhead) = b;
705 *ring_e.get_unchecked_mut(rhead) = e;
706 }
707 rhead += 1;
708 if rhead == period {
709 rhead = 0;
710 }
711 if filled < period {
712 filled += 1;
713 }
714
715 let n = filled as f64;
716 unsafe {
717 *out_bulls.get_unchecked_mut(i) = sum_b / n;
718 *out_bears.get_unchecked_mut(i) = sum_e / n;
719 }
720 }
721 }
722}
723
724#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
725#[inline]
726unsafe fn aso_simd128(
727 open: &[f64],
728 high: &[f64],
729 low: &[f64],
730 close: &[f64],
731 period: usize,
732 mode: usize,
733 first_val: usize,
734 out_bulls: &mut [f64],
735 out_bears: &mut [f64],
736) {
737 use core::arch::wasm32::*;
738
739 aso_scalar(
740 open, high, low, close, period, mode, first_val, out_bulls, out_bears,
741 );
742}
743
744#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
745#[target_feature(enable = "avx2,fma")]
746unsafe fn aso_avx2(
747 open: &[f64],
748 high: &[f64],
749 low: &[f64],
750 close: &[f64],
751 period: usize,
752 mode: usize,
753 first_val: usize,
754 out_bulls: &mut [f64],
755 out_bears: &mut [f64],
756) {
757 aso_scalar(
758 open, high, low, close, period, mode, first_val, out_bulls, out_bears,
759 );
760}
761
762#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
763#[target_feature(enable = "avx512f,fma")]
764unsafe fn aso_avx512(
765 open: &[f64],
766 high: &[f64],
767 low: &[f64],
768 close: &[f64],
769 period: usize,
770 mode: usize,
771 first_val: usize,
772 out_bulls: &mut [f64],
773 out_bears: &mut [f64],
774) {
775 aso_scalar(
776 open, high, low, close, period, mode, first_val, out_bulls, out_bears,
777 );
778}
779
780#[derive(Clone, Debug)]
781pub struct AsoBatchRange {
782 pub period: (usize, usize, usize),
783 pub mode: (usize, usize, usize),
784}
785
786impl Default for AsoBatchRange {
787 fn default() -> Self {
788 Self {
789 period: (10, 259, 1),
790 mode: (0, 0, 0),
791 }
792 }
793}
794
795#[derive(Clone, Debug, Default)]
796pub struct AsoBatchBuilder {
797 range: AsoBatchRange,
798 kernel: Kernel,
799}
800
801impl AsoBatchBuilder {
802 pub fn new() -> Self {
803 Self::default()
804 }
805
806 pub fn kernel(mut self, k: Kernel) -> Self {
807 self.kernel = k;
808 self
809 }
810
811 pub fn period_range(mut self, s: usize, e: usize, st: usize) -> Self {
812 self.range.period = (s, e, st);
813 self
814 }
815
816 pub fn period_static(mut self, p: usize) -> Self {
817 self.range.period = (p, p, 1);
818 self
819 }
820
821 pub fn mode_range(mut self, s: usize, e: usize, st: usize) -> Self {
822 self.range.mode = (s, e, st);
823 self
824 }
825
826 pub fn mode_static(mut self, m: usize) -> Self {
827 self.range.mode = (m, m, 1);
828 self
829 }
830
831 pub fn apply_candles(self, c: &Candles) -> Result<AsoBatchOutput, AsoError> {
832 let k = match self.kernel {
833 Kernel::Scalar => Kernel::ScalarBatch,
834 Kernel::Avx2 => Kernel::Avx2Batch,
835 Kernel::Avx512 => Kernel::Avx512Batch,
836 other => other,
837 };
838 aso_batch_with_kernel(&c.open, &c.high, &c.low, &c.close, &self.range, k)
839 }
840
841 pub fn apply_slices(
842 self,
843 o: &[f64],
844 h: &[f64],
845 l: &[f64],
846 c: &[f64],
847 ) -> Result<AsoBatchOutput, AsoError> {
848 let k = match self.kernel {
849 Kernel::Scalar => Kernel::ScalarBatch,
850 Kernel::Avx2 => Kernel::Avx2Batch,
851 Kernel::Avx512 => Kernel::Avx512Batch,
852 other => other,
853 };
854 aso_batch_with_kernel(o, h, l, c, &self.range, k)
855 }
856
857 pub fn with_default_candles(c: &Candles) -> Result<AsoBatchOutput, AsoError> {
858 Self::default().apply_candles(c)
859 }
860
861 pub fn with_default_slices(
862 o: &[f64],
863 h: &[f64],
864 l: &[f64],
865 c: &[f64],
866 k: Kernel,
867 ) -> Result<AsoBatchOutput, AsoError> {
868 Self::new().kernel(k).apply_slices(o, h, l, c)
869 }
870}
871
872#[derive(Clone, Debug)]
873pub struct AsoBatchOutput {
874 pub bulls: Vec<f64>,
875 pub bears: Vec<f64>,
876 pub combos: Vec<AsoParams>,
877 pub rows: usize,
878 pub cols: usize,
879}
880
881impl AsoBatchOutput {
882 #[inline]
883 pub fn bulls_row(&self, row: usize) -> &[f64] {
884 let s = row * self.cols;
885 &self.bulls[s..s + self.cols]
886 }
887
888 #[inline]
889 pub fn bears_row(&self, row: usize) -> &[f64] {
890 let s = row * self.cols;
891 &self.bears[s..s + self.cols]
892 }
893
894 #[inline]
895 pub fn row_for_params(&self, p: &AsoParams) -> Option<usize> {
896 self.combos
897 .iter()
898 .position(|c| c.period == p.period && c.mode == p.mode)
899 }
900
901 #[inline]
902 pub fn values_for(&self, p: &AsoParams) -> Option<(&[f64], &[f64])> {
903 self.row_for_params(p)
904 .map(|row| (self.bulls_row(row), self.bears_row(row)))
905 }
906}
907
908#[inline(always)]
909fn expand_grid_aso(r: &AsoBatchRange) -> Result<Vec<AsoParams>, AsoError> {
910 fn axis_usize((s, e, st): (usize, usize, usize)) -> Result<Vec<usize>, AsoError> {
911 if st == 0 || s == e {
912 return Ok(vec![s]);
913 }
914 let mut v = Vec::new();
915 if s < e {
916 let mut cur = s;
917 while cur <= e {
918 v.push(cur);
919 let next = cur.saturating_add(st);
920 if next == cur {
921 break;
922 }
923 cur = next;
924 }
925 } else {
926 let mut cur = s;
927 while cur >= e {
928 v.push(cur);
929 let next = cur.saturating_sub(st);
930 if next == cur {
931 break;
932 }
933 cur = next;
934 if cur == 0 && e > 0 {
935 break;
936 }
937 }
938 }
939 if v.is_empty() {
940 return Err(AsoError::InvalidRange {
941 start: s,
942 end: e,
943 step: st,
944 });
945 }
946 Ok(v)
947 }
948
949 let ps = axis_usize(r.period)?;
950 let ms = axis_usize(r.mode)?;
951 let total = ps
952 .len()
953 .checked_mul(ms.len())
954 .ok_or(AsoError::InvalidRange {
955 start: ps.len(),
956 end: ms.len(),
957 step: 0,
958 })?;
959 let mut out = Vec::with_capacity(total);
960 for &p in &ps {
961 for &m in &ms {
962 out.push(AsoParams {
963 period: Some(p),
964 mode: Some(m),
965 });
966 }
967 }
968 Ok(out)
969}
970
971pub fn aso_batch_with_kernel(
972 open: &[f64],
973 high: &[f64],
974 low: &[f64],
975 close: &[f64],
976 sweep: &AsoBatchRange,
977 k: Kernel,
978) -> Result<AsoBatchOutput, AsoError> {
979 let combos = expand_grid_aso(sweep)?;
980 let rows = combos.len();
981 let cols = close.len();
982
983 if cols == 0 {
984 return Err(AsoError::EmptyInputData);
985 }
986 if open.len() != cols || high.len() != cols || low.len() != cols {
987 return Err(AsoError::MissingData);
988 }
989
990 let first = close
991 .iter()
992 .position(|x| !x.is_nan())
993 .ok_or(AsoError::AllValuesNaN)?;
994 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
995
996 if cols - first < max_p {
997 return Err(AsoError::NotEnoughValidData {
998 needed: max_p,
999 valid: cols - first,
1000 });
1001 }
1002
1003 let mut bulls_mu = make_uninit_matrix(rows, cols);
1004 let mut bears_mu = make_uninit_matrix(rows, cols);
1005
1006 let warm: Vec<usize> = combos
1007 .iter()
1008 .map(|c| first + c.period.unwrap() - 1)
1009 .collect();
1010 init_matrix_prefixes(&mut bulls_mu, cols, &warm);
1011 init_matrix_prefixes(&mut bears_mu, cols, &warm);
1012
1013 let mut guard_b = core::mem::ManuallyDrop::new(bulls_mu);
1014 let mut guard_e = core::mem::ManuallyDrop::new(bears_mu);
1015 let bulls_out: &mut [f64] =
1016 unsafe { core::slice::from_raw_parts_mut(guard_b.as_mut_ptr() as *mut f64, guard_b.len()) };
1017 let bears_out: &mut [f64] =
1018 unsafe { core::slice::from_raw_parts_mut(guard_e.as_mut_ptr() as *mut f64, guard_e.len()) };
1019
1020 let mut actual = match k {
1021 Kernel::Auto => detect_best_batch_kernel(),
1022 Kernel::ScalarBatch | Kernel::Avx2Batch | Kernel::Avx512Batch => k,
1023 other => return Err(AsoError::InvalidKernelForBatch(other)),
1024 };
1025
1026 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1027 if matches!(k, Kernel::Auto) && matches!(actual, Kernel::Avx2Batch | Kernel::Avx512Batch) {
1028 actual = Kernel::ScalarBatch;
1029 }
1030
1031 let do_row = |row: usize, bulls_row: &mut [f64], bears_row: &mut [f64]| {
1032 let p = combos[row].period.unwrap();
1033 let m = combos[row].mode.unwrap();
1034 unsafe {
1035 match actual {
1036 Kernel::Scalar | Kernel::ScalarBatch => {
1037 aso_scalar(open, high, low, close, p, m, first, bulls_row, bears_row)
1038 }
1039 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1040 Kernel::Avx2 | Kernel::Avx2Batch => {
1041 aso_avx2(open, high, low, close, p, m, first, bulls_row, bears_row)
1042 }
1043 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1044 Kernel::Avx512 | Kernel::Avx512Batch => {
1045 aso_avx512(open, high, low, close, p, m, first, bulls_row, bears_row)
1046 }
1047 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1048 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
1049 aso_scalar(open, high, low, close, p, m, first, bulls_row, bears_row)
1050 }
1051 Kernel::Auto => unreachable!(),
1052 }
1053 }
1054 };
1055
1056 #[cfg(not(target_arch = "wasm32"))]
1057 {
1058 bulls_out
1059 .chunks_mut(cols)
1060 .zip(bears_out.chunks_mut(cols))
1061 .enumerate()
1062 .par_bridge()
1063 .for_each(|(row, (b, e))| do_row(row, b, e));
1064 }
1065 #[cfg(target_arch = "wasm32")]
1066 {
1067 for (row, (b, e)) in bulls_out
1068 .chunks_mut(cols)
1069 .zip(bears_out.chunks_mut(cols))
1070 .enumerate()
1071 {
1072 do_row(row, b, e);
1073 }
1074 }
1075
1076 let bulls = unsafe {
1077 Vec::from_raw_parts(
1078 guard_b.as_mut_ptr() as *mut f64,
1079 guard_b.len(),
1080 guard_b.capacity(),
1081 )
1082 };
1083 let bears = unsafe {
1084 Vec::from_raw_parts(
1085 guard_e.as_mut_ptr() as *mut f64,
1086 guard_e.len(),
1087 guard_e.capacity(),
1088 )
1089 };
1090
1091 Ok(AsoBatchOutput {
1092 bulls,
1093 bears,
1094 combos,
1095 rows,
1096 cols,
1097 })
1098}
1099
1100pub fn aso_batch_slice(
1101 open: &[f64],
1102 high: &[f64],
1103 low: &[f64],
1104 close: &[f64],
1105 sweep: &AsoBatchRange,
1106 kern: Kernel,
1107) -> Result<AsoBatchOutput, AsoError> {
1108 let k = match kern {
1109 Kernel::Scalar => Kernel::ScalarBatch,
1110 Kernel::Avx2 => Kernel::Avx2Batch,
1111 Kernel::Avx512 => Kernel::Avx512Batch,
1112 other => other,
1113 };
1114 aso_batch_with_kernel(open, high, low, close, sweep, k)
1115}
1116
1117#[cfg(not(target_arch = "wasm32"))]
1118pub fn aso_batch_par_slice(
1119 open: &[f64],
1120 high: &[f64],
1121 low: &[f64],
1122 close: &[f64],
1123 sweep: &AsoBatchRange,
1124 kern: Kernel,
1125) -> Result<AsoBatchOutput, AsoError> {
1126 let k = match kern {
1127 Kernel::Scalar => Kernel::ScalarBatch,
1128 Kernel::Avx2 => Kernel::Avx2Batch,
1129 Kernel::Avx512 => Kernel::Avx512Batch,
1130 other => other,
1131 };
1132 aso_batch_with_kernel(open, high, low, close, sweep, k)
1133}
1134
1135#[cfg(target_arch = "wasm32")]
1136pub fn aso_batch_par_slice(
1137 open: &[f64],
1138 high: &[f64],
1139 low: &[f64],
1140 close: &[f64],
1141 sweep: &AsoBatchRange,
1142 kern: Kernel,
1143) -> Result<AsoBatchOutput, AsoError> {
1144 aso_batch_with_kernel(open, high, low, close, sweep, kern)
1145}
1146
1147#[inline(always)]
1148fn aso_batch_inner_into(
1149 open: &[f64],
1150 high: &[f64],
1151 low: &[f64],
1152 close: &[f64],
1153 sweep: &AsoBatchRange,
1154 kern: Kernel,
1155 parallel: bool,
1156 out_bulls: &mut [f64],
1157 out_bears: &mut [f64],
1158) -> Result<Vec<AsoParams>, AsoError> {
1159 let combos = expand_grid_aso(sweep)?;
1160 if combos.is_empty() {
1161 return Err(AsoError::InvalidRange {
1162 start: 0,
1163 end: 0,
1164 step: 0,
1165 });
1166 }
1167
1168 let cols = close.len();
1169 if cols == 0 {
1170 return Err(AsoError::EmptyInputData);
1171 }
1172 if open.len() != cols || high.len() != cols || low.len() != cols {
1173 return Err(AsoError::MissingData);
1174 }
1175 let rows = combos.len();
1176 let total = rows.checked_mul(cols).ok_or(AsoError::InvalidRange {
1177 start: rows,
1178 end: cols,
1179 step: 0,
1180 })?;
1181 if out_bulls.len() != total || out_bears.len() != total {
1182 return Err(AsoError::OutputLengthMismatch {
1183 expected: total,
1184 got: out_bulls.len().min(out_bears.len()),
1185 });
1186 }
1187
1188 let first = close
1189 .iter()
1190 .position(|x| !x.is_nan())
1191 .ok_or(AsoError::AllValuesNaN)?;
1192 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1193 if cols - first < max_p {
1194 return Err(AsoError::NotEnoughValidData {
1195 needed: max_p,
1196 valid: cols - first,
1197 });
1198 }
1199
1200 let mut b_mu = unsafe {
1201 core::slice::from_raw_parts_mut(
1202 out_bulls.as_mut_ptr() as *mut MaybeUninit<f64>,
1203 out_bulls.len(),
1204 )
1205 };
1206 let mut e_mu = unsafe {
1207 core::slice::from_raw_parts_mut(
1208 out_bears.as_mut_ptr() as *mut MaybeUninit<f64>,
1209 out_bears.len(),
1210 )
1211 };
1212 let warm: Vec<usize> = combos
1213 .iter()
1214 .map(|c| first + c.period.unwrap() - 1)
1215 .collect();
1216 init_matrix_prefixes(&mut b_mu, cols, &warm);
1217 init_matrix_prefixes(&mut e_mu, cols, &warm);
1218
1219 let actual = match kern {
1220 Kernel::Auto => detect_best_batch_kernel(),
1221 Kernel::ScalarBatch | Kernel::Avx2Batch | Kernel::Avx512Batch => kern,
1222 other => return Err(AsoError::InvalidKernelForBatch(other)),
1223 };
1224
1225 let do_row = |row: usize, br: &mut [MaybeUninit<f64>], er: &mut [MaybeUninit<f64>]| unsafe {
1226 let p = combos[row].period.unwrap();
1227 let m = combos[row].mode.unwrap();
1228
1229 let b = core::slice::from_raw_parts_mut(br.as_mut_ptr() as *mut f64, br.len());
1230 let e = core::slice::from_raw_parts_mut(er.as_mut_ptr() as *mut f64, er.len());
1231
1232 match actual {
1233 Kernel::ScalarBatch => aso_scalar(open, high, low, close, p, m, first, b, e),
1234 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1235 Kernel::Avx2Batch => aso_avx2(open, high, low, close, p, m, first, b, e),
1236 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1237 Kernel::Avx512Batch => aso_avx512(open, high, low, close, p, m, first, b, e),
1238 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1239 Kernel::Avx2Batch | Kernel::Avx512Batch => {
1240 aso_scalar(open, high, low, close, p, m, first, b, e)
1241 }
1242 Kernel::Auto | Kernel::Scalar | Kernel::Avx2 | Kernel::Avx512 => unreachable!(),
1243 }
1244 };
1245
1246 if parallel {
1247 #[cfg(not(target_arch = "wasm32"))]
1248 {
1249 use rayon::prelude::*;
1250 b_mu.chunks_mut(cols)
1251 .zip(e_mu.chunks_mut(cols))
1252 .enumerate()
1253 .par_bridge()
1254 .for_each(|(row, (br, er))| do_row(row, br, er));
1255 }
1256 #[cfg(target_arch = "wasm32")]
1257 {
1258 for (row, (br, er)) in b_mu.chunks_mut(cols).zip(e_mu.chunks_mut(cols)).enumerate() {
1259 do_row(row, br, er);
1260 }
1261 }
1262 } else {
1263 for (row, (br, er)) in b_mu.chunks_mut(cols).zip(e_mu.chunks_mut(cols)).enumerate() {
1264 do_row(row, br, er);
1265 }
1266 }
1267
1268 Ok(combos)
1269}
1270
1271#[derive(Debug, Clone)]
1272pub struct AsoStream {
1273 o: Vec<f64>,
1274 h: Vec<f64>,
1275 l: Vec<f64>,
1276 c: Vec<f64>,
1277
1278 rb: Vec<f64>,
1279 re: Vec<f64>,
1280 sum_b: f64,
1281 sum_e: f64,
1282 head_be: usize,
1283 filled_be: usize,
1284
1285 dq_min_idx: Vec<usize>,
1286 dq_min_val: Vec<f64>,
1287 min_head: usize,
1288 min_tail: usize,
1289 min_len: usize,
1290
1291 dq_max_idx: Vec<usize>,
1292 dq_max_val: Vec<f64>,
1293 max_head: usize,
1294 max_tail: usize,
1295 max_len: usize,
1296
1297 period: usize,
1298 mode: usize,
1299 i: usize,
1300 ready: bool,
1301}
1302
1303impl AsoStream {
1304 #[inline]
1305 pub fn try_new(params: AsoParams) -> Result<Self, AsoError> {
1306 let period = params.period.unwrap_or(10);
1307 let mode = params.mode.unwrap_or(0);
1308
1309 if period == 0 {
1310 return Err(AsoError::InvalidPeriod {
1311 period,
1312 data_len: 0,
1313 });
1314 }
1315 if mode > 2 {
1316 return Err(AsoError::InvalidMode { mode });
1317 }
1318
1319 Ok(Self {
1320 o: vec![0.0; period],
1321 h: vec![0.0; period],
1322 l: vec![0.0; period],
1323 c: vec![0.0; period],
1324
1325 rb: vec![0.0; period],
1326 re: vec![0.0; period],
1327 sum_b: 0.0,
1328 sum_e: 0.0,
1329 head_be: 0,
1330 filled_be: 0,
1331
1332 dq_min_idx: vec![0usize; period],
1333 dq_min_val: vec![0.0; period],
1334 min_head: 0,
1335 min_tail: 0,
1336 min_len: 0,
1337
1338 dq_max_idx: vec![0usize; period],
1339 dq_max_val: vec![0.0; period],
1340 max_head: 0,
1341 max_tail: 0,
1342 max_len: 0,
1343
1344 period,
1345 mode,
1346 i: 0,
1347 ready: false,
1348 })
1349 }
1350
1351 #[inline(always)]
1352 fn inv_or_one(x: f64) -> f64 {
1353 if x != 0.0 {
1354 x.recip()
1355 } else {
1356 1.0
1357 }
1358 }
1359
1360 #[inline]
1361 pub fn update(&mut self, open: f64, high: f64, low: f64, close: f64) -> Option<(f64, f64)> {
1362 let p = self.period;
1363 let i = self.i;
1364 let idx = i % p;
1365
1366 self.o[idx] = open;
1367 self.h[idx] = high;
1368 self.l[idx] = low;
1369 self.c[idx] = close;
1370
1371 while self.min_len > 0 {
1372 let back = if self.min_tail == 0 {
1373 p - 1
1374 } else {
1375 self.min_tail - 1
1376 };
1377 if low <= self.dq_min_val[back] {
1378 self.min_tail = back;
1379 self.min_len -= 1;
1380 } else {
1381 break;
1382 }
1383 }
1384 if self.min_len == p {
1385 self.min_head += 1;
1386 if self.min_head == p {
1387 self.min_head = 0;
1388 }
1389 self.min_len -= 1;
1390 }
1391 self.dq_min_idx[self.min_tail] = i;
1392 self.dq_min_val[self.min_tail] = low;
1393 self.min_tail += 1;
1394 if self.min_tail == p {
1395 self.min_tail = 0;
1396 }
1397 self.min_len += 1;
1398
1399 while self.max_len > 0 {
1400 let back = if self.max_tail == 0 {
1401 p - 1
1402 } else {
1403 self.max_tail - 1
1404 };
1405 if high >= self.dq_max_val[back] {
1406 self.max_tail = back;
1407 self.max_len -= 1;
1408 } else {
1409 break;
1410 }
1411 }
1412 if self.max_len == p {
1413 self.max_head += 1;
1414 if self.max_head == p {
1415 self.max_head = 0;
1416 }
1417 self.max_len -= 1;
1418 }
1419 self.dq_max_idx[self.max_tail] = i;
1420 self.dq_max_val[self.max_tail] = high;
1421 self.max_tail += 1;
1422 if self.max_tail == p {
1423 self.max_tail = 0;
1424 }
1425 self.max_len += 1;
1426
1427 self.i = i + 1;
1428 if self.i >= p {
1429 self.ready = true;
1430 }
1431 if !self.ready {
1432 return None;
1433 }
1434
1435 let start_abs = self.i - p;
1436
1437 while self.min_len > 0 && self.dq_min_idx[self.min_head] < start_abs {
1438 self.min_head += 1;
1439 if self.min_head == p {
1440 self.min_head = 0;
1441 }
1442 self.min_len -= 1;
1443 }
1444 while self.max_len > 0 && self.dq_max_idx[self.max_head] < start_abs {
1445 self.max_head += 1;
1446 if self.max_head == p {
1447 self.max_head = 0;
1448 }
1449 self.max_len -= 1;
1450 }
1451
1452 debug_assert!(self.min_len > 0 && self.max_len > 0);
1453 let gl = self.dq_min_val[self.min_head];
1454 let gh = self.dq_max_val[self.max_head];
1455
1456 let oldest_ring = if idx + 1 == p { 0 } else { idx + 1 };
1457 let gopen = self.o[oldest_ring];
1458
1459 let intrarange = high - low;
1460 let scale1 = 50.0 * Self::inv_or_one(intrarange);
1461 let intrabarbulls = ((close - low) + (high - open)) * scale1;
1462 let intrabarbears = ((high - close) + (open - low)) * scale1;
1463
1464 let gr = gh - gl;
1465 let scale2 = 50.0 * Self::inv_or_one(gr);
1466 let groupbulls = ((close - gl) + (gh - gopen)) * scale2;
1467 let groupbears = ((gh - close) + (gopen - gl)) * scale2;
1468
1469 let b = match self.mode {
1470 0 => 0.5 * (intrabarbulls + groupbulls),
1471 1 => intrabarbulls,
1472 _ => groupbulls,
1473 };
1474 let e = match self.mode {
1475 0 => 0.5 * (intrabarbears + groupbears),
1476 1 => intrabarbears,
1477 _ => groupbears,
1478 };
1479
1480 let old_b = if self.filled_be == p {
1481 self.rb[self.head_be]
1482 } else {
1483 0.0
1484 };
1485 let old_e = if self.filled_be == p {
1486 self.re[self.head_be]
1487 } else {
1488 0.0
1489 };
1490
1491 self.sum_b += b - old_b;
1492 self.sum_e += e - old_e;
1493
1494 self.rb[self.head_be] = b;
1495 self.re[self.head_be] = e;
1496
1497 self.head_be += 1;
1498 if self.head_be == p {
1499 self.head_be = 0;
1500 }
1501 if self.filled_be < p {
1502 self.filled_be += 1;
1503 }
1504
1505 let n = self.filled_be as f64;
1506 Some((self.sum_b / n, self.sum_e / n))
1507 }
1508}
1509
1510#[cfg(feature = "python")]
1511#[pyfunction(name = "aso")]
1512#[pyo3(signature = (open, high, low, close, period=None, mode=None, kernel=None))]
1513pub fn aso_py<'py>(
1514 py: Python<'py>,
1515 open: PyReadonlyArray1<'py, f64>,
1516 high: PyReadonlyArray1<'py, f64>,
1517 low: PyReadonlyArray1<'py, f64>,
1518 close: PyReadonlyArray1<'py, f64>,
1519 period: Option<usize>,
1520 mode: Option<usize>,
1521 kernel: Option<&str>,
1522) -> PyResult<(Bound<'py, PyArray1<f64>>, Bound<'py, PyArray1<f64>>)> {
1523 let o = open.as_slice()?;
1524 let h = high.as_slice()?;
1525 let l = low.as_slice()?;
1526 let c = close.as_slice()?;
1527
1528 if h.len() != o.len() || l.len() != o.len() || c.len() != o.len() {
1529 return Err(PyValueError::new_err(
1530 "All OHLC arrays must have the same length",
1531 ));
1532 }
1533
1534 let kern = validate_kernel(kernel, false)?;
1535 let params = AsoParams { period, mode };
1536 let input = AsoInput::from_slices(o, h, l, c, params);
1537
1538 let (bulls, bears) = py
1539 .allow_threads(|| aso_with_kernel(&input, kern).map(|o| (o.bulls, o.bears)))
1540 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1541
1542 Ok((bulls.into_pyarray(py), bears.into_pyarray(py)))
1543}
1544
1545#[cfg(feature = "python")]
1546#[pyfunction(name = "aso_batch")]
1547#[pyo3(signature = (open, high, low, close, period_range, mode_range, kernel=None))]
1548pub fn aso_batch_py<'py>(
1549 py: Python<'py>,
1550 open: PyReadonlyArray1<'py, f64>,
1551 high: PyReadonlyArray1<'py, f64>,
1552 low: PyReadonlyArray1<'py, f64>,
1553 close: PyReadonlyArray1<'py, f64>,
1554 period_range: (usize, usize, usize),
1555 mode_range: (usize, usize, usize),
1556 kernel: Option<&str>,
1557) -> PyResult<Bound<'py, PyDict>> {
1558 use numpy::PyArray1;
1559 let (o, h, l, c) = (
1560 open.as_slice()?,
1561 high.as_slice()?,
1562 low.as_slice()?,
1563 close.as_slice()?,
1564 );
1565 let sweep = AsoBatchRange {
1566 period: period_range,
1567 mode: mode_range,
1568 };
1569 let combos = expand_grid_aso(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1570 let rows = combos.len();
1571 let cols = c.len();
1572 let total = rows
1573 .checked_mul(cols)
1574 .ok_or_else(|| PyValueError::new_err("size overflow"))?;
1575
1576 let bulls_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1577 let bears_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1578 let b = unsafe { bulls_arr.as_slice_mut()? };
1579 let e = unsafe { bears_arr.as_slice_mut()? };
1580
1581 let kern = validate_kernel(kernel, true)?;
1582 py.allow_threads(|| {
1583 let simd = match kern {
1584 Kernel::Auto => detect_best_batch_kernel(),
1585 k => k,
1586 };
1587 aso_batch_inner_into(o, h, l, c, &sweep, simd, true, b, e)
1588 })
1589 .map_err(|er| PyValueError::new_err(er.to_string()))?;
1590
1591 let d = PyDict::new(py);
1592 d.set_item("bulls", bulls_arr.reshape((rows, cols))?)?;
1593 d.set_item("bears", bears_arr.reshape((rows, cols))?)?;
1594 d.set_item(
1595 "periods",
1596 combos
1597 .iter()
1598 .map(|p| p.period.unwrap() as u64)
1599 .collect::<Vec<_>>()
1600 .into_pyarray(py),
1601 )?;
1602 d.set_item(
1603 "modes",
1604 combos
1605 .iter()
1606 .map(|p| p.mode.unwrap() as u64)
1607 .collect::<Vec<_>>()
1608 .into_pyarray(py),
1609 )?;
1610 Ok(d)
1611}
1612
1613#[cfg(all(feature = "python", feature = "cuda"))]
1614use crate::cuda::cuda_available;
1615#[cfg(all(feature = "python", feature = "cuda"))]
1616use crate::cuda::moving_averages::DeviceArrayF32;
1617#[cfg(all(feature = "python", feature = "cuda"))]
1618use crate::cuda::oscillators::CudaAso;
1619#[cfg(all(feature = "python", feature = "cuda"))]
1620use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
1621#[cfg(all(feature = "python", feature = "cuda"))]
1622use cust::context::Context as CudaContext;
1623#[cfg(all(feature = "python", feature = "cuda"))]
1624use cust::memory::DeviceBuffer;
1625#[cfg(all(feature = "python", feature = "cuda"))]
1626use std::sync::Arc;
1627
1628#[cfg(all(feature = "python", feature = "cuda"))]
1629#[pyclass(module = "ta_indicators.cuda", unsendable)]
1630pub struct AsoDeviceArrayF32Py {
1631 pub(crate) buf: Option<DeviceBuffer<f32>>,
1632 pub(crate) rows: usize,
1633 pub(crate) cols: usize,
1634 pub(crate) _ctx: Arc<CudaContext>,
1635 pub(crate) device_id: u32,
1636}
1637
1638#[cfg(all(feature = "python", feature = "cuda"))]
1639#[pymethods]
1640impl AsoDeviceArrayF32Py {
1641 #[getter]
1642 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1643 let d = PyDict::new(py);
1644 d.set_item("shape", (self.rows, self.cols))?;
1645 d.set_item("typestr", "<f4")?;
1646 d.set_item(
1647 "strides",
1648 (
1649 self.cols * std::mem::size_of::<f32>(),
1650 std::mem::size_of::<f32>(),
1651 ),
1652 )?;
1653 let ptr = if self.rows == 0 || self.cols == 0 {
1654 0usize
1655 } else {
1656 self.buf
1657 .as_ref()
1658 .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?
1659 .as_device_ptr()
1660 .as_raw() as usize
1661 };
1662 d.set_item("data", (ptr, false))?;
1663
1664 d.set_item("version", 3)?;
1665 Ok(d)
1666 }
1667
1668 fn __dlpack_device__(&self) -> (i32, i32) {
1669 (2, self.device_id as i32)
1670 }
1671
1672 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1673 fn __dlpack__<'py>(
1674 &mut self,
1675 py: Python<'py>,
1676 stream: Option<PyObject>,
1677 max_version: Option<PyObject>,
1678 dl_device: Option<PyObject>,
1679 copy: Option<PyObject>,
1680 ) -> PyResult<PyObject> {
1681 let (kdl, alloc_dev) = self.__dlpack_device__();
1682 if let Some(dev_obj) = dl_device.as_ref() {
1683 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1684 if dev_ty != kdl || dev_id != alloc_dev {
1685 let wants_copy = copy
1686 .as_ref()
1687 .and_then(|c| c.extract::<bool>(py).ok())
1688 .unwrap_or(false);
1689 if wants_copy {
1690 return Err(PyValueError::new_err(
1691 "__dlpack__(copy=True) not implemented for ASO device handle",
1692 ));
1693 } else {
1694 return Err(PyValueError::new_err("dl_device mismatch for ASO tensor"));
1695 }
1696 }
1697 }
1698 }
1699 let _ = stream;
1700
1701 let buf = self
1702 .buf
1703 .take()
1704 .ok_or_else(|| PyValueError::new_err("__dlpack__ may only be called once"))?;
1705
1706 let rows = self.rows;
1707 let cols = self.cols;
1708
1709 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1710
1711 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1712 }
1713}
1714
1715#[cfg(all(feature = "python", feature = "cuda"))]
1716#[pyfunction(name = "aso_cuda_batch_dev")]
1717#[pyo3(signature = (open, high, low, close, period_range, mode_range, device_id=0))]
1718pub fn aso_cuda_batch_dev_py(
1719 py: Python<'_>,
1720 open: PyReadonlyArray1<'_, f32>,
1721 high: PyReadonlyArray1<'_, f32>,
1722 low: PyReadonlyArray1<'_, f32>,
1723 close: PyReadonlyArray1<'_, f32>,
1724 period_range: (usize, usize, usize),
1725 mode_range: (usize, usize, usize),
1726 device_id: usize,
1727) -> PyResult<(AsoDeviceArrayF32Py, AsoDeviceArrayF32Py)> {
1728 if !cuda_available() {
1729 return Err(PyValueError::new_err("CUDA not available"));
1730 }
1731 let o = open.as_slice()?;
1732 let h = high.as_slice()?;
1733 let l = low.as_slice()?;
1734 let c = close.as_slice()?;
1735 if o.len() == 0 || h.len() != o.len() || l.len() != o.len() || c.len() != o.len() {
1736 return Err(PyValueError::new_err("mismatched input lengths"));
1737 }
1738 let sweep = AsoBatchRange {
1739 period: period_range,
1740 mode: mode_range,
1741 };
1742 let (bulls, bears, ctx_guard, dev_id) = py.allow_threads(|| {
1743 let cuda = CudaAso::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1744 let out = cuda
1745 .aso_batch_dev(o, h, l, c, &sweep)
1746 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1747 Ok::<_, PyErr>((out.0, out.1, cuda.context_arc(), cuda.device_id()))
1748 })?;
1749 Ok((
1750 AsoDeviceArrayF32Py {
1751 buf: Some(bulls.buf),
1752 rows: bulls.rows,
1753 cols: bulls.cols,
1754 _ctx: ctx_guard.clone(),
1755 device_id: dev_id,
1756 },
1757 AsoDeviceArrayF32Py {
1758 buf: Some(bears.buf),
1759 rows: bears.rows,
1760 cols: bears.cols,
1761 _ctx: ctx_guard,
1762 device_id: dev_id,
1763 },
1764 ))
1765}
1766
1767#[cfg(all(feature = "python", feature = "cuda"))]
1768#[pyfunction(name = "aso_cuda_many_series_one_param_dev")]
1769#[pyo3(signature = (open_tm, high_tm, low_tm, close_tm, cols, rows, period, mode, device_id=0))]
1770pub fn aso_cuda_many_series_one_param_dev_py(
1771 py: Python<'_>,
1772 open_tm: PyReadonlyArray1<'_, f32>,
1773 high_tm: PyReadonlyArray1<'_, f32>,
1774 low_tm: PyReadonlyArray1<'_, f32>,
1775 close_tm: PyReadonlyArray1<'_, f32>,
1776 cols: usize,
1777 rows: usize,
1778 period: usize,
1779 mode: usize,
1780 device_id: usize,
1781) -> PyResult<(AsoDeviceArrayF32Py, AsoDeviceArrayF32Py)> {
1782 if !cuda_available() {
1783 return Err(PyValueError::new_err("CUDA not available"));
1784 }
1785 let o = open_tm.as_slice()?;
1786 let h = high_tm.as_slice()?;
1787 let l = low_tm.as_slice()?;
1788 let c = close_tm.as_slice()?;
1789 let expected = cols
1790 .checked_mul(rows)
1791 .ok_or_else(|| PyValueError::new_err("size overflow"))?;
1792 if expected != o.len() || h.len() != o.len() || l.len() != o.len() || c.len() != o.len() {
1793 return Err(PyValueError::new_err("mismatched input sizes"));
1794 }
1795 if mode > 2 {
1796 return Err(PyValueError::new_err("invalid mode"));
1797 }
1798 let (bulls, bears, ctx_guard, dev_id) = py.allow_threads(|| {
1799 let cuda = CudaAso::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1800 let out = cuda
1801 .aso_many_series_one_param_time_major_dev(o, h, l, c, cols, rows, period, mode)
1802 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1803 Ok::<_, PyErr>((out.0, out.1, cuda.context_arc(), cuda.device_id()))
1804 })?;
1805 Ok((
1806 AsoDeviceArrayF32Py {
1807 buf: Some(bulls.buf),
1808 rows: bulls.rows,
1809 cols: bulls.cols,
1810 _ctx: ctx_guard.clone(),
1811 device_id: dev_id,
1812 },
1813 AsoDeviceArrayF32Py {
1814 buf: Some(bears.buf),
1815 rows: bears.rows,
1816 cols: bears.cols,
1817 _ctx: ctx_guard,
1818 device_id: dev_id,
1819 },
1820 ))
1821}
1822
1823#[cfg(feature = "python")]
1824#[pyclass(name = "AsoStream")]
1825pub struct AsoStreamPy {
1826 stream: AsoStream,
1827}
1828
1829#[cfg(feature = "python")]
1830#[pymethods]
1831impl AsoStreamPy {
1832 #[new]
1833 fn new(period: Option<usize>, mode: Option<usize>) -> PyResult<Self> {
1834 let params = AsoParams { period, mode };
1835 let stream =
1836 AsoStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1837 Ok(AsoStreamPy { stream })
1838 }
1839
1840 fn update(&mut self, open: f64, high: f64, low: f64, close: f64) -> Option<(f64, f64)> {
1841 self.stream.update(open, high, low, close)
1842 }
1843}
1844
1845#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1846#[derive(Serialize, Deserialize)]
1847pub struct AsoResult {
1848 pub values: Vec<f64>,
1849 pub rows: usize,
1850 pub cols: usize,
1851}
1852
1853#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1854#[wasm_bindgen(js_name = "aso")]
1855pub fn aso_js(
1856 open: &[f64],
1857 high: &[f64],
1858 low: &[f64],
1859 close: &[f64],
1860 period: Option<usize>,
1861 mode: Option<usize>,
1862) -> Result<JsValue, JsValue> {
1863 let len = close.len();
1864 if open.len() != len || high.len() != len || low.len() != len {
1865 return Err(JsValue::from_str(
1866 "All OHLC arrays must have the same length",
1867 ));
1868 }
1869 let p = period.unwrap_or(10);
1870 let m = mode.unwrap_or(0);
1871 if m > 2 {
1872 return Err(JsValue::from_str("Invalid mode"));
1873 }
1874
1875 let first = close
1876 .iter()
1877 .position(|x| !x.is_nan())
1878 .ok_or_else(|| JsValue::from_str("All values NaN"))?;
1879 if p == 0 || p > len {
1880 return Err(JsValue::from_str("Invalid period"));
1881 }
1882 if len - first < p {
1883 return Err(JsValue::from_str("Not enough valid data"));
1884 }
1885
1886 let mut mu = make_uninit_matrix(2, len);
1887 let warm = first + p - 1;
1888 init_matrix_prefixes(&mut mu, len, &[warm, warm]);
1889
1890 let mut guard = core::mem::ManuallyDrop::new(mu);
1891 let dst: &mut [f64] =
1892 unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
1893 let (bulls_dst, bears_dst) = dst.split_at_mut(len);
1894
1895 let chosen = detect_best_kernel();
1896 unsafe {
1897 aso_compute_into(
1898 open, high, low, close, p, m, first, chosen, bulls_dst, bears_dst,
1899 );
1900 }
1901
1902 let values = unsafe {
1903 Vec::from_raw_parts(
1904 guard.as_mut_ptr() as *mut f64,
1905 guard.len(),
1906 guard.capacity(),
1907 )
1908 };
1909 let out = AsoResult {
1910 values,
1911 rows: 2,
1912 cols: len,
1913 };
1914 serde_wasm_bindgen::to_value(&out).map_err(|e| JsValue::from_str(&e.to_string()))
1915}
1916
1917#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1918#[wasm_bindgen]
1919pub fn aso_into(
1920 open_ptr: *const f64,
1921 high_ptr: *const f64,
1922 low_ptr: *const f64,
1923 close_ptr: *const f64,
1924 bulls_ptr: *mut f64,
1925 bears_ptr: *mut f64,
1926 len: usize,
1927 period: usize,
1928 mode: usize,
1929) -> Result<(), JsValue> {
1930 if [open_ptr, high_ptr, low_ptr, close_ptr]
1931 .iter()
1932 .any(|p| p.is_null())
1933 || [bulls_ptr, bears_ptr].iter().any(|p| p.is_null())
1934 {
1935 return Err(JsValue::from_str("null pointer"));
1936 }
1937
1938 unsafe {
1939 let o = std::slice::from_raw_parts(open_ptr, len);
1940 let h = std::slice::from_raw_parts(high_ptr, len);
1941 let l = std::slice::from_raw_parts(low_ptr, len);
1942 let c = std::slice::from_raw_parts(close_ptr, len);
1943 let bulls = std::slice::from_raw_parts_mut(bulls_ptr, len);
1944 let bears = std::slice::from_raw_parts_mut(bears_ptr, len);
1945
1946 let input = AsoInput::from_slices(
1947 o,
1948 h,
1949 l,
1950 c,
1951 AsoParams {
1952 period: Some(period),
1953 mode: Some(mode),
1954 },
1955 );
1956
1957 aso_into_slices(bulls, bears, &input, detect_best_kernel())
1958 .map_err(|e| JsValue::from_str(&e.to_string()))
1959 }
1960}
1961
1962#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1963#[wasm_bindgen]
1964pub fn aso_batch_into(
1965 open_ptr: *const f64,
1966 high_ptr: *const f64,
1967 low_ptr: *const f64,
1968 close_ptr: *const f64,
1969 len: usize,
1970 period_start: usize,
1971 period_end: usize,
1972 period_step: usize,
1973 mode_start: usize,
1974 mode_end: usize,
1975 mode_step: usize,
1976 bulls_out: *mut f64,
1977 bears_out: *mut f64,
1978) -> Result<usize, JsValue> {
1979 if [open_ptr, high_ptr, low_ptr, close_ptr, bulls_out, bears_out]
1980 .iter()
1981 .any(|p| p.is_null())
1982 {
1983 return Err(JsValue::from_str("null pointer"));
1984 }
1985 unsafe {
1986 let o = std::slice::from_raw_parts(open_ptr, len);
1987 let h = std::slice::from_raw_parts(high_ptr, len);
1988 let l = std::slice::from_raw_parts(low_ptr, len);
1989 let c = std::slice::from_raw_parts(close_ptr, len);
1990
1991 let sweep = AsoBatchRange {
1992 period: (period_start, period_end, period_step),
1993 mode: (mode_start, mode_end, mode_step),
1994 };
1995
1996 let combos = expand_grid_aso(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1997 let rows = combos.len();
1998 let cols = len;
1999 let total = rows
2000 .checked_mul(cols)
2001 .ok_or_else(|| JsValue::from_str("size overflow"))?;
2002
2003 let b = std::slice::from_raw_parts_mut(bulls_out, total);
2004 let e = std::slice::from_raw_parts_mut(bears_out, total);
2005
2006 aso_batch_inner_into(o, h, l, c, &sweep, detect_best_batch_kernel(), false, b, e)
2007 .map_err(|er| JsValue::from_str(&er.to_string()))?;
2008 Ok(rows)
2009 }
2010}
2011
2012#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2013#[derive(Serialize, Deserialize)]
2014pub struct AsoBatchConfig {
2015 pub period_range: (usize, usize, usize),
2016 pub mode_range: (usize, usize, usize),
2017}
2018
2019#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2020#[derive(Serialize, Deserialize)]
2021pub struct AsoBatchJsOutput {
2022 pub values: Vec<f64>,
2023 pub combos: Vec<AsoParams>,
2024 pub rows: usize,
2025 pub cols: usize,
2026}
2027
2028#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2029#[wasm_bindgen(js_name = "aso_batch")]
2030pub fn aso_batch_unified_js(
2031 open: &[f64],
2032 high: &[f64],
2033 low: &[f64],
2034 close: &[f64],
2035 config: JsValue,
2036) -> Result<JsValue, JsValue> {
2037 let cfg: AsoBatchConfig = serde_wasm_bindgen::from_value(config)
2038 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2039
2040 let sweep = AsoBatchRange {
2041 period: cfg.period_range,
2042 mode: cfg.mode_range,
2043 };
2044 let combos = expand_grid_aso(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2045 let rows = combos.len();
2046 let cols = close.len();
2047 if cols == 0 {
2048 return Err(JsValue::from_str("Empty input"));
2049 }
2050 if open.len() != cols || high.len() != cols || low.len() != cols {
2051 return Err(JsValue::from_str("OHLC length mismatch"));
2052 }
2053
2054 let mut mu = make_uninit_matrix(rows * 2, cols);
2055 let first = close
2056 .iter()
2057 .position(|x| !x.is_nan())
2058 .ok_or_else(|| JsValue::from_str("All values NaN"))?;
2059 let warms: Vec<usize> = combos
2060 .iter()
2061 .flat_map(|c| {
2062 let w = first + c.period.unwrap() - 1;
2063 [w, w]
2064 })
2065 .collect();
2066 init_matrix_prefixes(&mut mu, cols, &warms);
2067
2068 let mut guard = core::mem::ManuallyDrop::new(mu);
2069 let dst: &mut [f64] =
2070 unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
2071 let total = rows
2072 .checked_mul(cols)
2073 .ok_or_else(|| JsValue::from_str("size overflow"))?;
2074 let (bulls_dst, bears_dst) = dst.split_at_mut(total);
2075
2076 let kern = detect_best_batch_kernel();
2077 aso_batch_inner_into(
2078 open, high, low, close, &sweep, kern, false, bulls_dst, bears_dst,
2079 )
2080 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2081
2082 let values = unsafe {
2083 Vec::from_raw_parts(
2084 guard.as_mut_ptr() as *mut f64,
2085 guard.len(),
2086 guard.capacity(),
2087 )
2088 };
2089
2090 let out = AsoBatchJsOutput {
2091 values,
2092 combos: combos.clone(),
2093 rows: rows * 2,
2094 cols,
2095 };
2096 serde_wasm_bindgen::to_value(&out)
2097 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2098}
2099
2100#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2101#[wasm_bindgen]
2102pub fn aso_alloc(len: usize) -> *mut f64 {
2103 let mut v = Vec::<f64>::with_capacity(len);
2104 let p = v.as_mut_ptr();
2105 std::mem::forget(v);
2106 p
2107}
2108
2109#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2110#[wasm_bindgen]
2111pub fn aso_free(ptr: *mut f64, len: usize) {
2112 unsafe {
2113 let _ = Vec::from_raw_parts(ptr, len, len);
2114 }
2115}
2116
2117#[cfg(test)]
2118mod tests {
2119 use super::*;
2120 use crate::skip_if_unsupported;
2121 use crate::utilities::data_loader::read_candles_from_csv;
2122 #[cfg(feature = "proptest")]
2123 use proptest::prelude::*;
2124 use std::error::Error;
2125
2126 fn check_aso_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2127 skip_if_unsupported!(kernel, test_name);
2128
2129 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2130 let candles = read_candles_from_csv(file_path)?;
2131
2132 let input = AsoInput::from_candles(&candles, "close", AsoParams::default());
2133 let result = aso_with_kernel(&input, kernel)?;
2134
2135 let expected_bulls = [
2136 48.48594883,
2137 46.37206396,
2138 47.20522805,
2139 46.83750720,
2140 43.28268188,
2141 ];
2142
2143 let expected_bears = [
2144 51.51405117,
2145 53.62793604,
2146 52.79477195,
2147 53.16249280,
2148 56.71731812,
2149 ];
2150
2151 let start = result.bulls.len().saturating_sub(5);
2152 for (i, (&bull_val, &bear_val)) in result.bulls[start..]
2153 .iter()
2154 .zip(result.bears[start..].iter())
2155 .enumerate()
2156 {
2157 let bull_diff = (bull_val - expected_bulls[i]).abs();
2158 let bear_diff = (bear_val - expected_bears[i]).abs();
2159
2160 assert!(
2161 bull_diff < 1e-6,
2162 "[{}] ASO Bulls {:?} mismatch at idx {}: got {}, expected {}",
2163 test_name,
2164 kernel,
2165 i,
2166 bull_val,
2167 expected_bulls[i]
2168 );
2169
2170 assert!(
2171 bear_diff < 1e-6,
2172 "[{}] ASO Bears {:?} mismatch at idx {}: got {}, expected {}",
2173 test_name,
2174 kernel,
2175 i,
2176 bear_val,
2177 expected_bears[i]
2178 );
2179 }
2180 Ok(())
2181 }
2182
2183 fn check_aso_slice_input(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2184 skip_if_unsupported!(kernel, test_name);
2185
2186 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2187 let candles = read_candles_from_csv(file_path)?;
2188
2189 let input = AsoInput::from_slices(
2190 &candles.open,
2191 &candles.high,
2192 &candles.low,
2193 &candles.close,
2194 AsoParams::default(),
2195 );
2196 let result = aso_with_kernel(&input, kernel)?;
2197
2198 assert_eq!(result.bulls.len(), candles.close.len());
2199 assert_eq!(result.bears.len(), candles.close.len());
2200
2201 Ok(())
2202 }
2203
2204 fn check_aso_into_slices(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2205 skip_if_unsupported!(kernel, test_name);
2206
2207 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2208 let candles = read_candles_from_csv(file_path)?;
2209
2210 let mut bulls = vec![0.0; candles.close.len()];
2211 let mut bears = vec![0.0; candles.close.len()];
2212
2213 let input = AsoInput::from_candles(&candles, "close", AsoParams::default());
2214 aso_into_slices(&mut bulls, &mut bears, &input, kernel)?;
2215
2216 for i in 0..9 {
2217 assert!(bulls[i].is_nan());
2218 assert!(bears[i].is_nan());
2219 }
2220
2221 assert!(!bulls[20].is_nan());
2222 assert!(!bears[20].is_nan());
2223
2224 Ok(())
2225 }
2226
2227 fn check_aso_batch(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2228 skip_if_unsupported!(kernel, test_name);
2229
2230 let open = vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0];
2231 let high = vec![15.0, 25.0, 35.0, 45.0, 55.0, 65.0, 75.0, 85.0, 95.0, 105.0];
2232 let low = vec![5.0, 15.0, 25.0, 35.0, 45.0, 55.0, 65.0, 75.0, 85.0, 95.0];
2233 let close = vec![12.0, 22.0, 32.0, 42.0, 52.0, 62.0, 72.0, 82.0, 92.0, 102.0];
2234
2235 let sweep = AsoBatchRange {
2236 period: (3, 5, 1),
2237 mode: (0, 2, 1),
2238 };
2239
2240 let result = aso_batch_with_kernel(&open, &high, &low, &close, &sweep, kernel)?;
2241
2242 assert_eq!(result.rows, 9);
2243 assert_eq!(result.cols, 10);
2244 assert_eq!(result.bulls.len(), 90);
2245 assert_eq!(result.bears.len(), 90);
2246 assert_eq!(result.combos.len(), 9);
2247
2248 Ok(())
2249 }
2250
2251 fn check_aso_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2252 skip_if_unsupported!(kernel, test_name);
2253
2254 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2255 let candles = read_candles_from_csv(file_path)?;
2256
2257 let default_params = AsoParams {
2258 period: None,
2259 mode: None,
2260 };
2261 let input = AsoInput::from_candles(&candles, "close", default_params);
2262 let output = aso_with_kernel(&input, kernel)?;
2263 assert_eq!(output.bulls.len(), candles.close.len());
2264 assert_eq!(output.bears.len(), candles.close.len());
2265
2266 Ok(())
2267 }
2268
2269 fn check_aso_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2270 skip_if_unsupported!(kernel, test_name);
2271
2272 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2273 let candles = read_candles_from_csv(file_path)?;
2274
2275 let input = AsoInput::with_default_candles(&candles);
2276 let output = aso_with_kernel(&input, kernel)?;
2277 assert_eq!(output.bulls.len(), candles.close.len());
2278 assert_eq!(output.bears.len(), candles.close.len());
2279
2280 Ok(())
2281 }
2282
2283 fn check_aso_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2284 skip_if_unsupported!(kernel, test_name);
2285
2286 let open = vec![10.0, 20.0, 30.0];
2287 let high = vec![15.0, 25.0, 35.0];
2288 let low = vec![8.0, 18.0, 28.0];
2289 let close = vec![12.0, 22.0, 32.0];
2290
2291 let params = AsoParams {
2292 period: Some(0),
2293 mode: None,
2294 };
2295 let input = AsoInput::from_slices(&open, &high, &low, &close, params);
2296 let res = aso_with_kernel(&input, kernel);
2297 assert!(
2298 res.is_err(),
2299 "[{}] ASO should fail with zero period",
2300 test_name
2301 );
2302 Ok(())
2303 }
2304
2305 fn check_aso_period_exceeds_length(
2306 test_name: &str,
2307 kernel: Kernel,
2308 ) -> Result<(), Box<dyn Error>> {
2309 skip_if_unsupported!(kernel, test_name);
2310
2311 let open = vec![10.0, 20.0, 30.0];
2312 let high = vec![15.0, 25.0, 35.0];
2313 let low = vec![8.0, 18.0, 28.0];
2314 let close = vec![12.0, 22.0, 32.0];
2315
2316 let params = AsoParams {
2317 period: Some(10),
2318 mode: None,
2319 };
2320 let input = AsoInput::from_slices(&open, &high, &low, &close, params);
2321 let res = aso_with_kernel(&input, kernel);
2322 assert!(
2323 res.is_err(),
2324 "[{}] ASO should fail with period exceeding length",
2325 test_name
2326 );
2327 Ok(())
2328 }
2329
2330 fn check_aso_invalid_mode(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2331 skip_if_unsupported!(kernel, test_name);
2332
2333 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2334 let candles = read_candles_from_csv(file_path)?;
2335
2336 let params = AsoParams {
2337 period: Some(10),
2338 mode: Some(3),
2339 };
2340 let input = AsoInput::from_candles(&candles, "close", params);
2341 let res = aso_with_kernel(&input, kernel);
2342 assert!(
2343 res.is_err(),
2344 "[{}] ASO should fail with invalid mode",
2345 test_name
2346 );
2347 Ok(())
2348 }
2349
2350 fn check_aso_empty_input(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2351 skip_if_unsupported!(kernel, test_name);
2352
2353 let empty: Vec<f64> = vec![];
2354 let params = AsoParams::default();
2355 let input = AsoInput::from_slices(&empty, &empty, &empty, &empty, params);
2356 let res = aso_with_kernel(&input, kernel);
2357 assert!(
2358 res.is_err(),
2359 "[{}] ASO should fail with empty input",
2360 test_name
2361 );
2362 Ok(())
2363 }
2364
2365 fn check_aso_all_nan(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2366 skip_if_unsupported!(kernel, test_name);
2367
2368 let nan_data = vec![f64::NAN, f64::NAN, f64::NAN];
2369 let params = AsoParams::default();
2370 let input = AsoInput::from_slices(&nan_data, &nan_data, &nan_data, &nan_data, params);
2371 let res = aso_with_kernel(&input, kernel);
2372 assert!(
2373 res.is_err(),
2374 "[{}] ASO should fail with all NaN values",
2375 test_name
2376 );
2377 Ok(())
2378 }
2379
2380 fn check_aso_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2381 skip_if_unsupported!(kernel, test_name);
2382 let single_point = [42.0];
2383 let params = AsoParams {
2384 period: Some(10),
2385 mode: None,
2386 };
2387 let input = AsoInput::from_slices(
2388 &single_point,
2389 &single_point,
2390 &single_point,
2391 &single_point,
2392 params,
2393 );
2394 let res = aso_with_kernel(&input, kernel);
2395 assert!(
2396 res.is_err(),
2397 "[{}] ASO should fail with insufficient data",
2398 test_name
2399 );
2400 Ok(())
2401 }
2402
2403 fn check_aso_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2404 skip_if_unsupported!(kernel, test_name);
2405 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2406 let candles = read_candles_from_csv(file_path)?;
2407
2408 let first_params = AsoParams {
2409 period: Some(10),
2410 mode: Some(0),
2411 };
2412 let first_input = AsoInput::from_candles(&candles, "close", first_params);
2413 let first_result = aso_with_kernel(&first_input, kernel)?;
2414
2415 let second_params = AsoParams {
2416 period: Some(10),
2417 mode: Some(0),
2418 };
2419 let second_input = AsoInput::from_slices(
2420 &first_result.bulls,
2421 &first_result.bulls,
2422 &first_result.bulls,
2423 &first_result.bulls,
2424 second_params,
2425 );
2426 let second_result = aso_with_kernel(&second_input, kernel)?;
2427
2428 assert_eq!(second_result.bulls.len(), first_result.bulls.len());
2429 assert_eq!(second_result.bears.len(), first_result.bears.len());
2430
2431 if second_result.bulls.len() > 30 {
2432 assert!(!second_result.bulls[30].is_nan());
2433 assert!(!second_result.bears[30].is_nan());
2434 }
2435
2436 Ok(())
2437 }
2438
2439 fn check_aso_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2440 skip_if_unsupported!(kernel, test_name);
2441 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2442 let candles = read_candles_from_csv(file_path)?;
2443
2444 let input = AsoInput::from_candles(
2445 &candles,
2446 "close",
2447 AsoParams {
2448 period: Some(10),
2449 mode: Some(0),
2450 },
2451 );
2452 let res = aso_with_kernel(&input, kernel)?;
2453 assert_eq!(res.bulls.len(), candles.close.len());
2454 assert_eq!(res.bears.len(), candles.close.len());
2455
2456 if res.bulls.len() > 240 {
2457 for (i, (&bull_val, &bear_val)) in res.bulls[240..]
2458 .iter()
2459 .zip(res.bears[240..].iter())
2460 .enumerate()
2461 {
2462 assert!(
2463 !bull_val.is_nan(),
2464 "[{}] Found unexpected NaN in bulls at out-index {}",
2465 test_name,
2466 240 + i
2467 );
2468 assert!(
2469 !bear_val.is_nan(),
2470 "[{}] Found unexpected NaN in bears at out-index {}",
2471 test_name,
2472 240 + i
2473 );
2474 }
2475 }
2476 Ok(())
2477 }
2478
2479 fn check_aso_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2480 skip_if_unsupported!(kernel, test_name);
2481
2482 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2483 let candles = read_candles_from_csv(file_path)?;
2484
2485 let period = 10;
2486 let mode = 0;
2487
2488 let input = AsoInput::from_candles(
2489 &candles,
2490 "close",
2491 AsoParams {
2492 period: Some(period),
2493 mode: Some(mode),
2494 },
2495 );
2496 let batch_output = aso_with_kernel(&input, kernel)?;
2497
2498 let mut stream = AsoStream::try_new(AsoParams {
2499 period: Some(period),
2500 mode: Some(mode),
2501 })?;
2502
2503 let mut stream_bulls = Vec::with_capacity(candles.close.len());
2504 let mut stream_bears = Vec::with_capacity(candles.close.len());
2505
2506 for i in 0..candles.close.len() {
2507 match stream.update(
2508 candles.open[i],
2509 candles.high[i],
2510 candles.low[i],
2511 candles.close[i],
2512 ) {
2513 Some((bull, bear)) => {
2514 stream_bulls.push(bull);
2515 stream_bears.push(bear);
2516 }
2517 None => {
2518 stream_bulls.push(f64::NAN);
2519 stream_bears.push(f64::NAN);
2520 }
2521 }
2522 }
2523
2524 assert_eq!(batch_output.bulls.len(), stream_bulls.len());
2525 assert_eq!(batch_output.bears.len(), stream_bears.len());
2526
2527 for (i, ((&batch_bull, &stream_bull), (&batch_bear, &stream_bear))) in batch_output
2528 .bulls
2529 .iter()
2530 .zip(stream_bulls.iter())
2531 .zip(batch_output.bears.iter().zip(stream_bears.iter()))
2532 .enumerate()
2533 {
2534 if batch_bull.is_nan() && stream_bull.is_nan() {
2535 continue;
2536 }
2537 if batch_bear.is_nan() && stream_bear.is_nan() {
2538 continue;
2539 }
2540
2541 if i >= period {
2542 if !batch_bull.is_nan() && !stream_bull.is_nan() {
2543 assert!(
2544 stream_bull >= -1e-9 && stream_bull <= 100.0 + 1e-9,
2545 "[{}] ASO streaming bulls out of range at idx {}: {}",
2546 test_name,
2547 i,
2548 stream_bull
2549 );
2550 }
2551 if !batch_bear.is_nan() && !stream_bear.is_nan() {
2552 assert!(
2553 stream_bear >= -1e-9 && stream_bear <= 100.0 + 1e-9,
2554 "[{}] ASO streaming bears out of range at idx {}: {}",
2555 test_name,
2556 i,
2557 stream_bear
2558 );
2559 }
2560
2561 if mode != 0 && !stream_bull.is_nan() && !stream_bear.is_nan() {
2562 let sum = stream_bull + stream_bear;
2563 assert!(
2564 (sum - 100.0).abs() < 1e-9,
2565 "[{}] ASO streaming bulls + bears != 100 at idx {} (mode {}): {} + {} = {}",
2566 test_name,
2567 i,
2568 mode,
2569 stream_bull,
2570 stream_bear,
2571 sum
2572 );
2573 }
2574 }
2575 }
2576 Ok(())
2577 }
2578
2579 #[cfg(debug_assertions)]
2580 fn check_aso_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2581 skip_if_unsupported!(kernel, test_name);
2582
2583 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2584 let candles = read_candles_from_csv(file_path)?;
2585
2586 let test_params = vec![
2587 AsoParams::default(),
2588 AsoParams {
2589 period: Some(5),
2590 mode: Some(0),
2591 },
2592 AsoParams {
2593 period: Some(5),
2594 mode: Some(1),
2595 },
2596 AsoParams {
2597 period: Some(5),
2598 mode: Some(2),
2599 },
2600 AsoParams {
2601 period: Some(10),
2602 mode: Some(0),
2603 },
2604 AsoParams {
2605 period: Some(10),
2606 mode: Some(1),
2607 },
2608 AsoParams {
2609 period: Some(10),
2610 mode: Some(2),
2611 },
2612 AsoParams {
2613 period: Some(20),
2614 mode: Some(0),
2615 },
2616 AsoParams {
2617 period: Some(20),
2618 mode: Some(1),
2619 },
2620 AsoParams {
2621 period: Some(20),
2622 mode: Some(2),
2623 },
2624 AsoParams {
2625 period: Some(2),
2626 mode: Some(0),
2627 },
2628 AsoParams {
2629 period: Some(50),
2630 mode: Some(1),
2631 },
2632 AsoParams {
2633 period: Some(100),
2634 mode: Some(2),
2635 },
2636 ];
2637
2638 for (param_idx, params) in test_params.iter().enumerate() {
2639 let input = AsoInput::from_candles(&candles, "close", params.clone());
2640 let output = aso_with_kernel(&input, kernel)?;
2641
2642 for (i, (&bull_val, &bear_val)) in
2643 output.bulls.iter().zip(output.bears.iter()).enumerate()
2644 {
2645 if bull_val.is_nan() || bear_val.is_nan() {
2646 continue;
2647 }
2648
2649 let bull_bits = bull_val.to_bits();
2650 let bear_bits = bear_val.to_bits();
2651
2652 for (val, bits, name) in [
2653 (bull_val, bull_bits, "bulls"),
2654 (bear_val, bear_bits, "bears"),
2655 ] {
2656 if bits == 0x11111111_11111111 {
2657 panic!(
2658 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) in {} at index {} \
2659 with params: period={}, mode={}",
2660 test_name,
2661 val,
2662 bits,
2663 name,
2664 i,
2665 params.period.unwrap_or(10),
2666 params.mode.unwrap_or(0)
2667 );
2668 }
2669
2670 if bits == 0x22222222_22222222 {
2671 panic!(
2672 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) in {} at index {} \
2673 with params: period={}, mode={}",
2674 test_name,
2675 val,
2676 bits,
2677 name,
2678 i,
2679 params.period.unwrap_or(10),
2680 params.mode.unwrap_or(0)
2681 );
2682 }
2683
2684 if bits == 0x33333333_33333333 {
2685 panic!(
2686 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) in {} at index {} \
2687 with params: period={}, mode={}",
2688 test_name,
2689 val,
2690 bits,
2691 name,
2692 i,
2693 params.period.unwrap_or(10),
2694 params.mode.unwrap_or(0)
2695 );
2696 }
2697 }
2698 }
2699 }
2700
2701 Ok(())
2702 }
2703
2704 #[cfg(not(debug_assertions))]
2705 fn check_aso_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2706 Ok(())
2707 }
2708
2709 #[cfg(feature = "proptest")]
2710 #[allow(clippy::float_cmp)]
2711 fn check_aso_property(
2712 test_name: &str,
2713 kernel: Kernel,
2714 ) -> Result<(), Box<dyn std::error::Error>> {
2715 use proptest::prelude::*;
2716 skip_if_unsupported!(kernel, test_name);
2717
2718 let strat = (2usize..=50).prop_flat_map(|period| {
2719 (
2720 prop::collection::vec(
2721 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
2722 period..400,
2723 ),
2724 Just(period),
2725 0usize..=2,
2726 )
2727 });
2728
2729 proptest::test_runner::TestRunner::default()
2730 .run(&strat, |(data, period, mode)| {
2731 let params = AsoParams {
2732 period: Some(period),
2733 mode: Some(mode),
2734 };
2735
2736 let mut open = Vec::with_capacity(data.len());
2737 let mut high = Vec::with_capacity(data.len());
2738 let mut low = Vec::with_capacity(data.len());
2739 let mut close = Vec::with_capacity(data.len());
2740
2741 for &val in &data {
2742 let spread = val.abs() * 0.1 + 1.0;
2743 open.push(val);
2744 high.push(val + spread);
2745 low.push(val - spread);
2746 close.push(val + spread * 0.5);
2747 }
2748
2749 let input = AsoInput::from_slices(&open, &high, &low, &close, params);
2750
2751 let AsoOutput {
2752 bulls: out_bulls,
2753 bears: out_bears,
2754 } = aso_with_kernel(&input, kernel).unwrap();
2755 let AsoOutput {
2756 bulls: ref_bulls,
2757 bears: ref_bears,
2758 } = aso_with_kernel(&input, Kernel::Scalar).unwrap();
2759
2760 for i in (period - 1)..data.len() {
2761 let bull = out_bulls[i];
2762 let bear = out_bears[i];
2763 let ref_bull = ref_bulls[i];
2764 let ref_bear = ref_bears[i];
2765
2766 if !bull.is_nan() && !bear.is_nan() {
2767 let sum = bull + bear;
2768 prop_assert!(
2769 (sum - 100.0).abs() < 1e-9,
2770 "idx {}: bulls + bears = {} + {} = {}, expected 100",
2771 i,
2772 bull,
2773 bear,
2774 sum
2775 );
2776 }
2777
2778 if !bull.is_nan() {
2779 prop_assert!(
2780 bull >= -1e-9 && bull <= 100.0 + 1e-9,
2781 "idx {}: bull {} out of range [0, 100]",
2782 i,
2783 bull
2784 );
2785 }
2786 if !bear.is_nan() {
2787 prop_assert!(
2788 bear >= -1e-9 && bear <= 100.0 + 1e-9,
2789 "idx {}: bear {} out of range [0, 100]",
2790 i,
2791 bear
2792 );
2793 }
2794
2795 let bull_bits = bull.to_bits();
2796 let bear_bits = bear.to_bits();
2797 let ref_bull_bits = ref_bull.to_bits();
2798 let ref_bear_bits = ref_bear.to_bits();
2799
2800 if !bull.is_finite() || !ref_bull.is_finite() {
2801 prop_assert!(
2802 bull_bits == ref_bull_bits,
2803 "bull finite/NaN mismatch idx {}: {} vs {}",
2804 i,
2805 bull,
2806 ref_bull
2807 );
2808 } else {
2809 let ulp_diff: u64 = bull_bits.abs_diff(ref_bull_bits);
2810 prop_assert!(
2811 (bull - ref_bull).abs() <= 1e-9 || ulp_diff <= 4,
2812 "bull mismatch idx {}: {} vs {} (ULP={})",
2813 i,
2814 bull,
2815 ref_bull,
2816 ulp_diff
2817 );
2818 }
2819
2820 if !bear.is_finite() || !ref_bear.is_finite() {
2821 prop_assert!(
2822 bear_bits == ref_bear_bits,
2823 "bear finite/NaN mismatch idx {}: {} vs {}",
2824 i,
2825 bear,
2826 ref_bear
2827 );
2828 } else {
2829 let ulp_diff: u64 = bear_bits.abs_diff(ref_bear_bits);
2830 prop_assert!(
2831 (bear - ref_bear).abs() <= 1e-9 || ulp_diff <= 4,
2832 "bear mismatch idx {}: {} vs {} (ULP={})",
2833 i,
2834 bear,
2835 ref_bear,
2836 ulp_diff
2837 );
2838 }
2839 }
2840 Ok(())
2841 })
2842 .unwrap();
2843
2844 Ok(())
2845 }
2846
2847 fn check_batch_default_row(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2848 skip_if_unsupported!(kernel, test_name);
2849
2850 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2851 let candles = read_candles_from_csv(file_path)?;
2852
2853 let output = AsoBatchBuilder::new()
2854 .kernel(kernel)
2855 .apply_candles(&candles)?;
2856
2857 let default_params = AsoParams::default();
2858 let default_row_idx = output
2859 .combos
2860 .iter()
2861 .position(|p| p.period == default_params.period && p.mode == default_params.mode)
2862 .expect("default row missing");
2863
2864 let bulls_row = output.bulls_row(default_row_idx);
2865 let bears_row = output.bears_row(default_row_idx);
2866
2867 assert_eq!(bulls_row.len(), candles.close.len());
2868 assert_eq!(bears_row.len(), candles.close.len());
2869
2870 let expected_bulls = [
2871 48.48594883,
2872 46.37206396,
2873 47.20522805,
2874 46.83750720,
2875 43.28268188,
2876 ];
2877 let expected_bears = [
2878 51.51405117,
2879 53.62793604,
2880 52.79477195,
2881 53.16249280,
2882 56.71731812,
2883 ];
2884
2885 let start = bulls_row.len() - 5;
2886 for (i, (&bull, &bear)) in bulls_row[start..]
2887 .iter()
2888 .zip(bears_row[start..].iter())
2889 .enumerate()
2890 {
2891 assert!(
2892 (bull - expected_bulls[i]).abs() < 1e-6,
2893 "[{}] default-row bulls mismatch at idx {}: {} vs {}",
2894 test_name,
2895 i,
2896 bull,
2897 expected_bulls[i]
2898 );
2899 assert!(
2900 (bear - expected_bears[i]).abs() < 1e-6,
2901 "[{}] default-row bears mismatch at idx {}: {} vs {}",
2902 test_name,
2903 i,
2904 bear,
2905 expected_bears[i]
2906 );
2907 }
2908 Ok(())
2909 }
2910
2911 fn check_batch_sweep(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2912 skip_if_unsupported!(kernel, test_name);
2913
2914 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2915 let candles = read_candles_from_csv(file_path)?;
2916
2917 let output = AsoBatchBuilder::new()
2918 .kernel(kernel)
2919 .period_range(10, 20, 2)
2920 .mode_range(0, 2, 1)
2921 .apply_candles(&candles)?;
2922
2923 let expected_combos = 6 * 3;
2924 assert_eq!(output.combos.len(), expected_combos);
2925 assert_eq!(output.rows, expected_combos);
2926 assert_eq!(output.cols, candles.close.len());
2927
2928 let mut found_combos = 0;
2929 for period in (10..=20).step_by(2) {
2930 for mode in 0..=2 {
2931 let found = output
2932 .combos
2933 .iter()
2934 .any(|c| c.period == Some(period) && c.mode == Some(mode));
2935 assert!(
2936 found,
2937 "[{}] Missing combo: period={}, mode={}",
2938 test_name, period, mode
2939 );
2940 if found {
2941 found_combos += 1;
2942 }
2943 }
2944 }
2945 assert_eq!(found_combos, expected_combos);
2946
2947 Ok(())
2948 }
2949
2950 #[cfg(debug_assertions)]
2951 fn check_batch_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2952 skip_if_unsupported!(kernel, test_name);
2953
2954 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2955 let candles = read_candles_from_csv(file_path)?;
2956
2957 let test_configs = vec![
2958 (2, 10, 2, 0, 2, 1),
2959 (5, 25, 5, 0, 0, 1),
2960 (10, 10, 1, 0, 2, 1),
2961 (2, 5, 1, 1, 1, 1),
2962 (30, 60, 15, 2, 2, 1),
2963 (9, 15, 3, 0, 2, 2),
2964 (8, 12, 1, 0, 1, 1),
2965 ];
2966
2967 for (cfg_idx, &(p_start, p_end, p_step, m_start, m_end, m_step)) in
2968 test_configs.iter().enumerate()
2969 {
2970 let output = AsoBatchBuilder::new()
2971 .kernel(kernel)
2972 .period_range(p_start, p_end, p_step)
2973 .mode_range(m_start, m_end, m_step)
2974 .apply_candles(&candles)?;
2975
2976 for (row_idx, combo) in output.combos.iter().enumerate() {
2977 let bulls_row = output.bulls_row(row_idx);
2978 let bears_row = output.bears_row(row_idx);
2979
2980 for (col_idx, (&bull_val, &bear_val)) in
2981 bulls_row.iter().zip(bears_row.iter()).enumerate()
2982 {
2983 if bull_val.is_nan() || bear_val.is_nan() {
2984 continue;
2985 }
2986
2987 let bull_bits = bull_val.to_bits();
2988 let bear_bits = bear_val.to_bits();
2989
2990 for (val, bits, name) in [
2991 (bull_val, bull_bits, "bulls"),
2992 (bear_val, bear_bits, "bears"),
2993 ] {
2994 if bits == 0x11111111_11111111 {
2995 panic!(
2996 "[{}] Config {}: Found alloc_with_nan_prefix poison {} (0x{:016X}) in {} \
2997 at row {} col {} (period={}, mode={})",
2998 test_name, cfg_idx, val, bits, name, row_idx, col_idx,
2999 combo.period.unwrap_or(10), combo.mode.unwrap_or(0)
3000 );
3001 }
3002
3003 if bits == 0x22222222_22222222 {
3004 panic!(
3005 "[{}] Config {}: Found init_matrix_prefixes poison {} (0x{:016X}) in {} \
3006 at row {} col {} (period={}, mode={})",
3007 test_name, cfg_idx, val, bits, name, row_idx, col_idx,
3008 combo.period.unwrap_or(10), combo.mode.unwrap_or(0)
3009 );
3010 }
3011
3012 if bits == 0x33333333_33333333 {
3013 panic!(
3014 "[{}] Config {}: Found make_uninit_matrix poison {} (0x{:016X}) in {} \
3015 at row {} col {} (period={}, mode={})",
3016 test_name, cfg_idx, val, bits, name, row_idx, col_idx,
3017 combo.period.unwrap_or(10), combo.mode.unwrap_or(0)
3018 );
3019 }
3020 }
3021 }
3022 }
3023 }
3024
3025 Ok(())
3026 }
3027
3028 #[cfg(not(debug_assertions))]
3029 fn check_batch_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
3030 Ok(())
3031 }
3032
3033 macro_rules! generate_all_aso_tests {
3034 ($($test_fn:ident),*) => {
3035 paste::paste! {
3036 $(
3037 #[test]
3038 fn [<$test_fn _scalar_f64>]() {
3039 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
3040 }
3041 )*
3042 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3043 $(
3044 #[test]
3045 fn [<$test_fn _avx2_f64>]() {
3046 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
3047 }
3048 #[test]
3049 fn [<$test_fn _avx512_f64>]() {
3050 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
3051 }
3052 )*
3053 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
3054 $(
3055 #[test]
3056 fn [<$test_fn _simd128_f64>]() {
3057 let _ = $test_fn(stringify!([<$test_fn _simd128_f64>]), Kernel::Scalar);
3058 }
3059 )*
3060 }
3061 }
3062 }
3063
3064 generate_all_aso_tests!(
3065 check_aso_accuracy,
3066 check_aso_slice_input,
3067 check_aso_into_slices,
3068 check_aso_batch,
3069 check_aso_partial_params,
3070 check_aso_default_candles,
3071 check_aso_zero_period,
3072 check_aso_period_exceeds_length,
3073 check_aso_invalid_mode,
3074 check_aso_empty_input,
3075 check_aso_all_nan,
3076 check_aso_very_small_dataset,
3077 check_aso_reinput,
3078 check_aso_nan_handling,
3079 check_aso_streaming,
3080 check_aso_no_poison
3081 );
3082
3083 #[cfg(feature = "proptest")]
3084 generate_all_aso_tests!(check_aso_property);
3085
3086 macro_rules! gen_batch_tests {
3087 ($fn_name:ident) => {
3088 paste::paste! {
3089 #[test] fn [<$fn_name _scalar>]() {
3090 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
3091 }
3092 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3093 #[test] fn [<$fn_name _avx2>]() {
3094 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
3095 }
3096 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3097 #[test] fn [<$fn_name _avx512>]() {
3098 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
3099 }
3100 #[test] fn [<$fn_name _auto_detect>]() {
3101 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
3102 }
3103 }
3104 };
3105 }
3106
3107 gen_batch_tests!(check_batch_default_row);
3108 gen_batch_tests!(check_batch_sweep);
3109 gen_batch_tests!(check_batch_no_poison);
3110
3111 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
3112 #[test]
3113 fn test_aso_into_matches_api() -> Result<(), Box<dyn Error>> {
3114 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3115 let candles = read_candles_from_csv(file_path)?;
3116
3117 let input = AsoInput::from_candles(&candles, "close", AsoParams::default());
3118
3119 let base = aso(&input)?;
3120
3121 let mut bulls = vec![0.0; candles.close.len()];
3122 let mut bears = vec![0.0; candles.close.len()];
3123 aso_into(&input, &mut bulls, &mut bears)?;
3124
3125 assert_eq!(bulls.len(), base.bulls.len());
3126 assert_eq!(bears.len(), base.bears.len());
3127
3128 fn eq_or_both_nan(a: f64, b: f64) -> bool {
3129 (a.is_nan() && b.is_nan()) || (a == b)
3130 }
3131
3132 for i in 0..bulls.len() {
3133 assert!(
3134 eq_or_both_nan(bulls[i], base.bulls[i]),
3135 "bulls mismatch at {}: got {}, expected {}",
3136 i,
3137 bulls[i],
3138 base.bulls[i]
3139 );
3140 assert!(
3141 eq_or_both_nan(bears[i], base.bears[i]),
3142 "bears mismatch at {}: got {}, expected {}",
3143 i,
3144 bears[i],
3145 base.bears[i]
3146 );
3147 }
3148
3149 Ok(())
3150 }
3151
3152 #[test]
3153 fn test_new_api_features() -> Result<(), Box<dyn Error>> {
3154 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3155 let candles = read_candles_from_csv(file_path)?;
3156
3157 let input = AsoInput::from_candles(&candles, "close", AsoParams::default());
3158 let _data_ref: &[f64] = input.as_ref();
3159
3160 let builder = AsoBatchBuilder::new().period_static(10).mode_static(0);
3161 let output = builder.apply_candles(&candles)?;
3162 assert_eq!(output.combos.len(), 1);
3163
3164 let output2 = AsoBatchBuilder::with_default_candles(&candles)?;
3165 assert!(output2.combos.len() > 0);
3166
3167 let output3 = AsoBatchBuilder::with_default_slices(
3168 &candles.open,
3169 &candles.high,
3170 &candles.low,
3171 &candles.close,
3172 Kernel::Scalar,
3173 )?;
3174 assert!(output3.combos.len() > 0);
3175
3176 let params = AsoParams::default();
3177 if let Some(row) = output2.row_for_params(¶ms) {
3178 let bulls_row = output2.bulls_row(row);
3179 let bears_row = output2.bears_row(row);
3180 assert_eq!(bulls_row.len(), candles.close.len());
3181 assert_eq!(bears_row.len(), candles.close.len());
3182 }
3183
3184 if let Some((bulls, bears)) = output2.values_for(¶ms) {
3185 assert_eq!(bulls.len(), candles.close.len());
3186 assert_eq!(bears.len(), candles.close.len());
3187 }
3188
3189 let sweep = AsoBatchRange::default();
3190 let output4 = aso_batch_slice(
3191 &candles.open,
3192 &candles.high,
3193 &candles.low,
3194 &candles.close,
3195 &sweep,
3196 Kernel::Scalar,
3197 )?;
3198 assert!(output4.combos.len() > 0);
3199
3200 let output5 = aso_batch_par_slice(
3201 &candles.open,
3202 &candles.high,
3203 &candles.low,
3204 &candles.close,
3205 &sweep,
3206 Kernel::Scalar,
3207 )?;
3208 assert_eq!(output4.combos.len(), output5.combos.len());
3209
3210 let input_high = AsoInput::from_candles(&candles, "high", AsoParams::default());
3211 let _output_high = aso(&input_high)?;
3212
3213 Ok(())
3214 }
3215}