Skip to main content

vector_ta/indicators/moving_averages/
smma.rs

1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1, PyReadonlyArray2};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
10use serde::{Deserialize, Serialize};
11#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
12use wasm_bindgen::prelude::*;
13
14#[cfg(all(feature = "python", feature = "cuda"))]
15use crate::cuda::cuda_available;
16#[cfg(all(feature = "python", feature = "cuda"))]
17use crate::cuda::moving_averages::smma_wrapper::DeviceArrayF32Smma;
18#[cfg(all(feature = "python", feature = "cuda"))]
19use crate::cuda::moving_averages::CudaSmma;
20use crate::utilities::data_loader::{source_type, Candles};
21use crate::utilities::enums::Kernel;
22use crate::utilities::helpers::{
23    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
24    make_uninit_matrix,
25};
26#[cfg(feature = "python")]
27use crate::utilities::kernel_validation::validate_kernel;
28#[cfg(not(target_arch = "wasm32"))]
29use rayon::prelude::*;
30use std::mem::MaybeUninit;
31use thiserror::Error;
32
33impl<'a> AsRef<[f64]> for SmmaInput<'a> {
34    #[inline(always)]
35    fn as_ref(&self) -> &[f64] {
36        match &self.data {
37            SmmaData::Slice(slice) => slice,
38            SmmaData::Candles { candles, source } => source_type(candles, source),
39        }
40    }
41}
42
43#[derive(Debug, Clone)]
44pub enum SmmaData<'a> {
45    Candles {
46        candles: &'a Candles,
47        source: &'a str,
48    },
49    Slice(&'a [f64]),
50}
51
52#[derive(Debug, Clone)]
53pub struct SmmaOutput {
54    pub values: Vec<f64>,
55}
56
57#[derive(Debug, Clone)]
58#[cfg_attr(
59    all(target_arch = "wasm32", feature = "wasm"),
60    derive(Serialize, Deserialize)
61)]
62pub struct SmmaParams {
63    pub period: Option<usize>,
64}
65
66impl Default for SmmaParams {
67    fn default() -> Self {
68        Self { period: Some(7) }
69    }
70}
71
72#[derive(Debug, Clone)]
73pub struct SmmaInput<'a> {
74    pub data: SmmaData<'a>,
75    pub params: SmmaParams,
76}
77
78impl<'a> SmmaInput<'a> {
79    #[inline]
80    pub fn from_candles(c: &'a Candles, s: &'a str, p: SmmaParams) -> Self {
81        Self {
82            data: SmmaData::Candles {
83                candles: c,
84                source: s,
85            },
86            params: p,
87        }
88    }
89    #[inline]
90    pub fn from_slice(sl: &'a [f64], p: SmmaParams) -> Self {
91        Self {
92            data: SmmaData::Slice(sl),
93            params: p,
94        }
95    }
96    #[inline]
97    pub fn with_default_candles(c: &'a Candles) -> Self {
98        Self::from_candles(c, "close", SmmaParams::default())
99    }
100    #[inline]
101    pub fn get_period(&self) -> usize {
102        self.params.period.unwrap_or(7)
103    }
104}
105
106#[derive(Copy, Clone, Debug)]
107pub struct SmmaBuilder {
108    period: Option<usize>,
109    kernel: Kernel,
110}
111
112impl Default for SmmaBuilder {
113    fn default() -> Self {
114        Self {
115            period: None,
116            kernel: Kernel::Auto,
117        }
118    }
119}
120
121impl SmmaBuilder {
122    #[inline(always)]
123    pub fn new() -> Self {
124        Self::default()
125    }
126    #[inline(always)]
127    pub fn period(mut self, n: usize) -> Self {
128        self.period = Some(n);
129        self
130    }
131    #[inline(always)]
132    pub fn kernel(mut self, k: Kernel) -> Self {
133        self.kernel = k;
134        self
135    }
136
137    #[inline(always)]
138    pub fn apply(self, c: &Candles) -> Result<SmmaOutput, SmmaError> {
139        let p = SmmaParams {
140            period: self.period,
141        };
142        let i = SmmaInput::from_candles(c, "close", p);
143        smma_with_kernel(&i, self.kernel)
144    }
145
146    #[inline(always)]
147    pub fn apply_slice(self, d: &[f64]) -> Result<SmmaOutput, SmmaError> {
148        let p = SmmaParams {
149            period: self.period,
150        };
151        let i = SmmaInput::from_slice(d, p);
152        smma_with_kernel(&i, self.kernel)
153    }
154
155    #[inline(always)]
156    pub fn into_stream(self) -> Result<SmmaStream, SmmaError> {
157        let p = SmmaParams {
158            period: self.period,
159        };
160        SmmaStream::try_new(p)
161    }
162}
163
164#[derive(Debug, Error)]
165pub enum SmmaError {
166    #[error("smma: Input data slice is empty.")]
167    EmptyInputData,
168    #[error("smma: All values are NaN.")]
169    AllValuesNaN,
170    #[error("smma: Invalid period: period = {period}, data length = {data_len}")]
171    InvalidPeriod { period: usize, data_len: usize },
172    #[error("smma: Not enough valid data: needed = {needed}, valid = {valid}")]
173    NotEnoughValidData { needed: usize, valid: usize },
174    #[error("smma: Invalid kernel for batch operation: {kernel:?} (expected batch kernel)")]
175    InvalidKernelForBatch { kernel: Kernel },
176    #[error("smma: Invalid range expansion: start={start}, end={end}, step={step}")]
177    InvalidRange {
178        start: usize,
179        end: usize,
180        step: usize,
181    },
182    #[error("smma: Output buffer length mismatch: expected = {expected}, got = {got}")]
183    OutputLengthMismatch { expected: usize, got: usize },
184}
185
186#[inline]
187pub fn smma(input: &SmmaInput) -> Result<SmmaOutput, SmmaError> {
188    smma_with_kernel(input, Kernel::Auto)
189}
190
191pub fn smma_with_kernel(input: &SmmaInput, kernel: Kernel) -> Result<SmmaOutput, SmmaError> {
192    let data: &[f64] = match &input.data {
193        SmmaData::Candles { candles, source } => source_type(candles, source),
194        SmmaData::Slice(sl) => sl,
195    };
196
197    if data.is_empty() {
198        return Err(SmmaError::EmptyInputData);
199    }
200
201    let first = data
202        .iter()
203        .position(|x| !x.is_nan())
204        .ok_or(SmmaError::AllValuesNaN)?;
205    let len = data.len();
206    let period = input.get_period();
207
208    if period == 0 || period > len {
209        return Err(SmmaError::InvalidPeriod {
210            period,
211            data_len: len,
212        });
213    }
214    if (len - first) < period {
215        return Err(SmmaError::NotEnoughValidData {
216            needed: period,
217            valid: len - first,
218        });
219    }
220
221    let chosen = match kernel {
222        Kernel::Auto => detect_best_kernel(),
223        other => other,
224    };
225
226    let warm = first + period - 1;
227    let mut out = alloc_with_nan_prefix(len, warm);
228
229    unsafe {
230        match chosen {
231            Kernel::Scalar | Kernel::ScalarBatch => smma_scalar(data, period, first, &mut out),
232            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
233            Kernel::Avx2 | Kernel::Avx2Batch => smma_avx2(data, period, first, &mut out),
234            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
235            Kernel::Avx512 | Kernel::Avx512Batch => smma_avx512(data, period, first, &mut out),
236            _ => smma_scalar(data, period, first, &mut out),
237        }
238    }
239    Ok(SmmaOutput { values: out })
240}
241
242#[inline]
243pub fn smma_scalar(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
244    let len = data.len();
245    let end = first + period;
246
247    if period == 1 {
248        out[first] = data[first];
249        let mut i = first + 1;
250        while i < len {
251            out[i] = data[i];
252            i += 1;
253        }
254        return;
255    }
256
257    let mut sum = 0.0f64;
258    for i in first..end {
259        sum += data[i];
260    }
261
262    let pf64 = period as f64;
263    let pm1 = pf64 - 1.0;
264
265    let mut prev = sum / pf64;
266    out[end - 1] = prev;
267
268    for i in end..len {
269        prev = (prev * pm1 + data[i]) / pf64;
270        out[i] = prev;
271    }
272}
273
274#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
275#[inline]
276unsafe fn smma_simd128(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
277    let end = first + period;
278    let sum: f64 = data[first..end].iter().sum();
279    let mut prev = sum / period as f64;
280    out[end - 1] = prev;
281
282    let period_f64 = period as f64;
283    let period_minus_1 = period_f64 - 1.0;
284    let inv_period = 1.0 / period_f64;
285
286    for i in end..data.len() {
287        prev = (prev * period_minus_1 + data[i]) * inv_period;
288        out[i] = prev;
289    }
290}
291
292#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
293#[inline]
294pub fn smma_avx512(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
295    unsafe { smma_avx512_long(data, period, first, out) }
296}
297
298#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
299#[inline]
300pub fn smma_avx2(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
301    let len = data.len();
302    let end = first + period;
303
304    if period == 1 {
305        out[first] = data[first];
306        let mut i = first + 1;
307        while i < len {
308            out[i] = data[i];
309            i += 1;
310        }
311        return;
312    }
313
314    let mut sum = 0.0f64;
315    let mut i = first;
316    while i < end {
317        sum += data[i];
318        i += 1;
319    }
320
321    let pf64 = period as f64;
322    let pm1 = pf64 - 1.0;
323    let inv_p = 1.0 / pf64;
324
325    let mut prev = sum * inv_p;
326    out[end - 1] = prev;
327
328    let mut t = end;
329    while t + 4 <= len {
330        let x0 = data[t];
331        prev = f64::mul_add(prev, pm1, x0) * inv_p;
332        out[t] = prev;
333
334        let x1 = data[t + 1];
335        prev = f64::mul_add(prev, pm1, x1) * inv_p;
336        out[t + 1] = prev;
337
338        let x2 = data[t + 2];
339        prev = f64::mul_add(prev, pm1, x2) * inv_p;
340        out[t + 2] = prev;
341
342        let x3 = data[t + 3];
343        prev = f64::mul_add(prev, pm1, x3) * inv_p;
344        out[t + 3] = prev;
345
346        t += 4;
347    }
348    while t < len {
349        let x = data[t];
350        prev = f64::mul_add(prev, pm1, x) * inv_p;
351        out[t] = prev;
352        t += 1;
353    }
354}
355
356#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
357#[inline]
358pub unsafe fn smma_avx512_short(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
359    smma_avx2(data, period, first, out)
360}
361
362#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
363#[inline]
364pub unsafe fn smma_avx512_long(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
365    smma_avx2(data, period, first, out)
366}
367
368#[inline(always)]
369fn smma_prepare<'a>(
370    input: &'a SmmaInput,
371    kernel: Kernel,
372) -> Result<(&'a [f64], usize, usize, Kernel), SmmaError> {
373    let data: &[f64] = input.as_ref();
374
375    if data.is_empty() {
376        return Err(SmmaError::EmptyInputData);
377    }
378
379    let first = data
380        .iter()
381        .position(|x| !x.is_nan())
382        .ok_or(SmmaError::AllValuesNaN)?;
383    let len = data.len();
384    let period = input.get_period();
385
386    if period == 0 || period > len {
387        return Err(SmmaError::InvalidPeriod {
388            period,
389            data_len: len,
390        });
391    }
392    if (len - first) < period {
393        return Err(SmmaError::NotEnoughValidData {
394            needed: period,
395            valid: len - first,
396        });
397    }
398
399    let chosen = match kernel {
400        Kernel::Auto => detect_best_kernel(),
401        other => other,
402    };
403
404    Ok((data, period, first, chosen))
405}
406
407#[inline(always)]
408fn smma_compute_into(data: &[f64], period: usize, first: usize, kernel: Kernel, out: &mut [f64]) {
409    unsafe {
410        #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
411        {
412            if matches!(kernel, Kernel::Scalar | Kernel::ScalarBatch) {
413                smma_simd128(data, period, first, out);
414                return;
415            }
416        }
417
418        match kernel {
419            Kernel::Scalar | Kernel::ScalarBatch => smma_scalar(data, period, first, out),
420            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
421            Kernel::Avx2 | Kernel::Avx2Batch => smma_avx2(data, period, first, out),
422            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
423            Kernel::Avx512 | Kernel::Avx512Batch => smma_avx512(data, period, first, out),
424            _ => smma_scalar(data, period, first, out),
425        }
426    }
427}
428
429#[inline]
430pub fn expand_grid(r: &SmmaBatchRange) -> Vec<SmmaParams> {
431    let axis_usize = |(start, end, step): (usize, usize, usize)| -> Vec<usize> {
432        if step == 0 {
433            return vec![start];
434        }
435        if start == end {
436            return vec![start];
437        }
438        let mut out = Vec::new();
439        if start < end {
440            let mut v = start;
441            while v <= end {
442                out.push(v);
443                match v.checked_add(step) {
444                    Some(next) => {
445                        if next == v {
446                            break;
447                        }
448                        v = next;
449                    }
450                    None => break,
451                }
452            }
453        } else {
454            let mut v = start;
455            while v >= end {
456                out.push(v);
457
458                let next = v.saturating_sub(step);
459                if next == v {
460                    break;
461                }
462                v = next;
463                if v == 0 {
464                    if end > 0 {
465                        break;
466                    }
467                }
468            }
469        }
470        out
471    };
472    axis_usize(r.period)
473        .into_iter()
474        .map(|p| SmmaParams { period: Some(p) })
475        .collect()
476}
477
478#[derive(Debug, Clone)]
479pub struct SmmaStream {
480    period: usize,
481
482    pf64: f64,
483    pm1: f64,
484    inv_p: f64,
485
486    state: SmmaStreamState,
487}
488
489#[derive(Debug, Clone)]
490enum SmmaStreamState {
491    SeekingFirst,
492    Warming { sum: f64, count: usize },
493    Ready { value: f64 },
494}
495
496impl SmmaStream {
497    #[inline]
498    pub fn try_new(params: SmmaParams) -> Result<Self, SmmaError> {
499        let period = params.period.unwrap_or(7);
500        if period == 0 {
501            return Err(SmmaError::InvalidPeriod {
502                period,
503                data_len: 0,
504            });
505        }
506        let pf64 = period as f64;
507        Ok(Self {
508            period,
509            pf64,
510            pm1: pf64 - 1.0,
511            inv_p: 1.0 / pf64,
512            state: SmmaStreamState::SeekingFirst,
513        })
514    }
515
516    #[inline(always)]
517    pub fn update(&mut self, v: f64) -> Option<f64> {
518        use SmmaStreamState::*;
519        match &mut self.state {
520            SeekingFirst => {
521                if v.is_finite() {
522                    if self.period == 1 {
523                        self.state = Ready { value: v };
524                        return Some(v);
525                    }
526                    self.state = Warming { sum: v, count: 1 };
527                }
528                None
529            }
530            Warming { sum, count } => {
531                *sum += v;
532                *count += 1;
533                if *count == self.period {
534                    let first_val = *sum / self.pf64;
535                    self.state = Ready { value: first_val };
536                    Some(first_val)
537                } else {
538                    None
539                }
540            }
541            Ready { value } => {
542                let next = (*value * self.pm1 + v) / self.pf64;
543                *value = next;
544                Some(next)
545            }
546        }
547    }
548
549    #[inline(always)]
550    pub fn is_ready(&self) -> bool {
551        matches!(self.state, SmmaStreamState::Ready { .. })
552    }
553    #[inline(always)]
554    pub fn current(&self) -> Option<f64> {
555        match self.state {
556            SmmaStreamState::Ready { value } => Some(value),
557            _ => None,
558        }
559    }
560    #[inline]
561    pub fn reset(&mut self) {
562        self.state = SmmaStreamState::SeekingFirst;
563    }
564}
565
566#[derive(Clone, Debug)]
567pub struct SmmaBatchRange {
568    pub period: (usize, usize, usize),
569}
570
571impl Default for SmmaBatchRange {
572    fn default() -> Self {
573        Self {
574            period: (7, 256, 1),
575        }
576    }
577}
578
579#[derive(Clone, Debug, Default)]
580pub struct SmmaBatchBuilder {
581    range: SmmaBatchRange,
582    kernel: Kernel,
583}
584
585impl SmmaBatchBuilder {
586    pub fn new() -> Self {
587        Self::default()
588    }
589    pub fn kernel(mut self, k: Kernel) -> Self {
590        self.kernel = k;
591        self
592    }
593    #[inline]
594    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
595        self.range.period = (start, end, step);
596        self
597    }
598    #[inline]
599    pub fn period_static(mut self, p: usize) -> Self {
600        self.range.period = (p, p, 0);
601        self
602    }
603
604    pub fn apply_slice(self, data: &[f64]) -> Result<SmmaBatchOutput, SmmaError> {
605        smma_batch_with_kernel(data, &self.range, self.kernel)
606    }
607    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<SmmaBatchOutput, SmmaError> {
608        SmmaBatchBuilder::new().kernel(k).apply_slice(data)
609    }
610    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<SmmaBatchOutput, SmmaError> {
611        let slice = source_type(c, src);
612        self.apply_slice(slice)
613    }
614    pub fn with_default_candles(c: &Candles) -> Result<SmmaBatchOutput, SmmaError> {
615        SmmaBatchBuilder::new()
616            .kernel(Kernel::Auto)
617            .apply_candles(c, "close")
618    }
619}
620
621pub fn smma_batch_with_kernel(
622    data: &[f64],
623    sweep: &SmmaBatchRange,
624    k: Kernel,
625) -> Result<SmmaBatchOutput, SmmaError> {
626    let kernel = match k {
627        Kernel::Auto => detect_best_batch_kernel(),
628        other if other.is_batch() => other,
629        other => return Err(SmmaError::InvalidKernelForBatch { kernel: other }),
630    };
631    let simd = match kernel {
632        Kernel::Avx512Batch => Kernel::Avx512,
633        Kernel::Avx2Batch => Kernel::Avx2,
634        Kernel::ScalarBatch => Kernel::Scalar,
635        _ => Kernel::Scalar,
636    };
637    smma_batch_par_slice(data, sweep, simd)
638}
639
640#[derive(Clone, Debug)]
641pub struct SmmaBatchOutput {
642    pub values: Vec<f64>,
643    pub combos: Vec<SmmaParams>,
644    pub rows: usize,
645    pub cols: usize,
646}
647impl SmmaBatchOutput {
648    pub fn row_for_params(&self, p: &SmmaParams) -> Option<usize> {
649        self.combos
650            .iter()
651            .position(|c| c.period.unwrap_or(7) == p.period.unwrap_or(7))
652    }
653    pub fn values_for(&self, p: &SmmaParams) -> Option<&[f64]> {
654        self.row_for_params(p).map(|row| {
655            let start = row * self.cols;
656            &self.values[start..start + self.cols]
657        })
658    }
659}
660
661#[inline(always)]
662pub fn smma_batch_slice(
663    data: &[f64],
664    sweep: &SmmaBatchRange,
665    kern: Kernel,
666) -> Result<SmmaBatchOutput, SmmaError> {
667    smma_batch_inner(data, sweep, kern, false)
668}
669#[inline(always)]
670pub fn smma_batch_par_slice(
671    data: &[f64],
672    sweep: &SmmaBatchRange,
673    kern: Kernel,
674) -> Result<SmmaBatchOutput, SmmaError> {
675    smma_batch_inner(data, sweep, kern, true)
676}
677#[inline(always)]
678fn smma_batch_inner(
679    data: &[f64],
680    sweep: &SmmaBatchRange,
681    kern: Kernel,
682    parallel: bool,
683) -> Result<SmmaBatchOutput, SmmaError> {
684    if data.is_empty() {
685        return Err(SmmaError::EmptyInputData);
686    }
687
688    let combos = expand_grid(sweep);
689    if combos.is_empty() {
690        let (s, e, st) = sweep.period;
691        return Err(SmmaError::InvalidRange {
692            start: s,
693            end: e,
694            step: st,
695        });
696    }
697    let first = data
698        .iter()
699        .position(|x| !x.is_nan())
700        .ok_or(SmmaError::AllValuesNaN)?;
701    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
702    if data.len() - first < max_p {
703        return Err(SmmaError::NotEnoughValidData {
704            needed: max_p,
705            valid: data.len() - first,
706        });
707    }
708    let rows = combos.len();
709    let cols = data.len();
710
711    let warm: Vec<usize> = combos
712        .iter()
713        .map(|c| first + c.period.unwrap() - 1)
714        .collect();
715
716    let mut raw = make_uninit_matrix(rows, cols);
717    unsafe {
718        init_matrix_prefixes(&mut raw, cols, &warm);
719    }
720
721    let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
722        let period = combos[row].period.unwrap();
723
724        let out_row =
725            core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
726
727        match kern {
728            Kernel::Scalar | Kernel::ScalarBatch => smma_row_scalar(data, first, period, out_row),
729            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
730            Kernel::Avx2 | Kernel::Avx2Batch => smma_row_avx2(data, first, period, out_row),
731            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
732            Kernel::Avx512 | Kernel::Avx512Batch => smma_row_avx512(data, first, period, out_row),
733            _ => smma_row_scalar(data, first, period, out_row),
734        }
735    };
736
737    if parallel {
738        #[cfg(not(target_arch = "wasm32"))]
739        {
740            raw.par_chunks_mut(cols)
741                .enumerate()
742                .for_each(|(row, slice)| do_row(row, slice));
743        }
744
745        #[cfg(target_arch = "wasm32")]
746        {
747            for (row, slice) in raw.chunks_mut(cols).enumerate() {
748                do_row(row, slice);
749            }
750        }
751    } else {
752        for (row, slice) in raw.chunks_mut(cols).enumerate() {
753            do_row(row, slice);
754        }
755    }
756
757    let mut guard = core::mem::ManuallyDrop::new(raw);
758    let values = unsafe {
759        Vec::from_raw_parts(
760            guard.as_mut_ptr() as *mut f64,
761            guard.len(),
762            guard.capacity(),
763        )
764    };
765
766    Ok(SmmaBatchOutput {
767        values,
768        combos,
769        rows,
770        cols,
771    })
772}
773
774#[inline(always)]
775fn smma_batch_inner_into(
776    data: &[f64],
777    sweep: &SmmaBatchRange,
778    kern: Kernel,
779    parallel: bool,
780    out: &mut [f64],
781) -> Result<Vec<SmmaParams>, SmmaError> {
782    if data.is_empty() {
783        return Err(SmmaError::EmptyInputData);
784    }
785
786    let combos = expand_grid(sweep);
787    if combos.is_empty() {
788        let (s, e, st) = sweep.period;
789        return Err(SmmaError::InvalidRange {
790            start: s,
791            end: e,
792            step: st,
793        });
794    }
795    let first = data
796        .iter()
797        .position(|x| !x.is_nan())
798        .ok_or(SmmaError::AllValuesNaN)?;
799    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
800    if data.len() - first < max_p {
801        return Err(SmmaError::NotEnoughValidData {
802            needed: max_p,
803            valid: data.len() - first,
804        });
805    }
806
807    let rows = combos.len();
808    let cols = data.len();
809
810    let warm: Vec<usize> = combos
811        .iter()
812        .map(|c| first + c.period.unwrap() - 1)
813        .collect();
814
815    let out_uninit = unsafe {
816        std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
817    };
818
819    unsafe { init_matrix_prefixes(out_uninit, cols, &warm) };
820
821    let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
822        let period = combos[row].period.unwrap();
823
824        let dst = core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
825
826        match kern {
827            Kernel::Scalar | Kernel::ScalarBatch => smma_row_scalar(data, first, period, dst),
828            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
829            Kernel::Avx2 | Kernel::Avx2Batch => smma_row_avx2(data, first, period, dst),
830            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
831            Kernel::Avx512 | Kernel::Avx512Batch => smma_row_avx512(data, first, period, dst),
832            _ => smma_row_scalar(data, first, period, dst),
833        }
834    };
835
836    if parallel {
837        #[cfg(not(target_arch = "wasm32"))]
838        {
839            out_uninit
840                .par_chunks_mut(cols)
841                .enumerate()
842                .for_each(|(row, slice)| do_row(row, slice));
843        }
844        #[cfg(target_arch = "wasm32")]
845        {
846            for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
847                do_row(row, slice);
848            }
849        }
850    } else {
851        for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
852            do_row(row, slice);
853        }
854    }
855
856    Ok(combos)
857}
858
859#[inline(always)]
860pub unsafe fn smma_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
861    smma_scalar(data, period, first, out)
862}
863#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
864#[inline(always)]
865pub unsafe fn smma_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
866    smma_avx2(data, period, first, out)
867}
868#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
869#[inline(always)]
870pub unsafe fn smma_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
871    if period <= 32 {
872        smma_row_avx512_short(data, first, period, out);
873    } else {
874        smma_row_avx512_long(data, first, period, out);
875    }
876}
877#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
878#[inline(always)]
879pub unsafe fn smma_row_avx512_short(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
880    smma_row_avx2(data, first, period, out)
881}
882#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
883#[inline(always)]
884pub unsafe fn smma_row_avx512_long(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
885    smma_row_avx2(data, first, period, out)
886}
887
888#[cfg(test)]
889mod tests {
890    use super::*;
891    use crate::skip_if_unsupported;
892    use crate::utilities::data_loader::read_candles_from_csv;
893    #[cfg(feature = "proptest")]
894    use proptest::prelude::*;
895
896    fn check_smma_partial_params(
897        test_name: &str,
898        kernel: Kernel,
899    ) -> Result<(), Box<dyn std::error::Error>> {
900        skip_if_unsupported!(kernel, test_name);
901        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
902        let candles = read_candles_from_csv(file_path)?;
903        let input = SmmaInput::from_candles(&candles, "close", SmmaParams { period: None });
904        let output = smma_with_kernel(&input, kernel)?;
905        assert_eq!(output.values.len(), candles.close.len());
906        Ok(())
907    }
908    fn check_smma_accuracy(
909        test_name: &str,
910        kernel: Kernel,
911    ) -> Result<(), Box<dyn std::error::Error>> {
912        skip_if_unsupported!(kernel, test_name);
913        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
914        let candles = read_candles_from_csv(file_path)?;
915        let input = SmmaInput::from_candles(&candles, "close", SmmaParams::default());
916        let result = smma_with_kernel(&input, kernel)?;
917        let expected_last_five = [59434.4, 59398.2, 59346.9, 59319.4, 59224.5];
918        let start = result.values.len().saturating_sub(5);
919        for (i, &val) in result.values[start..].iter().enumerate() {
920            let diff = (val - expected_last_five[i]).abs();
921            assert!(
922                diff < 1e-1,
923                "[{}] SMMA {:?} mismatch at idx {}: got {}, expected {}",
924                test_name,
925                kernel,
926                i,
927                val,
928                expected_last_five[i]
929            );
930        }
931        Ok(())
932    }
933    fn check_smma_default_candles(
934        test_name: &str,
935        kernel: Kernel,
936    ) -> Result<(), Box<dyn std::error::Error>> {
937        skip_if_unsupported!(kernel, test_name);
938        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
939        let candles = read_candles_from_csv(file_path)?;
940        let input = SmmaInput::with_default_candles(&candles);
941        match input.data {
942            SmmaData::Candles { source, .. } => assert_eq!(source, "close"),
943            _ => panic!("Expected SmmaData::Candles"),
944        }
945        let output = smma_with_kernel(&input, kernel)?;
946        assert_eq!(output.values.len(), candles.close.len());
947        Ok(())
948    }
949    fn check_smma_zero_period(
950        test_name: &str,
951        kernel: Kernel,
952    ) -> Result<(), Box<dyn std::error::Error>> {
953        skip_if_unsupported!(kernel, test_name);
954        let input_data = [10.0, 20.0, 30.0];
955        let params = SmmaParams { period: Some(0) };
956        let input = SmmaInput::from_slice(&input_data, params);
957        let res = smma_with_kernel(&input, kernel);
958        assert!(
959            res.is_err(),
960            "[{}] SMMA should fail with zero period",
961            test_name
962        );
963        Ok(())
964    }
965    fn check_smma_period_exceeds_length(
966        test_name: &str,
967        kernel: Kernel,
968    ) -> Result<(), Box<dyn std::error::Error>> {
969        skip_if_unsupported!(kernel, test_name);
970        let data_small = [10.0, 20.0, 30.0];
971        let params = SmmaParams { period: Some(10) };
972        let input = SmmaInput::from_slice(&data_small, params);
973        let res = smma_with_kernel(&input, kernel);
974        assert!(
975            res.is_err(),
976            "[{}] SMMA should fail with period exceeding length",
977            test_name
978        );
979        Ok(())
980    }
981    fn check_smma_very_small_dataset(
982        test_name: &str,
983        kernel: Kernel,
984    ) -> Result<(), Box<dyn std::error::Error>> {
985        skip_if_unsupported!(kernel, test_name);
986        let single_point = [42.0];
987        let params = SmmaParams { period: Some(9) };
988        let input = SmmaInput::from_slice(&single_point, params);
989        let res = smma_with_kernel(&input, kernel);
990        assert!(
991            res.is_err(),
992            "[{}] SMMA should fail with insufficient data",
993            test_name
994        );
995        Ok(())
996    }
997    fn check_smma_empty_input(
998        test_name: &str,
999        kernel: Kernel,
1000    ) -> Result<(), Box<dyn std::error::Error>> {
1001        skip_if_unsupported!(kernel, test_name);
1002        let empty: Vec<f64> = vec![];
1003        let params = SmmaParams { period: Some(7) };
1004        let input = SmmaInput::from_slice(&empty, params);
1005        let res = smma_with_kernel(&input, kernel);
1006        assert!(
1007            res.is_err(),
1008            "[{}] SMMA should fail with empty input",
1009            test_name
1010        );
1011        if let Err(SmmaError::EmptyInputData) = res {
1012        } else {
1013            panic!(
1014                "[{}] Expected EmptyInputData error, got {:?}",
1015                test_name, res
1016            );
1017        }
1018        Ok(())
1019    }
1020    fn check_smma_reinput(
1021        test_name: &str,
1022        kernel: Kernel,
1023    ) -> Result<(), Box<dyn std::error::Error>> {
1024        skip_if_unsupported!(kernel, test_name);
1025        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1026        let candles = read_candles_from_csv(file_path)?;
1027        let first_params = SmmaParams { period: Some(7) };
1028        let first_input = SmmaInput::from_candles(&candles, "close", first_params);
1029        let first_result = smma_with_kernel(&first_input, kernel)?;
1030        let second_params = SmmaParams { period: Some(5) };
1031        let second_input = SmmaInput::from_slice(&first_result.values, second_params);
1032        let second_result = smma_with_kernel(&second_input, kernel)?;
1033        assert_eq!(second_result.values.len(), first_result.values.len());
1034        Ok(())
1035    }
1036    fn check_smma_nan_handling(
1037        test_name: &str,
1038        kernel: Kernel,
1039    ) -> Result<(), Box<dyn std::error::Error>> {
1040        skip_if_unsupported!(kernel, test_name);
1041        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1042        let candles = read_candles_from_csv(file_path)?;
1043        let input = SmmaInput::from_candles(&candles, "close", SmmaParams { period: Some(7) });
1044        let res = smma_with_kernel(&input, kernel)?;
1045        assert_eq!(res.values.len(), candles.close.len());
1046        if res.values.len() > 240 {
1047            for (i, &val) in res.values[240..].iter().enumerate() {
1048                assert!(
1049                    !val.is_nan(),
1050                    "[{}] Found unexpected NaN at out-index {}",
1051                    test_name,
1052                    240 + i
1053                );
1054            }
1055        }
1056        Ok(())
1057    }
1058    fn check_smma_streaming(
1059        test_name: &str,
1060        kernel: Kernel,
1061    ) -> Result<(), Box<dyn std::error::Error>> {
1062        skip_if_unsupported!(kernel, test_name);
1063        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1064        let candles = read_candles_from_csv(file_path)?;
1065        let period = 7;
1066        let input = SmmaInput::from_candles(
1067            &candles,
1068            "close",
1069            SmmaParams {
1070                period: Some(period),
1071            },
1072        );
1073        let batch_output = smma_with_kernel(&input, kernel)?.values;
1074        let mut stream = SmmaStream::try_new(SmmaParams {
1075            period: Some(period),
1076        })?;
1077        let mut stream_values = Vec::with_capacity(candles.close.len());
1078        for &price in &candles.close {
1079            match stream.update(price) {
1080                Some(smma_val) => stream_values.push(smma_val),
1081                None => stream_values.push(f64::NAN),
1082            }
1083        }
1084        assert_eq!(batch_output.len(), stream_values.len());
1085        for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1086            if b.is_nan() && s.is_nan() {
1087                continue;
1088            }
1089            let diff = (b - s).abs();
1090            assert!(
1091                diff < 1e-7,
1092                "[{}] SMMA streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1093                test_name,
1094                i,
1095                b,
1096                s,
1097                diff
1098            );
1099        }
1100        Ok(())
1101    }
1102    #[cfg(feature = "proptest")]
1103    fn check_smma_property(
1104        test_name: &str,
1105        kernel: Kernel,
1106    ) -> Result<(), Box<dyn std::error::Error>> {
1107        use proptest::prelude::*;
1108        skip_if_unsupported!(kernel, test_name);
1109
1110        let strat = (1usize..=100).prop_flat_map(|period| {
1111            (
1112                prop::collection::vec(
1113                    (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1114                    period.max(2)..400,
1115                ),
1116                Just(period),
1117            )
1118        });
1119
1120        proptest::test_runner::TestRunner::default().run(&strat, |(data, period)| {
1121            let params = SmmaParams {
1122                period: Some(period),
1123            };
1124            let input = SmmaInput::from_slice(&data, params);
1125
1126            let output = smma_with_kernel(&input, kernel)?;
1127
1128            let reference = smma_with_kernel(&input, Kernel::Scalar)?;
1129
1130            prop_assert_eq!(output.values.len(), data.len());
1131            prop_assert_eq!(reference.values.len(), data.len());
1132
1133            if period > 1 {
1134                for i in 0..period - 1 {
1135                    prop_assert!(
1136                        output.values[i].is_nan(),
1137                        "Expected NaN at index {} but got {}",
1138                        i,
1139                        output.values[i]
1140                    );
1141                }
1142            }
1143
1144            let first_smma_idx = if period == 1 { 0 } else { period - 1 };
1145            let first_sum: f64 = data[0..period].iter().sum();
1146            let expected_first = first_sum / period as f64;
1147            let actual_first = output.values[first_smma_idx];
1148            prop_assert!(
1149                (actual_first - expected_first).abs() < 1e-7,
1150                "First SMMA value mismatch: expected {}, got {} (diff: {})",
1151                expected_first,
1152                actual_first,
1153                (actual_first - expected_first).abs()
1154            );
1155
1156            if data.len() > period {
1157                for i in period..data.len().min(period + 10) {
1158                    let prev_smma = output.values[i - 1];
1159                    let expected = (prev_smma * (period as f64 - 1.0) + data[i]) / period as f64;
1160                    let actual = output.values[i];
1161
1162                    prop_assert!(
1163                        (actual - expected).abs() < 1e-7,
1164                        "Recursive formula mismatch at index {}: expected {}, got {} (diff: {})",
1165                        i,
1166                        expected,
1167                        actual,
1168                        (actual - expected).abs()
1169                    );
1170                }
1171            }
1172
1173            for i in 0..output.values.len() {
1174                let test_val = output.values[i];
1175                let ref_val = reference.values[i];
1176
1177                if test_val.is_nan() && ref_val.is_nan() {
1178                    continue;
1179                }
1180
1181                prop_assert!(
1182                    test_val.is_finite() == ref_val.is_finite(),
1183                    "Finite/NaN mismatch at index {}: test={}, ref={}",
1184                    i,
1185                    test_val,
1186                    ref_val
1187                );
1188
1189                if test_val.is_finite() && ref_val.is_finite() {
1190                    let test_bits = test_val.to_bits();
1191                    let ref_bits = ref_val.to_bits();
1192                    let ulp_diff = test_bits.abs_diff(ref_bits);
1193
1194                    let max_ulps = if matches!(kernel, Kernel::Avx512 | Kernel::Avx512Batch) {
1195                        20
1196                    } else {
1197                        10
1198                    };
1199
1200                    prop_assert!(
1201						ulp_diff <= max_ulps || (test_val - ref_val).abs() < 1e-7,
1202						"Cross-kernel mismatch at index {}: test={}, ref={}, ULP diff={}, abs diff={}",
1203						i,
1204						test_val,
1205						ref_val,
1206						ulp_diff,
1207						(test_val - ref_val).abs()
1208					);
1209                }
1210            }
1211
1212            let start_idx = if period == 1 { 0 } else { period - 1 };
1213            for i in start_idx..output.values.len() {
1214                let val = output.values[i];
1215                if val.is_finite() {
1216                    let data_up_to_i = &data[0..=i];
1217                    let min = data_up_to_i.iter().copied().fold(f64::INFINITY, f64::min);
1218                    let max = data_up_to_i
1219                        .iter()
1220                        .copied()
1221                        .fold(f64::NEG_INFINITY, f64::max);
1222
1223                    prop_assert!(
1224                        val >= min - 1e-9 && val <= max + 1e-9,
1225                        "SMMA value {} at index {} outside historical bounds [{}, {}]",
1226                        val,
1227                        i,
1228                        min,
1229                        max
1230                    );
1231                }
1232            }
1233
1234            if data.windows(2).all(|w| (w[0] - w[1]).abs() < f64::EPSILON) {
1235                let constant_val = data[0];
1236                let check_start = if period == 1 { 0 } else { period - 1 };
1237                for i in check_start..output.values.len() {
1238                    let val = output.values[i];
1239                    prop_assert!(
1240                        (val - constant_val).abs() < 1e-7,
1241                        "SMMA should converge to {} for constant data, but got {} at index {}",
1242                        constant_val,
1243                        val,
1244                        i
1245                    );
1246                }
1247            }
1248
1249            if period == 1 {
1250                prop_assert_eq!(
1251                    output.values[0],
1252                    data[0],
1253                    "Period=1: first value should equal input"
1254                );
1255                for i in 1..data.len() {
1256                    prop_assert!(
1257                        (output.values[i] - data[i]).abs() < 1e-7,
1258                        "Period=1: output should equal input at index {}: {} != {}",
1259                        i,
1260                        output.values[i],
1261                        data[i]
1262                    );
1263                }
1264            }
1265
1266            Ok(())
1267        })?;
1268
1269        Ok(())
1270    }
1271
1272    #[cfg(debug_assertions)]
1273    fn check_smma_no_poison(
1274        test_name: &str,
1275        kernel: Kernel,
1276    ) -> Result<(), Box<dyn std::error::Error>> {
1277        skip_if_unsupported!(kernel, test_name);
1278
1279        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1280        let candles = read_candles_from_csv(file_path)?;
1281
1282        let test_periods = vec![5, 7, 14, 50, 100, 200];
1283
1284        for &period in &test_periods {
1285            if period > candles.close.len() {
1286                continue;
1287            }
1288
1289            let input = SmmaInput::from_candles(
1290                &candles,
1291                "close",
1292                SmmaParams {
1293                    period: Some(period),
1294                },
1295            );
1296            let output = smma_with_kernel(&input, kernel)?;
1297
1298            for (i, &val) in output.values.iter().enumerate() {
1299                if val.is_nan() {
1300                    continue;
1301                }
1302
1303                let bits = val.to_bits();
1304
1305                if bits == 0x11111111_11111111 {
1306                    panic!(
1307						"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} with period {}",
1308						test_name, val, bits, i, period
1309					);
1310                }
1311
1312                if bits == 0x22222222_22222222 {
1313                    panic!(
1314						"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} with period {}",
1315						test_name, val, bits, i, period
1316					);
1317                }
1318
1319                if bits == 0x33333333_33333333 {
1320                    panic!(
1321						"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} with period {}",
1322						test_name, val, bits, i, period
1323					);
1324                }
1325            }
1326        }
1327
1328        Ok(())
1329    }
1330
1331    #[cfg(not(debug_assertions))]
1332    fn check_smma_no_poison(
1333        _test_name: &str,
1334        _kernel: Kernel,
1335    ) -> Result<(), Box<dyn std::error::Error>> {
1336        Ok(())
1337    }
1338
1339    macro_rules! generate_all_smma_tests {
1340        ($($test_fn:ident),*) => {
1341            paste::paste! {
1342                $(
1343                    #[test]
1344                    fn [<$test_fn _scalar_f64>]() {
1345                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1346                    }
1347                )*
1348                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1349                $(
1350                    #[test]
1351                    fn [<$test_fn _avx2_f64>]() {
1352                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1353                    }
1354                    #[test]
1355                    fn [<$test_fn _avx512_f64>]() {
1356                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1357                    }
1358                )*
1359            }
1360        }
1361    }
1362    generate_all_smma_tests!(
1363        check_smma_partial_params,
1364        check_smma_accuracy,
1365        check_smma_default_candles,
1366        check_smma_zero_period,
1367        check_smma_period_exceeds_length,
1368        check_smma_very_small_dataset,
1369        check_smma_empty_input,
1370        check_smma_reinput,
1371        check_smma_nan_handling,
1372        check_smma_streaming,
1373        check_smma_no_poison
1374    );
1375
1376    #[cfg(feature = "proptest")]
1377    generate_all_smma_tests!(check_smma_property);
1378
1379    #[test]
1380    fn test_smma_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
1381        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1382        let candles = read_candles_from_csv(file_path)?;
1383
1384        let input = SmmaInput::with_default_candles(&candles);
1385        let baseline = smma(&input)?.values;
1386
1387        let mut out = vec![0.0f64; baseline.len()];
1388
1389        #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1390        {
1391            smma_into(&input, &mut out)?;
1392        }
1393        #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1394        {
1395            smma_into_slice(&mut out, &input, Kernel::Auto)?;
1396        }
1397
1398        assert_eq!(out.len(), baseline.len());
1399
1400        for (i, (&a, &b)) in out.iter().zip(baseline.iter()).enumerate() {
1401            let equal = (a.is_nan() && b.is_nan()) || (a == b);
1402            assert!(
1403                equal,
1404                "into parity mismatch at idx {}: got {}, expected {}",
1405                i, a, b
1406            );
1407        }
1408
1409        Ok(())
1410    }
1411    fn check_batch_default_row(
1412        test: &str,
1413        kernel: Kernel,
1414    ) -> Result<(), Box<dyn std::error::Error>> {
1415        skip_if_unsupported!(kernel, test);
1416        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1417        let c = read_candles_from_csv(file)?;
1418        let output = SmmaBatchBuilder::new()
1419            .kernel(kernel)
1420            .apply_candles(&c, "close")?;
1421        let def = SmmaParams::default();
1422        let row = output.values_for(&def).expect("default row missing");
1423        assert_eq!(row.len(), c.close.len());
1424        let expected = [59434.4, 59398.2, 59346.9, 59319.4, 59224.5];
1425        let start = row.len() - 5;
1426        for (i, &v) in row[start..].iter().enumerate() {
1427            assert!(
1428                (v - expected[i]).abs() < 1e-1,
1429                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1430            );
1431        }
1432        Ok(())
1433    }
1434
1435    #[cfg(debug_assertions)]
1436    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
1437        skip_if_unsupported!(kernel, test);
1438
1439        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1440        let c = read_candles_from_csv(file)?;
1441
1442        let batch_configs = vec![
1443            (3, 7, 1),
1444            (10, 30, 10),
1445            (5, 100, 5),
1446            (50, 200, 50),
1447            (1, 10, 1),
1448        ];
1449
1450        for (start, end, step) in batch_configs {
1451            if start > c.close.len() {
1452                continue;
1453            }
1454
1455            let output = SmmaBatchBuilder::new()
1456                .kernel(kernel)
1457                .period_range(start, end, step)
1458                .apply_candles(&c, "close")?;
1459
1460            for (idx, &val) in output.values.iter().enumerate() {
1461                if val.is_nan() {
1462                    continue;
1463                }
1464
1465                let bits = val.to_bits();
1466                let row = idx / output.cols;
1467                let col = idx % output.cols;
1468                let period = output.combos[row].period.unwrap_or(0);
1469
1470                if bits == 0x11111111_11111111 {
1471                    panic!(
1472                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (flat index {}) for period {} in range ({}, {}, {})",
1473                        test, val, bits, row, col, idx, period, start, end, step
1474                    );
1475                }
1476
1477                if bits == 0x22222222_22222222 {
1478                    panic!(
1479                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (flat index {}) for period {} in range ({}, {}, {})",
1480                        test, val, bits, row, col, idx, period, start, end, step
1481                    );
1482                }
1483
1484                if bits == 0x33333333_33333333 {
1485                    panic!(
1486                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (flat index {}) for period {} in range ({}, {}, {})",
1487                        test, val, bits, row, col, idx, period, start, end, step
1488                    );
1489                }
1490            }
1491        }
1492
1493        Ok(())
1494    }
1495
1496    #[cfg(not(debug_assertions))]
1497    fn check_batch_no_poison(
1498        _test: &str,
1499        _kernel: Kernel,
1500    ) -> Result<(), Box<dyn std::error::Error>> {
1501        Ok(())
1502    }
1503    macro_rules! gen_batch_tests {
1504        ($fn_name:ident) => {
1505            paste::paste! {
1506                #[test] fn [<$fn_name _scalar>]()      {
1507                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1508                }
1509                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1510                #[test] fn [<$fn_name _avx2>]()        {
1511                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1512                }
1513                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1514                #[test] fn [<$fn_name _avx512>]()      {
1515                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1516                }
1517                #[test] fn [<$fn_name _auto_detect>]() {
1518                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1519                }
1520            }
1521        };
1522    }
1523    gen_batch_tests!(check_batch_default_row);
1524    gen_batch_tests!(check_batch_no_poison);
1525}
1526
1527#[cfg(feature = "python")]
1528#[pyfunction(name = "smma")]
1529#[pyo3(signature = (data, period, kernel=None))]
1530
1531pub fn smma_py<'py>(
1532    py: Python<'py>,
1533    data: numpy::PyReadonlyArray1<'py, f64>,
1534    period: usize,
1535    kernel: Option<&str>,
1536) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
1537    use numpy::{IntoPyArray, PyArrayMethods};
1538
1539    let slice_in = data.as_slice()?;
1540    let kern = validate_kernel(kernel, false)?;
1541
1542    let params = SmmaParams {
1543        period: Some(period),
1544    };
1545    let smma_in = SmmaInput::from_slice(slice_in, params);
1546
1547    let result_vec: Vec<f64> = py
1548        .allow_threads(|| smma_with_kernel(&smma_in, kern).map(|o| o.values))
1549        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1550
1551    Ok(result_vec.into_pyarray(py))
1552}
1553
1554#[cfg(feature = "python")]
1555#[pyclass(name = "SmmaStream")]
1556pub struct SmmaStreamPy {
1557    stream: SmmaStream,
1558}
1559
1560#[cfg(feature = "python")]
1561#[pymethods]
1562impl SmmaStreamPy {
1563    #[new]
1564    fn new(period: usize) -> PyResult<Self> {
1565        let params = SmmaParams {
1566            period: Some(period),
1567        };
1568        let stream =
1569            SmmaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1570        Ok(SmmaStreamPy { stream })
1571    }
1572
1573    fn update(&mut self, value: f64) -> Option<f64> {
1574        self.stream.update(value)
1575    }
1576}
1577
1578#[cfg(feature = "python")]
1579#[pyfunction(name = "smma_batch")]
1580#[pyo3(signature = (data, period_range, kernel=None))]
1581
1582pub fn smma_batch_py<'py>(
1583    py: Python<'py>,
1584    data: numpy::PyReadonlyArray1<'py, f64>,
1585    period_range: (usize, usize, usize),
1586    kernel: Option<&str>,
1587) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1588    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1589    use pyo3::types::PyDict;
1590
1591    let slice_in = data.as_slice()?;
1592    let kern = validate_kernel(kernel, true)?;
1593    let sweep = SmmaBatchRange {
1594        period: period_range,
1595    };
1596
1597    let combos = expand_grid(&sweep);
1598    if combos.is_empty() {
1599        return Err(PyValueError::new_err(format!(
1600            "invalid period range: start={}, end={}, step={}",
1601            sweep.period.0, sweep.period.1, sweep.period.2
1602        )));
1603    }
1604    let rows = combos.len();
1605    let cols = slice_in.len();
1606    let total = rows
1607        .checked_mul(cols)
1608        .ok_or_else(|| PyValueError::new_err("rows * cols overflow"))?;
1609
1610    let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1611    let slice_out = unsafe { out_arr.as_slice_mut()? };
1612
1613    let combos = py
1614        .allow_threads(|| {
1615            let kernel = match kern {
1616                Kernel::Auto => detect_best_batch_kernel(),
1617                k => k,
1618            };
1619            let simd = match kernel {
1620                Kernel::Avx512Batch => Kernel::Avx512,
1621                Kernel::Avx2Batch => Kernel::Avx2,
1622                Kernel::ScalarBatch => Kernel::Scalar,
1623                _ => unreachable!(),
1624            };
1625
1626            smma_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1627        })
1628        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1629
1630    let dict = PyDict::new(py);
1631    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1632    dict.set_item(
1633        "periods",
1634        combos
1635            .iter()
1636            .map(|p| p.period.unwrap() as u64)
1637            .collect::<Vec<_>>()
1638            .into_pyarray(py),
1639    )?;
1640
1641    Ok(dict.into())
1642}
1643
1644#[cfg(all(feature = "python", feature = "cuda"))]
1645#[pyfunction(name = "smma_cuda_batch_dev")]
1646#[pyo3(signature = (data, period_range, device_id=0))]
1647pub fn smma_cuda_batch_dev_py(
1648    py: Python<'_>,
1649    data: numpy::PyReadonlyArray1<'_, f64>,
1650    period_range: (usize, usize, usize),
1651    device_id: usize,
1652) -> PyResult<DeviceArrayF32SmmaPy> {
1653    use numpy::PyArrayMethods;
1654
1655    if !cuda_available() {
1656        return Err(PyValueError::new_err("CUDA not available"));
1657    }
1658
1659    let slice_in = data.as_slice()?;
1660    let sweep = SmmaBatchRange {
1661        period: period_range,
1662    };
1663    let data_f32: Vec<f32> = slice_in.iter().map(|&v| v as f32).collect();
1664
1665    let inner = py.allow_threads(|| {
1666        let cuda = CudaSmma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1667        cuda.smma_batch_dev(&data_f32, &sweep)
1668            .map_err(|e| PyValueError::new_err(e.to_string()))
1669    })?;
1670
1671    Ok(DeviceArrayF32SmmaPy { inner: Some(inner) })
1672}
1673
1674#[cfg(all(feature = "python", feature = "cuda"))]
1675#[pyfunction(name = "smma_cuda_many_series_one_param_dev")]
1676#[pyo3(signature = (data_tm_f32, period, device_id=0))]
1677pub fn smma_cuda_many_series_one_param_dev_py(
1678    py: Python<'_>,
1679    data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1680    period: usize,
1681    device_id: usize,
1682) -> PyResult<DeviceArrayF32SmmaPy> {
1683    use numpy::PyUntypedArrayMethods;
1684
1685    if !cuda_available() {
1686        return Err(PyValueError::new_err("CUDA not available"));
1687    }
1688
1689    let flat_in = data_tm_f32.as_slice()?;
1690    let rows = data_tm_f32.shape()[0];
1691    let cols = data_tm_f32.shape()[1];
1692    let params = SmmaParams {
1693        period: Some(period),
1694    };
1695
1696    let inner = py.allow_threads(|| {
1697        let cuda = CudaSmma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1698        cuda.smma_multi_series_one_param_time_major_dev(flat_in, cols, rows, &params)
1699            .map_err(|e| PyValueError::new_err(e.to_string()))
1700    })?;
1701
1702    Ok(DeviceArrayF32SmmaPy { inner: Some(inner) })
1703}
1704
1705#[cfg(all(feature = "python", feature = "cuda"))]
1706#[pyclass(module = "ta_indicators.cuda", name = "DeviceArrayF32Smma", unsendable)]
1707pub struct DeviceArrayF32SmmaPy {
1708    pub(crate) inner: Option<DeviceArrayF32Smma>,
1709}
1710
1711#[cfg(all(feature = "python", feature = "cuda"))]
1712#[pymethods]
1713impl DeviceArrayF32SmmaPy {
1714    #[getter]
1715    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1716        let inner = self
1717            .inner
1718            .as_ref()
1719            .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
1720        let d = PyDict::new(py);
1721        d.set_item("shape", (inner.rows, inner.cols))?;
1722        d.set_item("typestr", "<f4")?;
1723        d.set_item(
1724            "strides",
1725            (
1726                inner.cols * std::mem::size_of::<f32>(),
1727                std::mem::size_of::<f32>(),
1728            ),
1729        )?;
1730        d.set_item("data", (inner.device_ptr() as usize, false))?;
1731        d.set_item("version", 3)?;
1732        Ok(d)
1733    }
1734
1735    fn __dlpack_device__(&self) -> PyResult<(i32, i32)> {
1736        let inner = self
1737            .inner
1738            .as_ref()
1739            .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
1740        Ok((2, inner.device_id as i32))
1741    }
1742
1743    fn __dlpack__<'py>(
1744        &mut self,
1745        py: Python<'py>,
1746        stream: Option<pyo3::PyObject>,
1747        max_version: Option<pyo3::PyObject>,
1748        dl_device: Option<pyo3::PyObject>,
1749        copy: Option<pyo3::PyObject>,
1750    ) -> PyResult<PyObject> {
1751        use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
1752
1753        let (kdl, alloc_dev) = self.__dlpack_device__()?;
1754        if let Some(dev_obj) = dl_device.as_ref() {
1755            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1756                if dev_ty != kdl || dev_id != alloc_dev {
1757                    let wants_copy = copy
1758                        .as_ref()
1759                        .and_then(|c| c.extract::<bool>(py).ok())
1760                        .unwrap_or(false);
1761                    if wants_copy {
1762                        return Err(PyValueError::new_err(
1763                            "device copy not implemented for __dlpack__",
1764                        ));
1765                    } else {
1766                        return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
1767                    }
1768                }
1769            }
1770        }
1771
1772        let _ = stream;
1773
1774        let inner = self
1775            .inner
1776            .take()
1777            .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
1778
1779        let rows = inner.rows;
1780        let cols = inner.cols;
1781        let buf = inner.buf;
1782
1783        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1784
1785        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1786    }
1787}
1788
1789#[inline]
1790pub fn smma_into_slice(dst: &mut [f64], input: &SmmaInput, kern: Kernel) -> Result<(), SmmaError> {
1791    let (data, period, first, chosen) = smma_prepare(input, kern)?;
1792
1793    if dst.len() != data.len() {
1794        return Err(SmmaError::OutputLengthMismatch {
1795            expected: data.len(),
1796            got: dst.len(),
1797        });
1798    }
1799
1800    let warmup_end = first + period - 1;
1801    unsafe {
1802        let dst_uninit =
1803            std::slice::from_raw_parts_mut(dst.as_mut_ptr() as *mut MaybeUninit<f64>, dst.len());
1804        init_matrix_prefixes(dst_uninit, dst.len(), &[warmup_end]);
1805    }
1806
1807    smma_compute_into(data, period, first, chosen, dst);
1808
1809    Ok(())
1810}
1811
1812#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1813#[inline]
1814pub fn smma_into(input: &SmmaInput, out: &mut [f64]) -> Result<(), SmmaError> {
1815    smma_into_slice(out, input, Kernel::Auto)
1816}
1817
1818#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1819#[wasm_bindgen(js_name = "smma")]
1820pub fn smma_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1821    let params = SmmaParams {
1822        period: Some(period),
1823    };
1824    let input = SmmaInput::from_slice(data, params);
1825
1826    let mut output = vec![0.0; data.len()];
1827
1828    smma_into_slice(&mut output, &input, Kernel::Auto)
1829        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1830
1831    Ok(output)
1832}
1833
1834#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1835#[derive(Serialize, Deserialize)]
1836pub struct SmmaBatchConfig {
1837    pub period_range: (usize, usize, usize),
1838}
1839
1840#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1841#[derive(Serialize, Deserialize)]
1842pub struct SmmaBatchJsOutput {
1843    pub values: Vec<f64>,
1844    pub combos: Vec<SmmaParams>,
1845    pub rows: usize,
1846    pub cols: usize,
1847}
1848
1849#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1850#[wasm_bindgen(js_name = "smma_batch")]
1851pub fn smma_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1852    let config: SmmaBatchConfig = serde_wasm_bindgen::from_value(config)
1853        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1854
1855    let sweep = SmmaBatchRange {
1856        period: config.period_range,
1857    };
1858
1859    let output = smma_batch_inner(data, &sweep, Kernel::Auto, false)
1860        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1861
1862    let js_output = SmmaBatchJsOutput {
1863        values: output.values,
1864        combos: output.combos,
1865        rows: output.rows,
1866        cols: output.cols,
1867    };
1868
1869    serde_wasm_bindgen::to_value(&js_output)
1870        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1871}
1872
1873#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1874#[wasm_bindgen(js_name = "smma_batch_legacy")]
1875pub fn smma_batch_js(
1876    data: &[f64],
1877    period_start: usize,
1878    period_end: usize,
1879    period_step: usize,
1880) -> Result<Vec<f64>, JsValue> {
1881    let sweep = SmmaBatchRange {
1882        period: (period_start, period_end, period_step),
1883    };
1884
1885    smma_batch_with_kernel(data, &sweep, Kernel::Auto)
1886        .map(|output| output.values)
1887        .map_err(|e| JsValue::from_str(&e.to_string()))
1888}
1889
1890#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1891#[wasm_bindgen(js_name = "smma_batch_metadata")]
1892
1893pub fn smma_batch_metadata_js(
1894    period_start: usize,
1895    period_end: usize,
1896    period_step: usize,
1897) -> Vec<usize> {
1898    let range = SmmaBatchRange {
1899        period: (period_start, period_end, period_step),
1900    };
1901    let combos = expand_grid(&range);
1902    combos.iter().map(|c| c.period.unwrap_or(7)).collect()
1903}
1904
1905#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1906#[wasm_bindgen(js_name = "smma_batch_rows_cols")]
1907
1908pub fn smma_batch_rows_cols_js(
1909    period_start: usize,
1910    period_end: usize,
1911    period_step: usize,
1912    data_len: usize,
1913) -> Vec<usize> {
1914    let range = SmmaBatchRange {
1915        period: (period_start, period_end, period_step),
1916    };
1917    let combos = expand_grid(&range);
1918    vec![combos.len(), data_len]
1919}
1920
1921#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1922#[wasm_bindgen]
1923pub fn smma_alloc(len: usize) -> *mut f64 {
1924    let mut vec = Vec::<f64>::with_capacity(len);
1925    let ptr = vec.as_mut_ptr();
1926    std::mem::forget(vec);
1927    ptr
1928}
1929
1930#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1931#[wasm_bindgen]
1932pub fn smma_free(ptr: *mut f64, len: usize) {
1933    if !ptr.is_null() {
1934        unsafe {
1935            let _ = Vec::from_raw_parts(ptr, len, len);
1936        }
1937    }
1938}
1939
1940#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1941#[wasm_bindgen]
1942pub fn smma_into(
1943    in_ptr: *const f64,
1944    out_ptr: *mut f64,
1945    len: usize,
1946    period: usize,
1947) -> Result<(), JsValue> {
1948    if in_ptr.is_null() || out_ptr.is_null() {
1949        return Err(JsValue::from_str("null pointer passed to smma_into"));
1950    }
1951
1952    unsafe {
1953        let data = std::slice::from_raw_parts(in_ptr, len);
1954
1955        if period == 0 || period > len {
1956            return Err(JsValue::from_str("Invalid period"));
1957        }
1958
1959        let params = SmmaParams {
1960            period: Some(period),
1961        };
1962        let input = SmmaInput::from_slice(data, params);
1963
1964        if in_ptr == out_ptr as *const f64 {
1965            let mut temp = vec![0.0; len];
1966            smma_into_slice(&mut temp, &input, Kernel::Auto)
1967                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1968
1969            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1970            out.copy_from_slice(&temp);
1971        } else {
1972            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1973            smma_into_slice(out, &input, Kernel::Auto)
1974                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1975        }
1976
1977        Ok(())
1978    }
1979}
1980
1981#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1982#[wasm_bindgen]
1983pub fn smma_batch_into(
1984    in_ptr: *const f64,
1985    out_ptr: *mut f64,
1986    len: usize,
1987    period_start: usize,
1988    period_end: usize,
1989    period_step: usize,
1990) -> Result<usize, JsValue> {
1991    if in_ptr.is_null() || out_ptr.is_null() {
1992        return Err(JsValue::from_str("null pointer passed to smma_batch_into"));
1993    }
1994
1995    unsafe {
1996        let data = std::slice::from_raw_parts(in_ptr, len);
1997
1998        let sweep = SmmaBatchRange {
1999            period: (period_start, period_end, period_step),
2000        };
2001
2002        let combos = expand_grid(&sweep);
2003        if combos.is_empty() {
2004            return Err(JsValue::from_str("invalid period range"));
2005        }
2006        let rows = combos.len();
2007        let cols = len;
2008        let total = rows
2009            .checked_mul(cols)
2010            .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
2011
2012        let out = std::slice::from_raw_parts_mut(out_ptr, total);
2013
2014        let batch = detect_best_batch_kernel();
2015        let simd = match batch {
2016            Kernel::Avx512Batch => Kernel::Avx512,
2017            Kernel::Avx2Batch => Kernel::Avx2,
2018            _ => Kernel::Scalar,
2019        };
2020
2021        smma_batch_inner_into(data, &sweep, simd, false, out)
2022            .map_err(|e| JsValue::from_str(&e.to_string()))?;
2023
2024        Ok(rows)
2025    }
2026}