Skip to main content

vector_ta/indicators/moving_averages/
sma.rs

1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5    make_uninit_matrix,
6};
7#[cfg(feature = "python")]
8use crate::utilities::kernel_validation::validate_kernel;
9use aligned_vec::{AVec, CACHELINE_ALIGN};
10#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
11use core::arch::x86_64::*;
12#[cfg(not(target_arch = "wasm32"))]
13use rayon::prelude::*;
14use std::convert::AsRef;
15use std::mem::MaybeUninit;
16use thiserror::Error;
17
18#[cfg(all(feature = "python", feature = "cuda"))]
19use crate::cuda::moving_averages::{CudaSma, DeviceArrayF32};
20#[cfg(all(feature = "python", feature = "cuda"))]
21use cust::context::Context;
22#[cfg(all(feature = "python", feature = "cuda"))]
23use cust::memory::DeviceBuffer;
24#[cfg(feature = "python")]
25use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1, PyReadonlyArray2};
26#[cfg(feature = "python")]
27use pyo3::exceptions::PyValueError;
28#[cfg(feature = "python")]
29use pyo3::prelude::*;
30#[cfg(feature = "python")]
31use pyo3::types::PyDict;
32#[cfg(all(feature = "python", feature = "cuda"))]
33use std::sync::Arc;
34
35#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
36use serde::{Deserialize, Serialize};
37#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
38use wasm_bindgen::prelude::*;
39
40impl<'a> AsRef<[f64]> for SmaInput<'a> {
41    #[inline(always)]
42    fn as_ref(&self) -> &[f64] {
43        match &self.data {
44            SmaData::Slice(slice) => slice,
45            SmaData::Candles { candles, source } => source_type(candles, source),
46        }
47    }
48}
49
50#[derive(Debug, Clone)]
51pub enum SmaData<'a> {
52    Candles {
53        candles: &'a Candles,
54        source: &'a str,
55    },
56    Slice(&'a [f64]),
57}
58
59#[derive(Debug, Clone)]
60pub struct SmaOutput {
61    pub values: Vec<f64>,
62}
63
64#[derive(Debug, Clone)]
65#[cfg_attr(
66    all(target_arch = "wasm32", feature = "wasm"),
67    derive(Serialize, Deserialize)
68)]
69pub struct SmaParams {
70    pub period: Option<usize>,
71}
72
73impl Default for SmaParams {
74    fn default() -> Self {
75        Self { period: Some(9) }
76    }
77}
78
79#[derive(Debug, Clone)]
80pub struct SmaInput<'a> {
81    pub data: SmaData<'a>,
82    pub params: SmaParams,
83}
84
85impl<'a> SmaInput<'a> {
86    #[inline]
87    pub fn from_candles(c: &'a Candles, s: &'a str, p: SmaParams) -> Self {
88        Self {
89            data: SmaData::Candles {
90                candles: c,
91                source: s,
92            },
93            params: p,
94        }
95    }
96    #[inline]
97    pub fn from_slice(sl: &'a [f64], p: SmaParams) -> Self {
98        Self {
99            data: SmaData::Slice(sl),
100            params: p,
101        }
102    }
103    #[inline]
104    pub fn with_default_candles(c: &'a Candles) -> Self {
105        Self::from_candles(c, "close", SmaParams::default())
106    }
107    #[inline]
108    pub fn get_period(&self) -> usize {
109        self.params.period.unwrap_or(9)
110    }
111}
112
113#[derive(Copy, Clone, Debug)]
114pub struct SmaBuilder {
115    period: Option<usize>,
116    kernel: Kernel,
117}
118
119impl Default for SmaBuilder {
120    fn default() -> Self {
121        Self {
122            period: None,
123            kernel: Kernel::Auto,
124        }
125    }
126}
127
128impl SmaBuilder {
129    #[inline(always)]
130    pub fn new() -> Self {
131        Self::default()
132    }
133    #[inline(always)]
134    pub fn period(mut self, n: usize) -> Self {
135        self.period = Some(n);
136        self
137    }
138    #[inline(always)]
139    pub fn kernel(mut self, k: Kernel) -> Self {
140        self.kernel = k;
141        self
142    }
143    #[inline(always)]
144    pub fn apply(self, c: &Candles) -> Result<SmaOutput, SmaError> {
145        let p = SmaParams {
146            period: self.period,
147        };
148        let i = SmaInput::from_candles(c, "close", p);
149        sma_with_kernel(&i, self.kernel)
150    }
151    #[inline(always)]
152    pub fn apply_slice(self, d: &[f64]) -> Result<SmaOutput, SmaError> {
153        let p = SmaParams {
154            period: self.period,
155        };
156        let i = SmaInput::from_slice(d, p);
157        sma_with_kernel(&i, self.kernel)
158    }
159    #[inline(always)]
160    pub fn into_stream(self) -> Result<SmaStream, SmaError> {
161        let p = SmaParams {
162            period: self.period,
163        };
164        SmaStream::try_new(p)
165    }
166}
167
168#[derive(Debug, Error)]
169pub enum SmaError {
170    #[error("sma: Empty input data.")]
171    EmptyInputData,
172    #[error("sma: Invalid period: period = {period}, data length = {data_len}")]
173    InvalidPeriod { period: usize, data_len: usize },
174    #[error("sma: Not enough valid data: needed = {needed}, valid = {valid}")]
175    NotEnoughValidData { needed: usize, valid: usize },
176    #[error("sma: All values are NaN.")]
177    AllValuesNaN,
178    #[error("sma: Output buffer size mismatch: expected = {expected}, got = {got}")]
179    OutputLengthMismatch { expected: usize, got: usize },
180    #[error("sma: Invalid range: start={start}, end={end}, step={step}")]
181    InvalidRange {
182        start: usize,
183        end: usize,
184        step: usize,
185    },
186    #[error("sma: Invalid kernel for batch: {0:?}")]
187    InvalidKernelForBatch(Kernel),
188}
189
190#[inline]
191pub fn sma(input: &SmaInput) -> Result<SmaOutput, SmaError> {
192    sma_with_kernel(input, Kernel::Auto)
193}
194
195pub fn sma_with_kernel(input: &SmaInput, kernel: Kernel) -> Result<SmaOutput, SmaError> {
196    let (data, period, first, chosen) = sma_prepare(input, kernel)?;
197    let mut out = alloc_with_nan_prefix(data.len(), first + period - 1);
198    sma_compute_into(data, period, first, chosen, &mut out);
199    Ok(SmaOutput { values: out })
200}
201
202#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
203#[inline]
204pub fn sma_into(input: &SmaInput, out: &mut [f64]) -> Result<(), SmaError> {
205    let (data, period, first, chosen) = sma_prepare(input, Kernel::Auto)?;
206
207    if out.len() != data.len() {
208        return Err(SmaError::OutputLengthMismatch {
209            expected: data.len(),
210            got: out.len(),
211        });
212    }
213
214    let warm = (first + period - 1).min(out.len());
215    for v in &mut out[..warm] {
216        *v = f64::from_bits(0x7ff8_0000_0000_0000);
217    }
218
219    sma_compute_into(data, period, first, chosen, out);
220    Ok(())
221}
222
223#[inline]
224pub fn sma_into_slice(dst: &mut [f64], input: &SmaInput, kern: Kernel) -> Result<(), SmaError> {
225    let (data, period, first, chosen) = sma_prepare(input, kern)?;
226
227    if dst.len() != data.len() {
228        return Err(SmaError::OutputLengthMismatch {
229            expected: data.len(),
230            got: dst.len(),
231        });
232    }
233
234    let warmup = first + period - 1;
235    for v in &mut dst[..warmup] {
236        *v = f64::NAN;
237    }
238
239    sma_compute_into(data, period, first, chosen, dst);
240
241    Ok(())
242}
243
244#[inline(always)]
245fn sma_prepare<'a>(
246    input: &'a SmaInput,
247    kernel: Kernel,
248) -> Result<(&'a [f64], usize, usize, Kernel), SmaError> {
249    let data: &[f64] = input.as_ref();
250    if data.is_empty() {
251        return Err(SmaError::EmptyInputData);
252    }
253
254    let period = input.get_period();
255    let len = data.len();
256    if period == 0 || period > len {
257        return Err(SmaError::InvalidPeriod {
258            period,
259            data_len: len,
260        });
261    }
262
263    let first = data
264        .iter()
265        .position(|x| !x.is_nan())
266        .ok_or(SmaError::AllValuesNaN)?;
267    if len - first < period {
268        return Err(SmaError::NotEnoughValidData {
269            needed: period,
270            valid: len - first,
271        });
272    }
273
274    let chosen = match kernel {
275        Kernel::Auto => detect_best_kernel(),
276        k => k,
277    };
278    Ok((data, period, first, chosen))
279}
280
281#[inline]
282fn sma_compute_into(data: &[f64], period: usize, first: usize, kernel: Kernel, out: &mut [f64]) {
283    unsafe {
284        match kernel {
285            Kernel::Scalar | Kernel::ScalarBatch => {
286                sma_scalar(data, period, first, out);
287            }
288            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
289            Kernel::Avx2 | Kernel::Avx2Batch => {
290                sma_scalar(data, period, first, out);
291            }
292            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
293            Kernel::Avx512 | Kernel::Avx512Batch => {
294                sma_avx512(data, period, first, out);
295            }
296            _ => unreachable!(),
297        }
298    }
299}
300
301#[inline(always)]
302pub unsafe fn sma_scalar(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
303    debug_assert!(period >= 1);
304    debug_assert_eq!(data.len(), out.len());
305    let len = data.len();
306
307    let dp = data.as_ptr();
308    let op = out.as_mut_ptr();
309
310    if period == 1 {
311        for i in first..len {
312            *op.add(i) = *dp.add(i);
313        }
314        return;
315    }
316
317    let mut sum = 0.0;
318    for k in 0..period {
319        sum += *dp.add(first + k);
320    }
321    let inv = 1.0 / (period as f64);
322
323    *op.add(first + period - 1) = sum * inv;
324
325    for i in (first + period)..len {
326        sum += *dp.add(i) - *dp.add(i - period);
327        *op.add(i) = sum * inv;
328    }
329}
330
331#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
332#[target_feature(enable = "avx2")]
333#[inline]
334pub unsafe fn sma_avx2(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
335    use core::arch::x86_64::*;
336    debug_assert!(period >= 1);
337    debug_assert_eq!(data.len(), out.len());
338
339    let len = data.len();
340    let dp = data.as_ptr();
341    let op = out.as_mut_ptr();
342
343    if period == 1 {
344        let mut i = first;
345        while i < len {
346            *op.add(i) = *dp.add(i);
347            i += 1;
348        }
349        return;
350    }
351
352    let mut acc256 = _mm256_setzero_pd();
353    let mut k = 0usize;
354    let base = first;
355    let p4 = period & !3;
356
357    while k < p4 {
358        let v = _mm256_loadu_pd(dp.add(base + k));
359        acc256 = _mm256_add_pd(acc256, v);
360        k += 4;
361    }
362
363    let hadd = _mm256_hadd_pd(acc256, acc256);
364    let lo = _mm256_castpd256_pd128(hadd);
365    let hi = _mm256_extractf128_pd(hadd, 1);
366    let sum128 = _mm_add_sd(lo, hi);
367    let mut sum = _mm_cvtsd_f64(sum128);
368
369    while k < period {
370        sum += *dp.add(base + k);
371        k += 1;
372    }
373
374    let inv = 1.0 / (period as f64);
375    let inv_v = _mm256_set1_pd(inv);
376    let mut warm = first + period - 1;
377    *op.add(warm) = sum.mul_add(inv, 0.0);
378
379    let mut i = warm + 1;
380    let end = len;
381    let stride = 4usize;
382
383    while i + stride - 1 < end {
384        let v_new = _mm256_loadu_pd(dp.add(i));
385        let v_old = _mm256_loadu_pd(dp.add(i - period));
386        let d = _mm256_sub_pd(v_new, v_old);
387
388        let d_lo = _mm256_castpd256_pd128(d);
389        let d_hi = _mm256_extractf128_pd(d, 1);
390
391        let t_lo = _mm_unpacklo_pd(_mm_setzero_pd(), d_lo);
392        let p_lo = _mm_add_pd(d_lo, t_lo);
393
394        let t_hi = _mm_unpacklo_pd(_mm_setzero_pd(), d_hi);
395        let mut p_hi = _mm_add_pd(d_hi, t_hi);
396
397        let carry = _mm_permute_pd(p_lo, 0b11);
398        p_hi = _mm_add_pd(p_hi, carry);
399
400        let mut prefix = _mm256_castpd128_pd256(p_lo);
401        prefix = _mm256_insertf128_pd(prefix, p_hi, 1);
402
403        let sum_v = _mm256_set1_pd(sum);
404        let sums = _mm256_add_pd(sum_v, prefix);
405
406        let out_v = _mm256_mul_pd(sums, inv_v);
407        _mm256_storeu_pd(op.add(i), out_v);
408
409        let sums_hi = _mm256_extractf128_pd(sums, 1);
410        let last = _mm_unpackhi_pd(sums_hi, sums_hi);
411        sum = _mm_cvtsd_f64(last);
412
413        i += stride;
414    }
415
416    while i < end {
417        sum += *dp.add(i) - *dp.add(i - period);
418        *op.add(i) = sum.mul_add(inv, 0.0);
419        i += 1;
420    }
421}
422
423#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
424#[inline]
425pub fn sma_avx512(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
426    if period <= 32 {
427        unsafe { sma_avx512_short(data, period, first, out) }
428    } else {
429        unsafe { sma_avx512_long(data, period, first, out) }
430    }
431}
432
433#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
434#[target_feature(enable = "avx512f")]
435#[inline]
436pub unsafe fn sma_avx512_short(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
437    sma_avx512_long(data, period, first, out);
438}
439
440#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
441#[target_feature(enable = "avx512f")]
442#[inline]
443pub unsafe fn sma_avx512_long(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
444    use core::arch::x86_64::*;
445    debug_assert!(period >= 1);
446    debug_assert_eq!(data.len(), out.len());
447
448    let len = data.len();
449    let dp = data.as_ptr();
450    let op = out.as_mut_ptr();
451
452    if period == 1 {
453        let mut i = first;
454        while i < len {
455            *op.add(i) = *dp.add(i);
456            i += 1;
457        }
458        return;
459    }
460
461    let mut acc512 = _mm512_setzero_pd();
462    let mut k = 0usize;
463    let base = first;
464    let p8 = period & !7;
465
466    while k < p8 {
467        let v = _mm512_loadu_pd(dp.add(base + k));
468        acc512 = _mm512_add_pd(acc512, v);
469        k += 8;
470    }
471
472    let acc_lo256 = _mm512_castpd512_pd256(acc512);
473    let acc_hi256 = _mm512_extractf64x4_pd(acc512, 1);
474    let acc256 = _mm256_add_pd(acc_lo256, acc_hi256);
475
476    let hadd = _mm256_hadd_pd(acc256, acc256);
477    let lo = _mm256_castpd256_pd128(hadd);
478    let hi = _mm256_extractf128_pd(hadd, 1);
479    let sum128 = _mm_add_sd(lo, hi);
480    let mut sum = _mm_cvtsd_f64(sum128);
481
482    while k < period {
483        sum += *dp.add(base + k);
484        k += 1;
485    }
486
487    let inv = 1.0 / (period as f64);
488    let inv_v = _mm512_set1_pd(inv);
489    let warm = first + period - 1;
490    *op.add(warm) = sum.mul_add(inv, 0.0);
491
492    let idx_sl1 = _mm512_set_epi64(6, 5, 4, 3, 2, 1, 0, 0);
493
494    let idx_sl2 = _mm512_set_epi64(5, 4, 3, 2, 1, 0, 0, 0);
495
496    let idx_sl4 = _mm512_set_epi64(3, 2, 1, 0, 0, 0, 0, 0);
497
498    let mut i = warm + 1;
499    let end = len;
500
501    while i + 7 < end {
502        let v_new = _mm512_loadu_pd(dp.add(i));
503        let v_old = _mm512_loadu_pd(dp.add(i - period));
504        let d = _mm512_sub_pd(v_new, v_old);
505
506        let mut pref = d;
507        let sh1 = _mm512_maskz_permutexvar_pd(0b1111_1110, idx_sl1, pref);
508        pref = _mm512_add_pd(pref, sh1);
509
510        let sh2 = _mm512_maskz_permutexvar_pd(0b1111_1100, idx_sl2, pref);
511        pref = _mm512_add_pd(pref, sh2);
512
513        let sh4 = _mm512_maskz_permutexvar_pd(0b1111_0000, idx_sl4, pref);
514        pref = _mm512_add_pd(pref, sh4);
515
516        let sums = _mm512_add_pd(_mm512_set1_pd(sum), pref);
517
518        let out_v = _mm512_mul_pd(sums, inv_v);
519        _mm512_storeu_pd(op.add(i), out_v);
520
521        let sums_hi256 = _mm512_extractf64x4_pd(sums, 1);
522        let sums_hi128 = _mm256_extractf128_pd(sums_hi256, 1);
523        let last = _mm_unpackhi_pd(sums_hi128, sums_hi128);
524        sum = _mm_cvtsd_f64(last);
525
526        i += 8;
527    }
528
529    while i < end {
530        sum += *dp.add(i) - *dp.add(i - period);
531        *op.add(i) = sum.mul_add(inv, 0.0);
532        i += 1;
533    }
534}
535
536#[derive(Debug, Clone)]
537pub struct SmaStream {
538    period: usize,
539    buffer: Vec<f64>,
540    head: usize,
541    sum: f64,
542    count: usize,
543    inv: f64,
544
545    use_mask: bool,
546    mask: usize,
547}
548
549impl SmaStream {
550    #[inline(always)]
551    pub fn try_new(params: SmaParams) -> Result<Self, SmaError> {
552        let period = params.period.unwrap_or(9);
553        if period == 0 {
554            return Err(SmaError::InvalidPeriod {
555                period,
556                data_len: 0,
557            });
558        }
559        let use_mask = period.is_power_of_two();
560        Ok(Self {
561            period,
562            buffer: vec![0.0; period],
563            head: 0,
564            sum: 0.0,
565            count: 0,
566            inv: (period as f64).recip(),
567            use_mask,
568            mask: period.wrapping_sub(1),
569        })
570    }
571
572    #[inline(always)]
573    fn advance_head(&mut self) {
574        if self.use_mask {
575            self.head = (self.head + 1) & self.mask;
576        } else {
577            let next = self.head + 1;
578            self.head = if next == self.period { 0 } else { next };
579        }
580    }
581
582    #[inline(always)]
583    pub fn update(&mut self, value: f64) -> Option<f64> {
584        if self.period == 1 {
585            self.sum = value;
586            self.buffer[0] = value;
587            self.count = 1;
588            return Some(value);
589        }
590
591        if self.count < self.period {
592            self.sum += value;
593            self.buffer[self.head] = value;
594            self.advance_head();
595            self.count += 1;
596            if self.count == self.period {
597                return Some(self.sum * self.inv);
598            }
599            return None;
600        }
601
602        let old = self.buffer[self.head];
603        self.sum += value - old;
604        self.buffer[self.head] = value;
605        self.advance_head();
606        Some(self.sum * self.inv)
607    }
608}
609
610#[derive(Clone, Debug)]
611pub struct SmaBatchRange {
612    pub period: (usize, usize, usize),
613}
614
615impl Default for SmaBatchRange {
616    fn default() -> Self {
617        Self {
618            period: (9, 258, 1),
619        }
620    }
621}
622
623#[derive(Clone, Debug, Default)]
624pub struct SmaBatchBuilder {
625    range: SmaBatchRange,
626    kernel: Kernel,
627}
628
629impl SmaBatchBuilder {
630    pub fn new() -> Self {
631        Self::default()
632    }
633    pub fn kernel(mut self, k: Kernel) -> Self {
634        self.kernel = k;
635        self
636    }
637    #[inline]
638    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
639        self.range.period = (start, end, step);
640        self
641    }
642    #[inline]
643    pub fn period_static(mut self, p: usize) -> Self {
644        self.range.period = (p, p, 0);
645        self
646    }
647    pub fn apply_slice(self, data: &[f64]) -> Result<SmaBatchOutput, SmaError> {
648        sma_batch_with_kernel(data, &self.range, self.kernel)
649    }
650    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<SmaBatchOutput, SmaError> {
651        SmaBatchBuilder::new().kernel(k).apply_slice(data)
652    }
653    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<SmaBatchOutput, SmaError> {
654        let slice = source_type(c, src);
655        self.apply_slice(slice)
656    }
657    pub fn with_default_candles(c: &Candles) -> Result<SmaBatchOutput, SmaError> {
658        SmaBatchBuilder::new()
659            .kernel(Kernel::Auto)
660            .apply_candles(c, "close")
661    }
662}
663
664pub fn sma_batch_with_kernel(
665    data: &[f64],
666    sweep: &SmaBatchRange,
667    k: Kernel,
668) -> Result<SmaBatchOutput, SmaError> {
669    let kernel = match k {
670        Kernel::Auto => detect_best_batch_kernel(),
671        other if other.is_batch() => other,
672        other => return Err(SmaError::InvalidKernelForBatch(other)),
673    };
674    let simd = match kernel {
675        Kernel::Avx512Batch => Kernel::Avx512,
676        Kernel::Avx2Batch => Kernel::Avx2,
677        Kernel::ScalarBatch => Kernel::Scalar,
678        _ => unreachable!(),
679    };
680    sma_batch_par_slice(data, sweep, simd)
681}
682
683#[derive(Clone, Debug)]
684pub struct SmaBatchOutput {
685    pub values: Vec<f64>,
686    pub combos: Vec<SmaParams>,
687    pub rows: usize,
688    pub cols: usize,
689}
690impl SmaBatchOutput {
691    pub fn row_for_params(&self, p: &SmaParams) -> Option<usize> {
692        self.combos
693            .iter()
694            .position(|c| c.period.unwrap_or(9) == p.period.unwrap_or(9))
695    }
696    pub fn values_for(&self, p: &SmaParams) -> Option<&[f64]> {
697        self.row_for_params(p).map(|row| {
698            let start = row * self.cols;
699            &self.values[start..start + self.cols]
700        })
701    }
702}
703
704#[inline(always)]
705pub fn expand_grid_sma(r: &SmaBatchRange) -> Result<Vec<SmaParams>, SmaError> {
706    fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, SmaError> {
707        if step == 0 {
708            return Ok(vec![start]);
709        }
710        if start == end {
711            return Ok(vec![start]);
712        }
713        let mut vals = Vec::new();
714        if start < end {
715            let mut v = start;
716            while v <= end {
717                vals.push(v);
718                match v.checked_add(step) {
719                    Some(next) => {
720                        if next == v {
721                            break;
722                        }
723                        v = next;
724                    }
725                    None => break,
726                }
727            }
728        } else {
729            let mut v = start;
730            while v >= end {
731                vals.push(v);
732                if v == 0 {
733                    break;
734                }
735                let next = v.saturating_sub(step);
736                if next == v {
737                    break;
738                }
739                v = next;
740                if v < end {
741                    break;
742                }
743            }
744        }
745        if vals.is_empty() {
746            return Err(SmaError::InvalidRange { start, end, step });
747        }
748        Ok(vals)
749    }
750    let periods = axis_usize(r.period)?;
751    let mut out = Vec::with_capacity(periods.len());
752    for &p in &periods {
753        out.push(SmaParams { period: Some(p) });
754    }
755    Ok(out)
756}
757
758#[inline(always)]
759pub fn sma_batch_slice(
760    data: &[f64],
761    sweep: &SmaBatchRange,
762    kern: Kernel,
763) -> Result<SmaBatchOutput, SmaError> {
764    sma_batch_inner(data, sweep, kern, false)
765}
766
767#[inline(always)]
768pub fn sma_batch_par_slice(
769    data: &[f64],
770    sweep: &SmaBatchRange,
771    kern: Kernel,
772) -> Result<SmaBatchOutput, SmaError> {
773    sma_batch_inner(data, sweep, kern, true)
774}
775
776#[inline(always)]
777fn sma_batch_inner(
778    data: &[f64],
779    sweep: &SmaBatchRange,
780    kern: Kernel,
781    parallel: bool,
782) -> Result<SmaBatchOutput, SmaError> {
783    let combos = expand_grid_sma(sweep)?;
784    if data.is_empty() {
785        return Err(SmaError::EmptyInputData);
786    }
787
788    let cols = data.len();
789    let first = data
790        .iter()
791        .position(|x| !x.is_nan())
792        .ok_or(SmaError::AllValuesNaN)?;
793    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
794    if cols - first < max_p {
795        return Err(SmaError::NotEnoughValidData {
796            needed: max_p,
797            valid: cols - first,
798        });
799    }
800
801    let rows = combos.len();
802
803    rows.checked_mul(cols).ok_or(SmaError::InvalidRange {
804        start: sweep.period.0,
805        end: sweep.period.1,
806        step: sweep.period.2,
807    })?;
808
809    let mut buf_mu = make_uninit_matrix(rows, cols);
810
811    let mut guard = core::mem::ManuallyDrop::new(buf_mu);
812    let out_slice: &mut [f64] =
813        unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
814
815    sma_batch_inner_into(data, sweep, kern, parallel, out_slice)?;
816
817    let values = unsafe {
818        Vec::from_raw_parts(
819            guard.as_mut_ptr() as *mut f64,
820            guard.len(),
821            guard.capacity(),
822        )
823    };
824
825    Ok(SmaBatchOutput {
826        values,
827        combos,
828        rows,
829        cols,
830    })
831}
832
833#[inline(always)]
834unsafe fn sma_batch_row_prefixsum_scalar(
835    ps: &[f64],
836    period: usize,
837    mut i: usize,
838    cols: usize,
839    inv: f64,
840    dst: *mut f64,
841) {
842    while i < cols {
843        let s_hi = *ps.get_unchecked(i);
844        let s_lo = *ps.get_unchecked(i - period);
845        *dst.add(i) = (s_hi - s_lo) * inv;
846        i += 1;
847    }
848}
849
850#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
851#[target_feature(enable = "avx2")]
852#[inline]
853unsafe fn sma_batch_row_prefixsum_avx2(
854    ps: &[f64],
855    period: usize,
856    mut i: usize,
857    cols: usize,
858    inv: f64,
859    dst: *mut f64,
860) {
861    use core::arch::x86_64::*;
862
863    let inv_v = _mm256_set1_pd(inv);
864    let ps_ptr = ps.as_ptr();
865    let lanes = 4usize;
866
867    while i + (lanes - 1) < cols {
868        let hi = _mm256_loadu_pd(ps_ptr.add(i));
869        let lo = _mm256_loadu_pd(ps_ptr.add(i - period));
870        let diff = _mm256_sub_pd(hi, lo);
871        let out_v = _mm256_mul_pd(diff, inv_v);
872        _mm256_storeu_pd(dst.add(i), out_v);
873        i += lanes;
874    }
875
876    sma_batch_row_prefixsum_scalar(ps, period, i, cols, inv, dst);
877}
878
879#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
880#[target_feature(enable = "avx512f")]
881#[inline]
882unsafe fn sma_batch_row_prefixsum_avx512(
883    ps: &[f64],
884    period: usize,
885    mut i: usize,
886    cols: usize,
887    inv: f64,
888    dst: *mut f64,
889) {
890    use core::arch::x86_64::*;
891
892    let inv_v = _mm512_set1_pd(inv);
893    let ps_ptr = ps.as_ptr();
894    let lanes = 8usize;
895
896    while i + (lanes - 1) < cols {
897        let hi = _mm512_loadu_pd(ps_ptr.add(i));
898        let lo = _mm512_loadu_pd(ps_ptr.add(i - period));
899        let diff = _mm512_sub_pd(hi, lo);
900        let out_v = _mm512_mul_pd(diff, inv_v);
901        _mm512_storeu_pd(dst.add(i), out_v);
902        i += lanes;
903    }
904
905    sma_batch_row_prefixsum_scalar(ps, period, i, cols, inv, dst);
906}
907
908#[inline(always)]
909fn sma_batch_inner_into(
910    data: &[f64],
911    sweep: &SmaBatchRange,
912    kern: Kernel,
913    parallel: bool,
914    out: &mut [f64],
915) -> Result<Vec<SmaParams>, SmaError> {
916    let combos = expand_grid_sma(sweep)?;
917    if data.is_empty() {
918        return Err(SmaError::EmptyInputData);
919    }
920
921    let first = data
922        .iter()
923        .position(|x| !x.is_nan())
924        .ok_or(SmaError::AllValuesNaN)?;
925    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
926    if data.len() - first < max_p {
927        return Err(SmaError::NotEnoughValidData {
928            needed: max_p,
929            valid: data.len() - first,
930        });
931    }
932
933    let rows = combos.len();
934    let cols = data.len();
935    rows.checked_mul(cols).ok_or(SmaError::InvalidRange {
936        start: sweep.period.0,
937        end: sweep.period.1,
938        step: sweep.period.2,
939    })?;
940
941    let actual_kern = match kern {
942        Kernel::Auto => detect_best_batch_kernel(),
943        k => k,
944    };
945    let actual_kern = match actual_kern {
946        Kernel::Avx512Batch => Kernel::Avx512,
947        Kernel::Avx2Batch => Kernel::Avx2,
948        Kernel::ScalarBatch => Kernel::Scalar,
949        other => other,
950    };
951
952    let out_uninit: &mut [MaybeUninit<f64>] = unsafe {
953        core::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
954    };
955
956    let warm: Vec<usize> = combos
957        .iter()
958        .map(|c| first + c.period.unwrap_or(9) - 1)
959        .collect();
960    init_matrix_prefixes(out_uninit, cols, &warm);
961
962    let mut ps = vec![0.0_f64; cols];
963    if first < cols {
964        ps[first] = data[first];
965        for i in (first + 1)..cols {
966            ps[i] = ps[i - 1] + data[i];
967        }
968    }
969
970    let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
971        let period = combos[row].period.unwrap();
972        let warm = first + period - 1;
973
974        let dst = core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
975        if warm >= cols {
976            return;
977        }
978        let inv = (period as f64).recip();
979
980        let s_hi = *ps.get_unchecked(warm);
981        let s_lo = if warm >= period {
982            *ps.get_unchecked(warm - period)
983        } else {
984            0.0
985        };
986        dst[warm] = (s_hi - s_lo) * inv;
987
988        let mut i = warm + 1;
989        if i >= cols {
990            return;
991        }
992
993        let dst_ptr = dst.as_mut_ptr();
994        match actual_kern {
995            Kernel::Scalar => sma_batch_row_prefixsum_scalar(&ps, period, i, cols, inv, dst_ptr),
996            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
997            Kernel::Avx2 => sma_batch_row_prefixsum_avx2(&ps, period, i, cols, inv, dst_ptr),
998            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
999            Kernel::Avx512 => sma_batch_row_prefixsum_avx512(&ps, period, i, cols, inv, dst_ptr),
1000            _ => sma_batch_row_prefixsum_scalar(&ps, period, i, cols, inv, dst_ptr),
1001        }
1002    };
1003
1004    if parallel {
1005        #[cfg(not(target_arch = "wasm32"))]
1006        out_uninit
1007            .par_chunks_mut(cols)
1008            .enumerate()
1009            .for_each(|(row, slice)| do_row(row, slice));
1010        #[cfg(target_arch = "wasm32")]
1011        for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
1012            do_row(row, slice);
1013        }
1014    } else {
1015        for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
1016            do_row(row, slice);
1017        }
1018    }
1019
1020    Ok(combos)
1021}
1022
1023#[inline(always)]
1024unsafe fn sma_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1025    sma_scalar(data, period, first, out);
1026}
1027
1028#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1029#[inline(always)]
1030unsafe fn sma_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1031    sma_avx2(data, period, first, out);
1032}
1033
1034#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1035#[inline(always)]
1036unsafe fn sma_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1037    if period <= 32 {
1038        sma_avx512_short(data, period, first, out);
1039    } else {
1040        sma_avx512_long(data, period, first, out);
1041    }
1042}
1043
1044#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1045#[inline(always)]
1046unsafe fn sma_row_avx512_short(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
1047    sma_avx512_short(data, period, first, out);
1048}
1049
1050#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1051#[inline(always)]
1052unsafe fn sma_row_avx512_long(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
1053    sma_avx512_long(data, period, first, out);
1054}
1055
1056#[cfg(feature = "python")]
1057#[pyfunction(name = "sma")]
1058#[pyo3(signature = (data, period, kernel=None))]
1059
1060pub fn sma_py<'py>(
1061    py: Python<'py>,
1062    data: PyReadonlyArray1<'py, f64>,
1063    period: usize,
1064    kernel: Option<&str>,
1065) -> PyResult<Bound<'py, PyArray1<f64>>> {
1066    use numpy::IntoPyArray;
1067
1068    let kern = validate_kernel(kernel, false)?;
1069
1070    let params = SmaParams {
1071        period: Some(period),
1072    };
1073
1074    let result_vec: Vec<f64> = if let Ok(data_slice) = data.as_slice() {
1075        let input = SmaInput::from_slice(data_slice, params);
1076        py.allow_threads(|| sma_with_kernel(&input, kern).map(|o| o.values))
1077            .map_err(|e| PyValueError::new_err(e.to_string()))?
1078    } else {
1079        let owned = data.as_array().to_owned();
1080        let data_slice = owned
1081            .as_slice()
1082            .expect("owned numpy array should be contiguous");
1083        let input = SmaInput::from_slice(data_slice, params);
1084        py.allow_threads(|| sma_with_kernel(&input, kern).map(|o| o.values))
1085            .map_err(|e| PyValueError::new_err(e.to_string()))?
1086    };
1087
1088    Ok(result_vec.into_pyarray(py))
1089}
1090
1091#[cfg(feature = "python")]
1092#[pyfunction(name = "sma_batch")]
1093#[pyo3(signature = (data, period_range, kernel=None))]
1094
1095pub fn sma_batch_py<'py>(
1096    py: Python<'py>,
1097    data: PyReadonlyArray1<'py, f64>,
1098    period_range: (usize, usize, usize),
1099    kernel: Option<&str>,
1100) -> PyResult<Bound<'py, PyDict>> {
1101    use numpy::IntoPyArray;
1102    use pyo3::types::PyDict;
1103
1104    let kern = validate_kernel(kernel, true)?;
1105
1106    let data_slice = data.as_slice()?;
1107    let range = SmaBatchRange {
1108        period: period_range,
1109    };
1110
1111    let combos = expand_grid_sma(&range).map_err(|e| PyValueError::new_err(e.to_string()))?;
1112    if data_slice.is_empty() {
1113        return Err(PyValueError::new_err("Empty data"));
1114    }
1115
1116    let rows = combos.len();
1117    let cols = data_slice.len();
1118
1119    let nelems = rows
1120        .checked_mul(cols)
1121        .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
1122
1123    let out_arr = unsafe { PyArray1::<f64>::new(py, [nelems], false) };
1124    let slice_out = unsafe { out_arr.as_slice_mut()? };
1125
1126    let combos = py
1127        .allow_threads(|| {
1128            let kernel = match kern {
1129                Kernel::Auto => detect_best_batch_kernel(),
1130                k => k,
1131            };
1132            let simd = match kernel {
1133                Kernel::Avx512Batch => Kernel::Avx512,
1134                Kernel::Avx2Batch => Kernel::Avx2,
1135                Kernel::ScalarBatch => Kernel::Scalar,
1136                _ => unreachable!(),
1137            };
1138
1139            sma_batch_inner_into(data_slice, &range, simd, true, slice_out)
1140        })
1141        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1142
1143    let dict = PyDict::new(py);
1144    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1145
1146    dict.set_item(
1147        "periods",
1148        combos
1149            .iter()
1150            .map(|p| p.period.unwrap_or(9) as u64)
1151            .collect::<Vec<_>>()
1152            .into_pyarray(py),
1153    )?;
1154
1155    Ok(dict.into())
1156}
1157
1158#[cfg(all(feature = "python", feature = "cuda"))]
1159#[pyfunction(name = "sma_cuda_batch_dev")]
1160#[pyo3(signature = (data_f32, period_range, device_id=0))]
1161pub fn sma_cuda_batch_dev_py<'py>(
1162    py: Python<'py>,
1163    data_f32: numpy::PyReadonlyArray1<'py, f32>,
1164    period_range: (usize, usize, usize),
1165    device_id: usize,
1166) -> PyResult<(SmaDeviceArrayF32Py, Bound<'py, PyDict>)> {
1167    use crate::cuda::cuda_available;
1168    use numpy::IntoPyArray;
1169    use pyo3::types::PyDict;
1170
1171    if !cuda_available() {
1172        return Err(PyValueError::new_err("CUDA not available"));
1173    }
1174
1175    let slice_in = data_f32.as_slice()?;
1176    let sweep = SmaBatchRange {
1177        period: period_range,
1178    };
1179
1180    let (inner, combos, ctx_arc, dev_id) = py.allow_threads(|| {
1181        let cuda = CudaSma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1182        let (dev, combos) = cuda
1183            .sma_batch_dev(slice_in, &sweep)
1184            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1185        cuda.synchronize()
1186            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1187        Ok::<_, PyErr>((dev, combos, cuda.context_arc_clone(), cuda.device_id()))
1188    })?;
1189
1190    let dict = PyDict::new(py);
1191    let periods: Vec<u64> = combos.iter().map(|c| c.period.unwrap() as u64).collect();
1192    dict.set_item("periods", periods.into_pyarray(py))?;
1193
1194    Ok((
1195        SmaDeviceArrayF32Py {
1196            inner,
1197            _ctx: ctx_arc,
1198            device_id: dev_id,
1199        },
1200        dict,
1201    ))
1202}
1203
1204#[cfg(all(feature = "python", feature = "cuda"))]
1205#[pyfunction(name = "sma_cuda_many_series_one_param_dev")]
1206#[pyo3(signature = (data_tm_f32, period, device_id=0))]
1207pub fn sma_cuda_many_series_one_param_dev_py(
1208    py: Python<'_>,
1209    data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1210    period: usize,
1211    device_id: usize,
1212) -> PyResult<SmaDeviceArrayF32Py> {
1213    use crate::cuda::cuda_available;
1214    use numpy::PyUntypedArrayMethods;
1215
1216    if !cuda_available() {
1217        return Err(PyValueError::new_err("CUDA not available"));
1218    }
1219
1220    let flat_in = data_tm_f32.as_slice()?;
1221    let rows = data_tm_f32.shape()[0];
1222    let cols = data_tm_f32.shape()[1];
1223    let params = SmaParams {
1224        period: Some(period),
1225    };
1226
1227    let (inner, ctx_arc, dev_id) = py.allow_threads(|| {
1228        let cuda = CudaSma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1229        let dev = cuda
1230            .sma_multi_series_one_param_time_major_dev(flat_in, cols, rows, &params)
1231            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1232        cuda.synchronize()
1233            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1234        Ok::<_, PyErr>((dev, cuda.context_arc_clone(), cuda.device_id()))
1235    })?;
1236
1237    Ok(SmaDeviceArrayF32Py {
1238        inner,
1239        _ctx: ctx_arc,
1240        device_id: dev_id,
1241    })
1242}
1243
1244#[cfg(all(feature = "python", feature = "cuda"))]
1245#[pyclass(module = "ta_indicators.cuda", name = "SmaDeviceArrayF32", unsendable)]
1246pub struct SmaDeviceArrayF32Py {
1247    pub(crate) inner: DeviceArrayF32,
1248    pub(crate) _ctx: Arc<Context>,
1249    pub(crate) device_id: u32,
1250}
1251
1252#[cfg(all(feature = "python", feature = "cuda"))]
1253#[pymethods]
1254impl SmaDeviceArrayF32Py {
1255    #[getter]
1256    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1257        let d = PyDict::new(py);
1258
1259        d.set_item("shape", (self.inner.rows, self.inner.cols))?;
1260        d.set_item("typestr", "<f4")?;
1261        d.set_item(
1262            "strides",
1263            (
1264                self.inner.cols * std::mem::size_of::<f32>(),
1265                std::mem::size_of::<f32>(),
1266            ),
1267        )?;
1268        d.set_item("data", (self.inner.device_ptr() as usize, false))?;
1269
1270        d.set_item("version", 3)?;
1271        Ok(d)
1272    }
1273
1274    fn __dlpack_device__(&self) -> (i32, i32) {
1275        (2, self.device_id as i32)
1276    }
1277
1278    #[pyo3(signature=(stream=None, max_version=None, dl_device=None, copy=None))]
1279    fn __dlpack__<'py>(
1280        &mut self,
1281        py: Python<'py>,
1282        stream: Option<pyo3::PyObject>,
1283        max_version: Option<pyo3::PyObject>,
1284        dl_device: Option<pyo3::PyObject>,
1285        copy: Option<pyo3::PyObject>,
1286    ) -> PyResult<PyObject> {
1287        let (kdl, alloc_dev) = self.__dlpack_device__();
1288        if let Some(dev_obj) = dl_device.as_ref() {
1289            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1290                if dev_ty != kdl || dev_id != alloc_dev {
1291                    let wants_copy = copy
1292                        .as_ref()
1293                        .and_then(|c| c.extract::<bool>(py).ok())
1294                        .unwrap_or(false);
1295                    if wants_copy {
1296                        return Err(PyValueError::new_err(
1297                            "device copy not implemented for __dlpack__",
1298                        ));
1299                    } else {
1300                        return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
1301                    }
1302                }
1303            }
1304        }
1305        let _ = stream;
1306
1307        let dummy =
1308            DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
1309        let inner = std::mem::replace(
1310            &mut self.inner,
1311            DeviceArrayF32 {
1312                buf: dummy,
1313                rows: 0,
1314                cols: 0,
1315            },
1316        );
1317
1318        let rows = inner.rows;
1319        let cols = inner.cols;
1320        let buf = inner.buf;
1321
1322        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1323
1324        crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d(
1325            py,
1326            buf,
1327            rows,
1328            cols,
1329            alloc_dev,
1330            max_version_bound,
1331        )
1332    }
1333}
1334
1335#[cfg(feature = "python")]
1336#[pyclass(name = "SmaStream")]
1337
1338pub struct SmaStreamPy {
1339    inner: SmaStream,
1340}
1341
1342#[cfg(feature = "python")]
1343#[pymethods]
1344impl SmaStreamPy {
1345    #[new]
1346    #[pyo3(signature = (period))]
1347    pub fn new(period: usize) -> PyResult<Self> {
1348        let params = SmaParams {
1349            period: Some(period),
1350        };
1351        let inner = SmaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1352        Ok(Self { inner })
1353    }
1354
1355    pub fn update(&mut self, value: f64) -> Option<f64> {
1356        self.inner.update(value)
1357    }
1358}
1359
1360#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1361#[wasm_bindgen(js_name = "sma")]
1362
1363pub fn sma_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1364    let params = SmaParams {
1365        period: Some(period),
1366    };
1367    let input = SmaInput::from_slice(data, params);
1368
1369    let mut output = vec![0.0; data.len()];
1370
1371    sma_into_slice(&mut output, &input, Kernel::Auto)
1372        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1373
1374    Ok(output)
1375}
1376
1377#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1378#[derive(Serialize, Deserialize)]
1379pub struct SmaBatchConfig {
1380    pub period_range: (usize, usize, usize),
1381}
1382
1383#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1384#[derive(Serialize, Deserialize)]
1385pub struct SmaBatchJsOutput {
1386    pub values: Vec<f64>,
1387    pub combos: Vec<SmaParams>,
1388    pub periods: Vec<usize>,
1389    pub rows: usize,
1390    pub cols: usize,
1391}
1392
1393#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1394#[wasm_bindgen(js_name = "sma_batch")]
1395pub fn sma_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1396    let config: SmaBatchConfig = serde_wasm_bindgen::from_value(config)
1397        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1398
1399    let sweep = SmaBatchRange {
1400        period: config.period_range,
1401    };
1402
1403    let output = sma_batch_with_kernel(data, &sweep, Kernel::Auto)
1404        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1405
1406    let js_output = SmaBatchJsOutput {
1407        values: output.values,
1408        periods: output
1409            .combos
1410            .iter()
1411            .map(|c| c.period.unwrap_or(9))
1412            .collect(),
1413        combos: output.combos,
1414        rows: output.rows,
1415        cols: output.cols,
1416    };
1417
1418    serde_wasm_bindgen::to_value(&js_output).map_err(|e| JsValue::from_str(&e.to_string()))
1419}
1420
1421#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1422#[wasm_bindgen(js_name = "smaBatch")]
1423#[deprecated(since = "1.0.0", note = "Use sma_batch instead")]
1424pub fn sma_batch_js(
1425    data: &[f64],
1426    period_start: usize,
1427    period_end: usize,
1428    period_step: usize,
1429) -> Result<Vec<f64>, JsValue> {
1430    let range = SmaBatchRange {
1431        period: (period_start, period_end, period_step),
1432    };
1433
1434    sma_batch_with_kernel(data, &range, Kernel::Auto)
1435        .map(|output| output.values)
1436        .map_err(|e| JsValue::from_str(&e.to_string()))
1437}
1438
1439#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1440#[wasm_bindgen(js_name = "smaBatchMetadata")]
1441#[deprecated(since = "1.0.0", note = "Use sma_batch which returns metadata")]
1442pub fn sma_batch_metadata_js(
1443    period_start: usize,
1444    period_end: usize,
1445    period_step: usize,
1446) -> Vec<usize> {
1447    let range = SmaBatchRange {
1448        period: (period_start, period_end, period_step),
1449    };
1450    let combos = expand_grid_sma(&range).unwrap_or_default();
1451    combos.iter().map(|c| c.period.unwrap_or(9)).collect()
1452}
1453
1454#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1455#[wasm_bindgen(js_name = "smaBatchRowsCols")]
1456#[deprecated(since = "1.0.0", note = "Use sma_batch which returns rows and cols")]
1457pub fn sma_batch_rows_cols_js(
1458    period_start: usize,
1459    period_end: usize,
1460    period_step: usize,
1461    data_len: usize,
1462) -> Vec<usize> {
1463    let range = SmaBatchRange {
1464        period: (period_start, period_end, period_step),
1465    };
1466    let combos = expand_grid_sma(&range).unwrap_or_default();
1467    vec![combos.len(), data_len]
1468}
1469
1470#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1471#[wasm_bindgen]
1472pub fn sma_alloc(len: usize) -> *mut f64 {
1473    let mut vec = Vec::<f64>::with_capacity(len);
1474    let ptr = vec.as_mut_ptr();
1475    std::mem::forget(vec);
1476    ptr
1477}
1478
1479#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1480#[wasm_bindgen]
1481pub fn sma_free(ptr: *mut f64, len: usize) {
1482    if !ptr.is_null() {
1483        unsafe {
1484            let _ = Vec::from_raw_parts(ptr, len, len);
1485        }
1486    }
1487}
1488
1489#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1490#[wasm_bindgen]
1491pub fn sma_into(
1492    in_ptr: *const f64,
1493    out_ptr: *mut f64,
1494    len: usize,
1495    period: usize,
1496) -> Result<(), JsValue> {
1497    if in_ptr.is_null() || out_ptr.is_null() {
1498        return Err(JsValue::from_str("Null pointer provided"));
1499    }
1500
1501    unsafe {
1502        let data = std::slice::from_raw_parts(in_ptr, len);
1503
1504        let params = SmaParams {
1505            period: Some(period),
1506        };
1507        let input = SmaInput::from_slice(data, params);
1508
1509        if in_ptr == out_ptr as *const f64 {
1510            let mut temp = vec![0.0; len];
1511            sma_into_slice(&mut temp, &input, Kernel::Auto)
1512                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1513
1514            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1515            out.copy_from_slice(&temp);
1516        } else {
1517            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1518            sma_into_slice(out, &input, Kernel::Auto)
1519                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1520        }
1521
1522        Ok(())
1523    }
1524}
1525
1526#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1527#[wasm_bindgen]
1528pub fn sma_batch_into(
1529    in_ptr: *const f64,
1530    out_ptr: *mut f64,
1531    len: usize,
1532    period_start: usize,
1533    period_end: usize,
1534    period_step: usize,
1535) -> Result<usize, JsValue> {
1536    if in_ptr.is_null() || out_ptr.is_null() {
1537        return Err(JsValue::from_str("Null pointer provided"));
1538    }
1539
1540    unsafe {
1541        let data = std::slice::from_raw_parts(in_ptr, len);
1542
1543        let sweep = SmaBatchRange {
1544            period: (period_start, period_end, period_step),
1545        };
1546
1547        let combos = expand_grid_sma(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1548        let rows = combos.len();
1549        let total_size = rows * len;
1550
1551        let out = std::slice::from_raw_parts_mut(out_ptr, total_size);
1552
1553        let kernel = match detect_best_batch_kernel() {
1554            Kernel::Avx512Batch => Kernel::Avx512,
1555            Kernel::Avx2Batch => Kernel::Avx2,
1556            Kernel::ScalarBatch => Kernel::Scalar,
1557            other => other,
1558        };
1559
1560        sma_batch_inner_into(data, &sweep, kernel, false, out)
1561            .map_err(|e| JsValue::from_str(&e.to_string()))?;
1562
1563        Ok(rows)
1564    }
1565}
1566
1567#[cfg(test)]
1568mod tests {
1569    use super::*;
1570    use crate::skip_if_unsupported;
1571    use crate::utilities::data_loader::read_candles_from_csv;
1572
1573    #[test]
1574    fn test_sma_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
1575        let mut data = Vec::with_capacity(256);
1576        data.extend_from_slice(&[f64::NAN, f64::NAN, f64::NAN]);
1577        for i in 0..253u32 {
1578            let v = ((i % 17) as f64) * 1.2345 + (i as f64).sin() * 0.001;
1579            data.push(v);
1580        }
1581
1582        let params = SmaParams::default();
1583        let input = SmaInput::from_slice(&data, params);
1584
1585        let base = sma_with_kernel(&input, Kernel::Auto)?.values;
1586
1587        let mut out = vec![0.0; data.len()];
1588        #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1589        {
1590            sma_into(&input, &mut out)?;
1591        }
1592
1593        assert_eq!(base.len(), out.len());
1594
1595        for (i, (a, b)) in base.iter().zip(out.iter()).enumerate() {
1596            let ok = if a.is_nan() && b.is_nan() {
1597                true
1598            } else {
1599                (a - b).abs() <= 1e-12
1600            };
1601            assert!(ok, "Mismatch at index {}: base={} vs into={}", i, a, b);
1602        }
1603        Ok(())
1604    }
1605    fn check_sma_partial_params(
1606        test_name: &str,
1607        kernel: Kernel,
1608    ) -> Result<(), Box<dyn std::error::Error>> {
1609        skip_if_unsupported!(kernel, test_name);
1610        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1611        let candles = read_candles_from_csv(file_path)?;
1612        let default_params = SmaParams { period: None };
1613        let input = SmaInput::from_candles(&candles, "close", default_params);
1614        let output = sma_with_kernel(&input, kernel)?;
1615        assert_eq!(output.values.len(), candles.close.len());
1616        Ok(())
1617    }
1618    fn check_sma_accuracy(
1619        test_name: &str,
1620        kernel: Kernel,
1621    ) -> Result<(), Box<dyn std::error::Error>> {
1622        skip_if_unsupported!(kernel, test_name);
1623        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1624        let candles = read_candles_from_csv(file_path)?;
1625        let params = SmaParams { period: Some(9) };
1626        let input = SmaInput::from_candles(&candles, "close", params);
1627        let result = sma_with_kernel(&input, kernel)?;
1628        let expected_last_five = [59180.8, 59175.0, 59129.4, 59085.4, 59133.7];
1629        let start = result.values.len().saturating_sub(5);
1630        for (i, &val) in result.values[start..].iter().enumerate() {
1631            let diff = (val - expected_last_five[i]).abs();
1632            assert!(
1633                diff < 1e-1,
1634                "[{}] SMA {:?} mismatch at idx {}: got {}, expected {}",
1635                test_name,
1636                kernel,
1637                i,
1638                val,
1639                expected_last_five[i]
1640            );
1641        }
1642        Ok(())
1643    }
1644    fn check_sma_default_candles(
1645        test_name: &str,
1646        kernel: Kernel,
1647    ) -> Result<(), Box<dyn std::error::Error>> {
1648        skip_if_unsupported!(kernel, test_name);
1649        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1650        let candles = read_candles_from_csv(file_path)?;
1651        let input = SmaInput::with_default_candles(&candles);
1652        match input.data {
1653            SmaData::Candles { source, .. } => assert_eq!(source, "close"),
1654            _ => panic!("Expected SmaData::Candles"),
1655        }
1656        let output = sma_with_kernel(&input, kernel)?;
1657        assert_eq!(output.values.len(), candles.close.len());
1658        Ok(())
1659    }
1660    fn check_sma_zero_period(
1661        test_name: &str,
1662        kernel: Kernel,
1663    ) -> Result<(), Box<dyn std::error::Error>> {
1664        skip_if_unsupported!(kernel, test_name);
1665        let input_data = [10.0, 20.0, 30.0];
1666        let params = SmaParams { period: Some(0) };
1667        let input = SmaInput::from_slice(&input_data, params);
1668        let res = sma_with_kernel(&input, kernel);
1669        assert!(
1670            res.is_err(),
1671            "[{}] SMA should fail with zero period",
1672            test_name
1673        );
1674        Ok(())
1675    }
1676    fn check_sma_period_exceeds_length(
1677        test_name: &str,
1678        kernel: Kernel,
1679    ) -> Result<(), Box<dyn std::error::Error>> {
1680        skip_if_unsupported!(kernel, test_name);
1681        let data_small = [10.0, 20.0, 30.0];
1682        let params = SmaParams { period: Some(10) };
1683        let input = SmaInput::from_slice(&data_small, params);
1684        let res = sma_with_kernel(&input, kernel);
1685        assert!(
1686            res.is_err(),
1687            "[{}] SMA should fail with period exceeding length",
1688            test_name
1689        );
1690        Ok(())
1691    }
1692    fn check_sma_very_small_dataset(
1693        test_name: &str,
1694        kernel: Kernel,
1695    ) -> Result<(), Box<dyn std::error::Error>> {
1696        skip_if_unsupported!(kernel, test_name);
1697        let single_point = [42.0];
1698        let params = SmaParams { period: Some(9) };
1699        let input = SmaInput::from_slice(&single_point, params);
1700        let res = sma_with_kernel(&input, kernel);
1701        assert!(
1702            res.is_err(),
1703            "[{}] SMA should fail with insufficient data",
1704            test_name
1705        );
1706        Ok(())
1707    }
1708    fn check_sma_reinput(
1709        test_name: &str,
1710        kernel: Kernel,
1711    ) -> Result<(), Box<dyn std::error::Error>> {
1712        skip_if_unsupported!(kernel, test_name);
1713        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1714        let candles = read_candles_from_csv(file_path)?;
1715        let first_params = SmaParams { period: Some(14) };
1716        let first_input = SmaInput::from_candles(&candles, "close", first_params);
1717        let first_result = sma_with_kernel(&first_input, kernel)?;
1718        let second_params = SmaParams { period: Some(14) };
1719        let second_input = SmaInput::from_slice(&first_result.values, second_params);
1720        let second_result = sma_with_kernel(&second_input, kernel)?;
1721        assert_eq!(second_result.values.len(), first_result.values.len());
1722        Ok(())
1723    }
1724    fn check_sma_nan_handling(
1725        test_name: &str,
1726        kernel: Kernel,
1727    ) -> Result<(), Box<dyn std::error::Error>> {
1728        skip_if_unsupported!(kernel, test_name);
1729        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1730        let candles = read_candles_from_csv(file_path)?;
1731        let input = SmaInput::from_candles(&candles, "close", SmaParams { period: Some(9) });
1732        let res = sma_with_kernel(&input, kernel)?;
1733        assert_eq!(res.values.len(), candles.close.len());
1734        if res.values.len() > 240 {
1735            for (i, &val) in res.values[240..].iter().enumerate() {
1736                assert!(
1737                    !val.is_nan(),
1738                    "[{}] Found unexpected NaN at out-index {}",
1739                    test_name,
1740                    240 + i
1741                );
1742            }
1743        }
1744        Ok(())
1745    }
1746    fn check_sma_streaming(
1747        test_name: &str,
1748        kernel: Kernel,
1749    ) -> Result<(), Box<dyn std::error::Error>> {
1750        skip_if_unsupported!(kernel, test_name);
1751        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1752        let candles = read_candles_from_csv(file_path)?;
1753        let period = 9;
1754        let input = SmaInput::from_candles(
1755            &candles,
1756            "close",
1757            SmaParams {
1758                period: Some(period),
1759            },
1760        );
1761        let batch_output = sma_with_kernel(&input, kernel)?.values;
1762        let mut stream = SmaStream::try_new(SmaParams {
1763            period: Some(period),
1764        })?;
1765        let mut stream_values = Vec::with_capacity(candles.close.len());
1766        for &price in &candles.close {
1767            match stream.update(price) {
1768                Some(sma_val) => stream_values.push(sma_val),
1769                None => stream_values.push(f64::NAN),
1770            }
1771        }
1772        assert_eq!(batch_output.len(), stream_values.len());
1773        for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1774            if b.is_nan() && s.is_nan() {
1775                continue;
1776            }
1777            let diff = (b - s).abs();
1778            assert!(
1779                diff < 1e-9,
1780                "[{}] SMA streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1781                test_name,
1782                i,
1783                b,
1784                s,
1785                diff
1786            );
1787        }
1788        Ok(())
1789    }
1790
1791    #[cfg(debug_assertions)]
1792    fn check_sma_no_poison(
1793        test_name: &str,
1794        kernel: Kernel,
1795    ) -> Result<(), Box<dyn std::error::Error>> {
1796        skip_if_unsupported!(kernel, test_name);
1797
1798        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1799        let candles = read_candles_from_csv(file_path)?;
1800
1801        let test_periods = vec![5, 9, 14, 20, 30, 50];
1802
1803        for period in test_periods {
1804            let params = SmaParams {
1805                period: Some(period),
1806            };
1807            let input = SmaInput::from_candles(&candles, "close", params);
1808            let output = sma_with_kernel(&input, kernel)?;
1809
1810            for (i, &val) in output.values.iter().enumerate() {
1811                if val.is_nan() {
1812                    continue;
1813                }
1814
1815                let bits = val.to_bits();
1816
1817                if bits == 0x11111111_11111111 {
1818                    panic!(
1819						"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} (period={})",
1820						test_name, val, bits, i, period
1821					);
1822                }
1823
1824                if bits == 0x22222222_22222222 {
1825                    panic!(
1826						"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} (period={})",
1827						test_name, val, bits, i, period
1828					);
1829                }
1830
1831                if bits == 0x33333333_33333333 {
1832                    panic!(
1833						"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} (period={})",
1834						test_name, val, bits, i, period
1835					);
1836                }
1837            }
1838        }
1839
1840        Ok(())
1841    }
1842
1843    #[cfg(not(debug_assertions))]
1844    fn check_sma_no_poison(
1845        _test_name: &str,
1846        _kernel: Kernel,
1847    ) -> Result<(), Box<dyn std::error::Error>> {
1848        Ok(())
1849    }
1850
1851    #[cfg(feature = "proptest")]
1852    #[allow(clippy::float_cmp)]
1853    fn check_sma_property(
1854        test_name: &str,
1855        kernel: Kernel,
1856    ) -> Result<(), Box<dyn std::error::Error>> {
1857        use proptest::prelude::*;
1858        skip_if_unsupported!(kernel, test_name);
1859
1860        let strat = (1usize..=100).prop_flat_map(|period| {
1861            (
1862                prop::collection::vec(
1863                    (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1864                    period..400,
1865                ),
1866                Just(period),
1867            )
1868        });
1869
1870        proptest::test_runner::TestRunner::default()
1871            .run(&strat, |(data, period)| {
1872                let params = SmaParams {
1873                    period: Some(period),
1874                };
1875                let input = SmaInput::from_slice(&data, params);
1876
1877                let SmaOutput { values: out } = sma_with_kernel(&input, kernel).unwrap();
1878                let SmaOutput { values: ref_out } =
1879                    sma_with_kernel(&input, Kernel::Scalar).unwrap();
1880
1881                for i in 0..(period - 1) {
1882                    prop_assert!(
1883                        out[i].is_nan(),
1884                        "Expected NaN during warmup at index {}, got {}",
1885                        i,
1886                        out[i]
1887                    );
1888                }
1889
1890                for i in (period - 1)..data.len() {
1891                    let window_start = i + 1 - period;
1892                    let window = &data[window_start..=i];
1893
1894                    let expected_sum: f64 = window.iter().sum();
1895                    let expected_mean = expected_sum / period as f64;
1896
1897                    let abs_tolerance = 1e-8_f64;
1898                    let rel_tolerance = 1e-12_f64;
1899                    let tolerance = abs_tolerance.max(expected_mean.abs() * rel_tolerance);
1900
1901                    let kernel_tol = 5e-8_f64.max(tolerance);
1902                    prop_assert!(
1903                        (out[i] - expected_mean).abs() <= tolerance,
1904                        "SMA mismatch at index {}: expected {}, got {} (diff: {})",
1905                        i,
1906                        expected_mean,
1907                        out[i],
1908                        (out[i] - expected_mean).abs()
1909                    );
1910
1911                    let window_min = window.iter().cloned().fold(f64::INFINITY, f64::min);
1912                    let window_max = window.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1913
1914                    prop_assert!(
1915                        out[i] >= window_min - kernel_tol && out[i] <= window_max + kernel_tol,
1916                        "SMA out of bounds at index {}: {} not in [{}, {}]",
1917                        i,
1918                        out[i],
1919                        window_min,
1920                        window_max
1921                    );
1922
1923                    if window.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-12) {
1924                        let tolerance = kernel_tol.max(if period == 1 { 1e-8 } else { 1e-9 });
1925                        prop_assert!(
1926                            (out[i] - window[0]).abs() <= tolerance,
1927                            "Constant input property failed at index {}: expected {}, got {}",
1928                            i,
1929                            window[0],
1930                            out[i]
1931                        );
1932                    }
1933
1934                    if period >= 3 {
1935                        let diffs: Vec<f64> = window.windows(2).map(|w| w[1] - w[0]).collect();
1936                        let is_linear = diffs.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-9);
1937
1938                        if is_linear && !diffs.is_empty() {
1939                            let midpoint_value = window[period / 2];
1940                            let tolerance = if period % 2 == 0 {
1941                                (window[period / 2 - 1] - window[period / 2]).abs() / 2.0
1942                                    + kernel_tol
1943                            } else {
1944                                kernel_tol
1945                            };
1946
1947                            prop_assert!(
1948                                (out[i] - midpoint_value).abs() <= tolerance,
1949                                "Linear trend property failed at index {}: expected ~{}, got {}",
1950                                i,
1951                                midpoint_value,
1952                                out[i]
1953                            );
1954                        }
1955                    }
1956
1957                    prop_assert!(
1958                        (out[i] - ref_out[i]).abs() <= kernel_tol
1959                            || (out[i].is_nan() && ref_out[i].is_nan()),
1960                        "Kernel mismatch at index {}: {} ({:?}) vs {} (Scalar)",
1961                        i,
1962                        out[i],
1963                        kernel,
1964                        ref_out[i]
1965                    );
1966
1967                    if i >= period {
1968                        let new_value = data[i];
1969                        let old_value = data[i - period];
1970                        let expected_sma_change = (new_value - old_value) / period as f64;
1971                        let actual_sma_change = out[i] - out[i - 1];
1972                        let lag_tol = (expected_sma_change.abs() * rel_tolerance)
1973                            .max(5e-8_f64)
1974                            .max(2.0 * kernel_tol);
1975
1976                        prop_assert!(
1977								(actual_sma_change - expected_sma_change).abs() <= lag_tol,
1978								"Lag property failed at index {}: SMA change {} should be {} (new: {}, old: {})",
1979								i,
1980								actual_sma_change,
1981							expected_sma_change,
1982							new_value,
1983							old_value
1984						);
1985                    }
1986
1987                    #[cfg(debug_assertions)]
1988                    {
1989                        let bits = out[i].to_bits();
1990                        prop_assert!(
1991                            bits != 0x11111111_11111111
1992                                && bits != 0x22222222_22222222
1993                                && bits != 0x33333333_33333333,
1994                            "Found poison value at index {}: {} (0x{:016X})",
1995                            i,
1996                            out[i],
1997                            bits
1998                        );
1999                    }
2000                }
2001
2002                if period == 1 {
2003                    for i in 0..data.len() {
2004                        prop_assert!(
2005                            (out[i] - data[i]).abs() <= 1e-8,
2006                            "Period=1 property failed at index {}: expected {}, got {}",
2007                            i,
2008                            data[i],
2009                            out[i]
2010                        );
2011                    }
2012                }
2013
2014                Ok(())
2015            })
2016            .unwrap();
2017
2018        Ok(())
2019    }
2020
2021    macro_rules! generate_all_sma_tests {
2022        ($($test_fn:ident),*) => {
2023            paste::paste! {
2024                $(
2025                    #[test]
2026                    fn [<$test_fn _scalar_f64>]() {
2027                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2028                    }
2029                )*
2030                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2031                $(
2032                    #[test]
2033                    fn [<$test_fn _avx2_f64>]() {
2034                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2035                    }
2036                    #[test]
2037                    fn [<$test_fn _avx512_f64>]() {
2038                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2039                    }
2040                )*
2041            }
2042        }
2043    }
2044    generate_all_sma_tests!(
2045        check_sma_partial_params,
2046        check_sma_accuracy,
2047        check_sma_default_candles,
2048        check_sma_zero_period,
2049        check_sma_period_exceeds_length,
2050        check_sma_very_small_dataset,
2051        check_sma_reinput,
2052        check_sma_nan_handling,
2053        check_sma_streaming,
2054        check_sma_no_poison
2055    );
2056
2057    #[cfg(feature = "proptest")]
2058    generate_all_sma_tests!(check_sma_property);
2059    fn check_batch_default_row(
2060        test: &str,
2061        kernel: Kernel,
2062    ) -> Result<(), Box<dyn std::error::Error>> {
2063        skip_if_unsupported!(kernel, test);
2064        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2065        let c = read_candles_from_csv(file)?;
2066        let output = SmaBatchBuilder::new()
2067            .kernel(kernel)
2068            .apply_candles(&c, "close")?;
2069        let def = SmaParams::default();
2070        let row = output.values_for(&def).expect("default row missing");
2071        assert_eq!(row.len(), c.close.len());
2072        let expected = [59180.8, 59175.0, 59129.4, 59085.4, 59133.7];
2073        let start = row.len() - 5;
2074        for (i, &v) in row[start..].iter().enumerate() {
2075            assert!(
2076                (v - expected[i]).abs() < 1e-1,
2077                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
2078            );
2079        }
2080        Ok(())
2081    }
2082
2083    #[cfg(debug_assertions)]
2084    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
2085        skip_if_unsupported!(kernel, test);
2086
2087        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2088        let c = read_candles_from_csv(file)?;
2089
2090        let test_configs = vec![(5, 15, 5), (10, 30, 10), (20, 50, 15), (2, 10, 2)];
2091
2092        for (start, end, step) in test_configs {
2093            let output = SmaBatchBuilder::new()
2094                .kernel(kernel)
2095                .period_range(start, end, step)
2096                .apply_candles(&c, "close")?;
2097
2098            for (idx, &val) in output.values.iter().enumerate() {
2099                if val.is_nan() {
2100                    continue;
2101                }
2102
2103                let bits = val.to_bits();
2104                let row = idx / output.cols;
2105                let col = idx % output.cols;
2106                let period = output.combos[row].period.unwrap();
2107
2108                if bits == 0x11111111_11111111 {
2109                    panic!(
2110                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (flat index {}, period={})",
2111                        test, val, bits, row, col, idx, period
2112                    );
2113                }
2114
2115                if bits == 0x22222222_22222222 {
2116                    panic!(
2117                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (flat index {}, period={})",
2118                        test, val, bits, row, col, idx, period
2119                    );
2120                }
2121
2122                if bits == 0x33333333_33333333 {
2123                    panic!(
2124                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (flat index {}, period={})",
2125                        test, val, bits, row, col, idx, period
2126                    );
2127                }
2128            }
2129        }
2130
2131        Ok(())
2132    }
2133
2134    #[cfg(not(debug_assertions))]
2135    fn check_batch_no_poison(
2136        _test: &str,
2137        _kernel: Kernel,
2138    ) -> Result<(), Box<dyn std::error::Error>> {
2139        Ok(())
2140    }
2141    macro_rules! gen_batch_tests {
2142        ($fn_name:ident) => {
2143            paste::paste! {
2144                #[test] fn [<$fn_name _scalar>]()      {
2145                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2146                }
2147                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2148                #[test] fn [<$fn_name _avx2>]()        {
2149                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2150                }
2151                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2152                #[test] fn [<$fn_name _avx512>]()      {
2153                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2154                }
2155                #[test] fn [<$fn_name _auto_detect>]() {
2156                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2157                }
2158            }
2159        };
2160    }
2161    gen_batch_tests!(check_batch_default_row);
2162    gen_batch_tests!(check_batch_no_poison);
2163}