Skip to main content

vector_ta/indicators/moving_averages/
swma.rs

1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5    make_uninit_matrix,
6};
7use aligned_vec::{AVec, CACHELINE_ALIGN};
8#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
9use core::arch::x86_64::*;
10#[cfg(not(target_arch = "wasm32"))]
11use rayon::prelude::*;
12use std::convert::AsRef;
13use std::mem::MaybeUninit;
14use thiserror::Error;
15
16#[cfg(all(feature = "python", feature = "cuda"))]
17use crate::cuda::cuda_available;
18#[cfg(all(feature = "python", feature = "cuda"))]
19use crate::cuda::moving_averages::CudaSwma;
20#[cfg(all(feature = "python", feature = "cuda"))]
21use crate::utilities::dlpack_cuda::DeviceArrayF32Py;
22#[cfg(feature = "python")]
23use crate::utilities::kernel_validation::validate_kernel;
24#[cfg(feature = "python")]
25use numpy::{IntoPyArray, PyArray1};
26#[cfg(feature = "python")]
27use pyo3::exceptions::PyValueError;
28#[cfg(feature = "python")]
29use pyo3::prelude::*;
30#[cfg(feature = "python")]
31use pyo3::types::{PyDict, PyList};
32#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
33use serde::{Deserialize, Serialize};
34#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
35use wasm_bindgen::prelude::*;
36
37impl<'a> AsRef<[f64]> for SwmaInput<'a> {
38    #[inline(always)]
39    fn as_ref(&self) -> &[f64] {
40        match &self.data {
41            SwmaData::Slice(slice) => slice,
42            SwmaData::Candles { candles, source } => source_type(candles, source),
43        }
44    }
45}
46
47#[derive(Debug, Clone)]
48pub enum SwmaData<'a> {
49    Candles {
50        candles: &'a Candles,
51        source: &'a str,
52    },
53    Slice(&'a [f64]),
54}
55
56#[derive(Debug, Clone)]
57pub struct SwmaOutput {
58    pub values: Vec<f64>,
59}
60
61#[derive(Debug, Clone)]
62#[cfg_attr(
63    all(target_arch = "wasm32", feature = "wasm"),
64    derive(Serialize, Deserialize)
65)]
66pub struct SwmaParams {
67    pub period: Option<usize>,
68}
69
70impl Default for SwmaParams {
71    fn default() -> Self {
72        Self { period: Some(5) }
73    }
74}
75
76#[derive(Debug, Clone)]
77pub struct SwmaInput<'a> {
78    pub data: SwmaData<'a>,
79    pub params: SwmaParams,
80}
81
82impl<'a> SwmaInput<'a> {
83    #[inline]
84    pub fn from_candles(c: &'a Candles, s: &'a str, p: SwmaParams) -> Self {
85        Self {
86            data: SwmaData::Candles {
87                candles: c,
88                source: s,
89            },
90            params: p,
91        }
92    }
93    #[inline]
94    pub fn from_slice(sl: &'a [f64], p: SwmaParams) -> Self {
95        Self {
96            data: SwmaData::Slice(sl),
97            params: p,
98        }
99    }
100    #[inline]
101    pub fn with_default_candles(c: &'a Candles) -> Self {
102        Self::from_candles(c, "close", SwmaParams::default())
103    }
104    #[inline]
105    pub fn get_period(&self) -> usize {
106        self.params.period.unwrap_or(5)
107    }
108}
109
110#[derive(Copy, Clone, Debug)]
111pub struct SwmaBuilder {
112    period: Option<usize>,
113    kernel: Kernel,
114}
115
116impl Default for SwmaBuilder {
117    fn default() -> Self {
118        Self {
119            period: None,
120            kernel: Kernel::Auto,
121        }
122    }
123}
124
125impl SwmaBuilder {
126    #[inline(always)]
127    pub fn new() -> Self {
128        Self::default()
129    }
130    #[inline(always)]
131    pub fn period(mut self, n: usize) -> Self {
132        self.period = Some(n);
133        self
134    }
135    #[inline(always)]
136    pub fn kernel(mut self, k: Kernel) -> Self {
137        self.kernel = k;
138        self
139    }
140
141    #[inline(always)]
142    pub fn apply(self, c: &Candles) -> Result<SwmaOutput, SwmaError> {
143        let p = SwmaParams {
144            period: self.period,
145        };
146        let i = SwmaInput::from_candles(c, "close", p);
147        swma_with_kernel(&i, self.kernel)
148    }
149
150    #[inline(always)]
151    pub fn apply_slice(self, d: &[f64]) -> Result<SwmaOutput, SwmaError> {
152        let p = SwmaParams {
153            period: self.period,
154        };
155        let i = SwmaInput::from_slice(d, p);
156        swma_with_kernel(&i, self.kernel)
157    }
158
159    #[inline(always)]
160    pub fn into_stream(self) -> Result<SwmaStream, SwmaError> {
161        let p = SwmaParams {
162            period: self.period,
163        };
164        SwmaStream::try_new(p)
165    }
166}
167
168#[derive(Debug, Error)]
169pub enum SwmaError {
170    #[error("swma: Input data slice is empty.")]
171    EmptyInputData,
172    #[error("swma: All values are NaN.")]
173    AllValuesNaN,
174
175    #[error(
176		"swma: Invalid period: period = {period}, data length = {data_len}. Period must be between 1 and data length."
177	)]
178    InvalidPeriod { period: usize, data_len: usize },
179
180    #[error("swma: Not enough valid data: needed = {needed}, valid = {valid}")]
181    NotEnoughValidData { needed: usize, valid: usize },
182
183    #[error("swma: Output length mismatch: expected {expected}, got {got}")]
184    OutputLengthMismatch { expected: usize, got: usize },
185
186    #[error("swma: Invalid range expansion: start={start}, end={end}, step={step}")]
187    InvalidRange {
188        start: usize,
189        end: usize,
190        step: usize,
191    },
192
193    #[error("swma: Invalid kernel passed to batch path: {0:?}")]
194    InvalidKernelForBatch(Kernel),
195}
196
197#[inline]
198pub fn swma(input: &SwmaInput) -> Result<SwmaOutput, SwmaError> {
199    swma_with_kernel(input, Kernel::Auto)
200}
201
202#[inline]
203fn swma_prepare<'a>(
204    input: &'a SwmaInput,
205    kernel: Kernel,
206) -> Result<(&'a [f64], AVec<f64>, usize, usize, Kernel), SwmaError> {
207    let data: &[f64] = input.as_ref();
208    let len = data.len();
209    if len == 0 {
210        return Err(SwmaError::EmptyInputData);
211    }
212
213    let first = data
214        .iter()
215        .position(|x| !x.is_nan())
216        .ok_or(SwmaError::AllValuesNaN)?;
217    let period = input.get_period();
218
219    if period == 0 || period > len {
220        return Err(SwmaError::InvalidPeriod {
221            period,
222            data_len: len,
223        });
224    }
225    if len - first < period {
226        return Err(SwmaError::NotEnoughValidData {
227            needed: period,
228            valid: len - first,
229        });
230    }
231
232    let weights = build_symmetric_triangle_avec(period);
233    let chosen = match kernel {
234        Kernel::Auto => Kernel::Scalar,
235        k => k,
236    };
237
238    Ok((data, weights, period, first, chosen))
239}
240
241#[inline(always)]
242fn swma_compute_into(
243    data: &[f64],
244    weights: &[f64],
245    period: usize,
246    first: usize,
247    kernel: Kernel,
248    out: &mut [f64],
249) {
250    unsafe {
251        match kernel {
252            Kernel::Scalar | Kernel::ScalarBatch => swma_scalar(data, weights, period, first, out),
253            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
254            Kernel::Avx2 | Kernel::Avx2Batch => swma_avx2(data, weights, period, first, out),
255            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
256            Kernel::Avx512 | Kernel::Avx512Batch => swma_avx512(data, weights, period, first, out),
257            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
258            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
259                swma_scalar(data, weights, period, first, out)
260            }
261            _ => unreachable!(),
262        }
263    }
264}
265
266pub fn swma_with_kernel(input: &SwmaInput, kernel: Kernel) -> Result<SwmaOutput, SwmaError> {
267    let (data, weights, period, first, chosen) = swma_prepare(input, kernel)?;
268
269    let len = data.len();
270    let warm = first + period - 1;
271    let mut out = alloc_with_nan_prefix(len, warm);
272
273    swma_compute_into(data, &weights, period, first, chosen, &mut out);
274
275    Ok(SwmaOutput { values: out })
276}
277
278#[inline]
279pub fn swma_into_slice(dst: &mut [f64], input: &SwmaInput, kern: Kernel) -> Result<(), SwmaError> {
280    let (data, weights, period, first, chosen) = swma_prepare(input, kern)?;
281
282    if dst.len() != data.len() {
283        return Err(SwmaError::OutputLengthMismatch {
284            expected: data.len(),
285            got: dst.len(),
286        });
287    }
288
289    swma_compute_into(data, &weights, period, first, chosen, dst);
290
291    let warmup_end = first + period - 1;
292    for v in &mut dst[..warmup_end] {
293        *v = f64::from_bits(0x7ff8_0000_0000_0000);
294    }
295
296    Ok(())
297}
298
299#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
300#[inline]
301pub fn swma_into(input: &SwmaInput, out: &mut [f64]) -> Result<(), SwmaError> {
302    let (data, weights, period, first, chosen) = swma_prepare(input, Kernel::Auto)?;
303
304    if out.len() != data.len() {
305        return Err(SwmaError::OutputLengthMismatch {
306            expected: data.len(),
307            got: out.len(),
308        });
309    }
310
311    let warm = (first + period - 1).min(out.len());
312    for v in &mut out[..warm] {
313        *v = f64::from_bits(0x7ff8_0000_0000_0000);
314    }
315
316    swma_compute_into(data, &weights, period, first, chosen, out);
317    Ok(())
318}
319
320#[inline(always)]
321fn build_symmetric_triangle_vec(n: usize) -> Vec<f64> {
322    let mut w = Vec::with_capacity(n);
323    if n == 1 {
324        w.push(1.0);
325    } else if n == 2 {
326        w.extend_from_slice(&[0.5, 0.5]);
327    } else if n % 2 == 0 {
328        let half = n / 2;
329        for i in 1..=half {
330            w.push(i as f64);
331        }
332        for i in (1..=half).rev() {
333            w.push(i as f64);
334        }
335        let sum: f64 = triangle_weight_sum(n);
336        for x in &mut w {
337            *x /= sum;
338        }
339    } else {
340        let half_plus = (n + 1) / 2;
341        for i in 1..=half_plus {
342            w.push(i as f64);
343        }
344        for i in (1..half_plus).rev() {
345            w.push(i as f64);
346        }
347        let sum: f64 = triangle_weight_sum(n);
348        for x in &mut w {
349            *x /= sum;
350        }
351    }
352    w
353}
354
355#[inline(always)]
356fn triangle_weight_sum(n: usize) -> f64 {
357    if (n & 1) == 0 {
358        let m = (n >> 1) as f64;
359        m * (m + 1.0)
360    } else {
361        let m = ((n + 1) >> 1) as f64;
362        m * m
363    }
364}
365
366#[inline(always)]
367fn build_symmetric_triangle_avec(n: usize) -> AVec<f64> {
368    let mut weights: AVec<f64> = AVec::with_capacity(CACHELINE_ALIGN, n);
369
370    if n == 1 {
371        weights.push(1.0);
372    } else if n == 2 {
373        weights.push(0.5);
374        weights.push(0.5);
375    } else if n % 2 == 0 {
376        let half = n / 2;
377
378        for i in 1..=half {
379            weights.push(i as f64);
380        }
381
382        for i in (1..=half).rev() {
383            weights.push(i as f64);
384        }
385    } else {
386        let half_plus = (n + 1) / 2;
387
388        for i in 1..=half_plus {
389            weights.push(i as f64);
390        }
391
392        for i in (1..half_plus).rev() {
393            weights.push(i as f64);
394        }
395    }
396
397    let sum: f64 = if n <= 2 { 1.0 } else { triangle_weight_sum(n) };
398    for w in weights.iter_mut() {
399        *w /= sum;
400    }
401
402    weights
403}
404
405#[inline]
406pub fn swma_scalar(
407    data: &[f64],
408    _weights: &[f64],
409    period: usize,
410    first_val: usize,
411    out: &mut [f64],
412) {
413    debug_assert!(out.len() >= data.len());
414    debug_assert!(period >= 1);
415
416    let len = data.len();
417    if len == 0 {
418        return;
419    }
420
421    let (a, b) = if (period & 1) != 0 {
422        let m = (period + 1) >> 1;
423        (m, m)
424    } else {
425        let m = period >> 1;
426        (m, m + 1)
427    };
428
429    if period == 1 {
430        unsafe {
431            for i in first_val..len {
432                *out.get_unchecked_mut(i) = *data.get_unchecked(i);
433            }
434        }
435        return;
436    }
437
438    if period == 2 {
439        unsafe {
440            for i in (first_val + 1)..len {
441                *out.get_unchecked_mut(i) =
442                    (*data.get_unchecked(i - 1) + *data.get_unchecked(i)) * 0.5;
443            }
444        }
445        return;
446    }
447
448    let inv_ab = 1.0 / ((a as f64) * (b as f64));
449    let start_full_a = first_val + a - 1;
450    let start_full_ab = first_val + period - 1;
451
452    let mut ring = AVec::<f64>::with_capacity(CACHELINE_ALIGN, b);
453    ring.resize(b, 0.0);
454    let mut rb_idx = 0usize;
455
456    let mut s1_sum = 0.0_f64;
457    let mut s2_sum = 0.0_f64;
458
459    unsafe {
460        for i in first_val..len {
461            s1_sum += *data.get_unchecked(i);
462
463            if i >= start_full_a {
464                let old = *ring.get_unchecked(rb_idx);
465                s2_sum = s2_sum + (s1_sum - old);
466                *ring.get_unchecked_mut(rb_idx) = s1_sum;
467
468                rb_idx += 1;
469                if rb_idx == b {
470                    rb_idx = 0;
471                }
472
473                if i >= start_full_ab {
474                    *out.get_unchecked_mut(i) = s2_sum * inv_ab;
475                }
476
477                s1_sum -= *data.get_unchecked(i + 1 - a);
478            }
479        }
480    }
481}
482
483#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
484#[inline]
485pub fn swma_avx512(
486    data: &[f64],
487    weights: &[f64],
488    period: usize,
489    first_valid: usize,
490    out: &mut [f64],
491) {
492    if period <= 32 {
493        unsafe { swma_avx512_short(data, weights, period, first_valid, out) }
494    } else {
495        unsafe { swma_avx512_long(data, weights, period, first_valid, out) }
496    }
497}
498
499#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
500#[target_feature(enable = "avx2,fma")]
501unsafe fn swma_avx2(
502    data: &[f64],
503    weights: &[f64],
504    period: usize,
505    first_valid: usize,
506    out: &mut [f64],
507) {
508    swma_scalar(data, weights, period, first_valid, out)
509}
510
511#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
512#[target_feature(enable = "avx512f,fma")]
513unsafe fn swma_avx512_short(
514    data: &[f64],
515    weights: &[f64],
516    period: usize,
517    first_valid: usize,
518    out: &mut [f64],
519) {
520    swma_scalar(data, weights, period, first_valid, out)
521}
522
523#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
524#[target_feature(enable = "avx512f,avx512dq,fma")]
525unsafe fn swma_avx512_long(
526    data: &[f64],
527    weights: &[f64],
528    period: usize,
529    first_valid: usize,
530    out: &mut [f64],
531) {
532    swma_scalar(data, weights, period, first_valid, out)
533}
534
535#[derive(Debug, Clone)]
536pub struct SwmaStream {
537    period: usize,
538
539    a: usize,
540    b: usize,
541    inv_ab: f64,
542
543    ring_a: aligned_vec::AVec<f64>,
544    idx_a: usize,
545    cnt_a: usize,
546    s1_sum: f64,
547
548    ring_b: aligned_vec::AVec<f64>,
549    idx_b: usize,
550    cnt_b: usize,
551    s2_sum: f64,
552}
553
554impl SwmaStream {
555    pub fn try_new(params: SwmaParams) -> Result<Self, SwmaError> {
556        let period = params.period.unwrap_or(5);
557        if period == 0 {
558            return Err(SwmaError::InvalidPeriod {
559                period,
560                data_len: 0,
561            });
562        }
563
564        let (a, b) = if (period & 1) != 0 {
565            let m = (period + 1) >> 1;
566            (m, m)
567        } else {
568            let m = period >> 1;
569            (m, m + 1)
570        };
571
572        let mut ring_a = aligned_vec::AVec::<f64>::with_capacity(aligned_vec::CACHELINE_ALIGN, a);
573        ring_a.resize(a, 0.0);
574
575        let mut ring_b = aligned_vec::AVec::<f64>::with_capacity(aligned_vec::CACHELINE_ALIGN, b);
576        ring_b.resize(b, 0.0);
577
578        Ok(Self {
579            period,
580            a,
581            b,
582            inv_ab: 1.0 / ((a as f64) * (b as f64)),
583            ring_a,
584            idx_a: 0,
585            cnt_a: 0,
586            s1_sum: 0.0,
587            ring_b,
588            idx_b: 0,
589            cnt_b: 0,
590            s2_sum: 0.0,
591        })
592    }
593
594    #[inline(always)]
595    pub fn update(&mut self, x: f64) -> Option<f64> {
596        let ia = self.idx_a;
597
598        let old_a = self.ring_a[ia];
599
600        if self.cnt_a == self.a {
601            self.s1_sum -= old_a;
602        } else {
603            self.cnt_a += 1;
604        }
605        self.ring_a[ia] = x;
606        self.s1_sum += x;
607
608        self.idx_a = ia + 1;
609        if self.idx_a == self.a {
610            self.idx_a = 0;
611        }
612
613        if self.cnt_a == self.a {
614            let ib = self.idx_b;
615            let old_s1 = self.ring_b[ib];
616
617            if self.cnt_b == self.b {
618                self.s2_sum -= old_s1;
619            } else {
620                self.cnt_b += 1;
621            }
622            self.ring_b[ib] = self.s1_sum;
623            self.s2_sum += self.s1_sum;
624
625            self.idx_b = ib + 1;
626            if self.idx_b == self.b {
627                self.idx_b = 0;
628            }
629
630            if self.cnt_b == self.b {
631                return Some(self.s2_sum * self.inv_ab);
632            }
633        }
634
635        None
636    }
637}
638
639#[derive(Clone, Debug)]
640pub struct SwmaBatchRange {
641    pub period: (usize, usize, usize),
642}
643
644impl Default for SwmaBatchRange {
645    fn default() -> Self {
646        Self {
647            period: (5, 254, 1),
648        }
649    }
650}
651
652#[derive(Clone, Debug, Default)]
653pub struct SwmaBatchBuilder {
654    range: SwmaBatchRange,
655    kernel: Kernel,
656}
657
658impl SwmaBatchBuilder {
659    pub fn new() -> Self {
660        Self::default()
661    }
662    pub fn kernel(mut self, k: Kernel) -> Self {
663        self.kernel = k;
664        self
665    }
666
667    #[inline]
668    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
669        self.range.period = (start, end, step);
670        self
671    }
672    #[inline]
673    pub fn period_static(mut self, p: usize) -> Self {
674        self.range.period = (p, p, 0);
675        self
676    }
677
678    pub fn apply_slice(self, data: &[f64]) -> Result<SwmaBatchOutput, SwmaError> {
679        swma_batch_with_kernel(data, &self.range, self.kernel)
680    }
681
682    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<SwmaBatchOutput, SwmaError> {
683        SwmaBatchBuilder::new().kernel(k).apply_slice(data)
684    }
685
686    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<SwmaBatchOutput, SwmaError> {
687        let slice = source_type(c, src);
688        self.apply_slice(slice)
689    }
690
691    pub fn with_default_candles(c: &Candles) -> Result<SwmaBatchOutput, SwmaError> {
692        SwmaBatchBuilder::new()
693            .kernel(Kernel::Auto)
694            .apply_candles(c, "close")
695    }
696}
697
698pub fn swma_batch_with_kernel(
699    data: &[f64],
700    sweep: &SwmaBatchRange,
701    k: Kernel,
702) -> Result<SwmaBatchOutput, SwmaError> {
703    let kernel = match k {
704        Kernel::Auto => detect_best_batch_kernel(),
705        other if other.is_batch() => other,
706        _ => return Err(SwmaError::InvalidKernelForBatch(k)),
707    };
708
709    let simd = match kernel {
710        Kernel::Avx512Batch => Kernel::Avx512,
711        Kernel::Avx2Batch => Kernel::Avx2,
712        Kernel::ScalarBatch => Kernel::Scalar,
713        _ => unreachable!(),
714    };
715    swma_batch_par_slice(data, sweep, simd)
716}
717
718#[derive(Clone, Debug)]
719pub struct SwmaBatchOutput {
720    pub values: Vec<f64>,
721    pub combos: Vec<SwmaParams>,
722    pub rows: usize,
723    pub cols: usize,
724}
725
726impl SwmaBatchOutput {
727    pub fn row_for_params(&self, p: &SwmaParams) -> Option<usize> {
728        self.combos
729            .iter()
730            .position(|c| c.period.unwrap_or(5) == p.period.unwrap_or(5))
731    }
732
733    pub fn values_for(&self, p: &SwmaParams) -> Option<&[f64]> {
734        self.row_for_params(p).map(|row| {
735            let start = row * self.cols;
736            &self.values[start..start + self.cols]
737        })
738    }
739}
740
741#[inline(always)]
742fn expand_grid(r: &SwmaBatchRange) -> Vec<SwmaParams> {
743    fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
744        if step == 0 || start == end {
745            return vec![start];
746        }
747        if start < end {
748            return (start..=end).step_by(step.max(1)).collect();
749        }
750
751        let mut v = Vec::new();
752        let mut cur = start;
753        loop {
754            v.push(cur);
755            if cur <= end {
756                break;
757            }
758            match cur.checked_sub(step.max(1)) {
759                Some(next) => {
760                    cur = next;
761                    if cur < end {
762                        break;
763                    }
764                }
765                None => break,
766            }
767        }
768        v
769    }
770    let periods = axis_usize(r.period);
771    let mut out = Vec::with_capacity(periods.len());
772    for &p in &periods {
773        out.push(SwmaParams { period: Some(p) });
774    }
775    out
776}
777
778#[inline(always)]
779pub fn swma_batch_slice(
780    data: &[f64],
781    sweep: &SwmaBatchRange,
782    kern: Kernel,
783) -> Result<SwmaBatchOutput, SwmaError> {
784    swma_batch_inner(data, sweep, kern, false)
785}
786
787#[inline(always)]
788pub fn swma_batch_par_slice(
789    data: &[f64],
790    sweep: &SwmaBatchRange,
791    kern: Kernel,
792) -> Result<SwmaBatchOutput, SwmaError> {
793    swma_batch_inner(data, sweep, kern, true)
794}
795
796pub fn swma_batch_into_slice(
797    dst: &mut [f64],
798    data: &[f64],
799    sweep: &SwmaBatchRange,
800    k: Kernel,
801) -> Result<Vec<SwmaParams>, SwmaError> {
802    swma_batch_inner_into(data, sweep, k, true, dst)
803}
804
805#[inline(always)]
806fn swma_batch_inner(
807    data: &[f64],
808    sweep: &SwmaBatchRange,
809    kern: Kernel,
810    parallel: bool,
811) -> Result<SwmaBatchOutput, SwmaError> {
812    let combos = expand_grid(sweep);
813    if combos.is_empty() {
814        let (s, e, t) = sweep.period;
815        return Err(SwmaError::InvalidRange {
816            start: s,
817            end: e,
818            step: t,
819        });
820    }
821
822    let len = data.len();
823    if len == 0 {
824        return Err(SwmaError::EmptyInputData);
825    }
826
827    let first = data
828        .iter()
829        .position(|x| !x.is_nan())
830        .ok_or(SwmaError::AllValuesNaN)?;
831    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
832
833    if max_p == 0 || max_p > len {
834        return Err(SwmaError::InvalidPeriod {
835            period: max_p,
836            data_len: len,
837        });
838    }
839    if len - first < max_p {
840        return Err(SwmaError::NotEnoughValidData {
841            needed: max_p,
842            valid: len - first,
843        });
844    }
845
846    let rows = combos.len();
847    let cols = data.len();
848    let cap = rows.checked_mul(max_p).ok_or_else(|| {
849        let (s, e, t) = sweep.period;
850        SwmaError::InvalidRange {
851            start: s,
852            end: e,
853            step: t,
854        }
855    })?;
856    let mut flat_w = AVec::<f64>::with_capacity(CACHELINE_ALIGN, cap);
857    flat_w.resize(cap, 0.0);
858
859    for (row, combo) in combos.iter().enumerate() {
860        let period = combo.period.unwrap();
861        let w_start = row * max_p;
862
863        if period == 1 {
864            flat_w[w_start] = 1.0;
865        } else if period == 2 {
866            flat_w[w_start] = 0.5;
867            flat_w[w_start + 1] = 0.5;
868        } else if period % 2 == 0 {
869            let half = period / 2;
870
871            for i in 1..=half {
872                flat_w[w_start + i - 1] = i as f64;
873            }
874
875            for i in (1..=half).rev() {
876                flat_w[w_start + period - i] = i as f64;
877            }
878
879            let sum: f64 = flat_w[w_start..w_start + period].iter().sum();
880            for i in 0..period {
881                flat_w[w_start + i] /= sum;
882            }
883        } else {
884            let half_plus = (period + 1) / 2;
885
886            for i in 1..=half_plus {
887                flat_w[w_start + i - 1] = i as f64;
888            }
889
890            for i in (1..half_plus).rev() {
891                flat_w[w_start + period - i] = i as f64;
892            }
893
894            let sum: f64 = flat_w[w_start..w_start + period].iter().sum();
895            for i in 0..period {
896                flat_w[w_start + i] /= sum;
897            }
898        }
899    }
900
901    let warm: Vec<usize> = combos
902        .iter()
903        .map(|c| first + c.period.unwrap() - 1)
904        .collect();
905
906    let _ = rows.checked_mul(cols).ok_or_else(|| {
907        let (s, e, t) = sweep.period;
908        SwmaError::InvalidRange {
909            start: s,
910            end: e,
911            step: t,
912        }
913    })?;
914    let mut buf_mu = make_uninit_matrix(rows, cols);
915    init_matrix_prefixes(&mut buf_mu, cols, &warm);
916
917    let actual_kern = match kern {
918        Kernel::Auto => detect_best_batch_kernel(),
919        k => k,
920    };
921    let simd = match actual_kern {
922        Kernel::Avx512Batch => Kernel::Avx512,
923        Kernel::Avx2Batch => Kernel::Avx2,
924        Kernel::ScalarBatch => Kernel::Scalar,
925
926        other => other,
927    };
928
929    let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
930        let period = combos[row].period.unwrap();
931        let w_ptr = flat_w.as_ptr().add(row * max_p);
932        let out_row =
933            core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
934        match simd {
935            Kernel::Scalar => swma_row_scalar(data, first, period, w_ptr, out_row),
936            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
937            Kernel::Avx2 => swma_row_avx2(data, first, period, w_ptr, out_row),
938            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
939            Kernel::Avx512 => swma_row_avx512(data, first, period, w_ptr, out_row),
940            _ => swma_row_scalar(data, first, period, w_ptr, out_row),
941        }
942    };
943
944    {
945        use std::mem::MaybeUninit;
946        let rows_mut: &mut [MaybeUninit<f64>] = &mut buf_mu;
947        #[cfg(not(target_arch = "wasm32"))]
948        if parallel {
949            use rayon::prelude::*;
950            rows_mut
951                .par_chunks_mut(cols)
952                .enumerate()
953                .for_each(|(row, slice)| do_row(row, slice));
954        } else {
955            for (row, slice) in rows_mut.chunks_mut(cols).enumerate() {
956                do_row(row, slice);
957            }
958        }
959        #[cfg(target_arch = "wasm32")]
960        {
961            for (row, slice) in rows_mut.chunks_mut(cols).enumerate() {
962                do_row(row, slice);
963            }
964        }
965    }
966
967    use core::mem::ManuallyDrop;
968    let mut guard = ManuallyDrop::new(buf_mu);
969    let values = unsafe {
970        Vec::from_raw_parts(
971            guard.as_mut_ptr() as *mut f64,
972            guard.len(),
973            guard.capacity(),
974        )
975    };
976
977    Ok(SwmaBatchOutput {
978        values,
979        combos,
980        rows,
981        cols,
982    })
983}
984
985#[inline(always)]
986fn swma_batch_inner_into(
987    data: &[f64],
988    sweep: &SwmaBatchRange,
989    kern: Kernel,
990    parallel: bool,
991    out: &mut [f64],
992) -> Result<Vec<SwmaParams>, SwmaError> {
993    let combos = expand_grid(sweep);
994    if combos.is_empty() {
995        let (s, e, t) = sweep.period;
996        return Err(SwmaError::InvalidRange {
997            start: s,
998            end: e,
999            step: t,
1000        });
1001    }
1002
1003    let len = data.len();
1004    if len == 0 {
1005        return Err(SwmaError::EmptyInputData);
1006    }
1007
1008    let first = data
1009        .iter()
1010        .position(|x| !x.is_nan())
1011        .ok_or(SwmaError::AllValuesNaN)?;
1012    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1013
1014    if max_p == 0 || max_p > len {
1015        return Err(SwmaError::InvalidPeriod {
1016            period: max_p,
1017            data_len: len,
1018        });
1019    }
1020    if len - first < max_p {
1021        return Err(SwmaError::NotEnoughValidData {
1022            needed: max_p,
1023            valid: len - first,
1024        });
1025    }
1026
1027    let rows = combos.len();
1028    let cols = data.len();
1029    let cap = rows.checked_mul(max_p).ok_or_else(|| {
1030        let (s, e, t) = sweep.period;
1031        SwmaError::InvalidRange {
1032            start: s,
1033            end: e,
1034            step: t,
1035        }
1036    })?;
1037    let mut flat_w = AVec::<f64>::with_capacity(CACHELINE_ALIGN, cap);
1038    flat_w.resize(cap, 0.0);
1039
1040    for (row, combo) in combos.iter().enumerate() {
1041        let period = combo.period.unwrap();
1042        let w_start = row * max_p;
1043
1044        if period == 1 {
1045            flat_w[w_start] = 1.0;
1046        } else if period == 2 {
1047            flat_w[w_start] = 0.5;
1048            flat_w[w_start + 1] = 0.5;
1049        } else if period % 2 == 0 {
1050            let half = period / 2;
1051
1052            for i in 1..=half {
1053                flat_w[w_start + i - 1] = i as f64;
1054            }
1055
1056            for i in (1..=half).rev() {
1057                flat_w[w_start + period - i] = i as f64;
1058            }
1059
1060            let sum: f64 = flat_w[w_start..w_start + period].iter().sum();
1061            for i in 0..period {
1062                flat_w[w_start + i] /= sum;
1063            }
1064        } else {
1065            let half_plus = (period + 1) / 2;
1066
1067            for i in 1..=half_plus {
1068                flat_w[w_start + i - 1] = i as f64;
1069            }
1070
1071            for i in (1..half_plus).rev() {
1072                flat_w[w_start + period - i] = i as f64;
1073            }
1074
1075            let sum: f64 = flat_w[w_start..w_start + period].iter().sum();
1076            for i in 0..period {
1077                flat_w[w_start + i] /= sum;
1078            }
1079        }
1080    }
1081
1082    let warm: Vec<usize> = combos
1083        .iter()
1084        .map(|c| first + c.period.unwrap() - 1)
1085        .collect();
1086    let expected_len = rows.checked_mul(cols).ok_or_else(|| {
1087        let (s, e, t) = sweep.period;
1088        SwmaError::InvalidRange {
1089            start: s,
1090            end: e,
1091            step: t,
1092        }
1093    })?;
1094    if out.len() != expected_len {
1095        return Err(SwmaError::OutputLengthMismatch {
1096            expected: expected_len,
1097            got: out.len(),
1098        });
1099    }
1100    let out_uninit = unsafe {
1101        std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
1102    };
1103    init_matrix_prefixes(out_uninit, cols, &warm);
1104
1105    let actual_kern = match kern {
1106        Kernel::Auto => detect_best_batch_kernel(),
1107        k => k,
1108    };
1109    let simd = match actual_kern {
1110        Kernel::Avx512Batch => Kernel::Avx512,
1111        Kernel::Avx2Batch => Kernel::Avx2,
1112        Kernel::ScalarBatch => Kernel::Scalar,
1113        other => other,
1114    };
1115
1116    let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
1117        let period = combos[row].period.unwrap();
1118        let w_ptr = flat_w.as_ptr().add(row * max_p);
1119        let out_row =
1120            core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
1121        match simd {
1122            Kernel::Scalar => swma_row_scalar(data, first, period, w_ptr, out_row),
1123            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1124            Kernel::Avx2 => swma_row_avx2(data, first, period, w_ptr, out_row),
1125            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1126            Kernel::Avx512 => swma_row_avx512(data, first, period, w_ptr, out_row),
1127            _ => swma_row_scalar(data, first, period, w_ptr, out_row),
1128        }
1129    };
1130
1131    if parallel {
1132        #[cfg(not(target_arch = "wasm32"))]
1133        {
1134            out_uninit
1135                .par_chunks_mut(cols)
1136                .enumerate()
1137                .for_each(|(row, slice)| do_row(row, slice));
1138        }
1139        #[cfg(target_arch = "wasm32")]
1140        {
1141            for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
1142                do_row(row, slice);
1143            }
1144        }
1145    } else {
1146        for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
1147            do_row(row, slice);
1148        }
1149    }
1150
1151    Ok(combos)
1152}
1153
1154#[inline(always)]
1155unsafe fn swma_row_scalar(
1156    data: &[f64],
1157    first: usize,
1158    period: usize,
1159    _w_ptr: *const f64,
1160    out: &mut [f64],
1161) {
1162    let len = data.len();
1163    if len == 0 {
1164        return;
1165    }
1166
1167    let (a, b) = if (period & 1) != 0 {
1168        let m = (period + 1) >> 1;
1169        (m, m)
1170    } else {
1171        let m = period >> 1;
1172        (m, m + 1)
1173    };
1174
1175    if period == 1 {
1176        for i in first..len {
1177            *out.get_unchecked_mut(i) = *data.get_unchecked(i);
1178        }
1179        return;
1180    }
1181    if period == 2 {
1182        for i in (first + 1)..len {
1183            *out.get_unchecked_mut(i) = (*data.get_unchecked(i - 1) + *data.get_unchecked(i)) * 0.5;
1184        }
1185        return;
1186    }
1187
1188    let inv_ab = 1.0 / ((a as f64) * (b as f64));
1189    let start_full_a = first + a - 1;
1190    let start_full_ab = first + period - 1;
1191
1192    let mut ring = AVec::<f64>::with_capacity(CACHELINE_ALIGN, b);
1193    ring.resize(b, 0.0);
1194    let mut rb_idx = 0usize;
1195
1196    let mut s1_sum = 0.0_f64;
1197    let mut s2_sum = 0.0_f64;
1198
1199    for i in first..len {
1200        s1_sum += *data.get_unchecked(i);
1201
1202        if i >= start_full_a {
1203            let old = *ring.get_unchecked(rb_idx);
1204            s2_sum = s2_sum + (s1_sum - old);
1205            *ring.get_unchecked_mut(rb_idx) = s1_sum;
1206            rb_idx += 1;
1207            if rb_idx == b {
1208                rb_idx = 0;
1209            }
1210
1211            if i >= start_full_ab {
1212                *out.get_unchecked_mut(i) = s2_sum * inv_ab;
1213            }
1214
1215            s1_sum -= *data.get_unchecked(i + 1 - a);
1216        }
1217    }
1218}
1219
1220#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1221#[target_feature(enable = "avx2,fma")]
1222unsafe fn swma_row_avx2(
1223    data: &[f64],
1224    first: usize,
1225    period: usize,
1226    w_ptr: *const f64,
1227    out: &mut [f64],
1228) {
1229    swma_row_scalar(data, first, period, w_ptr, out)
1230}
1231
1232#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1233#[target_feature(enable = "avx512f,avx512dq,fma")]
1234pub unsafe fn swma_row_avx512(
1235    data: &[f64],
1236    first: usize,
1237    period: usize,
1238    w_ptr: *const f64,
1239    out: &mut [f64],
1240) {
1241    if period <= 32 {
1242        swma_row_avx512_short(data, first, period, w_ptr, out);
1243    } else {
1244        swma_row_avx512_long(data, first, period, w_ptr, out);
1245    }
1246}
1247
1248#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1249#[target_feature(enable = "avx512f,fma")]
1250unsafe fn swma_row_avx512_short(
1251    data: &[f64],
1252    first: usize,
1253    period: usize,
1254    w_ptr: *const f64,
1255    out: &mut [f64],
1256) {
1257    swma_row_scalar(data, first, period, w_ptr, out)
1258}
1259
1260#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1261#[target_feature(enable = "avx512f,avx512dq,fma")]
1262unsafe fn swma_row_avx512_long(
1263    data: &[f64],
1264    first: usize,
1265    period: usize,
1266    w_ptr: *const f64,
1267    out: &mut [f64],
1268) {
1269    swma_row_scalar(data, first, period, w_ptr, out)
1270}
1271
1272#[cfg(test)]
1273mod tests {
1274    use super::*;
1275    use crate::skip_if_unsupported;
1276    use crate::utilities::data_loader::read_candles_from_csv;
1277
1278    fn check_swma_partial_params(
1279        test_name: &str,
1280        kernel: Kernel,
1281    ) -> Result<(), Box<dyn std::error::Error>> {
1282        skip_if_unsupported!(kernel, test_name);
1283        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1284        let candles = read_candles_from_csv(file_path)?;
1285        let default_params = SwmaParams { period: None };
1286        let input = SwmaInput::from_candles(&candles, "close", default_params);
1287        let output = swma_with_kernel(&input, kernel)?;
1288        assert_eq!(output.values.len(), candles.close.len());
1289        Ok(())
1290    }
1291
1292    fn check_swma_accuracy(
1293        test_name: &str,
1294        kernel: Kernel,
1295    ) -> Result<(), Box<dyn std::error::Error>> {
1296        skip_if_unsupported!(kernel, test_name);
1297        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1298        let candles = read_candles_from_csv(file_path)?;
1299        let input = SwmaInput::from_candles(&candles, "close", SwmaParams::default());
1300        let result = swma_with_kernel(&input, kernel)?;
1301        let expected_last_five = [
1302            59288.22222222222,
1303            59301.99999999999,
1304            59247.33333333333,
1305            59179.88888888889,
1306            59080.99999999999,
1307        ];
1308        let start = result.values.len().saturating_sub(5);
1309        for (i, &val) in result.values[start..].iter().enumerate() {
1310            let diff = (val - expected_last_five[i]).abs();
1311            assert!(
1312                diff < 1e-8,
1313                "[{}] SWMA {:?} mismatch at idx {}: got {}, expected {}",
1314                test_name,
1315                kernel,
1316                i,
1317                val,
1318                expected_last_five[i]
1319            );
1320        }
1321        Ok(())
1322    }
1323
1324    fn check_swma_default_candles(
1325        test_name: &str,
1326        kernel: Kernel,
1327    ) -> Result<(), Box<dyn std::error::Error>> {
1328        skip_if_unsupported!(kernel, test_name);
1329        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1330        let candles = read_candles_from_csv(file_path)?;
1331        let input = SwmaInput::with_default_candles(&candles);
1332        match input.data {
1333            SwmaData::Candles { source, .. } => assert_eq!(source, "close"),
1334            _ => panic!("Expected SwmaData::Candles"),
1335        }
1336        let output = swma_with_kernel(&input, kernel)?;
1337        assert_eq!(output.values.len(), candles.close.len());
1338        Ok(())
1339    }
1340
1341    fn check_swma_zero_period(
1342        test_name: &str,
1343        kernel: Kernel,
1344    ) -> Result<(), Box<dyn std::error::Error>> {
1345        skip_if_unsupported!(kernel, test_name);
1346        let input_data = [10.0, 20.0, 30.0];
1347        let params = SwmaParams { period: Some(0) };
1348        let input = SwmaInput::from_slice(&input_data, params);
1349        let res = swma_with_kernel(&input, kernel);
1350        assert!(
1351            res.is_err(),
1352            "[{}] SWMA should fail with zero period",
1353            test_name
1354        );
1355        Ok(())
1356    }
1357
1358    fn check_swma_period_exceeds_length(
1359        test_name: &str,
1360        kernel: Kernel,
1361    ) -> Result<(), Box<dyn std::error::Error>> {
1362        skip_if_unsupported!(kernel, test_name);
1363        let data_small = [10.0, 20.0, 30.0];
1364        let params = SwmaParams { period: Some(10) };
1365        let input = SwmaInput::from_slice(&data_small, params);
1366        let res = swma_with_kernel(&input, kernel);
1367        assert!(
1368            res.is_err(),
1369            "[{}] SWMA should fail with period exceeding length",
1370            test_name
1371        );
1372        Ok(())
1373    }
1374
1375    fn check_swma_very_small_dataset(
1376        test_name: &str,
1377        kernel: Kernel,
1378    ) -> Result<(), Box<dyn std::error::Error>> {
1379        skip_if_unsupported!(kernel, test_name);
1380        let single_point = [42.0];
1381        let params = SwmaParams { period: Some(5) };
1382        let input = SwmaInput::from_slice(&single_point, params);
1383        let res = swma_with_kernel(&input, kernel);
1384        assert!(
1385            res.is_err(),
1386            "[{}] SWMA should fail with insufficient data",
1387            test_name
1388        );
1389        Ok(())
1390    }
1391
1392    fn check_swma_reinput(
1393        test_name: &str,
1394        kernel: Kernel,
1395    ) -> Result<(), Box<dyn std::error::Error>> {
1396        skip_if_unsupported!(kernel, test_name);
1397        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1398        let candles = read_candles_from_csv(file_path)?;
1399        let first_params = SwmaParams { period: Some(5) };
1400        let first_input = SwmaInput::from_candles(&candles, "close", first_params);
1401        let first_result = swma_with_kernel(&first_input, kernel)?;
1402        let second_params = SwmaParams { period: Some(3) };
1403        let second_input = SwmaInput::from_slice(&first_result.values, second_params);
1404        let second_result = swma_with_kernel(&second_input, kernel)?;
1405        assert_eq!(second_result.values.len(), first_result.values.len());
1406        Ok(())
1407    }
1408
1409    fn check_swma_nan_handling(
1410        test_name: &str,
1411        kernel: Kernel,
1412    ) -> Result<(), Box<dyn std::error::Error>> {
1413        skip_if_unsupported!(kernel, test_name);
1414        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1415        let candles = read_candles_from_csv(file_path)?;
1416        let params = SwmaParams { period: Some(5) };
1417        let input = SwmaInput::from_candles(&candles, "close", params);
1418        let res = swma_with_kernel(&input, kernel)?;
1419        assert_eq!(res.values.len(), candles.close.len());
1420        if res.values.len() > 240 {
1421            for (i, &val) in res.values[240..].iter().enumerate() {
1422                assert!(
1423                    !val.is_nan(),
1424                    "[{}] Found unexpected NaN at out-index {}",
1425                    test_name,
1426                    240 + i
1427                );
1428            }
1429        }
1430        Ok(())
1431    }
1432
1433    fn check_swma_streaming(
1434        test_name: &str,
1435        kernel: Kernel,
1436    ) -> Result<(), Box<dyn std::error::Error>> {
1437        skip_if_unsupported!(kernel, test_name);
1438        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1439        let candles = read_candles_from_csv(file_path)?;
1440        let period = 5;
1441        let input = SwmaInput::from_candles(
1442            &candles,
1443            "close",
1444            SwmaParams {
1445                period: Some(period),
1446            },
1447        );
1448        let batch_output = swma_with_kernel(&input, kernel)?.values;
1449        let mut stream = SwmaStream::try_new(SwmaParams {
1450            period: Some(period),
1451        })?;
1452        let mut stream_values = Vec::with_capacity(candles.close.len());
1453        for &price in &candles.close {
1454            match stream.update(price) {
1455                Some(swma_val) => stream_values.push(swma_val),
1456                None => stream_values.push(f64::NAN),
1457            }
1458        }
1459        assert_eq!(batch_output.len(), stream_values.len());
1460        for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1461            if b.is_nan() && s.is_nan() {
1462                continue;
1463            }
1464            let diff = (b - s).abs();
1465            assert!(
1466                diff < 1e-9,
1467                "[{}] SWMA streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1468                test_name,
1469                i,
1470                b,
1471                s,
1472                diff
1473            );
1474        }
1475        Ok(())
1476    }
1477
1478    macro_rules! generate_all_swma_tests {
1479        ($($test_fn:ident),*) => {
1480            paste::paste! {
1481                $(
1482                    #[test]
1483                    fn [<$test_fn _scalar_f64>]() {
1484                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1485                    }
1486                )*
1487                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1488                $(
1489                    #[test]
1490                    fn [<$test_fn _avx2_f64>]() {
1491                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1492                    }
1493                    #[test]
1494                    fn [<$test_fn _avx512_f64>]() {
1495                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1496                    }
1497                )*
1498            }
1499        }
1500    }
1501
1502    #[cfg(debug_assertions)]
1503    fn check_swma_no_poison(
1504        test_name: &str,
1505        kernel: Kernel,
1506    ) -> Result<(), Box<dyn std::error::Error>> {
1507        skip_if_unsupported!(kernel, test_name);
1508
1509        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1510        let candles = read_candles_from_csv(file_path)?;
1511
1512        let test_periods = vec![1, 2, 3, 5, 7, 10, 15, 20, 30, 50, 100];
1513
1514        for period in test_periods {
1515            let params = SwmaParams {
1516                period: Some(period),
1517            };
1518            let input = SwmaInput::from_candles(&candles, "close", params);
1519
1520            if period > candles.close.len() {
1521                continue;
1522            }
1523
1524            let output = swma_with_kernel(&input, kernel)?;
1525
1526            for (i, &val) in output.values.iter().enumerate() {
1527                if val.is_nan() {
1528                    continue;
1529                }
1530
1531                let bits = val.to_bits();
1532
1533                if bits == 0x11111111_11111111 {
1534                    panic!(
1535						"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} with period {}",
1536						test_name, val, bits, i, period
1537					);
1538                }
1539
1540                if bits == 0x22222222_22222222 {
1541                    panic!(
1542						"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} with period {}",
1543						test_name, val, bits, i, period
1544					);
1545                }
1546
1547                if bits == 0x33333333_33333333 {
1548                    panic!(
1549						"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} with period {}",
1550						test_name, val, bits, i, period
1551					);
1552                }
1553            }
1554        }
1555
1556        Ok(())
1557    }
1558
1559    #[cfg(not(debug_assertions))]
1560    fn check_swma_no_poison(
1561        _test_name: &str,
1562        _kernel: Kernel,
1563    ) -> Result<(), Box<dyn std::error::Error>> {
1564        Ok(())
1565    }
1566
1567    #[cfg(feature = "proptest")]
1568    fn check_swma_property(
1569        test_name: &str,
1570        kernel: Kernel,
1571    ) -> Result<(), Box<dyn std::error::Error>> {
1572        use proptest::prelude::*;
1573        skip_if_unsupported!(kernel, test_name);
1574
1575        let strat = (1usize..=100).prop_flat_map(|period| {
1576            (
1577                prop::collection::vec(
1578                    (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1579                    period.max(2)..400,
1580                ),
1581                Just(period),
1582            )
1583        });
1584
1585        proptest::test_runner::TestRunner::default()
1586            .run(&strat, |(data, period)| {
1587                let params = SwmaParams {
1588                    period: Some(period),
1589                };
1590                let input = SwmaInput::from_slice(&data, params);
1591
1592                let SwmaOutput { values: out } = swma_with_kernel(&input, kernel).unwrap();
1593                let SwmaOutput { values: ref_out } =
1594                    swma_with_kernel(&input, Kernel::Scalar).unwrap();
1595
1596                prop_assert_eq!(out.len(), data.len(), "Output length mismatch");
1597
1598                if period > 1 {
1599                    for i in 0..(period - 1) {
1600                        prop_assert!(
1601                            out[i].is_nan(),
1602                            "Expected NaN during warmup at index {}, got {}",
1603                            i,
1604                            out[i]
1605                        );
1606                    }
1607                }
1608
1609                let weights = build_symmetric_triangle_avec(period);
1610
1611                let weight_sum: f64 = weights.iter().sum();
1612                prop_assert!(
1613                    (weight_sum - 1.0).abs() < 1e-10,
1614                    "Weights don't sum to 1.0, got {}",
1615                    weight_sum
1616                );
1617
1618                for i in 0..period / 2 {
1619                    let left = weights[i];
1620                    let right = weights[period - 1 - i];
1621                    prop_assert!(
1622                        (left - right).abs() < 1e-10,
1623                        "Weights not symmetric at positions {} and {}: {} vs {}",
1624                        i,
1625                        period - 1 - i,
1626                        left,
1627                        right
1628                    );
1629                }
1630
1631                for i in (period - 1)..data.len() {
1632                    let window = &data[i + 1 - period..=i];
1633                    let lo = window.iter().cloned().fold(f64::INFINITY, f64::min);
1634                    let hi = window.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1635                    let y = out[i];
1636                    let r = ref_out[i];
1637
1638                    prop_assert!(
1639                        y.is_nan() || (y >= lo - 1e-9 && y <= hi + 1e-9),
1640                        "idx {}: {} ∉ [{}, {}]",
1641                        i,
1642                        y,
1643                        lo,
1644                        hi
1645                    );
1646
1647                    if period == 1 {
1648                        prop_assert!(
1649                            (y - data[i]).abs() <= f64::EPSILON,
1650                            "Period=1 should return input value at idx {}: {} vs {}",
1651                            i,
1652                            y,
1653                            data[i]
1654                        );
1655                    }
1656
1657                    if period == 2 && i >= 1 {
1658                        let expected = (data[i - 1] + data[i]) / 2.0;
1659                        prop_assert!(
1660                            (y - expected).abs() < 1e-9,
1661                            "Period=2 should return average at idx {}: {} vs {}",
1662                            i,
1663                            y,
1664                            expected
1665                        );
1666                    }
1667
1668                    if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10) {
1669                        prop_assert!(
1670                            (y - data[0]).abs() < 1e-9,
1671                            "Constant data should produce constant output at idx {}: {} vs {}",
1672                            i,
1673                            y,
1674                            data[0]
1675                        );
1676                    }
1677
1678                    let y_bits = y.to_bits();
1679                    let r_bits = r.to_bits();
1680
1681                    if !y.is_finite() || !r.is_finite() {
1682                        prop_assert!(
1683                            y.to_bits() == r.to_bits(),
1684                            "finite/NaN mismatch idx {}: {} vs {}",
1685                            i,
1686                            y,
1687                            r
1688                        );
1689                        continue;
1690                    }
1691
1692                    let ulp_diff: u64 = y_bits.abs_diff(r_bits);
1693
1694                    let max_ulp = if matches!(kernel, Kernel::Avx512) {
1695                        20
1696                    } else {
1697                        10
1698                    };
1699
1700                    prop_assert!(
1701                        (y - r).abs() <= 1e-9 || ulp_diff <= max_ulp,
1702                        "mismatch idx {}: {} vs {} (ULP={})",
1703                        i,
1704                        y,
1705                        r,
1706                        ulp_diff
1707                    );
1708                }
1709
1710                Ok(())
1711            })
1712            .unwrap();
1713
1714        Ok(())
1715    }
1716
1717    generate_all_swma_tests!(
1718        check_swma_partial_params,
1719        check_swma_accuracy,
1720        check_swma_default_candles,
1721        check_swma_zero_period,
1722        check_swma_period_exceeds_length,
1723        check_swma_very_small_dataset,
1724        check_swma_reinput,
1725        check_swma_nan_handling,
1726        check_swma_streaming,
1727        check_swma_no_poison
1728    );
1729
1730    #[cfg(feature = "proptest")]
1731    generate_all_swma_tests!(check_swma_property);
1732
1733    fn check_batch_default_row(
1734        test: &str,
1735        kernel: Kernel,
1736    ) -> Result<(), Box<dyn std::error::Error>> {
1737        skip_if_unsupported!(kernel, test);
1738
1739        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1740        let c = read_candles_from_csv(file)?;
1741
1742        let output = SwmaBatchBuilder::new()
1743            .kernel(kernel)
1744            .apply_candles(&c, "close")?;
1745
1746        let def = SwmaParams::default();
1747        let period = def.period.unwrap_or(5);
1748        let row = output.values_for(&def).expect("default row missing");
1749
1750        assert_eq!(row.len(), c.close.len());
1751
1752        let expected = [
1753            59288.22222222222,
1754            59301.99999999999,
1755            59247.33333333333,
1756            59179.88888888889,
1757            59080.99999999999,
1758        ];
1759        let tail = &row[row.len() - 5..];
1760        for (i, &v) in tail.iter().enumerate() {
1761            assert!(
1762                (v - expected[i]).abs() < 1e-8,
1763                "[{test}] default-row mismatch at idx {i}: {v} vs {}",
1764                expected[i]
1765            );
1766        }
1767        Ok(())
1768    }
1769
1770    macro_rules! gen_batch_tests {
1771        ($fn_name:ident) => {
1772            paste::paste! {
1773                #[test] fn [<$fn_name _scalar>]()      {
1774                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1775                }
1776                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1777                #[test] fn [<$fn_name _avx2>]()        {
1778                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1779                }
1780                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1781                #[test] fn [<$fn_name _avx512>]()      {
1782                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1783                }
1784                #[test] fn [<$fn_name _auto_detect>]() {
1785                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1786                }
1787            }
1788        };
1789    }
1790
1791    #[cfg(debug_assertions)]
1792    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
1793        skip_if_unsupported!(kernel, test);
1794
1795        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1796        let c = read_candles_from_csv(file)?;
1797
1798        let batch_configs = vec![
1799            (1, 10, 1),
1800            (3, 9, 3),
1801            (5, 25, 5),
1802            (10, 50, 10),
1803            (2, 2, 1),
1804            (1, 30, 2),
1805        ];
1806
1807        for (start, end, step) in batch_configs {
1808            if end > c.close.len() {
1809                continue;
1810            }
1811
1812            let output = SwmaBatchBuilder::new()
1813                .kernel(kernel)
1814                .period_range(start, end, step)
1815                .apply_candles(&c, "close")?;
1816
1817            for (idx, &val) in output.values.iter().enumerate() {
1818                if val.is_nan() {
1819                    continue;
1820                }
1821
1822                let bits = val.to_bits();
1823                let row = idx / output.cols;
1824                let col = idx % output.cols;
1825                let period = if row < output.combos.len() {
1826                    output.combos[row].period.unwrap_or(0)
1827                } else {
1828                    0
1829                };
1830
1831                if bits == 0x11111111_11111111 {
1832                    panic!(
1833                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (flat index {}) with period {} in batch ({}, {}, {})",
1834                        test, val, bits, row, col, idx, period, start, end, step
1835                    );
1836                }
1837
1838                if bits == 0x22222222_22222222 {
1839                    panic!(
1840                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (flat index {}) with period {} in batch ({}, {}, {})",
1841                        test, val, bits, row, col, idx, period, start, end, step
1842                    );
1843                }
1844
1845                if bits == 0x33333333_33333333 {
1846                    panic!(
1847                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (flat index {}) with period {} in batch ({}, {}, {})",
1848                        test, val, bits, row, col, idx, period, start, end, step
1849                    );
1850                }
1851            }
1852        }
1853
1854        Ok(())
1855    }
1856
1857    #[cfg(not(debug_assertions))]
1858    fn check_batch_no_poison(
1859        _test: &str,
1860        _kernel: Kernel,
1861    ) -> Result<(), Box<dyn std::error::Error>> {
1862        Ok(())
1863    }
1864
1865    #[test]
1866    fn test_swma_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
1867        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1868        let candles = read_candles_from_csv(file_path)?;
1869
1870        let input = SwmaInput::with_default_candles(&candles);
1871        let baseline = swma(&input)?.values;
1872
1873        let mut out = vec![0.0f64; baseline.len()];
1874
1875        #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1876        {
1877            swma_into(&input, &mut out)?;
1878        }
1879        #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1880        {
1881            swma_into_slice(&mut out, &input, Kernel::Auto)?;
1882        }
1883
1884        assert_eq!(out.len(), baseline.len());
1885
1886        for (i, (&a, &b)) in out.iter().zip(baseline.iter()).enumerate() {
1887            let equal = (a.is_nan() && b.is_nan()) || (a == b);
1888            assert!(
1889                equal,
1890                "into parity mismatch at idx {}: got {}, expected {}",
1891                i, a, b
1892            );
1893        }
1894
1895        Ok(())
1896    }
1897
1898    gen_batch_tests!(check_batch_default_row);
1899    gen_batch_tests!(check_batch_no_poison);
1900}
1901
1902#[cfg(feature = "python")]
1903#[pyfunction(name = "swma")]
1904#[pyo3(signature = (data, period, kernel=None))]
1905
1906pub fn swma_py<'py>(
1907    py: Python<'py>,
1908    data: numpy::PyReadonlyArray1<'py, f64>,
1909    period: usize,
1910    kernel: Option<&str>,
1911) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
1912    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1913
1914    let slice_in = data.as_slice()?;
1915    let kern = validate_kernel(kernel, false)?;
1916
1917    let params = SwmaParams {
1918        period: Some(period),
1919    };
1920    let swma_in = SwmaInput::from_slice(slice_in, params);
1921
1922    let result_vec: Vec<f64> = py
1923        .allow_threads(|| swma_with_kernel(&swma_in, kern).map(|o| o.values))
1924        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1925
1926    Ok(result_vec.into_pyarray(py))
1927}
1928
1929#[cfg(feature = "python")]
1930#[pyclass(name = "SwmaStream")]
1931pub struct SwmaStreamPy {
1932    stream: SwmaStream,
1933}
1934
1935#[cfg(feature = "python")]
1936#[pymethods]
1937impl SwmaStreamPy {
1938    #[new]
1939    fn new(period: usize) -> PyResult<Self> {
1940        let params = SwmaParams {
1941            period: Some(period),
1942        };
1943        let stream =
1944            SwmaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1945        Ok(SwmaStreamPy { stream })
1946    }
1947
1948    fn update(&mut self, value: f64) -> Option<f64> {
1949        self.stream.update(value)
1950    }
1951}
1952
1953#[cfg(feature = "python")]
1954#[pyfunction(name = "swma_batch")]
1955#[pyo3(signature = (data, period_range, kernel=None))]
1956
1957pub fn swma_batch_py<'py>(
1958    py: Python<'py>,
1959    data: numpy::PyReadonlyArray1<'py, f64>,
1960    period_range: (usize, usize, usize),
1961    kernel: Option<&str>,
1962) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1963    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1964    use pyo3::types::PyDict;
1965
1966    let slice_in = data.as_slice()?;
1967    let kern = validate_kernel(kernel, true)?;
1968
1969    let sweep = SwmaBatchRange {
1970        period: period_range,
1971    };
1972
1973    let combos = expand_grid(&sweep);
1974    let rows = combos.len();
1975    let cols = slice_in.len();
1976
1977    let rows_cols = rows
1978        .checked_mul(cols)
1979        .ok_or_else(|| PyValueError::new_err("swma: rows*cols overflow during allocation"))?;
1980    let out_arr = unsafe { PyArray1::<f64>::new(py, [rows_cols], false) };
1981    let slice_out = unsafe { out_arr.as_slice_mut()? };
1982
1983    let combos = py
1984        .allow_threads(|| swma_batch_inner_into(slice_in, &sweep, kern, true, slice_out))
1985        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1986
1987    let dict = PyDict::new(py);
1988    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1989    dict.set_item(
1990        "periods",
1991        combos
1992            .iter()
1993            .map(|c| c.period.unwrap_or(5))
1994            .collect::<Vec<_>>()
1995            .into_pyarray(py),
1996    )?;
1997
1998    Ok(dict)
1999}
2000
2001#[cfg(all(feature = "python", feature = "cuda"))]
2002#[pyfunction(name = "swma_cuda_batch_dev")]
2003#[pyo3(signature = (data, period_range, device_id=0))]
2004pub fn swma_cuda_batch_dev_py(
2005    py: Python<'_>,
2006    data: numpy::PyReadonlyArray1<'_, f64>,
2007    period_range: (usize, usize, usize),
2008    device_id: usize,
2009) -> PyResult<DeviceArrayF32SwmaPy> {
2010    use numpy::PyArrayMethods;
2011
2012    if !cuda_available() {
2013        return Err(PyValueError::new_err("CUDA not available"));
2014    }
2015
2016    let slice_in = data.as_slice()?;
2017    let sweep = SwmaBatchRange {
2018        period: period_range,
2019    };
2020    let data_f32: Vec<f32> = slice_in.iter().map(|&v| v as f32).collect();
2021
2022    let (inner, ctx, dev_id) = py.allow_threads(|| {
2023        let cuda = CudaSwma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2024        let ctx = cuda.context_arc();
2025        let dev_id = cuda.device_id();
2026        cuda.swma_batch_dev(&data_f32, &sweep)
2027            .map(|inner| (inner, ctx, dev_id))
2028            .map_err(|e| PyValueError::new_err(e.to_string()))
2029    })?;
2030
2031    Ok(DeviceArrayF32SwmaPy {
2032        inner: Some(DeviceArrayF32Py {
2033            inner,
2034            _ctx: Some(ctx),
2035            device_id: Some(dev_id),
2036        }),
2037    })
2038}
2039
2040#[cfg(all(feature = "python", feature = "cuda"))]
2041#[pyfunction(name = "swma_cuda_many_series_one_param_dev")]
2042#[pyo3(signature = (data_tm_f32, period, device_id=0))]
2043pub fn swma_cuda_many_series_one_param_dev_py(
2044    py: Python<'_>,
2045    data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
2046    period: usize,
2047    device_id: usize,
2048) -> PyResult<DeviceArrayF32SwmaPy> {
2049    use numpy::PyUntypedArrayMethods;
2050
2051    if !cuda_available() {
2052        return Err(PyValueError::new_err("CUDA not available"));
2053    }
2054
2055    let flat_in = data_tm_f32.as_slice()?;
2056    let rows = data_tm_f32.shape()[0];
2057    let cols = data_tm_f32.shape()[1];
2058    let params = SwmaParams {
2059        period: Some(period),
2060    };
2061
2062    let (inner, ctx, dev_id) = py.allow_threads(|| {
2063        let cuda = CudaSwma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2064        let ctx = cuda.context_arc();
2065        let dev_id = cuda.device_id();
2066        cuda.swma_multi_series_one_param_time_major_dev(flat_in, cols, rows, &params)
2067            .map(|inner| (inner, ctx, dev_id))
2068            .map_err(|e| PyValueError::new_err(e.to_string()))
2069    })?;
2070
2071    Ok(DeviceArrayF32SwmaPy {
2072        inner: Some(DeviceArrayF32Py {
2073            inner,
2074            _ctx: Some(ctx),
2075            device_id: Some(dev_id),
2076        }),
2077    })
2078}
2079
2080#[cfg(all(feature = "python", feature = "cuda"))]
2081#[pyclass(module = "ta_indicators.cuda", name = "DeviceArrayF32Swma", unsendable)]
2082pub struct DeviceArrayF32SwmaPy {
2083    pub(crate) inner: Option<DeviceArrayF32Py>,
2084}
2085
2086#[cfg(all(feature = "python", feature = "cuda"))]
2087#[pymethods]
2088impl DeviceArrayF32SwmaPy {
2089    #[getter]
2090    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2091        let inner = self
2092            .inner
2093            .as_ref()
2094            .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
2095        inner.__cuda_array_interface__(py)
2096    }
2097
2098    fn __dlpack_device__(&self) -> PyResult<(i32, i32)> {
2099        let inner = self
2100            .inner
2101            .as_ref()
2102            .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
2103        inner.__dlpack_device__()
2104    }
2105
2106    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2107    fn __dlpack__<'py>(
2108        &mut self,
2109        py: Python<'py>,
2110        stream: Option<PyObject>,
2111        max_version: Option<PyObject>,
2112        dl_device: Option<PyObject>,
2113        copy: Option<PyObject>,
2114    ) -> PyResult<PyObject> {
2115        let mut inner = self
2116            .inner
2117            .take()
2118            .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
2119        let capsule = inner.__dlpack__(py, stream, max_version, dl_device, copy)?;
2120        Ok(capsule)
2121    }
2122}
2123
2124#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2125#[wasm_bindgen]
2126pub fn swma_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
2127    let params = SwmaParams {
2128        period: Some(period),
2129    };
2130    let input = SwmaInput::from_slice(data, params);
2131
2132    let mut output = vec![0.0; data.len()];
2133
2134    swma_into_slice(&mut output, &input, Kernel::Auto)
2135        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2136
2137    Ok(output)
2138}
2139
2140#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2141#[wasm_bindgen]
2142pub fn swma_batch_js(
2143    data: &[f64],
2144    period_start: usize,
2145    period_end: usize,
2146    period_step: usize,
2147) -> Result<Vec<f64>, JsValue> {
2148    let sweep = SwmaBatchRange {
2149        period: (period_start, period_end, period_step),
2150    };
2151    swma_batch_with_kernel(data, &sweep, Kernel::Auto)
2152        .map(|o| o.values)
2153        .map_err(|e| JsValue::from_str(&e.to_string()))
2154}
2155
2156#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2157#[wasm_bindgen]
2158pub fn swma_batch_metadata_js(
2159    period_start: usize,
2160    period_end: usize,
2161    period_step: usize,
2162) -> Result<Vec<f64>, JsValue> {
2163    let sweep = SwmaBatchRange {
2164        period: (period_start, period_end, period_step),
2165    };
2166
2167    let combos = expand_grid(&sweep);
2168    let mut metadata = Vec::with_capacity(combos.len());
2169
2170    for combo in combos {
2171        metadata.push(combo.period.unwrap_or(5) as f64);
2172    }
2173
2174    Ok(metadata)
2175}
2176
2177#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2178#[derive(Serialize, Deserialize)]
2179pub struct SwmaBatchConfig {
2180    pub period_range: (usize, usize, usize),
2181}
2182
2183#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2184#[derive(Serialize, Deserialize)]
2185pub struct SwmaBatchJsOutput {
2186    pub values: Vec<f64>,
2187    pub combos: Vec<SwmaParams>,
2188    pub rows: usize,
2189    pub cols: usize,
2190}
2191
2192#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2193#[wasm_bindgen(js_name = swma_batch)]
2194pub fn swma_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2195    let config: SwmaBatchConfig = serde_wasm_bindgen::from_value(config)
2196        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2197
2198    let sweep = SwmaBatchRange {
2199        period: config.period_range,
2200    };
2201
2202    let output = swma_batch_with_kernel(data, &sweep, Kernel::Auto)
2203        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2204
2205    let js_output = SwmaBatchJsOutput {
2206        values: output.values,
2207        combos: output.combos,
2208        rows: output.rows,
2209        cols: output.cols,
2210    };
2211
2212    serde_wasm_bindgen::to_value(&js_output)
2213        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2214}
2215
2216#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2217#[wasm_bindgen]
2218pub fn swma_alloc(len: usize) -> *mut f64 {
2219    let mut vec = Vec::<f64>::with_capacity(len);
2220    let ptr = vec.as_mut_ptr();
2221    std::mem::forget(vec);
2222    ptr
2223}
2224
2225#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2226#[wasm_bindgen]
2227pub fn swma_free(ptr: *mut f64, len: usize) {
2228    if !ptr.is_null() {
2229        unsafe {
2230            let _ = Vec::from_raw_parts(ptr, len, len);
2231        }
2232    }
2233}
2234
2235#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2236#[wasm_bindgen]
2237pub fn swma_into(
2238    in_ptr: *const f64,
2239    out_ptr: *mut f64,
2240    len: usize,
2241    period: usize,
2242) -> Result<(), JsValue> {
2243    if in_ptr.is_null() || out_ptr.is_null() {
2244        return Err(JsValue::from_str("Null pointer provided"));
2245    }
2246
2247    unsafe {
2248        let data = std::slice::from_raw_parts(in_ptr, len);
2249
2250        if period == 0 || period > len {
2251            return Err(JsValue::from_str("Invalid period"));
2252        }
2253
2254        let params = SwmaParams {
2255            period: Some(period),
2256        };
2257        let input = SwmaInput::from_slice(data, params);
2258
2259        if in_ptr == out_ptr {
2260            let mut temp = vec![0.0; len];
2261            swma_into_slice(&mut temp, &input, Kernel::Auto)
2262                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2263
2264            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2265            out.copy_from_slice(&temp);
2266        } else {
2267            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2268            swma_into_slice(out, &input, Kernel::Auto)
2269                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2270        }
2271
2272        Ok(())
2273    }
2274}
2275
2276#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2277#[wasm_bindgen]
2278pub fn swma_batch_into(
2279    in_ptr: *const f64,
2280    out_ptr: *mut f64,
2281    len: usize,
2282    period_start: usize,
2283    period_end: usize,
2284    period_step: usize,
2285) -> Result<usize, JsValue> {
2286    if in_ptr.is_null() || out_ptr.is_null() {
2287        return Err(JsValue::from_str("null pointer passed to swma_batch_into"));
2288    }
2289
2290    unsafe {
2291        let data = std::slice::from_raw_parts(in_ptr, len);
2292
2293        let sweep = SwmaBatchRange {
2294            period: (period_start, period_end, period_step),
2295        };
2296
2297        let combos = expand_grid(&sweep);
2298        if combos.is_empty() {
2299            return Err(JsValue::from_str(
2300                "swma: invalid period range (empty expansion)",
2301            ));
2302        }
2303        let rows = combos.len();
2304        let cols = len;
2305        let rows_cols = rows
2306            .checked_mul(cols)
2307            .ok_or_else(|| JsValue::from_str("swma: rows*cols overflow"))?;
2308
2309        let out = std::slice::from_raw_parts_mut(out_ptr, rows_cols);
2310
2311        swma_batch_inner_into(data, &sweep, Kernel::Auto, false, out)
2312            .map_err(|e| JsValue::from_str(&e.to_string()))?;
2313
2314        Ok(rows)
2315    }
2316}