Skip to main content

vector_ta/indicators/
mab.rs

1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::utilities::data_loader::{source_type, Candles};
16use crate::utilities::enums::Kernel;
17use crate::utilities::helpers::{
18    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
19    make_uninit_matrix,
20};
21#[cfg(feature = "python")]
22use crate::utilities::kernel_validation::validate_kernel;
23use aligned_vec::{AVec, CACHELINE_ALIGN};
24#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
25use core::arch::x86_64::*;
26#[cfg(not(target_arch = "wasm32"))]
27use rayon::prelude::*;
28use std::convert::AsRef;
29use std::error::Error;
30use std::mem::MaybeUninit;
31use thiserror::Error;
32
33impl<'a> AsRef<[f64]> for MabInput<'a> {
34    #[inline(always)]
35    fn as_ref(&self) -> &[f64] {
36        match &self.data {
37            MabData::Slice(slice) => slice,
38            MabData::Candles { candles, source } => source_type(candles, source),
39        }
40    }
41}
42
43#[derive(Debug, Clone)]
44pub enum MabData<'a> {
45    Candles {
46        candles: &'a Candles,
47        source: &'a str,
48    },
49    Slice(&'a [f64]),
50}
51
52#[derive(Debug, Clone)]
53pub struct MabOutput {
54    pub upperband: Vec<f64>,
55    pub middleband: Vec<f64>,
56    pub lowerband: Vec<f64>,
57}
58
59#[derive(Debug, Clone)]
60#[cfg_attr(
61    all(target_arch = "wasm32", feature = "wasm"),
62    derive(Serialize, Deserialize)
63)]
64pub struct MabParams {
65    pub fast_period: Option<usize>,
66    pub slow_period: Option<usize>,
67    pub devup: Option<f64>,
68    pub devdn: Option<f64>,
69    pub fast_ma_type: Option<String>,
70    pub slow_ma_type: Option<String>,
71}
72
73impl Default for MabParams {
74    fn default() -> Self {
75        Self {
76            fast_period: Some(10),
77            slow_period: Some(50),
78            devup: Some(1.0),
79            devdn: Some(1.0),
80            fast_ma_type: Some("sma".to_string()),
81            slow_ma_type: Some("sma".to_string()),
82        }
83    }
84}
85
86#[derive(Debug, Clone)]
87pub struct MabInput<'a> {
88    pub data: MabData<'a>,
89    pub params: MabParams,
90}
91
92impl<'a> MabInput<'a> {
93    pub fn from_candles(candles: &'a Candles, source: &'a str, params: MabParams) -> Self {
94        Self {
95            data: MabData::Candles { candles, source },
96            params,
97        }
98    }
99
100    pub fn from_slice(slice: &'a [f64], params: MabParams) -> Self {
101        Self {
102            data: MabData::Slice(slice),
103            params,
104        }
105    }
106
107    pub fn with_default_params(data: MabData<'a>) -> Self {
108        Self {
109            data,
110            params: MabParams::default(),
111        }
112    }
113
114    pub fn with_default_candles(candles: &'a Candles) -> Self {
115        Self::from_candles(candles, "close", MabParams::default())
116    }
117
118    pub fn get_fast_period(&self) -> usize {
119        self.params.fast_period.unwrap_or(10)
120    }
121
122    pub fn get_slow_period(&self) -> usize {
123        self.params.slow_period.unwrap_or(50)
124    }
125
126    pub fn get_devup(&self) -> f64 {
127        self.params.devup.unwrap_or(1.0)
128    }
129
130    pub fn get_devdn(&self) -> f64 {
131        self.params.devdn.unwrap_or(1.0)
132    }
133
134    pub fn get_fast_ma_type(&self) -> &str {
135        self.params
136            .fast_ma_type
137            .as_ref()
138            .map(|s| s.as_str())
139            .unwrap_or("sma")
140    }
141
142    pub fn get_slow_ma_type(&self) -> &str {
143        self.params
144            .slow_ma_type
145            .as_ref()
146            .map(|s| s.as_str())
147            .unwrap_or("sma")
148    }
149}
150
151#[derive(Error, Debug)]
152pub enum MabError {
153    #[error("mab: Input data slice is empty.")]
154    EmptyInputData,
155    #[error("mab: All values are NaN.")]
156    AllValuesNaN,
157    #[error("mab: Invalid period: fast={fast} slow={slow} len={data_len}")]
158    InvalidPeriod {
159        fast: usize,
160        slow: usize,
161        data_len: usize,
162    },
163    #[error("mab: Not enough valid data: need={needed} valid={valid}")]
164    NotEnoughValidData { needed: usize, valid: usize },
165    #[error(
166        "mab: Output length mismatch: expected={expected} upper={upper_len} middle={middle_len} lower={lower_len}"
167    )]
168    OutputLengthMismatch {
169        upper_len: usize,
170        middle_len: usize,
171        lower_len: usize,
172        expected: usize,
173    },
174    #[error("mab: Invalid range (start={start}, end={end}, step={step})")]
175    InvalidRange {
176        start: usize,
177        end: usize,
178        step: usize,
179    },
180    #[error("mab: Invalid range (f64) (start={start}, end={end}, step={step})")]
181    InvalidRangeF64 { start: f64, end: f64, step: f64 },
182    #[error("mab: non-batch kernel passed to batch path: {0:?}")]
183    InvalidKernelForBatch(Kernel),
184}
185
186#[inline(always)]
187fn mab_validate(input: &MabInput) -> Result<usize, MabError> {
188    let data = input.as_ref();
189    if data.is_empty() {
190        return Err(MabError::EmptyInputData);
191    }
192    let fast = input.get_fast_period();
193    let slow = input.get_slow_period();
194    if fast == 0 || slow == 0 || fast > data.len() || slow > data.len() {
195        return Err(MabError::InvalidPeriod {
196            fast,
197            slow,
198            data_len: data.len(),
199        });
200    }
201
202    let first_valid = data
203        .iter()
204        .position(|&x| !x.is_nan())
205        .ok_or(MabError::AllValuesNaN)?;
206    let max_period = fast.max(slow);
207    if data.len() - first_valid < max_period {
208        return Err(MabError::NotEnoughValidData {
209            needed: max_period,
210            valid: data.len() - first_valid,
211        });
212    }
213    Ok(first_valid)
214}
215
216#[inline(always)]
217fn mab_prepare<'a>(
218    input: &'a MabInput,
219    kernel: Kernel,
220) -> Result<(&'a [f64], usize, Kernel, usize, usize, f64, f64), MabError> {
221    let first_valid = mab_validate(input)?;
222    let data = input.as_ref();
223    let fast = input.get_fast_period();
224    let slow = input.get_slow_period();
225    let devup = input.get_devup();
226    let devdn = input.get_devdn();
227    let chosen = match kernel {
228        Kernel::Auto => Kernel::Scalar,
229        _ => kernel,
230    };
231    let warmup = first_valid + fast.max(slow) - 1;
232    Ok((data, warmup, chosen, fast, slow, devup, devdn))
233}
234
235#[inline(always)]
236fn mab_prepare2<'a>(
237    input: &'a MabInput,
238    kernel: Kernel,
239) -> Result<(&'a [f64], Kernel, usize, usize, usize, f64, f64), MabError> {
240    let data = input.as_ref();
241    if data.is_empty() {
242        return Err(MabError::EmptyInputData);
243    }
244    let fast = input.get_fast_period();
245    let slow = input.get_slow_period();
246    if fast == 0 || slow == 0 || fast > data.len() || slow > data.len() {
247        return Err(MabError::InvalidPeriod {
248            fast,
249            slow,
250            data_len: data.len(),
251        });
252    }
253    let first = data
254        .iter()
255        .position(|&x| !x.is_nan())
256        .ok_or(MabError::AllValuesNaN)?;
257    let need = fast.max(slow);
258
259    let need_total = need + fast - 1;
260    if data.len() - first < need_total {
261        return Err(MabError::NotEnoughValidData {
262            needed: need_total,
263            valid: data.len() - first,
264        });
265    }
266    let chosen = match kernel {
267        Kernel::Auto => Kernel::Scalar,
268        k => k,
269    };
270    let warmup = first + need_total - 1;
271    Ok((
272        data,
273        chosen,
274        first,
275        warmup,
276        fast,
277        input.get_devup(),
278        input.get_devdn(),
279    ))
280}
281
282pub fn mab_into_slice(
283    upper_dst: &mut [f64],
284    middle_dst: &mut [f64],
285    lower_dst: &mut [f64],
286    input: &MabInput,
287    kern: Kernel,
288) -> Result<(), MabError> {
289    use crate::indicators::ema::{ema, EmaInput, EmaParams};
290    use crate::indicators::sma::{sma, SmaInput, SmaParams};
291
292    let (data, chosen, first, warmup, fast_period, devup, devdn) = mab_prepare2(input, kern)?;
293    let slow_period = input.get_slow_period();
294
295    let n = data.len();
296    if upper_dst.len() != n || middle_dst.len() != n || lower_dst.len() != n {
297        return Err(MabError::OutputLengthMismatch {
298            upper_len: upper_dst.len(),
299            middle_len: middle_dst.len(),
300            lower_len: lower_dst.len(),
301            expected: n,
302        });
303    }
304
305    let fast_ma_type = input.get_fast_ma_type();
306    let slow_ma_type = input.get_slow_ma_type();
307
308    let fast_ma = match fast_ma_type {
309        "ema" => {
310            let params = EmaParams {
311                period: Some(fast_period),
312            };
313            ema(&EmaInput::from_slice(data, params))
314                .map_err(|_| MabError::NotEnoughValidData {
315                    needed: fast_period,
316                    valid: n - first,
317                })?
318                .values
319        }
320        _ => {
321            let params = SmaParams {
322                period: Some(fast_period),
323            };
324            sma(&SmaInput::from_slice(data, params))
325                .map_err(|_| MabError::NotEnoughValidData {
326                    needed: fast_period,
327                    valid: n - first,
328                })?
329                .values
330        }
331    };
332
333    let slow_ma = match slow_ma_type {
334        "ema" => {
335            let params = EmaParams {
336                period: Some(slow_period),
337            };
338            ema(&EmaInput::from_slice(data, params))
339                .map_err(|_| MabError::NotEnoughValidData {
340                    needed: slow_period,
341                    valid: n - first,
342                })?
343                .values
344        }
345        _ => {
346            let params = SmaParams {
347                period: Some(slow_period),
348            };
349            sma(&SmaInput::from_slice(data, params))
350                .map_err(|_| MabError::NotEnoughValidData {
351                    needed: slow_period,
352                    valid: n - first,
353                })?
354                .values
355        }
356    };
357
358    let warmup_end = (warmup + 1).min(n);
359    for dst in [&mut *upper_dst, &mut *middle_dst, &mut *lower_dst] {
360        for v in &mut dst[..warmup_end] {
361            *v = f64::NAN;
362        }
363    }
364
365    mab_compute_into(
366        &fast_ma,
367        &slow_ma,
368        fast_period,
369        devup,
370        devdn,
371        warmup + 1,
372        chosen,
373        upper_dst,
374        middle_dst,
375        lower_dst,
376    );
377
378    Ok(())
379}
380
381#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
382pub fn mab_into(
383    input: &MabInput,
384    upper_dst: &mut [f64],
385    middle_dst: &mut [f64],
386    lower_dst: &mut [f64],
387) -> Result<(), MabError> {
388    mab_into_slice(upper_dst, middle_dst, lower_dst, input, Kernel::Auto)
389}
390
391pub fn mab(input: &MabInput) -> Result<MabOutput, MabError> {
392    let (_data, _warmup, kernel, _fast, _slow, _devup, _devdn) = mab_prepare(input, Kernel::Auto)?;
393    mab_with_kernel(input, kernel)
394}
395
396pub fn mab_with_kernel(input: &MabInput, kernel: Kernel) -> Result<MabOutput, MabError> {
397    let data = input.as_ref();
398    let (_, _, _, warmup, _, _, _) = mab_prepare2(input, kernel)?;
399
400    let mut upperband = alloc_with_nan_prefix(data.len(), warmup);
401    let mut middleband = alloc_with_nan_prefix(data.len(), warmup);
402    let mut lowerband = alloc_with_nan_prefix(data.len(), warmup);
403
404    mab_into_slice(
405        &mut upperband,
406        &mut middleband,
407        &mut lowerband,
408        input,
409        kernel,
410    )?;
411
412    Ok(MabOutput {
413        upperband,
414        middleband,
415        lowerband,
416    })
417}
418
419#[cfg(test)]
420mod into_parity_tests {
421    use super::*;
422
423    fn eq_or_both_nan(a: f64, b: f64) -> bool {
424        (a.is_nan() && b.is_nan()) || (a == b)
425    }
426
427    #[test]
428    fn test_mab_into_matches_api() {
429        let n = 256usize;
430        let mut data = vec![f64::NAN; n];
431        for i in 5..n {
432            data[i] = (i as f64).sin() * 0.5 + (i as f64).cos() * 0.25;
433        }
434
435        let input = MabInput::from_slice(&data, MabParams::default());
436
437        let base = mab(&input).expect("mab baseline should succeed");
438
439        let mut up = vec![0.0; n];
440        let mut mid = vec![0.0; n];
441        let mut lo = vec![0.0; n];
442
443        #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
444        {
445            mab_into(&input, &mut up, &mut mid, &mut lo).expect("mab_into should succeed");
446        }
447
448        assert_eq!(base.upperband.len(), n);
449        assert_eq!(base.middleband.len(), n);
450        assert_eq!(base.lowerband.len(), n);
451        assert_eq!(up.len(), n);
452        assert_eq!(mid.len(), n);
453        assert_eq!(lo.len(), n);
454
455        for i in 0..n {
456            assert!(
457                eq_or_both_nan(base.upperband[i], up[i]),
458                "upper mismatch at {}: base={:?} into={:?}",
459                i,
460                base.upperband[i],
461                up[i]
462            );
463            assert!(
464                eq_or_both_nan(base.middleband[i], mid[i]),
465                "middle mismatch at {}: base={:?} into={:?}",
466                i,
467                base.middleband[i],
468                mid[i]
469            );
470            assert!(
471                eq_or_both_nan(base.lowerband[i], lo[i]),
472                "lower mismatch at {}: base={:?} into={:?}",
473                i,
474                base.lowerband[i],
475                lo[i]
476            );
477        }
478    }
479}
480
481#[inline]
482fn mab_compute_into(
483    fast_ma: &[f64],
484    slow_ma: &[f64],
485    fast_period: usize,
486    devup: f64,
487    devdn: f64,
488    first_output: usize,
489    kernel: Kernel,
490    upper: &mut [f64],
491    mid: &mut [f64],
492    lower: &mut [f64],
493) {
494    unsafe {
495        match kernel {
496            Kernel::Scalar | Kernel::ScalarBatch => mab_scalar(
497                fast_ma,
498                slow_ma,
499                fast_period,
500                devup,
501                devdn,
502                first_output,
503                upper,
504                mid,
505                lower,
506            ),
507            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
508            Kernel::Avx2 | Kernel::Avx2Batch => mab_avx2(
509                fast_ma,
510                slow_ma,
511                fast_period,
512                devup,
513                devdn,
514                first_output,
515                upper,
516                mid,
517                lower,
518            ),
519            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
520            Kernel::Avx512 | Kernel::Avx512Batch => mab_avx512(
521                fast_ma,
522                slow_ma,
523                fast_period,
524                devup,
525                devdn,
526                first_output,
527                upper,
528                mid,
529                lower,
530            ),
531            _ => unreachable!(),
532        }
533    }
534}
535
536#[inline(always)]
537pub unsafe fn mab_scalar(
538    fast_ma: &[f64],
539    slow_ma: &[f64],
540    fast_period: usize,
541    devup: f64,
542    devdn: f64,
543    first_output: usize,
544    upper: &mut [f64],
545    mid: &mut [f64],
546    lower: &mut [f64],
547) {
548    let start_idx = if first_output >= fast_period {
549        first_output - fast_period + 1
550    } else {
551        0
552    };
553
554    let mut sum_sq = 0.0;
555    for i in start_idx..(start_idx + fast_period).min(fast_ma.len()) {
556        let diff = fast_ma[i] - slow_ma[i];
557        sum_sq += diff * diff;
558    }
559
560    if first_output < fast_ma.len() {
561        let dev = (sum_sq / fast_period as f64).sqrt();
562        mid[first_output] = fast_ma[first_output];
563        upper[first_output] = slow_ma[first_output] + devup * dev;
564        lower[first_output] = slow_ma[first_output] - devdn * dev;
565    }
566
567    for i in (first_output + 1)..fast_ma.len() {
568        let old_idx = i - fast_period;
569        let old = fast_ma[old_idx] - slow_ma[old_idx];
570        let new = fast_ma[i] - slow_ma[i];
571        sum_sq += new * new - old * old;
572        let dev = (sum_sq / fast_period as f64).sqrt();
573
574        mid[i] = fast_ma[i];
575        upper[i] = slow_ma[i] + devup * dev;
576        lower[i] = slow_ma[i] - devdn * dev;
577    }
578}
579
580#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
581#[target_feature(enable = "avx2,fma")]
582pub unsafe fn mab_avx2(
583    fast_ma: &[f64],
584    slow_ma: &[f64],
585    fast_period: usize,
586    devup: f64,
587    devdn: f64,
588    first_output: usize,
589    upper: &mut [f64],
590    mid: &mut [f64],
591    lower: &mut [f64],
592) {
593    use core::arch::x86_64::*;
594    let n = fast_ma.len();
595    if first_output >= n {
596        return;
597    }
598    debug_assert!(fast_period > 0);
599    debug_assert!(first_output + 1 >= fast_period);
600
601    let start = first_output + 1 - fast_period;
602    let m = n - start;
603
604    let mut diffsq: AVec<f64> = AVec::with_capacity(CACHELINE_ALIGN, m);
605    diffsq.set_len(m);
606    let mut prefix: AVec<f64> = AVec::with_capacity(CACHELINE_ALIGN, m + 1);
607    prefix.set_len(m + 1);
608
609    let f0 = fast_ma.as_ptr().add(start);
610    let s0 = slow_ma.as_ptr().add(start);
611    let dptr = diffsq.as_mut_ptr();
612
613    let mut k = 0usize;
614    while k + 4 <= m {
615        let vf = _mm256_loadu_pd(f0.add(k));
616        let vs = _mm256_loadu_pd(s0.add(k));
617        let vd = _mm256_sub_pd(vf, vs);
618        let vd2 = _mm256_mul_pd(vd, vd);
619        _mm256_storeu_pd(dptr.add(k), vd2);
620        k += 4;
621    }
622    while k < m {
623        let d = *f0.add(k) - *s0.add(k);
624        *dptr.add(k) = d * d;
625        k += 1;
626    }
627
628    let pptr = prefix.as_mut_ptr();
629    *pptr = 0.0;
630    let mut acc = 0.0f64;
631    k = 0;
632    while k < m {
633        acc += *dptr.add(k);
634        *pptr.add(k + 1) = acc;
635        k += 1;
636    }
637
638    let invf = 1.0 / (fast_period as f64);
639    let vinvf = _mm256_set1_pd(invf);
640    let vup = _mm256_set1_pd(devup);
641    let vdn = _mm256_set1_pd(devdn);
642
643    let mut i = first_output;
644
645    while i < n && (i & 3) != 0 {
646        let base = i - start;
647        let sum = *pptr.add(base + 1) - *pptr.add(base + 1 - fast_period);
648        let dev = (sum * invf).sqrt();
649        let sm = *slow_ma.as_ptr().add(i);
650        *mid.as_mut_ptr().add(i) = *fast_ma.as_ptr().add(i);
651        *upper.as_mut_ptr().add(i) = sm + devup * dev;
652        *lower.as_mut_ptr().add(i) = sm - devdn * dev;
653        i += 1;
654    }
655
656    while i + 3 < n {
657        let base = i - start;
658        let pend = _mm256_loadu_pd(pptr.add(base + 1));
659        let psta = _mm256_loadu_pd(pptr.add(base + 1 - fast_period));
660        let vsum = _mm256_sub_pd(pend, psta);
661        let vdev = _mm256_sqrt_pd(_mm256_mul_pd(vsum, vinvf));
662
663        let vfast = _mm256_loadu_pd(fast_ma.as_ptr().add(i));
664        let vslow = _mm256_loadu_pd(slow_ma.as_ptr().add(i));
665
666        _mm256_storeu_pd(mid.as_mut_ptr().add(i), vfast);
667        let vupper = _mm256_fmadd_pd(vup, vdev, vslow);
668        let vlower = _mm256_fnmadd_pd(vdn, vdev, vslow);
669        _mm256_storeu_pd(upper.as_mut_ptr().add(i), vupper);
670        _mm256_storeu_pd(lower.as_mut_ptr().add(i), vlower);
671
672        i += 4;
673    }
674
675    while i < n {
676        let base = i - start;
677        let sum = *pptr.add(base + 1) - *pptr.add(base + 1 - fast_period);
678        let dev = (sum * invf).sqrt();
679        let sm = *slow_ma.as_ptr().add(i);
680        *mid.as_mut_ptr().add(i) = *fast_ma.as_ptr().add(i);
681        *upper.as_mut_ptr().add(i) = sm + devup * dev;
682        *lower.as_mut_ptr().add(i) = sm - devdn * dev;
683        i += 1;
684    }
685}
686
687#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
688#[target_feature(enable = "avx512f")]
689pub unsafe fn mab_avx512(
690    fast_ma: &[f64],
691    slow_ma: &[f64],
692    fast_period: usize,
693    devup: f64,
694    devdn: f64,
695    first_output: usize,
696    upper: &mut [f64],
697    mid: &mut [f64],
698    lower: &mut [f64],
699) {
700    use core::arch::x86_64::*;
701    let n = fast_ma.len();
702    if first_output >= n {
703        return;
704    }
705    debug_assert!(fast_period > 0);
706    debug_assert!(first_output + 1 >= fast_period);
707
708    let start = first_output + 1 - fast_period;
709    let m = n - start;
710
711    let mut diffsq: AVec<f64> = AVec::with_capacity(CACHELINE_ALIGN, m);
712    diffsq.set_len(m);
713    let mut prefix: AVec<f64> = AVec::with_capacity(CACHELINE_ALIGN, m + 1);
714    prefix.set_len(m + 1);
715
716    let f0 = fast_ma.as_ptr().add(start);
717    let s0 = slow_ma.as_ptr().add(start);
718    let dptr = diffsq.as_mut_ptr();
719
720    let mut k = 0usize;
721    while k + 8 <= m {
722        let vf = _mm512_loadu_pd(f0.add(k));
723        let vs = _mm512_loadu_pd(s0.add(k));
724        let vd = _mm512_sub_pd(vf, vs);
725        let vd2 = _mm512_mul_pd(vd, vd);
726        _mm512_storeu_pd(dptr.add(k), vd2);
727        k += 8;
728    }
729    while k < m {
730        let d = *f0.add(k) - *s0.add(k);
731        *dptr.add(k) = d * d;
732        k += 1;
733    }
734
735    let pptr = prefix.as_mut_ptr();
736    *pptr = 0.0;
737    let mut acc = 0.0f64;
738    k = 0;
739    while k < m {
740        acc += *dptr.add(k);
741        *pptr.add(k + 1) = acc;
742        k += 1;
743    }
744
745    let invf = 1.0 / (fast_period as f64);
746    let vinvf = _mm512_set1_pd(invf);
747    let vup = _mm512_set1_pd(devup);
748    let vdn = _mm512_set1_pd(devdn);
749
750    let mut i = first_output;
751
752    while i < n && (i & 7) != 0 {
753        let base = i - start;
754        let sum = *pptr.add(base + 1) - *pptr.add(base + 1 - fast_period);
755        let dev = (sum * invf).sqrt();
756        let sm = *slow_ma.as_ptr().add(i);
757        *mid.as_mut_ptr().add(i) = *fast_ma.as_ptr().add(i);
758        *upper.as_mut_ptr().add(i) = sm + devup * dev;
759        *lower.as_mut_ptr().add(i) = sm - devdn * dev;
760        i += 1;
761    }
762
763    while i + 7 < n {
764        let base = i - start;
765        let pend = _mm512_loadu_pd(pptr.add(base + 1));
766        let psta = _mm512_loadu_pd(pptr.add(base + 1 - fast_period));
767        let vsum = _mm512_sub_pd(pend, psta);
768        let vdev = _mm512_sqrt_pd(_mm512_mul_pd(vsum, vinvf));
769
770        let vfast = _mm512_loadu_pd(fast_ma.as_ptr().add(i));
771        let vslow = _mm512_loadu_pd(slow_ma.as_ptr().add(i));
772
773        _mm512_storeu_pd(mid.as_mut_ptr().add(i), vfast);
774        let vupper = _mm512_fmadd_pd(vup, vdev, vslow);
775        let vlower = _mm512_fnmadd_pd(vdn, vdev, vslow);
776        _mm512_storeu_pd(upper.as_mut_ptr().add(i), vupper);
777        _mm512_storeu_pd(lower.as_mut_ptr().add(i), vlower);
778
779        i += 8;
780    }
781
782    while i < n {
783        let base = i - start;
784        let sum = *pptr.add(base + 1) - *pptr.add(base + 1 - fast_period);
785        let dev = (sum * invf).sqrt();
786        let sm = *slow_ma.as_ptr().add(i);
787        *mid.as_mut_ptr().add(i) = *fast_ma.as_ptr().add(i);
788        *upper.as_mut_ptr().add(i) = sm + devup * dev;
789        *lower.as_mut_ptr().add(i) = sm - devdn * dev;
790        i += 1;
791    }
792}
793
794pub struct MabStream {
795    fast_buffer: Vec<f64>,
796    slow_buffer: Vec<f64>,
797    diffs_buffer: Vec<f64>,
798    fast_index: usize,
799    slow_index: usize,
800    diff_index: usize,
801    count: usize,
802    fast_period: usize,
803    slow_period: usize,
804    devup: f64,
805    devdn: f64,
806    fast_ma_type: String,
807    slow_ma_type: String,
808    fast_sum: f64,
809    slow_sum: f64,
810    fast_ma: f64,
811    slow_ma: f64,
812    ema_fast: f64,
813    ema_slow: f64,
814    kernel: Kernel,
815
816    sumsq_diff: f64,
817    diffs_filled: usize,
818    inv_fast_len: f64,
819    ready_threshold: usize,
820    k_fast: f64,
821    k_slow: f64,
822    max_period: usize,
823}
824
825impl MabStream {
826    pub fn try_new(params: MabParams) -> Result<Self, String> {
827        let fast_period = params.fast_period.unwrap_or(10);
828        let slow_period = params.slow_period.unwrap_or(50);
829        let devup = params.devup.unwrap_or(1.0);
830        let devdn = params.devdn.unwrap_or(1.0);
831        let fast_ma_type = params.fast_ma_type.unwrap_or_else(|| "sma".to_string());
832        let slow_ma_type = params.slow_ma_type.unwrap_or_else(|| "sma".to_string());
833
834        if fast_period == 0 || slow_period == 0 {
835            return Err("Period cannot be zero".to_string());
836        }
837
838        let max_period = fast_period.max(slow_period);
839        let ready_threshold = max_period + fast_period - 1;
840
841        Ok(Self {
842            fast_buffer: vec![0.0; fast_period],
843            slow_buffer: vec![0.0; slow_period],
844            diffs_buffer: vec![0.0; fast_period],
845            fast_index: 0,
846            slow_index: 0,
847            diff_index: 0,
848            count: 0,
849            fast_period,
850            slow_period,
851            devup,
852            devdn,
853            fast_ma_type,
854            slow_ma_type,
855            fast_sum: 0.0,
856            slow_sum: 0.0,
857            fast_ma: 0.0,
858            slow_ma: 0.0,
859            ema_fast: 0.0,
860            ema_slow: 0.0,
861            kernel: detect_best_kernel(),
862
863            sumsq_diff: 0.0,
864            diffs_filled: 0,
865            inv_fast_len: 1.0 / fast_period as f64,
866            ready_threshold,
867            k_fast: 2.0 / (fast_period as f64 + 1.0),
868            k_slow: 2.0 / (slow_period as f64 + 1.0),
869            max_period,
870        })
871    }
872
873    pub fn update(&mut self, value: f64) -> Option<(f64, f64, f64)> {
874        if !value.is_finite() {
875            return None;
876        }
877
878        self.count += 1;
879
880        match self.fast_ma_type.as_str() {
881            "ema" => {
882                if self.count == 1 {
883                    self.ema_fast = value;
884                } else {
885                    self.ema_fast = (1.0 - self.k_fast).mul_add(self.ema_fast, self.k_fast * value);
886                }
887                self.fast_ma = self.ema_fast;
888            }
889            _ => {
890                if self.count <= self.fast_period {
891                    let idx = self.fast_index;
892                    self.fast_sum += value;
893                    self.fast_buffer[idx] = value;
894                    if self.count == self.fast_period {
895                        self.fast_ma = self.fast_sum * self.inv_fast_len;
896                    }
897                    self.fast_index += 1;
898                    if self.fast_index == self.fast_period {
899                        self.fast_index = 0;
900                    }
901                } else {
902                    let idx = self.fast_index;
903                    let old = self.fast_buffer[idx];
904                    self.fast_buffer[idx] = value;
905                    self.fast_sum += value - old;
906                    self.fast_ma = self.fast_sum * self.inv_fast_len;
907                    self.fast_index += 1;
908                    if self.fast_index == self.fast_period {
909                        self.fast_index = 0;
910                    }
911                }
912            }
913        }
914
915        match self.slow_ma_type.as_str() {
916            "ema" => {
917                if self.count == 1 {
918                    self.ema_slow = value;
919                } else {
920                    self.ema_slow = (1.0 - self.k_slow).mul_add(self.ema_slow, self.k_slow * value);
921                }
922                self.slow_ma = self.ema_slow;
923            }
924            _ => {
925                if self.count <= self.slow_period {
926                    let idx = self.slow_index;
927                    self.slow_sum += value;
928                    self.slow_buffer[idx] = value;
929                    if self.count == self.slow_period {
930                        self.slow_ma = self.slow_sum / self.slow_period as f64;
931                    }
932                    self.slow_index += 1;
933                    if self.slow_index == self.slow_period {
934                        self.slow_index = 0;
935                    }
936                } else {
937                    let idx = self.slow_index;
938                    let old = self.slow_buffer[idx];
939                    self.slow_buffer[idx] = value;
940                    self.slow_sum += value - old;
941                    self.slow_ma = self.slow_sum / self.slow_period as f64;
942                    self.slow_index += 1;
943                    if self.slow_index == self.slow_period {
944                        self.slow_index = 0;
945                    }
946                }
947            }
948        }
949
950        if self.count < self.max_period {
951            return None;
952        }
953
954        let diff = self.fast_ma - self.slow_ma;
955        let diff2 = diff * diff;
956
957        if self.diffs_filled < self.fast_period {
958            self.sumsq_diff += diff2;
959            self.diffs_buffer[self.diff_index] = diff2;
960            self.diff_index += 1;
961            if self.diff_index == self.fast_period {
962                self.diff_index = 0;
963            }
964            self.diffs_filled += 1;
965        } else {
966            let old2 = self.diffs_buffer[self.diff_index];
967            self.sumsq_diff += diff2 - old2;
968            self.diffs_buffer[self.diff_index] = diff2;
969            self.diff_index += 1;
970            if self.diff_index == self.fast_period {
971                self.diff_index = 0;
972            }
973        }
974
975        if self.count < self.ready_threshold || self.diffs_filled < self.fast_period {
976            return None;
977        }
978
979        let dev = (self.sumsq_diff * self.inv_fast_len).sqrt();
980
981        let upper = dev.mul_add(self.devup, self.slow_ma);
982        let middle = self.fast_ma;
983        let lower = (-self.devdn * dev).mul_add(1.0, self.slow_ma);
984
985        Some((upper, middle, lower))
986    }
987}
988
989#[derive(Clone, Debug)]
990pub struct MabBatchRange {
991    pub fast_period: (usize, usize, usize),
992    pub slow_period: (usize, usize, usize),
993    pub devup: (f64, f64, f64),
994    pub devdn: (f64, f64, f64),
995    pub fast_ma_type: (String, String, String),
996    pub slow_ma_type: (String, String, String),
997}
998
999impl Default for MabBatchRange {
1000    fn default() -> Self {
1001        Self {
1002            fast_period: (10, 10, 0),
1003            slow_period: (50, 299, 1),
1004            devup: (1.0, 1.0, 0.0),
1005            devdn: (1.0, 1.0, 0.0),
1006            fast_ma_type: ("sma".to_string(), "sma".to_string(), String::new()),
1007            slow_ma_type: ("sma".to_string(), "sma".to_string(), String::new()),
1008        }
1009    }
1010}
1011
1012#[derive(Clone, Debug)]
1013pub struct MabBatchOutput {
1014    pub upperbands: Vec<f64>,
1015    pub middlebands: Vec<f64>,
1016    pub lowerbands: Vec<f64>,
1017    pub combos: Vec<MabParams>,
1018    pub rows: usize,
1019    pub cols: usize,
1020}
1021impl MabBatchOutput {
1022    pub fn row_for_params(&self, p: &MabParams) -> Option<usize> {
1023        self.combos.iter().position(|c| {
1024            c.fast_period == p.fast_period
1025                && c.slow_period == p.slow_period
1026                && c.devup == p.devup
1027                && c.devdn == p.devdn
1028                && c.fast_ma_type == p.fast_ma_type
1029                && c.slow_ma_type == p.slow_ma_type
1030        })
1031    }
1032
1033    pub fn upper_slice(&self, row: usize) -> Option<&[f64]> {
1034        if row < self.rows {
1035            let start = row * self.cols;
1036            let end = start + self.cols;
1037            Some(&self.upperbands[start..end])
1038        } else {
1039            None
1040        }
1041    }
1042
1043    pub fn middle_slice(&self, row: usize) -> Option<&[f64]> {
1044        if row < self.rows {
1045            let start = row * self.cols;
1046            let end = start + self.cols;
1047            Some(&self.middlebands[start..end])
1048        } else {
1049            None
1050        }
1051    }
1052
1053    pub fn lower_slice(&self, row: usize) -> Option<&[f64]> {
1054        if row < self.rows {
1055            let start = row * self.cols;
1056            let end = start + self.cols;
1057            Some(&self.lowerbands[start..end])
1058        } else {
1059            None
1060        }
1061    }
1062}
1063
1064pub(crate) fn expand_grid(p: &MabBatchRange) -> Result<Vec<MabParams>, MabError> {
1065    fn axis_usize(axis: (usize, usize, usize)) -> Result<Vec<usize>, MabError> {
1066        let (start, end, step) = axis;
1067        if step == 0 || start == end {
1068            return Ok(vec![start]);
1069        }
1070        let (lo, hi) = if start <= end {
1071            (start, end)
1072        } else {
1073            (end, start)
1074        };
1075        let v: Vec<usize> = (lo..=hi).step_by(step).collect();
1076        if v.is_empty() {
1077            return Err(MabError::InvalidRange { start, end, step });
1078        }
1079        Ok(v)
1080    }
1081
1082    fn axis_f64(axis: (f64, f64, f64)) -> Result<Vec<f64>, MabError> {
1083        let (start, end, step) = axis;
1084        const EPS: f64 = 1e-12;
1085        if step.abs() < EPS || (start - end).abs() < EPS {
1086            return Ok(vec![start]);
1087        }
1088        let step_eff = if start <= end {
1089            step.abs()
1090        } else {
1091            -step.abs()
1092        };
1093        let mut v = Vec::new();
1094        let mut x = start;
1095        if step_eff > 0.0 {
1096            while x <= end + EPS {
1097                v.push(x);
1098                x += step_eff;
1099            }
1100        } else {
1101            while x >= end - EPS {
1102                v.push(x);
1103                x += step_eff;
1104            }
1105        }
1106        if v.is_empty() {
1107            return Err(MabError::InvalidRangeF64 { start, end, step });
1108        }
1109        Ok(v)
1110    }
1111
1112    let fast_periods = axis_usize(p.fast_period)?;
1113    let slow_periods = axis_usize(p.slow_period)?;
1114    let devups = axis_f64(p.devup)?;
1115    let devdns = axis_f64(p.devdn)?;
1116
1117    let mut combos =
1118        Vec::with_capacity(fast_periods.len() * slow_periods.len() * devups.len() * devdns.len());
1119
1120    for &fast in &fast_periods {
1121        for &slow in &slow_periods {
1122            for &devup in &devups {
1123                for &devdn in &devdns {
1124                    combos.push(MabParams {
1125                        fast_period: Some(fast),
1126                        slow_period: Some(slow),
1127                        devup: Some(devup),
1128                        devdn: Some(devdn),
1129                        fast_ma_type: Some(p.fast_ma_type.0.clone()),
1130                        slow_ma_type: Some(p.slow_ma_type.0.clone()),
1131                    });
1132                }
1133            }
1134        }
1135    }
1136
1137    Ok(combos)
1138}
1139
1140pub fn mab_batch(input: &[f64], sweep: &MabBatchRange) -> Result<MabBatchOutput, MabError> {
1141    mab_batch_inner(input, sweep, Kernel::Auto, false)
1142}
1143
1144fn mab_batch_inner(
1145    input: &[f64],
1146    sweep: &MabBatchRange,
1147    kernel: Kernel,
1148    parallel: bool,
1149) -> Result<MabBatchOutput, MabError> {
1150    let kernel = match kernel {
1151        Kernel::Auto => detect_best_batch_kernel(),
1152        k => k,
1153    };
1154    let simd = match kernel {
1155        Kernel::Avx512Batch => Kernel::Avx512,
1156        Kernel::Avx2Batch => Kernel::Avx2,
1157        Kernel::ScalarBatch => Kernel::Scalar,
1158        other => {
1159            return Err(MabError::InvalidKernelForBatch(other));
1160        }
1161    };
1162
1163    let combos = expand_grid(sweep)?;
1164    let rows = combos.len();
1165    let cols = input.len();
1166    if cols == 0 {
1167        return Err(MabError::EmptyInputData);
1168    }
1169    rows.checked_mul(cols).ok_or(MabError::InvalidRange {
1170        start: sweep.fast_period.0,
1171        end: sweep.fast_period.1,
1172        step: sweep.fast_period.2,
1173    })?;
1174
1175    let first_valid = input
1176        .iter()
1177        .position(|x| !x.is_nan())
1178        .ok_or(MabError::AllValuesNaN)?;
1179    let valid = cols - first_valid;
1180    let warmup_prefixes: Vec<usize> = combos
1181        .iter()
1182        .map(|p| {
1183            let fast = p.fast_period.unwrap();
1184            let slow = p.slow_period.unwrap();
1185            if fast == 0 || slow == 0 || fast > cols || slow > cols {
1186                return Err(MabError::InvalidPeriod {
1187                    fast,
1188                    slow,
1189                    data_len: cols,
1190                });
1191            }
1192            let need_total = fast.max(slow) + fast - 1;
1193            if valid < need_total {
1194                return Err(MabError::NotEnoughValidData {
1195                    needed: need_total,
1196                    valid,
1197                });
1198            }
1199            Ok(first_valid + need_total)
1200        })
1201        .collect::<Result<Vec<_>, MabError>>()?;
1202
1203    let mut upper_buf = make_uninit_matrix(rows, cols);
1204    let mut middle_buf = make_uninit_matrix(rows, cols);
1205    let mut lower_buf = make_uninit_matrix(rows, cols);
1206
1207    init_matrix_prefixes(&mut upper_buf, cols, &warmup_prefixes);
1208    init_matrix_prefixes(&mut middle_buf, cols, &warmup_prefixes);
1209    init_matrix_prefixes(&mut lower_buf, cols, &warmup_prefixes);
1210
1211    let upper_slice =
1212        unsafe { std::slice::from_raw_parts_mut(upper_buf.as_mut_ptr() as *mut f64, rows * cols) };
1213    let middle_slice =
1214        unsafe { std::slice::from_raw_parts_mut(middle_buf.as_mut_ptr() as *mut f64, rows * cols) };
1215    let lower_slice =
1216        unsafe { std::slice::from_raw_parts_mut(lower_buf.as_mut_ptr() as *mut f64, rows * cols) };
1217
1218    let combos = mab_batch_inner_into(
1219        input,
1220        sweep,
1221        simd,
1222        parallel,
1223        upper_slice,
1224        middle_slice,
1225        lower_slice,
1226    )?;
1227
1228    let mut upper_guard = core::mem::ManuallyDrop::new(upper_buf);
1229    let mut middle_guard = core::mem::ManuallyDrop::new(middle_buf);
1230    let mut lower_guard = core::mem::ManuallyDrop::new(lower_buf);
1231
1232    let upperbands = unsafe {
1233        Vec::from_raw_parts(
1234            upper_guard.as_mut_ptr() as *mut f64,
1235            upper_guard.len(),
1236            upper_guard.capacity(),
1237        )
1238    };
1239    let middlebands = unsafe {
1240        Vec::from_raw_parts(
1241            middle_guard.as_mut_ptr() as *mut f64,
1242            middle_guard.len(),
1243            middle_guard.capacity(),
1244        )
1245    };
1246    let lowerbands = unsafe {
1247        Vec::from_raw_parts(
1248            lower_guard.as_mut_ptr() as *mut f64,
1249            lower_guard.len(),
1250            lower_guard.capacity(),
1251        )
1252    };
1253
1254    Ok(MabBatchOutput {
1255        upperbands,
1256        middlebands,
1257        lowerbands,
1258        combos,
1259        rows,
1260        cols,
1261    })
1262}
1263
1264fn mab_batch_inner_into(
1265    input: &[f64],
1266    sweep: &MabBatchRange,
1267    kernel: Kernel,
1268    parallel: bool,
1269    upper_out: &mut [f64],
1270    middle_out: &mut [f64],
1271    lower_out: &mut [f64],
1272) -> Result<Vec<MabParams>, MabError> {
1273    let combos = expand_grid(sweep)?;
1274    let rows = combos.len();
1275    let cols = input.len();
1276    let expected = rows.checked_mul(cols).ok_or(MabError::InvalidRange {
1277        start: sweep.fast_period.0,
1278        end: sweep.fast_period.1,
1279        step: sweep.fast_period.2,
1280    })?;
1281
1282    if upper_out.len() != expected || middle_out.len() != expected || lower_out.len() != expected {
1283        return Err(MabError::OutputLengthMismatch {
1284            upper_len: upper_out.len(),
1285            middle_len: middle_out.len(),
1286            lower_len: lower_out.len(),
1287            expected,
1288        });
1289    }
1290
1291    if !combos.is_empty() {
1292        let p0 = &combos[0];
1293        let all_same_ma = combos.iter().all(|p| {
1294            p.fast_period == p0.fast_period
1295                && p.slow_period == p0.slow_period
1296                && p.fast_ma_type == p0.fast_ma_type
1297                && p.slow_ma_type == p0.slow_ma_type
1298        });
1299
1300        if all_same_ma {
1301            use crate::indicators::ema::{ema, EmaInput, EmaParams};
1302            use crate::indicators::sma::{sma, SmaInput, SmaParams};
1303
1304            let n = input.len();
1305            let first = input.iter().position(|x| !x.is_nan()).unwrap_or(0);
1306            let fast = p0.fast_period.unwrap();
1307            let slow = p0.slow_period.unwrap();
1308            let fast_ma_type = p0.fast_ma_type.as_deref().unwrap_or("sma");
1309            let slow_ma_type = p0.slow_ma_type.as_deref().unwrap_or("sma");
1310
1311            let fast_ma = match fast_ma_type {
1312                "ema" => {
1313                    let params = EmaParams { period: Some(fast) };
1314                    ema(&EmaInput::from_slice(input, params))
1315                        .map_err(|_| MabError::NotEnoughValidData {
1316                            needed: fast,
1317                            valid: n - first,
1318                        })?
1319                        .values
1320                }
1321                _ => {
1322                    let params = SmaParams { period: Some(fast) };
1323                    sma(&SmaInput::from_slice(input, params))
1324                        .map_err(|_| MabError::NotEnoughValidData {
1325                            needed: fast,
1326                            valid: n - first,
1327                        })?
1328                        .values
1329                }
1330            };
1331
1332            let slow_ma = match slow_ma_type {
1333                "ema" => {
1334                    let params = EmaParams { period: Some(slow) };
1335                    ema(&EmaInput::from_slice(input, params))
1336                        .map_err(|_| MabError::NotEnoughValidData {
1337                            needed: slow,
1338                            valid: n - first,
1339                        })?
1340                        .values
1341                }
1342                _ => {
1343                    let params = SmaParams { period: Some(slow) };
1344                    sma(&SmaInput::from_slice(input, params))
1345                        .map_err(|_| MabError::NotEnoughValidData {
1346                            needed: slow,
1347                            valid: n - first,
1348                        })?
1349                        .values
1350                }
1351            };
1352
1353            let need_total = fast.max(slow) + fast - 1;
1354            let warmup = first + need_total - 1;
1355            let first_output = warmup + 1;
1356
1357            if first_output < n {
1358                let mut dev: AVec<f64> = AVec::with_capacity(CACHELINE_ALIGN, n);
1359                unsafe {
1360                    dev.set_len(n);
1361                }
1362
1363                unsafe {
1364                    let f_ptr = fast_ma.as_ptr();
1365                    let s_ptr = slow_ma.as_ptr();
1366                    let d_ptr = dev.as_mut_ptr();
1367
1368                    let start = first_output + 1 - fast;
1369                    let mut sum_sq = 0.0f64;
1370                    let mut k = 0usize;
1371                    while k < fast {
1372                        let idx = start + k;
1373                        let diff = *f_ptr.add(idx) - *s_ptr.add(idx);
1374                        sum_sq += diff * diff;
1375                        k += 1;
1376                    }
1377
1378                    *d_ptr.add(first_output) = (sum_sq / fast as f64).sqrt();
1379
1380                    let mut i = first_output + 1;
1381                    while i < n {
1382                        let old_idx = i - fast;
1383                        let old = *f_ptr.add(old_idx) - *s_ptr.add(old_idx);
1384                        let new = *f_ptr.add(i) - *s_ptr.add(i);
1385                        sum_sq += new * new - old * old;
1386                        *d_ptr.add(i) = (sum_sq / fast as f64).sqrt();
1387                        i += 1;
1388                    }
1389                }
1390
1391                let fill_row = |row: usize, u: &mut [f64], m: &mut [f64], l: &mut [f64]| {
1392                    let pr = &combos[row];
1393                    let devup = pr.devup.unwrap();
1394                    let devdn = pr.devdn.unwrap();
1395                    for i in first_output..n {
1396                        let d = dev[i];
1397                        m[i] = fast_ma[i];
1398                        u[i] = slow_ma[i] + devup * d;
1399                        l[i] = slow_ma[i] - devdn * d;
1400                    }
1401                    Ok(())
1402                };
1403
1404                if parallel {
1405                    #[cfg(not(target_arch = "wasm32"))]
1406                    {
1407                        use rayon::prelude::*;
1408                        upper_out
1409                            .par_chunks_mut(cols)
1410                            .zip(middle_out.par_chunks_mut(cols))
1411                            .zip(lower_out.par_chunks_mut(cols))
1412                            .enumerate()
1413                            .try_for_each(|(row, ((u, m), l))| fill_row(row, u, m, l))?;
1414                    }
1415                    #[cfg(target_arch = "wasm32")]
1416                    {
1417                        for row in 0..rows {
1418                            let s = row * cols;
1419                            fill_row(
1420                                row,
1421                                &mut upper_out[s..s + cols],
1422                                &mut middle_out[s..s + cols],
1423                                &mut lower_out[s..s + cols],
1424                            )?;
1425                        }
1426                    }
1427                } else {
1428                    for row in 0..rows {
1429                        let s = row * cols;
1430                        fill_row(
1431                            row,
1432                            &mut upper_out[s..s + cols],
1433                            &mut middle_out[s..s + cols],
1434                            &mut lower_out[s..s + cols],
1435                        )?;
1436                    }
1437                }
1438
1439                return Ok(combos);
1440            }
1441        }
1442    }
1443
1444    let process_row = |row: usize, u: &mut [f64], m: &mut [f64], l: &mut [f64]| {
1445        let p = &combos[row];
1446        let in_row = MabInput::from_slice(
1447            input,
1448            MabParams {
1449                fast_period: p.fast_period,
1450                slow_period: p.slow_period,
1451                devup: p.devup,
1452                devdn: p.devdn,
1453                fast_ma_type: p.fast_ma_type.clone(),
1454                slow_ma_type: p.slow_ma_type.clone(),
1455            },
1456        );
1457        mab_into_slice(u, m, l, &in_row, kernel)
1458    };
1459
1460    #[cfg(not(target_arch = "wasm32"))]
1461    {
1462        if parallel {
1463            use rayon::prelude::*;
1464            upper_out
1465                .par_chunks_mut(cols)
1466                .zip(middle_out.par_chunks_mut(cols))
1467                .zip(lower_out.par_chunks_mut(cols))
1468                .enumerate()
1469                .try_for_each(|(row, ((u, m), l))| process_row(row, u, m, l))?;
1470        } else {
1471            for row in 0..rows {
1472                let s = row * cols;
1473                process_row(
1474                    row,
1475                    &mut upper_out[s..s + cols],
1476                    &mut middle_out[s..s + cols],
1477                    &mut lower_out[s..s + cols],
1478                )?;
1479            }
1480        }
1481    }
1482
1483    #[cfg(target_arch = "wasm32")]
1484    {
1485        for row in 0..rows {
1486            let s = row * cols;
1487            process_row(
1488                row,
1489                &mut upper_out[s..s + cols],
1490                &mut middle_out[s..s + cols],
1491                &mut lower_out[s..s + cols],
1492            )?;
1493        }
1494    }
1495
1496    Ok(combos)
1497}
1498
1499#[cfg(feature = "python")]
1500#[pyfunction(name = "mab")]
1501#[pyo3(signature = (data, fast_period=10, slow_period=50, devup=1.0, devdn=1.0, fast_ma_type="sma", slow_ma_type="sma", kernel=None))]
1502pub fn mab_py<'py>(
1503    py: Python<'py>,
1504    data: numpy::PyReadonlyArray1<'py, f64>,
1505    fast_period: usize,
1506    slow_period: usize,
1507    devup: f64,
1508    devdn: f64,
1509    fast_ma_type: &str,
1510    slow_ma_type: &str,
1511    kernel: Option<&str>,
1512) -> PyResult<(
1513    Bound<'py, PyArray1<f64>>,
1514    Bound<'py, PyArray1<f64>>,
1515    Bound<'py, PyArray1<f64>>,
1516)> {
1517    let slice_in = data.as_slice()?;
1518    let params = MabParams {
1519        fast_period: Some(fast_period),
1520        slow_period: Some(slow_period),
1521        devup: Some(devup),
1522        devdn: Some(devdn),
1523        fast_ma_type: Some(fast_ma_type.to_string()),
1524        slow_ma_type: Some(slow_ma_type.to_string()),
1525    };
1526    let input = MabInput::from_slice(slice_in, params);
1527
1528    let chosen_kernel = validate_kernel(kernel, false)?;
1529
1530    let result = py
1531        .allow_threads(|| match chosen_kernel {
1532            Kernel::Auto => mab(&input),
1533            k => mab_with_kernel(&input, k),
1534        })
1535        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1536
1537    Ok((
1538        result.upperband.into_pyarray(py),
1539        result.middleband.into_pyarray(py),
1540        result.lowerband.into_pyarray(py),
1541    ))
1542}
1543
1544#[cfg(feature = "python")]
1545#[pyclass(name = "MabStream")]
1546pub struct MabStreamPy {
1547    stream: MabStream,
1548}
1549
1550#[cfg(feature = "python")]
1551#[pymethods]
1552impl MabStreamPy {
1553    #[new]
1554    fn new(
1555        fast_period: usize,
1556        slow_period: usize,
1557        devup: f64,
1558        devdn: f64,
1559        fast_ma_type: &str,
1560        slow_ma_type: &str,
1561    ) -> PyResult<Self> {
1562        let params = MabParams {
1563            fast_period: Some(fast_period),
1564            slow_period: Some(slow_period),
1565            devup: Some(devup),
1566            devdn: Some(devdn),
1567            fast_ma_type: Some(fast_ma_type.to_string()),
1568            slow_ma_type: Some(slow_ma_type.to_string()),
1569        };
1570        let stream =
1571            MabStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1572        Ok(MabStreamPy { stream })
1573    }
1574
1575    fn update(&mut self, value: f64) -> Option<(f64, f64, f64)> {
1576        self.stream.update(value)
1577    }
1578}
1579
1580#[cfg(feature = "python")]
1581#[pyfunction(name = "mab_batch")]
1582#[pyo3(signature = (data, fast_period_range, slow_period_range, devup_range=(1.0, 1.0, 0.0), devdn_range=(1.0, 1.0, 0.0), fast_ma_type="sma", slow_ma_type="sma", kernel=None))]
1583pub fn mab_batch_py<'py>(
1584    py: Python<'py>,
1585    data: numpy::PyReadonlyArray1<'py, f64>,
1586    fast_period_range: (usize, usize, usize),
1587    slow_period_range: (usize, usize, usize),
1588    devup_range: (f64, f64, f64),
1589    devdn_range: (f64, f64, f64),
1590    fast_ma_type: &str,
1591    slow_ma_type: &str,
1592    kernel: Option<&str>,
1593) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1594    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1595
1596    let slice_in = data.as_slice()?;
1597
1598    let sweep = MabBatchRange {
1599        fast_period: fast_period_range,
1600        slow_period: slow_period_range,
1601        devup: devup_range,
1602        devdn: devdn_range,
1603        fast_ma_type: (
1604            fast_ma_type.to_string(),
1605            fast_ma_type.to_string(),
1606            "".to_string(),
1607        ),
1608        slow_ma_type: (
1609            slow_ma_type.to_string(),
1610            slow_ma_type.to_string(),
1611            "".to_string(),
1612        ),
1613    };
1614
1615    let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1616    let rows = combos.len();
1617    let cols = slice_in.len();
1618    if cols == 0 {
1619        return Err(PyValueError::new_err(MabError::EmptyInputData.to_string()));
1620    }
1621    let total = rows
1622        .checked_mul(cols)
1623        .ok_or_else(|| PyValueError::new_err("mab_batch: rows*cols overflow"))?;
1624
1625    let upper_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1626    let middle_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1627    let lower_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1628
1629    let slice_upper = unsafe { upper_arr.as_slice_mut()? };
1630    let slice_middle = unsafe { middle_arr.as_slice_mut()? };
1631    let slice_lower = unsafe { lower_arr.as_slice_mut()? };
1632
1633    let kern = validate_kernel(kernel, true)?;
1634
1635    let first_valid = slice_in
1636        .iter()
1637        .position(|x| !x.is_nan())
1638        .ok_or_else(|| PyValueError::new_err(MabError::AllValuesNaN.to_string()))?;
1639    let valid = cols - first_valid;
1640    let warmup_prefixes: Vec<usize> = combos
1641        .iter()
1642        .map(|p| {
1643            let fast = p.fast_period.unwrap();
1644            let slow = p.slow_period.unwrap();
1645            if fast == 0 || slow == 0 || fast > cols || slow > cols {
1646                return Err(MabError::InvalidPeriod {
1647                    fast,
1648                    slow,
1649                    data_len: cols,
1650                });
1651            }
1652            let need_total = fast.max(slow) + fast - 1;
1653            if valid < need_total {
1654                return Err(MabError::NotEnoughValidData {
1655                    needed: need_total,
1656                    valid,
1657                });
1658            }
1659            Ok(first_valid + need_total)
1660        })
1661        .collect::<Result<Vec<_>, MabError>>()
1662        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1663
1664    let mu_upper: &mut [MaybeUninit<f64>] = unsafe {
1665        let ptr = upper_arr.as_array_mut().as_mut_ptr();
1666        std::slice::from_raw_parts_mut(ptr as *mut MaybeUninit<f64>, total)
1667    };
1668    let mu_middle: &mut [MaybeUninit<f64>] = unsafe {
1669        let ptr = middle_arr.as_array_mut().as_mut_ptr();
1670        std::slice::from_raw_parts_mut(ptr as *mut MaybeUninit<f64>, total)
1671    };
1672    let mu_lower: &mut [MaybeUninit<f64>] = unsafe {
1673        let ptr = lower_arr.as_array_mut().as_mut_ptr();
1674        std::slice::from_raw_parts_mut(ptr as *mut MaybeUninit<f64>, total)
1675    };
1676    init_matrix_prefixes(mu_upper, cols, &warmup_prefixes);
1677    init_matrix_prefixes(mu_middle, cols, &warmup_prefixes);
1678    init_matrix_prefixes(mu_lower, cols, &warmup_prefixes);
1679
1680    let combos = py
1681        .allow_threads(|| {
1682            let kernel = match kern {
1683                Kernel::Auto => detect_best_batch_kernel(),
1684                k => k,
1685            };
1686            let simd = match kernel {
1687                Kernel::Avx512Batch => Kernel::Avx512,
1688                Kernel::Avx2Batch => Kernel::Avx2,
1689                Kernel::ScalarBatch => Kernel::Scalar,
1690                _ => unreachable!(),
1691            };
1692
1693            mab_batch_inner_into(
1694                slice_in,
1695                &sweep,
1696                simd,
1697                true,
1698                slice_upper,
1699                slice_middle,
1700                slice_lower,
1701            )
1702        })
1703        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1704
1705    let dict = PyDict::new(py);
1706    dict.set_item("upperbands", upper_arr.reshape((rows, cols))?)?;
1707    dict.set_item("middlebands", middle_arr.reshape((rows, cols))?)?;
1708    dict.set_item("lowerbands", lower_arr.reshape((rows, cols))?)?;
1709
1710    dict.set_item(
1711        "fast_periods",
1712        combos
1713            .iter()
1714            .map(|p| p.fast_period.unwrap() as u64)
1715            .collect::<Vec<_>>()
1716            .into_pyarray(py),
1717    )?;
1718    dict.set_item(
1719        "slow_periods",
1720        combos
1721            .iter()
1722            .map(|p| p.slow_period.unwrap() as u64)
1723            .collect::<Vec<_>>()
1724            .into_pyarray(py),
1725    )?;
1726    dict.set_item(
1727        "devups",
1728        combos
1729            .iter()
1730            .map(|p| p.devup.unwrap())
1731            .collect::<Vec<_>>()
1732            .into_pyarray(py),
1733    )?;
1734    dict.set_item(
1735        "devdns",
1736        combos
1737            .iter()
1738            .map(|p| p.devdn.unwrap())
1739            .collect::<Vec<_>>()
1740            .into_pyarray(py),
1741    )?;
1742
1743    Ok(dict)
1744}
1745
1746#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1747#[derive(Serialize, Deserialize)]
1748pub struct MabJsSingle {
1749    pub values: Vec<f64>,
1750    pub rows: usize,
1751    pub cols: usize,
1752}
1753
1754#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1755#[wasm_bindgen(js_name = "mab")]
1756pub fn mab_wasm(
1757    data: &[f64],
1758    fast_period: usize,
1759    slow_period: usize,
1760    devup: f64,
1761    devdn: f64,
1762    fast_ma_type: &str,
1763    slow_ma_type: &str,
1764) -> Result<JsValue, JsValue> {
1765    let params = MabParams {
1766        fast_period: Some(fast_period),
1767        slow_period: Some(slow_period),
1768        devup: Some(devup),
1769        devdn: Some(devdn),
1770        fast_ma_type: Some(fast_ma_type.to_string()),
1771        slow_ma_type: Some(slow_ma_type.to_string()),
1772    };
1773    let input = MabInput::from_slice(data, params);
1774
1775    let mut upper = vec![0.0; data.len()];
1776    let mut middle = vec![0.0; data.len()];
1777    let mut lower = vec![0.0; data.len()];
1778
1779    mab_into_slice(
1780        &mut upper,
1781        &mut middle,
1782        &mut lower,
1783        &input,
1784        detect_best_kernel(),
1785    )
1786    .map_err(|e| JsValue::from_str(&e.to_string()))?;
1787
1788    let mut values = Vec::with_capacity(3 * data.len());
1789    values.extend_from_slice(&upper);
1790    values.extend_from_slice(&middle);
1791    values.extend_from_slice(&lower);
1792
1793    serde_wasm_bindgen::to_value(&MabJsSingle {
1794        values,
1795        rows: 3,
1796        cols: data.len(),
1797    })
1798    .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1799}
1800
1801#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1802#[wasm_bindgen]
1803pub fn mab_js(
1804    data: &[f64],
1805    fast_period: usize,
1806    slow_period: usize,
1807    devup: f64,
1808    devdn: f64,
1809    fast_ma_type: &str,
1810    slow_ma_type: &str,
1811) -> Result<Vec<f64>, JsValue> {
1812    let params = MabParams {
1813        fast_period: Some(fast_period),
1814        slow_period: Some(slow_period),
1815        devup: Some(devup),
1816        devdn: Some(devdn),
1817        fast_ma_type: Some(fast_ma_type.to_string()),
1818        slow_ma_type: Some(slow_ma_type.to_string()),
1819    };
1820    let input = MabInput::from_slice(data, params);
1821
1822    let mut upper = vec![0.0; data.len()];
1823    let mut middle = vec![0.0; data.len()];
1824    let mut lower = vec![0.0; data.len()];
1825
1826    mab_into_slice(
1827        &mut upper,
1828        &mut middle,
1829        &mut lower,
1830        &input,
1831        detect_best_kernel(),
1832    )
1833    .map_err(|e| JsValue::from_str(&e.to_string()))?;
1834
1835    let mut result = Vec::with_capacity(3 * data.len());
1836    result.extend_from_slice(&upper);
1837    result.extend_from_slice(&middle);
1838    result.extend_from_slice(&lower);
1839
1840    Ok(result)
1841}
1842
1843#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1844#[derive(Serialize, Deserialize)]
1845pub struct MabBatchConfig {
1846    pub fast_period_range: (usize, usize, usize),
1847    pub slow_period_range: (usize, usize, usize),
1848    pub devup_range: (f64, f64, f64),
1849    pub devdn_range: (f64, f64, f64),
1850    pub fast_ma_type: String,
1851    pub slow_ma_type: String,
1852}
1853
1854#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1855#[derive(Serialize, Deserialize)]
1856pub struct MabBatchJsOutput {
1857    pub upperbands: Vec<f64>,
1858    pub middlebands: Vec<f64>,
1859    pub lowerbands: Vec<f64>,
1860    pub combos: Vec<MabParams>,
1861    pub rows: usize,
1862    pub cols: usize,
1863}
1864
1865#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1866#[wasm_bindgen(js_name = mab_batch)]
1867pub fn mab_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1868    let config: MabBatchConfig = serde_wasm_bindgen::from_value(config)
1869        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1870
1871    let sweep = MabBatchRange {
1872        fast_period: config.fast_period_range,
1873        slow_period: config.slow_period_range,
1874        devup: config.devup_range,
1875        devdn: config.devdn_range,
1876        fast_ma_type: (
1877            config.fast_ma_type.clone(),
1878            config.fast_ma_type.clone(),
1879            "".to_string(),
1880        ),
1881        slow_ma_type: (
1882            config.slow_ma_type.clone(),
1883            config.slow_ma_type.clone(),
1884            "".to_string(),
1885        ),
1886    };
1887
1888    let output = mab_batch_inner(data, &sweep, Kernel::Auto, false)
1889        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1890
1891    let js_output = MabBatchJsOutput {
1892        upperbands: output.upperbands,
1893        middlebands: output.middlebands,
1894        lowerbands: output.lowerbands,
1895        combos: output.combos,
1896        rows: output.rows,
1897        cols: output.cols,
1898    };
1899
1900    serde_wasm_bindgen::to_value(&js_output)
1901        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1902}
1903
1904#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1905#[wasm_bindgen]
1906pub fn mab_alloc(len: usize) -> *mut f64 {
1907    let mut vec = Vec::<f64>::with_capacity(len);
1908    let ptr = vec.as_mut_ptr();
1909    std::mem::forget(vec);
1910    ptr
1911}
1912
1913#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1914#[wasm_bindgen]
1915pub fn mab_free(ptr: *mut f64, len: usize) {
1916    if !ptr.is_null() {
1917        unsafe {
1918            let _ = Vec::from_raw_parts(ptr, len, len);
1919        }
1920    }
1921}
1922
1923#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1924#[wasm_bindgen]
1925pub fn mab_into(
1926    in_ptr: *const f64,
1927    upper_ptr: *mut f64,
1928    middle_ptr: *mut f64,
1929    lower_ptr: *mut f64,
1930    len: usize,
1931    fast_period: usize,
1932    slow_period: usize,
1933    devup: f64,
1934    devdn: f64,
1935    fast_ma_type: &str,
1936    slow_ma_type: &str,
1937) -> Result<(), JsValue> {
1938    if in_ptr.is_null() || upper_ptr.is_null() || middle_ptr.is_null() || lower_ptr.is_null() {
1939        return Err(JsValue::from_str("Null pointer provided"));
1940    }
1941
1942    unsafe {
1943        let data = std::slice::from_raw_parts(in_ptr, len);
1944        let params = MabParams {
1945            fast_period: Some(fast_period),
1946            slow_period: Some(slow_period),
1947            devup: Some(devup),
1948            devdn: Some(devdn),
1949            fast_ma_type: Some(fast_ma_type.to_string()),
1950            slow_ma_type: Some(slow_ma_type.to_string()),
1951        };
1952        let input = MabInput::from_slice(data, params);
1953
1954        let need_temp = in_ptr == upper_ptr || in_ptr == middle_ptr || in_ptr == lower_ptr;
1955
1956        if need_temp {
1957            let mut temp_upper = vec![0.0; len];
1958            let mut temp_middle = vec![0.0; len];
1959            let mut temp_lower = vec![0.0; len];
1960
1961            mab_into_slice(
1962                &mut temp_upper,
1963                &mut temp_middle,
1964                &mut temp_lower,
1965                &input,
1966                Kernel::Auto,
1967            )
1968            .map_err(|e| JsValue::from_str(&e.to_string()))?;
1969
1970            let upper_out = std::slice::from_raw_parts_mut(upper_ptr, len);
1971            let middle_out = std::slice::from_raw_parts_mut(middle_ptr, len);
1972            let lower_out = std::slice::from_raw_parts_mut(lower_ptr, len);
1973
1974            upper_out.copy_from_slice(&temp_upper);
1975            middle_out.copy_from_slice(&temp_middle);
1976            lower_out.copy_from_slice(&temp_lower);
1977        } else {
1978            let upper_out = std::slice::from_raw_parts_mut(upper_ptr, len);
1979            let middle_out = std::slice::from_raw_parts_mut(middle_ptr, len);
1980            let lower_out = std::slice::from_raw_parts_mut(lower_ptr, len);
1981
1982            mab_into_slice(upper_out, middle_out, lower_out, &input, Kernel::Auto)
1983                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1984        }
1985        Ok(())
1986    }
1987}
1988
1989#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1990#[wasm_bindgen]
1991pub fn mab_batch_into(
1992    in_ptr: *const f64,
1993    upper_ptr: *mut f64,
1994    middle_ptr: *mut f64,
1995    lower_ptr: *mut f64,
1996    len: usize,
1997    fast_period_start: usize,
1998    fast_period_end: usize,
1999    fast_period_step: usize,
2000    slow_period_start: usize,
2001    slow_period_end: usize,
2002    slow_period_step: usize,
2003    devup_start: f64,
2004    devup_end: f64,
2005    devup_step: f64,
2006    devdn_start: f64,
2007    devdn_end: f64,
2008    devdn_step: f64,
2009    fast_ma_type: &str,
2010    slow_ma_type: &str,
2011) -> Result<usize, JsValue> {
2012    if in_ptr.is_null() || upper_ptr.is_null() || middle_ptr.is_null() || lower_ptr.is_null() {
2013        return Err(JsValue::from_str("Null pointer passed to mab_batch_into"));
2014    }
2015
2016    unsafe {
2017        let data = std::slice::from_raw_parts(in_ptr, len);
2018
2019        let sweep = MabBatchRange {
2020            fast_period: (fast_period_start, fast_period_end, fast_period_step),
2021            slow_period: (slow_period_start, slow_period_end, slow_period_step),
2022            devup: (devup_start, devup_end, devup_step),
2023            devdn: (devdn_start, devdn_end, devdn_step),
2024            fast_ma_type: (
2025                fast_ma_type.to_string(),
2026                fast_ma_type.to_string(),
2027                "".to_string(),
2028            ),
2029            slow_ma_type: (
2030                slow_ma_type.to_string(),
2031                slow_ma_type.to_string(),
2032                "".to_string(),
2033            ),
2034        };
2035
2036        let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2037        let rows = combos.len();
2038        let cols = len;
2039        if cols == 0 {
2040            return Err(JsValue::from_str(&MabError::EmptyInputData.to_string()));
2041        }
2042        let total = rows
2043            .checked_mul(cols)
2044            .ok_or_else(|| JsValue::from_str("mab_batch_into: rows*cols overflow"))?;
2045
2046        let first_valid = data
2047            .iter()
2048            .position(|x| !x.is_nan())
2049            .ok_or(MabError::AllValuesNaN)
2050            .map_err(|e| JsValue::from_str(&e.to_string()))?;
2051        let valid = cols - first_valid;
2052        let warmup_prefixes: Vec<usize> = combos
2053            .iter()
2054            .map(|p| {
2055                let fast = p.fast_period.unwrap();
2056                let slow = p.slow_period.unwrap();
2057                if fast == 0 || slow == 0 || fast > cols || slow > cols {
2058                    return Err(MabError::InvalidPeriod {
2059                        fast,
2060                        slow,
2061                        data_len: cols,
2062                    });
2063                }
2064                let need_total = fast.max(slow) + fast - 1;
2065                if valid < need_total {
2066                    return Err(MabError::NotEnoughValidData {
2067                        needed: need_total,
2068                        valid,
2069                    });
2070                }
2071                Ok(first_valid + need_total)
2072            })
2073            .collect::<Result<Vec<_>, MabError>>()
2074            .map_err(|e| JsValue::from_str(&e.to_string()))?;
2075
2076        let mu_upper: &mut [MaybeUninit<f64>] =
2077            std::slice::from_raw_parts_mut(upper_ptr as *mut MaybeUninit<f64>, total);
2078        let mu_middle: &mut [MaybeUninit<f64>] =
2079            std::slice::from_raw_parts_mut(middle_ptr as *mut MaybeUninit<f64>, total);
2080        let mu_lower: &mut [MaybeUninit<f64>] =
2081            std::slice::from_raw_parts_mut(lower_ptr as *mut MaybeUninit<f64>, total);
2082        init_matrix_prefixes(mu_upper, cols, &warmup_prefixes);
2083        init_matrix_prefixes(mu_middle, cols, &warmup_prefixes);
2084        init_matrix_prefixes(mu_lower, cols, &warmup_prefixes);
2085
2086        let upper_out = std::slice::from_raw_parts_mut(upper_ptr, total);
2087        let middle_out = std::slice::from_raw_parts_mut(middle_ptr, total);
2088        let lower_out = std::slice::from_raw_parts_mut(lower_ptr, total);
2089
2090        mab_batch_inner_into(
2091            data,
2092            &sweep,
2093            Kernel::Auto,
2094            false,
2095            upper_out,
2096            middle_out,
2097            lower_out,
2098        )
2099        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2100
2101        Ok(rows)
2102    }
2103}
2104
2105#[inline]
2106pub unsafe fn mab_scalar_classic_sma(
2107    data: &[f64],
2108    fast_period: usize,
2109    slow_period: usize,
2110    devup: f64,
2111    devdn: f64,
2112    first_valid_idx: usize,
2113    upper: &mut [f64],
2114    middle: &mut [f64],
2115    lower: &mut [f64],
2116) -> Result<(), MabError> {
2117    let n = data.len();
2118
2119    let mut fast_ma = vec![f64::NAN; n];
2120    if fast_period > 0 && first_valid_idx + fast_period <= n {
2121        let mut sum = 0.0;
2122        for i in 0..fast_period {
2123            sum += data[first_valid_idx + i];
2124        }
2125        fast_ma[first_valid_idx + fast_period - 1] = sum / fast_period as f64;
2126
2127        for i in (first_valid_idx + fast_period)..n {
2128            sum = sum - data[i - fast_period] + data[i];
2129            fast_ma[i] = sum / fast_period as f64;
2130        }
2131    }
2132
2133    let mut slow_ma = vec![f64::NAN; n];
2134    if slow_period > 0 && first_valid_idx + slow_period <= n {
2135        let mut sum = 0.0;
2136        for i in 0..slow_period {
2137            sum += data[first_valid_idx + i];
2138        }
2139        slow_ma[first_valid_idx + slow_period - 1] = sum / slow_period as f64;
2140
2141        for i in (first_valid_idx + slow_period)..n {
2142            sum = sum - data[i - slow_period] + data[i];
2143            slow_ma[i] = sum / slow_period as f64;
2144        }
2145    }
2146
2147    let need_total = slow_period.max(fast_period) + fast_period - 1;
2148    let warmup = first_valid_idx + need_total - 1;
2149    let first_output = warmup + 1;
2150
2151    for i in 0..first_output.min(n) {
2152        upper[i] = f64::NAN;
2153        middle[i] = f64::NAN;
2154        lower[i] = f64::NAN;
2155    }
2156
2157    if first_output >= n {
2158        return Ok(());
2159    }
2160
2161    let start_idx = if first_output >= fast_period {
2162        first_output - fast_period + 1
2163    } else {
2164        0
2165    };
2166
2167    let mut sum_sq = 0.0;
2168    for i in start_idx..(start_idx + fast_period).min(fast_ma.len()) {
2169        let diff = fast_ma[i] - slow_ma[i];
2170        if !diff.is_nan() {
2171            sum_sq += diff * diff;
2172        }
2173    }
2174
2175    if first_output < fast_ma.len() {
2176        let dev = (sum_sq / fast_period as f64).sqrt();
2177        middle[first_output] = fast_ma[first_output];
2178        upper[first_output] = slow_ma[first_output] + devup * dev;
2179        lower[first_output] = slow_ma[first_output] - devdn * dev;
2180    }
2181
2182    for i in (first_output + 1)..fast_ma.len() {
2183        let old_idx = i - fast_period;
2184        let old = fast_ma[old_idx] - slow_ma[old_idx];
2185        let new = fast_ma[i] - slow_ma[i];
2186        if !old.is_nan() && !new.is_nan() {
2187            sum_sq += new * new - old * old;
2188        }
2189        let dev = (sum_sq / fast_period as f64).sqrt();
2190
2191        middle[i] = fast_ma[i];
2192        upper[i] = slow_ma[i] + devup * dev;
2193        lower[i] = slow_ma[i] - devdn * dev;
2194    }
2195
2196    Ok(())
2197}
2198
2199#[cfg(all(feature = "python", feature = "cuda"))]
2200use crate::cuda::{cuda_available, moving_averages::CudaMab};
2201#[cfg(all(feature = "python", feature = "cuda"))]
2202use crate::indicators::moving_averages::alma::{make_device_array_py, DeviceArrayF32Py};
2203#[cfg(all(feature = "python", feature = "cuda"))]
2204#[cfg(all(feature = "python", feature = "cuda"))]
2205use numpy::{PyReadonlyArray1, PyReadonlyArray2, PyUntypedArrayMethods};
2206#[cfg(all(feature = "python", feature = "cuda"))]
2207use pyo3::{pyfunction, PyResult, Python};
2208
2209#[cfg(all(feature = "python", feature = "cuda"))]
2210#[pyfunction(name = "mab_cuda_batch_dev")]
2211#[pyo3(signature = (data_f32, fast_period_range, slow_period_range, devup_range=(1.0,1.0,0.0), devdn_range=(1.0,1.0,0.0), fast_ma_type="sma", slow_ma_type="sma", device_id=0))]
2212pub fn mab_cuda_batch_dev_py(
2213    py: Python<'_>,
2214    data_f32: PyReadonlyArray1<'_, f32>,
2215    fast_period_range: (usize, usize, usize),
2216    slow_period_range: (usize, usize, usize),
2217    devup_range: (f64, f64, f64),
2218    devdn_range: (f64, f64, f64),
2219    fast_ma_type: &str,
2220    slow_ma_type: &str,
2221    device_id: usize,
2222) -> PyResult<(DeviceArrayF32Py, DeviceArrayF32Py, DeviceArrayF32Py)> {
2223    if !cuda_available() {
2224        return Err(PyValueError::new_err("CUDA not available"));
2225    }
2226    let slice = data_f32.as_slice()?;
2227    let sweep = MabBatchRange {
2228        fast_period: fast_period_range,
2229        slow_period: slow_period_range,
2230        devup: devup_range,
2231        devdn: devdn_range,
2232        fast_ma_type: (
2233            fast_ma_type.to_string(),
2234            fast_ma_type.to_string(),
2235            String::new(),
2236        ),
2237        slow_ma_type: (
2238            slow_ma_type.to_string(),
2239            slow_ma_type.to_string(),
2240            String::new(),
2241        ),
2242    };
2243    let (up, mid, lo) = py.allow_threads(|| {
2244        let cuda = CudaMab::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2245        let (trip, _combos) = cuda
2246            .mab_batch_dev(slice, &sweep)
2247            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2248        Ok::<_, pyo3::PyErr>((trip.upper, trip.middle, trip.lower))
2249    })?;
2250
2251    let up_py = make_device_array_py(device_id, up)?;
2252    let mid_py = make_device_array_py(device_id, mid)?;
2253    let lo_py = make_device_array_py(device_id, lo)?;
2254
2255    Ok((up_py, mid_py, lo_py))
2256}
2257
2258#[cfg(all(feature = "python", feature = "cuda"))]
2259#[pyfunction(name = "mab_cuda_many_series_one_param_dev")]
2260#[pyo3(signature = (data_tm_f32, fast_period, slow_period, devup=1.0, devdn=1.0, fast_ma_type="sma", slow_ma_type="sma", device_id=0))]
2261pub fn mab_cuda_many_series_one_param_dev_py(
2262    py: Python<'_>,
2263    data_tm_f32: PyReadonlyArray2<'_, f32>,
2264    fast_period: usize,
2265    slow_period: usize,
2266    devup: f64,
2267    devdn: f64,
2268    fast_ma_type: &str,
2269    slow_ma_type: &str,
2270    device_id: usize,
2271) -> PyResult<(DeviceArrayF32Py, DeviceArrayF32Py, DeviceArrayF32Py)> {
2272    if !cuda_available() {
2273        return Err(PyValueError::new_err("CUDA not available"));
2274    }
2275    let flat: &[f32] = data_tm_f32.as_slice()?;
2276    let rows = data_tm_f32.shape()[0];
2277    let cols = data_tm_f32.shape()[1];
2278    let params = MabParams {
2279        fast_period: Some(fast_period),
2280        slow_period: Some(slow_period),
2281        devup: Some(devup),
2282        devdn: Some(devdn),
2283        fast_ma_type: Some(fast_ma_type.to_string()),
2284        slow_ma_type: Some(slow_ma_type.to_string()),
2285    };
2286    let (up, mid, lo) = py.allow_threads(|| {
2287        let cuda = CudaMab::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2288        let trip = cuda
2289            .mab_many_series_one_param_time_major_dev(flat, cols, rows, &params)
2290            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2291        Ok::<_, pyo3::PyErr>((trip.upper, trip.middle, trip.lower))
2292    })?;
2293
2294    let up_py = make_device_array_py(device_id, up)?;
2295    let mid_py = make_device_array_py(device_id, mid)?;
2296    let lo_py = make_device_array_py(device_id, lo)?;
2297
2298    Ok((up_py, mid_py, lo_py))
2299}
2300
2301#[cfg(test)]
2302mod tests {
2303    use super::*;
2304    use crate::utilities::data_loader::read_candles_from_csv;
2305    #[cfg(feature = "proptest")]
2306    use proptest::prelude::*;
2307    use std::error::Error;
2308
2309    macro_rules! skip_if_unsupported {
2310        ($kernel:expr, $test_name:expr) => {
2311            #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
2312            if matches!(
2313                $kernel,
2314                Kernel::Avx2 | Kernel::Avx512 | Kernel::Avx2Batch | Kernel::Avx512Batch
2315            ) {
2316                eprintln!(
2317                    "[{}] Skipping - {:?} not supported on WASM",
2318                    $test_name, $kernel
2319                );
2320                return Ok(());
2321            }
2322            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
2323            if matches!(
2324                $kernel,
2325                Kernel::Avx2 | Kernel::Avx512 | Kernel::Avx2Batch | Kernel::Avx512Batch
2326            ) {
2327                eprintln!(
2328                    "[{}] Skipping - {:?} requires 'nightly-avx' feature",
2329                    $test_name, $kernel
2330                );
2331                return Ok(());
2332            }
2333        };
2334    }
2335
2336    #[cfg(debug_assertions)]
2337    fn check_mab_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2338        skip_if_unsupported!(kernel, test_name);
2339
2340        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2341        let candles = read_candles_from_csv(file_path)?;
2342
2343        let test_params = vec![
2344            MabParams::default(),
2345            MabParams {
2346                fast_period: Some(2),
2347                slow_period: Some(3),
2348                devup: Some(1.0),
2349                devdn: Some(1.0),
2350                fast_ma_type: Some("sma".to_string()),
2351                slow_ma_type: Some("sma".to_string()),
2352            },
2353            MabParams {
2354                fast_period: Some(5),
2355                slow_period: Some(10),
2356                devup: Some(0.5),
2357                devdn: Some(0.5),
2358                fast_ma_type: Some("sma".to_string()),
2359                slow_ma_type: Some("sma".to_string()),
2360            },
2361            MabParams {
2362                fast_period: Some(15),
2363                slow_period: Some(30),
2364                devup: Some(2.0),
2365                devdn: Some(2.0),
2366                fast_ma_type: Some("ema".to_string()),
2367                slow_ma_type: Some("ema".to_string()),
2368            },
2369            MabParams {
2370                fast_period: Some(50),
2371                slow_period: Some(100),
2372                devup: Some(3.0),
2373                devdn: Some(3.0),
2374                fast_ma_type: Some("sma".to_string()),
2375                slow_ma_type: Some("sma".to_string()),
2376            },
2377            MabParams {
2378                fast_period: Some(10),
2379                slow_period: Some(20),
2380                devup: Some(1.5),
2381                devdn: Some(1.5),
2382                fast_ma_type: Some("sma".to_string()),
2383                slow_ma_type: Some("ema".to_string()),
2384            },
2385            MabParams {
2386                fast_period: Some(8),
2387                slow_period: Some(21),
2388                devup: Some(2.5),
2389                devdn: Some(2.5),
2390                fast_ma_type: Some("ema".to_string()),
2391                slow_ma_type: Some("sma".to_string()),
2392            },
2393            MabParams {
2394                fast_period: Some(12),
2395                slow_period: Some(26),
2396                devup: Some(2.0),
2397                devdn: Some(1.0),
2398                fast_ma_type: Some("ema".to_string()),
2399                slow_ma_type: Some("ema".to_string()),
2400            },
2401            MabParams {
2402                fast_period: Some(9),
2403                slow_period: Some(10),
2404                devup: Some(1.0),
2405                devdn: Some(2.0),
2406                fast_ma_type: Some("sma".to_string()),
2407                slow_ma_type: Some("sma".to_string()),
2408            },
2409            MabParams {
2410                fast_period: Some(30),
2411                slow_period: Some(200),
2412                devup: Some(1.0),
2413                devdn: Some(1.0),
2414                fast_ma_type: Some("ema".to_string()),
2415                slow_ma_type: Some("ema".to_string()),
2416            },
2417        ];
2418
2419        for (param_idx, params) in test_params.iter().enumerate() {
2420            let input = MabInput::from_candles(&candles, "close", params.clone());
2421            let output = mab_with_kernel(&input, kernel)?;
2422
2423            for (i, &val) in output.upperband.iter().enumerate() {
2424                if val.is_nan() {
2425                    continue;
2426                }
2427
2428                let bits = val.to_bits();
2429
2430                if bits == 0x11111111_11111111 {
2431                    panic!(
2432						"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in upperband \
2433						 with params: fast_period={}, slow_period={}, devup={}, devdn={}, fast_ma_type={}, slow_ma_type={} (param set {})",
2434						test_name, val, bits, i,
2435						params.fast_period.unwrap_or(10),
2436						params.slow_period.unwrap_or(50),
2437						params.devup.unwrap_or(1.0),
2438						params.devdn.unwrap_or(1.0),
2439						params.fast_ma_type.as_ref().unwrap_or(&"sma".to_string()),
2440						params.slow_ma_type.as_ref().unwrap_or(&"sma".to_string()),
2441						param_idx
2442					);
2443                }
2444
2445                if bits == 0x22222222_22222222 {
2446                    panic!(
2447						"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in upperband \
2448						 with params: fast_period={}, slow_period={}, devup={}, devdn={}, fast_ma_type={}, slow_ma_type={} (param set {})",
2449						test_name, val, bits, i,
2450						params.fast_period.unwrap_or(10),
2451						params.slow_period.unwrap_or(50),
2452						params.devup.unwrap_or(1.0),
2453						params.devdn.unwrap_or(1.0),
2454						params.fast_ma_type.as_ref().unwrap_or(&"sma".to_string()),
2455						params.slow_ma_type.as_ref().unwrap_or(&"sma".to_string()),
2456						param_idx
2457					);
2458                }
2459
2460                if bits == 0x33333333_33333333 {
2461                    panic!(
2462						"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in upperband \
2463						 with params: fast_period={}, slow_period={}, devup={}, devdn={}, fast_ma_type={}, slow_ma_type={} (param set {})",
2464						test_name, val, bits, i,
2465						params.fast_period.unwrap_or(10),
2466						params.slow_period.unwrap_or(50),
2467						params.devup.unwrap_or(1.0),
2468						params.devdn.unwrap_or(1.0),
2469						params.fast_ma_type.as_ref().unwrap_or(&"sma".to_string()),
2470						params.slow_ma_type.as_ref().unwrap_or(&"sma".to_string()),
2471						param_idx
2472					);
2473                }
2474            }
2475
2476            for (i, &val) in output.middleband.iter().enumerate() {
2477                if val.is_nan() {
2478                    continue;
2479                }
2480
2481                let bits = val.to_bits();
2482
2483                if bits == 0x11111111_11111111 {
2484                    panic!(
2485						"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in middleband \
2486						 with params: fast_period={}, slow_period={}, devup={}, devdn={}, fast_ma_type={}, slow_ma_type={} (param set {})",
2487						test_name, val, bits, i,
2488						params.fast_period.unwrap_or(10),
2489						params.slow_period.unwrap_or(50),
2490						params.devup.unwrap_or(1.0),
2491						params.devdn.unwrap_or(1.0),
2492						params.fast_ma_type.as_ref().unwrap_or(&"sma".to_string()),
2493						params.slow_ma_type.as_ref().unwrap_or(&"sma".to_string()),
2494						param_idx
2495					);
2496                }
2497
2498                if bits == 0x22222222_22222222 {
2499                    panic!(
2500						"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in middleband \
2501						 with params: fast_period={}, slow_period={}, devup={}, devdn={}, fast_ma_type={}, slow_ma_type={} (param set {})",
2502						test_name, val, bits, i,
2503						params.fast_period.unwrap_or(10),
2504						params.slow_period.unwrap_or(50),
2505						params.devup.unwrap_or(1.0),
2506						params.devdn.unwrap_or(1.0),
2507						params.fast_ma_type.as_ref().unwrap_or(&"sma".to_string()),
2508						params.slow_ma_type.as_ref().unwrap_or(&"sma".to_string()),
2509						param_idx
2510					);
2511                }
2512
2513                if bits == 0x33333333_33333333 {
2514                    panic!(
2515						"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in middleband \
2516						 with params: fast_period={}, slow_period={}, devup={}, devdn={}, fast_ma_type={}, slow_ma_type={} (param set {})",
2517						test_name, val, bits, i,
2518						params.fast_period.unwrap_or(10),
2519						params.slow_period.unwrap_or(50),
2520						params.devup.unwrap_or(1.0),
2521						params.devdn.unwrap_or(1.0),
2522						params.fast_ma_type.as_ref().unwrap_or(&"sma".to_string()),
2523						params.slow_ma_type.as_ref().unwrap_or(&"sma".to_string()),
2524						param_idx
2525					);
2526                }
2527            }
2528
2529            for (i, &val) in output.lowerband.iter().enumerate() {
2530                if val.is_nan() {
2531                    continue;
2532                }
2533
2534                let bits = val.to_bits();
2535
2536                if bits == 0x11111111_11111111 {
2537                    panic!(
2538						"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} in lowerband \
2539						 with params: fast_period={}, slow_period={}, devup={}, devdn={}, fast_ma_type={}, slow_ma_type={} (param set {})",
2540						test_name, val, bits, i,
2541						params.fast_period.unwrap_or(10),
2542						params.slow_period.unwrap_or(50),
2543						params.devup.unwrap_or(1.0),
2544						params.devdn.unwrap_or(1.0),
2545						params.fast_ma_type.as_ref().unwrap_or(&"sma".to_string()),
2546						params.slow_ma_type.as_ref().unwrap_or(&"sma".to_string()),
2547						param_idx
2548					);
2549                }
2550
2551                if bits == 0x22222222_22222222 {
2552                    panic!(
2553						"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} in lowerband \
2554						 with params: fast_period={}, slow_period={}, devup={}, devdn={}, fast_ma_type={}, slow_ma_type={} (param set {})",
2555						test_name, val, bits, i,
2556						params.fast_period.unwrap_or(10),
2557						params.slow_period.unwrap_or(50),
2558						params.devup.unwrap_or(1.0),
2559						params.devdn.unwrap_or(1.0),
2560						params.fast_ma_type.as_ref().unwrap_or(&"sma".to_string()),
2561						params.slow_ma_type.as_ref().unwrap_or(&"sma".to_string()),
2562						param_idx
2563					);
2564                }
2565
2566                if bits == 0x33333333_33333333 {
2567                    panic!(
2568						"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} in lowerband \
2569						 with params: fast_period={}, slow_period={}, devup={}, devdn={}, fast_ma_type={}, slow_ma_type={} (param set {})",
2570						test_name, val, bits, i,
2571						params.fast_period.unwrap_or(10),
2572						params.slow_period.unwrap_or(50),
2573						params.devup.unwrap_or(1.0),
2574						params.devdn.unwrap_or(1.0),
2575						params.fast_ma_type.as_ref().unwrap_or(&"sma".to_string()),
2576						params.slow_ma_type.as_ref().unwrap_or(&"sma".to_string()),
2577						param_idx
2578					);
2579                }
2580            }
2581        }
2582
2583        Ok(())
2584    }
2585
2586    #[cfg(not(debug_assertions))]
2587    fn check_mab_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2588        Ok(())
2589    }
2590
2591    #[cfg(debug_assertions)]
2592    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2593        skip_if_unsupported!(kernel, test);
2594
2595        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2596        let c = read_candles_from_csv(file)?;
2597
2598        let test_configs = vec![
2599            ((2, 10, 2), (10, 20, 5), (1.0, 1.0, 0.0), (1.0, 1.0, 0.0)),
2600            ((5, 15, 5), (20, 40, 10), (0.5, 2.0, 0.5), (0.5, 2.0, 0.5)),
2601            (
2602                (20, 40, 10),
2603                (50, 100, 25),
2604                (1.0, 3.0, 1.0),
2605                (1.0, 3.0, 1.0),
2606            ),
2607            ((2, 5, 1), (6, 10, 1), (1.0, 2.0, 0.5), (1.0, 2.0, 0.5)),
2608            ((10, 10, 0), (20, 50, 10), (1.0, 1.0, 0.0), (1.0, 1.0, 0.0)),
2609            ((5, 20, 5), (50, 50, 0), (2.0, 2.0, 0.0), (2.0, 2.0, 0.0)),
2610            ((8, 12, 2), (26, 26, 0), (1.0, 3.0, 0.5), (0.5, 2.0, 0.5)),
2611        ];
2612
2613        for (cfg_idx, &(fast_range, slow_range, devup_range, devdn_range)) in
2614            test_configs.iter().enumerate()
2615        {
2616            let sweep = MabBatchRange {
2617                fast_period: fast_range,
2618                slow_period: slow_range,
2619                devup: devup_range,
2620                devdn: devdn_range,
2621                fast_ma_type: ("sma".to_string(), "sma".to_string(), String::new()),
2622                slow_ma_type: ("sma".to_string(), "sma".to_string(), String::new()),
2623            };
2624
2625            let output = mab_batch_inner(c.close.as_slice(), &sweep, kernel, false)?;
2626
2627            for (idx, &val) in output.upperbands.iter().enumerate() {
2628                if val.is_nan() {
2629                    continue;
2630                }
2631
2632                let bits = val.to_bits();
2633                let row = idx / output.cols;
2634                let col = idx % output.cols;
2635                let combo = &output.combos[row];
2636
2637                if bits == 0x11111111_11111111 {
2638                    panic!(
2639						"[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) in upperbands \
2640						 at row {} col {} (flat index {}) with params: fast_period={}, slow_period={}, devup={}, devdn={}",
2641						test, cfg_idx, val, bits, row, col, idx,
2642						combo.fast_period.unwrap_or(10),
2643						combo.slow_period.unwrap_or(50),
2644						combo.devup.unwrap_or(1.0),
2645						combo.devdn.unwrap_or(1.0)
2646					);
2647                }
2648
2649                if bits == 0x22222222_22222222 {
2650                    panic!(
2651						"[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) in upperbands \
2652						 at row {} col {} (flat index {}) with params: fast_period={}, slow_period={}, devup={}, devdn={}",
2653						test, cfg_idx, val, bits, row, col, idx,
2654						combo.fast_period.unwrap_or(10),
2655						combo.slow_period.unwrap_or(50),
2656						combo.devup.unwrap_or(1.0),
2657						combo.devdn.unwrap_or(1.0)
2658					);
2659                }
2660
2661                if bits == 0x33333333_33333333 {
2662                    panic!(
2663						"[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) in upperbands \
2664						 at row {} col {} (flat index {}) with params: fast_period={}, slow_period={}, devup={}, devdn={}",
2665						test, cfg_idx, val, bits, row, col, idx,
2666						combo.fast_period.unwrap_or(10),
2667						combo.slow_period.unwrap_or(50),
2668						combo.devup.unwrap_or(1.0),
2669						combo.devdn.unwrap_or(1.0)
2670					);
2671                }
2672            }
2673
2674            for (idx, &val) in output.middlebands.iter().enumerate() {
2675                if val.is_nan() {
2676                    continue;
2677                }
2678
2679                let bits = val.to_bits();
2680                let row = idx / output.cols;
2681                let col = idx % output.cols;
2682                let combo = &output.combos[row];
2683
2684                if bits == 0x11111111_11111111 {
2685                    panic!(
2686						"[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) in middlebands \
2687						 at row {} col {} (flat index {}) with params: fast_period={}, slow_period={}, devup={}, devdn={}",
2688						test, cfg_idx, val, bits, row, col, idx,
2689						combo.fast_period.unwrap_or(10),
2690						combo.slow_period.unwrap_or(50),
2691						combo.devup.unwrap_or(1.0),
2692						combo.devdn.unwrap_or(1.0)
2693					);
2694                }
2695
2696                if bits == 0x22222222_22222222 {
2697                    panic!(
2698						"[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) in middlebands \
2699						 at row {} col {} (flat index {}) with params: fast_period={}, slow_period={}, devup={}, devdn={}",
2700						test, cfg_idx, val, bits, row, col, idx,
2701						combo.fast_period.unwrap_or(10),
2702						combo.slow_period.unwrap_or(50),
2703						combo.devup.unwrap_or(1.0),
2704						combo.devdn.unwrap_or(1.0)
2705					);
2706                }
2707
2708                if bits == 0x33333333_33333333 {
2709                    panic!(
2710						"[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) in middlebands \
2711						 at row {} col {} (flat index {}) with params: fast_period={}, slow_period={}, devup={}, devdn={}",
2712						test, cfg_idx, val, bits, row, col, idx,
2713						combo.fast_period.unwrap_or(10),
2714						combo.slow_period.unwrap_or(50),
2715						combo.devup.unwrap_or(1.0),
2716						combo.devdn.unwrap_or(1.0)
2717					);
2718                }
2719            }
2720
2721            for (idx, &val) in output.lowerbands.iter().enumerate() {
2722                if val.is_nan() {
2723                    continue;
2724                }
2725
2726                let bits = val.to_bits();
2727                let row = idx / output.cols;
2728                let col = idx % output.cols;
2729                let combo = &output.combos[row];
2730
2731                if bits == 0x11111111_11111111 {
2732                    panic!(
2733						"[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) in lowerbands \
2734						 at row {} col {} (flat index {}) with params: fast_period={}, slow_period={}, devup={}, devdn={}",
2735						test, cfg_idx, val, bits, row, col, idx,
2736						combo.fast_period.unwrap_or(10),
2737						combo.slow_period.unwrap_or(50),
2738						combo.devup.unwrap_or(1.0),
2739						combo.devdn.unwrap_or(1.0)
2740					);
2741                }
2742
2743                if bits == 0x22222222_22222222 {
2744                    panic!(
2745						"[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) in lowerbands \
2746						 at row {} col {} (flat index {}) with params: fast_period={}, slow_period={}, devup={}, devdn={}",
2747						test, cfg_idx, val, bits, row, col, idx,
2748						combo.fast_period.unwrap_or(10),
2749						combo.slow_period.unwrap_or(50),
2750						combo.devup.unwrap_or(1.0),
2751						combo.devdn.unwrap_or(1.0)
2752					);
2753                }
2754
2755                if bits == 0x33333333_33333333 {
2756                    panic!(
2757						"[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) in lowerbands \
2758						 at row {} col {} (flat index {}) with params: fast_period={}, slow_period={}, devup={}, devdn={}",
2759						test, cfg_idx, val, bits, row, col, idx,
2760						combo.fast_period.unwrap_or(10),
2761						combo.slow_period.unwrap_or(50),
2762						combo.devup.unwrap_or(1.0),
2763						combo.devdn.unwrap_or(1.0)
2764					);
2765                }
2766            }
2767        }
2768
2769        Ok(())
2770    }
2771
2772    #[cfg(not(debug_assertions))]
2773    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2774        Ok(())
2775    }
2776
2777    #[cfg(feature = "proptest")]
2778    #[allow(clippy::float_cmp)]
2779    fn check_mab_property(
2780        test_name: &str,
2781        kernel: Kernel,
2782    ) -> Result<(), Box<dyn std::error::Error>> {
2783        use proptest::prelude::*;
2784        skip_if_unsupported!(kernel, test_name);
2785
2786        let strat = (2usize..=50).prop_flat_map(|slow_period| {
2787            (2usize..=slow_period).prop_flat_map(move |fast_period| {
2788                (
2789                    prop::collection::vec(
2790                        (1f64..1000f64).prop_filter("finite", |x| x.is_finite()),
2791                        slow_period..400,
2792                    ),
2793                    Just(fast_period),
2794                    Just(slow_period),
2795                    0.5f64..3.0f64,
2796                    0.5f64..3.0f64,
2797                    prop::bool::ANY,
2798                    prop::bool::ANY,
2799                )
2800            })
2801        });
2802
2803        proptest::test_runner::TestRunner::default()
2804			.run(&strat, |(data, fast_period, slow_period, devup, devdn, fast_is_ema, slow_is_ema)| {
2805				let params = MabParams {
2806					fast_period: Some(fast_period),
2807					slow_period: Some(slow_period),
2808					devup: Some(devup),
2809					devdn: Some(devdn),
2810					fast_ma_type: Some(if fast_is_ema { "ema" } else { "sma" }.to_string()),
2811					slow_ma_type: Some(if slow_is_ema { "ema" } else { "sma" }.to_string()),
2812				};
2813				let input = MabInput::from_slice(&data, params.clone());
2814
2815
2816				let result = mab_with_kernel(&input, kernel).unwrap();
2817
2818
2819				let ref_params = params.clone();
2820				let ref_input = MabInput::from_slice(&data, ref_params);
2821				let ref_result = mab_with_kernel(&ref_input, Kernel::Scalar).unwrap();
2822
2823
2824				let first_valid_idx = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
2825				let warmup_period = first_valid_idx + fast_period.max(slow_period) - 1;
2826
2827				for i in 0..data.len() {
2828					let upper = result.upperband[i];
2829					let middle = result.middleband[i];
2830					let lower = result.lowerband[i];
2831					let ref_upper = ref_result.upperband[i];
2832					let ref_middle = ref_result.middleband[i];
2833					let ref_lower = ref_result.lowerband[i];
2834
2835
2836					if upper.is_nan() {
2837						prop_assert!(ref_upper.is_nan(),
2838							"[{}] NaN mismatch in upperband at idx {}: kernel={:?} has NaN but scalar doesn't",
2839							test_name, i, kernel);
2840					}
2841					if middle.is_nan() {
2842						prop_assert!(ref_middle.is_nan(),
2843							"[{}] NaN mismatch in middleband at idx {}: kernel={:?} has NaN but scalar doesn't",
2844							test_name, i, kernel);
2845					}
2846					if lower.is_nan() {
2847						prop_assert!(ref_lower.is_nan(),
2848							"[{}] NaN mismatch in lowerband at idx {}: kernel={:?} has NaN but scalar doesn't",
2849							test_name, i, kernel);
2850					}
2851
2852
2853					if upper.is_finite() && ref_upper.is_finite() {
2854						let ulp_diff = upper.to_bits().abs_diff(ref_upper.to_bits());
2855						prop_assert!(
2856							(upper - ref_upper).abs() <= 1e-9 || ulp_diff <= 8,
2857							"[{}] Upperband mismatch at idx {}: {} vs {} (ULP={})",
2858							test_name, i, upper, ref_upper, ulp_diff
2859						);
2860					}
2861					if middle.is_finite() && ref_middle.is_finite() {
2862						let ulp_diff = middle.to_bits().abs_diff(ref_middle.to_bits());
2863						prop_assert!(
2864							(middle - ref_middle).abs() <= 1e-9 || ulp_diff <= 8,
2865							"[{}] Middleband mismatch at idx {}: {} vs {} (ULP={})",
2866							test_name, i, middle, ref_middle, ulp_diff
2867						);
2868					}
2869					if lower.is_finite() && ref_lower.is_finite() {
2870						let ulp_diff = lower.to_bits().abs_diff(ref_lower.to_bits());
2871						prop_assert!(
2872							(lower - ref_lower).abs() <= 1e-9 || ulp_diff <= 8,
2873							"[{}] Lowerband mismatch at idx {}: {} vs {} (ULP={})",
2874							test_name, i, lower, ref_lower, ulp_diff
2875						);
2876					}
2877				}
2878
2879
2880				for i in 0..warmup_period.min(data.len()) {
2881					prop_assert!(
2882						result.upperband[i].is_nan(),
2883						"[{}] Expected NaN in upperband during warmup at idx {} (warmup={})",
2884						test_name, i, warmup_period
2885					);
2886					prop_assert!(
2887						result.middleband[i].is_nan(),
2888						"[{}] Expected NaN in middleband during warmup at idx {} (warmup={})",
2889						test_name, i, warmup_period
2890					);
2891					prop_assert!(
2892						result.lowerband[i].is_nan(),
2893						"[{}] Expected NaN in lowerband during warmup at idx {} (warmup={})",
2894						test_name, i, warmup_period
2895					);
2896				}
2897
2898
2899
2900				let first_valid_output = warmup_period + fast_period - 1;
2901				if first_valid_output < data.len() {
2902					for i in first_valid_output..data.len() {
2903						prop_assert!(
2904							result.upperband[i].is_finite(),
2905							"[{}] Non-finite value in upperband at idx {} after warmup",
2906							test_name, i
2907						);
2908						prop_assert!(
2909							result.middleband[i].is_finite(),
2910							"[{}] Non-finite value in middleband at idx {} after warmup",
2911							test_name, i
2912						);
2913						prop_assert!(
2914							result.lowerband[i].is_finite(),
2915							"[{}] Non-finite value in lowerband at idx {} after warmup",
2916							test_name, i
2917						);
2918					}
2919				}
2920
2921
2922				for i in first_valid_output..data.len() {
2923					let upper = result.upperband[i];
2924					let middle = result.middleband[i];
2925					let lower = result.lowerband[i];
2926
2927					if upper.is_finite() && middle.is_finite() && lower.is_finite() {
2928						prop_assert!(
2929							upper >= middle - 1e-10,
2930							"[{}] Band ordering violated: upper {} < middle {} at idx {}",
2931							test_name, upper, middle, i
2932						);
2933						prop_assert!(
2934							middle >= lower - 1e-10,
2935							"[{}] Band ordering violated: middle {} < lower {} at idx {}",
2936							test_name, middle, lower, i
2937						);
2938					}
2939				}
2940
2941
2942				if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10) && data.len() > first_valid_output {
2943
2944					for i in first_valid_output..data.len() {
2945						let upper = result.upperband[i];
2946						let middle = result.middleband[i];
2947						let lower = result.lowerband[i];
2948
2949						if upper.is_finite() && middle.is_finite() && lower.is_finite() {
2950							prop_assert!(
2951								(upper - middle).abs() <= 1e-9,
2952								"[{}] Constant data: upper {} != middle {} at idx {}",
2953								test_name, upper, middle, i
2954							);
2955							prop_assert!(
2956								(middle - lower).abs() <= 1e-9,
2957								"[{}] Constant data: middle {} != lower {} at idx {}",
2958								test_name, middle, lower, i
2959							);
2960						}
2961					}
2962				}
2963
2964
2965
2966				for i in first_valid_output..data.len() {
2967					let upper = result.upperband[i];
2968					let middle = result.middleband[i];
2969					let lower = result.lowerband[i];
2970
2971					if upper.is_finite() && middle.is_finite() && lower.is_finite() {
2972						let upper_spread = upper - middle;
2973						let lower_spread = middle - lower;
2974
2975
2976						if upper_spread > 1e-10 && lower_spread > 1e-10 {
2977							let spread_ratio = upper_spread / lower_spread;
2978							let multiplier_ratio = devup / devdn;
2979
2980							prop_assert!(
2981								(spread_ratio - multiplier_ratio).abs() <= multiplier_ratio * 0.05,
2982								"[{}] Deviation multiplier ratio mismatch at idx {}: spread_ratio={} vs multiplier_ratio={}",
2983								test_name, i, spread_ratio, multiplier_ratio
2984							);
2985						}
2986					}
2987				}
2988
2989
2990
2991				use crate::indicators::sma::{sma, SmaInput, SmaParams};
2992				use crate::indicators::ema::{ema, EmaInput, EmaParams};
2993
2994				let fast_ma = if fast_is_ema {
2995					let ema_params = EmaParams { period: Some(fast_period) };
2996					let ema_input = EmaInput::from_slice(&data, ema_params);
2997					ema(&ema_input).unwrap().values
2998				} else {
2999					let sma_params = SmaParams { period: Some(fast_period) };
3000					let sma_input = SmaInput::from_slice(&data, sma_params);
3001					sma(&sma_input).unwrap().values
3002				};
3003
3004				for i in first_valid_output..data.len() {
3005					if result.middleband[i].is_finite() && fast_ma[i].is_finite() {
3006						prop_assert!(
3007							(result.middleband[i] - fast_ma[i]).abs() <= 1e-9,
3008							"[{}] Middle band != fast MA at idx {}: {} vs {}",
3009							test_name, i, result.middleband[i], fast_ma[i]
3010						);
3011					}
3012				}
3013
3014
3015
3016				for i in first_valid_output..data.len() {
3017					let upper = result.upperband[i];
3018					let middle = result.middleband[i];
3019					let lower = result.lowerband[i];
3020
3021					if upper.is_finite() && middle.is_finite() && lower.is_finite() {
3022						let upper_spread = upper - middle;
3023						let lower_spread = middle - lower;
3024
3025						prop_assert!(
3026							upper_spread >= -1e-10,
3027							"[{}] Negative upper spread at idx {}: {}",
3028							test_name, i, upper_spread
3029						);
3030						prop_assert!(
3031							lower_spread >= -1e-10,
3032							"[{}] Negative lower spread at idx {}: {}",
3033							test_name, i, lower_spread
3034						);
3035
3036
3037						let data_range = data.iter()
3038							.filter(|x| x.is_finite())
3039							.fold((f64::INFINITY, f64::NEG_INFINITY), |(min, max), &x| {
3040								(min.min(x), max.max(x))
3041							});
3042						let range_span = data_range.1 - data_range.0;
3043
3044
3045
3046						if range_span > 0.0 {
3047							prop_assert!(
3048								upper_spread <= range_span * devup * 10.0,
3049								"[{}] Upper spread unreasonably large at idx {}: {} (range_span={})",
3050								test_name, i, upper_spread, range_span
3051							);
3052							prop_assert!(
3053								lower_spread <= range_span * devdn * 10.0,
3054								"[{}] Lower spread unreasonably large at idx {}: {} (range_span={})",
3055								test_name, i, lower_spread, range_span
3056							);
3057						}
3058					}
3059				}
3060
3061
3062
3063				let invalid_params = vec![
3064					MabParams {
3065						fast_period: Some(0),
3066						slow_period: Some(10),
3067						..Default::default()
3068					},
3069					MabParams {
3070						fast_period: Some(10),
3071						slow_period: Some(0),
3072						..Default::default()
3073					},
3074					MabParams {
3075						fast_period: Some(data.len() + 1),
3076						slow_period: Some(10),
3077						..Default::default()
3078					},
3079				];
3080
3081				for invalid_param in invalid_params {
3082					let invalid_input = MabInput::from_slice(&data, invalid_param);
3083					let invalid_result = mab_with_kernel(&invalid_input, kernel);
3084					prop_assert!(
3085						invalid_result.is_err(),
3086						"[{}] Expected error for invalid parameters but got Ok",
3087						test_name
3088					);
3089				}
3090
3091
3092
3093
3094				if fast_period == slow_period && data.len() > first_valid_output {
3095
3096					let equal_params = MabParams {
3097						fast_period: Some(fast_period),
3098						slow_period: Some(fast_period),
3099						devup: Some(devup),
3100						devdn: Some(devdn),
3101						fast_ma_type: params.fast_ma_type.clone(),
3102						slow_ma_type: params.slow_ma_type.clone(),
3103					};
3104					let equal_input = MabInput::from_slice(&data, equal_params);
3105					let equal_result = mab_with_kernel(&equal_input, kernel).unwrap();
3106
3107
3108
3109					if params.fast_ma_type == params.slow_ma_type {
3110						for i in first_valid_output..data.len().min(first_valid_output + 10) {
3111							if equal_result.upperband[i].is_finite() &&
3112							   equal_result.middleband[i].is_finite() &&
3113							   equal_result.lowerband[i].is_finite() {
3114								let upper_spread = equal_result.upperband[i] - equal_result.middleband[i];
3115								let lower_spread = equal_result.middleband[i] - equal_result.lowerband[i];
3116
3117
3118								prop_assert!(
3119									upper_spread <= 1e-6 || upper_spread <= equal_result.middleband[i].abs() * 1e-6,
3120									"[{}] Equal periods: upper spread too large at idx {}: {}",
3121									test_name, i, upper_spread
3122								);
3123								prop_assert!(
3124									lower_spread <= 1e-6 || lower_spread <= equal_result.middleband[i].abs() * 1e-6,
3125									"[{}] Equal periods: lower spread too large at idx {}: {}",
3126									test_name, i, lower_spread
3127								);
3128							}
3129						}
3130					}
3131				}
3132
3133				Ok(())
3134			})?;
3135
3136        Ok(())
3137    }
3138
3139    fn check_mab_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3140        skip_if_unsupported!(kernel, test_name);
3141        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3142        let candles = read_candles_from_csv(file_path)?;
3143        let default_params = MabParams {
3144            fast_period: None,
3145            ..MabParams::default()
3146        };
3147        let input = MabInput::from_candles(&candles, "close", default_params);
3148        let output = mab_with_kernel(&input, kernel)?;
3149        assert_eq!(output.upperband.len(), candles.close.len());
3150        Ok(())
3151    }
3152
3153    fn check_mab_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3154        skip_if_unsupported!(kernel, test_name);
3155        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3156        let candles = read_candles_from_csv(file_path)?;
3157        let params = MabParams::default();
3158        let input = MabInput::from_candles(&candles, "close", params);
3159        let result = mab_with_kernel(&input, kernel)?;
3160
3161        let expected_upper_last_five = [
3162            64002.843463352016,
3163            63976.62699738246,
3164            63949.00496307154,
3165            63912.13708526151,
3166            63828.40371728143,
3167        ];
3168        let expected_middle_last_five = [
3169            59213.90000000002,
3170            59180.800000000025,
3171            59161.40000000002,
3172            59132.00000000002,
3173            59042.40000000002,
3174        ];
3175        let expected_lower_last_five = [
3176            59350.676536647945,
3177            59296.93300261751,
3178            59252.75503692843,
3179            59190.30291473845,
3180            59070.11628271853,
3181        ];
3182
3183        let len = result.upperband.len();
3184        for i in 0..5 {
3185            let idx = len - 5 + i;
3186            assert!(
3187                (result.upperband[idx] - expected_upper_last_five[i]).abs() < 1e-4,
3188                "[{}] Upper band mismatch at index {}: {} vs expected {}",
3189                test_name,
3190                i,
3191                result.upperband[idx],
3192                expected_upper_last_five[i]
3193            );
3194            assert!(
3195                (result.middleband[idx] - expected_middle_last_five[i]).abs() < 1e-4,
3196                "[{}] Middle band mismatch at index {}: {} vs expected {}",
3197                test_name,
3198                i,
3199                result.middleband[idx],
3200                expected_middle_last_five[i]
3201            );
3202            assert!(
3203                (result.lowerband[idx] - expected_lower_last_five[i]).abs() < 1e-4,
3204                "[{}] Lower band mismatch at index {}: {} vs expected {}",
3205                test_name,
3206                i,
3207                result.lowerband[idx],
3208                expected_lower_last_five[i]
3209            );
3210        }
3211        Ok(())
3212    }
3213
3214    fn check_mab_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3215        skip_if_unsupported!(kernel, test_name);
3216        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3217        let candles = read_candles_from_csv(file_path)?;
3218        let input = MabInput::with_default_candles(&candles);
3219        let output = mab_with_kernel(&input, kernel)?;
3220        assert_eq!(output.upperband.len(), candles.close.len());
3221        Ok(())
3222    }
3223
3224    fn check_mab_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3225        skip_if_unsupported!(kernel, test_name);
3226        let input_data = [10.0, 20.0, 30.0];
3227        let params = MabParams {
3228            fast_period: Some(0),
3229            slow_period: Some(5),
3230            ..MabParams::default()
3231        };
3232        let input = MabInput::from_slice(&input_data, params);
3233        let res = mab_with_kernel(&input, kernel);
3234        assert!(
3235            res.is_err(),
3236            "[{}] Expected error for zero fast period",
3237            test_name
3238        );
3239
3240        let params2 = MabParams {
3241            fast_period: Some(5),
3242            slow_period: Some(0),
3243            ..MabParams::default()
3244        };
3245        let input2 = MabInput::from_slice(&input_data, params2);
3246        let res2 = mab_with_kernel(&input2, kernel);
3247        assert!(
3248            res2.is_err(),
3249            "[{}] Expected error for zero slow period",
3250            test_name
3251        );
3252        Ok(())
3253    }
3254
3255    fn check_mab_period_exceeds_length(
3256        test_name: &str,
3257        kernel: Kernel,
3258    ) -> Result<(), Box<dyn Error>> {
3259        skip_if_unsupported!(kernel, test_name);
3260        let data_small = [10.0, 20.0, 30.0];
3261        let params = MabParams {
3262            fast_period: Some(2),
3263            slow_period: Some(10),
3264            ..MabParams::default()
3265        };
3266        let input = MabInput::from_slice(&data_small, params);
3267        let res = mab_with_kernel(&input, kernel);
3268        assert!(
3269            res.is_err(),
3270            "[{}] Expected error when period exceeds data length",
3271            test_name
3272        );
3273        Ok(())
3274    }
3275
3276    fn check_mab_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3277        skip_if_unsupported!(kernel, test_name);
3278        let single_point = [42.0];
3279        let params = MabParams {
3280            fast_period: Some(10),
3281            slow_period: Some(20),
3282            ..MabParams::default()
3283        };
3284        let input = MabInput::from_slice(&single_point, params);
3285        let res = mab_with_kernel(&input, kernel);
3286        assert!(
3287            res.is_err(),
3288            "[{}] Expected error for insufficient data",
3289            test_name
3290        );
3291        Ok(())
3292    }
3293
3294    fn check_mab_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3295        skip_if_unsupported!(kernel, test_name);
3296        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3297        let candles = read_candles_from_csv(file_path)?;
3298        let params = MabParams::default();
3299        let first_input = MabInput::from_candles(&candles, "close", params.clone());
3300        let first_result = mab_with_kernel(&first_input, kernel)?;
3301
3302        let second_input = MabInput::from_slice(&first_result.upperband, params);
3303        let second_result = mab_with_kernel(&second_input, kernel)?;
3304        assert_eq!(second_result.upperband.len(), first_result.upperband.len());
3305
3306        let non_nan_count = second_result
3307            .upperband
3308            .iter()
3309            .skip(100)
3310            .filter(|x| !x.is_nan())
3311            .count();
3312        assert!(
3313            non_nan_count > 0,
3314            "[{}] Second calculation produced all NaN values",
3315            test_name
3316        );
3317        Ok(())
3318    }
3319
3320    fn check_mab_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3321        skip_if_unsupported!(kernel, test_name);
3322        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3323        let candles = read_candles_from_csv(file_path)?;
3324        let input = MabInput::from_candles(&candles, "close", MabParams::default());
3325        let res = mab_with_kernel(&input, kernel)?;
3326
3327        for i in 100..res.upperband.len().min(200) {
3328            assert!(
3329                !res.upperband[i].is_nan(),
3330                "[{}] Unexpected NaN in upper band at index {}",
3331                test_name,
3332                i
3333            );
3334            assert!(
3335                !res.middleband[i].is_nan(),
3336                "[{}] Unexpected NaN in middle band at index {}",
3337                test_name,
3338                i
3339            );
3340            assert!(
3341                !res.lowerband[i].is_nan(),
3342                "[{}] Unexpected NaN in lower band at index {}",
3343                test_name,
3344                i
3345            );
3346        }
3347        Ok(())
3348    }
3349
3350    #[allow(dead_code)]
3351    fn check_mab_streaming(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
3352        Ok(())
3353    }
3354
3355    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
3356        skip_if_unsupported!(kernel, test);
3357        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3358        let c = read_candles_from_csv(file)?;
3359
3360        let sweep = MabBatchRange {
3361            fast_period: (10, 10, 0),
3362            slow_period: (50, 50, 0),
3363            devup: (1.0, 1.0, 0.0),
3364            devdn: (1.0, 1.0, 0.0),
3365            fast_ma_type: ("sma".to_string(), "sma".to_string(), String::new()),
3366            slow_ma_type: ("sma".to_string(), "sma".to_string(), String::new()),
3367        };
3368
3369        let output = mab_batch_inner(c.close.as_slice(), &sweep, kernel, false)?;
3370
3371        assert_eq!(
3372            output.rows, 1,
3373            "[{}] Expected 1 row for default params",
3374            test
3375        );
3376        assert_eq!(
3377            output.cols,
3378            c.close.len(),
3379            "[{}] Cols should match input length",
3380            test
3381        );
3382
3383        let expected_upper = [
3384            64002.843463352016,
3385            63976.62699738246,
3386            63949.00496307154,
3387            63912.13708526151,
3388            63828.40371728143,
3389        ];
3390        let expected_middle = [
3391            59213.90000000002,
3392            59180.800000000025,
3393            59161.40000000002,
3394            59132.00000000002,
3395            59042.40000000002,
3396        ];
3397        let expected_lower = [
3398            59350.676536647945,
3399            59296.93300261751,
3400            59252.75503692843,
3401            59190.30291473845,
3402            59070.11628271853,
3403        ];
3404
3405        let start = output.cols - 5;
3406        for i in 0..5 {
3407            let idx = start + i;
3408            assert!(
3409                (output.upperbands[idx] - expected_upper[i]).abs() < 1e-4,
3410                "[{}] batch upper mismatch at idx {}: {} vs expected {}",
3411                test,
3412                i,
3413                output.upperbands[idx],
3414                expected_upper[i]
3415            );
3416            assert!(
3417                (output.middlebands[idx] - expected_middle[i]).abs() < 1e-4,
3418                "[{}] batch middle mismatch at idx {}: {} vs expected {}",
3419                test,
3420                i,
3421                output.middlebands[idx],
3422                expected_middle[i]
3423            );
3424            assert!(
3425                (output.lowerbands[idx] - expected_lower[i]).abs() < 1e-4,
3426                "[{}] batch lower mismatch at idx {}: {} vs expected {}",
3427                test,
3428                i,
3429                output.lowerbands[idx],
3430                expected_lower[i]
3431            );
3432        }
3433        Ok(())
3434    }
3435
3436    fn check_batch_grid_varying_fast_period(
3437        test: &str,
3438        kernel: Kernel,
3439    ) -> Result<(), Box<dyn Error>> {
3440        skip_if_unsupported!(kernel, test);
3441        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3442        let c = read_candles_from_csv(file)?;
3443
3444        let sweep = MabBatchRange {
3445            fast_period: (10, 12, 1),
3446            slow_period: (50, 50, 0),
3447            devup: (1.0, 1.0, 0.0),
3448            devdn: (1.0, 1.0, 0.0),
3449            fast_ma_type: ("sma".to_string(), "sma".to_string(), String::new()),
3450            slow_ma_type: ("sma".to_string(), "sma".to_string(), String::new()),
3451        };
3452
3453        let output = mab_batch_inner(c.close.as_slice(), &sweep, kernel, false)?;
3454
3455        assert_eq!(
3456            output.rows, 3,
3457            "[{}] Expected 3 rows for fast period 10-12",
3458            test
3459        );
3460        assert_eq!(
3461            output.combos.len(),
3462            3,
3463            "[{}] Expected 3 parameter combinations",
3464            test
3465        );
3466
3467        assert_eq!(
3468            output.combos[0].fast_period,
3469            Some(10),
3470            "[{}] First combo fast period",
3471            test
3472        );
3473        assert_eq!(
3474            output.combos[1].fast_period,
3475            Some(11),
3476            "[{}] Second combo fast period",
3477            test
3478        );
3479        assert_eq!(
3480            output.combos[2].fast_period,
3481            Some(12),
3482            "[{}] Third combo fast period",
3483            test
3484        );
3485
3486        for row in 0..3 {
3487            let row_start = row * output.cols;
3488            let row_data = &output.upperbands[row_start..row_start + output.cols];
3489
3490            let valid_count = row_data.iter().skip(100).filter(|x| !x.is_nan()).count();
3491            assert!(
3492                valid_count > 0,
3493                "[{}] Row {} should have valid values",
3494                test,
3495                row
3496            );
3497        }
3498        Ok(())
3499    }
3500
3501    macro_rules! generate_all_mab_tests {
3502		($($test_fn:ident),*) => {
3503			paste::paste! {
3504				$(
3505					#[test]
3506					fn [<$test_fn _scalar_f64>]() {
3507						let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
3508					}
3509				)*
3510				#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3511				$(
3512					#[test]
3513					fn [<$test_fn _avx2_f64>]() {
3514						let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
3515					}
3516					#[test]
3517					fn [<$test_fn _avx512_f64>]() {
3518						let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
3519					}
3520				)*
3521			}
3522		}
3523	}
3524
3525    macro_rules! gen_batch_tests {
3526        ($fn_name:ident) => {
3527            paste::paste! {
3528                #[test] fn [<$fn_name _scalar>]() {
3529                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
3530                }
3531                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3532                #[test] fn [<$fn_name _avx2>]() {
3533                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
3534                }
3535                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3536                #[test] fn [<$fn_name _avx512>]() {
3537                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
3538                }
3539                #[test] fn [<$fn_name _auto_detect>]() {
3540                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
3541                }
3542            }
3543        };
3544    }
3545
3546    generate_all_mab_tests!(
3547        check_mab_no_poison,
3548        check_mab_partial_params,
3549        check_mab_accuracy,
3550        check_mab_default_candles,
3551        check_mab_zero_period,
3552        check_mab_period_exceeds_length,
3553        check_mab_very_small_dataset,
3554        check_mab_reinput,
3555        check_mab_nan_handling
3556    );
3557
3558    #[cfg(feature = "proptest")]
3559    generate_all_mab_tests!(check_mab_property);
3560
3561    gen_batch_tests!(check_batch_no_poison);
3562    gen_batch_tests!(check_batch_default_row);
3563    gen_batch_tests!(check_batch_grid_varying_fast_period);
3564}