Skip to main content

vector_ta/indicators/moving_averages/
wma.rs

1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::cuda_available;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::cuda::moving_averages::wma_wrapper::DeviceArrayF32Py;
5#[cfg(all(feature = "python", feature = "cuda"))]
6use crate::cuda::moving_averages::CudaWma;
7#[cfg(feature = "python")]
8use crate::utilities::kernel_validation::validate_kernel;
9#[cfg(feature = "python")]
10use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1, PyReadonlyArray2};
11#[cfg(feature = "python")]
12use pyo3::exceptions::PyValueError;
13#[cfg(feature = "python")]
14use pyo3::prelude::*;
15#[cfg(feature = "python")]
16use pyo3::types::PyDict;
17#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
18use serde::{Deserialize, Serialize};
19#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
20use wasm_bindgen::prelude::*;
21
22use crate::utilities::data_loader::{source_type, Candles};
23use crate::utilities::enums::Kernel;
24use crate::utilities::helpers::{
25    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
26    make_uninit_matrix,
27};
28use aligned_vec::{AVec, CACHELINE_ALIGN};
29#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
30use core::arch::x86_64::*;
31#[cfg(not(target_arch = "wasm32"))]
32use rayon::prelude::*;
33use std::convert::AsRef;
34use std::error::Error;
35use std::mem::MaybeUninit;
36use thiserror::Error;
37
38impl<'a> AsRef<[f64]> for WmaInput<'a> {
39    #[inline(always)]
40    fn as_ref(&self) -> &[f64] {
41        match &self.data {
42            WmaData::Slice(slice) => slice,
43            WmaData::Candles { candles, source } => source_type(candles, source),
44        }
45    }
46}
47
48#[derive(Debug, Clone)]
49pub enum WmaData<'a> {
50    Candles {
51        candles: &'a Candles,
52        source: &'a str,
53    },
54    Slice(&'a [f64]),
55}
56
57#[derive(Debug, Clone)]
58pub struct WmaOutput {
59    pub values: Vec<f64>,
60}
61
62#[derive(Debug, Clone)]
63#[cfg_attr(
64    all(target_arch = "wasm32", feature = "wasm"),
65    derive(Serialize, Deserialize)
66)]
67pub struct WmaParams {
68    pub period: Option<usize>,
69}
70
71impl Default for WmaParams {
72    fn default() -> Self {
73        Self { period: Some(30) }
74    }
75}
76
77#[derive(Debug, Clone)]
78pub struct WmaInput<'a> {
79    pub data: WmaData<'a>,
80    pub params: WmaParams,
81}
82
83impl<'a> WmaInput<'a> {
84    #[inline]
85    pub fn from_candles(c: &'a Candles, s: &'a str, p: WmaParams) -> Self {
86        Self {
87            data: WmaData::Candles {
88                candles: c,
89                source: s,
90            },
91            params: p,
92        }
93    }
94    #[inline]
95    pub fn from_slice(sl: &'a [f64], p: WmaParams) -> Self {
96        Self {
97            data: WmaData::Slice(sl),
98            params: p,
99        }
100    }
101    #[inline]
102    pub fn with_default_candles(c: &'a Candles) -> Self {
103        Self::from_candles(c, "close", WmaParams::default())
104    }
105    #[inline]
106    pub fn get_period(&self) -> usize {
107        self.params.period.unwrap_or(30)
108    }
109}
110
111#[derive(Copy, Clone, Debug)]
112pub struct WmaBuilder {
113    period: Option<usize>,
114    kernel: Kernel,
115}
116
117impl Default for WmaBuilder {
118    fn default() -> Self {
119        Self {
120            period: None,
121            kernel: Kernel::Auto,
122        }
123    }
124}
125
126impl WmaBuilder {
127    #[inline(always)]
128    pub fn new() -> Self {
129        Self::default()
130    }
131    #[inline(always)]
132    pub fn period(mut self, n: usize) -> Self {
133        self.period = Some(n);
134        self
135    }
136    #[inline(always)]
137    pub fn kernel(mut self, k: Kernel) -> Self {
138        self.kernel = k;
139        self
140    }
141
142    #[inline(always)]
143    pub fn apply(self, c: &Candles) -> Result<WmaOutput, WmaError> {
144        let p = WmaParams {
145            period: self.period,
146        };
147        let i = WmaInput::from_candles(c, "close", p);
148        wma_with_kernel(&i, self.kernel)
149    }
150
151    #[inline(always)]
152    pub fn apply_slice(self, d: &[f64]) -> Result<WmaOutput, WmaError> {
153        let p = WmaParams {
154            period: self.period,
155        };
156        let i = WmaInput::from_slice(d, p);
157        wma_with_kernel(&i, self.kernel)
158    }
159
160    #[inline(always)]
161    pub fn into_stream(self) -> Result<WmaStream, WmaError> {
162        let p = WmaParams {
163            period: self.period,
164        };
165        WmaStream::try_new(p)
166    }
167}
168
169#[derive(Debug, Error)]
170pub enum WmaError {
171    #[error("wma: Input data slice is empty.")]
172    EmptyInputData,
173
174    #[error("wma: All values are NaN.")]
175    AllValuesNaN,
176
177    #[error("wma: Invalid period: period = {period}, data length = {data_len}")]
178    InvalidPeriod { period: usize, data_len: usize },
179
180    #[error("wma: Not enough valid data: needed = {needed}, valid = {valid}")]
181    NotEnoughValidData { needed: usize, valid: usize },
182
183    #[error("wma: Non-batch kernel passed to batch path: {0:?}")]
184    InvalidKernelForBatch(crate::utilities::enums::Kernel),
185
186    #[error("wma: Output length mismatch: expected = {expected}, got = {got}")]
187    OutputLengthMismatch { expected: usize, got: usize },
188
189    #[error("wma: Invalid range expansion: start = {start}, end = {end}, step = {step}")]
190    InvalidRange {
191        start: usize,
192        end: usize,
193        step: usize,
194    },
195
196    #[error("wma: invalid input: {0}")]
197    InvalidInput(String),
198}
199
200#[inline]
201pub fn wma(input: &WmaInput) -> Result<WmaOutput, WmaError> {
202    wma_with_kernel(input, Kernel::Auto)
203}
204
205pub fn wma_with_kernel(input: &WmaInput, kernel: Kernel) -> Result<WmaOutput, WmaError> {
206    let (data, period, first, chosen) = wma_prepare(input, kernel)?;
207    let len = data.len();
208    let warm = first + period - 1;
209    let mut out = alloc_with_nan_prefix(len, warm);
210
211    wma_compute_into(data, period, first, chosen, &mut out);
212
213    Ok(WmaOutput { values: out })
214}
215
216#[inline]
217pub fn wma_into_slice(dst: &mut [f64], input: &WmaInput, kern: Kernel) -> Result<(), WmaError> {
218    let (data, period, first, chosen) = wma_prepare(input, kern)?;
219
220    if dst.len() != data.len() {
221        return Err(WmaError::OutputLengthMismatch {
222            expected: data.len(),
223            got: dst.len(),
224        });
225    }
226
227    wma_compute_into(data, period, first, chosen, dst);
228
229    let warmup_end = first + period - 1;
230    for v in &mut dst[..warmup_end] {
231        *v = f64::NAN;
232    }
233
234    Ok(())
235}
236
237#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
238#[inline]
239pub fn wma_into(input: &WmaInput, out: &mut [f64]) -> Result<(), WmaError> {
240    wma_into_slice(out, input, Kernel::Auto)
241}
242
243fn wma_prepare<'a>(
244    input: &'a WmaInput,
245    kernel: Kernel,
246) -> Result<(&'a [f64], usize, usize, Kernel), WmaError> {
247    let data: &[f64] = input.as_ref();
248    let len = data.len();
249    if len == 0 {
250        return Err(WmaError::EmptyInputData);
251    }
252
253    let first = data
254        .iter()
255        .position(|x| !x.is_nan())
256        .ok_or(WmaError::AllValuesNaN)?;
257    let period = input.get_period();
258
259    if period < 2 || period > len {
260        return Err(WmaError::InvalidPeriod {
261            period,
262            data_len: len,
263        });
264    }
265    if len - first < period {
266        return Err(WmaError::NotEnoughValidData {
267            needed: period,
268            valid: len - first,
269        });
270    }
271
272    let chosen = match kernel {
273        Kernel::Auto => Kernel::Scalar,
274        k => k,
275    };
276
277    Ok((data, period, first, chosen))
278}
279
280#[inline(always)]
281fn wma_compute_into(data: &[f64], period: usize, first: usize, kernel: Kernel, out: &mut [f64]) {
282    unsafe {
283        match kernel {
284            Kernel::Scalar | Kernel::ScalarBatch => wma_scalar(data, period, first, out),
285            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
286            Kernel::Avx2 | Kernel::Avx2Batch => wma_avx2(data, period, first, out),
287            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
288            Kernel::Avx512 | Kernel::Avx512Batch => wma_avx512(data, period, first, out),
289            _ => unreachable!(),
290        }
291    }
292}
293
294#[inline]
295pub fn wma_scalar(data: &[f64], period: usize, first_val: usize, out: &mut [f64]) {
296    debug_assert_eq!(out.len(), data.len());
297    let lookback = period - 1;
298    let period_f = period as f64;
299
300    let weights = period_f * (period_f + 1.0) * 0.5;
301
302    unsafe {
303        let base = data.as_ptr().add(first_val);
304        let end = data.as_ptr().add(data.len());
305
306        let mut sum = 0.0_f64;
307        let mut weight_sum = 0.0_f64;
308
309        let mut k = 0usize;
310        while k < lookback {
311            let v = *base.add(k);
312            weight_sum += v * (k as f64 + 1.0);
313            sum += v;
314            k += 1;
315        }
316
317        let mut in_new = base.add(lookback);
318        let mut in_old = base;
319        let mut out_ptr = out.as_mut_ptr().add(first_val + lookback);
320
321        while in_new < end {
322            let v = *in_new;
323            weight_sum += v * period_f;
324            sum += v;
325
326            *out_ptr = weight_sum / weights;
327
328            weight_sum -= sum;
329            sum -= *in_old;
330
331            in_new = in_new.add(1);
332            in_old = in_old.add(1);
333            out_ptr = out_ptr.add(1);
334        }
335    }
336}
337
338#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
339#[inline]
340pub fn wma_avx2(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
341    wma_scalar(data, period, first_valid, out)
342}
343
344#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
345#[inline]
346pub fn wma_avx512(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
347    wma_scalar(data, period, first_valid, out)
348}
349
350#[inline]
351pub fn wma_avx512_short(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
352    wma_scalar(data, period, first_valid, out)
353}
354
355#[inline]
356pub fn wma_avx512_long(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
357    wma_scalar(data, period, first_valid, out)
358}
359
360#[inline(always)]
361pub fn wma_with_kernel_batch(
362    data: &[f64],
363    sweep: &WmaBatchRange,
364    k: Kernel,
365) -> Result<WmaBatchOutput, WmaError> {
366    let kernel = match k {
367        Kernel::Auto => detect_best_batch_kernel(),
368        other if other.is_batch() => other,
369        _ => return Err(WmaError::InvalidKernelForBatch(k)),
370    };
371
372    let simd = match kernel {
373        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
374        Kernel::Avx512Batch => Kernel::Avx512,
375        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
376        Kernel::Avx2Batch => Kernel::Avx2,
377        Kernel::ScalarBatch => Kernel::Scalar,
378        _ => Kernel::Scalar,
379    };
380    wma_batch_par_slice(data, sweep, simd)
381}
382
383#[derive(Clone, Debug)]
384pub struct WmaStream {
385    period: usize,
386    buffer: Vec<f64>,
387    head: usize,
388    filled: bool,
389
390    plain_sum: f64,
391    weighted_sum: f64,
392    inv_div: f64,
393    p_f64: f64,
394}
395
396impl WmaStream {
397    pub fn try_new(params: WmaParams) -> Result<Self, WmaError> {
398        let period = params.period.unwrap_or(30);
399        if period < 2 {
400            return Err(WmaError::InvalidPeriod {
401                period,
402                data_len: 0,
403            });
404        }
405
406        let sum_of_weights = (period * (period + 1)) as f64 * 0.5;
407        Ok(Self {
408            period,
409            buffer: vec![f64::NAN; period],
410            head: 0,
411            filled: false,
412            plain_sum: 0.0,
413            weighted_sum: 0.0,
414            inv_div: 1.0 / sum_of_weights,
415            p_f64: period as f64,
416        })
417    }
418
419    #[inline(always)]
420    pub fn update(&mut self, value: f64) -> Option<f64> {
421        let write_idx = self.head;
422        self.buffer[write_idx] = value;
423
424        self.head += 1;
425        if self.head == self.period {
426            self.head = 0;
427        }
428
429        if !self.filled {
430            if self.head == 0 {
431                let mut wsum = 0.0;
432                let mut ssum = 0.0;
433
434                let mut idx = self.head;
435                for w in 1..=self.period {
436                    let v = self.buffer[idx];
437                    ssum += v;
438                    wsum += (w as f64) * v;
439                    idx += 1;
440                    if idx == self.period {
441                        idx = 0;
442                    }
443                }
444                let out = wsum * self.inv_div;
445
446                self.weighted_sum = wsum - ssum;
447                let oldest_next = self.buffer[self.head];
448                self.plain_sum = ssum - oldest_next;
449
450                self.filled = true;
451                Some(out)
452            } else {
453                None
454            }
455        } else {
456            let oldest = self.buffer[self.head];
457            self.weighted_sum += self.p_f64 * value;
458            self.plain_sum += value;
459
460            let out = self.weighted_sum * self.inv_div;
461
462            self.weighted_sum -= self.plain_sum;
463            self.plain_sum -= oldest;
464
465            Some(out)
466        }
467    }
468}
469
470#[derive(Clone, Debug)]
471pub struct WmaBatchRange {
472    pub period: (usize, usize, usize),
473}
474
475impl Default for WmaBatchRange {
476    fn default() -> Self {
477        Self {
478            period: (2, 251, 1),
479        }
480    }
481}
482
483#[derive(Clone, Debug, Default)]
484pub struct WmaBatchBuilder {
485    range: WmaBatchRange,
486    kernel: Kernel,
487}
488
489impl WmaBatchBuilder {
490    pub fn new() -> Self {
491        Self::default()
492    }
493    pub fn kernel(mut self, k: Kernel) -> Self {
494        self.kernel = k;
495        self
496    }
497
498    #[inline]
499    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
500        self.range.period = (start, end, step);
501        self
502    }
503    #[inline]
504    pub fn period_static(mut self, p: usize) -> Self {
505        self.range.period = (p, p, 0);
506        self
507    }
508
509    pub fn apply_slice(self, data: &[f64]) -> Result<WmaBatchOutput, WmaError> {
510        wma_with_kernel_batch(data, &self.range, self.kernel)
511    }
512
513    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<WmaBatchOutput, WmaError> {
514        WmaBatchBuilder::new().kernel(k).apply_slice(data)
515    }
516
517    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<WmaBatchOutput, WmaError> {
518        let slice = source_type(c, src);
519        self.apply_slice(slice)
520    }
521
522    pub fn with_default_candles(c: &Candles) -> Result<WmaBatchOutput, WmaError> {
523        WmaBatchBuilder::new()
524            .kernel(Kernel::Auto)
525            .apply_candles(c, "close")
526    }
527}
528
529#[derive(Clone, Debug)]
530pub struct WmaBatchOutput {
531    pub values: Vec<f64>,
532    pub combos: Vec<WmaParams>,
533    pub rows: usize,
534    pub cols: usize,
535}
536impl WmaBatchOutput {
537    pub fn row_for_params(&self, p: &WmaParams) -> Option<usize> {
538        self.combos
539            .iter()
540            .position(|c| c.period.unwrap_or(30) == p.period.unwrap_or(30))
541    }
542
543    pub fn values_for(&self, p: &WmaParams) -> Option<&[f64]> {
544        self.row_for_params(p).map(|row| {
545            let start = row * self.cols;
546            &self.values[start..start + self.cols]
547        })
548    }
549}
550
551#[inline(always)]
552fn expand_grid(r: &WmaBatchRange) -> Vec<WmaParams> {
553    fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
554        if step == 0 || start == end {
555            return vec![start];
556        }
557        if start < end {
558            return (start..=end).step_by(step.max(1)).collect();
559        }
560
561        let mut out = Vec::new();
562        let mut x = start as isize;
563        let end_i = end as isize;
564        let st = (step as isize).max(1);
565        while x >= end_i {
566            out.push(x as usize);
567            x -= st;
568        }
569        if out.is_empty() {
570            return out;
571        }
572        if *out.last().unwrap() != end {
573            out.push(end);
574        }
575        out
576    }
577
578    let periods = axis_usize(r.period);
579    let mut out = Vec::with_capacity(periods.len());
580    for &p in &periods {
581        out.push(WmaParams { period: Some(p) });
582    }
583    out
584}
585
586#[inline(always)]
587pub fn wma_batch_slice(
588    data: &[f64],
589    sweep: &WmaBatchRange,
590    kern: Kernel,
591) -> Result<WmaBatchOutput, WmaError> {
592    wma_batch_inner(data, sweep, kern, false)
593}
594
595#[inline(always)]
596pub fn wma_batch_par_slice(
597    data: &[f64],
598    sweep: &WmaBatchRange,
599    kern: Kernel,
600) -> Result<WmaBatchOutput, WmaError> {
601    wma_batch_inner(data, sweep, kern, true)
602}
603
604#[inline(always)]
605fn wma_batch_inner(
606    data: &[f64],
607    sweep: &WmaBatchRange,
608    kern: Kernel,
609    parallel: bool,
610) -> Result<WmaBatchOutput, WmaError> {
611    let combos = expand_grid(sweep);
612    let rows = combos.len();
613    let cols = data.len();
614    if cols == 0 {
615        return Err(WmaError::EmptyInputData);
616    }
617
618    rows.checked_mul(cols)
619        .ok_or_else(|| WmaError::InvalidInput("rows*cols overflow".into()))?;
620
621    let mut buf_mu = make_uninit_matrix(rows, cols);
622
623    let first = data
624        .iter()
625        .position(|x| !x.is_nan())
626        .ok_or(WmaError::AllValuesNaN)?;
627    let warm: Vec<usize> = combos
628        .iter()
629        .map(|c| first + c.period.unwrap() - 1)
630        .collect();
631    init_matrix_prefixes(&mut buf_mu, cols, &warm);
632
633    let mut guard = core::mem::ManuallyDrop::new(buf_mu);
634    let out: &mut [f64] =
635        unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
636
637    wma_batch_inner_into(data, sweep, kern, parallel, out)?;
638
639    let values = unsafe {
640        Vec::from_raw_parts(
641            guard.as_mut_ptr() as *mut f64,
642            guard.len(),
643            guard.capacity(),
644        )
645    };
646
647    Ok(WmaBatchOutput {
648        values,
649        combos,
650        rows,
651        cols,
652    })
653}
654
655#[inline(always)]
656fn wma_batch_inner_into(
657    data: &[f64],
658    sweep: &WmaBatchRange,
659    kern: Kernel,
660    parallel: bool,
661    out: &mut [f64],
662) -> Result<Vec<WmaParams>, WmaError> {
663    let combos = expand_grid(sweep);
664    if combos.is_empty() {
665        let (start, end, step) = sweep.period;
666        return Err(WmaError::InvalidRange { start, end, step });
667    }
668
669    let first = data
670        .iter()
671        .position(|x| !x.is_nan())
672        .ok_or(WmaError::AllValuesNaN)?;
673    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
674    if data.len() - first < max_p {
675        return Err(WmaError::NotEnoughValidData {
676            needed: max_p,
677            valid: data.len() - first,
678        });
679    }
680
681    let rows = combos.len();
682    let cols = data.len();
683
684    let needed = rows
685        .checked_mul(cols)
686        .ok_or_else(|| WmaError::InvalidInput("rows*cols overflow".into()))?;
687    if out.len() != needed {
688        return Err(WmaError::OutputLengthMismatch {
689            expected: needed,
690            got: out.len(),
691        });
692    }
693
694    let warm: Vec<usize> = combos
695        .iter()
696        .map(|c| first + c.period.unwrap() - 1)
697        .collect();
698
699    let out_uninit = unsafe {
700        std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
701    };
702
703    unsafe { init_matrix_prefixes(out_uninit, cols, &warm) };
704
705    let cols = data.len();
706    let mut pref_a = AVec::<f64>::with_capacity(CACHELINE_ALIGN, cols + 1);
707    let mut pref_b = AVec::<f64>::with_capacity(CACHELINE_ALIGN, cols + 1);
708    unsafe {
709        pref_a.set_len(cols + 1);
710        pref_b.set_len(cols + 1);
711    }
712    pref_a[0] = 0.0;
713    pref_b[0] = 0.0;
714    for i in 0..cols {
715        let x = if i < first { 0.0 } else { data[i] };
716        pref_a[i + 1] = pref_a[i] + x;
717        pref_b[i + 1] = pref_b[i] + (i as f64) * x;
718    }
719
720    let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
721        let period = combos[row].period.unwrap();
722        let denom = (period * (period + 1)) as f64 / 2.0;
723        let inv_div = 1.0 / denom;
724        let warm_end = first + period - 1;
725
726        let out_row =
727            core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
728
729        for i in warm_end..cols {
730            let s_a = pref_a[i + 1] - pref_a[i + 1 - period];
731            let s_b = pref_b[i + 1] - pref_b[i + 1 - period];
732            let wsum = s_b - ((i + 1 - period) as f64 - 1.0) * s_a;
733            out_row[i] = wsum * inv_div;
734        }
735    };
736
737    if parallel {
738        #[cfg(not(target_arch = "wasm32"))]
739        {
740            out_uninit
741                .par_chunks_mut(cols)
742                .enumerate()
743                .for_each(|(row, slice)| do_row(row, slice));
744        }
745
746        #[cfg(target_arch = "wasm32")]
747        {
748            for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
749                do_row(row, slice);
750            }
751        }
752    } else {
753        for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
754            do_row(row, slice);
755        }
756    }
757
758    Ok(combos)
759}
760
761#[inline(always)]
762unsafe fn wma_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
763    wma_scalar(data, period, first, out)
764}
765
766#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
767#[inline(always)]
768unsafe fn wma_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
769    wma_row_scalar(data, first, period, out)
770}
771
772#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
773#[inline(always)]
774pub unsafe fn wma_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
775    if period <= 32 {
776        wma_row_avx512_short(data, first, period, out);
777    } else {
778        wma_row_avx512_long(data, first, period, out);
779    }
780}
781
782#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
783#[inline(always)]
784unsafe fn wma_row_avx512_short(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
785    wma_row_scalar(data, first, period, out)
786}
787
788#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
789#[inline(always)]
790unsafe fn wma_row_avx512_long(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
791    wma_row_scalar(data, first, period, out)
792}
793
794#[cfg(test)]
795mod tests {
796    use super::*;
797    use crate::skip_if_unsupported;
798    use crate::utilities::data_loader::read_candles_from_csv;
799    use paste::paste;
800
801    #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
802    #[test]
803    fn test_wma_into_matches_api() -> Result<(), Box<dyn Error>> {
804        let mut data = Vec::with_capacity(256);
805
806        data.extend_from_slice(&[f64::NAN, f64::NAN, f64::NAN, f64::NAN]);
807        for i in 0..252u32 {
808            let v = 0.5 * (i as f64) + ((i % 7) as f64);
809            data.push(v);
810        }
811
812        let params = WmaParams { period: Some(30) };
813        let input = WmaInput::from_slice(&data, params);
814
815        let baseline = wma(&input)?.values;
816
817        let mut out = vec![0.0; data.len()];
818        wma_into(&input, &mut out)?;
819
820        assert_eq!(baseline.len(), out.len());
821        for (i, (&a, &b)) in baseline.iter().zip(out.iter()).enumerate() {
822            let equal = (a.is_nan() && b.is_nan()) || (a == b);
823            assert!(equal, "Mismatch at index {}: api={} into={}", i, a, b);
824        }
825
826        Ok(())
827    }
828
829    fn check_wma_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
830        skip_if_unsupported!(kernel, test_name);
831        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
832        let candles = read_candles_from_csv(file_path)?;
833
834        let default_params = WmaParams { period: None };
835        let input = WmaInput::from_candles(&candles, "close", default_params);
836        let output = wma_with_kernel(&input, kernel)?;
837        assert_eq!(output.values.len(), candles.close.len());
838
839        let params_period_14 = WmaParams { period: Some(14) };
840        let input2 = WmaInput::from_candles(&candles, "hl2", params_period_14);
841        let output2 = wma_with_kernel(&input2, kernel)?;
842        assert_eq!(output2.values.len(), candles.close.len());
843
844        let params_custom = WmaParams { period: Some(20) };
845        let input3 = WmaInput::from_candles(&candles, "hlc3", params_custom);
846        let output3 = wma_with_kernel(&input3, kernel)?;
847        assert_eq!(output3.values.len(), candles.close.len());
848        Ok(())
849    }
850
851    fn check_wma_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
852        skip_if_unsupported!(kernel, test_name);
853        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
854        let candles = read_candles_from_csv(file_path)?;
855        let data = &candles.close;
856        let default_params = WmaParams::default();
857        let input = WmaInput::from_candles(&candles, "close", default_params);
858        let result = wma_with_kernel(&input, kernel)?;
859
860        let expected_last_five = [
861            59638.52903225806,
862            59563.7376344086,
863            59489.4064516129,
864            59432.02580645162,
865            59350.58279569892,
866        ];
867        assert!(result.values.len() >= 5, "Not enough WMA values");
868        assert_eq!(
869            result.values.len(),
870            data.len(),
871            "WMA output length should match input length"
872        );
873        let start_index = result.values.len().saturating_sub(5);
874        let last_five = &result.values[start_index..];
875        for (i, &value) in last_five.iter().enumerate() {
876            assert!(
877                (value - expected_last_five[i]).abs() < 1e-6,
878                "WMA value mismatch at index {}: expected {}, got {}",
879                i,
880                expected_last_five[i],
881                value
882            );
883        }
884        let period = input.params.period.unwrap_or(30);
885        for val in result.values.iter().skip(period - 1) {
886            if !val.is_nan() {
887                assert!(val.is_finite(), "WMA output should be finite");
888            }
889        }
890        Ok(())
891    }
892
893    fn check_wma_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
894        skip_if_unsupported!(kernel, test_name);
895        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
896        let candles = read_candles_from_csv(file_path)?;
897        let input = WmaInput::with_default_candles(&candles);
898        match input.data {
899            WmaData::Candles { source, .. } => assert_eq!(source, "close"),
900            _ => panic!("Expected WmaData::Candles"),
901        }
902        let output = wma_with_kernel(&input, kernel)?;
903        assert_eq!(output.values.len(), candles.close.len());
904        Ok(())
905    }
906
907    fn check_wma_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
908        skip_if_unsupported!(kernel, test_name);
909        let input_data = [10.0, 20.0, 30.0];
910        let params = WmaParams { period: Some(0) };
911        let input = WmaInput::from_slice(&input_data, params);
912        let res = wma_with_kernel(&input, kernel);
913        assert!(
914            res.is_err(),
915            "[{}] WMA should fail with zero period",
916            test_name
917        );
918        Ok(())
919    }
920
921    fn check_wma_period_exceeds_length(
922        test_name: &str,
923        kernel: Kernel,
924    ) -> Result<(), Box<dyn Error>> {
925        skip_if_unsupported!(kernel, test_name);
926        let data_small = [10.0, 20.0, 30.0];
927        let params = WmaParams { period: Some(10) };
928        let input = WmaInput::from_slice(&data_small, params);
929        let res = wma_with_kernel(&input, kernel);
930        assert!(
931            res.is_err(),
932            "[{}] WMA should fail with period exceeding length",
933            test_name
934        );
935        Ok(())
936    }
937
938    fn check_wma_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
939        skip_if_unsupported!(kernel, test_name);
940        let single_point = [42.0];
941        let params = WmaParams { period: Some(9) };
942        let input = WmaInput::from_slice(&single_point, params);
943        let res = wma_with_kernel(&input, kernel);
944        assert!(
945            res.is_err(),
946            "[{}] WMA should fail with insufficient data",
947            test_name
948        );
949        Ok(())
950    }
951
952    fn check_wma_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
953        skip_if_unsupported!(kernel, test_name);
954        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
955        let candles = read_candles_from_csv(file_path)?;
956        let first_params = WmaParams { period: Some(14) };
957        let first_input = WmaInput::from_candles(&candles, "close", first_params);
958        let first_result = wma_with_kernel(&first_input, kernel)?;
959        let second_params = WmaParams { period: Some(5) };
960        let second_input = WmaInput::from_slice(&first_result.values, second_params);
961        let second_result = wma_with_kernel(&second_input, kernel)?;
962        assert_eq!(second_result.values.len(), first_result.values.len());
963        for val in &second_result.values[50..] {
964            assert!(!val.is_nan());
965        }
966        Ok(())
967    }
968
969    fn check_wma_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
970        skip_if_unsupported!(kernel, test_name);
971        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
972        let candles = read_candles_from_csv(file_path)?;
973        let params = WmaParams { period: Some(14) };
974        let input = WmaInput::from_candles(&candles, "close", params);
975        let result = wma_with_kernel(&input, kernel)?;
976        assert_eq!(result.values.len(), candles.close.len());
977        if result.values.len() > 50 {
978            for i in 50..result.values.len() {
979                assert!(!result.values[i].is_nan());
980            }
981        }
982        Ok(())
983    }
984
985    fn check_wma_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
986        skip_if_unsupported!(kernel, test_name);
987        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
988        let candles = read_candles_from_csv(file_path)?;
989        let period = 30;
990        let input = WmaInput::from_candles(
991            &candles,
992            "close",
993            WmaParams {
994                period: Some(period),
995            },
996        );
997        let batch_output = wma_with_kernel(&input, kernel)?.values;
998
999        let mut stream = WmaStream::try_new(WmaParams {
1000            period: Some(period),
1001        })?;
1002        let mut stream_values = Vec::with_capacity(candles.close.len());
1003        for &price in &candles.close {
1004            match stream.update(price) {
1005                Some(val) => stream_values.push(val),
1006                None => stream_values.push(f64::NAN),
1007            }
1008        }
1009        assert_eq!(batch_output.len(), stream_values.len());
1010        for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1011            if b.is_nan() && s.is_nan() {
1012                continue;
1013            }
1014            let diff = (b - s).abs();
1015            assert!(
1016                diff < 1e-8,
1017                "[{}] WMA streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1018                test_name,
1019                i,
1020                b,
1021                s,
1022                diff
1023            );
1024        }
1025        Ok(())
1026    }
1027
1028    #[cfg(debug_assertions)]
1029    fn check_wma_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1030        skip_if_unsupported!(kernel, test_name);
1031
1032        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1033        let candles = read_candles_from_csv(file_path)?;
1034
1035        let test_periods = vec![2, 5, 10, 14, 20, 30, 50, 100, 200];
1036
1037        for &period in &test_periods {
1038            if period > candles.close.len() {
1039                continue;
1040            }
1041
1042            let input = WmaInput::from_candles(
1043                &candles,
1044                "close",
1045                WmaParams {
1046                    period: Some(period),
1047                },
1048            );
1049            let output = wma_with_kernel(&input, kernel)?;
1050
1051            for (i, &val) in output.values.iter().enumerate() {
1052                if val.is_nan() {
1053                    continue;
1054                }
1055
1056                let bits = val.to_bits();
1057
1058                if bits == 0x11111111_11111111 {
1059                    panic!(
1060						"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} with period {}",
1061						test_name, val, bits, i, period
1062					);
1063                }
1064
1065                if bits == 0x22222222_22222222 {
1066                    panic!(
1067						"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} with period {}",
1068						test_name, val, bits, i, period
1069					);
1070                }
1071
1072                if bits == 0x33333333_33333333 {
1073                    panic!(
1074						"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} with period {}",
1075						test_name, val, bits, i, period
1076					);
1077                }
1078            }
1079        }
1080
1081        Ok(())
1082    }
1083
1084    #[cfg(not(debug_assertions))]
1085    fn check_wma_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1086        Ok(())
1087    }
1088
1089    macro_rules! generate_all_wma_tests {
1090        ($($test_fn:ident),*) => {
1091            paste! {
1092                $(
1093                    #[test]
1094                    fn [<$test_fn _scalar_f64>]() {
1095                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1096                    }
1097                )*
1098                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1099                $(
1100                    #[test]
1101                    fn [<$test_fn _avx2_f64>]() {
1102                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1103                    }
1104                    #[test]
1105                    fn [<$test_fn _avx512_f64>]() {
1106                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1107                    }
1108                )*
1109            }
1110        }
1111    }
1112
1113    #[cfg(feature = "proptest")]
1114    #[allow(clippy::float_cmp)]
1115    fn check_wma_property(
1116        test_name: &str,
1117        kernel: Kernel,
1118    ) -> Result<(), Box<dyn std::error::Error>> {
1119        use proptest::prelude::*;
1120        skip_if_unsupported!(kernel, test_name);
1121
1122        let strat = (2usize..=100).prop_flat_map(|period| {
1123            (
1124                prop::collection::vec(
1125                    (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1126                    period..400,
1127                ),
1128                Just(period),
1129            )
1130        });
1131
1132        proptest::test_runner::TestRunner::default()
1133            .run(&strat, |(data, period)| {
1134                let params = WmaParams {
1135                    period: Some(period),
1136                };
1137                let input = WmaInput::from_slice(&data, params.clone());
1138
1139                let WmaOutput { values: out } = wma_with_kernel(&input, kernel).unwrap();
1140
1141                let WmaOutput { values: ref_out } =
1142                    wma_with_kernel(&input, Kernel::Scalar).unwrap();
1143
1144                let first = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1145                let warmup_end = first + period - 1;
1146
1147                for i in 0..warmup_end.min(out.len()) {
1148                    prop_assert!(
1149                        out[i].is_nan(),
1150                        "Expected NaN during warmup at index {}, got {}",
1151                        i,
1152                        out[i]
1153                    );
1154                }
1155
1156                for i in warmup_end..out.len() {
1157                    prop_assert!(
1158                        out[i].is_finite(),
1159                        "Expected finite value after warmup at index {}, got {}",
1160                        i,
1161                        out[i]
1162                    );
1163                }
1164
1165                for i in warmup_end..data.len() {
1166                    let window_start = i + 1 - period;
1167                    let window = &data[window_start..=i];
1168
1169                    let lo = window.iter().cloned().fold(f64::INFINITY, f64::min);
1170                    let hi = window.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1171                    let y = out[i];
1172
1173                    prop_assert!(
1174                        y >= lo - 1e-9 && y <= hi + 1e-9,
1175                        "WMA at index {} = {} is outside window bounds [{}, {}]",
1176                        i,
1177                        y,
1178                        lo,
1179                        hi
1180                    );
1181
1182                    if window.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-12) {
1183                        prop_assert!(
1184                            (y - window[0]).abs() <= 1e-9,
1185                            "Constant input should produce constant output: {} vs {}",
1186                            y,
1187                            window[0]
1188                        );
1189                    }
1190
1191                    let r = ref_out[i];
1192                    if y.is_finite() && r.is_finite() {
1193                        let y_bits = y.to_bits();
1194                        let r_bits = r.to_bits();
1195                        let ulp_diff = y_bits.abs_diff(r_bits);
1196
1197                        prop_assert!(
1198                            (y - r).abs() <= 1e-9 || ulp_diff <= 4,
1199                            "Kernel mismatch at index {}: {} vs {} (ULP={})",
1200                            i,
1201                            y,
1202                            r,
1203                            ulp_diff
1204                        );
1205                    }
1206                }
1207
1208                if period == 2 && out.len() >= 2 {
1209                    let idx = warmup_end;
1210                    if idx < out.len() {
1211                        let expected = (data[idx - 1] + 2.0 * data[idx]) / 3.0;
1212                        prop_assert!(
1213                            (out[idx] - expected).abs() <= 1e-9,
1214                            "Period=2 calculation mismatch: {} vs expected {}",
1215                            out[idx],
1216                            expected
1217                        );
1218                    }
1219                }
1220
1221                if period <= 5 && warmup_end < out.len() {
1222                    let idx = warmup_end;
1223                    let window_start = idx + 1 - period;
1224
1225                    let mut weighted_sum = 0.0;
1226                    let mut weight_sum = 0.0;
1227                    for (j, &val) in data[window_start..=idx].iter().enumerate() {
1228                        let weight = (j + 1) as f64;
1229                        weighted_sum += weight * val;
1230                        weight_sum += weight;
1231                    }
1232                    let expected = weighted_sum / weight_sum;
1233
1234                    prop_assert!(
1235                        (out[idx] - expected).abs() <= 1e-9,
1236                        "Weight formula verification failed at index {}: {} vs expected {}",
1237                        idx,
1238                        out[idx],
1239                        expected
1240                    );
1241                }
1242
1243                if data.len() >= period * 2 {
1244                    let mid = data.len() / 2;
1245                    if mid > warmup_end {
1246                        let mut step_data = vec![10.0; data.len()];
1247                        for i in mid..step_data.len() {
1248                            step_data[i] = 100.0;
1249                        }
1250
1251                        let step_input = WmaInput::from_slice(&step_data, params.clone());
1252                        let WmaOutput { values: step_out } =
1253                            wma_with_kernel(&step_input, kernel).unwrap();
1254
1255                        if mid + period < step_out.len() {
1256                            let wma_after_step = step_out[mid + period - 1];
1257                            let distance_to_new = (wma_after_step - 100.0).abs();
1258                            let distance_to_old = (wma_after_step - 10.0).abs();
1259                            prop_assert!(
1260								distance_to_new < distance_to_old,
1261								"WMA should respond more to recent values: {} should be closer to 100 than 10",
1262								wma_after_step
1263							);
1264                        }
1265                    }
1266                }
1267
1268                if data.len() == period {
1269                    let valid_count = out.iter().filter(|x| x.is_finite()).count();
1270                    prop_assert!(
1271                        valid_count == 1,
1272                        "With data.len() == period, should have exactly 1 valid output, got {}",
1273                        valid_count
1274                    );
1275
1276                    prop_assert!(
1277                        out[data.len() - 1].is_finite(),
1278                        "Last value should be valid when data.len() == period"
1279                    );
1280                }
1281
1282                let is_monotonic_increasing = data.windows(2).all(|w| w[1] >= w[0] - 1e-12);
1283                if is_monotonic_increasing && out.len() > warmup_end + 1 {
1284                    for i in (warmup_end + 1)..out.len() {
1285                        prop_assert!(
1286                            out[i] >= out[i - 1] - 1e-9,
1287                            "Monotonic input should produce monotonic WMA: {} < {} at index {}",
1288                            out[i],
1289                            out[i - 1],
1290                            i
1291                        );
1292                    }
1293                }
1294
1295                #[cfg(debug_assertions)]
1296                {
1297                    for (i, &val) in out.iter().enumerate() {
1298                        if !val.is_nan() {
1299                            let bits = val.to_bits();
1300                            prop_assert!(
1301                                bits != 0x11111111_11111111
1302                                    && bits != 0x22222222_22222222
1303                                    && bits != 0x33333333_33333333,
1304                                "Found poison value at index {}: {} (0x{:016X})",
1305                                i,
1306                                val,
1307                                bits
1308                            );
1309                        }
1310                    }
1311                }
1312
1313                Ok(())
1314            })
1315            .unwrap();
1316
1317        Ok(())
1318    }
1319
1320    generate_all_wma_tests!(
1321        check_wma_partial_params,
1322        check_wma_accuracy,
1323        check_wma_default_candles,
1324        check_wma_zero_period,
1325        check_wma_period_exceeds_length,
1326        check_wma_very_small_dataset,
1327        check_wma_reinput,
1328        check_wma_nan_handling,
1329        check_wma_streaming,
1330        check_wma_no_poison
1331    );
1332
1333    #[cfg(feature = "proptest")]
1334    generate_all_wma_tests!(check_wma_property);
1335
1336    fn check_invalid_kernel_error(test: &str) -> Result<(), Box<dyn Error>> {
1337        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
1338        let sweep = WmaBatchRange { period: (2, 5, 1) };
1339
1340        let non_batch_kernels = vec![Kernel::Scalar, Kernel::Avx2, Kernel::Avx512];
1341        for kernel in non_batch_kernels {
1342            let result = wma_with_kernel_batch(&data, &sweep, kernel);
1343            assert!(
1344                matches!(result, Err(WmaError::InvalidKernelForBatch(_))),
1345                "[{}] Expected InvalidKernelForBatch error for {:?}, got {:?}",
1346                test,
1347                kernel,
1348                result
1349            );
1350        }
1351
1352        let batch_kernels = vec![Kernel::Auto, Kernel::ScalarBatch];
1353        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1354        let batch_kernels = vec![
1355            Kernel::Auto,
1356            Kernel::ScalarBatch,
1357            Kernel::Avx2Batch,
1358            Kernel::Avx512Batch,
1359        ];
1360
1361        for kernel in batch_kernels {
1362            let result = wma_with_kernel_batch(&data, &sweep, kernel);
1363            assert!(
1364                result.is_ok(),
1365                "[{}] Expected success for batch kernel {:?}, got error: {:?}",
1366                test,
1367                kernel,
1368                result.err()
1369            );
1370        }
1371
1372        Ok(())
1373    }
1374
1375    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1376        skip_if_unsupported!(kernel, test);
1377        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1378        let c = read_candles_from_csv(file)?;
1379
1380        let output = WmaBatchBuilder::new()
1381            .kernel(kernel)
1382            .apply_candles(&c, "close")?;
1383
1384        let def = WmaParams::default();
1385        let row = output.values_for(&def).expect("default row missing");
1386
1387        assert_eq!(row.len(), c.close.len());
1388
1389        let expected = [
1390            59638.52903225806,
1391            59563.7376344086,
1392            59489.4064516129,
1393            59432.02580645162,
1394            59350.58279569892,
1395        ];
1396        let start = row.len() - 5;
1397        for (i, &v) in row[start..].iter().enumerate() {
1398            assert!(
1399                (v - expected[i]).abs() < 1e-6,
1400                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1401            );
1402        }
1403        Ok(())
1404    }
1405
1406    #[cfg(debug_assertions)]
1407    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1408        skip_if_unsupported!(kernel, test);
1409
1410        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1411        let c = read_candles_from_csv(file)?;
1412
1413        let batch_configs = vec![
1414            (2, 10, 1),
1415            (5, 25, 5),
1416            (10, 30, 10),
1417            (20, 100, 10),
1418            (30, 150, 30),
1419            (50, 200, 50),
1420            (2, 5, 1),
1421        ];
1422
1423        for (start, end, step) in batch_configs {
1424            if start > c.close.len() {
1425                continue;
1426            }
1427
1428            let output = WmaBatchBuilder::new()
1429                .kernel(kernel)
1430                .period_range(start, end, step)
1431                .apply_candles(&c, "close")?;
1432
1433            for (idx, &val) in output.values.iter().enumerate() {
1434                if val.is_nan() {
1435                    continue;
1436                }
1437
1438                let bits = val.to_bits();
1439                let row = idx / output.cols;
1440                let col = idx % output.cols;
1441                let period = output.combos[row].period.unwrap_or(0);
1442
1443                if bits == 0x11111111_11111111 {
1444                    panic!(
1445                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (flat index {}) for period {} in range ({}, {}, {})",
1446                        test, val, bits, row, col, idx, period, start, end, step
1447                    );
1448                }
1449
1450                if bits == 0x22222222_22222222 {
1451                    panic!(
1452                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (flat index {}) for period {} in range ({}, {}, {})",
1453                        test, val, bits, row, col, idx, period, start, end, step
1454                    );
1455                }
1456
1457                if bits == 0x33333333_33333333 {
1458                    panic!(
1459                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (flat index {}) for period {} in range ({}, {}, {})",
1460                        test, val, bits, row, col, idx, period, start, end, step
1461                    );
1462                }
1463            }
1464        }
1465
1466        Ok(())
1467    }
1468
1469    #[cfg(not(debug_assertions))]
1470    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1471        Ok(())
1472    }
1473
1474    macro_rules! gen_batch_tests {
1475        ($fn_name:ident) => {
1476            paste! {
1477                #[test] fn [<$fn_name _scalar>]()      {
1478                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1479                }
1480                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1481                #[test] fn [<$fn_name _avx2>]()        {
1482                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1483                }
1484                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1485                #[test] fn [<$fn_name _avx512>]()      {
1486                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1487                }
1488                #[test] fn [<$fn_name _auto_detect>]() {
1489                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1490                }
1491            }
1492        };
1493    }
1494    gen_batch_tests!(check_batch_default_row);
1495    gen_batch_tests!(check_batch_no_poison);
1496
1497    #[test]
1498    fn test_invalid_kernel_error() {
1499        let _ = check_invalid_kernel_error("test_invalid_kernel_error");
1500    }
1501}
1502
1503#[cfg(feature = "python")]
1504#[pyfunction(name = "wma")]
1505#[pyo3(signature = (data, period, kernel=None))]
1506
1507pub fn wma_py<'py>(
1508    py: Python<'py>,
1509    data: numpy::PyReadonlyArray1<'py, f64>,
1510    period: usize,
1511    kernel: Option<&str>,
1512) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
1513    use numpy::{IntoPyArray, PyArrayMethods};
1514
1515    let slice_in = data.as_slice()?;
1516
1517    let kern = validate_kernel(kernel, false)?;
1518
1519    let params = WmaParams {
1520        period: Some(period),
1521    };
1522    let wma_in = WmaInput::from_slice(slice_in, params);
1523
1524    let result_vec: Vec<f64> = py
1525        .allow_threads(|| wma_with_kernel(&wma_in, kern).map(|o| o.values))
1526        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1527
1528    Ok(result_vec.into_pyarray(py))
1529}
1530
1531#[cfg(feature = "python")]
1532#[pyclass(name = "WmaStream")]
1533pub struct WmaStreamPy {
1534    stream: WmaStream,
1535}
1536
1537#[cfg(feature = "python")]
1538#[pymethods]
1539impl WmaStreamPy {
1540    #[new]
1541    fn new(period: usize) -> PyResult<Self> {
1542        let params = WmaParams {
1543            period: Some(period),
1544        };
1545        let stream =
1546            WmaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1547        Ok(WmaStreamPy { stream })
1548    }
1549
1550    fn update(&mut self, value: f64) -> Option<f64> {
1551        self.stream.update(value)
1552    }
1553}
1554
1555#[cfg(feature = "python")]
1556#[pyfunction(name = "wma_batch")]
1557#[pyo3(signature = (data, period_range, kernel=None))]
1558
1559pub fn wma_batch_py<'py>(
1560    py: Python<'py>,
1561    data: numpy::PyReadonlyArray1<'py, f64>,
1562    period_range: (usize, usize, usize),
1563    kernel: Option<&str>,
1564) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1565    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1566    use pyo3::types::PyDict;
1567
1568    let slice_in = data.as_slice()?;
1569
1570    let sweep = WmaBatchRange {
1571        period: period_range,
1572    };
1573
1574    let kern = validate_kernel(kernel, true)?;
1575
1576    let combos = expand_grid(&sweep);
1577    let rows = combos.len();
1578    let cols = slice_in.len();
1579
1580    let needed = rows
1581        .checked_mul(cols)
1582        .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
1583    let out_arr = unsafe { PyArray1::<f64>::new(py, [needed], false) };
1584    let slice_out = unsafe { out_arr.as_slice_mut()? };
1585
1586    let combos = py
1587        .allow_threads(|| {
1588            let kernel = match kern {
1589                Kernel::Auto => detect_best_batch_kernel(),
1590                k => k,
1591            };
1592            let simd = match kernel {
1593                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1594                Kernel::Avx512Batch => Kernel::Avx512,
1595                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1596                Kernel::Avx2Batch => Kernel::Avx2,
1597                Kernel::ScalarBatch => Kernel::Scalar,
1598                _ => Kernel::Scalar,
1599            };
1600
1601            wma_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1602        })
1603        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1604
1605    let dict = PyDict::new(py);
1606    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1607    dict.set_item(
1608        "periods",
1609        combos
1610            .iter()
1611            .map(|p| p.period.unwrap() as u64)
1612            .collect::<Vec<_>>()
1613            .into_pyarray(py),
1614    )?;
1615    dict.set_item("rows", rows)?;
1616    dict.set_item("cols", cols)?;
1617
1618    Ok(dict)
1619}
1620
1621#[cfg(all(feature = "python", feature = "cuda"))]
1622#[pyfunction(name = "wma_cuda_batch_dev")]
1623#[pyo3(signature = (data_f32, period_range, device_id=0))]
1624pub fn wma_cuda_batch_dev_py(
1625    py: Python<'_>,
1626    data_f32: numpy::PyReadonlyArray1<'_, f32>,
1627    period_range: (usize, usize, usize),
1628    device_id: usize,
1629) -> PyResult<DeviceArrayF32Py> {
1630    if !cuda_available() {
1631        return Err(PyValueError::new_err("CUDA not available"));
1632    }
1633
1634    let slice_in = data_f32.as_slice()?;
1635    let sweep = WmaBatchRange {
1636        period: period_range,
1637    };
1638
1639    let (inner, ctx, dev_id) = py.allow_threads(|| {
1640        let cuda = CudaWma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1641        let ctx = cuda.context_arc();
1642        let dev_id = cuda.device_id();
1643        let arr = cuda
1644            .wma_batch_dev(slice_in, &sweep)
1645            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1646        Ok::<_, pyo3::PyErr>((arr, ctx, dev_id))
1647    })?;
1648
1649    Ok(DeviceArrayF32Py::new_from_rust(inner, ctx, dev_id))
1650}
1651
1652#[cfg(all(feature = "python", feature = "cuda"))]
1653#[pyfunction(name = "wma_cuda_many_series_one_param_dev")]
1654#[pyo3(signature = (data_tm_f32, period, device_id=0))]
1655pub fn wma_cuda_many_series_one_param_dev_py(
1656    py: Python<'_>,
1657    data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1658    period: usize,
1659    device_id: usize,
1660) -> PyResult<DeviceArrayF32Py> {
1661    use numpy::PyUntypedArrayMethods;
1662
1663    if !cuda_available() {
1664        return Err(PyValueError::new_err("CUDA not available"));
1665    }
1666
1667    let flat_in = data_tm_f32.as_slice()?;
1668    let rows = data_tm_f32.shape()[0];
1669    let cols = data_tm_f32.shape()[1];
1670    let params = WmaParams {
1671        period: Some(period),
1672    };
1673
1674    let (inner, ctx, dev_id) = py.allow_threads(|| {
1675        let cuda = CudaWma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1676        let ctx = cuda.context_arc();
1677        let dev_id = cuda.device_id();
1678        let arr = cuda
1679            .wma_multi_series_one_param_time_major_dev(flat_in, cols, rows, &params)
1680            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1681        Ok::<_, pyo3::PyErr>((arr, ctx, dev_id))
1682    })?;
1683
1684    Ok(DeviceArrayF32Py::new_from_rust(inner, ctx, dev_id))
1685}
1686
1687#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1688#[derive(Serialize, Deserialize)]
1689pub struct WmaBatchConfig {
1690    pub period_range: (usize, usize, usize),
1691}
1692
1693#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1694#[derive(Serialize, Deserialize)]
1695pub struct WmaBatchJsOutput {
1696    pub values: Vec<f64>,
1697    pub combos: Vec<WmaParams>,
1698    pub rows: usize,
1699    pub cols: usize,
1700}
1701
1702#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1703#[wasm_bindgen(js_name = wma_batch)]
1704pub fn wma_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1705    let cfg: WmaBatchConfig = serde_wasm_bindgen::from_value(config)
1706        .map_err(|e| JsValue::from_str(&format!("Invalid config: {e}")))?;
1707    let sweep = WmaBatchRange {
1708        period: cfg.period_range,
1709    };
1710
1711    let out = wma_batch_inner(data, &sweep, detect_best_kernel(), false)
1712        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1713
1714    let js = WmaBatchJsOutput {
1715        values: out.values,
1716        combos: out.combos,
1717        rows: out.rows,
1718        cols: out.cols,
1719    };
1720    serde_wasm_bindgen::to_value(&js)
1721        .map_err(|e| JsValue::from_str(&format!("Serialization error: {e}")))
1722}
1723
1724#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1725#[wasm_bindgen]
1726pub fn wma_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1727    let params = WmaParams {
1728        period: Some(period),
1729    };
1730    let input = WmaInput::from_slice(data, params);
1731
1732    let mut output = vec![0.0; data.len()];
1733    wma_into_slice(&mut output, &input, Kernel::Auto)
1734        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1735
1736    Ok(output)
1737}
1738
1739#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1740#[wasm_bindgen]
1741pub fn wma_batch_js(
1742    data: &[f64],
1743    period_start: usize,
1744    period_end: usize,
1745    period_step: usize,
1746) -> Result<Vec<f64>, JsValue> {
1747    let sweep = WmaBatchRange {
1748        period: (period_start, period_end, period_step),
1749    };
1750
1751    wma_batch_inner(data, &sweep, Kernel::Auto, false)
1752        .map(|output| output.values)
1753        .map_err(|e| JsValue::from_str(&e.to_string()))
1754}
1755
1756#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1757#[wasm_bindgen]
1758pub fn wma_batch_metadata_js(
1759    period_start: usize,
1760    period_end: usize,
1761    period_step: usize,
1762) -> Result<Vec<f64>, JsValue> {
1763    let sweep = WmaBatchRange {
1764        period: (period_start, period_end, period_step),
1765    };
1766
1767    let combos = expand_grid(&sweep);
1768    let mut metadata = Vec::with_capacity(combos.len());
1769
1770    for combo in combos {
1771        metadata.push(combo.period.unwrap() as f64);
1772    }
1773
1774    Ok(metadata)
1775}
1776
1777#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1778#[wasm_bindgen]
1779pub fn wma_alloc(len: usize) -> *mut f64 {
1780    let mut vec = Vec::<f64>::with_capacity(len);
1781    let ptr = vec.as_mut_ptr();
1782    std::mem::forget(vec);
1783    ptr
1784}
1785
1786#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1787#[wasm_bindgen]
1788pub fn wma_free(ptr: *mut f64, len: usize) {
1789    if !ptr.is_null() {
1790        unsafe {
1791            let _ = Vec::from_raw_parts(ptr, len, len);
1792        }
1793    }
1794}
1795
1796#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1797#[wasm_bindgen]
1798pub fn wma_into(
1799    in_ptr: *const f64,
1800    out_ptr: *mut f64,
1801    len: usize,
1802    period: usize,
1803) -> Result<(), JsValue> {
1804    if in_ptr.is_null() || out_ptr.is_null() {
1805        return Err(JsValue::from_str("Null pointer provided"));
1806    }
1807
1808    unsafe {
1809        let data = std::slice::from_raw_parts(in_ptr, len);
1810        let params = WmaParams {
1811            period: Some(period),
1812        };
1813        let input = WmaInput::from_slice(data, params);
1814
1815        if in_ptr == out_ptr {
1816            let mut temp = vec![0.0; len];
1817            wma_into_slice(&mut temp, &input, Kernel::Auto)
1818                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1819            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1820            out.copy_from_slice(&temp);
1821        } else {
1822            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1823            wma_into_slice(out, &input, Kernel::Auto)
1824                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1825        }
1826        Ok(())
1827    }
1828}
1829
1830#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1831#[wasm_bindgen]
1832pub fn wma_batch_into(
1833    in_ptr: *const f64,
1834    out_ptr: *mut f64,
1835    len: usize,
1836    period_start: usize,
1837    period_end: usize,
1838    period_step: usize,
1839) -> Result<usize, JsValue> {
1840    if in_ptr.is_null() || out_ptr.is_null() {
1841        return Err(JsValue::from_str("Null pointer provided"));
1842    }
1843
1844    unsafe {
1845        let data = std::slice::from_raw_parts(in_ptr, len);
1846
1847        let sweep = WmaBatchRange {
1848            period: (period_start, period_end, period_step),
1849        };
1850
1851        let combos = expand_grid(&sweep);
1852        let rows = combos.len();
1853        let total_size = rows * len;
1854
1855        let out = std::slice::from_raw_parts_mut(out_ptr, total_size);
1856
1857        wma_batch_inner_into(data, &sweep, Kernel::Auto, false, out)
1858            .map_err(|e| JsValue::from_str(&e.to_string()))?;
1859
1860        Ok(rows)
1861    }
1862}