Skip to main content

vector_ta/indicators/moving_averages/
sinwma.rs

1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5    make_uninit_matrix,
6};
7use aligned_vec::{AVec, CACHELINE_ALIGN};
8#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
9use core::arch::x86_64::*;
10#[cfg(not(target_arch = "wasm32"))]
11use rayon::prelude::*;
12use std::convert::AsRef;
13use std::error::Error;
14use std::f64::consts::PI;
15use std::mem::MaybeUninit;
16use thiserror::Error;
17
18#[cfg(all(feature = "python", feature = "cuda"))]
19use crate::cuda::cuda_available;
20#[cfg(all(feature = "python", feature = "cuda"))]
21use crate::cuda::moving_averages::CudaSinwma;
22#[cfg(all(feature = "python", feature = "cuda"))]
23use crate::utilities::dlpack_cuda::{make_device_array_py, DeviceArrayF32Py};
24#[cfg(feature = "python")]
25use crate::utilities::kernel_validation::validate_kernel;
26#[cfg(feature = "python")]
27use numpy::{IntoPyArray, PyArray1, PyArray2, PyArrayMethods, PyReadonlyArray1};
28#[cfg(all(feature = "python", feature = "cuda"))]
29use numpy::{PyReadonlyArray2, PyUntypedArrayMethods};
30#[cfg(feature = "python")]
31use pyo3::exceptions::PyValueError;
32#[cfg(feature = "python")]
33use pyo3::prelude::*;
34#[cfg(feature = "python")]
35use pyo3::types::{PyDict, PyList};
36
37#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
38use serde::{Deserialize, Serialize};
39#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
40use wasm_bindgen::prelude::*;
41
42impl<'a> AsRef<[f64]> for SinWmaInput<'a> {
43    #[inline(always)]
44    fn as_ref(&self) -> &[f64] {
45        match &self.data {
46            SinWmaData::Slice(slice) => slice,
47            SinWmaData::Candles { candles, source } => source_type(candles, source),
48        }
49    }
50}
51
52#[derive(Debug, Clone)]
53pub enum SinWmaData<'a> {
54    Candles {
55        candles: &'a Candles,
56        source: &'a str,
57    },
58    Slice(&'a [f64]),
59}
60
61#[derive(Debug, Clone)]
62pub struct SinWmaOutput {
63    pub values: Vec<f64>,
64}
65
66#[derive(Debug, Clone)]
67pub struct SinWmaParams {
68    pub period: Option<usize>,
69}
70
71impl Default for SinWmaParams {
72    fn default() -> Self {
73        Self { period: Some(14) }
74    }
75}
76
77#[derive(Debug, Clone)]
78pub struct SinWmaInput<'a> {
79    pub data: SinWmaData<'a>,
80    pub params: SinWmaParams,
81}
82
83impl<'a> SinWmaInput<'a> {
84    #[inline]
85    pub fn from_candles(c: &'a Candles, s: &'a str, p: SinWmaParams) -> Self {
86        Self {
87            data: SinWmaData::Candles {
88                candles: c,
89                source: s,
90            },
91            params: p,
92        }
93    }
94    #[inline]
95    pub fn from_slice(sl: &'a [f64], p: SinWmaParams) -> Self {
96        Self {
97            data: SinWmaData::Slice(sl),
98            params: p,
99        }
100    }
101    #[inline]
102    pub fn with_default_candles(c: &'a Candles) -> Self {
103        Self::from_candles(c, "close", SinWmaParams::default())
104    }
105    #[inline]
106    pub fn get_period(&self) -> usize {
107        self.params.period.unwrap_or(14)
108    }
109}
110
111#[derive(Copy, Clone, Debug)]
112pub struct SinWmaBuilder {
113    period: Option<usize>,
114    kernel: Kernel,
115}
116
117impl Default for SinWmaBuilder {
118    fn default() -> Self {
119        Self {
120            period: None,
121            kernel: Kernel::Auto,
122        }
123    }
124}
125
126impl SinWmaBuilder {
127    #[inline(always)]
128    pub fn new() -> Self {
129        Self::default()
130    }
131    #[inline(always)]
132    pub fn period(mut self, n: usize) -> Self {
133        self.period = Some(n);
134        self
135    }
136    #[inline(always)]
137    pub fn kernel(mut self, k: Kernel) -> Self {
138        self.kernel = k;
139        self
140    }
141
142    #[inline(always)]
143    pub fn apply(self, c: &Candles) -> Result<SinWmaOutput, SinWmaError> {
144        let p = SinWmaParams {
145            period: self.period,
146        };
147        let i = SinWmaInput::from_candles(c, "close", p);
148        sinwma_with_kernel(&i, self.kernel)
149    }
150
151    #[inline(always)]
152    pub fn apply_slice(self, d: &[f64]) -> Result<SinWmaOutput, SinWmaError> {
153        let p = SinWmaParams {
154            period: self.period,
155        };
156        let i = SinWmaInput::from_slice(d, p);
157        sinwma_with_kernel(&i, self.kernel)
158    }
159
160    #[inline(always)]
161    pub fn into_stream(self) -> Result<SinWmaStream, SinWmaError> {
162        let p = SinWmaParams {
163            period: self.period,
164        };
165        SinWmaStream::try_new(p)
166    }
167}
168
169#[derive(Debug, Error)]
170pub enum SinWmaError {
171    #[error("sinwma: No data provided (empty slice).")]
172    EmptyInputData,
173    #[error("sinwma: All values are NaN.")]
174    AllValuesNaN,
175    #[error("sinwma: Invalid period: period = {period}, data length = {data_len}")]
176    InvalidPeriod { period: usize, data_len: usize },
177    #[error("sinwma: Not enough valid data: needed = {needed}, valid = {valid}")]
178    NotEnoughValidData { needed: usize, valid: usize },
179    #[error("sinwma: Sum of sines is zero or too close to zero. sum_sines = {sum_sines}")]
180    ZeroSumSines { sum_sines: f64 },
181    #[error("sinwma: Output slice length mismatch: expected = {expected}, got = {got}")]
182    OutputLengthMismatch { expected: usize, got: usize },
183    #[error("sinwma: Invalid range: start={start}, end={end}, step={step}")]
184    InvalidRange {
185        start: usize,
186        end: usize,
187        step: usize,
188    },
189    #[error("sinwma: Invalid kernel for batch path: {0:?}")]
190    InvalidKernelForBatch(Kernel),
191}
192
193#[inline]
194pub fn sinwma(input: &SinWmaInput) -> Result<SinWmaOutput, SinWmaError> {
195    sinwma_with_kernel(input, Kernel::Auto)
196}
197
198pub fn sinwma_with_kernel(
199    input: &SinWmaInput,
200    kernel: Kernel,
201) -> Result<SinWmaOutput, SinWmaError> {
202    let (data, weights, period, first, chosen) = sinwma_prepare(input, kernel)?;
203
204    let mut out = alloc_with_nan_prefix(data.len(), first + period - 1);
205
206    sinwma_compute_into(data, &weights, period, first, chosen, &mut out);
207
208    Ok(SinWmaOutput { values: out })
209}
210
211#[inline]
212pub fn sinwma_into_slice(
213    dst: &mut [f64],
214    input: &SinWmaInput,
215    kern: Kernel,
216) -> Result<(), SinWmaError> {
217    let (data, weights, period, first, chosen) = sinwma_prepare(input, kern)?;
218
219    if dst.len() != data.len() {
220        return Err(SinWmaError::OutputLengthMismatch {
221            expected: data.len(),
222            got: dst.len(),
223        });
224    }
225
226    sinwma_compute_into(data, &weights, period, first, chosen, dst);
227
228    let warmup_end = first + period - 1;
229    for v in &mut dst[..warmup_end] {
230        *v = f64::NAN;
231    }
232
233    Ok(())
234}
235
236#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
237#[inline]
238pub fn sinwma_into(input: &SinWmaInput, out: &mut [f64]) -> Result<(), SinWmaError> {
239    sinwma_into_slice(out, input, Kernel::Auto)
240}
241
242#[inline(always)]
243fn sinwma_prepare<'a>(
244    input: &'a SinWmaInput,
245    kernel: Kernel,
246) -> Result<(&'a [f64], AVec<f64>, usize, usize, Kernel), SinWmaError> {
247    let data: &[f64] = input.as_ref();
248    let len = data.len();
249
250    if len == 0 {
251        return Err(SinWmaError::EmptyInputData);
252    }
253
254    let first = data
255        .iter()
256        .position(|x| !x.is_nan())
257        .ok_or(SinWmaError::AllValuesNaN)?;
258
259    let period = input.get_period();
260
261    if period == 0 || period > len {
262        return Err(SinWmaError::InvalidPeriod {
263            period,
264            data_len: len,
265        });
266    }
267    if (len - first) < period {
268        return Err(SinWmaError::NotEnoughValidData {
269            needed: period,
270            valid: len - first,
271        });
272    }
273
274    let mut weights: AVec<f64> = AVec::with_capacity(CACHELINE_ALIGN, period);
275    weights.resize(period, 0.0);
276    let mut sum_sines = 0.0;
277    for k in 0..period {
278        let angle = (k as f64 + 1.0) * PI / (period as f64 + 1.0);
279        let val = angle.sin();
280        weights[k] = val;
281        sum_sines += val;
282    }
283
284    if sum_sines.abs() < f64::EPSILON {
285        return Err(SinWmaError::ZeroSumSines { sum_sines });
286    }
287    let inv_sum = 1.0 / sum_sines;
288    for w in &mut weights[..] {
289        *w *= inv_sum;
290    }
291
292    let chosen = match kernel {
293        Kernel::Auto => detect_best_kernel(),
294        k => k,
295    };
296
297    Ok((data, weights, period, first, chosen))
298}
299
300#[inline(always)]
301fn sinwma_compute_into(
302    data: &[f64],
303    weights: &[f64],
304    period: usize,
305    first: usize,
306    kernel: Kernel,
307    out: &mut [f64],
308) {
309    unsafe {
310        match kernel {
311            Kernel::Scalar | Kernel::ScalarBatch => {
312                sinwma_scalar(data, weights, period, first, out)
313            }
314            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
315            Kernel::Avx2 | Kernel::Avx2Batch => sinwma_avx2(data, weights, period, first, out),
316            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
317            Kernel::Avx512 | Kernel::Avx512Batch => {
318                sinwma_avx512(data, weights, period, first, out)
319            }
320            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
321            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
322                sinwma_scalar(data, weights, period, first, out)
323            }
324            _ => unreachable!(),
325        }
326    }
327}
328
329#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
330#[inline]
331#[target_feature(enable = "avx512f")]
332unsafe fn hsum_pd_zmm(v: __m512d) -> f64 {
333    _mm512_reduce_add_pd(v)
334}
335
336#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
337#[inline]
338pub fn sinwma_avx512(
339    data: &[f64],
340    weights: &[f64],
341    period: usize,
342    first_valid: usize,
343    out: &mut [f64],
344) {
345    if period <= 32 {
346        unsafe { sinwma_avx512_short(data, weights, period, first_valid, out) }
347    } else {
348        unsafe { sinwma_avx512_long(data, weights, period, first_valid, out) }
349    }
350}
351
352#[inline]
353pub fn sinwma_scalar(
354    data: &[f64],
355    weights: &[f64],
356    period: usize,
357    first_val: usize,
358    out: &mut [f64],
359) {
360    assert_eq!(weights.len(), period, "weights.len() must equal `period`");
361    assert!(
362        out.len() >= data.len(),
363        "`out` must be at least as long as `data`"
364    );
365
366    let p4 = period & !3;
367
368    for i in (first_val + period - 1)..data.len() {
369        let start = i + 1 - period;
370        let window = &data[start..start + period];
371
372        let mut sum = 0.0;
373        for (d4, w4) in window[..p4]
374            .chunks_exact(4)
375            .zip(weights[..p4].chunks_exact(4))
376        {
377            sum += d4[0] * w4[0] + d4[1] * w4[1] + d4[2] * w4[2] + d4[3] * w4[3];
378        }
379
380        for (d, w) in window[p4..].iter().zip(&weights[p4..]) {
381            sum += d * w;
382        }
383
384        out[i] = sum;
385    }
386}
387
388#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
389#[inline]
390#[target_feature(enable = "avx2,fma")]
391pub unsafe fn sinwma_avx2(
392    data: &[f64],
393    weights: &[f64],
394    period: usize,
395    first_valid: usize,
396    out: &mut [f64],
397) {
398    debug_assert!(period >= 1);
399    debug_assert_eq!(data.len(), out.len());
400    debug_assert_eq!(weights.len(), period);
401
402    let p4 = period & !3usize;
403    for i in (first_valid + period - 1)..data.len() {
404        let start = i + 1 - period;
405        let mut acc = _mm256_setzero_pd();
406
407        let mut k = 0usize;
408        while k < p4 {
409            let w = _mm256_loadu_pd(weights.as_ptr().add(k));
410            let d = _mm256_loadu_pd(data.as_ptr().add(start + k));
411            acc = _mm256_fmadd_pd(d, w, acc);
412            k += 4;
413        }
414
415        let sum128 = _mm_add_pd(_mm256_castpd256_pd128(acc), _mm256_extractf128_pd(acc, 1));
416        let mut sum = _mm_cvtsd_f64(_mm_hadd_pd(sum128, sum128));
417
418        while k < period {
419            sum = (*data.get_unchecked(start + k)).mul_add(*weights.get_unchecked(k), sum);
420            k += 1;
421        }
422
423        *out.get_unchecked_mut(i) = sum;
424    }
425}
426
427#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
428#[inline]
429#[target_feature(enable = "avx512f,fma")]
430pub unsafe fn sinwma_avx512_short(
431    data: &[f64],
432    weights: &[f64],
433    period: usize,
434    first_valid: usize,
435    out: &mut [f64],
436) {
437    debug_assert!(period >= 1);
438    debug_assert_eq!(data.len(), out.len());
439    debug_assert_eq!(weights.len(), period);
440
441    const STEP: usize = 8;
442    let chunks = period / STEP;
443    let tail_len = period % STEP;
444    let tail_mask: __mmask8 = (1u8 << tail_len).wrapping_sub(1);
445
446    if chunks == 0 {
447        let wv = _mm512_maskz_loadu_pd(tail_mask, weights.as_ptr());
448        for i in (first_valid + period - 1)..data.len() {
449            let start = i + 1 - period;
450            let dv = _mm512_maskz_loadu_pd(tail_mask, data.as_ptr().add(start));
451            let sum = hsum_pd_zmm(_mm512_mul_pd(dv, wv));
452            *out.get_unchecked_mut(i) = sum;
453        }
454        return;
455    }
456
457    for i in (first_valid + period - 1)..data.len() {
458        let start = i + 1 - period;
459        let mut acc = _mm512_setzero_pd();
460
461        for blk in 0..chunks {
462            let w = _mm512_loadu_pd(weights.as_ptr().add(blk * STEP));
463            let d = _mm512_loadu_pd(data.as_ptr().add(start + blk * STEP));
464            acc = _mm512_fmadd_pd(d, w, acc);
465        }
466
467        if tail_len != 0 {
468            let wt = _mm512_maskz_loadu_pd(tail_mask, weights.as_ptr().add(chunks * STEP));
469            let dt = _mm512_maskz_loadu_pd(tail_mask, data.as_ptr().add(start + chunks * STEP));
470            acc = _mm512_fmadd_pd(dt, wt, acc);
471        }
472
473        *out.get_unchecked_mut(i) = hsum_pd_zmm(acc);
474    }
475}
476
477#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
478#[inline]
479#[target_feature(enable = "avx512f,fma,avx512dq")]
480pub unsafe fn sinwma_avx512_long(
481    data: &[f64],
482    weights: &[f64],
483    period: usize,
484    first_valid: usize,
485    out: &mut [f64],
486) {
487    const STEP: usize = 8;
488    let n_chunks = period / STEP;
489    let tail_len = period % STEP;
490    let unroll8 = n_chunks & !3;
491    let tail_mask: __mmask8 = (1u8 << tail_len).wrapping_sub(1);
492
493    debug_assert!(period >= 1);
494    debug_assert_eq!(data.len(), out.len());
495    debug_assert_eq!(weights.len(), period);
496
497    let mut wregs: Vec<__m512d> = Vec::with_capacity(n_chunks);
498    for blk in 0..n_chunks {
499        wregs.push(_mm512_loadu_pd(weights.as_ptr().add(blk * STEP)));
500    }
501    let w_tail = if tail_len != 0 {
502        Some(_mm512_maskz_loadu_pd(
503            tail_mask,
504            weights.as_ptr().add(n_chunks * STEP),
505        ))
506    } else {
507        None
508    };
509
510    let mut data_ptr = data.as_ptr().add(first_valid);
511    let stop_ptr = data.as_ptr().add(data.len());
512    let mut dst_ptr = out.as_mut_ptr().add(first_valid + period - 1);
513
514    if tail_len == 0 {
515        while data_ptr.add(period) <= stop_ptr {
516            let mut s0 = _mm512_setzero_pd();
517            let mut s1 = _mm512_setzero_pd();
518            let mut s2 = _mm512_setzero_pd();
519            let mut s3 = _mm512_setzero_pd();
520
521            for blk in (0..unroll8).step_by(4) {
522                let d0 = _mm512_loadu_pd(data_ptr.add((blk + 0) * STEP));
523                let d1 = _mm512_loadu_pd(data_ptr.add((blk + 1) * STEP));
524                let d2 = _mm512_loadu_pd(data_ptr.add((blk + 2) * STEP));
525                let d3 = _mm512_loadu_pd(data_ptr.add((blk + 3) * STEP));
526
527                s0 = _mm512_fmadd_pd(d0, *wregs.get_unchecked(blk + 0), s0);
528                s1 = _mm512_fmadd_pd(d1, *wregs.get_unchecked(blk + 1), s1);
529                s2 = _mm512_fmadd_pd(d2, *wregs.get_unchecked(blk + 2), s2);
530                s3 = _mm512_fmadd_pd(d3, *wregs.get_unchecked(blk + 3), s3);
531            }
532
533            for blk in unroll8..n_chunks {
534                let d = _mm512_loadu_pd(data_ptr.add(blk * STEP));
535                s0 = _mm512_fmadd_pd(d, *wregs.get_unchecked(blk), s0);
536            }
537
538            let sum01 = _mm512_add_pd(s0, s1);
539            let sum23 = _mm512_add_pd(s2, s3);
540            let tot = _mm512_add_pd(sum01, sum23);
541            *dst_ptr = hsum_pd_zmm(tot);
542
543            data_ptr = data_ptr.add(1);
544            dst_ptr = dst_ptr.add(1);
545        }
546    } else {
547        let wt = w_tail.unwrap();
548        while data_ptr.add(period) <= stop_ptr {
549            let mut s0 = _mm512_setzero_pd();
550            let mut s1 = _mm512_setzero_pd();
551            let mut s2 = _mm512_setzero_pd();
552            let mut s3 = _mm512_setzero_pd();
553
554            for blk in (0..unroll8).step_by(4) {
555                let d0 = _mm512_loadu_pd(data_ptr.add((blk + 0) * STEP));
556                let d1 = _mm512_loadu_pd(data_ptr.add((blk + 1) * STEP));
557                let d2 = _mm512_loadu_pd(data_ptr.add((blk + 2) * STEP));
558                let d3 = _mm512_loadu_pd(data_ptr.add((blk + 3) * STEP));
559
560                s0 = _mm512_fmadd_pd(d0, *wregs.get_unchecked(blk + 0), s0);
561                s1 = _mm512_fmadd_pd(d1, *wregs.get_unchecked(blk + 1), s1);
562                s2 = _mm512_fmadd_pd(d2, *wregs.get_unchecked(blk + 2), s2);
563                s3 = _mm512_fmadd_pd(d3, *wregs.get_unchecked(blk + 3), s3);
564            }
565
566            for blk in unroll8..n_chunks {
567                let d = _mm512_loadu_pd(data_ptr.add(blk * STEP));
568                s0 = _mm512_fmadd_pd(d, *wregs.get_unchecked(blk), s0);
569            }
570
571            let dt = _mm512_maskz_loadu_pd(tail_mask, data_ptr.add(n_chunks * STEP));
572            s0 = _mm512_fmadd_pd(dt, wt, s0);
573
574            let sum01 = _mm512_add_pd(s0, s1);
575            let sum23 = _mm512_add_pd(s2, s3);
576            let tot = _mm512_add_pd(sum01, sum23);
577            *dst_ptr = hsum_pd_zmm(tot);
578
579            data_ptr = data_ptr.add(1);
580            dst_ptr = dst_ptr.add(1);
581        }
582    }
583}
584
585#[derive(Debug, Clone)]
586pub struct SinWmaStream {
587    period: usize,
588
589    r_re: f64,
590    r_im: f64,
591    rp_re: f64,
592    rp_im: f64,
593
594    sinp: f64,
595    cosp: f64,
596
597    inv_sum: f64,
598
599    z_re: f64,
600    z_im: f64,
601
602    buffer: Vec<f64>,
603    head: usize,
604    filled: bool,
605
606    nan_count: usize,
607    z_valid: bool,
608}
609
610impl SinWmaStream {
611    pub fn try_new(params: SinWmaParams) -> Result<Self, SinWmaError> {
612        let period = params.period.unwrap_or(14);
613        if period == 0 {
614            return Err(SinWmaError::InvalidPeriod {
615                period,
616                data_len: 0,
617            });
618        }
619
620        let alpha = PI / (period as f64 + 1.0);
621
622        let (sina, cosa) = alpha.sin_cos();
623        let (sinp, cosp) = (alpha * period as f64).sin_cos();
624
625        let denom = (0.5 * alpha).sin();
626        if denom.abs() < f64::EPSILON {
627            return Err(SinWmaError::ZeroSumSines { sum_sines: 0.0 });
628        }
629        let sum_sines = (0.5 * alpha * period as f64).sin() / denom;
630        if sum_sines.abs() < f64::EPSILON {
631            return Err(SinWmaError::ZeroSumSines { sum_sines });
632        }
633        let inv_sum = 1.0 / sum_sines;
634
635        Ok(Self {
636            period,
637            r_re: cosa,
638            r_im: sina,
639            rp_re: cosp,
640            rp_im: sinp,
641            sinp,
642            cosp,
643            inv_sum,
644            z_re: 0.0,
645            z_im: 0.0,
646            buffer: vec![f64::NAN; period],
647            head: 0,
648            filled: false,
649            nan_count: period,
650            z_valid: false,
651        })
652    }
653
654    #[inline(always)]
655    pub fn update(&mut self, value: f64) -> Option<f64> {
656        if !self.filled {
657            let idx = self.head;
658            let old = self.buffer[idx];
659            if old.is_nan() {
660                self.nan_count = self.nan_count.saturating_sub(1);
661            }
662            if value.is_nan() {
663                self.nan_count += 1;
664            }
665            self.buffer[idx] = value;
666            self.head = (idx + 1) % self.period;
667
668            if self.head != 0 {
669                return None;
670            }
671
672            self.filled = true;
673
674            if self.nan_count != 0 {
675                self.z_valid = false;
676                return Some(f64::NAN);
677            }
678
679            self.rebuild_z();
680            return Some(self.output_from_z());
681        }
682
683        let idx_old = self.head;
684        let x_old = self.buffer[idx_old];
685
686        self.buffer[idx_old] = value;
687        self.head = (idx_old + 1) % self.period;
688
689        if x_old.is_nan() {
690            self.nan_count = self.nan_count.saturating_sub(1);
691        }
692        if value.is_nan() {
693            self.nan_count += 1;
694        }
695
696        if self.nan_count != 0 {
697            self.z_valid = false;
698            return Some(f64::NAN);
699        }
700
701        if !self.z_valid {
702            self.rebuild_z();
703            return Some(self.output_from_z());
704        }
705
706        let rZ_re = self.r_re.mul_add(self.z_re, -self.r_im * self.z_im);
707        let rZ_im = self.r_re.mul_add(self.z_im, self.r_im * self.z_re);
708
709        self.z_re = rZ_re + value - self.rp_re * x_old;
710        self.z_im = rZ_im - self.rp_im * x_old;
711
712        Some(self.output_from_z())
713    }
714
715    #[inline(always)]
716    fn output_from_z(&self) -> f64 {
717        let y_unscaled = self.sinp.mul_add(self.z_re, -self.cosp * self.z_im);
718        y_unscaled * self.inv_sum
719    }
720
721    #[inline(always)]
722    fn rebuild_z(&mut self) {
723        debug_assert!(
724            self.nan_count == 0,
725            "rebuild_z called on NaN-contaminated window"
726        );
727        let newest = (self.head + self.period - 1) % self.period;
728
729        let mut rj_re = 1.0f64;
730        let mut rj_im = 0.0f64;
731
732        let mut zr = 0.0f64;
733        let mut zi = 0.0f64;
734
735        for j in 0..self.period {
736            let idx = (newest + self.period - j) % self.period;
737            let x = self.buffer[idx];
738            zr = rj_re.mul_add(x, zr);
739            zi = rj_im.mul_add(x, zi);
740
741            let nr_re = rj_re.mul_add(self.r_re, -rj_im * self.r_im);
742            let nr_im = rj_re.mul_add(self.r_im, rj_im * self.r_re);
743            rj_re = nr_re;
744            rj_im = nr_im;
745        }
746
747        self.z_re = zr;
748        self.z_im = zi;
749        self.z_valid = true;
750    }
751}
752
753#[derive(Clone, Debug)]
754pub struct SinWmaBatchRange {
755    pub period: (usize, usize, usize),
756}
757
758impl Default for SinWmaBatchRange {
759    fn default() -> Self {
760        Self {
761            period: (14, 263, 1),
762        }
763    }
764}
765
766#[derive(Clone, Debug, Default)]
767pub struct SinWmaBatchBuilder {
768    range: SinWmaBatchRange,
769    kernel: Kernel,
770}
771
772impl SinWmaBatchBuilder {
773    pub fn new() -> Self {
774        Self::default()
775    }
776    pub fn kernel(mut self, k: Kernel) -> Self {
777        self.kernel = k;
778        self
779    }
780
781    #[inline]
782    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
783        self.range.period = (start, end, step);
784        self
785    }
786    #[inline]
787    pub fn period_static(mut self, p: usize) -> Self {
788        self.range.period = (p, p, 0);
789        self
790    }
791
792    pub fn apply_slice(self, data: &[f64]) -> Result<SinWmaBatchOutput, SinWmaError> {
793        sinwma_batch_with_kernel(data, &self.range, self.kernel)
794    }
795
796    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<SinWmaBatchOutput, SinWmaError> {
797        SinWmaBatchBuilder::new().kernel(k).apply_slice(data)
798    }
799
800    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<SinWmaBatchOutput, SinWmaError> {
801        let slice = source_type(c, src);
802        self.apply_slice(slice)
803    }
804
805    pub fn with_default_candles(c: &Candles) -> Result<SinWmaBatchOutput, SinWmaError> {
806        SinWmaBatchBuilder::new()
807            .kernel(Kernel::Auto)
808            .apply_candles(c, "close")
809    }
810}
811
812pub fn sinwma_batch_with_kernel(
813    data: &[f64],
814    sweep: &SinWmaBatchRange,
815    k: Kernel,
816) -> Result<SinWmaBatchOutput, SinWmaError> {
817    let kernel = match k {
818        Kernel::Auto => detect_best_batch_kernel(),
819        other if other.is_batch() => other,
820        _ => return Err(SinWmaError::InvalidKernelForBatch(k)),
821    };
822
823    let simd = match kernel {
824        Kernel::Avx512Batch => Kernel::Avx512,
825        Kernel::Avx2Batch => Kernel::Avx2,
826        Kernel::ScalarBatch => Kernel::Scalar,
827        _ => unreachable!(),
828    };
829    sinwma_batch_par_slice(data, sweep, simd)
830}
831
832#[derive(Clone, Debug)]
833pub struct SinWmaBatchOutput {
834    pub values: Vec<f64>,
835    pub combos: Vec<SinWmaParams>,
836    pub rows: usize,
837    pub cols: usize,
838}
839impl SinWmaBatchOutput {
840    pub fn row_for_params(&self, p: &SinWmaParams) -> Option<usize> {
841        self.combos
842            .iter()
843            .position(|c| c.period.unwrap_or(14) == p.period.unwrap_or(14))
844    }
845
846    pub fn values_for(&self, p: &SinWmaParams) -> Option<&[f64]> {
847        self.row_for_params(p).map(|row| {
848            let start = row * self.cols;
849            &self.values[start..start + self.cols]
850        })
851    }
852}
853
854#[inline(always)]
855fn expand_grid(r: &SinWmaBatchRange) -> Result<Vec<SinWmaParams>, SinWmaError> {
856    fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, SinWmaError> {
857        if step == 0 || start == end {
858            return Ok(vec![start]);
859        }
860        if start < end {
861            let v: Vec<usize> = (start..=end).step_by(step).collect();
862            if v.is_empty() {
863                return Err(SinWmaError::InvalidRange { start, end, step });
864            }
865            return Ok(v);
866        }
867
868        let mut v = Vec::new();
869        let mut cur = start;
870        while cur >= end {
871            v.push(cur);
872            if cur == end {
873                break;
874            }
875
876            let next = cur.saturating_sub(step);
877            if next == cur {
878                break;
879            }
880            cur = next;
881        }
882        if v.is_empty() {
883            return Err(SinWmaError::InvalidRange { start, end, step });
884        }
885        Ok(v)
886    }
887
888    let periods = axis_usize(r.period)?;
889
890    let mut out = Vec::with_capacity(periods.len());
891    for &p in &periods {
892        out.push(SinWmaParams { period: Some(p) });
893    }
894    Ok(out)
895}
896
897#[inline]
898fn round_up8(x: usize) -> usize {
899    (x + 7) & !7
900}
901
902#[inline(always)]
903pub unsafe fn sinwma_row_dispatch(
904    data: &[f64],
905    first: usize,
906    period: usize,
907    w_ptr: *const f64,
908    out: &mut [f64],
909    kern: Kernel,
910) {
911    match kern {
912        Kernel::Scalar | Kernel::ScalarBatch => sinwma_row_scalar(data, first, period, w_ptr, out),
913        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
914        Kernel::Avx2 | Kernel::Avx2Batch => sinwma_row_avx2(data, first, period, w_ptr, out),
915        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
916        Kernel::Avx512 | Kernel::Avx512Batch => sinwma_row_avx512(data, first, period, w_ptr, out),
917        #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
918        Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
919            sinwma_row_scalar(data, first, period, w_ptr, out)
920        }
921        _ => unreachable!(),
922    }
923}
924
925#[inline(always)]
926pub fn sinwma_batch_slice(
927    data: &[f64],
928    sweep: &SinWmaBatchRange,
929    kern: Kernel,
930) -> Result<SinWmaBatchOutput, SinWmaError> {
931    sinwma_batch_inner(data, sweep, kern, false)
932}
933
934#[inline(always)]
935pub fn sinwma_batch_par_slice(
936    data: &[f64],
937    sweep: &SinWmaBatchRange,
938    kern: Kernel,
939) -> Result<SinWmaBatchOutput, SinWmaError> {
940    sinwma_batch_inner(data, sweep, kern, true)
941}
942
943#[inline(always)]
944fn sinwma_batch_inner(
945    data: &[f64],
946    sweep: &SinWmaBatchRange,
947    kern: Kernel,
948    parallel: bool,
949) -> Result<SinWmaBatchOutput, SinWmaError> {
950    let combos = expand_grid(sweep)?;
951    if combos.is_empty() {
952        return Err(SinWmaError::InvalidRange {
953            start: sweep.period.0,
954            end: sweep.period.1,
955            step: sweep.period.2,
956        });
957    }
958
959    if data.is_empty() {
960        return Err(SinWmaError::EmptyInputData);
961    }
962
963    let first = data
964        .iter()
965        .position(|x| !x.is_nan())
966        .ok_or(SinWmaError::AllValuesNaN)?;
967    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
968    if data.len() - first < max_p {
969        return Err(SinWmaError::NotEnoughValidData {
970            needed: max_p,
971            valid: data.len() - first,
972        });
973    }
974
975    let rows = combos.len();
976    let cols = data.len();
977
978    let _total = rows.checked_mul(cols).ok_or(SinWmaError::InvalidRange {
979        start: sweep.period.0,
980        end: sweep.period.1,
981        step: sweep.period.2,
982    })?;
983    let warm: Vec<usize> = combos
984        .iter()
985        .map(|c| first + c.period.unwrap() - 1)
986        .collect();
987
988    let mut raw = make_uninit_matrix(rows, cols);
989    init_matrix_prefixes(&mut raw, cols, &warm);
990
991    let stride = round_up8(max_p);
992    let cap = rows.checked_mul(stride).ok_or(SinWmaError::InvalidRange {
993        start: sweep.period.0,
994        end: sweep.period.1,
995        step: sweep.period.2,
996    })?;
997    let mut flat_w = AVec::<f64>::with_capacity(CACHELINE_ALIGN, cap);
998    flat_w.resize(cap, 0.0);
999
1000    for (row, prm) in combos.iter().enumerate() {
1001        let p = prm.period.unwrap();
1002        let base = row * stride;
1003        let mut sum = 0.0;
1004        for k in 0..p {
1005            let a = (k as f64 + 1.0) * PI / (p as f64 + 1.0);
1006            let v = a.sin();
1007            flat_w[base + k] = v;
1008            sum += v;
1009        }
1010        let inv = 1.0 / sum;
1011        for k in 0..p {
1012            flat_w[base + k] *= inv;
1013        }
1014    }
1015
1016    let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
1017        let p = combos[row].period.unwrap();
1018        let w_ptr = flat_w.as_ptr().add(row * stride);
1019        let dst = core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
1020        sinwma_row_dispatch(data, first, p, w_ptr, dst, kern);
1021    };
1022
1023    if parallel {
1024        #[cfg(not(target_arch = "wasm32"))]
1025        {
1026            raw.par_chunks_mut(cols)
1027                .enumerate()
1028                .for_each(|(r, sl)| do_row(r, sl));
1029        }
1030
1031        #[cfg(target_arch = "wasm32")]
1032        {
1033            for (r, sl) in raw.chunks_mut(cols).enumerate() {
1034                do_row(r, sl);
1035            }
1036        }
1037    } else {
1038        for (r, sl) in raw.chunks_mut(cols).enumerate() {
1039            do_row(r, sl);
1040        }
1041    }
1042
1043    let mut guard = core::mem::ManuallyDrop::new(raw);
1044    let values = unsafe {
1045        Vec::from_raw_parts(
1046            guard.as_mut_ptr() as *mut f64,
1047            guard.len(),
1048            guard.capacity(),
1049        )
1050    };
1051
1052    Ok(SinWmaBatchOutput {
1053        values,
1054        combos,
1055        rows,
1056        cols,
1057    })
1058}
1059
1060#[inline(always)]
1061fn sinwma_batch_inner_into(
1062    data: &[f64],
1063    sweep: &SinWmaBatchRange,
1064    kern: Kernel,
1065    parallel: bool,
1066    out: &mut [f64],
1067) -> Result<Vec<SinWmaParams>, SinWmaError> {
1068    let combos = expand_grid(sweep)?;
1069    if combos.is_empty() {
1070        return Err(SinWmaError::InvalidRange {
1071            start: sweep.period.0,
1072            end: sweep.period.1,
1073            step: sweep.period.2,
1074        });
1075    }
1076
1077    if data.is_empty() {
1078        return Err(SinWmaError::EmptyInputData);
1079    }
1080
1081    let first = data
1082        .iter()
1083        .position(|x| !x.is_nan())
1084        .ok_or(SinWmaError::AllValuesNaN)?;
1085    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1086    if data.len() - first < max_p {
1087        return Err(SinWmaError::NotEnoughValidData {
1088            needed: max_p,
1089            valid: data.len() - first,
1090        });
1091    }
1092
1093    let rows = combos.len();
1094    let cols = data.len();
1095    let warm: Vec<usize> = combos
1096        .iter()
1097        .map(|c| first + c.period.unwrap() - 1)
1098        .collect();
1099
1100    let expected = rows.checked_mul(cols).ok_or(SinWmaError::InvalidRange {
1101        start: sweep.period.0,
1102        end: sweep.period.1,
1103        step: sweep.period.2,
1104    })?;
1105    if out.len() != expected {
1106        return Err(SinWmaError::OutputLengthMismatch {
1107            expected,
1108            got: out.len(),
1109        });
1110    }
1111
1112    let out_mu = unsafe {
1113        std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
1114    };
1115    init_matrix_prefixes(out_mu, cols, &warm);
1116
1117    let stride = round_up8(max_p);
1118    let cap = rows.checked_mul(stride).ok_or(SinWmaError::InvalidRange {
1119        start: sweep.period.0,
1120        end: sweep.period.1,
1121        step: sweep.period.2,
1122    })?;
1123    let mut flat_w = AVec::<f64>::with_capacity(CACHELINE_ALIGN, cap);
1124    flat_w.resize(cap, 0.0);
1125
1126    for (row, prm) in combos.iter().enumerate() {
1127        let p = prm.period.unwrap();
1128        let base = row * stride;
1129        let mut sum = 0.0;
1130        for k in 0..p {
1131            let a = (k as f64 + 1.0) * PI / (p as f64 + 1.0);
1132            let v = a.sin();
1133            flat_w[base + k] = v;
1134            sum += v;
1135        }
1136        let inv = 1.0 / sum;
1137        for k in 0..p {
1138            flat_w[base + k] *= inv;
1139        }
1140    }
1141
1142    let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
1143        let p = combos[row].period.unwrap();
1144        let w_ptr = flat_w.as_ptr().add(row * stride);
1145        let dst = core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
1146        sinwma_row_dispatch(data, first, p, w_ptr, dst, kern);
1147    };
1148
1149    if parallel {
1150        #[cfg(not(target_arch = "wasm32"))]
1151        {
1152            out_mu
1153                .par_chunks_mut(cols)
1154                .enumerate()
1155                .for_each(|(r, sl)| do_row(r, sl));
1156        }
1157        #[cfg(target_arch = "wasm32")]
1158        {
1159            for (r, sl) in out_mu.chunks_mut(cols).enumerate() {
1160                do_row(r, sl);
1161            }
1162        }
1163    } else {
1164        for (r, sl) in out_mu.chunks_mut(cols).enumerate() {
1165            do_row(r, sl);
1166        }
1167    }
1168
1169    Ok(combos)
1170}
1171
1172#[inline(always)]
1173pub unsafe fn sinwma_row_scalar(
1174    data: &[f64],
1175    first: usize,
1176    period: usize,
1177    w_ptr: *const f64,
1178    out: &mut [f64],
1179) {
1180    let p4 = period & !3;
1181    for i in (first + period - 1)..data.len() {
1182        let start = i + 1 - period;
1183        let mut sum = 0.0;
1184        for k in (0..p4).step_by(4) {
1185            let w = std::slice::from_raw_parts(w_ptr.add(k), 4);
1186            let d = &data[start + k..start + k + 4];
1187            sum += d[0] * w[0] + d[1] * w[1] + d[2] * w[2] + d[3] * w[3];
1188        }
1189        for k in p4..period {
1190            sum += *data.get_unchecked(start + k) * *w_ptr.add(k);
1191        }
1192        out[i] = sum;
1193    }
1194}
1195
1196#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1197#[inline]
1198#[target_feature(enable = "avx2,fma")]
1199pub unsafe fn sinwma_row_avx2(
1200    data: &[f64],
1201    first: usize,
1202    period: usize,
1203    w_ptr: *const f64,
1204    out: &mut [f64],
1205) {
1206    let p4 = period & !3usize;
1207    for i in (first + period - 1)..data.len() {
1208        let start = i + 1 - period;
1209        let mut acc = _mm256_setzero_pd();
1210
1211        let mut k = 0usize;
1212        while k < p4 {
1213            let d = _mm256_loadu_pd(data.as_ptr().add(start + k));
1214            let w = _mm256_loadu_pd(w_ptr.add(k));
1215            acc = _mm256_fmadd_pd(d, w, acc);
1216            k += 4;
1217        }
1218
1219        let sum128 = _mm_add_pd(_mm256_castpd256_pd128(acc), _mm256_extractf128_pd(acc, 1));
1220        let mut sum = _mm_cvtsd_f64(_mm_hadd_pd(sum128, sum128));
1221
1222        while k < period {
1223            sum = (*data.get_unchecked(start + k)).mul_add(*w_ptr.add(k), sum);
1224            k += 1;
1225        }
1226
1227        *out.get_unchecked_mut(i) = sum;
1228    }
1229}
1230
1231#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1232#[inline]
1233#[target_feature(enable = "avx512f,fma,avx512dq")]
1234pub unsafe fn sinwma_row_avx512(
1235    data: &[f64],
1236    first: usize,
1237    period: usize,
1238    w_ptr: *const f64,
1239    out: &mut [f64],
1240) {
1241    if period <= 32 {
1242        sinwma_row_avx512_short(data, first, period, w_ptr, out);
1243    } else {
1244        sinwma_row_avx512_long(data, first, period, w_ptr, out);
1245    }
1246}
1247
1248#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1249#[inline]
1250#[target_feature(enable = "avx512f,fma")]
1251pub unsafe fn sinwma_row_avx512_short(
1252    data: &[f64],
1253    first: usize,
1254    period: usize,
1255    w_ptr: *const f64,
1256    out: &mut [f64],
1257) {
1258    const STEP: usize = 8;
1259    let chunks = period / STEP;
1260    let tail = period % STEP;
1261    let mask: __mmask8 = (1u8 << tail).wrapping_sub(1);
1262
1263    if chunks == 0 {
1264        let wv = _mm512_maskz_loadu_pd(mask, w_ptr);
1265        for i in (first + period - 1)..data.len() {
1266            let start = i + 1 - period;
1267            let dv = _mm512_maskz_loadu_pd(mask, data.as_ptr().add(start));
1268            let sum = hsum_pd_zmm(_mm512_mul_pd(dv, wv));
1269            *out.get_unchecked_mut(i) = sum;
1270        }
1271        return;
1272    }
1273
1274    for i in (first + period - 1)..data.len() {
1275        let start = i + 1 - period;
1276        let mut acc = _mm512_setzero_pd();
1277        for blk in 0..chunks {
1278            let d = _mm512_loadu_pd(data.as_ptr().add(start + blk * STEP));
1279            let w = _mm512_loadu_pd(w_ptr.add(blk * STEP));
1280            acc = _mm512_fmadd_pd(d, w, acc);
1281        }
1282        if tail != 0 {
1283            let dt = _mm512_maskz_loadu_pd(mask, data.as_ptr().add(start + chunks * STEP));
1284            let wt = _mm512_maskz_loadu_pd(mask, w_ptr.add(chunks * STEP));
1285            acc = _mm512_fmadd_pd(dt, wt, acc);
1286        }
1287        *out.get_unchecked_mut(i) = hsum_pd_zmm(acc);
1288    }
1289}
1290
1291#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1292#[inline]
1293#[target_feature(enable = "avx512f,fma,avx512dq")]
1294pub unsafe fn sinwma_row_avx512_long(
1295    data: &[f64],
1296    first: usize,
1297    period: usize,
1298    w_ptr: *const f64,
1299    out: &mut [f64],
1300) {
1301    const STEP: usize = 8;
1302    let n_chunks = period / STEP;
1303    let tail_len = period % STEP;
1304    let unroll4 = n_chunks & !3;
1305    let mask: __mmask8 = (1u8 << tail_len).wrapping_sub(1);
1306
1307    let mut wregs: Vec<__m512d> = Vec::with_capacity(n_chunks);
1308    for blk in 0..n_chunks {
1309        wregs.push(_mm512_loadu_pd(w_ptr.add(blk * STEP)));
1310    }
1311    let wt = if tail_len != 0 {
1312        Some(_mm512_maskz_loadu_pd(mask, w_ptr.add(n_chunks * STEP)))
1313    } else {
1314        None
1315    };
1316
1317    for i in (first + period - 1)..data.len() {
1318        let start = i + 1 - period;
1319        let mut s0 = _mm512_setzero_pd();
1320        let mut s1 = _mm512_setzero_pd();
1321        let mut s2 = _mm512_setzero_pd();
1322        let mut s3 = _mm512_setzero_pd();
1323
1324        for blk in (0..unroll4).step_by(4) {
1325            let d0 = _mm512_loadu_pd(data.as_ptr().add(start + (blk + 0) * STEP));
1326            let d1 = _mm512_loadu_pd(data.as_ptr().add(start + (blk + 1) * STEP));
1327            let d2 = _mm512_loadu_pd(data.as_ptr().add(start + (blk + 2) * STEP));
1328            let d3 = _mm512_loadu_pd(data.as_ptr().add(start + (blk + 3) * STEP));
1329
1330            s0 = _mm512_fmadd_pd(d0, *wregs.get_unchecked(blk + 0), s0);
1331            s1 = _mm512_fmadd_pd(d1, *wregs.get_unchecked(blk + 1), s1);
1332            s2 = _mm512_fmadd_pd(d2, *wregs.get_unchecked(blk + 2), s2);
1333            s3 = _mm512_fmadd_pd(d3, *wregs.get_unchecked(blk + 3), s3);
1334        }
1335
1336        for blk in unroll4..n_chunks {
1337            let d = _mm512_loadu_pd(data.as_ptr().add(start + blk * STEP));
1338            s0 = _mm512_fmadd_pd(d, *wregs.get_unchecked(blk), s0);
1339        }
1340
1341        if let Some(wt) = wt {
1342            let dt = _mm512_maskz_loadu_pd(mask, data.as_ptr().add(start + n_chunks * STEP));
1343            s0 = _mm512_fmadd_pd(dt, wt, s0);
1344        }
1345
1346        let sum01 = _mm512_add_pd(s0, s1);
1347        let sum23 = _mm512_add_pd(s2, s3);
1348        let tot = _mm512_add_pd(sum01, sum23);
1349        *out.get_unchecked_mut(i) = hsum_pd_zmm(tot);
1350    }
1351}
1352
1353#[cfg(test)]
1354mod tests {
1355    use super::*;
1356    use crate::skip_if_unsupported;
1357    use crate::utilities::data_loader::read_candles_from_csv;
1358    #[cfg(feature = "proptest")]
1359    use proptest::prelude::*;
1360
1361    #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1362    #[test]
1363    fn test_sinwma_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
1364        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1365        let candles = read_candles_from_csv(file_path)?;
1366
1367        let input = SinWmaInput::with_default_candles(&candles);
1368
1369        let baseline = sinwma(&input)?.values;
1370
1371        let mut out = vec![0.0; candles.close.len()];
1372        sinwma_into(&input, &mut out)?;
1373
1374        assert_eq!(baseline.len(), out.len());
1375
1376        fn eq_or_both_nan(a: f64, b: f64) -> bool {
1377            (a.is_nan() && b.is_nan()) || (a == b)
1378        }
1379
1380        for (i, (a, b)) in baseline.iter().zip(out.iter()).enumerate() {
1381            assert!(
1382                eq_or_both_nan(*a, *b),
1383                "mismatch at index {}: baseline={}, into={}",
1384                i,
1385                a,
1386                b
1387            );
1388        }
1389
1390        Ok(())
1391    }
1392
1393    fn check_sinwma_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1394        skip_if_unsupported!(kernel, test_name);
1395        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1396        let candles = read_candles_from_csv(file_path)?;
1397        let default_params = SinWmaParams { period: None };
1398        let input = SinWmaInput::from_candles(&candles, "close", default_params);
1399        let output = sinwma_with_kernel(&input, kernel)?;
1400        assert_eq!(output.values.len(), candles.close.len());
1401        Ok(())
1402    }
1403
1404    fn check_sinwma_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1405        skip_if_unsupported!(kernel, test_name);
1406        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1407        let candles = read_candles_from_csv(file_path)?;
1408        let input = SinWmaInput::from_candles(&candles, "close", SinWmaParams { period: Some(14) });
1409        let result = sinwma_with_kernel(&input, kernel)?;
1410        let expected_last_five = [
1411            59376.72903536103,
1412            59300.76862770367,
1413            59229.27622157621,
1414            59178.48781774477,
1415            59154.66580703081,
1416        ];
1417        let start = result.values.len().saturating_sub(5);
1418        for (i, &val) in result.values[start..].iter().enumerate() {
1419            let diff = (val - expected_last_five[i]).abs();
1420            assert!(
1421                diff < 1e-6,
1422                "[{}] SINWMA {:?} mismatch at idx {}: got {}, expected {}",
1423                test_name,
1424                kernel,
1425                i,
1426                val,
1427                expected_last_five[i]
1428            );
1429        }
1430        Ok(())
1431    }
1432
1433    fn check_sinwma_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1434        skip_if_unsupported!(kernel, test_name);
1435        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1436        let candles = read_candles_from_csv(file_path)?;
1437        let input = SinWmaInput::with_default_candles(&candles);
1438        match input.data {
1439            SinWmaData::Candles { source, .. } => assert_eq!(source, "close"),
1440            _ => panic!("Expected SinWmaData::Candles"),
1441        }
1442        let output = sinwma_with_kernel(&input, kernel)?;
1443        assert_eq!(output.values.len(), candles.close.len());
1444        Ok(())
1445    }
1446
1447    fn check_sinwma_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1448        skip_if_unsupported!(kernel, test_name);
1449        let input_data = [10.0, 20.0, 30.0];
1450        let params = SinWmaParams { period: Some(0) };
1451        let input = SinWmaInput::from_slice(&input_data, params);
1452        let res = sinwma_with_kernel(&input, kernel);
1453        assert!(
1454            res.is_err(),
1455            "[{}] SINWMA should fail with zero period",
1456            test_name
1457        );
1458        Ok(())
1459    }
1460
1461    fn check_sinwma_period_exceeds_length(
1462        test_name: &str,
1463        kernel: Kernel,
1464    ) -> Result<(), Box<dyn Error>> {
1465        skip_if_unsupported!(kernel, test_name);
1466        let data_small = [10.0, 20.0, 30.0];
1467        let params = SinWmaParams { period: Some(10) };
1468        let input = SinWmaInput::from_slice(&data_small, params);
1469        let res = sinwma_with_kernel(&input, kernel);
1470        assert!(
1471            res.is_err(),
1472            "[{}] SINWMA should fail with period exceeding length",
1473            test_name
1474        );
1475        Ok(())
1476    }
1477
1478    fn check_sinwma_very_small_dataset(
1479        test_name: &str,
1480        kernel: Kernel,
1481    ) -> Result<(), Box<dyn Error>> {
1482        skip_if_unsupported!(kernel, test_name);
1483        let single_point = [42.0];
1484        let params = SinWmaParams { period: Some(14) };
1485        let input = SinWmaInput::from_slice(&single_point, params);
1486        let res = sinwma_with_kernel(&input, kernel);
1487        assert!(
1488            res.is_err(),
1489            "[{}] SINWMA should fail with insufficient data",
1490            test_name
1491        );
1492        Ok(())
1493    }
1494
1495    fn check_sinwma_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1496        skip_if_unsupported!(kernel, test_name);
1497        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1498        let candles = read_candles_from_csv(file_path)?;
1499
1500        let first_params = SinWmaParams { period: Some(14) };
1501        let first_input = SinWmaInput::from_candles(&candles, "close", first_params);
1502        let first_result = sinwma_with_kernel(&first_input, kernel)?;
1503
1504        let second_params = SinWmaParams { period: Some(5) };
1505        let second_input = SinWmaInput::from_slice(&first_result.values, second_params);
1506        let second_result = sinwma_with_kernel(&second_input, kernel)?;
1507
1508        assert_eq!(second_result.values.len(), first_result.values.len());
1509        for &val in second_result.values.iter().skip(240) {
1510            assert!(val.is_finite());
1511        }
1512        Ok(())
1513    }
1514
1515    fn check_sinwma_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1516        skip_if_unsupported!(kernel, test_name);
1517        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1518        let candles = read_candles_from_csv(file_path)?;
1519        let input = SinWmaInput::from_candles(&candles, "close", SinWmaParams { period: Some(14) });
1520        let res = sinwma_with_kernel(&input, kernel)?;
1521        assert_eq!(res.values.len(), candles.close.len());
1522        if res.values.len() > 240 {
1523            for (i, &val) in res.values[240..].iter().enumerate() {
1524                assert!(
1525                    !val.is_nan(),
1526                    "[{}] Found unexpected NaN at out-index {}",
1527                    test_name,
1528                    240 + i
1529                );
1530            }
1531        }
1532        Ok(())
1533    }
1534
1535    fn check_sinwma_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1536        skip_if_unsupported!(kernel, test_name);
1537
1538        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1539        let candles = read_candles_from_csv(file_path)?;
1540
1541        let period = 14;
1542
1543        let input = SinWmaInput::from_candles(
1544            &candles,
1545            "close",
1546            SinWmaParams {
1547                period: Some(period),
1548            },
1549        );
1550        let batch_output = sinwma_with_kernel(&input, kernel)?.values;
1551
1552        let mut stream = SinWmaStream::try_new(SinWmaParams {
1553            period: Some(period),
1554        })?;
1555
1556        let mut stream_values = Vec::with_capacity(candles.close.len());
1557        for &price in &candles.close {
1558            match stream.update(price) {
1559                Some(val) => stream_values.push(val),
1560                None => stream_values.push(f64::NAN),
1561            }
1562        }
1563
1564        assert_eq!(batch_output.len(), stream_values.len());
1565        for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1566            if b.is_nan() && s.is_nan() {
1567                continue;
1568            }
1569            let diff = (b - s).abs();
1570            assert!(
1571                diff < 1e-9,
1572                "[{}] SINWMA streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1573                test_name,
1574                i,
1575                b,
1576                s,
1577                diff
1578            );
1579        }
1580        Ok(())
1581    }
1582
1583    #[cfg(debug_assertions)]
1584    fn check_sinwma_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1585        skip_if_unsupported!(kernel, test_name);
1586
1587        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1588        let candles = read_candles_from_csv(file_path)?;
1589
1590        let test_periods = vec![5, 10, 14, 20, 30, 50];
1591
1592        for period in test_periods {
1593            let params = SinWmaParams {
1594                period: Some(period),
1595            };
1596            let input = SinWmaInput::from_candles(&candles, "close", params);
1597            let output = sinwma_with_kernel(&input, kernel)?;
1598
1599            for (i, &val) in output.values.iter().enumerate() {
1600                if val.is_nan() {
1601                    continue;
1602                }
1603
1604                let bits = val.to_bits();
1605
1606                if bits == 0x11111111_11111111 {
1607                    panic!(
1608						"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} (period={})",
1609						test_name, val, bits, i, period
1610					);
1611                }
1612
1613                if bits == 0x22222222_22222222 {
1614                    panic!(
1615						"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} (period={})",
1616						test_name, val, bits, i, period
1617					);
1618                }
1619
1620                if bits == 0x33333333_33333333 {
1621                    panic!(
1622						"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} (period={})",
1623						test_name, val, bits, i, period
1624					);
1625                }
1626            }
1627        }
1628
1629        Ok(())
1630    }
1631
1632    #[cfg(not(debug_assertions))]
1633    fn check_sinwma_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1634        Ok(())
1635    }
1636
1637    #[cfg(feature = "proptest")]
1638    #[allow(clippy::float_cmp)]
1639    fn check_sinwma_property(
1640        test_name: &str,
1641        kernel: Kernel,
1642    ) -> Result<(), Box<dyn std::error::Error>> {
1643        use proptest::prelude::*;
1644        skip_if_unsupported!(kernel, test_name);
1645
1646        let strat = (1usize..=100).prop_flat_map(|period| {
1647            (
1648                prop::collection::vec(
1649                    (-1e5f64..1e5f64).prop_filter("finite", |x| x.is_finite()),
1650                    period..=500,
1651                ),
1652                Just(period),
1653            )
1654        });
1655
1656        proptest::test_runner::TestRunner::default()
1657			.run(&strat, |(data, period)| {
1658				let params = SinWmaParams { period: Some(period) };
1659				let input = SinWmaInput::from_slice(&data, params);
1660
1661				let SinWmaOutput { values: out } = sinwma_with_kernel(&input, kernel).unwrap();
1662				let SinWmaOutput { values: ref_out } = sinwma_with_kernel(&input, Kernel::Scalar).unwrap();
1663
1664
1665				prop_assert_eq!(out.len(), data.len(), "Output length should match input length");
1666
1667
1668
1669				let warmup_end = period - 1;
1670				for i in 0..warmup_end.min(data.len()) {
1671					prop_assert!(
1672						out[i].is_nan(),
1673						"[{}] Expected NaN at index {} during warmup (period={})",
1674						test_name,
1675						i,
1676						period
1677					);
1678				}
1679
1680
1681				for i in warmup_end..data.len() {
1682					let y = out[i];
1683
1684
1685					if y.is_nan() {
1686						continue;
1687					}
1688
1689
1690					let window_start = i + 1 - period;
1691					let window = &data[window_start..=i];
1692
1693					let lo = window.iter().cloned().fold(f64::INFINITY, f64::min);
1694					let hi = window.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1695
1696
1697					let tolerance = 1e-9 + (hi - lo).abs() * 1e-12;
1698					prop_assert!(
1699						y >= lo - tolerance && y <= hi + tolerance,
1700						"[{}] idx {}: value {} not in window bounds [{}, {}] (period={})",
1701						test_name,
1702						i,
1703						y,
1704						lo,
1705						hi,
1706						period
1707					);
1708				}
1709
1710
1711				if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-12) && !data.is_empty() {
1712					for i in warmup_end..data.len() {
1713						if !out[i].is_nan() {
1714							prop_assert!(
1715								(out[i] - data[0]).abs() <= 1e-9,
1716								"[{}] Constant input should produce constant output: expected {}, got {} at index {}",
1717								test_name,
1718								data[0],
1719								out[i],
1720								i
1721							);
1722						}
1723					}
1724				}
1725
1726
1727				if period == 1 {
1728
1729
1730					for i in 0..data.len() {
1731						if !out[i].is_nan() && !data[i].is_nan() {
1732							prop_assert!(
1733								(out[i] - data[i]).abs() <= 1e-12,
1734								"[{}] Period=1 should pass through input: expected {}, got {} at index {}",
1735								test_name,
1736								data[i],
1737								out[i],
1738								i
1739							);
1740						}
1741					}
1742				}
1743
1744
1745
1746				for i in 0..data.len() {
1747					if out[i].is_nan() && ref_out[i].is_nan() {
1748						continue;
1749					}
1750
1751
1752					let y_bits = out[i].to_bits();
1753					let r_bits = ref_out[i].to_bits();
1754
1755					prop_assert_eq!(
1756						y_bits,
1757						r_bits,
1758						"[{}] Kernel consistency failed at index {}: {:?} gives {}, Scalar gives {}",
1759						test_name,
1760						i,
1761						kernel,
1762						out[i],
1763						ref_out[i]
1764					);
1765				}
1766
1767
1768
1769
1770
1771
1772				for i in warmup_end..data.len() {
1773					if out[i].is_nan() {
1774						continue;
1775					}
1776
1777					let window_start = i + 1 - period;
1778					let window = &data[window_start..=i];
1779
1780					let all_positive = window.iter().all(|&x| x > 0.0);
1781					let all_negative = window.iter().all(|&x| x < 0.0);
1782
1783					if all_positive {
1784						prop_assert!(
1785							out[i] > 0.0,
1786							"[{}] All positive window should produce positive output, got {} at index {}",
1787							test_name,
1788							out[i],
1789							i
1790						);
1791					}
1792
1793					if all_negative {
1794						prop_assert!(
1795							out[i] < 0.0,
1796							"[{}] All negative window should produce negative output, got {} at index {}",
1797							test_name,
1798							out[i],
1799							i
1800						);
1801					}
1802				}
1803
1804
1805
1806
1807
1808				if data.len() >= period * 2 {
1809					for i in (warmup_end + period)..data.len().min(warmup_end + period * 3) {
1810						if out[i].is_nan() {
1811							continue;
1812						}
1813
1814						let window_start = i + 1 - period;
1815						let window = &data[window_start..=i];
1816
1817
1818						let window_min = window.iter().cloned().fold(f64::INFINITY, f64::min);
1819						let window_max = window.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1820						let range = window_max - window_min;
1821
1822						if range > 1.0 {
1823
1824							let middle_idx = period / 2;
1825							let middle_value = window[middle_idx];
1826							let first_value = window[0];
1827							let last_value = window[period - 1];
1828
1829
1830
1831							let clearly_ascending = first_value < middle_value && middle_value < last_value
1832								&& (last_value - first_value) > range * 0.8;
1833							let clearly_descending = first_value > middle_value && middle_value > last_value
1834								&& (first_value - last_value) > range * 0.8;
1835
1836							if clearly_ascending || clearly_descending {
1837
1838								let dist_to_middle = (out[i] - middle_value).abs();
1839								let dist_to_first = (out[i] - first_value).abs();
1840								let dist_to_last = (out[i] - last_value).abs();
1841
1842
1843
1844								prop_assert!(
1845									dist_to_middle < dist_to_first.min(dist_to_last) * 1.2,
1846									"[{}] idx {}: SINWMA output {} should be closer to middle {} than to extremes [{}, {}] (period={})",
1847									test_name,
1848									i,
1849									out[i],
1850									middle_value,
1851									first_value,
1852									last_value,
1853									period
1854								);
1855							}
1856						}
1857					}
1858				}
1859
1860				Ok(())
1861			})?;
1862
1863        Ok(())
1864    }
1865
1866    macro_rules! generate_all_sinwma_tests {
1867        ($($test_fn:ident),*) => {
1868            paste::paste! {
1869                $(
1870                    #[test]
1871                    fn [<$test_fn _scalar_f64>]() {
1872                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1873                    }
1874                )*
1875                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1876                $(
1877                    #[test]
1878                    fn [<$test_fn _avx2_f64>]() {
1879                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1880                    }
1881                    #[test]
1882                    fn [<$test_fn _avx512_f64>]() {
1883                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1884                    }
1885                )*
1886            }
1887        }
1888    }
1889
1890    generate_all_sinwma_tests!(
1891        check_sinwma_partial_params,
1892        check_sinwma_accuracy,
1893        check_sinwma_default_candles,
1894        check_sinwma_zero_period,
1895        check_sinwma_period_exceeds_length,
1896        check_sinwma_very_small_dataset,
1897        check_sinwma_reinput,
1898        check_sinwma_nan_handling,
1899        check_sinwma_streaming,
1900        check_sinwma_no_poison
1901    );
1902
1903    #[cfg(feature = "proptest")]
1904    generate_all_sinwma_tests!(check_sinwma_property);
1905
1906    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1907        skip_if_unsupported!(kernel, test);
1908
1909        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1910        let c = read_candles_from_csv(file)?;
1911
1912        let output = SinWmaBatchBuilder::new()
1913            .kernel(kernel)
1914            .apply_candles(&c, "close")?;
1915
1916        let def = SinWmaParams::default();
1917        let row = output.values_for(&def).expect("default row missing");
1918
1919        assert_eq!(row.len(), c.close.len());
1920
1921        let expected = [
1922            59376.72903536103,
1923            59300.76862770367,
1924            59229.27622157621,
1925            59178.48781774477,
1926            59154.66580703081,
1927        ];
1928        let start = row.len() - 5;
1929        for (i, &v) in row[start..].iter().enumerate() {
1930            assert!(
1931                (v - expected[i]).abs() < 1e-6,
1932                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1933            );
1934        }
1935        Ok(())
1936    }
1937
1938    #[cfg(debug_assertions)]
1939    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1940        skip_if_unsupported!(kernel, test);
1941
1942        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1943        let c = read_candles_from_csv(file)?;
1944
1945        let test_configs = vec![(5, 15, 5), (10, 30, 10), (20, 50, 15), (2, 10, 2)];
1946
1947        for (start, end, step) in test_configs {
1948            let output = SinWmaBatchBuilder::new()
1949                .kernel(kernel)
1950                .period_range(start, end, step)
1951                .apply_candles(&c, "close")?;
1952
1953            for (idx, &val) in output.values.iter().enumerate() {
1954                if val.is_nan() {
1955                    continue;
1956                }
1957
1958                let bits = val.to_bits();
1959                let row = idx / output.cols;
1960                let col = idx % output.cols;
1961                let period = output.combos[row].period.unwrap();
1962
1963                if bits == 0x11111111_11111111 {
1964                    panic!(
1965                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (flat index {}, period={})",
1966                        test, val, bits, row, col, idx, period
1967                    );
1968                }
1969
1970                if bits == 0x22222222_22222222 {
1971                    panic!(
1972                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (flat index {}, period={})",
1973                        test, val, bits, row, col, idx, period
1974                    );
1975                }
1976
1977                if bits == 0x33333333_33333333 {
1978                    panic!(
1979                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (flat index {}, period={})",
1980                        test, val, bits, row, col, idx, period
1981                    );
1982                }
1983            }
1984        }
1985
1986        Ok(())
1987    }
1988
1989    #[cfg(not(debug_assertions))]
1990    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1991        Ok(())
1992    }
1993
1994    macro_rules! gen_batch_tests {
1995        ($fn_name:ident) => {
1996            paste::paste! {
1997                #[test] fn [<$fn_name _scalar>]()      {
1998                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1999                }
2000                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2001                #[test] fn [<$fn_name _avx2>]()        {
2002                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2003                }
2004                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2005                #[test] fn [<$fn_name _avx512>]()      {
2006                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2007                }
2008                #[test] fn [<$fn_name _auto_detect>]() {
2009                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2010                }
2011            }
2012        };
2013    }
2014    gen_batch_tests!(check_batch_default_row);
2015    gen_batch_tests!(check_batch_no_poison);
2016}
2017
2018#[cfg(feature = "python")]
2019#[pyfunction(name = "sinwma")]
2020#[pyo3(signature = (data, period, kernel=None))]
2021
2022pub fn sinwma_py<'py>(
2023    py: Python<'py>,
2024    data: PyReadonlyArray1<'py, f64>,
2025    period: usize,
2026    kernel: Option<&str>,
2027) -> PyResult<Bound<'py, PyArray1<f64>>> {
2028    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
2029
2030    let slice_in = data.as_slice()?;
2031    let kern = validate_kernel(kernel, false)?;
2032
2033    let params = SinWmaParams {
2034        period: Some(period),
2035    };
2036    let input = SinWmaInput::from_slice(slice_in, params);
2037
2038    let result_vec: Vec<f64> = py
2039        .allow_threads(|| sinwma_with_kernel(&input, kern).map(|o| o.values))
2040        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2041
2042    Ok(result_vec.into_pyarray(py))
2043}
2044
2045#[cfg(feature = "python")]
2046#[pyclass(name = "SinWmaStream")]
2047
2048pub struct SinWmaStreamPy {
2049    stream: SinWmaStream,
2050}
2051
2052#[cfg(feature = "python")]
2053#[pymethods]
2054impl SinWmaStreamPy {
2055    #[new]
2056
2057    fn new(period: usize) -> PyResult<Self> {
2058        let params = SinWmaParams {
2059            period: Some(period),
2060        };
2061        let stream =
2062            SinWmaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2063        Ok(SinWmaStreamPy { stream })
2064    }
2065
2066    fn update(&mut self, value: f64) -> Option<f64> {
2067        self.stream.update(value)
2068    }
2069}
2070
2071#[cfg(feature = "python")]
2072#[pyfunction(name = "sinwma_batch")]
2073#[pyo3(signature = (data, period_range, kernel=None))]
2074
2075pub fn sinwma_batch_py<'py>(
2076    py: Python<'py>,
2077    data: PyReadonlyArray1<'py, f64>,
2078    period_range: (usize, usize, usize),
2079    kernel: Option<&str>,
2080) -> PyResult<Bound<'py, PyDict>> {
2081    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
2082    use pyo3::types::PyDict;
2083
2084    let slice_in = data.as_slice()?;
2085
2086    let sweep = SinWmaBatchRange {
2087        period: period_range,
2088    };
2089
2090    let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
2091    let rows = combos.len();
2092    let cols = slice_in.len();
2093
2094    let out_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
2095    let slice_out = unsafe { out_arr.as_slice_mut()? };
2096
2097    let kern = validate_kernel(kernel, true)?;
2098
2099    let combos = py
2100        .allow_threads(|| {
2101            let kernel = match kern {
2102                Kernel::Auto => detect_best_batch_kernel(),
2103                k => k,
2104            };
2105            let simd = match kernel {
2106                Kernel::Avx512Batch => Kernel::Avx512,
2107                Kernel::Avx2Batch => Kernel::Avx2,
2108                Kernel::ScalarBatch => Kernel::Scalar,
2109                _ => unreachable!(),
2110            };
2111            sinwma_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
2112        })
2113        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2114
2115    let dict = PyDict::new(py);
2116    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2117    dict.set_item(
2118        "periods",
2119        combos
2120            .iter()
2121            .map(|p| p.period.unwrap() as u64)
2122            .collect::<Vec<_>>()
2123            .into_pyarray(py),
2124    )?;
2125
2126    Ok(dict)
2127}
2128
2129#[cfg(all(feature = "python", feature = "cuda"))]
2130#[pyfunction(name = "sinwma_cuda_batch_dev")]
2131#[pyo3(signature = (data_f32, period_range, device_id=0))]
2132pub fn sinwma_cuda_batch_dev_py(
2133    py: Python<'_>,
2134    data_f32: PyReadonlyArray1<'_, f32>,
2135    period_range: (usize, usize, usize),
2136    device_id: usize,
2137) -> PyResult<DeviceArrayF32Py> {
2138    if !cuda_available() {
2139        return Err(PyValueError::new_err("CUDA not available"));
2140    }
2141
2142    let slice_in = data_f32.as_slice()?;
2143    let sweep = SinWmaBatchRange {
2144        period: period_range,
2145    };
2146
2147    let inner = py.allow_threads(|| {
2148        let cuda = CudaSinwma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2149        cuda.sinwma_batch_dev(slice_in, &sweep)
2150            .map_err(|e| PyValueError::new_err(e.to_string()))
2151    })?;
2152
2153    Ok(make_device_array_py(device_id, inner)?)
2154}
2155
2156#[cfg(all(feature = "python", feature = "cuda"))]
2157#[pyfunction(name = "sinwma_cuda_many_series_one_param_dev")]
2158#[pyo3(signature = (data_tm_f32, period, device_id=0))]
2159pub fn sinwma_cuda_many_series_one_param_dev_py(
2160    py: Python<'_>,
2161    data_tm_f32: PyReadonlyArray2<'_, f32>,
2162    period: usize,
2163    device_id: usize,
2164) -> PyResult<DeviceArrayF32Py> {
2165    if !cuda_available() {
2166        return Err(PyValueError::new_err("CUDA not available"));
2167    }
2168
2169    let flat_in = data_tm_f32.as_slice()?;
2170    let rows = data_tm_f32.shape()[0];
2171    let cols = data_tm_f32.shape()[1];
2172    let params = SinWmaParams {
2173        period: Some(period),
2174    };
2175
2176    let inner = py.allow_threads(|| {
2177        let cuda = CudaSinwma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2178        cuda.sinwma_many_series_one_param_time_major_dev(flat_in, cols, rows, &params)
2179            .map_err(|e| PyValueError::new_err(e.to_string()))
2180    })?;
2181
2182    Ok(make_device_array_py(device_id, inner)?)
2183}
2184
2185#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2186#[wasm_bindgen]
2187pub fn sinwma_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
2188    let params = SinWmaParams {
2189        period: Some(period),
2190    };
2191    let input = SinWmaInput::from_slice(data, params);
2192
2193    let mut output = vec![0.0; data.len()];
2194
2195    sinwma_into_slice(&mut output, &input, Kernel::Auto)
2196        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2197
2198    Ok(output)
2199}
2200
2201#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2202#[wasm_bindgen]
2203pub fn sinwma_alloc(len: usize) -> *mut f64 {
2204    let mut vec = Vec::<f64>::with_capacity(len);
2205    let ptr = vec.as_mut_ptr();
2206    std::mem::forget(vec);
2207    ptr
2208}
2209
2210#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2211#[wasm_bindgen]
2212pub fn sinwma_free(ptr: *mut f64, len: usize) {
2213    if !ptr.is_null() {
2214        unsafe {
2215            let _ = Vec::from_raw_parts(ptr, len, len);
2216        }
2217    }
2218}
2219
2220#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2221#[wasm_bindgen]
2222pub fn sinwma_into(
2223    in_ptr: *const f64,
2224    out_ptr: *mut f64,
2225    len: usize,
2226    period: usize,
2227) -> Result<(), JsValue> {
2228    if in_ptr.is_null() || out_ptr.is_null() {
2229        return Err(JsValue::from_str("null pointer passed to sinwma_into"));
2230    }
2231
2232    unsafe {
2233        let data = std::slice::from_raw_parts(in_ptr, len);
2234
2235        if period == 0 || period > len {
2236            return Err(JsValue::from_str("Invalid period"));
2237        }
2238
2239        let params = SinWmaParams {
2240            period: Some(period),
2241        };
2242        let input = SinWmaInput::from_slice(data, params);
2243
2244        if in_ptr == out_ptr {
2245            let mut temp = vec![0.0; len];
2246            sinwma_into_slice(&mut temp, &input, Kernel::Auto)
2247                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2248
2249            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2250            out.copy_from_slice(&temp);
2251        } else {
2252            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2253            sinwma_into_slice(out, &input, Kernel::Auto)
2254                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2255        }
2256
2257        Ok(())
2258    }
2259}
2260
2261#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2262#[derive(Serialize, Deserialize)]
2263pub struct SinWmaBatchConfig {
2264    pub period_range: (usize, usize, usize),
2265}
2266
2267#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2268#[derive(Serialize, Deserialize)]
2269pub struct SinWmaBatchJsOutput {
2270    pub values: Vec<f64>,
2271    pub periods: Vec<usize>,
2272    pub rows: usize,
2273    pub cols: usize,
2274}
2275
2276#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2277#[wasm_bindgen(js_name = sinwma_batch)]
2278pub fn sinwma_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2279    let config: SinWmaBatchConfig = serde_wasm_bindgen::from_value(config)
2280        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2281
2282    let sweep = SinWmaBatchRange {
2283        period: config.period_range,
2284    };
2285
2286    let output = sinwma_batch_with_kernel(data, &sweep, Kernel::Auto)
2287        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2288
2289    let periods: Vec<usize> = output.combos.iter().map(|c| c.period.unwrap()).collect();
2290
2291    let js_output = SinWmaBatchJsOutput {
2292        values: output.values,
2293        periods,
2294        rows: output.rows,
2295        cols: output.cols,
2296    };
2297
2298    serde_wasm_bindgen::to_value(&js_output)
2299        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2300}
2301
2302#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2303#[wasm_bindgen]
2304pub fn sinwma_batch_js(
2305    data: &[f64],
2306    period_start: usize,
2307    period_end: usize,
2308    period_step: usize,
2309) -> Result<Vec<f64>, JsValue> {
2310    let sweep = SinWmaBatchRange {
2311        period: (period_start, period_end, period_step),
2312    };
2313
2314    sinwma_batch_with_kernel(data, &sweep, Kernel::Auto)
2315        .map(|output| output.values)
2316        .map_err(|e| JsValue::from_str(&e.to_string()))
2317}
2318
2319#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2320#[wasm_bindgen]
2321pub fn sinwma_batch_metadata_js(
2322    period_start: usize,
2323    period_end: usize,
2324    period_step: usize,
2325) -> Vec<u32> {
2326    let sweep = SinWmaBatchRange {
2327        period: (period_start, period_end, period_step),
2328    };
2329
2330    let combos = expand_grid(&sweep).unwrap_or_else(|_| Vec::new());
2331    combos.iter().map(|p| p.period.unwrap() as u32).collect()
2332}
2333
2334#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2335#[wasm_bindgen]
2336pub fn sinwma_batch_rows_cols_js(
2337    period_start: usize,
2338    period_end: usize,
2339    period_step: usize,
2340    data_len: usize,
2341) -> Vec<u32> {
2342    let sweep = SinWmaBatchRange {
2343        period: (period_start, period_end, period_step),
2344    };
2345
2346    let combos = expand_grid(&sweep).unwrap_or_else(|_| Vec::new());
2347    vec![combos.len() as u32, data_len as u32]
2348}
2349
2350#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2351#[wasm_bindgen]
2352pub fn sinwma_batch_into(
2353    in_ptr: *const f64,
2354    out_ptr: *mut f64,
2355    len: usize,
2356    period_start: usize,
2357    period_end: usize,
2358    period_step: usize,
2359) -> Result<usize, JsValue> {
2360    if in_ptr.is_null() || out_ptr.is_null() {
2361        return Err(JsValue::from_str(
2362            "null pointer passed to sinwma_batch_into",
2363        ));
2364    }
2365    unsafe {
2366        let data = std::slice::from_raw_parts(in_ptr, len);
2367        let sweep = SinWmaBatchRange {
2368            period: (period_start, period_end, period_step),
2369        };
2370        let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2371        let rows = combos.len();
2372        let cols = len;
2373
2374        let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
2375
2376        let simd = match detect_best_batch_kernel() {
2377            Kernel::Avx512Batch => Kernel::Avx512,
2378            Kernel::Avx2Batch => Kernel::Avx2,
2379            _ => Kernel::Scalar,
2380        };
2381        sinwma_batch_inner_into(data, &sweep, simd, false, out)
2382            .map_err(|e| JsValue::from_str(&e.to_string()))?;
2383        Ok(rows)
2384    }
2385}