1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::utilities::data_loader::{source_type, Candles};
16use crate::utilities::enums::Kernel;
17use crate::utilities::helpers::{
18 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
19 make_uninit_matrix,
20};
21#[cfg(feature = "python")]
22use crate::utilities::kernel_validation::validate_kernel;
23use aligned_vec::{AVec, CACHELINE_ALIGN};
24#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
25use core::arch::x86_64::*;
26#[cfg(not(target_arch = "wasm32"))]
27use rayon::prelude::*;
28use std::convert::AsRef;
29use std::error::Error;
30use std::mem::MaybeUninit;
31use thiserror::Error;
32
33impl<'a> AsRef<[f64]> for MabInput<'a> {
34 #[inline(always)]
35 fn as_ref(&self) -> &[f64] {
36 match &self.data {
37 MabData::Slice(slice) => slice,
38 MabData::Candles { candles, source } => source_type(candles, source),
39 }
40 }
41}
42
43#[derive(Debug, Clone)]
44pub enum MabData<'a> {
45 Candles {
46 candles: &'a Candles,
47 source: &'a str,
48 },
49 Slice(&'a [f64]),
50}
51
52#[derive(Debug, Clone)]
53pub struct MabOutput {
54 pub upperband: Vec<f64>,
55 pub middleband: Vec<f64>,
56 pub lowerband: Vec<f64>,
57}
58
59#[derive(Debug, Clone)]
60#[cfg_attr(
61 all(target_arch = "wasm32", feature = "wasm"),
62 derive(Serialize, Deserialize)
63)]
64pub struct MabParams {
65 pub fast_period: Option<usize>,
66 pub slow_period: Option<usize>,
67 pub devup: Option<f64>,
68 pub devdn: Option<f64>,
69 pub fast_ma_type: Option<String>,
70 pub slow_ma_type: Option<String>,
71}
72
73impl Default for MabParams {
74 fn default() -> Self {
75 Self {
76 fast_period: Some(10),
77 slow_period: Some(50),
78 devup: Some(1.0),
79 devdn: Some(1.0),
80 fast_ma_type: Some("sma".to_string()),
81 slow_ma_type: Some("sma".to_string()),
82 }
83 }
84}
85
86#[derive(Debug, Clone)]
87pub struct MabInput<'a> {
88 pub data: MabData<'a>,
89 pub params: MabParams,
90}
91
92impl<'a> MabInput<'a> {
93 pub fn from_candles(candles: &'a Candles, source: &'a str, params: MabParams) -> Self {
94 Self {
95 data: MabData::Candles { candles, source },
96 params,
97 }
98 }
99
100 pub fn from_slice(slice: &'a [f64], params: MabParams) -> Self {
101 Self {
102 data: MabData::Slice(slice),
103 params,
104 }
105 }
106
107 pub fn with_default_params(data: MabData<'a>) -> Self {
108 Self {
109 data,
110 params: MabParams::default(),
111 }
112 }
113
114 pub fn with_default_candles(candles: &'a Candles) -> Self {
115 Self::from_candles(candles, "close", MabParams::default())
116 }
117
118 pub fn get_fast_period(&self) -> usize {
119 self.params.fast_period.unwrap_or(10)
120 }
121
122 pub fn get_slow_period(&self) -> usize {
123 self.params.slow_period.unwrap_or(50)
124 }
125
126 pub fn get_devup(&self) -> f64 {
127 self.params.devup.unwrap_or(1.0)
128 }
129
130 pub fn get_devdn(&self) -> f64 {
131 self.params.devdn.unwrap_or(1.0)
132 }
133
134 pub fn get_fast_ma_type(&self) -> &str {
135 self.params
136 .fast_ma_type
137 .as_ref()
138 .map(|s| s.as_str())
139 .unwrap_or("sma")
140 }
141
142 pub fn get_slow_ma_type(&self) -> &str {
143 self.params
144 .slow_ma_type
145 .as_ref()
146 .map(|s| s.as_str())
147 .unwrap_or("sma")
148 }
149}
150
151#[derive(Error, Debug)]
152pub enum MabError {
153 #[error("mab: Input data slice is empty.")]
154 EmptyInputData,
155 #[error("mab: All values are NaN.")]
156 AllValuesNaN,
157 #[error("mab: Invalid period: fast={fast} slow={slow} len={data_len}")]
158 InvalidPeriod {
159 fast: usize,
160 slow: usize,
161 data_len: usize,
162 },
163 #[error("mab: Not enough valid data: need={needed} valid={valid}")]
164 NotEnoughValidData { needed: usize, valid: usize },
165 #[error(
166 "mab: Output length mismatch: expected={expected} upper={upper_len} middle={middle_len} lower={lower_len}"
167 )]
168 OutputLengthMismatch {
169 upper_len: usize,
170 middle_len: usize,
171 lower_len: usize,
172 expected: usize,
173 },
174 #[error("mab: Invalid range (start={start}, end={end}, step={step})")]
175 InvalidRange {
176 start: usize,
177 end: usize,
178 step: usize,
179 },
180 #[error("mab: Invalid range (f64) (start={start}, end={end}, step={step})")]
181 InvalidRangeF64 { start: f64, end: f64, step: f64 },
182 #[error("mab: non-batch kernel passed to batch path: {0:?}")]
183 InvalidKernelForBatch(Kernel),
184}
185
186#[inline(always)]
187fn mab_validate(input: &MabInput) -> Result<usize, MabError> {
188 let data = input.as_ref();
189 if data.is_empty() {
190 return Err(MabError::EmptyInputData);
191 }
192 let fast = input.get_fast_period();
193 let slow = input.get_slow_period();
194 if fast == 0 || slow == 0 || fast > data.len() || slow > data.len() {
195 return Err(MabError::InvalidPeriod {
196 fast,
197 slow,
198 data_len: data.len(),
199 });
200 }
201
202 let first_valid = data
203 .iter()
204 .position(|&x| !x.is_nan())
205 .ok_or(MabError::AllValuesNaN)?;
206 let max_period = fast.max(slow);
207 if data.len() - first_valid < max_period {
208 return Err(MabError::NotEnoughValidData {
209 needed: max_period,
210 valid: data.len() - first_valid,
211 });
212 }
213 Ok(first_valid)
214}
215
216#[inline(always)]
217fn mab_prepare<'a>(
218 input: &'a MabInput,
219 kernel: Kernel,
220) -> Result<(&'a [f64], usize, Kernel, usize, usize, f64, f64), MabError> {
221 let first_valid = mab_validate(input)?;
222 let data = input.as_ref();
223 let fast = input.get_fast_period();
224 let slow = input.get_slow_period();
225 let devup = input.get_devup();
226 let devdn = input.get_devdn();
227 let chosen = match kernel {
228 Kernel::Auto => Kernel::Scalar,
229 _ => kernel,
230 };
231 let warmup = first_valid + fast.max(slow) - 1;
232 Ok((data, warmup, chosen, fast, slow, devup, devdn))
233}
234
235#[inline(always)]
236fn mab_prepare2<'a>(
237 input: &'a MabInput,
238 kernel: Kernel,
239) -> Result<(&'a [f64], Kernel, usize, usize, usize, f64, f64), MabError> {
240 let data = input.as_ref();
241 if data.is_empty() {
242 return Err(MabError::EmptyInputData);
243 }
244 let fast = input.get_fast_period();
245 let slow = input.get_slow_period();
246 if fast == 0 || slow == 0 || fast > data.len() || slow > data.len() {
247 return Err(MabError::InvalidPeriod {
248 fast,
249 slow,
250 data_len: data.len(),
251 });
252 }
253 let first = data
254 .iter()
255 .position(|&x| !x.is_nan())
256 .ok_or(MabError::AllValuesNaN)?;
257 let need = fast.max(slow);
258
259 let need_total = need + fast - 1;
260 if data.len() - first < need_total {
261 return Err(MabError::NotEnoughValidData {
262 needed: need_total,
263 valid: data.len() - first,
264 });
265 }
266 let chosen = match kernel {
267 Kernel::Auto => Kernel::Scalar,
268 k => k,
269 };
270 let warmup = first + need_total - 1;
271 Ok((
272 data,
273 chosen,
274 first,
275 warmup,
276 fast,
277 input.get_devup(),
278 input.get_devdn(),
279 ))
280}
281
282pub fn mab_into_slice(
283 upper_dst: &mut [f64],
284 middle_dst: &mut [f64],
285 lower_dst: &mut [f64],
286 input: &MabInput,
287 kern: Kernel,
288) -> Result<(), MabError> {
289 use crate::indicators::ema::{ema, EmaInput, EmaParams};
290 use crate::indicators::sma::{sma, SmaInput, SmaParams};
291
292 let (data, chosen, first, warmup, fast_period, devup, devdn) = mab_prepare2(input, kern)?;
293 let slow_period = input.get_slow_period();
294
295 let n = data.len();
296 if upper_dst.len() != n || middle_dst.len() != n || lower_dst.len() != n {
297 return Err(MabError::OutputLengthMismatch {
298 upper_len: upper_dst.len(),
299 middle_len: middle_dst.len(),
300 lower_len: lower_dst.len(),
301 expected: n,
302 });
303 }
304
305 let fast_ma_type = input.get_fast_ma_type();
306 let slow_ma_type = input.get_slow_ma_type();
307
308 let fast_ma = match fast_ma_type {
309 "ema" => {
310 let params = EmaParams {
311 period: Some(fast_period),
312 };
313 ema(&EmaInput::from_slice(data, params))
314 .map_err(|_| MabError::NotEnoughValidData {
315 needed: fast_period,
316 valid: n - first,
317 })?
318 .values
319 }
320 _ => {
321 let params = SmaParams {
322 period: Some(fast_period),
323 };
324 sma(&SmaInput::from_slice(data, params))
325 .map_err(|_| MabError::NotEnoughValidData {
326 needed: fast_period,
327 valid: n - first,
328 })?
329 .values
330 }
331 };
332
333 let slow_ma = match slow_ma_type {
334 "ema" => {
335 let params = EmaParams {
336 period: Some(slow_period),
337 };
338 ema(&EmaInput::from_slice(data, params))
339 .map_err(|_| MabError::NotEnoughValidData {
340 needed: slow_period,
341 valid: n - first,
342 })?
343 .values
344 }
345 _ => {
346 let params = SmaParams {
347 period: Some(slow_period),
348 };
349 sma(&SmaInput::from_slice(data, params))
350 .map_err(|_| MabError::NotEnoughValidData {
351 needed: slow_period,
352 valid: n - first,
353 })?
354 .values
355 }
356 };
357
358 let warmup_end = (warmup + 1).min(n);
359 for dst in [&mut *upper_dst, &mut *middle_dst, &mut *lower_dst] {
360 for v in &mut dst[..warmup_end] {
361 *v = f64::NAN;
362 }
363 }
364
365 mab_compute_into(
366 &fast_ma,
367 &slow_ma,
368 fast_period,
369 devup,
370 devdn,
371 warmup + 1,
372 chosen,
373 upper_dst,
374 middle_dst,
375 lower_dst,
376 );
377
378 Ok(())
379}
380
381#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
382pub fn mab_into(
383 input: &MabInput,
384 upper_dst: &mut [f64],
385 middle_dst: &mut [f64],
386 lower_dst: &mut [f64],
387) -> Result<(), MabError> {
388 mab_into_slice(upper_dst, middle_dst, lower_dst, input, Kernel::Auto)
389}
390
391pub fn mab(input: &MabInput) -> Result<MabOutput, MabError> {
392 let (_data, _warmup, kernel, _fast, _slow, _devup, _devdn) = mab_prepare(input, Kernel::Auto)?;
393 mab_with_kernel(input, kernel)
394}
395
396pub fn mab_with_kernel(input: &MabInput, kernel: Kernel) -> Result<MabOutput, MabError> {
397 let data = input.as_ref();
398 let (_, _, _, warmup, _, _, _) = mab_prepare2(input, kernel)?;
399
400 let mut upperband = alloc_with_nan_prefix(data.len(), warmup);
401 let mut middleband = alloc_with_nan_prefix(data.len(), warmup);
402 let mut lowerband = alloc_with_nan_prefix(data.len(), warmup);
403
404 mab_into_slice(
405 &mut upperband,
406 &mut middleband,
407 &mut lowerband,
408 input,
409 kernel,
410 )?;
411
412 Ok(MabOutput {
413 upperband,
414 middleband,
415 lowerband,
416 })
417}
418
419#[cfg(test)]
420mod into_parity_tests {
421 use super::*;
422
423 fn eq_or_both_nan(a: f64, b: f64) -> bool {
424 (a.is_nan() && b.is_nan()) || (a == b)
425 }
426
427 #[test]
428 fn test_mab_into_matches_api() {
429 let n = 256usize;
430 let mut data = vec![f64::NAN; n];
431 for i in 5..n {
432 data[i] = (i as f64).sin() * 0.5 + (i as f64).cos() * 0.25;
433 }
434
435 let input = MabInput::from_slice(&data, MabParams::default());
436
437 let base = mab(&input).expect("mab baseline should succeed");
438
439 let mut up = vec![0.0; n];
440 let mut mid = vec![0.0; n];
441 let mut lo = vec![0.0; n];
442
443 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
444 {
445 mab_into(&input, &mut up, &mut mid, &mut lo).expect("mab_into should succeed");
446 }
447
448 assert_eq!(base.upperband.len(), n);
449 assert_eq!(base.middleband.len(), n);
450 assert_eq!(base.lowerband.len(), n);
451 assert_eq!(up.len(), n);
452 assert_eq!(mid.len(), n);
453 assert_eq!(lo.len(), n);
454
455 for i in 0..n {
456 assert!(
457 eq_or_both_nan(base.upperband[i], up[i]),
458 "upper mismatch at {}: base={:?} into={:?}",
459 i,
460 base.upperband[i],
461 up[i]
462 );
463 assert!(
464 eq_or_both_nan(base.middleband[i], mid[i]),
465 "middle mismatch at {}: base={:?} into={:?}",
466 i,
467 base.middleband[i],
468 mid[i]
469 );
470 assert!(
471 eq_or_both_nan(base.lowerband[i], lo[i]),
472 "lower mismatch at {}: base={:?} into={:?}",
473 i,
474 base.lowerband[i],
475 lo[i]
476 );
477 }
478 }
479}
480
481#[inline]
482fn mab_compute_into(
483 fast_ma: &[f64],
484 slow_ma: &[f64],
485 fast_period: usize,
486 devup: f64,
487 devdn: f64,
488 first_output: usize,
489 kernel: Kernel,
490 upper: &mut [f64],
491 mid: &mut [f64],
492 lower: &mut [f64],
493) {
494 unsafe {
495 match kernel {
496 Kernel::Scalar | Kernel::ScalarBatch => mab_scalar(
497 fast_ma,
498 slow_ma,
499 fast_period,
500 devup,
501 devdn,
502 first_output,
503 upper,
504 mid,
505 lower,
506 ),
507 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
508 Kernel::Avx2 | Kernel::Avx2Batch => mab_avx2(
509 fast_ma,
510 slow_ma,
511 fast_period,
512 devup,
513 devdn,
514 first_output,
515 upper,
516 mid,
517 lower,
518 ),
519 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
520 Kernel::Avx512 | Kernel::Avx512Batch => mab_avx512(
521 fast_ma,
522 slow_ma,
523 fast_period,
524 devup,
525 devdn,
526 first_output,
527 upper,
528 mid,
529 lower,
530 ),
531 _ => unreachable!(),
532 }
533 }
534}
535
536#[inline(always)]
537pub unsafe fn mab_scalar(
538 fast_ma: &[f64],
539 slow_ma: &[f64],
540 fast_period: usize,
541 devup: f64,
542 devdn: f64,
543 first_output: usize,
544 upper: &mut [f64],
545 mid: &mut [f64],
546 lower: &mut [f64],
547) {
548 let start_idx = if first_output >= fast_period {
549 first_output - fast_period + 1
550 } else {
551 0
552 };
553
554 let mut sum_sq = 0.0;
555 for i in start_idx..(start_idx + fast_period).min(fast_ma.len()) {
556 let diff = fast_ma[i] - slow_ma[i];
557 sum_sq += diff * diff;
558 }
559
560 if first_output < fast_ma.len() {
561 let dev = (sum_sq / fast_period as f64).sqrt();
562 mid[first_output] = fast_ma[first_output];
563 upper[first_output] = slow_ma[first_output] + devup * dev;
564 lower[first_output] = slow_ma[first_output] - devdn * dev;
565 }
566
567 for i in (first_output + 1)..fast_ma.len() {
568 let old_idx = i - fast_period;
569 let old = fast_ma[old_idx] - slow_ma[old_idx];
570 let new = fast_ma[i] - slow_ma[i];
571 sum_sq += new * new - old * old;
572 let dev = (sum_sq / fast_period as f64).sqrt();
573
574 mid[i] = fast_ma[i];
575 upper[i] = slow_ma[i] + devup * dev;
576 lower[i] = slow_ma[i] - devdn * dev;
577 }
578}
579
580#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
581#[target_feature(enable = "avx2,fma")]
582pub unsafe fn mab_avx2(
583 fast_ma: &[f64],
584 slow_ma: &[f64],
585 fast_period: usize,
586 devup: f64,
587 devdn: f64,
588 first_output: usize,
589 upper: &mut [f64],
590 mid: &mut [f64],
591 lower: &mut [f64],
592) {
593 use core::arch::x86_64::*;
594 let n = fast_ma.len();
595 if first_output >= n {
596 return;
597 }
598 debug_assert!(fast_period > 0);
599 debug_assert!(first_output + 1 >= fast_period);
600
601 let start = first_output + 1 - fast_period;
602 let m = n - start;
603
604 let mut diffsq: AVec<f64> = AVec::with_capacity(CACHELINE_ALIGN, m);
605 diffsq.set_len(m);
606 let mut prefix: AVec<f64> = AVec::with_capacity(CACHELINE_ALIGN, m + 1);
607 prefix.set_len(m + 1);
608
609 let f0 = fast_ma.as_ptr().add(start);
610 let s0 = slow_ma.as_ptr().add(start);
611 let dptr = diffsq.as_mut_ptr();
612
613 let mut k = 0usize;
614 while k + 4 <= m {
615 let vf = _mm256_loadu_pd(f0.add(k));
616 let vs = _mm256_loadu_pd(s0.add(k));
617 let vd = _mm256_sub_pd(vf, vs);
618 let vd2 = _mm256_mul_pd(vd, vd);
619 _mm256_storeu_pd(dptr.add(k), vd2);
620 k += 4;
621 }
622 while k < m {
623 let d = *f0.add(k) - *s0.add(k);
624 *dptr.add(k) = d * d;
625 k += 1;
626 }
627
628 let pptr = prefix.as_mut_ptr();
629 *pptr = 0.0;
630 let mut acc = 0.0f64;
631 k = 0;
632 while k < m {
633 acc += *dptr.add(k);
634 *pptr.add(k + 1) = acc;
635 k += 1;
636 }
637
638 let invf = 1.0 / (fast_period as f64);
639 let vinvf = _mm256_set1_pd(invf);
640 let vup = _mm256_set1_pd(devup);
641 let vdn = _mm256_set1_pd(devdn);
642
643 let mut i = first_output;
644
645 while i < n && (i & 3) != 0 {
646 let base = i - start;
647 let sum = *pptr.add(base + 1) - *pptr.add(base + 1 - fast_period);
648 let dev = (sum * invf).sqrt();
649 let sm = *slow_ma.as_ptr().add(i);
650 *mid.as_mut_ptr().add(i) = *fast_ma.as_ptr().add(i);
651 *upper.as_mut_ptr().add(i) = sm + devup * dev;
652 *lower.as_mut_ptr().add(i) = sm - devdn * dev;
653 i += 1;
654 }
655
656 while i + 3 < n {
657 let base = i - start;
658 let pend = _mm256_loadu_pd(pptr.add(base + 1));
659 let psta = _mm256_loadu_pd(pptr.add(base + 1 - fast_period));
660 let vsum = _mm256_sub_pd(pend, psta);
661 let vdev = _mm256_sqrt_pd(_mm256_mul_pd(vsum, vinvf));
662
663 let vfast = _mm256_loadu_pd(fast_ma.as_ptr().add(i));
664 let vslow = _mm256_loadu_pd(slow_ma.as_ptr().add(i));
665
666 _mm256_storeu_pd(mid.as_mut_ptr().add(i), vfast);
667 let vupper = _mm256_fmadd_pd(vup, vdev, vslow);
668 let vlower = _mm256_fnmadd_pd(vdn, vdev, vslow);
669 _mm256_storeu_pd(upper.as_mut_ptr().add(i), vupper);
670 _mm256_storeu_pd(lower.as_mut_ptr().add(i), vlower);
671
672 i += 4;
673 }
674
675 while i < n {
676 let base = i - start;
677 let sum = *pptr.add(base + 1) - *pptr.add(base + 1 - fast_period);
678 let dev = (sum * invf).sqrt();
679 let sm = *slow_ma.as_ptr().add(i);
680 *mid.as_mut_ptr().add(i) = *fast_ma.as_ptr().add(i);
681 *upper.as_mut_ptr().add(i) = sm + devup * dev;
682 *lower.as_mut_ptr().add(i) = sm - devdn * dev;
683 i += 1;
684 }
685}
686
687#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
688#[target_feature(enable = "avx512f")]
689pub unsafe fn mab_avx512(
690 fast_ma: &[f64],
691 slow_ma: &[f64],
692 fast_period: usize,
693 devup: f64,
694 devdn: f64,
695 first_output: usize,
696 upper: &mut [f64],
697 mid: &mut [f64],
698 lower: &mut [f64],
699) {
700 use core::arch::x86_64::*;
701 let n = fast_ma.len();
702 if first_output >= n {
703 return;
704 }
705 debug_assert!(fast_period > 0);
706 debug_assert!(first_output + 1 >= fast_period);
707
708 let start = first_output + 1 - fast_period;
709 let m = n - start;
710
711 let mut diffsq: AVec<f64> = AVec::with_capacity(CACHELINE_ALIGN, m);
712 diffsq.set_len(m);
713 let mut prefix: AVec<f64> = AVec::with_capacity(CACHELINE_ALIGN, m + 1);
714 prefix.set_len(m + 1);
715
716 let f0 = fast_ma.as_ptr().add(start);
717 let s0 = slow_ma.as_ptr().add(start);
718 let dptr = diffsq.as_mut_ptr();
719
720 let mut k = 0usize;
721 while k + 8 <= m {
722 let vf = _mm512_loadu_pd(f0.add(k));
723 let vs = _mm512_loadu_pd(s0.add(k));
724 let vd = _mm512_sub_pd(vf, vs);
725 let vd2 = _mm512_mul_pd(vd, vd);
726 _mm512_storeu_pd(dptr.add(k), vd2);
727 k += 8;
728 }
729 while k < m {
730 let d = *f0.add(k) - *s0.add(k);
731 *dptr.add(k) = d * d;
732 k += 1;
733 }
734
735 let pptr = prefix.as_mut_ptr();
736 *pptr = 0.0;
737 let mut acc = 0.0f64;
738 k = 0;
739 while k < m {
740 acc += *dptr.add(k);
741 *pptr.add(k + 1) = acc;
742 k += 1;
743 }
744
745 let invf = 1.0 / (fast_period as f64);
746 let vinvf = _mm512_set1_pd(invf);
747 let vup = _mm512_set1_pd(devup);
748 let vdn = _mm512_set1_pd(devdn);
749
750 let mut i = first_output;
751
752 while i < n && (i & 7) != 0 {
753 let base = i - start;
754 let sum = *pptr.add(base + 1) - *pptr.add(base + 1 - fast_period);
755 let dev = (sum * invf).sqrt();
756 let sm = *slow_ma.as_ptr().add(i);
757 *mid.as_mut_ptr().add(i) = *fast_ma.as_ptr().add(i);
758 *upper.as_mut_ptr().add(i) = sm + devup * dev;
759 *lower.as_mut_ptr().add(i) = sm - devdn * dev;
760 i += 1;
761 }
762
763 while i + 7 < n {
764 let base = i - start;
765 let pend = _mm512_loadu_pd(pptr.add(base + 1));
766 let psta = _mm512_loadu_pd(pptr.add(base + 1 - fast_period));
767 let vsum = _mm512_sub_pd(pend, psta);
768 let vdev = _mm512_sqrt_pd(_mm512_mul_pd(vsum, vinvf));
769
770 let vfast = _mm512_loadu_pd(fast_ma.as_ptr().add(i));
771 let vslow = _mm512_loadu_pd(slow_ma.as_ptr().add(i));
772
773 _mm512_storeu_pd(mid.as_mut_ptr().add(i), vfast);
774 let vupper = _mm512_fmadd_pd(vup, vdev, vslow);
775 let vlower = _mm512_fnmadd_pd(vdn, vdev, vslow);
776 _mm512_storeu_pd(upper.as_mut_ptr().add(i), vupper);
777 _mm512_storeu_pd(lower.as_mut_ptr().add(i), vlower);
778
779 i += 8;
780 }
781
782 while i < n {
783 let base = i - start;
784 let sum = *pptr.add(base + 1) - *pptr.add(base + 1 - fast_period);
785 let dev = (sum * invf).sqrt();
786 let sm = *slow_ma.as_ptr().add(i);
787 *mid.as_mut_ptr().add(i) = *fast_ma.as_ptr().add(i);
788 *upper.as_mut_ptr().add(i) = sm + devup * dev;
789 *lower.as_mut_ptr().add(i) = sm - devdn * dev;
790 i += 1;
791 }
792}
793
794pub struct MabStream {
795 fast_buffer: Vec<f64>,
796 slow_buffer: Vec<f64>,
797 diffs_buffer: Vec<f64>,
798 fast_index: usize,
799 slow_index: usize,
800 diff_index: usize,
801 count: usize,
802 fast_period: usize,
803 slow_period: usize,
804 devup: f64,
805 devdn: f64,
806 fast_ma_type: String,
807 slow_ma_type: String,
808 fast_sum: f64,
809 slow_sum: f64,
810 fast_ma: f64,
811 slow_ma: f64,
812 ema_fast: f64,
813 ema_slow: f64,
814 kernel: Kernel,
815
816 sumsq_diff: f64,
817 diffs_filled: usize,
818 inv_fast_len: f64,
819 ready_threshold: usize,
820 k_fast: f64,
821 k_slow: f64,
822 max_period: usize,
823}
824
825impl MabStream {
826 pub fn try_new(params: MabParams) -> Result<Self, String> {
827 let fast_period = params.fast_period.unwrap_or(10);
828 let slow_period = params.slow_period.unwrap_or(50);
829 let devup = params.devup.unwrap_or(1.0);
830 let devdn = params.devdn.unwrap_or(1.0);
831 let fast_ma_type = params.fast_ma_type.unwrap_or_else(|| "sma".to_string());
832 let slow_ma_type = params.slow_ma_type.unwrap_or_else(|| "sma".to_string());
833
834 if fast_period == 0 || slow_period == 0 {
835 return Err("Period cannot be zero".to_string());
836 }
837
838 let max_period = fast_period.max(slow_period);
839 let ready_threshold = max_period + fast_period - 1;
840
841 Ok(Self {
842 fast_buffer: vec![0.0; fast_period],
843 slow_buffer: vec![0.0; slow_period],
844 diffs_buffer: vec![0.0; fast_period],
845 fast_index: 0,
846 slow_index: 0,
847 diff_index: 0,
848 count: 0,
849 fast_period,
850 slow_period,
851 devup,
852 devdn,
853 fast_ma_type,
854 slow_ma_type,
855 fast_sum: 0.0,
856 slow_sum: 0.0,
857 fast_ma: 0.0,
858 slow_ma: 0.0,
859 ema_fast: 0.0,
860 ema_slow: 0.0,
861 kernel: detect_best_kernel(),
862
863 sumsq_diff: 0.0,
864 diffs_filled: 0,
865 inv_fast_len: 1.0 / fast_period as f64,
866 ready_threshold,
867 k_fast: 2.0 / (fast_period as f64 + 1.0),
868 k_slow: 2.0 / (slow_period as f64 + 1.0),
869 max_period,
870 })
871 }
872
873 pub fn update(&mut self, value: f64) -> Option<(f64, f64, f64)> {
874 if !value.is_finite() {
875 return None;
876 }
877
878 self.count += 1;
879
880 match self.fast_ma_type.as_str() {
881 "ema" => {
882 if self.count == 1 {
883 self.ema_fast = value;
884 } else {
885 self.ema_fast = (1.0 - self.k_fast).mul_add(self.ema_fast, self.k_fast * value);
886 }
887 self.fast_ma = self.ema_fast;
888 }
889 _ => {
890 if self.count <= self.fast_period {
891 let idx = self.fast_index;
892 self.fast_sum += value;
893 self.fast_buffer[idx] = value;
894 if self.count == self.fast_period {
895 self.fast_ma = self.fast_sum * self.inv_fast_len;
896 }
897 self.fast_index += 1;
898 if self.fast_index == self.fast_period {
899 self.fast_index = 0;
900 }
901 } else {
902 let idx = self.fast_index;
903 let old = self.fast_buffer[idx];
904 self.fast_buffer[idx] = value;
905 self.fast_sum += value - old;
906 self.fast_ma = self.fast_sum * self.inv_fast_len;
907 self.fast_index += 1;
908 if self.fast_index == self.fast_period {
909 self.fast_index = 0;
910 }
911 }
912 }
913 }
914
915 match self.slow_ma_type.as_str() {
916 "ema" => {
917 if self.count == 1 {
918 self.ema_slow = value;
919 } else {
920 self.ema_slow = (1.0 - self.k_slow).mul_add(self.ema_slow, self.k_slow * value);
921 }
922 self.slow_ma = self.ema_slow;
923 }
924 _ => {
925 if self.count <= self.slow_period {
926 let idx = self.slow_index;
927 self.slow_sum += value;
928 self.slow_buffer[idx] = value;
929 if self.count == self.slow_period {
930 self.slow_ma = self.slow_sum / self.slow_period as f64;
931 }
932 self.slow_index += 1;
933 if self.slow_index == self.slow_period {
934 self.slow_index = 0;
935 }
936 } else {
937 let idx = self.slow_index;
938 let old = self.slow_buffer[idx];
939 self.slow_buffer[idx] = value;
940 self.slow_sum += value - old;
941 self.slow_ma = self.slow_sum / self.slow_period as f64;
942 self.slow_index += 1;
943 if self.slow_index == self.slow_period {
944 self.slow_index = 0;
945 }
946 }
947 }
948 }
949
950 if self.count < self.max_period {
951 return None;
952 }
953
954 let diff = self.fast_ma - self.slow_ma;
955 let diff2 = diff * diff;
956
957 if self.diffs_filled < self.fast_period {
958 self.sumsq_diff += diff2;
959 self.diffs_buffer[self.diff_index] = diff2;
960 self.diff_index += 1;
961 if self.diff_index == self.fast_period {
962 self.diff_index = 0;
963 }
964 self.diffs_filled += 1;
965 } else {
966 let old2 = self.diffs_buffer[self.diff_index];
967 self.sumsq_diff += diff2 - old2;
968 self.diffs_buffer[self.diff_index] = diff2;
969 self.diff_index += 1;
970 if self.diff_index == self.fast_period {
971 self.diff_index = 0;
972 }
973 }
974
975 if self.count < self.ready_threshold || self.diffs_filled < self.fast_period {
976 return None;
977 }
978
979 let dev = (self.sumsq_diff * self.inv_fast_len).sqrt();
980
981 let upper = dev.mul_add(self.devup, self.slow_ma);
982 let middle = self.fast_ma;
983 let lower = (-self.devdn * dev).mul_add(1.0, self.slow_ma);
984
985 Some((upper, middle, lower))
986 }
987}
988
989#[derive(Clone, Debug)]
990pub struct MabBatchRange {
991 pub fast_period: (usize, usize, usize),
992 pub slow_period: (usize, usize, usize),
993 pub devup: (f64, f64, f64),
994 pub devdn: (f64, f64, f64),
995 pub fast_ma_type: (String, String, String),
996 pub slow_ma_type: (String, String, String),
997}
998
999impl Default for MabBatchRange {
1000 fn default() -> Self {
1001 Self {
1002 fast_period: (10, 10, 0),
1003 slow_period: (50, 299, 1),
1004 devup: (1.0, 1.0, 0.0),
1005 devdn: (1.0, 1.0, 0.0),
1006 fast_ma_type: ("sma".to_string(), "sma".to_string(), String::new()),
1007 slow_ma_type: ("sma".to_string(), "sma".to_string(), String::new()),
1008 }
1009 }
1010}
1011
1012#[derive(Clone, Debug)]
1013pub struct MabBatchOutput {
1014 pub upperbands: Vec<f64>,
1015 pub middlebands: Vec<f64>,
1016 pub lowerbands: Vec<f64>,
1017 pub combos: Vec<MabParams>,
1018 pub rows: usize,
1019 pub cols: usize,
1020}
1021impl MabBatchOutput {
1022 pub fn row_for_params(&self, p: &MabParams) -> Option<usize> {
1023 self.combos.iter().position(|c| {
1024 c.fast_period == p.fast_period
1025 && c.slow_period == p.slow_period
1026 && c.devup == p.devup
1027 && c.devdn == p.devdn
1028 && c.fast_ma_type == p.fast_ma_type
1029 && c.slow_ma_type == p.slow_ma_type
1030 })
1031 }
1032
1033 pub fn upper_slice(&self, row: usize) -> Option<&[f64]> {
1034 if row < self.rows {
1035 let start = row * self.cols;
1036 let end = start + self.cols;
1037 Some(&self.upperbands[start..end])
1038 } else {
1039 None
1040 }
1041 }
1042
1043 pub fn middle_slice(&self, row: usize) -> Option<&[f64]> {
1044 if row < self.rows {
1045 let start = row * self.cols;
1046 let end = start + self.cols;
1047 Some(&self.middlebands[start..end])
1048 } else {
1049 None
1050 }
1051 }
1052
1053 pub fn lower_slice(&self, row: usize) -> Option<&[f64]> {
1054 if row < self.rows {
1055 let start = row * self.cols;
1056 let end = start + self.cols;
1057 Some(&self.lowerbands[start..end])
1058 } else {
1059 None
1060 }
1061 }
1062}
1063
1064pub(crate) fn expand_grid(p: &MabBatchRange) -> Result<Vec<MabParams>, MabError> {
1065 fn axis_usize(axis: (usize, usize, usize)) -> Result<Vec<usize>, MabError> {
1066 let (start, end, step) = axis;
1067 if step == 0 || start == end {
1068 return Ok(vec![start]);
1069 }
1070 let (lo, hi) = if start <= end {
1071 (start, end)
1072 } else {
1073 (end, start)
1074 };
1075 let v: Vec<usize> = (lo..=hi).step_by(step).collect();
1076 if v.is_empty() {
1077 return Err(MabError::InvalidRange { start, end, step });
1078 }
1079 Ok(v)
1080 }
1081
1082 fn axis_f64(axis: (f64, f64, f64)) -> Result<Vec<f64>, MabError> {
1083 let (start, end, step) = axis;
1084 const EPS: f64 = 1e-12;
1085 if step.abs() < EPS || (start - end).abs() < EPS {
1086 return Ok(vec![start]);
1087 }
1088 let step_eff = if start <= end {
1089 step.abs()
1090 } else {
1091 -step.abs()
1092 };
1093 let mut v = Vec::new();
1094 let mut x = start;
1095 if step_eff > 0.0 {
1096 while x <= end + EPS {
1097 v.push(x);
1098 x += step_eff;
1099 }
1100 } else {
1101 while x >= end - EPS {
1102 v.push(x);
1103 x += step_eff;
1104 }
1105 }
1106 if v.is_empty() {
1107 return Err(MabError::InvalidRangeF64 { start, end, step });
1108 }
1109 Ok(v)
1110 }
1111
1112 let fast_periods = axis_usize(p.fast_period)?;
1113 let slow_periods = axis_usize(p.slow_period)?;
1114 let devups = axis_f64(p.devup)?;
1115 let devdns = axis_f64(p.devdn)?;
1116
1117 let mut combos =
1118 Vec::with_capacity(fast_periods.len() * slow_periods.len() * devups.len() * devdns.len());
1119
1120 for &fast in &fast_periods {
1121 for &slow in &slow_periods {
1122 for &devup in &devups {
1123 for &devdn in &devdns {
1124 combos.push(MabParams {
1125 fast_period: Some(fast),
1126 slow_period: Some(slow),
1127 devup: Some(devup),
1128 devdn: Some(devdn),
1129 fast_ma_type: Some(p.fast_ma_type.0.clone()),
1130 slow_ma_type: Some(p.slow_ma_type.0.clone()),
1131 });
1132 }
1133 }
1134 }
1135 }
1136
1137 Ok(combos)
1138}
1139
1140pub fn mab_batch(input: &[f64], sweep: &MabBatchRange) -> Result<MabBatchOutput, MabError> {
1141 mab_batch_inner(input, sweep, Kernel::Auto, false)
1142}
1143
1144fn mab_batch_inner(
1145 input: &[f64],
1146 sweep: &MabBatchRange,
1147 kernel: Kernel,
1148 parallel: bool,
1149) -> Result<MabBatchOutput, MabError> {
1150 let kernel = match kernel {
1151 Kernel::Auto => detect_best_batch_kernel(),
1152 k => k,
1153 };
1154 let simd = match kernel {
1155 Kernel::Avx512Batch => Kernel::Avx512,
1156 Kernel::Avx2Batch => Kernel::Avx2,
1157 Kernel::ScalarBatch => Kernel::Scalar,
1158 other => {
1159 return Err(MabError::InvalidKernelForBatch(other));
1160 }
1161 };
1162
1163 let combos = expand_grid(sweep)?;
1164 let rows = combos.len();
1165 let cols = input.len();
1166 if cols == 0 {
1167 return Err(MabError::EmptyInputData);
1168 }
1169 rows.checked_mul(cols).ok_or(MabError::InvalidRange {
1170 start: sweep.fast_period.0,
1171 end: sweep.fast_period.1,
1172 step: sweep.fast_period.2,
1173 })?;
1174
1175 let first_valid = input
1176 .iter()
1177 .position(|x| !x.is_nan())
1178 .ok_or(MabError::AllValuesNaN)?;
1179 let valid = cols - first_valid;
1180 let warmup_prefixes: Vec<usize> = combos
1181 .iter()
1182 .map(|p| {
1183 let fast = p.fast_period.unwrap();
1184 let slow = p.slow_period.unwrap();
1185 if fast == 0 || slow == 0 || fast > cols || slow > cols {
1186 return Err(MabError::InvalidPeriod {
1187 fast,
1188 slow,
1189 data_len: cols,
1190 });
1191 }
1192 let need_total = fast.max(slow) + fast - 1;
1193 if valid < need_total {
1194 return Err(MabError::NotEnoughValidData {
1195 needed: need_total,
1196 valid,
1197 });
1198 }
1199 Ok(first_valid + need_total)
1200 })
1201 .collect::<Result<Vec<_>, MabError>>()?;
1202
1203 let mut upper_buf = make_uninit_matrix(rows, cols);
1204 let mut middle_buf = make_uninit_matrix(rows, cols);
1205 let mut lower_buf = make_uninit_matrix(rows, cols);
1206
1207 init_matrix_prefixes(&mut upper_buf, cols, &warmup_prefixes);
1208 init_matrix_prefixes(&mut middle_buf, cols, &warmup_prefixes);
1209 init_matrix_prefixes(&mut lower_buf, cols, &warmup_prefixes);
1210
1211 let upper_slice =
1212 unsafe { std::slice::from_raw_parts_mut(upper_buf.as_mut_ptr() as *mut f64, rows * cols) };
1213 let middle_slice =
1214 unsafe { std::slice::from_raw_parts_mut(middle_buf.as_mut_ptr() as *mut f64, rows * cols) };
1215 let lower_slice =
1216 unsafe { std::slice::from_raw_parts_mut(lower_buf.as_mut_ptr() as *mut f64, rows * cols) };
1217
1218 let combos = mab_batch_inner_into(
1219 input,
1220 sweep,
1221 simd,
1222 parallel,
1223 upper_slice,
1224 middle_slice,
1225 lower_slice,
1226 )?;
1227
1228 let mut upper_guard = core::mem::ManuallyDrop::new(upper_buf);
1229 let mut middle_guard = core::mem::ManuallyDrop::new(middle_buf);
1230 let mut lower_guard = core::mem::ManuallyDrop::new(lower_buf);
1231
1232 let upperbands = unsafe {
1233 Vec::from_raw_parts(
1234 upper_guard.as_mut_ptr() as *mut f64,
1235 upper_guard.len(),
1236 upper_guard.capacity(),
1237 )
1238 };
1239 let middlebands = unsafe {
1240 Vec::from_raw_parts(
1241 middle_guard.as_mut_ptr() as *mut f64,
1242 middle_guard.len(),
1243 middle_guard.capacity(),
1244 )
1245 };
1246 let lowerbands = unsafe {
1247 Vec::from_raw_parts(
1248 lower_guard.as_mut_ptr() as *mut f64,
1249 lower_guard.len(),
1250 lower_guard.capacity(),
1251 )
1252 };
1253
1254 Ok(MabBatchOutput {
1255 upperbands,
1256 middlebands,
1257 lowerbands,
1258 combos,
1259 rows,
1260 cols,
1261 })
1262}
1263
1264fn mab_batch_inner_into(
1265 input: &[f64],
1266 sweep: &MabBatchRange,
1267 kernel: Kernel,
1268 parallel: bool,
1269 upper_out: &mut [f64],
1270 middle_out: &mut [f64],
1271 lower_out: &mut [f64],
1272) -> Result<Vec<MabParams>, MabError> {
1273 let combos = expand_grid(sweep)?;
1274 let rows = combos.len();
1275 let cols = input.len();
1276 let expected = rows.checked_mul(cols).ok_or(MabError::InvalidRange {
1277 start: sweep.fast_period.0,
1278 end: sweep.fast_period.1,
1279 step: sweep.fast_period.2,
1280 })?;
1281
1282 if upper_out.len() != expected || middle_out.len() != expected || lower_out.len() != expected {
1283 return Err(MabError::OutputLengthMismatch {
1284 upper_len: upper_out.len(),
1285 middle_len: middle_out.len(),
1286 lower_len: lower_out.len(),
1287 expected,
1288 });
1289 }
1290
1291 if !combos.is_empty() {
1292 let p0 = &combos[0];
1293 let all_same_ma = combos.iter().all(|p| {
1294 p.fast_period == p0.fast_period
1295 && p.slow_period == p0.slow_period
1296 && p.fast_ma_type == p0.fast_ma_type
1297 && p.slow_ma_type == p0.slow_ma_type
1298 });
1299
1300 if all_same_ma {
1301 use crate::indicators::ema::{ema, EmaInput, EmaParams};
1302 use crate::indicators::sma::{sma, SmaInput, SmaParams};
1303
1304 let n = input.len();
1305 let first = input.iter().position(|x| !x.is_nan()).unwrap_or(0);
1306 let fast = p0.fast_period.unwrap();
1307 let slow = p0.slow_period.unwrap();
1308 let fast_ma_type = p0.fast_ma_type.as_deref().unwrap_or("sma");
1309 let slow_ma_type = p0.slow_ma_type.as_deref().unwrap_or("sma");
1310
1311 let fast_ma = match fast_ma_type {
1312 "ema" => {
1313 let params = EmaParams { period: Some(fast) };
1314 ema(&EmaInput::from_slice(input, params))
1315 .map_err(|_| MabError::NotEnoughValidData {
1316 needed: fast,
1317 valid: n - first,
1318 })?
1319 .values
1320 }
1321 _ => {
1322 let params = SmaParams { period: Some(fast) };
1323 sma(&SmaInput::from_slice(input, params))
1324 .map_err(|_| MabError::NotEnoughValidData {
1325 needed: fast,
1326 valid: n - first,
1327 })?
1328 .values
1329 }
1330 };
1331
1332 let slow_ma = match slow_ma_type {
1333 "ema" => {
1334 let params = EmaParams { period: Some(slow) };
1335 ema(&EmaInput::from_slice(input, params))
1336 .map_err(|_| MabError::NotEnoughValidData {
1337 needed: slow,
1338 valid: n - first,
1339 })?
1340 .values
1341 }
1342 _ => {
1343 let params = SmaParams { period: Some(slow) };
1344 sma(&SmaInput::from_slice(input, params))
1345 .map_err(|_| MabError::NotEnoughValidData {
1346 needed: slow,
1347 valid: n - first,
1348 })?
1349 .values
1350 }
1351 };
1352
1353 let need_total = fast.max(slow) + fast - 1;
1354 let warmup = first + need_total - 1;
1355 let first_output = warmup + 1;
1356
1357 if first_output < n {
1358 let mut dev: AVec<f64> = AVec::with_capacity(CACHELINE_ALIGN, n);
1359 unsafe {
1360 dev.set_len(n);
1361 }
1362
1363 unsafe {
1364 let f_ptr = fast_ma.as_ptr();
1365 let s_ptr = slow_ma.as_ptr();
1366 let d_ptr = dev.as_mut_ptr();
1367
1368 let start = first_output + 1 - fast;
1369 let mut sum_sq = 0.0f64;
1370 let mut k = 0usize;
1371 while k < fast {
1372 let idx = start + k;
1373 let diff = *f_ptr.add(idx) - *s_ptr.add(idx);
1374 sum_sq += diff * diff;
1375 k += 1;
1376 }
1377
1378 *d_ptr.add(first_output) = (sum_sq / fast as f64).sqrt();
1379
1380 let mut i = first_output + 1;
1381 while i < n {
1382 let old_idx = i - fast;
1383 let old = *f_ptr.add(old_idx) - *s_ptr.add(old_idx);
1384 let new = *f_ptr.add(i) - *s_ptr.add(i);
1385 sum_sq += new * new - old * old;
1386 *d_ptr.add(i) = (sum_sq / fast as f64).sqrt();
1387 i += 1;
1388 }
1389 }
1390
1391 let fill_row = |row: usize, u: &mut [f64], m: &mut [f64], l: &mut [f64]| {
1392 let pr = &combos[row];
1393 let devup = pr.devup.unwrap();
1394 let devdn = pr.devdn.unwrap();
1395 for i in first_output..n {
1396 let d = dev[i];
1397 m[i] = fast_ma[i];
1398 u[i] = slow_ma[i] + devup * d;
1399 l[i] = slow_ma[i] - devdn * d;
1400 }
1401 Ok(())
1402 };
1403
1404 if parallel {
1405 #[cfg(not(target_arch = "wasm32"))]
1406 {
1407 use rayon::prelude::*;
1408 upper_out
1409 .par_chunks_mut(cols)
1410 .zip(middle_out.par_chunks_mut(cols))
1411 .zip(lower_out.par_chunks_mut(cols))
1412 .enumerate()
1413 .try_for_each(|(row, ((u, m), l))| fill_row(row, u, m, l))?;
1414 }
1415 #[cfg(target_arch = "wasm32")]
1416 {
1417 for row in 0..rows {
1418 let s = row * cols;
1419 fill_row(
1420 row,
1421 &mut upper_out[s..s + cols],
1422 &mut middle_out[s..s + cols],
1423 &mut lower_out[s..s + cols],
1424 )?;
1425 }
1426 }
1427 } else {
1428 for row in 0..rows {
1429 let s = row * cols;
1430 fill_row(
1431 row,
1432 &mut upper_out[s..s + cols],
1433 &mut middle_out[s..s + cols],
1434 &mut lower_out[s..s + cols],
1435 )?;
1436 }
1437 }
1438
1439 return Ok(combos);
1440 }
1441 }
1442 }
1443
1444 let process_row = |row: usize, u: &mut [f64], m: &mut [f64], l: &mut [f64]| {
1445 let p = &combos[row];
1446 let in_row = MabInput::from_slice(
1447 input,
1448 MabParams {
1449 fast_period: p.fast_period,
1450 slow_period: p.slow_period,
1451 devup: p.devup,
1452 devdn: p.devdn,
1453 fast_ma_type: p.fast_ma_type.clone(),
1454 slow_ma_type: p.slow_ma_type.clone(),
1455 },
1456 );
1457 mab_into_slice(u, m, l, &in_row, kernel)
1458 };
1459
1460 #[cfg(not(target_arch = "wasm32"))]
1461 {
1462 if parallel {
1463 use rayon::prelude::*;
1464 upper_out
1465 .par_chunks_mut(cols)
1466 .zip(middle_out.par_chunks_mut(cols))
1467 .zip(lower_out.par_chunks_mut(cols))
1468 .enumerate()
1469 .try_for_each(|(row, ((u, m), l))| process_row(row, u, m, l))?;
1470 } else {
1471 for row in 0..rows {
1472 let s = row * cols;
1473 process_row(
1474 row,
1475 &mut upper_out[s..s + cols],
1476 &mut middle_out[s..s + cols],
1477 &mut lower_out[s..s + cols],
1478 )?;
1479 }
1480 }
1481 }
1482
1483 #[cfg(target_arch = "wasm32")]
1484 {
1485 for row in 0..rows {
1486 let s = row * cols;
1487 process_row(
1488 row,
1489 &mut upper_out[s..s + cols],
1490 &mut middle_out[s..s + cols],
1491 &mut lower_out[s..s + cols],
1492 )?;
1493 }
1494 }
1495
1496 Ok(combos)
1497}
1498
1499#[cfg(feature = "python")]
1500#[pyfunction(name = "mab")]
1501#[pyo3(signature = (data, fast_period=10, slow_period=50, devup=1.0, devdn=1.0, fast_ma_type="sma", slow_ma_type="sma", kernel=None))]
1502pub fn mab_py<'py>(
1503 py: Python<'py>,
1504 data: numpy::PyReadonlyArray1<'py, f64>,
1505 fast_period: usize,
1506 slow_period: usize,
1507 devup: f64,
1508 devdn: f64,
1509 fast_ma_type: &str,
1510 slow_ma_type: &str,
1511 kernel: Option<&str>,
1512) -> PyResult<(
1513 Bound<'py, PyArray1<f64>>,
1514 Bound<'py, PyArray1<f64>>,
1515 Bound<'py, PyArray1<f64>>,
1516)> {
1517 let slice_in = data.as_slice()?;
1518 let params = MabParams {
1519 fast_period: Some(fast_period),
1520 slow_period: Some(slow_period),
1521 devup: Some(devup),
1522 devdn: Some(devdn),
1523 fast_ma_type: Some(fast_ma_type.to_string()),
1524 slow_ma_type: Some(slow_ma_type.to_string()),
1525 };
1526 let input = MabInput::from_slice(slice_in, params);
1527
1528 let chosen_kernel = validate_kernel(kernel, false)?;
1529
1530 let result = py
1531 .allow_threads(|| match chosen_kernel {
1532 Kernel::Auto => mab(&input),
1533 k => mab_with_kernel(&input, k),
1534 })
1535 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1536
1537 Ok((
1538 result.upperband.into_pyarray(py),
1539 result.middleband.into_pyarray(py),
1540 result.lowerband.into_pyarray(py),
1541 ))
1542}
1543
1544#[cfg(feature = "python")]
1545#[pyclass(name = "MabStream")]
1546pub struct MabStreamPy {
1547 stream: MabStream,
1548}
1549
1550#[cfg(feature = "python")]
1551#[pymethods]
1552impl MabStreamPy {
1553 #[new]
1554 fn new(
1555 fast_period: usize,
1556 slow_period: usize,
1557 devup: f64,
1558 devdn: f64,
1559 fast_ma_type: &str,
1560 slow_ma_type: &str,
1561 ) -> PyResult<Self> {
1562 let params = MabParams {
1563 fast_period: Some(fast_period),
1564 slow_period: Some(slow_period),
1565 devup: Some(devup),
1566 devdn: Some(devdn),
1567 fast_ma_type: Some(fast_ma_type.to_string()),
1568 slow_ma_type: Some(slow_ma_type.to_string()),
1569 };
1570 let stream =
1571 MabStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1572 Ok(MabStreamPy { stream })
1573 }
1574
1575 fn update(&mut self, value: f64) -> Option<(f64, f64, f64)> {
1576 self.stream.update(value)
1577 }
1578}
1579
1580#[cfg(feature = "python")]
1581#[pyfunction(name = "mab_batch")]
1582#[pyo3(signature = (data, fast_period_range, slow_period_range, devup_range=(1.0, 1.0, 0.0), devdn_range=(1.0, 1.0, 0.0), fast_ma_type="sma", slow_ma_type="sma", kernel=None))]
1583pub fn mab_batch_py<'py>(
1584 py: Python<'py>,
1585 data: numpy::PyReadonlyArray1<'py, f64>,
1586 fast_period_range: (usize, usize, usize),
1587 slow_period_range: (usize, usize, usize),
1588 devup_range: (f64, f64, f64),
1589 devdn_range: (f64, f64, f64),
1590 fast_ma_type: &str,
1591 slow_ma_type: &str,
1592 kernel: Option<&str>,
1593) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1594 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1595
1596 let slice_in = data.as_slice()?;
1597
1598 let sweep = MabBatchRange {
1599 fast_period: fast_period_range,
1600 slow_period: slow_period_range,
1601 devup: devup_range,
1602 devdn: devdn_range,
1603 fast_ma_type: (
1604 fast_ma_type.to_string(),
1605 fast_ma_type.to_string(),
1606 "".to_string(),
1607 ),
1608 slow_ma_type: (
1609 slow_ma_type.to_string(),
1610 slow_ma_type.to_string(),
1611 "".to_string(),
1612 ),
1613 };
1614
1615 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1616 let rows = combos.len();
1617 let cols = slice_in.len();
1618 if cols == 0 {
1619 return Err(PyValueError::new_err(MabError::EmptyInputData.to_string()));
1620 }
1621 let total = rows
1622 .checked_mul(cols)
1623 .ok_or_else(|| PyValueError::new_err("mab_batch: rows*cols overflow"))?;
1624
1625 let upper_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1626 let middle_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1627 let lower_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1628
1629 let slice_upper = unsafe { upper_arr.as_slice_mut()? };
1630 let slice_middle = unsafe { middle_arr.as_slice_mut()? };
1631 let slice_lower = unsafe { lower_arr.as_slice_mut()? };
1632
1633 let kern = validate_kernel(kernel, true)?;
1634
1635 let first_valid = slice_in
1636 .iter()
1637 .position(|x| !x.is_nan())
1638 .ok_or_else(|| PyValueError::new_err(MabError::AllValuesNaN.to_string()))?;
1639 let valid = cols - first_valid;
1640 let warmup_prefixes: Vec<usize> = combos
1641 .iter()
1642 .map(|p| {
1643 let fast = p.fast_period.unwrap();
1644 let slow = p.slow_period.unwrap();
1645 if fast == 0 || slow == 0 || fast > cols || slow > cols {
1646 return Err(MabError::InvalidPeriod {
1647 fast,
1648 slow,
1649 data_len: cols,
1650 });
1651 }
1652 let need_total = fast.max(slow) + fast - 1;
1653 if valid < need_total {
1654 return Err(MabError::NotEnoughValidData {
1655 needed: need_total,
1656 valid,
1657 });
1658 }
1659 Ok(first_valid + need_total)
1660 })
1661 .collect::<Result<Vec<_>, MabError>>()
1662 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1663
1664 let mu_upper: &mut [MaybeUninit<f64>] = unsafe {
1665 let ptr = upper_arr.as_array_mut().as_mut_ptr();
1666 std::slice::from_raw_parts_mut(ptr as *mut MaybeUninit<f64>, total)
1667 };
1668 let mu_middle: &mut [MaybeUninit<f64>] = unsafe {
1669 let ptr = middle_arr.as_array_mut().as_mut_ptr();
1670 std::slice::from_raw_parts_mut(ptr as *mut MaybeUninit<f64>, total)
1671 };
1672 let mu_lower: &mut [MaybeUninit<f64>] = unsafe {
1673 let ptr = lower_arr.as_array_mut().as_mut_ptr();
1674 std::slice::from_raw_parts_mut(ptr as *mut MaybeUninit<f64>, total)
1675 };
1676 init_matrix_prefixes(mu_upper, cols, &warmup_prefixes);
1677 init_matrix_prefixes(mu_middle, cols, &warmup_prefixes);
1678 init_matrix_prefixes(mu_lower, cols, &warmup_prefixes);
1679
1680 let combos = py
1681 .allow_threads(|| {
1682 let kernel = match kern {
1683 Kernel::Auto => detect_best_batch_kernel(),
1684 k => k,
1685 };
1686 let simd = match kernel {
1687 Kernel::Avx512Batch => Kernel::Avx512,
1688 Kernel::Avx2Batch => Kernel::Avx2,
1689 Kernel::ScalarBatch => Kernel::Scalar,
1690 _ => unreachable!(),
1691 };
1692
1693 mab_batch_inner_into(
1694 slice_in,
1695 &sweep,
1696 simd,
1697 true,
1698 slice_upper,
1699 slice_middle,
1700 slice_lower,
1701 )
1702 })
1703 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1704
1705 let dict = PyDict::new(py);
1706 dict.set_item("upperbands", upper_arr.reshape((rows, cols))?)?;
1707 dict.set_item("middlebands", middle_arr.reshape((rows, cols))?)?;
1708 dict.set_item("lowerbands", lower_arr.reshape((rows, cols))?)?;
1709
1710 dict.set_item(
1711 "fast_periods",
1712 combos
1713 .iter()
1714 .map(|p| p.fast_period.unwrap() as u64)
1715 .collect::<Vec<_>>()
1716 .into_pyarray(py),
1717 )?;
1718 dict.set_item(
1719 "slow_periods",
1720 combos
1721 .iter()
1722 .map(|p| p.slow_period.unwrap() as u64)
1723 .collect::<Vec<_>>()
1724 .into_pyarray(py),
1725 )?;
1726 dict.set_item(
1727 "devups",
1728 combos
1729 .iter()
1730 .map(|p| p.devup.unwrap())
1731 .collect::<Vec<_>>()
1732 .into_pyarray(py),
1733 )?;
1734 dict.set_item(
1735 "devdns",
1736 combos
1737 .iter()
1738 .map(|p| p.devdn.unwrap())
1739 .collect::<Vec<_>>()
1740 .into_pyarray(py),
1741 )?;
1742
1743 Ok(dict)
1744}
1745
1746#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1747#[derive(Serialize, Deserialize)]
1748pub struct MabJsSingle {
1749 pub values: Vec<f64>,
1750 pub rows: usize,
1751 pub cols: usize,
1752}
1753
1754#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1755#[wasm_bindgen(js_name = "mab")]
1756pub fn mab_wasm(
1757 data: &[f64],
1758 fast_period: usize,
1759 slow_period: usize,
1760 devup: f64,
1761 devdn: f64,
1762 fast_ma_type: &str,
1763 slow_ma_type: &str,
1764) -> Result<JsValue, JsValue> {
1765 let params = MabParams {
1766 fast_period: Some(fast_period),
1767 slow_period: Some(slow_period),
1768 devup: Some(devup),
1769 devdn: Some(devdn),
1770 fast_ma_type: Some(fast_ma_type.to_string()),
1771 slow_ma_type: Some(slow_ma_type.to_string()),
1772 };
1773 let input = MabInput::from_slice(data, params);
1774
1775 let mut upper = vec![0.0; data.len()];
1776 let mut middle = vec![0.0; data.len()];
1777 let mut lower = vec![0.0; data.len()];
1778
1779 mab_into_slice(
1780 &mut upper,
1781 &mut middle,
1782 &mut lower,
1783 &input,
1784 detect_best_kernel(),
1785 )
1786 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1787
1788 let mut values = Vec::with_capacity(3 * data.len());
1789 values.extend_from_slice(&upper);
1790 values.extend_from_slice(&middle);
1791 values.extend_from_slice(&lower);
1792
1793 serde_wasm_bindgen::to_value(&MabJsSingle {
1794 values,
1795 rows: 3,
1796 cols: data.len(),
1797 })
1798 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1799}
1800
1801#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1802#[wasm_bindgen]
1803pub fn mab_js(
1804 data: &[f64],
1805 fast_period: usize,
1806 slow_period: usize,
1807 devup: f64,
1808 devdn: f64,
1809 fast_ma_type: &str,
1810 slow_ma_type: &str,
1811) -> Result<Vec<f64>, JsValue> {
1812 let params = MabParams {
1813 fast_period: Some(fast_period),
1814 slow_period: Some(slow_period),
1815 devup: Some(devup),
1816 devdn: Some(devdn),
1817 fast_ma_type: Some(fast_ma_type.to_string()),
1818 slow_ma_type: Some(slow_ma_type.to_string()),
1819 };
1820 let input = MabInput::from_slice(data, params);
1821
1822 let mut upper = vec![0.0; data.len()];
1823 let mut middle = vec![0.0; data.len()];
1824 let mut lower = vec![0.0; data.len()];
1825
1826 mab_into_slice(
1827 &mut upper,
1828 &mut middle,
1829 &mut lower,
1830 &input,
1831 detect_best_kernel(),
1832 )
1833 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1834
1835 let mut result = Vec::with_capacity(3 * data.len());
1836 result.extend_from_slice(&upper);
1837 result.extend_from_slice(&middle);
1838 result.extend_from_slice(&lower);
1839
1840 Ok(result)
1841}
1842
1843#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1844#[derive(Serialize, Deserialize)]
1845pub struct MabBatchConfig {
1846 pub fast_period_range: (usize, usize, usize),
1847 pub slow_period_range: (usize, usize, usize),
1848 pub devup_range: (f64, f64, f64),
1849 pub devdn_range: (f64, f64, f64),
1850 pub fast_ma_type: String,
1851 pub slow_ma_type: String,
1852}
1853
1854#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1855#[derive(Serialize, Deserialize)]
1856pub struct MabBatchJsOutput {
1857 pub upperbands: Vec<f64>,
1858 pub middlebands: Vec<f64>,
1859 pub lowerbands: Vec<f64>,
1860 pub combos: Vec<MabParams>,
1861 pub rows: usize,
1862 pub cols: usize,
1863}
1864
1865#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1866#[wasm_bindgen(js_name = mab_batch)]
1867pub fn mab_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1868 let config: MabBatchConfig = serde_wasm_bindgen::from_value(config)
1869 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1870
1871 let sweep = MabBatchRange {
1872 fast_period: config.fast_period_range,
1873 slow_period: config.slow_period_range,
1874 devup: config.devup_range,
1875 devdn: config.devdn_range,
1876 fast_ma_type: (
1877 config.fast_ma_type.clone(),
1878 config.fast_ma_type.clone(),
1879 "".to_string(),
1880 ),
1881 slow_ma_type: (
1882 config.slow_ma_type.clone(),
1883 config.slow_ma_type.clone(),
1884 "".to_string(),
1885 ),
1886 };
1887
1888 let output = mab_batch_inner(data, &sweep, Kernel::Auto, false)
1889 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1890
1891 let js_output = MabBatchJsOutput {
1892 upperbands: output.upperbands,
1893 middlebands: output.middlebands,
1894 lowerbands: output.lowerbands,
1895 combos: output.combos,
1896 rows: output.rows,
1897 cols: output.cols,
1898 };
1899
1900 serde_wasm_bindgen::to_value(&js_output)
1901 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1902}
1903
1904#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1905#[wasm_bindgen]
1906pub fn mab_alloc(len: usize) -> *mut f64 {
1907 let mut vec = Vec::<f64>::with_capacity(len);
1908 let ptr = vec.as_mut_ptr();
1909 std::mem::forget(vec);
1910 ptr
1911}
1912
1913#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1914#[wasm_bindgen]
1915pub fn mab_free(ptr: *mut f64, len: usize) {
1916 if !ptr.is_null() {
1917 unsafe {
1918 let _ = Vec::from_raw_parts(ptr, len, len);
1919 }
1920 }
1921}
1922
1923#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1924#[wasm_bindgen]
1925pub fn mab_into(
1926 in_ptr: *const f64,
1927 upper_ptr: *mut f64,
1928 middle_ptr: *mut f64,
1929 lower_ptr: *mut f64,
1930 len: usize,
1931 fast_period: usize,
1932 slow_period: usize,
1933 devup: f64,
1934 devdn: f64,
1935 fast_ma_type: &str,
1936 slow_ma_type: &str,
1937) -> Result<(), JsValue> {
1938 if in_ptr.is_null() || upper_ptr.is_null() || middle_ptr.is_null() || lower_ptr.is_null() {
1939 return Err(JsValue::from_str("Null pointer provided"));
1940 }
1941
1942 unsafe {
1943 let data = std::slice::from_raw_parts(in_ptr, len);
1944 let params = MabParams {
1945 fast_period: Some(fast_period),
1946 slow_period: Some(slow_period),
1947 devup: Some(devup),
1948 devdn: Some(devdn),
1949 fast_ma_type: Some(fast_ma_type.to_string()),
1950 slow_ma_type: Some(slow_ma_type.to_string()),
1951 };
1952 let input = MabInput::from_slice(data, params);
1953
1954 let need_temp = in_ptr == upper_ptr || in_ptr == middle_ptr || in_ptr == lower_ptr;
1955
1956 if need_temp {
1957 let mut temp_upper = vec![0.0; len];
1958 let mut temp_middle = vec![0.0; len];
1959 let mut temp_lower = vec![0.0; len];
1960
1961 mab_into_slice(
1962 &mut temp_upper,
1963 &mut temp_middle,
1964 &mut temp_lower,
1965 &input,
1966 Kernel::Auto,
1967 )
1968 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1969
1970 let upper_out = std::slice::from_raw_parts_mut(upper_ptr, len);
1971 let middle_out = std::slice::from_raw_parts_mut(middle_ptr, len);
1972 let lower_out = std::slice::from_raw_parts_mut(lower_ptr, len);
1973
1974 upper_out.copy_from_slice(&temp_upper);
1975 middle_out.copy_from_slice(&temp_middle);
1976 lower_out.copy_from_slice(&temp_lower);
1977 } else {
1978 let upper_out = std::slice::from_raw_parts_mut(upper_ptr, len);
1979 let middle_out = std::slice::from_raw_parts_mut(middle_ptr, len);
1980 let lower_out = std::slice::from_raw_parts_mut(lower_ptr, len);
1981
1982 mab_into_slice(upper_out, middle_out, lower_out, &input, Kernel::Auto)
1983 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1984 }
1985 Ok(())
1986 }
1987}
1988
1989#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1990#[wasm_bindgen]
1991pub fn mab_batch_into(
1992 in_ptr: *const f64,
1993 upper_ptr: *mut f64,
1994 middle_ptr: *mut f64,
1995 lower_ptr: *mut f64,
1996 len: usize,
1997 fast_period_start: usize,
1998 fast_period_end: usize,
1999 fast_period_step: usize,
2000 slow_period_start: usize,
2001 slow_period_end: usize,
2002 slow_period_step: usize,
2003 devup_start: f64,
2004 devup_end: f64,
2005 devup_step: f64,
2006 devdn_start: f64,
2007 devdn_end: f64,
2008 devdn_step: f64,
2009 fast_ma_type: &str,
2010 slow_ma_type: &str,
2011) -> Result<usize, JsValue> {
2012 if in_ptr.is_null() || upper_ptr.is_null() || middle_ptr.is_null() || lower_ptr.is_null() {
2013 return Err(JsValue::from_str("Null pointer passed to mab_batch_into"));
2014 }
2015
2016 unsafe {
2017 let data = std::slice::from_raw_parts(in_ptr, len);
2018
2019 let sweep = MabBatchRange {
2020 fast_period: (fast_period_start, fast_period_end, fast_period_step),
2021 slow_period: (slow_period_start, slow_period_end, slow_period_step),
2022 devup: (devup_start, devup_end, devup_step),
2023 devdn: (devdn_start, devdn_end, devdn_step),
2024 fast_ma_type: (
2025 fast_ma_type.to_string(),
2026 fast_ma_type.to_string(),
2027 "".to_string(),
2028 ),
2029 slow_ma_type: (
2030 slow_ma_type.to_string(),
2031 slow_ma_type.to_string(),
2032 "".to_string(),
2033 ),
2034 };
2035
2036 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2037 let rows = combos.len();
2038 let cols = len;
2039 if cols == 0 {
2040 return Err(JsValue::from_str(&MabError::EmptyInputData.to_string()));
2041 }
2042 let total = rows
2043 .checked_mul(cols)
2044 .ok_or_else(|| JsValue::from_str("mab_batch_into: rows*cols overflow"))?;
2045
2046 let first_valid = data
2047 .iter()
2048 .position(|x| !x.is_nan())
2049 .ok_or(MabError::AllValuesNaN)
2050 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2051 let valid = cols - first_valid;
2052 let warmup_prefixes: Vec<usize> = combos
2053 .iter()
2054 .map(|p| {
2055 let fast = p.fast_period.unwrap();
2056 let slow = p.slow_period.unwrap();
2057 if fast == 0 || slow == 0 || fast > cols || slow > cols {
2058 return Err(MabError::InvalidPeriod {
2059 fast,
2060 slow,
2061 data_len: cols,
2062 });
2063 }
2064 let need_total = fast.max(slow) + fast - 1;
2065 if valid < need_total {
2066 return Err(MabError::NotEnoughValidData {
2067 needed: need_total,
2068 valid,
2069 });
2070 }
2071 Ok(first_valid + need_total)
2072 })
2073 .collect::<Result<Vec<_>, MabError>>()
2074 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2075
2076 let mu_upper: &mut [MaybeUninit<f64>] =
2077 std::slice::from_raw_parts_mut(upper_ptr as *mut MaybeUninit<f64>, total);
2078 let mu_middle: &mut [MaybeUninit<f64>] =
2079 std::slice::from_raw_parts_mut(middle_ptr as *mut MaybeUninit<f64>, total);
2080 let mu_lower: &mut [MaybeUninit<f64>] =
2081 std::slice::from_raw_parts_mut(lower_ptr as *mut MaybeUninit<f64>, total);
2082 init_matrix_prefixes(mu_upper, cols, &warmup_prefixes);
2083 init_matrix_prefixes(mu_middle, cols, &warmup_prefixes);
2084 init_matrix_prefixes(mu_lower, cols, &warmup_prefixes);
2085
2086 let upper_out = std::slice::from_raw_parts_mut(upper_ptr, total);
2087 let middle_out = std::slice::from_raw_parts_mut(middle_ptr, total);
2088 let lower_out = std::slice::from_raw_parts_mut(lower_ptr, total);
2089
2090 mab_batch_inner_into(
2091 data,
2092 &sweep,
2093 Kernel::Auto,
2094 false,
2095 upper_out,
2096 middle_out,
2097 lower_out,
2098 )
2099 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2100
2101 Ok(rows)
2102 }
2103}
2104
2105#[inline]
2106pub unsafe fn mab_scalar_classic_sma(
2107 data: &[f64],
2108 fast_period: usize,
2109 slow_period: usize,
2110 devup: f64,
2111 devdn: f64,
2112 first_valid_idx: usize,
2113 upper: &mut [f64],
2114 middle: &mut [f64],
2115 lower: &mut [f64],
2116) -> Result<(), MabError> {
2117 let n = data.len();
2118
2119 let mut fast_ma = vec![f64::NAN; n];
2120 if fast_period > 0 && first_valid_idx + fast_period <= n {
2121 let mut sum = 0.0;
2122 for i in 0..fast_period {
2123 sum += data[first_valid_idx + i];
2124 }
2125 fast_ma[first_valid_idx + fast_period - 1] = sum / fast_period as f64;
2126
2127 for i in (first_valid_idx + fast_period)..n {
2128 sum = sum - data[i - fast_period] + data[i];
2129 fast_ma[i] = sum / fast_period as f64;
2130 }
2131 }
2132
2133 let mut slow_ma = vec![f64::NAN; n];
2134 if slow_period > 0 && first_valid_idx + slow_period <= n {
2135 let mut sum = 0.0;
2136 for i in 0..slow_period {
2137 sum += data[first_valid_idx + i];
2138 }
2139 slow_ma[first_valid_idx + slow_period - 1] = sum / slow_period as f64;
2140
2141 for i in (first_valid_idx + slow_period)..n {
2142 sum = sum - data[i - slow_period] + data[i];
2143 slow_ma[i] = sum / slow_period as f64;
2144 }
2145 }
2146
2147 let need_total = slow_period.max(fast_period) + fast_period - 1;
2148 let warmup = first_valid_idx + need_total - 1;
2149 let first_output = warmup + 1;
2150
2151 for i in 0..first_output.min(n) {
2152 upper[i] = f64::NAN;
2153 middle[i] = f64::NAN;
2154 lower[i] = f64::NAN;
2155 }
2156
2157 if first_output >= n {
2158 return Ok(());
2159 }
2160
2161 let start_idx = if first_output >= fast_period {
2162 first_output - fast_period + 1
2163 } else {
2164 0
2165 };
2166
2167 let mut sum_sq = 0.0;
2168 for i in start_idx..(start_idx + fast_period).min(fast_ma.len()) {
2169 let diff = fast_ma[i] - slow_ma[i];
2170 if !diff.is_nan() {
2171 sum_sq += diff * diff;
2172 }
2173 }
2174
2175 if first_output < fast_ma.len() {
2176 let dev = (sum_sq / fast_period as f64).sqrt();
2177 middle[first_output] = fast_ma[first_output];
2178 upper[first_output] = slow_ma[first_output] + devup * dev;
2179 lower[first_output] = slow_ma[first_output] - devdn * dev;
2180 }
2181
2182 for i in (first_output + 1)..fast_ma.len() {
2183 let old_idx = i - fast_period;
2184 let old = fast_ma[old_idx] - slow_ma[old_idx];
2185 let new = fast_ma[i] - slow_ma[i];
2186 if !old.is_nan() && !new.is_nan() {
2187 sum_sq += new * new - old * old;
2188 }
2189 let dev = (sum_sq / fast_period as f64).sqrt();
2190
2191 middle[i] = fast_ma[i];
2192 upper[i] = slow_ma[i] + devup * dev;
2193 lower[i] = slow_ma[i] - devdn * dev;
2194 }
2195
2196 Ok(())
2197}
2198
2199#[cfg(all(feature = "python", feature = "cuda"))]
2200use crate::cuda::{cuda_available, moving_averages::CudaMab};
2201#[cfg(all(feature = "python", feature = "cuda"))]
2202use crate::indicators::moving_averages::alma::{make_device_array_py, DeviceArrayF32Py};
2203#[cfg(all(feature = "python", feature = "cuda"))]
2204#[cfg(all(feature = "python", feature = "cuda"))]
2205use numpy::{PyReadonlyArray1, PyReadonlyArray2, PyUntypedArrayMethods};
2206#[cfg(all(feature = "python", feature = "cuda"))]
2207use pyo3::{pyfunction, PyResult, Python};
2208
2209#[cfg(all(feature = "python", feature = "cuda"))]
2210#[pyfunction(name = "mab_cuda_batch_dev")]
2211#[pyo3(signature = (data_f32, fast_period_range, slow_period_range, devup_range=(1.0,1.0,0.0), devdn_range=(1.0,1.0,0.0), fast_ma_type="sma", slow_ma_type="sma", device_id=0))]
2212pub fn mab_cuda_batch_dev_py(
2213 py: Python<'_>,
2214 data_f32: PyReadonlyArray1<'_, f32>,
2215 fast_period_range: (usize, usize, usize),
2216 slow_period_range: (usize, usize, usize),
2217 devup_range: (f64, f64, f64),
2218 devdn_range: (f64, f64, f64),
2219 fast_ma_type: &str,
2220 slow_ma_type: &str,
2221 device_id: usize,
2222) -> PyResult<(DeviceArrayF32Py, DeviceArrayF32Py, DeviceArrayF32Py)> {
2223 if !cuda_available() {
2224 return Err(PyValueError::new_err("CUDA not available"));
2225 }
2226 let slice = data_f32.as_slice()?;
2227 let sweep = MabBatchRange {
2228 fast_period: fast_period_range,
2229 slow_period: slow_period_range,
2230 devup: devup_range,
2231 devdn: devdn_range,
2232 fast_ma_type: (
2233 fast_ma_type.to_string(),
2234 fast_ma_type.to_string(),
2235 String::new(),
2236 ),
2237 slow_ma_type: (
2238 slow_ma_type.to_string(),
2239 slow_ma_type.to_string(),
2240 String::new(),
2241 ),
2242 };
2243 let (up, mid, lo) = py.allow_threads(|| {
2244 let cuda = CudaMab::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2245 let (trip, _combos) = cuda
2246 .mab_batch_dev(slice, &sweep)
2247 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2248 Ok::<_, pyo3::PyErr>((trip.upper, trip.middle, trip.lower))
2249 })?;
2250
2251 let up_py = make_device_array_py(device_id, up)?;
2252 let mid_py = make_device_array_py(device_id, mid)?;
2253 let lo_py = make_device_array_py(device_id, lo)?;
2254
2255 Ok((up_py, mid_py, lo_py))
2256}
2257
2258#[cfg(all(feature = "python", feature = "cuda"))]
2259#[pyfunction(name = "mab_cuda_many_series_one_param_dev")]
2260#[pyo3(signature = (data_tm_f32, fast_period, slow_period, devup=1.0, devdn=1.0, fast_ma_type="sma", slow_ma_type="sma", device_id=0))]
2261pub fn mab_cuda_many_series_one_param_dev_py(
2262 py: Python<'_>,
2263 data_tm_f32: PyReadonlyArray2<'_, f32>,
2264 fast_period: usize,
2265 slow_period: usize,
2266 devup: f64,
2267 devdn: f64,
2268 fast_ma_type: &str,
2269 slow_ma_type: &str,
2270 device_id: usize,
2271) -> PyResult<(DeviceArrayF32Py, DeviceArrayF32Py, DeviceArrayF32Py)> {
2272 if !cuda_available() {
2273 return Err(PyValueError::new_err("CUDA not available"));
2274 }
2275 let flat: &[f32] = data_tm_f32.as_slice()?;
2276 let rows = data_tm_f32.shape()[0];
2277 let cols = data_tm_f32.shape()[1];
2278 let params = MabParams {
2279 fast_period: Some(fast_period),
2280 slow_period: Some(slow_period),
2281 devup: Some(devup),
2282 devdn: Some(devdn),
2283 fast_ma_type: Some(fast_ma_type.to_string()),
2284 slow_ma_type: Some(slow_ma_type.to_string()),
2285 };
2286 let (up, mid, lo) = py.allow_threads(|| {
2287 let cuda = CudaMab::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2288 let trip = cuda
2289 .mab_many_series_one_param_time_major_dev(flat, cols, rows, ¶ms)
2290 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2291 Ok::<_, pyo3::PyErr>((trip.upper, trip.middle, trip.lower))
2292 })?;
2293
2294 let up_py = make_device_array_py(device_id, up)?;
2295 let mid_py = make_device_array_py(device_id, mid)?;
2296 let lo_py = make_device_array_py(device_id, lo)?;
2297
2298 Ok((up_py, mid_py, lo_py))
2299}
2300
2301#[cfg(test)]
2302mod tests {
2303 use super::*;
2304 use crate::utilities::data_loader::read_candles_from_csv;
2305 #[cfg(feature = "proptest")]
2306 use proptest::prelude::*;
2307 use std::error::Error;
2308
2309 macro_rules! skip_if_unsupported {
2310 ($kernel:expr, $test_name:expr) => {
2311 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
2312 if matches!(
2313 $kernel,
2314 Kernel::Avx2 | Kernel::Avx512 | Kernel::Avx2Batch | Kernel::Avx512Batch
2315 ) {
2316 eprintln!(
2317 "[{}] Skipping - {:?} not supported on WASM",
2318 $test_name, $kernel
2319 );
2320 return Ok(());
2321 }
2322 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
2323 if matches!(
2324 $kernel,
2325 Kernel::Avx2 | Kernel::Avx512 | Kernel::Avx2Batch | Kernel::Avx512Batch
2326 ) {
2327 eprintln!(
2328 "[{}] Skipping - {:?} requires 'nightly-avx' feature",
2329 $test_name, $kernel
2330 );
2331 return Ok(());
2332 }
2333 };
2334 }
2335
2336 #[cfg(debug_assertions)]
2337 fn check_mab_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2338 skip_if_unsupported!(kernel, test_name);
2339
2340 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2341 let candles = read_candles_from_csv(file_path)?;
2342
2343 let test_params = vec![
2344 MabParams::default(),
2345 MabParams {
2346 fast_period: Some(2),
2347 slow_period: Some(3),
2348 devup: Some(1.0),
2349 devdn: Some(1.0),
2350 fast_ma_type: Some("sma".to_string()),
2351 slow_ma_type: Some("sma".to_string()),
2352 },
2353 MabParams {
2354 fast_period: Some(5),
2355 slow_period: Some(10),
2356 devup: Some(0.5),
2357 devdn: Some(0.5),
2358 fast_ma_type: Some("sma".to_string()),
2359 slow_ma_type: Some("sma".to_string()),
2360 },
2361 MabParams {
2362 fast_period: Some(15),
2363 slow_period: Some(30),
2364 devup: Some(2.0),
2365 devdn: Some(2.0),
2366 fast_ma_type: Some("ema".to_string()),
2367 slow_ma_type: Some("ema".to_string()),
2368 },
2369 MabParams {
2370 fast_period: Some(50),
2371 slow_period: Some(100),
2372 devup: Some(3.0),
2373 devdn: Some(3.0),
2374 fast_ma_type: Some("sma".to_string()),
2375 slow_ma_type: Some("sma".to_string()),
2376 },
2377 MabParams {
2378 fast_period: Some(10),
2379 slow_period: Some(20),
2380 devup: Some(1.5),
2381 devdn: Some(1.5),
2382 fast_ma_type: Some("sma".to_string()),
2383 slow_ma_type: Some("ema".to_string()),
2384 },
2385 MabParams {
2386 fast_period: Some(8),
2387 slow_period: Some(21),
2388 devup: Some(2.5),
2389 devdn: Some(2.5),
2390 fast_ma_type: Some("ema".to_string()),
2391 slow_ma_type: Some("sma".to_string()),
2392 },
2393 MabParams {
2394 fast_period: Some(12),
2395 slow_period: Some(26),
2396 devup: Some(2.0),
2397 devdn: Some(1.0),
2398 fast_ma_type: Some("ema".to_string()),
2399 slow_ma_type: Some("ema".to_string()),
2400 },
2401 MabParams {
2402 fast_period: Some(9),
2403 slow_period: Some(10),
2404 devup: Some(1.0),
2405 devdn: Some(2.0),
2406 fast_ma_type: Some("sma".to_string()),
2407 slow_ma_type: Some("sma".to_string()),
2408 },
2409 MabParams {
2410 fast_period: Some(30),
2411 slow_period: Some(200),
2412 devup: Some(1.0),
2413 devdn: Some(1.0),
2414 fast_ma_type: Some("ema".to_string()),
2415 slow_ma_type: Some("ema".to_string()),
2416 },
2417 ];
2418
2419 for (param_idx, params) in test_params.iter().enumerate() {
2420 let input = MabInput::from_candles(&candles, "close", params.clone());
2421 let output = mab_with_kernel(&input, kernel)?;
2422
2423 for (i, &val) in output.upperband.iter().enumerate() {
2424 if val.is_nan() {
2425 continue;
2426 }
2427
2428 let bits = val.to_bits();
2429
2430 if bits == 0x11111111_11111111 {
2431 panic!(
2432 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in upperband \
2433 with params: fast_period={}, slow_period={}, devup={}, devdn={}, fast_ma_type={}, slow_ma_type={} (param set {})",
2434 test_name, val, bits, i,
2435 params.fast_period.unwrap_or(10),
2436 params.slow_period.unwrap_or(50),
2437 params.devup.unwrap_or(1.0),
2438 params.devdn.unwrap_or(1.0),
2439 params.fast_ma_type.as_ref().unwrap_or(&"sma".to_string()),
2440 params.slow_ma_type.as_ref().unwrap_or(&"sma".to_string()),
2441 param_idx
2442 );
2443 }
2444
2445 if bits == 0x22222222_22222222 {
2446 panic!(
2447 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in upperband \
2448 with params: fast_period={}, slow_period={}, devup={}, devdn={}, fast_ma_type={}, slow_ma_type={} (param set {})",
2449 test_name, val, bits, i,
2450 params.fast_period.unwrap_or(10),
2451 params.slow_period.unwrap_or(50),
2452 params.devup.unwrap_or(1.0),
2453 params.devdn.unwrap_or(1.0),
2454 params.fast_ma_type.as_ref().unwrap_or(&"sma".to_string()),
2455 params.slow_ma_type.as_ref().unwrap_or(&"sma".to_string()),
2456 param_idx
2457 );
2458 }
2459
2460 if bits == 0x33333333_33333333 {
2461 panic!(
2462 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in upperband \
2463 with params: fast_period={}, slow_period={}, devup={}, devdn={}, fast_ma_type={}, slow_ma_type={} (param set {})",
2464 test_name, val, bits, i,
2465 params.fast_period.unwrap_or(10),
2466 params.slow_period.unwrap_or(50),
2467 params.devup.unwrap_or(1.0),
2468 params.devdn.unwrap_or(1.0),
2469 params.fast_ma_type.as_ref().unwrap_or(&"sma".to_string()),
2470 params.slow_ma_type.as_ref().unwrap_or(&"sma".to_string()),
2471 param_idx
2472 );
2473 }
2474 }
2475
2476 for (i, &val) in output.middleband.iter().enumerate() {
2477 if val.is_nan() {
2478 continue;
2479 }
2480
2481 let bits = val.to_bits();
2482
2483 if bits == 0x11111111_11111111 {
2484 panic!(
2485 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in middleband \
2486 with params: fast_period={}, slow_period={}, devup={}, devdn={}, fast_ma_type={}, slow_ma_type={} (param set {})",
2487 test_name, val, bits, i,
2488 params.fast_period.unwrap_or(10),
2489 params.slow_period.unwrap_or(50),
2490 params.devup.unwrap_or(1.0),
2491 params.devdn.unwrap_or(1.0),
2492 params.fast_ma_type.as_ref().unwrap_or(&"sma".to_string()),
2493 params.slow_ma_type.as_ref().unwrap_or(&"sma".to_string()),
2494 param_idx
2495 );
2496 }
2497
2498 if bits == 0x22222222_22222222 {
2499 panic!(
2500 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in middleband \
2501 with params: fast_period={}, slow_period={}, devup={}, devdn={}, fast_ma_type={}, slow_ma_type={} (param set {})",
2502 test_name, val, bits, i,
2503 params.fast_period.unwrap_or(10),
2504 params.slow_period.unwrap_or(50),
2505 params.devup.unwrap_or(1.0),
2506 params.devdn.unwrap_or(1.0),
2507 params.fast_ma_type.as_ref().unwrap_or(&"sma".to_string()),
2508 params.slow_ma_type.as_ref().unwrap_or(&"sma".to_string()),
2509 param_idx
2510 );
2511 }
2512
2513 if bits == 0x33333333_33333333 {
2514 panic!(
2515 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in middleband \
2516 with params: fast_period={}, slow_period={}, devup={}, devdn={}, fast_ma_type={}, slow_ma_type={} (param set {})",
2517 test_name, val, bits, i,
2518 params.fast_period.unwrap_or(10),
2519 params.slow_period.unwrap_or(50),
2520 params.devup.unwrap_or(1.0),
2521 params.devdn.unwrap_or(1.0),
2522 params.fast_ma_type.as_ref().unwrap_or(&"sma".to_string()),
2523 params.slow_ma_type.as_ref().unwrap_or(&"sma".to_string()),
2524 param_idx
2525 );
2526 }
2527 }
2528
2529 for (i, &val) in output.lowerband.iter().enumerate() {
2530 if val.is_nan() {
2531 continue;
2532 }
2533
2534 let bits = val.to_bits();
2535
2536 if bits == 0x11111111_11111111 {
2537 panic!(
2538 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in lowerband \
2539 with params: fast_period={}, slow_period={}, devup={}, devdn={}, fast_ma_type={}, slow_ma_type={} (param set {})",
2540 test_name, val, bits, i,
2541 params.fast_period.unwrap_or(10),
2542 params.slow_period.unwrap_or(50),
2543 params.devup.unwrap_or(1.0),
2544 params.devdn.unwrap_or(1.0),
2545 params.fast_ma_type.as_ref().unwrap_or(&"sma".to_string()),
2546 params.slow_ma_type.as_ref().unwrap_or(&"sma".to_string()),
2547 param_idx
2548 );
2549 }
2550
2551 if bits == 0x22222222_22222222 {
2552 panic!(
2553 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in lowerband \
2554 with params: fast_period={}, slow_period={}, devup={}, devdn={}, fast_ma_type={}, slow_ma_type={} (param set {})",
2555 test_name, val, bits, i,
2556 params.fast_period.unwrap_or(10),
2557 params.slow_period.unwrap_or(50),
2558 params.devup.unwrap_or(1.0),
2559 params.devdn.unwrap_or(1.0),
2560 params.fast_ma_type.as_ref().unwrap_or(&"sma".to_string()),
2561 params.slow_ma_type.as_ref().unwrap_or(&"sma".to_string()),
2562 param_idx
2563 );
2564 }
2565
2566 if bits == 0x33333333_33333333 {
2567 panic!(
2568 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in lowerband \
2569 with params: fast_period={}, slow_period={}, devup={}, devdn={}, fast_ma_type={}, slow_ma_type={} (param set {})",
2570 test_name, val, bits, i,
2571 params.fast_period.unwrap_or(10),
2572 params.slow_period.unwrap_or(50),
2573 params.devup.unwrap_or(1.0),
2574 params.devdn.unwrap_or(1.0),
2575 params.fast_ma_type.as_ref().unwrap_or(&"sma".to_string()),
2576 params.slow_ma_type.as_ref().unwrap_or(&"sma".to_string()),
2577 param_idx
2578 );
2579 }
2580 }
2581 }
2582
2583 Ok(())
2584 }
2585
2586 #[cfg(not(debug_assertions))]
2587 fn check_mab_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2588 Ok(())
2589 }
2590
2591 #[cfg(debug_assertions)]
2592 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2593 skip_if_unsupported!(kernel, test);
2594
2595 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2596 let c = read_candles_from_csv(file)?;
2597
2598 let test_configs = vec![
2599 ((2, 10, 2), (10, 20, 5), (1.0, 1.0, 0.0), (1.0, 1.0, 0.0)),
2600 ((5, 15, 5), (20, 40, 10), (0.5, 2.0, 0.5), (0.5, 2.0, 0.5)),
2601 (
2602 (20, 40, 10),
2603 (50, 100, 25),
2604 (1.0, 3.0, 1.0),
2605 (1.0, 3.0, 1.0),
2606 ),
2607 ((2, 5, 1), (6, 10, 1), (1.0, 2.0, 0.5), (1.0, 2.0, 0.5)),
2608 ((10, 10, 0), (20, 50, 10), (1.0, 1.0, 0.0), (1.0, 1.0, 0.0)),
2609 ((5, 20, 5), (50, 50, 0), (2.0, 2.0, 0.0), (2.0, 2.0, 0.0)),
2610 ((8, 12, 2), (26, 26, 0), (1.0, 3.0, 0.5), (0.5, 2.0, 0.5)),
2611 ];
2612
2613 for (cfg_idx, &(fast_range, slow_range, devup_range, devdn_range)) in
2614 test_configs.iter().enumerate()
2615 {
2616 let sweep = MabBatchRange {
2617 fast_period: fast_range,
2618 slow_period: slow_range,
2619 devup: devup_range,
2620 devdn: devdn_range,
2621 fast_ma_type: ("sma".to_string(), "sma".to_string(), String::new()),
2622 slow_ma_type: ("sma".to_string(), "sma".to_string(), String::new()),
2623 };
2624
2625 let output = mab_batch_inner(c.close.as_slice(), &sweep, kernel, false)?;
2626
2627 for (idx, &val) in output.upperbands.iter().enumerate() {
2628 if val.is_nan() {
2629 continue;
2630 }
2631
2632 let bits = val.to_bits();
2633 let row = idx / output.cols;
2634 let col = idx % output.cols;
2635 let combo = &output.combos[row];
2636
2637 if bits == 0x11111111_11111111 {
2638 panic!(
2639 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) in upperbands \
2640 at row {} col {} (flat index {}) with params: fast_period={}, slow_period={}, devup={}, devdn={}",
2641 test, cfg_idx, val, bits, row, col, idx,
2642 combo.fast_period.unwrap_or(10),
2643 combo.slow_period.unwrap_or(50),
2644 combo.devup.unwrap_or(1.0),
2645 combo.devdn.unwrap_or(1.0)
2646 );
2647 }
2648
2649 if bits == 0x22222222_22222222 {
2650 panic!(
2651 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) in upperbands \
2652 at row {} col {} (flat index {}) with params: fast_period={}, slow_period={}, devup={}, devdn={}",
2653 test, cfg_idx, val, bits, row, col, idx,
2654 combo.fast_period.unwrap_or(10),
2655 combo.slow_period.unwrap_or(50),
2656 combo.devup.unwrap_or(1.0),
2657 combo.devdn.unwrap_or(1.0)
2658 );
2659 }
2660
2661 if bits == 0x33333333_33333333 {
2662 panic!(
2663 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) in upperbands \
2664 at row {} col {} (flat index {}) with params: fast_period={}, slow_period={}, devup={}, devdn={}",
2665 test, cfg_idx, val, bits, row, col, idx,
2666 combo.fast_period.unwrap_or(10),
2667 combo.slow_period.unwrap_or(50),
2668 combo.devup.unwrap_or(1.0),
2669 combo.devdn.unwrap_or(1.0)
2670 );
2671 }
2672 }
2673
2674 for (idx, &val) in output.middlebands.iter().enumerate() {
2675 if val.is_nan() {
2676 continue;
2677 }
2678
2679 let bits = val.to_bits();
2680 let row = idx / output.cols;
2681 let col = idx % output.cols;
2682 let combo = &output.combos[row];
2683
2684 if bits == 0x11111111_11111111 {
2685 panic!(
2686 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) in middlebands \
2687 at row {} col {} (flat index {}) with params: fast_period={}, slow_period={}, devup={}, devdn={}",
2688 test, cfg_idx, val, bits, row, col, idx,
2689 combo.fast_period.unwrap_or(10),
2690 combo.slow_period.unwrap_or(50),
2691 combo.devup.unwrap_or(1.0),
2692 combo.devdn.unwrap_or(1.0)
2693 );
2694 }
2695
2696 if bits == 0x22222222_22222222 {
2697 panic!(
2698 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) in middlebands \
2699 at row {} col {} (flat index {}) with params: fast_period={}, slow_period={}, devup={}, devdn={}",
2700 test, cfg_idx, val, bits, row, col, idx,
2701 combo.fast_period.unwrap_or(10),
2702 combo.slow_period.unwrap_or(50),
2703 combo.devup.unwrap_or(1.0),
2704 combo.devdn.unwrap_or(1.0)
2705 );
2706 }
2707
2708 if bits == 0x33333333_33333333 {
2709 panic!(
2710 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) in middlebands \
2711 at row {} col {} (flat index {}) with params: fast_period={}, slow_period={}, devup={}, devdn={}",
2712 test, cfg_idx, val, bits, row, col, idx,
2713 combo.fast_period.unwrap_or(10),
2714 combo.slow_period.unwrap_or(50),
2715 combo.devup.unwrap_or(1.0),
2716 combo.devdn.unwrap_or(1.0)
2717 );
2718 }
2719 }
2720
2721 for (idx, &val) in output.lowerbands.iter().enumerate() {
2722 if val.is_nan() {
2723 continue;
2724 }
2725
2726 let bits = val.to_bits();
2727 let row = idx / output.cols;
2728 let col = idx % output.cols;
2729 let combo = &output.combos[row];
2730
2731 if bits == 0x11111111_11111111 {
2732 panic!(
2733 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) in lowerbands \
2734 at row {} col {} (flat index {}) with params: fast_period={}, slow_period={}, devup={}, devdn={}",
2735 test, cfg_idx, val, bits, row, col, idx,
2736 combo.fast_period.unwrap_or(10),
2737 combo.slow_period.unwrap_or(50),
2738 combo.devup.unwrap_or(1.0),
2739 combo.devdn.unwrap_or(1.0)
2740 );
2741 }
2742
2743 if bits == 0x22222222_22222222 {
2744 panic!(
2745 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) in lowerbands \
2746 at row {} col {} (flat index {}) with params: fast_period={}, slow_period={}, devup={}, devdn={}",
2747 test, cfg_idx, val, bits, row, col, idx,
2748 combo.fast_period.unwrap_or(10),
2749 combo.slow_period.unwrap_or(50),
2750 combo.devup.unwrap_or(1.0),
2751 combo.devdn.unwrap_or(1.0)
2752 );
2753 }
2754
2755 if bits == 0x33333333_33333333 {
2756 panic!(
2757 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) in lowerbands \
2758 at row {} col {} (flat index {}) with params: fast_period={}, slow_period={}, devup={}, devdn={}",
2759 test, cfg_idx, val, bits, row, col, idx,
2760 combo.fast_period.unwrap_or(10),
2761 combo.slow_period.unwrap_or(50),
2762 combo.devup.unwrap_or(1.0),
2763 combo.devdn.unwrap_or(1.0)
2764 );
2765 }
2766 }
2767 }
2768
2769 Ok(())
2770 }
2771
2772 #[cfg(not(debug_assertions))]
2773 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2774 Ok(())
2775 }
2776
2777 #[cfg(feature = "proptest")]
2778 #[allow(clippy::float_cmp)]
2779 fn check_mab_property(
2780 test_name: &str,
2781 kernel: Kernel,
2782 ) -> Result<(), Box<dyn std::error::Error>> {
2783 use proptest::prelude::*;
2784 skip_if_unsupported!(kernel, test_name);
2785
2786 let strat = (2usize..=50).prop_flat_map(|slow_period| {
2787 (2usize..=slow_period).prop_flat_map(move |fast_period| {
2788 (
2789 prop::collection::vec(
2790 (1f64..1000f64).prop_filter("finite", |x| x.is_finite()),
2791 slow_period..400,
2792 ),
2793 Just(fast_period),
2794 Just(slow_period),
2795 0.5f64..3.0f64,
2796 0.5f64..3.0f64,
2797 prop::bool::ANY,
2798 prop::bool::ANY,
2799 )
2800 })
2801 });
2802
2803 proptest::test_runner::TestRunner::default()
2804 .run(&strat, |(data, fast_period, slow_period, devup, devdn, fast_is_ema, slow_is_ema)| {
2805 let params = MabParams {
2806 fast_period: Some(fast_period),
2807 slow_period: Some(slow_period),
2808 devup: Some(devup),
2809 devdn: Some(devdn),
2810 fast_ma_type: Some(if fast_is_ema { "ema" } else { "sma" }.to_string()),
2811 slow_ma_type: Some(if slow_is_ema { "ema" } else { "sma" }.to_string()),
2812 };
2813 let input = MabInput::from_slice(&data, params.clone());
2814
2815
2816 let result = mab_with_kernel(&input, kernel).unwrap();
2817
2818
2819 let ref_params = params.clone();
2820 let ref_input = MabInput::from_slice(&data, ref_params);
2821 let ref_result = mab_with_kernel(&ref_input, Kernel::Scalar).unwrap();
2822
2823
2824 let first_valid_idx = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
2825 let warmup_period = first_valid_idx + fast_period.max(slow_period) - 1;
2826
2827 for i in 0..data.len() {
2828 let upper = result.upperband[i];
2829 let middle = result.middleband[i];
2830 let lower = result.lowerband[i];
2831 let ref_upper = ref_result.upperband[i];
2832 let ref_middle = ref_result.middleband[i];
2833 let ref_lower = ref_result.lowerband[i];
2834
2835
2836 if upper.is_nan() {
2837 prop_assert!(ref_upper.is_nan(),
2838 "[{}] NaN mismatch in upperband at idx {}: kernel={:?} has NaN but scalar doesn't",
2839 test_name, i, kernel);
2840 }
2841 if middle.is_nan() {
2842 prop_assert!(ref_middle.is_nan(),
2843 "[{}] NaN mismatch in middleband at idx {}: kernel={:?} has NaN but scalar doesn't",
2844 test_name, i, kernel);
2845 }
2846 if lower.is_nan() {
2847 prop_assert!(ref_lower.is_nan(),
2848 "[{}] NaN mismatch in lowerband at idx {}: kernel={:?} has NaN but scalar doesn't",
2849 test_name, i, kernel);
2850 }
2851
2852
2853 if upper.is_finite() && ref_upper.is_finite() {
2854 let ulp_diff = upper.to_bits().abs_diff(ref_upper.to_bits());
2855 prop_assert!(
2856 (upper - ref_upper).abs() <= 1e-9 || ulp_diff <= 8,
2857 "[{}] Upperband mismatch at idx {}: {} vs {} (ULP={})",
2858 test_name, i, upper, ref_upper, ulp_diff
2859 );
2860 }
2861 if middle.is_finite() && ref_middle.is_finite() {
2862 let ulp_diff = middle.to_bits().abs_diff(ref_middle.to_bits());
2863 prop_assert!(
2864 (middle - ref_middle).abs() <= 1e-9 || ulp_diff <= 8,
2865 "[{}] Middleband mismatch at idx {}: {} vs {} (ULP={})",
2866 test_name, i, middle, ref_middle, ulp_diff
2867 );
2868 }
2869 if lower.is_finite() && ref_lower.is_finite() {
2870 let ulp_diff = lower.to_bits().abs_diff(ref_lower.to_bits());
2871 prop_assert!(
2872 (lower - ref_lower).abs() <= 1e-9 || ulp_diff <= 8,
2873 "[{}] Lowerband mismatch at idx {}: {} vs {} (ULP={})",
2874 test_name, i, lower, ref_lower, ulp_diff
2875 );
2876 }
2877 }
2878
2879
2880 for i in 0..warmup_period.min(data.len()) {
2881 prop_assert!(
2882 result.upperband[i].is_nan(),
2883 "[{}] Expected NaN in upperband during warmup at idx {} (warmup={})",
2884 test_name, i, warmup_period
2885 );
2886 prop_assert!(
2887 result.middleband[i].is_nan(),
2888 "[{}] Expected NaN in middleband during warmup at idx {} (warmup={})",
2889 test_name, i, warmup_period
2890 );
2891 prop_assert!(
2892 result.lowerband[i].is_nan(),
2893 "[{}] Expected NaN in lowerband during warmup at idx {} (warmup={})",
2894 test_name, i, warmup_period
2895 );
2896 }
2897
2898
2899
2900 let first_valid_output = warmup_period + fast_period - 1;
2901 if first_valid_output < data.len() {
2902 for i in first_valid_output..data.len() {
2903 prop_assert!(
2904 result.upperband[i].is_finite(),
2905 "[{}] Non-finite value in upperband at idx {} after warmup",
2906 test_name, i
2907 );
2908 prop_assert!(
2909 result.middleband[i].is_finite(),
2910 "[{}] Non-finite value in middleband at idx {} after warmup",
2911 test_name, i
2912 );
2913 prop_assert!(
2914 result.lowerband[i].is_finite(),
2915 "[{}] Non-finite value in lowerband at idx {} after warmup",
2916 test_name, i
2917 );
2918 }
2919 }
2920
2921
2922 for i in first_valid_output..data.len() {
2923 let upper = result.upperband[i];
2924 let middle = result.middleband[i];
2925 let lower = result.lowerband[i];
2926
2927 if upper.is_finite() && middle.is_finite() && lower.is_finite() {
2928 prop_assert!(
2929 upper >= middle - 1e-10,
2930 "[{}] Band ordering violated: upper {} < middle {} at idx {}",
2931 test_name, upper, middle, i
2932 );
2933 prop_assert!(
2934 middle >= lower - 1e-10,
2935 "[{}] Band ordering violated: middle {} < lower {} at idx {}",
2936 test_name, middle, lower, i
2937 );
2938 }
2939 }
2940
2941
2942 if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10) && data.len() > first_valid_output {
2943
2944 for i in first_valid_output..data.len() {
2945 let upper = result.upperband[i];
2946 let middle = result.middleband[i];
2947 let lower = result.lowerband[i];
2948
2949 if upper.is_finite() && middle.is_finite() && lower.is_finite() {
2950 prop_assert!(
2951 (upper - middle).abs() <= 1e-9,
2952 "[{}] Constant data: upper {} != middle {} at idx {}",
2953 test_name, upper, middle, i
2954 );
2955 prop_assert!(
2956 (middle - lower).abs() <= 1e-9,
2957 "[{}] Constant data: middle {} != lower {} at idx {}",
2958 test_name, middle, lower, i
2959 );
2960 }
2961 }
2962 }
2963
2964
2965
2966 for i in first_valid_output..data.len() {
2967 let upper = result.upperband[i];
2968 let middle = result.middleband[i];
2969 let lower = result.lowerband[i];
2970
2971 if upper.is_finite() && middle.is_finite() && lower.is_finite() {
2972 let upper_spread = upper - middle;
2973 let lower_spread = middle - lower;
2974
2975
2976 if upper_spread > 1e-10 && lower_spread > 1e-10 {
2977 let spread_ratio = upper_spread / lower_spread;
2978 let multiplier_ratio = devup / devdn;
2979
2980 prop_assert!(
2981 (spread_ratio - multiplier_ratio).abs() <= multiplier_ratio * 0.05,
2982 "[{}] Deviation multiplier ratio mismatch at idx {}: spread_ratio={} vs multiplier_ratio={}",
2983 test_name, i, spread_ratio, multiplier_ratio
2984 );
2985 }
2986 }
2987 }
2988
2989
2990
2991 use crate::indicators::sma::{sma, SmaInput, SmaParams};
2992 use crate::indicators::ema::{ema, EmaInput, EmaParams};
2993
2994 let fast_ma = if fast_is_ema {
2995 let ema_params = EmaParams { period: Some(fast_period) };
2996 let ema_input = EmaInput::from_slice(&data, ema_params);
2997 ema(&ema_input).unwrap().values
2998 } else {
2999 let sma_params = SmaParams { period: Some(fast_period) };
3000 let sma_input = SmaInput::from_slice(&data, sma_params);
3001 sma(&sma_input).unwrap().values
3002 };
3003
3004 for i in first_valid_output..data.len() {
3005 if result.middleband[i].is_finite() && fast_ma[i].is_finite() {
3006 prop_assert!(
3007 (result.middleband[i] - fast_ma[i]).abs() <= 1e-9,
3008 "[{}] Middle band != fast MA at idx {}: {} vs {}",
3009 test_name, i, result.middleband[i], fast_ma[i]
3010 );
3011 }
3012 }
3013
3014
3015
3016 for i in first_valid_output..data.len() {
3017 let upper = result.upperband[i];
3018 let middle = result.middleband[i];
3019 let lower = result.lowerband[i];
3020
3021 if upper.is_finite() && middle.is_finite() && lower.is_finite() {
3022 let upper_spread = upper - middle;
3023 let lower_spread = middle - lower;
3024
3025 prop_assert!(
3026 upper_spread >= -1e-10,
3027 "[{}] Negative upper spread at idx {}: {}",
3028 test_name, i, upper_spread
3029 );
3030 prop_assert!(
3031 lower_spread >= -1e-10,
3032 "[{}] Negative lower spread at idx {}: {}",
3033 test_name, i, lower_spread
3034 );
3035
3036
3037 let data_range = data.iter()
3038 .filter(|x| x.is_finite())
3039 .fold((f64::INFINITY, f64::NEG_INFINITY), |(min, max), &x| {
3040 (min.min(x), max.max(x))
3041 });
3042 let range_span = data_range.1 - data_range.0;
3043
3044
3045
3046 if range_span > 0.0 {
3047 prop_assert!(
3048 upper_spread <= range_span * devup * 10.0,
3049 "[{}] Upper spread unreasonably large at idx {}: {} (range_span={})",
3050 test_name, i, upper_spread, range_span
3051 );
3052 prop_assert!(
3053 lower_spread <= range_span * devdn * 10.0,
3054 "[{}] Lower spread unreasonably large at idx {}: {} (range_span={})",
3055 test_name, i, lower_spread, range_span
3056 );
3057 }
3058 }
3059 }
3060
3061
3062
3063 let invalid_params = vec![
3064 MabParams {
3065 fast_period: Some(0),
3066 slow_period: Some(10),
3067 ..Default::default()
3068 },
3069 MabParams {
3070 fast_period: Some(10),
3071 slow_period: Some(0),
3072 ..Default::default()
3073 },
3074 MabParams {
3075 fast_period: Some(data.len() + 1),
3076 slow_period: Some(10),
3077 ..Default::default()
3078 },
3079 ];
3080
3081 for invalid_param in invalid_params {
3082 let invalid_input = MabInput::from_slice(&data, invalid_param);
3083 let invalid_result = mab_with_kernel(&invalid_input, kernel);
3084 prop_assert!(
3085 invalid_result.is_err(),
3086 "[{}] Expected error for invalid parameters but got Ok",
3087 test_name
3088 );
3089 }
3090
3091
3092
3093
3094 if fast_period == slow_period && data.len() > first_valid_output {
3095
3096 let equal_params = MabParams {
3097 fast_period: Some(fast_period),
3098 slow_period: Some(fast_period),
3099 devup: Some(devup),
3100 devdn: Some(devdn),
3101 fast_ma_type: params.fast_ma_type.clone(),
3102 slow_ma_type: params.slow_ma_type.clone(),
3103 };
3104 let equal_input = MabInput::from_slice(&data, equal_params);
3105 let equal_result = mab_with_kernel(&equal_input, kernel).unwrap();
3106
3107
3108
3109 if params.fast_ma_type == params.slow_ma_type {
3110 for i in first_valid_output..data.len().min(first_valid_output + 10) {
3111 if equal_result.upperband[i].is_finite() &&
3112 equal_result.middleband[i].is_finite() &&
3113 equal_result.lowerband[i].is_finite() {
3114 let upper_spread = equal_result.upperband[i] - equal_result.middleband[i];
3115 let lower_spread = equal_result.middleband[i] - equal_result.lowerband[i];
3116
3117
3118 prop_assert!(
3119 upper_spread <= 1e-6 || upper_spread <= equal_result.middleband[i].abs() * 1e-6,
3120 "[{}] Equal periods: upper spread too large at idx {}: {}",
3121 test_name, i, upper_spread
3122 );
3123 prop_assert!(
3124 lower_spread <= 1e-6 || lower_spread <= equal_result.middleband[i].abs() * 1e-6,
3125 "[{}] Equal periods: lower spread too large at idx {}: {}",
3126 test_name, i, lower_spread
3127 );
3128 }
3129 }
3130 }
3131 }
3132
3133 Ok(())
3134 })?;
3135
3136 Ok(())
3137 }
3138
3139 fn check_mab_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3140 skip_if_unsupported!(kernel, test_name);
3141 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3142 let candles = read_candles_from_csv(file_path)?;
3143 let default_params = MabParams {
3144 fast_period: None,
3145 ..MabParams::default()
3146 };
3147 let input = MabInput::from_candles(&candles, "close", default_params);
3148 let output = mab_with_kernel(&input, kernel)?;
3149 assert_eq!(output.upperband.len(), candles.close.len());
3150 Ok(())
3151 }
3152
3153 fn check_mab_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3154 skip_if_unsupported!(kernel, test_name);
3155 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3156 let candles = read_candles_from_csv(file_path)?;
3157 let params = MabParams::default();
3158 let input = MabInput::from_candles(&candles, "close", params);
3159 let result = mab_with_kernel(&input, kernel)?;
3160
3161 let expected_upper_last_five = [
3162 64002.843463352016,
3163 63976.62699738246,
3164 63949.00496307154,
3165 63912.13708526151,
3166 63828.40371728143,
3167 ];
3168 let expected_middle_last_five = [
3169 59213.90000000002,
3170 59180.800000000025,
3171 59161.40000000002,
3172 59132.00000000002,
3173 59042.40000000002,
3174 ];
3175 let expected_lower_last_five = [
3176 59350.676536647945,
3177 59296.93300261751,
3178 59252.75503692843,
3179 59190.30291473845,
3180 59070.11628271853,
3181 ];
3182
3183 let len = result.upperband.len();
3184 for i in 0..5 {
3185 let idx = len - 5 + i;
3186 assert!(
3187 (result.upperband[idx] - expected_upper_last_five[i]).abs() < 1e-4,
3188 "[{}] Upper band mismatch at index {}: {} vs expected {}",
3189 test_name,
3190 i,
3191 result.upperband[idx],
3192 expected_upper_last_five[i]
3193 );
3194 assert!(
3195 (result.middleband[idx] - expected_middle_last_five[i]).abs() < 1e-4,
3196 "[{}] Middle band mismatch at index {}: {} vs expected {}",
3197 test_name,
3198 i,
3199 result.middleband[idx],
3200 expected_middle_last_five[i]
3201 );
3202 assert!(
3203 (result.lowerband[idx] - expected_lower_last_five[i]).abs() < 1e-4,
3204 "[{}] Lower band mismatch at index {}: {} vs expected {}",
3205 test_name,
3206 i,
3207 result.lowerband[idx],
3208 expected_lower_last_five[i]
3209 );
3210 }
3211 Ok(())
3212 }
3213
3214 fn check_mab_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3215 skip_if_unsupported!(kernel, test_name);
3216 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3217 let candles = read_candles_from_csv(file_path)?;
3218 let input = MabInput::with_default_candles(&candles);
3219 let output = mab_with_kernel(&input, kernel)?;
3220 assert_eq!(output.upperband.len(), candles.close.len());
3221 Ok(())
3222 }
3223
3224 fn check_mab_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3225 skip_if_unsupported!(kernel, test_name);
3226 let input_data = [10.0, 20.0, 30.0];
3227 let params = MabParams {
3228 fast_period: Some(0),
3229 slow_period: Some(5),
3230 ..MabParams::default()
3231 };
3232 let input = MabInput::from_slice(&input_data, params);
3233 let res = mab_with_kernel(&input, kernel);
3234 assert!(
3235 res.is_err(),
3236 "[{}] Expected error for zero fast period",
3237 test_name
3238 );
3239
3240 let params2 = MabParams {
3241 fast_period: Some(5),
3242 slow_period: Some(0),
3243 ..MabParams::default()
3244 };
3245 let input2 = MabInput::from_slice(&input_data, params2);
3246 let res2 = mab_with_kernel(&input2, kernel);
3247 assert!(
3248 res2.is_err(),
3249 "[{}] Expected error for zero slow period",
3250 test_name
3251 );
3252 Ok(())
3253 }
3254
3255 fn check_mab_period_exceeds_length(
3256 test_name: &str,
3257 kernel: Kernel,
3258 ) -> Result<(), Box<dyn Error>> {
3259 skip_if_unsupported!(kernel, test_name);
3260 let data_small = [10.0, 20.0, 30.0];
3261 let params = MabParams {
3262 fast_period: Some(2),
3263 slow_period: Some(10),
3264 ..MabParams::default()
3265 };
3266 let input = MabInput::from_slice(&data_small, params);
3267 let res = mab_with_kernel(&input, kernel);
3268 assert!(
3269 res.is_err(),
3270 "[{}] Expected error when period exceeds data length",
3271 test_name
3272 );
3273 Ok(())
3274 }
3275
3276 fn check_mab_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3277 skip_if_unsupported!(kernel, test_name);
3278 let single_point = [42.0];
3279 let params = MabParams {
3280 fast_period: Some(10),
3281 slow_period: Some(20),
3282 ..MabParams::default()
3283 };
3284 let input = MabInput::from_slice(&single_point, params);
3285 let res = mab_with_kernel(&input, kernel);
3286 assert!(
3287 res.is_err(),
3288 "[{}] Expected error for insufficient data",
3289 test_name
3290 );
3291 Ok(())
3292 }
3293
3294 fn check_mab_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3295 skip_if_unsupported!(kernel, test_name);
3296 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3297 let candles = read_candles_from_csv(file_path)?;
3298 let params = MabParams::default();
3299 let first_input = MabInput::from_candles(&candles, "close", params.clone());
3300 let first_result = mab_with_kernel(&first_input, kernel)?;
3301
3302 let second_input = MabInput::from_slice(&first_result.upperband, params);
3303 let second_result = mab_with_kernel(&second_input, kernel)?;
3304 assert_eq!(second_result.upperband.len(), first_result.upperband.len());
3305
3306 let non_nan_count = second_result
3307 .upperband
3308 .iter()
3309 .skip(100)
3310 .filter(|x| !x.is_nan())
3311 .count();
3312 assert!(
3313 non_nan_count > 0,
3314 "[{}] Second calculation produced all NaN values",
3315 test_name
3316 );
3317 Ok(())
3318 }
3319
3320 fn check_mab_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3321 skip_if_unsupported!(kernel, test_name);
3322 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3323 let candles = read_candles_from_csv(file_path)?;
3324 let input = MabInput::from_candles(&candles, "close", MabParams::default());
3325 let res = mab_with_kernel(&input, kernel)?;
3326
3327 for i in 100..res.upperband.len().min(200) {
3328 assert!(
3329 !res.upperband[i].is_nan(),
3330 "[{}] Unexpected NaN in upper band at index {}",
3331 test_name,
3332 i
3333 );
3334 assert!(
3335 !res.middleband[i].is_nan(),
3336 "[{}] Unexpected NaN in middle band at index {}",
3337 test_name,
3338 i
3339 );
3340 assert!(
3341 !res.lowerband[i].is_nan(),
3342 "[{}] Unexpected NaN in lower band at index {}",
3343 test_name,
3344 i
3345 );
3346 }
3347 Ok(())
3348 }
3349
3350 #[allow(dead_code)]
3351 fn check_mab_streaming(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
3352 Ok(())
3353 }
3354
3355 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3356 skip_if_unsupported!(kernel, test);
3357 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3358 let c = read_candles_from_csv(file)?;
3359
3360 let sweep = MabBatchRange {
3361 fast_period: (10, 10, 0),
3362 slow_period: (50, 50, 0),
3363 devup: (1.0, 1.0, 0.0),
3364 devdn: (1.0, 1.0, 0.0),
3365 fast_ma_type: ("sma".to_string(), "sma".to_string(), String::new()),
3366 slow_ma_type: ("sma".to_string(), "sma".to_string(), String::new()),
3367 };
3368
3369 let output = mab_batch_inner(c.close.as_slice(), &sweep, kernel, false)?;
3370
3371 assert_eq!(
3372 output.rows, 1,
3373 "[{}] Expected 1 row for default params",
3374 test
3375 );
3376 assert_eq!(
3377 output.cols,
3378 c.close.len(),
3379 "[{}] Cols should match input length",
3380 test
3381 );
3382
3383 let expected_upper = [
3384 64002.843463352016,
3385 63976.62699738246,
3386 63949.00496307154,
3387 63912.13708526151,
3388 63828.40371728143,
3389 ];
3390 let expected_middle = [
3391 59213.90000000002,
3392 59180.800000000025,
3393 59161.40000000002,
3394 59132.00000000002,
3395 59042.40000000002,
3396 ];
3397 let expected_lower = [
3398 59350.676536647945,
3399 59296.93300261751,
3400 59252.75503692843,
3401 59190.30291473845,
3402 59070.11628271853,
3403 ];
3404
3405 let start = output.cols - 5;
3406 for i in 0..5 {
3407 let idx = start + i;
3408 assert!(
3409 (output.upperbands[idx] - expected_upper[i]).abs() < 1e-4,
3410 "[{}] batch upper mismatch at idx {}: {} vs expected {}",
3411 test,
3412 i,
3413 output.upperbands[idx],
3414 expected_upper[i]
3415 );
3416 assert!(
3417 (output.middlebands[idx] - expected_middle[i]).abs() < 1e-4,
3418 "[{}] batch middle mismatch at idx {}: {} vs expected {}",
3419 test,
3420 i,
3421 output.middlebands[idx],
3422 expected_middle[i]
3423 );
3424 assert!(
3425 (output.lowerbands[idx] - expected_lower[i]).abs() < 1e-4,
3426 "[{}] batch lower mismatch at idx {}: {} vs expected {}",
3427 test,
3428 i,
3429 output.lowerbands[idx],
3430 expected_lower[i]
3431 );
3432 }
3433 Ok(())
3434 }
3435
3436 fn check_batch_grid_varying_fast_period(
3437 test: &str,
3438 kernel: Kernel,
3439 ) -> Result<(), Box<dyn Error>> {
3440 skip_if_unsupported!(kernel, test);
3441 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3442 let c = read_candles_from_csv(file)?;
3443
3444 let sweep = MabBatchRange {
3445 fast_period: (10, 12, 1),
3446 slow_period: (50, 50, 0),
3447 devup: (1.0, 1.0, 0.0),
3448 devdn: (1.0, 1.0, 0.0),
3449 fast_ma_type: ("sma".to_string(), "sma".to_string(), String::new()),
3450 slow_ma_type: ("sma".to_string(), "sma".to_string(), String::new()),
3451 };
3452
3453 let output = mab_batch_inner(c.close.as_slice(), &sweep, kernel, false)?;
3454
3455 assert_eq!(
3456 output.rows, 3,
3457 "[{}] Expected 3 rows for fast period 10-12",
3458 test
3459 );
3460 assert_eq!(
3461 output.combos.len(),
3462 3,
3463 "[{}] Expected 3 parameter combinations",
3464 test
3465 );
3466
3467 assert_eq!(
3468 output.combos[0].fast_period,
3469 Some(10),
3470 "[{}] First combo fast period",
3471 test
3472 );
3473 assert_eq!(
3474 output.combos[1].fast_period,
3475 Some(11),
3476 "[{}] Second combo fast period",
3477 test
3478 );
3479 assert_eq!(
3480 output.combos[2].fast_period,
3481 Some(12),
3482 "[{}] Third combo fast period",
3483 test
3484 );
3485
3486 for row in 0..3 {
3487 let row_start = row * output.cols;
3488 let row_data = &output.upperbands[row_start..row_start + output.cols];
3489
3490 let valid_count = row_data.iter().skip(100).filter(|x| !x.is_nan()).count();
3491 assert!(
3492 valid_count > 0,
3493 "[{}] Row {} should have valid values",
3494 test,
3495 row
3496 );
3497 }
3498 Ok(())
3499 }
3500
3501 macro_rules! generate_all_mab_tests {
3502 ($($test_fn:ident),*) => {
3503 paste::paste! {
3504 $(
3505 #[test]
3506 fn [<$test_fn _scalar_f64>]() {
3507 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
3508 }
3509 )*
3510 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3511 $(
3512 #[test]
3513 fn [<$test_fn _avx2_f64>]() {
3514 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
3515 }
3516 #[test]
3517 fn [<$test_fn _avx512_f64>]() {
3518 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
3519 }
3520 )*
3521 }
3522 }
3523 }
3524
3525 macro_rules! gen_batch_tests {
3526 ($fn_name:ident) => {
3527 paste::paste! {
3528 #[test] fn [<$fn_name _scalar>]() {
3529 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
3530 }
3531 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3532 #[test] fn [<$fn_name _avx2>]() {
3533 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
3534 }
3535 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3536 #[test] fn [<$fn_name _avx512>]() {
3537 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
3538 }
3539 #[test] fn [<$fn_name _auto_detect>]() {
3540 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
3541 }
3542 }
3543 };
3544 }
3545
3546 generate_all_mab_tests!(
3547 check_mab_no_poison,
3548 check_mab_partial_params,
3549 check_mab_accuracy,
3550 check_mab_default_candles,
3551 check_mab_zero_period,
3552 check_mab_period_exceeds_length,
3553 check_mab_very_small_dataset,
3554 check_mab_reinput,
3555 check_mab_nan_handling
3556 );
3557
3558 #[cfg(feature = "proptest")]
3559 generate_all_mab_tests!(check_mab_property);
3560
3561 gen_batch_tests!(check_batch_no_poison);
3562 gen_batch_tests!(check_batch_default_row);
3563 gen_batch_tests!(check_batch_grid_varying_fast_period);
3564}