Skip to main content

vector_ta/indicators/moving_averages/
cwma.rs

1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::cuda_available;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::cuda::moving_averages::{CudaCwma, DeviceArrayF32};
5use crate::utilities::aligned_vector::AlignedVec;
6use crate::utilities::data_loader::{source_type, Candles};
7use crate::utilities::enums::Kernel;
8use crate::utilities::helpers::{
9    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
10    make_uninit_matrix,
11};
12#[cfg(feature = "python")]
13use crate::utilities::kernel_validation::validate_kernel;
14#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
15use core::arch::x86_64::*;
16#[cfg(all(feature = "python", feature = "cuda"))]
17use cust::memory::DeviceBuffer;
18#[cfg(feature = "python")]
19use numpy::{IntoPyArray, PyArray1};
20#[cfg(feature = "python")]
21use pyo3::exceptions::PyValueError;
22#[cfg(feature = "python")]
23use pyo3::prelude::*;
24#[cfg(feature = "python")]
25use pyo3::types::PyDict;
26#[cfg(not(target_arch = "wasm32"))]
27use rayon::prelude::*;
28#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
29use serde::{Deserialize, Serialize};
30use std::convert::AsRef;
31use std::mem::MaybeUninit;
32#[cfg(all(feature = "python", feature = "cuda"))]
33use std::sync::Arc;
34use thiserror::Error;
35#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
36use wasm_bindgen::prelude::*;
37
38#[inline(always)]
39fn cube(x: f64) -> f64 {
40    x * x * x
41}
42
43impl<'a> AsRef<[f64]> for CwmaInput<'a> {
44    #[inline(always)]
45    fn as_ref(&self) -> &[f64] {
46        match &self.data {
47            CwmaData::Slice(slice) => slice,
48            CwmaData::Candles { candles, source } => source_type(candles, source),
49        }
50    }
51}
52
53#[derive(Debug, Clone)]
54pub enum CwmaData<'a> {
55    Candles {
56        candles: &'a Candles,
57        source: &'a str,
58    },
59    Slice(&'a [f64]),
60}
61
62#[derive(Debug, Clone)]
63pub struct CwmaOutput {
64    pub values: Vec<f64>,
65}
66
67#[derive(Debug, Clone)]
68#[cfg_attr(
69    all(target_arch = "wasm32", feature = "wasm"),
70    derive(Serialize, Deserialize)
71)]
72pub struct CwmaParams {
73    pub period: Option<usize>,
74}
75
76impl Default for CwmaParams {
77    fn default() -> Self {
78        Self { period: Some(14) }
79    }
80}
81
82#[derive(Debug, Clone)]
83pub struct CwmaInput<'a> {
84    pub data: CwmaData<'a>,
85    pub params: CwmaParams,
86}
87
88impl<'a> CwmaInput<'a> {
89    #[inline]
90    pub fn from_candles(c: &'a Candles, s: &'a str, p: CwmaParams) -> Self {
91        Self {
92            data: CwmaData::Candles {
93                candles: c,
94                source: s,
95            },
96            params: p,
97        }
98    }
99    #[inline]
100    pub fn from_slice(sl: &'a [f64], p: CwmaParams) -> Self {
101        Self {
102            data: CwmaData::Slice(sl),
103            params: p,
104        }
105    }
106    #[inline]
107    pub fn with_default_candles(c: &'a Candles) -> Self {
108        Self::from_candles(c, "close", CwmaParams::default())
109    }
110    #[inline]
111    pub fn get_period(&self) -> usize {
112        self.params.period.unwrap_or(14)
113    }
114}
115
116#[derive(Copy, Clone, Debug)]
117pub struct CwmaBuilder {
118    period: Option<usize>,
119    kernel: Kernel,
120}
121
122impl Default for CwmaBuilder {
123    fn default() -> Self {
124        Self {
125            period: None,
126            kernel: Kernel::Auto,
127        }
128    }
129}
130
131impl CwmaBuilder {
132    #[inline(always)]
133    pub fn new() -> Self {
134        Self::default()
135    }
136    #[inline(always)]
137    pub fn period(mut self, n: usize) -> Self {
138        self.period = Some(n);
139        self
140    }
141    #[inline(always)]
142    pub fn kernel(mut self, k: Kernel) -> Self {
143        self.kernel = k;
144        self
145    }
146
147    #[inline(always)]
148    pub fn apply(self, c: &Candles) -> Result<CwmaOutput, CwmaError> {
149        let p = CwmaParams {
150            period: self.period,
151        };
152        let i = CwmaInput::from_candles(c, "close", p);
153        cwma_with_kernel(&i, self.kernel)
154    }
155
156    #[inline(always)]
157    pub fn apply_slice(self, d: &[f64]) -> Result<CwmaOutput, CwmaError> {
158        let p = CwmaParams {
159            period: self.period,
160        };
161        let i = CwmaInput::from_slice(d, p);
162        cwma_with_kernel(&i, self.kernel)
163    }
164
165    #[inline(always)]
166    pub fn into_stream(self) -> Result<CwmaStream, CwmaError> {
167        let p = CwmaParams {
168            period: self.period,
169        };
170        CwmaStream::try_new(p)
171    }
172}
173
174#[derive(Debug, Error)]
175pub enum CwmaError {
176    #[error("cwma: Input data slice is empty.")]
177    EmptyInputData,
178    #[error("cwma: All values are NaN.")]
179    AllValuesNaN,
180    #[error("cwma: Invalid period specified for CWMA calculation: period = {period}, data length = {data_len}")]
181    InvalidPeriod { period: usize, data_len: usize },
182    #[error(
183        "cwma: Not enough valid data points to compute CWMA: needed = {needed}, valid = {valid}"
184    )]
185    NotEnoughValidData { needed: usize, valid: usize },
186    #[error("cwma: output length mismatch: expected = {expected}, got = {got}")]
187    OutputLengthMismatch { expected: usize, got: usize },
188    #[error("cwma: invalid sweep range: start={start}, end={end}, step={step}")]
189    InvalidRange {
190        start: usize,
191        end: usize,
192        step: usize,
193    },
194    #[error("cwma: invalid kernel for batch API: {0:?}")]
195    InvalidKernelForBatch(Kernel),
196    #[error("cwma: size overflow while computing {ctx}")]
197    SizeOverflow { ctx: &'static str },
198}
199
200#[inline]
201pub fn cwma(input: &CwmaInput) -> Result<CwmaOutput, CwmaError> {
202    cwma_with_kernel(input, Kernel::Auto)
203}
204
205#[inline(always)]
206fn cwma_prepare<'a>(
207    input: &'a CwmaInput,
208    kernel: Kernel,
209) -> Result<(&'a [f64], Vec<f64>, usize, usize, f64, usize, Kernel), CwmaError> {
210    let data: &[f64] = match &input.data {
211        CwmaData::Candles { candles, source } => source_type(candles, source),
212        CwmaData::Slice(sl) => sl,
213    };
214    let len = data.len();
215    if len == 0 {
216        return Err(CwmaError::EmptyInputData);
217    }
218    let first = data
219        .iter()
220        .position(|x| !x.is_nan())
221        .ok_or(CwmaError::AllValuesNaN)?;
222
223    let period = input.get_period();
224
225    if period == 0 || period > len {
226        return Err(CwmaError::InvalidPeriod {
227            period,
228            data_len: len,
229        });
230    }
231
232    if period == 1 {
233        return Err(CwmaError::InvalidPeriod {
234            period,
235            data_len: len,
236        });
237    }
238    if (len - first) < period {
239        return Err(CwmaError::NotEnoughValidData {
240            needed: period,
241            valid: len - first,
242        });
243    }
244
245    let mut weights = Vec::with_capacity(period - 1);
246    let mut norm = 0.0;
247    for i in 0..period - 1 {
248        let w = cube((period - i) as f64);
249        weights.push(w);
250        norm += w;
251    }
252    let inv_norm = 1.0 / norm;
253
254    let warm = first + period - 1;
255
256    let chosen = match kernel {
257        Kernel::Auto => detect_best_kernel(),
258        k => k,
259    };
260
261    Ok((data, weights, period, first, inv_norm, warm, chosen))
262}
263
264#[inline(always)]
265fn cwma_compute_into(
266    data: &[f64],
267    weights: &[f64],
268    period: usize,
269    first: usize,
270    inv_norm: f64,
271    chosen: Kernel,
272    out: &mut [f64],
273) {
274    unsafe {
275        match chosen {
276            Kernel::Scalar | Kernel::ScalarBatch => {
277                cwma_scalar(data, weights, period, first, inv_norm, out)
278            }
279            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
280            Kernel::Avx2 | Kernel::Avx2Batch => {
281                cwma_avx2(data, weights, period, first, inv_norm, out)
282            }
283            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
284            Kernel::Avx512 | Kernel::Avx512Batch => {
285                cwma_avx512(data, weights, period, first, inv_norm, out)
286            }
287            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
288            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
289                cwma_scalar(data, weights, period, first, inv_norm, out)
290            }
291            _ => unreachable!(),
292        }
293    }
294}
295
296#[inline]
297pub fn cwma_into_slice(dst: &mut [f64], input: &CwmaInput, kern: Kernel) -> Result<(), CwmaError> {
298    let (data, weights, period, first, inv_norm, warm, chosen) = cwma_prepare(input, kern)?;
299
300    if dst.len() != data.len() {
301        return Err(CwmaError::OutputLengthMismatch {
302            expected: data.len(),
303            got: dst.len(),
304        });
305    }
306
307    cwma_compute_into(data, &weights, period, first, inv_norm, chosen, dst);
308
309    for v in &mut dst[..warm] {
310        *v = f64::NAN;
311    }
312
313    Ok(())
314}
315
316pub fn cwma_with_kernel(input: &CwmaInput, kernel: Kernel) -> Result<CwmaOutput, CwmaError> {
317    let (data, weights, period, first, inv_norm, warm, chosen) = cwma_prepare(input, kernel)?;
318    let len = data.len();
319    let mut out = alloc_with_nan_prefix(len, warm);
320    cwma_compute_into(data, &weights, period, first, inv_norm, chosen, &mut out);
321    Ok(CwmaOutput { values: out })
322}
323
324#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
325#[inline]
326pub fn cwma_into(input: &CwmaInput, out: &mut [f64]) -> Result<(), CwmaError> {
327    cwma_into_slice(out, input, Kernel::Auto)
328}
329
330#[inline]
331#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
332pub unsafe fn cwma_scalar(
333    data: &[f64],
334    weights: &[f64],
335    _period: usize,
336    first_val: usize,
337    inv_norm: f64,
338    out: &mut [f64],
339) {
340    let wlen = weights.len();
341    let wptr = weights.as_ptr();
342    let first_out = first_val + wlen;
343
344    for i in first_out..data.len() {
345        let mut d = data.as_ptr().add(i);
346
347        let mut s0 = 0.0;
348        let mut s1 = 0.0;
349        let mut s2 = 0.0;
350        let mut s3 = 0.0;
351        let mut s4 = 0.0;
352        let mut s5 = 0.0;
353        let mut s6 = 0.0;
354        let mut s7 = 0.0;
355
356        let mut k = 0usize;
357        while k + 7 < wlen {
358            s0 = (*d).mul_add(*wptr.add(k + 0), s0);
359            d = d.sub(1);
360            s1 = (*d).mul_add(*wptr.add(k + 1), s1);
361            d = d.sub(1);
362            s2 = (*d).mul_add(*wptr.add(k + 2), s2);
363            d = d.sub(1);
364            s3 = (*d).mul_add(*wptr.add(k + 3), s3);
365            d = d.sub(1);
366            s4 = (*d).mul_add(*wptr.add(k + 4), s4);
367            d = d.sub(1);
368            s5 = (*d).mul_add(*wptr.add(k + 5), s5);
369            d = d.sub(1);
370            s6 = (*d).mul_add(*wptr.add(k + 6), s6);
371            d = d.sub(1);
372            s7 = (*d).mul_add(*wptr.add(k + 7), s7);
373            d = d.sub(1);
374            k += 8;
375        }
376
377        let mut sum = (s0 + s1) + (s2 + s3) + (s4 + s5) + (s6 + s7);
378
379        match wlen - k {
380            0 => {}
381            1 => {
382                sum = (*d).mul_add(*wptr.add(k + 0), sum);
383            }
384            2 => {
385                sum = (*d).mul_add(*wptr.add(k + 0), sum);
386                d = d.sub(1);
387                sum = (*d).mul_add(*wptr.add(k + 1), sum);
388            }
389            3 => {
390                sum = (*d).mul_add(*wptr.add(k + 0), sum);
391                d = d.sub(1);
392                sum = (*d).mul_add(*wptr.add(k + 1), sum);
393                d = d.sub(1);
394                sum = (*d).mul_add(*wptr.add(k + 2), sum);
395            }
396            4 => {
397                sum = (*d).mul_add(*wptr.add(k + 0), sum);
398                d = d.sub(1);
399                sum = (*d).mul_add(*wptr.add(k + 1), sum);
400                d = d.sub(1);
401                sum = (*d).mul_add(*wptr.add(k + 2), sum);
402                d = d.sub(1);
403                sum = (*d).mul_add(*wptr.add(k + 3), sum);
404            }
405            5 => {
406                sum = (*d).mul_add(*wptr.add(k + 0), sum);
407                d = d.sub(1);
408                sum = (*d).mul_add(*wptr.add(k + 1), sum);
409                d = d.sub(1);
410                sum = (*d).mul_add(*wptr.add(k + 2), sum);
411                d = d.sub(1);
412                sum = (*d).mul_add(*wptr.add(k + 3), sum);
413                d = d.sub(1);
414                sum = (*d).mul_add(*wptr.add(k + 4), sum);
415            }
416            6 => {
417                sum = (*d).mul_add(*wptr.add(k + 0), sum);
418                d = d.sub(1);
419                sum = (*d).mul_add(*wptr.add(k + 1), sum);
420                d = d.sub(1);
421                sum = (*d).mul_add(*wptr.add(k + 2), sum);
422                d = d.sub(1);
423                sum = (*d).mul_add(*wptr.add(k + 3), sum);
424                d = d.sub(1);
425                sum = (*d).mul_add(*wptr.add(k + 4), sum);
426                d = d.sub(1);
427                sum = (*d).mul_add(*wptr.add(k + 5), sum);
428            }
429            7 => {
430                sum = (*d).mul_add(*wptr.add(k + 0), sum);
431                d = d.sub(1);
432                sum = (*d).mul_add(*wptr.add(k + 1), sum);
433                d = d.sub(1);
434                sum = (*d).mul_add(*wptr.add(k + 2), sum);
435                d = d.sub(1);
436                sum = (*d).mul_add(*wptr.add(k + 3), sum);
437                d = d.sub(1);
438                sum = (*d).mul_add(*wptr.add(k + 4), sum);
439                d = d.sub(1);
440                sum = (*d).mul_add(*wptr.add(k + 5), sum);
441                d = d.sub(1);
442                sum = (*d).mul_add(*wptr.add(k + 6), sum);
443            }
444            _ => core::hint::unreachable_unchecked(),
445        }
446
447        *out.get_unchecked_mut(i) = sum * inv_norm;
448    }
449}
450
451#[inline]
452#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
453pub unsafe fn cwma_scalar(
454    data: &[f64],
455    weights: &[f64],
456    _period: usize,
457    first_val: usize,
458    inv_norm: f64,
459    out: &mut [f64],
460) {
461    let wlen = weights.len();
462    let wptr = weights.as_ptr();
463    let first_out = first_val + wlen;
464
465    for i in first_out..data.len() {
466        let mut d = data.as_ptr().add(i);
467
468        let mut s0 = 0.0;
469        let mut s1 = 0.0;
470        let mut s2 = 0.0;
471        let mut s3 = 0.0;
472        let mut s4 = 0.0;
473        let mut s5 = 0.0;
474        let mut s6 = 0.0;
475        let mut s7 = 0.0;
476
477        let mut k = 0usize;
478        while k + 7 < wlen {
479            s0 += *d * *wptr.add(k + 0);
480            d = d.sub(1);
481            s1 += *d * *wptr.add(k + 1);
482            d = d.sub(1);
483            s2 += *d * *wptr.add(k + 2);
484            d = d.sub(1);
485            s3 += *d * *wptr.add(k + 3);
486            d = d.sub(1);
487            s4 += *d * *wptr.add(k + 4);
488            d = d.sub(1);
489            s5 += *d * *wptr.add(k + 5);
490            d = d.sub(1);
491            s6 += *d * *wptr.add(k + 6);
492            d = d.sub(1);
493            s7 += *d * *wptr.add(k + 7);
494            d = d.sub(1);
495            k += 8;
496        }
497
498        let mut sum = (s0 + s1) + (s2 + s3) + (s4 + s5) + (s6 + s7);
499
500        match wlen - k {
501            0 => {}
502            1 => {
503                sum += *d * *wptr.add(k + 0);
504            }
505            2 => {
506                sum += *d * *wptr.add(k + 0);
507                d = d.sub(1);
508                sum += *d * *wptr.add(k + 1);
509            }
510            3 => {
511                sum += *d * *wptr.add(k + 0);
512                d = d.sub(1);
513                sum += *d * *wptr.add(k + 1);
514                d = d.sub(1);
515                sum += *d * *wptr.add(k + 2);
516            }
517            4 => {
518                sum += *d * *wptr.add(k + 0);
519                d = d.sub(1);
520                sum += *d * *wptr.add(k + 1);
521                d = d.sub(1);
522                sum += *d * *wptr.add(k + 2);
523                d = d.sub(1);
524                sum += *d * *wptr.add(k + 3);
525            }
526            5 => {
527                sum += *d * *wptr.add(k + 0);
528                d = d.sub(1);
529                sum += *d * *wptr.add(k + 1);
530                d = d.sub(1);
531                sum += *d * *wptr.add(k + 2);
532                d = d.sub(1);
533                sum += *d * *wptr.add(k + 3);
534                d = d.sub(1);
535                sum += *d * *wptr.add(k + 4);
536            }
537            6 => {
538                sum += *d * *wptr.add(k + 0);
539                d = d.sub(1);
540                sum += *d * *wptr.add(k + 1);
541                d = d.sub(1);
542                sum += *d * *wptr.add(k + 2);
543                d = d.sub(1);
544                sum += *d * *wptr.add(k + 3);
545                d = d.sub(1);
546                sum += *d * *wptr.add(k + 4);
547                d = d.sub(1);
548                sum += *d * *wptr.add(k + 5);
549            }
550            7 => {
551                sum += *d * *wptr.add(k + 0);
552                d = d.sub(1);
553                sum += *d * *wptr.add(k + 1);
554                d = d.sub(1);
555                sum += *d * *wptr.add(k + 2);
556                d = d.sub(1);
557                sum += *d * *wptr.add(k + 3);
558                d = d.sub(1);
559                sum += *d * *wptr.add(k + 4);
560                d = d.sub(1);
561                sum += *d * *wptr.add(k + 5);
562                d = d.sub(1);
563                sum += *d * *wptr.add(k + 6);
564            }
565            _ => core::hint::unreachable_unchecked(),
566        }
567
568        *out.get_unchecked_mut(i) = sum * inv_norm;
569    }
570}
571
572#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
573#[target_feature(enable = "avx2,fma")]
574pub unsafe fn cwma_avx2(
575    data: &[f64],
576    weights: &[f64],
577    _period: usize,
578    first_valid: usize,
579    inv_norm: f64,
580    out: &mut [f64],
581) {
582    const STEP: usize = 4;
583
584    let wlen = weights.len();
585    let chunks = wlen / STEP;
586    let tail = wlen % STEP;
587    let first_out = first_valid + wlen;
588
589    for i in first_out..data.len() {
590        let mut acc = _mm256_setzero_pd();
591
592        for blk in 0..chunks {
593            let idx = blk * STEP;
594            let w = _mm256_loadu_pd(weights.as_ptr().add(idx));
595
596            let base = i - idx - (STEP - 1);
597            let mut d = _mm256_loadu_pd(data.as_ptr().add(base));
598            d = _mm256_permute4x64_pd(d, 0b00011011);
599
600            acc = _mm256_fmadd_pd(d, w, acc);
601        }
602
603        let mut tail_sum = 0.0;
604        if tail != 0 {
605            let base = chunks * STEP;
606            for k in 0..tail {
607                let w = *weights.get_unchecked(base + k);
608                let d = *data.get_unchecked(i - (base + k));
609                tail_sum = d.mul_add(w, tail_sum);
610            }
611        }
612
613        let hi = _mm256_extractf128_pd(acc, 1);
614        let lo = _mm256_castpd256_pd128(acc);
615        let sum2 = _mm_add_pd(hi, lo);
616        let sum1 = _mm_add_pd(sum2, _mm_unpackhi_pd(sum2, sum2));
617        let mut sum = _mm_cvtsd_f64(sum1);
618
619        sum += tail_sum;
620        *out.get_unchecked_mut(i) = sum * inv_norm;
621    }
622}
623#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
624#[inline]
625pub fn cwma_avx512(
626    data: &[f64],
627    weights: &[f64],
628    period: usize,
629    first_valid: usize,
630    inv_norm: f64,
631    out: &mut [f64],
632) {
633    if weights.len() < 24 {
634        unsafe { cwma_avx2(data, weights, period, first_valid, inv_norm, out) }
635        return;
636    }
637    if period <= 32 {
638        unsafe { cwma_avx512_short(data, weights, period, first_valid, inv_norm, out) }
639    } else {
640        unsafe { cwma_avx512_long(data, weights, period, first_valid, inv_norm, out) }
641    }
642}
643#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
644#[target_feature(enable = "avx512f")]
645unsafe fn reverse_vec(v: __m512d) -> __m512d {
646    let lanes = _mm512_set_epi64(0, 1, 2, 3, 4, 5, 6, 7);
647    _mm512_permutexvar_pd(lanes, v)
648}
649#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
650#[target_feature(enable = "avx512f")]
651unsafe fn make_weight_blocks(weights: &[f64]) -> Vec<__m512d> {
652    const STEP: usize = 8;
653    weights
654        .chunks_exact(STEP)
655        .map(|chunk| {
656            let v = _mm512_loadu_pd(chunk.as_ptr());
657            reverse_vec(v)
658        })
659        .collect()
660}
661#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
662#[target_feature(enable = "avx512f,fma")]
663pub unsafe fn cwma_avx512_short(
664    data: &[f64],
665    weights: &[f64],
666    period: usize,
667    first_valid: usize,
668    inv_norm: f64,
669    out: &mut [f64],
670) {
671    debug_assert!(period <= 32);
672    debug_assert_eq!(weights.len(), period - 1);
673
674    const STEP: usize = 8;
675    let wlen = weights.len();
676    let chunks = wlen / STEP;
677    let tail = wlen % STEP;
678    let first_out = first_valid + wlen;
679
680    let mut wv: [__m512d; 4] = [_mm512_setzero_pd(); 4];
681    if chunks >= 1 {
682        wv[0] = reverse_vec(_mm512_loadu_pd(weights.as_ptr()));
683    }
684    if chunks >= 2 {
685        wv[1] = reverse_vec(_mm512_loadu_pd(weights.as_ptr().add(STEP)));
686    }
687    if chunks >= 3 {
688        wv[2] = reverse_vec(_mm512_loadu_pd(weights.as_ptr().add(2 * STEP)));
689    }
690    if chunks == 4 {
691        wv[3] = reverse_vec(_mm512_loadu_pd(weights.as_ptr().add(3 * STEP)));
692    }
693
694    for i in first_out..data.len() {
695        let mut acc = _mm512_setzero_pd();
696
697        if chunks >= 1 && i >= 7 {
698            let d0 = _mm512_loadu_pd(data.as_ptr().add(i - 7));
699            acc = _mm512_fmadd_pd(d0, wv[0], acc);
700        }
701        if chunks >= 2 && i >= STEP + 7 {
702            let d1 = _mm512_loadu_pd(data.as_ptr().add(i - STEP - 7));
703            acc = _mm512_fmadd_pd(d1, wv[1], acc);
704        }
705        if chunks >= 3 && i >= 2 * STEP + 7 {
706            let d2 = _mm512_loadu_pd(data.as_ptr().add(i - 2 * STEP - 7));
707            acc = _mm512_fmadd_pd(d2, wv[2], acc);
708        }
709        if chunks == 4 && i >= 3 * STEP + 7 {
710            let d3 = _mm512_loadu_pd(data.as_ptr().add(i - 3 * STEP - 7));
711            acc = _mm512_fmadd_pd(d3, wv[3], acc);
712        }
713
714        let mut sum = _mm512_reduce_add_pd(acc);
715
716        if tail != 0 {
717            let base = chunks * STEP;
718            for k in 0..tail {
719                let w = *weights.get_unchecked(base + k);
720                let d = *data.get_unchecked(i - (base + k));
721                sum = d.mul_add(w, sum);
722            }
723        }
724
725        *out.get_unchecked_mut(i) = sum * inv_norm;
726    }
727}
728#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
729#[target_feature(enable = "avx512f,fma")]
730pub unsafe fn cwma_avx512_long(
731    data: &[f64],
732    weights: &[f64],
733    _period: usize,
734    first_valid: usize,
735    inv_norm: f64,
736    out: &mut [f64],
737) {
738    const STEP: usize = 8;
739
740    let wlen = weights.len();
741    let chunks = wlen / STEP;
742    let tail = wlen % STEP;
743    let first_out = first_valid + wlen;
744
745    if wlen < 24 {
746        cwma_avx2(data, weights, _period, first_valid, inv_norm, out);
747        return;
748    }
749
750    let wblocks = make_weight_blocks(weights);
751
752    for i in first_out..data.len() {
753        let mut acc0 = _mm512_setzero_pd();
754        let mut acc1 = _mm512_setzero_pd();
755
756        let paired = chunks & !1;
757        let mut blk = 0;
758
759        while blk < paired {
760            if i >= blk * STEP + 7 && i >= (blk + 1) * STEP + 7 {
761                let d0 = _mm512_loadu_pd(data.as_ptr().add(i - blk * STEP - 7));
762                let d1 = _mm512_loadu_pd(data.as_ptr().add(i - (blk + 1) * STEP - 7));
763
764                acc0 = _mm512_fmadd_pd(d0, *wblocks.get_unchecked(blk), acc0);
765                acc1 = _mm512_fmadd_pd(d1, *wblocks.get_unchecked(blk + 1), acc1);
766            }
767
768            blk += 2;
769        }
770
771        if blk < chunks && i >= blk * STEP + 7 {
772            let d = _mm512_loadu_pd(data.as_ptr().add(i - blk * STEP - 7));
773            acc0 = _mm512_fmadd_pd(d, *wblocks.get_unchecked(blk), acc0);
774        }
775
776        let mut sum = _mm512_reduce_add_pd(_mm512_add_pd(acc0, acc1));
777
778        if tail != 0 {
779            let base = chunks * STEP;
780            for k in 0..tail {
781                let w = *weights.get_unchecked(base + k);
782                let d = *data.get_unchecked(i - (base + k));
783                sum = d.mul_add(w, sum);
784            }
785        }
786
787        *out.get_unchecked_mut(i) = sum * inv_norm;
788    }
789}
790
791#[derive(Debug, Clone)]
792pub struct CwmaStream {
793    period: usize,
794    inv_norm: f64,
795
796    n: usize,
797
798    ring: Vec<f64>,
799    head: usize,
800    filled: usize,
801    nan_count: usize,
802
803    total_count: usize,
804    found_first: bool,
805    first_idx: usize,
806
807    m0: f64,
808    m1: f64,
809    m2: f64,
810    m3: f64,
811
812    s: f64,
813
814    a: f64,
815    w1: f64,
816    wn: f64,
817    alpha0: f64,
818    alpha1: f64,
819    alpha2: f64,
820
821    n_f: f64,
822    n_sq: f64,
823    n_p1: f64,
824    n_p1_sq: f64,
825
826    moments_ready: bool,
827}
828
829impl CwmaStream {
830    pub fn try_new(params: CwmaParams) -> Result<Self, CwmaError> {
831        let period = params.period.unwrap_or(14);
832        if period <= 1 {
833            return Err(CwmaError::InvalidPeriod {
834                period,
835                data_len: 0,
836            });
837        }
838
839        let n = period - 1;
840
841        let mut norm = 0.0;
842        for j in 2..=period {
843            let jf = j as f64;
844            norm += jf * jf * jf;
845        }
846        let inv_norm = 1.0 / norm;
847
848        let n_f = n as f64;
849        let n_p1 = (n + 1) as f64;
850        let n_p1_sq = n_p1 * n_p1;
851        let a = (n + 2) as f64;
852
853        let w1 = n_p1 * n_p1 * n_p1;
854        let wn = 8.0;
855
856        let alpha0 = -3.0 * a * a + 3.0 * a - 1.0;
857        let alpha1 = 6.0 * a - 3.0;
858        let alpha2 = -3.0;
859
860        Ok(Self {
861            period,
862            inv_norm,
863            n,
864            ring: vec![f64::NAN; n.max(1)],
865            head: 0,
866            filled: 0,
867            nan_count: 0,
868            total_count: 0,
869            found_first: false,
870            first_idx: 0,
871
872            m0: 0.0,
873            m1: 0.0,
874            m2: 0.0,
875            m3: 0.0,
876            s: 0.0,
877
878            a,
879            w1,
880            wn,
881            alpha0,
882            alpha1,
883            alpha2,
884            n_f,
885            n_sq: n_f * n_f,
886            n_p1,
887            n_p1_sq,
888            moments_ready: false,
889        })
890    }
891
892    #[inline(always)]
893    pub fn update(&mut self, value: f64) -> Option<f64> {
894        let idx = self.total_count;
895        self.total_count = idx + 1;
896
897        if !self.found_first {
898            if value.is_nan() {
899                return None;
900            } else {
901                self.found_first = true;
902                self.first_idx = idx;
903            }
904        }
905
906        let mut old = f64::NAN;
907        if self.filled >= self.n {
908            old = self.ring[self.head];
909        }
910
911        let new_nan = value.is_nan() as usize;
912        let old_nan = (self.filled >= self.n && old.is_nan()) as usize;
913
914        if self.n > 0 {
915            self.ring[self.head] = value;
916            self.head = (self.head + 1) % self.n;
917        }
918
919        if self.filled <= self.n {
920            self.filled += 1;
921            self.nan_count += new_nan;
922
923            if self.filled == self.n + 1 {
924                self.nan_count -= old_nan;
925            }
926
927            if self.filled <= self.n {
928                return None;
929            }
930
931            if self.nan_count > 0 {
932                self.moments_ready = false;
933                return Some(f64::NAN);
934            }
935
936            self.rebuild_moments_and_sum();
937            self.moments_ready = true;
938            return Some(self.sum_weighted() * self.inv_norm);
939        }
940
941        self.nan_count = self.nan_count + new_nan - old_nan;
942
943        if self.nan_count > 0 {
944            self.moments_ready = false;
945            return Some(f64::NAN);
946        }
947
948        if !self.moments_ready {
949            self.rebuild_moments_and_sum();
950            self.moments_ready = true;
951            return Some(self.sum_weighted() * self.inv_norm);
952        }
953
954        let m0_prev = self.m0;
955        let m1_prev = self.m1;
956        let m2_prev = self.m2;
957        let m3_prev = self.m3;
958
959        let newv = value;
960        let oldv = old;
961
962        self.m0 = m0_prev + newv - oldv;
963        self.m1 = (-self.n_p1).mul_add(oldv, m1_prev + m0_prev + newv);
964        let tmp2 = m1_prev.mul_add(2.0, m2_prev + m0_prev + newv);
965        self.m2 = (-self.n_p1_sq).mul_add(oldv, tmp2);
966        let np13 = self.n_p1 * self.n_p1 * self.n_p1;
967        let tmp3 = m2_prev.mul_add(3.0, m3_prev + m0_prev + newv);
968        let tmp3 = m1_prev.mul_add(3.0, tmp3);
969        self.m3 = (-np13).mul_add(oldv, tmp3);
970
971        let mut ds = newv.mul_add(self.w1, 0.0);
972        ds = oldv.mul_add(-self.wn, ds);
973        let t1 = self.alpha0.mul_add(m0_prev - oldv, ds);
974        let u1 = (-self.n_f).mul_add(oldv, m1_prev);
975        let t2 = self.alpha1.mul_add(u1, t1);
976        let u2 = (-self.n_sq).mul_add(oldv, m2_prev);
977        let delta_s = self.alpha2.mul_add(u2, t2);
978        self.s += delta_s;
979
980        Some(self.sum_weighted() * self.inv_norm)
981    }
982
983    #[inline(always)]
984    fn rebuild_moments_and_sum(&mut self) {
985        debug_assert!(self.nan_count == 0, "rebuild called with NaNs present");
986        let mut m0 = 0.0;
987        let mut m1 = 0.0;
988        let mut m2 = 0.0;
989        let mut m3 = 0.0;
990        let mut s = 0.0;
991
992        let a = self.a;
993        for r in 1..=self.n {
994            let idx = (self.head + self.n - r) % self.n;
995            let v = self.ring[idx];
996            let rf = r as f64;
997
998            m0 += v;
999            m1 += rf * v;
1000            m2 += (rf * rf) * v;
1001            m3 += (rf * rf * rf) * v;
1002
1003            let w = {
1004                let t = a - rf;
1005                t * t * t
1006            };
1007            s = v.mul_add(w, s);
1008        }
1009
1010        self.m0 = m0;
1011        self.m1 = m1;
1012        self.m2 = m2;
1013        self.m3 = m3;
1014        self.s = s;
1015    }
1016
1017    #[inline(always)]
1018    fn sum_weighted(&self) -> f64 {
1019        let mut s = 0.0;
1020        let a = self.a;
1021        if self.n == 0 {
1022            return 0.0;
1023        }
1024        for r in 1..=self.n {
1025            let idx = (self.head + self.n - r) % self.n;
1026            let v = self.ring[idx];
1027            let rf = r as f64;
1028            let t = a - rf;
1029            let w = t * t * t;
1030            s = v.mul_add(w, s);
1031        }
1032        s
1033    }
1034}
1035
1036#[derive(Clone, Debug)]
1037pub struct CwmaBatchRange {
1038    pub period: (usize, usize, usize),
1039}
1040
1041impl Default for CwmaBatchRange {
1042    fn default() -> Self {
1043        Self {
1044            period: (14, 263, 1),
1045        }
1046    }
1047}
1048
1049#[derive(Clone, Debug, Default)]
1050pub struct CwmaBatchBuilder {
1051    range: CwmaBatchRange,
1052    kernel: Kernel,
1053}
1054
1055impl CwmaBatchBuilder {
1056    pub fn new() -> Self {
1057        Self::default()
1058    }
1059    pub fn kernel(mut self, k: Kernel) -> Self {
1060        self.kernel = k;
1061        self
1062    }
1063
1064    #[inline]
1065    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1066        self.range.period = (start, end, step);
1067        self
1068    }
1069    #[inline]
1070    pub fn period_static(mut self, p: usize) -> Self {
1071        self.range.period = (p, p, 0);
1072        self
1073    }
1074
1075    pub fn apply_slice(self, data: &[f64]) -> Result<CwmaBatchOutput, CwmaError> {
1076        cwma_batch_with_kernel(data, &self.range, self.kernel)
1077    }
1078
1079    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<CwmaBatchOutput, CwmaError> {
1080        CwmaBatchBuilder::new().kernel(k).apply_slice(data)
1081    }
1082
1083    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<CwmaBatchOutput, CwmaError> {
1084        let slice = source_type(c, src);
1085        self.apply_slice(slice)
1086    }
1087
1088    pub fn with_default_candles(c: &Candles) -> Result<CwmaBatchOutput, CwmaError> {
1089        CwmaBatchBuilder::new()
1090            .kernel(Kernel::Auto)
1091            .apply_candles(c, "close")
1092    }
1093}
1094
1095pub fn cwma_batch_with_kernel(
1096    data: &[f64],
1097    sweep: &CwmaBatchRange,
1098    k: Kernel,
1099) -> Result<CwmaBatchOutput, CwmaError> {
1100    let kernel = match k {
1101        Kernel::Auto => detect_best_batch_kernel(),
1102        other if other.is_batch() => other,
1103        other => return Err(CwmaError::InvalidKernelForBatch(other)),
1104    };
1105
1106    let simd = match kernel {
1107        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1108        Kernel::Avx512Batch => Kernel::Avx512,
1109        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1110        Kernel::Avx2Batch => Kernel::Avx2,
1111        Kernel::ScalarBatch => Kernel::Scalar,
1112        #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1113        Kernel::Avx2Batch | Kernel::Avx512Batch => Kernel::Scalar,
1114        _ => unreachable!(),
1115    };
1116    cwma_batch_par_slice(data, sweep, simd)
1117}
1118
1119#[derive(Clone, Debug)]
1120pub struct CwmaBatchOutput {
1121    pub values: Vec<f64>,
1122    pub combos: Vec<CwmaParams>,
1123    pub rows: usize,
1124    pub cols: usize,
1125}
1126
1127impl CwmaBatchOutput {
1128    pub fn row_for_params(&self, p: &CwmaParams) -> Option<usize> {
1129        self.combos
1130            .iter()
1131            .position(|c| c.period.unwrap_or(14) == p.period.unwrap_or(14))
1132    }
1133    pub fn values_for(&self, p: &CwmaParams) -> Option<&[f64]> {
1134        self.row_for_params(p).map(|row| {
1135            let start = row * self.cols;
1136            &self.values[start..start + self.cols]
1137        })
1138    }
1139}
1140
1141#[inline(always)]
1142fn expand_grid(r: &CwmaBatchRange) -> Vec<CwmaParams> {
1143    fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
1144        if step == 0 || start == end {
1145            return vec![start];
1146        }
1147        let (lo, hi) = if start <= end {
1148            (start, end)
1149        } else {
1150            (end, start)
1151        };
1152        (lo..=hi).step_by(step).collect()
1153    }
1154
1155    let periods = axis_usize(r.period);
1156
1157    let mut out = Vec::with_capacity(periods.len());
1158    for &p in &periods {
1159        out.push(CwmaParams { period: Some(p) });
1160    }
1161    out
1162}
1163
1164#[inline(always)]
1165pub fn cwma_batch_slice(
1166    data: &[f64],
1167    sweep: &CwmaBatchRange,
1168    kern: Kernel,
1169) -> Result<CwmaBatchOutput, CwmaError> {
1170    cwma_batch_inner(data, sweep, kern, false)
1171}
1172
1173#[inline(always)]
1174pub fn cwma_batch_par_slice(
1175    data: &[f64],
1176    sweep: &CwmaBatchRange,
1177    kern: Kernel,
1178) -> Result<CwmaBatchOutput, CwmaError> {
1179    cwma_batch_inner(data, sweep, kern, true)
1180}
1181
1182#[inline]
1183fn round_up8(x: usize) -> usize {
1184    (x + 7) & !7
1185}
1186
1187#[inline(always)]
1188fn cwma_batch_inner(
1189    data: &[f64],
1190    sweep: &CwmaBatchRange,
1191    kern: Kernel,
1192    parallel: bool,
1193) -> Result<CwmaBatchOutput, CwmaError> {
1194    let combos = expand_grid(sweep);
1195    let cols = data.len();
1196    let rows = combos.len();
1197
1198    let _total = rows
1199        .checked_mul(cols)
1200        .ok_or(CwmaError::SizeOverflow { ctx: "rows*cols" })?;
1201
1202    if cols == 0 {
1203        return Err(CwmaError::EmptyInputData);
1204    }
1205
1206    let first = data
1207        .iter()
1208        .position(|x| !x.is_nan())
1209        .ok_or(CwmaError::AllValuesNaN)?;
1210
1211    let max_p = combos
1212        .iter()
1213        .map(|c| round_up8(c.period.unwrap()))
1214        .max()
1215        .unwrap();
1216
1217    if (cols - first) < max_p {
1218        return Err(CwmaError::NotEnoughValidData {
1219            needed: max_p,
1220            valid: cols - first,
1221        });
1222    }
1223
1224    let mut buf_mu = make_uninit_matrix(rows, cols);
1225
1226    let warm: Vec<usize> = combos
1227        .iter()
1228        .map(|c| first + c.period.unwrap() - 1)
1229        .collect();
1230
1231    init_matrix_prefixes(&mut buf_mu, cols, &warm);
1232
1233    let mut buf_guard = core::mem::ManuallyDrop::new(buf_mu);
1234    let out: &mut [f64] = unsafe {
1235        core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
1236    };
1237
1238    cwma_batch_inner_into(data, sweep, kern, parallel, out)?;
1239
1240    let values = unsafe {
1241        Vec::from_raw_parts(
1242            buf_guard.as_mut_ptr() as *mut f64,
1243            buf_guard.len(),
1244            buf_guard.capacity(),
1245        )
1246    };
1247
1248    Ok(CwmaBatchOutput {
1249        values,
1250        combos,
1251        rows,
1252        cols,
1253    })
1254}
1255
1256#[inline(always)]
1257fn cwma_batch_inner_into(
1258    data: &[f64],
1259    sweep: &CwmaBatchRange,
1260    kern: Kernel,
1261    parallel: bool,
1262    out: &mut [f64],
1263) -> Result<Vec<CwmaParams>, CwmaError> {
1264    let combos = expand_grid(sweep);
1265
1266    let first = data
1267        .iter()
1268        .position(|x| !x.is_nan())
1269        .ok_or(CwmaError::AllValuesNaN)?;
1270
1271    let max_p = combos
1272        .iter()
1273        .map(|c| round_up8(c.period.unwrap()))
1274        .max()
1275        .unwrap();
1276
1277    let rows = combos.len();
1278    let cols = data.len();
1279    let expected = rows
1280        .checked_mul(cols)
1281        .ok_or(CwmaError::SizeOverflow { ctx: "rows*cols" })?;
1282    if out.len() != expected {
1283        return Err(CwmaError::OutputLengthMismatch {
1284            expected,
1285            got: out.len(),
1286        });
1287    }
1288    let mut inv_norms = vec![0.0; rows];
1289
1290    let cap = rows
1291        .checked_mul(max_p)
1292        .ok_or(CwmaError::SizeOverflow { ctx: "rows*max_p" })?;
1293    let mut aligned = AlignedVec::with_capacity(cap);
1294    let flat_w = aligned.as_mut_slice();
1295
1296    for (row, prm) in combos.iter().enumerate() {
1297        let period = prm.period.unwrap();
1298        let mut norm = 0.0;
1299        for i in 0..period - 1 {
1300            let w = cube((period - i) as f64);
1301            flat_w[row * max_p + i] = w;
1302            norm += w;
1303        }
1304        let inv_norm = 1.0 / norm;
1305        inv_norms[row] = inv_norm;
1306    }
1307
1308    let out_uninit = unsafe {
1309        std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
1310    };
1311
1312    let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
1313        let period = combos[row].period.unwrap();
1314        let w_ptr = flat_w.as_ptr().add(row * max_p);
1315        let inv_n = *inv_norms.get_unchecked(row);
1316
1317        let out_row =
1318            core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
1319
1320        match kern {
1321            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1322            Kernel::Avx512 => cwma_row_avx512(data, first, period, max_p, w_ptr, inv_n, out_row),
1323            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1324            Kernel::Avx2 => cwma_row_avx2(data, first, period, max_p, w_ptr, inv_n, out_row),
1325            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1326            Kernel::Avx2 | Kernel::Avx512 => {
1327                cwma_row_scalar(data, first, period, max_p, w_ptr, inv_n, out_row)
1328            }
1329            _ => cwma_row_scalar(data, first, period, max_p, w_ptr, inv_n, out_row),
1330        }
1331    };
1332
1333    if parallel {
1334        #[cfg(not(target_arch = "wasm32"))]
1335        {
1336            out_uninit
1337                .par_chunks_mut(cols)
1338                .enumerate()
1339                .for_each(|(row, slice)| do_row(row, slice));
1340        }
1341
1342        #[cfg(target_arch = "wasm32")]
1343        {
1344            for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
1345                do_row(row, slice);
1346            }
1347        }
1348    } else {
1349        for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
1350            do_row(row, slice);
1351        }
1352    }
1353
1354    Ok(combos)
1355}
1356
1357#[inline(always)]
1358unsafe fn cwma_row_scalar(
1359    data: &[f64],
1360    first: usize,
1361    period: usize,
1362    stride: usize,
1363    w_ptr: *const f64,
1364    inv_n: f64,
1365    out: &mut [f64],
1366) {
1367    let wlen = period - 1;
1368    let start_idx = first + wlen;
1369    for i in start_idx..data.len().min(out.len()) {
1370        let mut sum = 0.0;
1371        for k in 0..wlen {
1372            let w = *w_ptr.add(k);
1373            let d = *data.get_unchecked(i - k);
1374            sum += d * w;
1375        }
1376        out[i] = sum * inv_n;
1377    }
1378}
1379
1380#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1381#[inline(always)]
1382unsafe fn load_w512_rev(ptr: *const f64) -> __m512d {
1383    let rev = _mm512_set_epi64(0, 1, 2, 3, 4, 5, 6, 7);
1384    _mm512_permutexvar_pd(rev, _mm512_loadu_pd(ptr))
1385}
1386
1387#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1388#[inline(always)]
1389unsafe fn load_w256_rev(ptr: *const f64) -> __m256d {
1390    _mm256_permute4x64_pd::<0b00011011>(_mm256_loadu_pd(ptr))
1391}
1392
1393#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1394#[target_feature(enable = "avx512f,fma")]
1395#[inline]
1396unsafe fn cwma_row_avx512(
1397    data: &[f64],
1398    first: usize,
1399    period: usize,
1400    _stride: usize,
1401    w_ptr: *const f64,
1402    inv_n: f64,
1403    out: &mut [f64],
1404) {
1405    if period <= 32 {
1406        cwma_row_avx512_short(data, first, period, w_ptr, inv_n, out);
1407    } else {
1408        cwma_row_avx512_long(data, first, period, w_ptr, inv_n, out);
1409    }
1410}
1411
1412#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1413#[target_feature(enable = "avx512f,fma")]
1414#[inline]
1415unsafe fn cwma_row_avx512_short(
1416    data: &[f64],
1417    first: usize,
1418    period: usize,
1419    w_ptr: *const f64,
1420    inv_n: f64,
1421    out: &mut [f64],
1422) {
1423    const STEP: usize = 8;
1424    let wlen = period - 1;
1425    let chunks = wlen / STEP;
1426    let tail_len = wlen % STEP;
1427    let tmask: __mmask8 = (1u8 << tail_len).wrapping_sub(1);
1428
1429    let w0 = load_w512_rev(w_ptr);
1430    let w1 = if chunks >= 2 {
1431        Some(load_w512_rev(w_ptr.add(STEP)))
1432    } else {
1433        None
1434    };
1435
1436    let start_idx = first + wlen;
1437    for i in start_idx..data.len().min(out.len()) {
1438        let mut acc = _mm512_setzero_pd();
1439
1440        if chunks >= 1 && i >= 7 {
1441            acc = _mm512_fmadd_pd(_mm512_loadu_pd(data.as_ptr().add(i - 7)), w0, acc);
1442        }
1443
1444        if let Some(w1v) = w1 {
1445            if i >= STEP + 7 {
1446                let d1 = _mm512_loadu_pd(data.as_ptr().add(i - STEP - 7));
1447                acc = _mm512_fmadd_pd(d1, w1v, acc);
1448            }
1449        }
1450
1451        let mut tail_sum = 0.0;
1452        if tail_len != 0 {
1453            let base = chunks * STEP;
1454            for k in 0..tail_len {
1455                let w = *w_ptr.add(base + k);
1456                let d = *data.get_unchecked(i - (base + k));
1457                tail_sum = d.mul_add(w, tail_sum);
1458            }
1459        }
1460
1461        let sum = _mm512_reduce_add_pd(acc) + tail_sum;
1462        *out.get_unchecked_mut(i) = sum * inv_n;
1463    }
1464}
1465#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1466#[target_feature(enable = "avx512f,avx512dq,fma")]
1467#[inline]
1468unsafe fn cwma_row_avx512_long(
1469    data: &[f64],
1470    first: usize,
1471    period: usize,
1472    w_ptr: *const f64,
1473    inv_n: f64,
1474    out: &mut [f64],
1475) {
1476    const STEP: usize = 8;
1477    const MAX: usize = 512;
1478
1479    let wlen = period - 1;
1480    let n_chunks = wlen / STEP;
1481    let tail_len = wlen % STEP;
1482    let tmask: __mmask8 = (1u8 << tail_len).wrapping_sub(1);
1483
1484    debug_assert!(n_chunks + (tail_len != 0) as usize <= MAX);
1485
1486    let mut wregs: [core::mem::MaybeUninit<__m512d>; MAX] =
1487        core::mem::MaybeUninit::uninit().assume_init();
1488
1489    for blk in 0..n_chunks {
1490        let src = w_ptr.add(wlen - (blk + 1) * STEP);
1491        wregs[blk].as_mut_ptr().write(load_w512_rev(src));
1492    }
1493    if tail_len != 0 {
1494        let src = w_ptr.add(wlen - n_chunks * STEP - tail_len);
1495        let wtl = _mm512_maskz_loadu_pd(tmask, src);
1496        wregs[n_chunks].as_mut_ptr().write(_mm512_permutexvar_pd(
1497            _mm512_set_epi64(0, 1, 2, 3, 4, 5, 6, 7),
1498            wtl,
1499        ));
1500    }
1501    let wregs: &[__m512d] = core::slice::from_raw_parts(
1502        wregs.as_ptr() as *const __m512d,
1503        n_chunks + (tail_len != 0) as usize,
1504    );
1505
1506    let start_idx = first + wlen;
1507    for i in start_idx..data.len().min(out.len()) {
1508        let base_ptr = data.as_ptr().add(i - wlen);
1509
1510        let mut s0 = _mm512_setzero_pd();
1511        let mut s1 = _mm512_setzero_pd();
1512        let mut s2 = _mm512_setzero_pd();
1513        let mut s3 = _mm512_setzero_pd();
1514
1515        let mut blk = 0;
1516        while blk + 3 < n_chunks {
1517            let d0 = _mm512_loadu_pd(base_ptr.add((blk + 0) * STEP));
1518            let d1 = _mm512_loadu_pd(base_ptr.add((blk + 1) * STEP));
1519            let d2 = _mm512_loadu_pd(base_ptr.add((blk + 2) * STEP));
1520            let d3 = _mm512_loadu_pd(base_ptr.add((blk + 3) * STEP));
1521
1522            s0 = _mm512_fmadd_pd(d0, *wregs.get_unchecked(blk + 0), s0);
1523            s1 = _mm512_fmadd_pd(d1, *wregs.get_unchecked(blk + 1), s1);
1524            s2 = _mm512_fmadd_pd(d2, *wregs.get_unchecked(blk + 2), s2);
1525            s3 = _mm512_fmadd_pd(d3, *wregs.get_unchecked(blk + 3), s3);
1526
1527            blk += 4;
1528        }
1529
1530        for r in blk..n_chunks {
1531            let d = _mm512_loadu_pd(base_ptr.add(r * STEP));
1532            s0 = _mm512_fmadd_pd(d, *wregs.get_unchecked(r), s0);
1533        }
1534
1535        let mut sum =
1536            _mm512_reduce_add_pd(_mm512_add_pd(_mm512_add_pd(s0, s1), _mm512_add_pd(s2, s3)));
1537
1538        if tail_len != 0 {
1539            let d_tail = _mm512_maskz_loadu_pd(tmask, base_ptr.add(n_chunks * STEP));
1540            let tail = _mm512_mul_pd(d_tail, *wregs.get_unchecked(n_chunks));
1541            sum += _mm512_reduce_add_pd(tail);
1542        }
1543
1544        out[i] = sum * inv_n;
1545    }
1546}
1547#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1548#[target_feature(enable = "avx2,fma")]
1549#[inline]
1550unsafe fn cwma_row_avx2(
1551    data: &[f64],
1552    first: usize,
1553    period: usize,
1554    _stride: usize,
1555    w_ptr: *const f64,
1556    inv_n: f64,
1557    out: &mut [f64],
1558) {
1559    const STEP: usize = 4;
1560    let wlen = period - 1;
1561    let vec_blks = wlen / STEP;
1562    let tail = wlen % STEP;
1563
1564    let tail_mask = match tail {
1565        0 => _mm256_setzero_si256(),
1566        1 => _mm256_setr_epi64x(-1, 0, 0, 0),
1567        2 => _mm256_setr_epi64x(-1, -1, 0, 0),
1568        3 => _mm256_setr_epi64x(-1, -1, -1, 0),
1569        _ => unreachable!(),
1570    };
1571
1572    let start_idx = first + wlen;
1573    for i in start_idx..data.len().min(out.len()) {
1574        let mut acc = _mm256_setzero_pd();
1575
1576        for blk in 0..vec_blks {
1577            let d = _mm256_loadu_pd(data.as_ptr().add(i - blk * STEP - 3));
1578            let w = load_w256_rev(w_ptr.add(blk * STEP));
1579            acc = _mm256_fmadd_pd(d, w, acc);
1580        }
1581
1582        if tail != 0 {
1583            let base = vec_blks * STEP;
1584            let mut tail_sum = 0.0;
1585            for k in 0..tail {
1586                let w = *w_ptr.add(base + k);
1587                let d = *data.get_unchecked(i - (base + k));
1588                tail_sum = d.mul_add(w, tail_sum);
1589            }
1590            let hi = _mm256_extractf128_pd(acc, 1);
1591            let lo = _mm256_castpd256_pd128(acc);
1592            let tmp = _mm_add_pd(hi, lo);
1593            let tmp2 = _mm_add_pd(tmp, _mm_unpackhi_pd(tmp, tmp));
1594            out[i] = (_mm_cvtsd_f64(tmp2) + tail_sum) * inv_n;
1595        } else {
1596            let hi = _mm256_extractf128_pd(acc, 1);
1597            let lo = _mm256_castpd256_pd128(acc);
1598            let tmp = _mm_add_pd(hi, lo);
1599            let tmp2 = _mm_add_pd(tmp, _mm_unpackhi_pd(tmp, tmp));
1600            out[i] = _mm_cvtsd_f64(tmp2) * inv_n;
1601        }
1602    }
1603}
1604
1605#[cfg(feature = "python")]
1606#[pyfunction(name = "cwma")]
1607#[pyo3(signature = (data, period, kernel=None))]
1608pub fn cwma_py<'py>(
1609    py: Python<'py>,
1610    data: numpy::PyReadonlyArray1<'py, f64>,
1611    period: usize,
1612    kernel: Option<&str>,
1613) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
1614    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1615
1616    let slice_in = data.as_slice()?;
1617    let kern = validate_kernel(kernel, false)?;
1618
1619    let params = CwmaParams {
1620        period: Some(period),
1621    };
1622    let cwma_in = CwmaInput::from_slice(slice_in, params);
1623
1624    let result_vec: Vec<f64> = py
1625        .allow_threads(|| cwma_with_kernel(&cwma_in, kern).map(|o| o.values))
1626        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1627
1628    Ok(result_vec.into_pyarray(py))
1629}
1630
1631#[cfg(feature = "python")]
1632#[pyclass(name = "CwmaStream")]
1633pub struct CwmaStreamPy {
1634    stream: CwmaStream,
1635}
1636
1637#[cfg(feature = "python")]
1638#[pymethods]
1639impl CwmaStreamPy {
1640    #[new]
1641    fn new(period: usize) -> PyResult<Self> {
1642        let params = CwmaParams {
1643            period: Some(period),
1644        };
1645        let stream =
1646            CwmaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1647        Ok(CwmaStreamPy { stream })
1648    }
1649
1650    fn update(&mut self, value: f64) -> Option<f64> {
1651        self.stream.update(value)
1652    }
1653}
1654
1655#[cfg(feature = "python")]
1656#[pyfunction(name = "cwma_batch")]
1657#[pyo3(signature = (data, period_range, kernel=None))]
1658pub fn cwma_batch_py<'py>(
1659    py: Python<'py>,
1660    data: numpy::PyReadonlyArray1<'py, f64>,
1661    period_range: (usize, usize, usize),
1662    kernel: Option<&str>,
1663) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1664    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1665    use pyo3::types::PyDict;
1666
1667    let slice_in = data.as_slice()?;
1668    let kern = validate_kernel(kernel, true)?;
1669
1670    let sweep = CwmaBatchRange {
1671        period: period_range,
1672    };
1673
1674    let combos = expand_grid(&sweep);
1675    let rows = combos.len();
1676    let cols = slice_in.len();
1677
1678    let out_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
1679    let slice_out = unsafe { out_arr.as_slice_mut()? };
1680
1681    let first = slice_in
1682        .iter()
1683        .position(|x| !x.is_nan())
1684        .unwrap_or(slice_in.len());
1685    for (row, combo) in combos.iter().enumerate() {
1686        let warmup = first + combo.period.unwrap() - 1;
1687        let row_start = row * cols;
1688        for i in 0..warmup.min(cols) {
1689            slice_out[row_start + i] = f64::NAN;
1690        }
1691    }
1692
1693    let combos = py
1694        .allow_threads(|| {
1695            let kernel = match kern {
1696                Kernel::Auto => detect_best_batch_kernel(),
1697                k => k,
1698            };
1699            let simd = match kernel {
1700                Kernel::Avx512Batch => Kernel::Avx512,
1701                Kernel::Avx2Batch => Kernel::Avx2,
1702                Kernel::ScalarBatch => Kernel::Scalar,
1703                _ => unreachable!(),
1704            };
1705
1706            cwma_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1707        })
1708        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1709
1710    let dict = PyDict::new(py);
1711    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1712    dict.set_item(
1713        "periods",
1714        combos
1715            .iter()
1716            .map(|p| p.period.unwrap() as u64)
1717            .collect::<Vec<_>>()
1718            .into_pyarray(py),
1719    )?;
1720
1721    Ok(dict)
1722}
1723
1724#[cfg(all(feature = "python", feature = "cuda"))]
1725#[pyclass(module = "ta_indicators.cuda", unsendable)]
1726pub struct DeviceArrayF32CwmaPy {
1727    pub(crate) inner: DeviceArrayF32,
1728    pub(crate) guard: Arc<CudaCwma>,
1729}
1730
1731#[cfg(all(feature = "python", feature = "cuda"))]
1732#[pymethods]
1733impl DeviceArrayF32CwmaPy {
1734    #[getter]
1735    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1736        let d = PyDict::new(py);
1737        d.set_item("shape", (self.inner.rows, self.inner.cols))?;
1738        d.set_item("typestr", "<f4")?;
1739        d.set_item(
1740            "strides",
1741            (
1742                self.inner.cols * std::mem::size_of::<f32>(),
1743                std::mem::size_of::<f32>(),
1744            ),
1745        )?;
1746        d.set_item("data", (self.inner.device_ptr() as usize, false))?;
1747
1748        d.set_item("version", 3)?;
1749        Ok(d)
1750    }
1751
1752    fn __dlpack_device__(&self) -> (i32, i32) {
1753        (2, self.guard.device_id() as i32)
1754    }
1755
1756    #[allow(clippy::too_many_arguments)]
1757    #[pyo3(signature=(stream=None, max_version=None, dl_device=None, copy=None))]
1758    fn __dlpack__<'py>(
1759        &mut self,
1760        py: Python<'py>,
1761        stream: Option<pyo3::PyObject>,
1762        max_version: Option<pyo3::PyObject>,
1763        dl_device: Option<pyo3::PyObject>,
1764        copy: Option<pyo3::PyObject>,
1765    ) -> PyResult<PyObject> {
1766        use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
1767
1768        let (kdl, alloc_dev) = self.__dlpack_device__();
1769        if let Some(dev_obj) = dl_device.as_ref() {
1770            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1771                if dev_ty != kdl || dev_id != alloc_dev {
1772                    let wants_copy = copy
1773                        .as_ref()
1774                        .and_then(|c| c.extract::<bool>(py).ok())
1775                        .unwrap_or(false);
1776                    if wants_copy {
1777                        return Err(PyValueError::new_err(
1778                            "device copy not implemented for __dlpack__",
1779                        ));
1780                    } else {
1781                        return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
1782                    }
1783                }
1784            }
1785        }
1786
1787        let _ = stream;
1788
1789        let dummy =
1790            DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
1791        let inner = std::mem::replace(
1792            &mut self.inner,
1793            DeviceArrayF32 {
1794                buf: dummy,
1795                rows: 0,
1796                cols: 0,
1797            },
1798        );
1799
1800        let rows = inner.rows;
1801        let cols = inner.cols;
1802        let buf = inner.buf;
1803
1804        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1805
1806        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1807    }
1808}
1809
1810#[cfg(all(feature = "python", feature = "cuda"))]
1811#[pyfunction(name = "cwma_cuda_batch_dev")]
1812#[pyo3(signature = (data_f32, period_range, device_id=0))]
1813pub fn cwma_cuda_batch_dev_py(
1814    py: Python<'_>,
1815    data_f32: numpy::PyReadonlyArray1<'_, f32>,
1816    period_range: (usize, usize, usize),
1817    device_id: usize,
1818) -> PyResult<DeviceArrayF32CwmaPy> {
1819    if !cuda_available() {
1820        return Err(PyValueError::new_err("CUDA not available"));
1821    }
1822    let slice_in = data_f32.as_slice()?;
1823    let sweep = CwmaBatchRange {
1824        period: period_range,
1825    };
1826
1827    let cuda =
1828        Arc::new(CudaCwma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?);
1829    let inner = py.allow_threads(|| {
1830        cuda.cwma_batch_dev(slice_in, &sweep)
1831            .map_err(|e| PyValueError::new_err(e.to_string()))
1832    })?;
1833
1834    Ok(DeviceArrayF32CwmaPy { inner, guard: cuda })
1835}
1836
1837#[cfg(all(feature = "python", feature = "cuda"))]
1838#[pyfunction(name = "cwma_cuda_many_series_one_param_dev")]
1839#[pyo3(signature = (data_tm_f32, period, device_id=0))]
1840pub fn cwma_cuda_many_series_one_param_dev_py(
1841    py: Python<'_>,
1842    data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1843    period: usize,
1844    device_id: usize,
1845) -> PyResult<DeviceArrayF32CwmaPy> {
1846    use numpy::PyUntypedArrayMethods;
1847
1848    if !cuda_available() {
1849        return Err(PyValueError::new_err("CUDA not available"));
1850    }
1851
1852    let flat_in: &[f32] = data_tm_f32.as_slice()?;
1853    let rows = data_tm_f32.shape()[0];
1854    let cols = data_tm_f32.shape()[1];
1855    let params = CwmaParams {
1856        period: Some(period),
1857    };
1858
1859    let cuda =
1860        Arc::new(CudaCwma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?);
1861    let inner = py.allow_threads(|| {
1862        cuda.cwma_multi_series_one_param_time_major_dev(flat_in, cols, rows, &params)
1863            .map_err(|e| PyValueError::new_err(e.to_string()))
1864    })?;
1865
1866    Ok(DeviceArrayF32CwmaPy { inner, guard: cuda })
1867}
1868
1869#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1870#[wasm_bindgen]
1871pub fn cwma_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1872    let params = CwmaParams {
1873        period: Some(period),
1874    };
1875    let input = CwmaInput::from_slice(data, params);
1876
1877    let mut output = vec![0.0; data.len()];
1878
1879    cwma_into_slice(&mut output, &input, Kernel::Auto)
1880        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1881
1882    Ok(output)
1883}
1884
1885#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1886#[wasm_bindgen]
1887pub fn cwma_alloc(len: usize) -> *mut f64 {
1888    let mut vec = Vec::<f64>::with_capacity(len);
1889    let ptr = vec.as_mut_ptr();
1890    std::mem::forget(vec);
1891    ptr
1892}
1893
1894#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1895#[wasm_bindgen]
1896pub fn cwma_free(ptr: *mut f64, len: usize) {
1897    if !ptr.is_null() && len > 0 {
1898        unsafe {
1899            let _ = Vec::from_raw_parts(ptr, len, len);
1900        }
1901    }
1902}
1903
1904#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1905#[wasm_bindgen]
1906pub fn cwma_into(
1907    in_ptr: *const f64,
1908    out_ptr: *mut f64,
1909    len: usize,
1910    period: usize,
1911) -> Result<(), JsValue> {
1912    if in_ptr.is_null() || out_ptr.is_null() {
1913        return Err(JsValue::from_str("null pointer passed to cwma_into"));
1914    }
1915
1916    unsafe {
1917        let data = std::slice::from_raw_parts(in_ptr, len);
1918
1919        if period == 0 || period > len {
1920            return Err(JsValue::from_str("Invalid period"));
1921        }
1922
1923        let params = CwmaParams {
1924            period: Some(period),
1925        };
1926        let input = CwmaInput::from_slice(data, params);
1927
1928        if in_ptr == out_ptr {
1929            let mut temp = vec![0.0; len];
1930            cwma_into_slice(&mut temp, &input, Kernel::Auto)
1931                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1932
1933            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1934            out.copy_from_slice(&temp);
1935        } else {
1936            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1937            cwma_into_slice(out, &input, Kernel::Auto)
1938                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1939        }
1940
1941        Ok(())
1942    }
1943}
1944
1945#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1946#[wasm_bindgen]
1947#[deprecated(since = "1.0.0", note = "Use cwma_batch instead")]
1948pub fn cwma_batch_js(
1949    data: &[f64],
1950    period_start: usize,
1951    period_end: usize,
1952    period_step: usize,
1953) -> Result<Vec<f64>, JsValue> {
1954    let sweep = CwmaBatchRange {
1955        period: (period_start, period_end, period_step),
1956    };
1957
1958    cwma_batch_inner(data, &sweep, Kernel::Auto, false)
1959        .map(|output| output.values)
1960        .map_err(|e| JsValue::from_str(&e.to_string()))
1961}
1962
1963#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1964#[wasm_bindgen]
1965pub fn cwma_batch_metadata_js(
1966    period_start: usize,
1967    period_end: usize,
1968    period_step: usize,
1969) -> Result<Vec<f64>, JsValue> {
1970    let sweep = CwmaBatchRange {
1971        period: (period_start, period_end, period_step),
1972    };
1973
1974    let combos = expand_grid(&sweep);
1975    let metadata: Vec<f64> = combos
1976        .iter()
1977        .map(|combo| combo.period.unwrap() as f64)
1978        .collect();
1979
1980    Ok(metadata)
1981}
1982
1983#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1984#[derive(Serialize, Deserialize)]
1985pub struct CwmaBatchConfig {
1986    pub period_range: (usize, usize, usize),
1987}
1988
1989#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1990#[derive(Serialize, Deserialize)]
1991pub struct CwmaBatchJsOutput {
1992    pub values: Vec<f64>,
1993    pub combos: Vec<CwmaParams>,
1994    pub rows: usize,
1995    pub cols: usize,
1996}
1997
1998#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1999#[wasm_bindgen(js_name = cwma_batch)]
2000pub fn cwma_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2001    let config: CwmaBatchConfig = serde_wasm_bindgen::from_value(config)
2002        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2003
2004    let sweep = CwmaBatchRange {
2005        period: config.period_range,
2006    };
2007
2008    let output = cwma_batch_inner(data, &sweep, Kernel::Auto, false)
2009        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2010
2011    let js_output = CwmaBatchJsOutput {
2012        values: output.values,
2013        combos: output.combos,
2014        rows: output.rows,
2015        cols: output.cols,
2016    };
2017
2018    serde_wasm_bindgen::to_value(&js_output)
2019        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2020}
2021
2022#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2023#[wasm_bindgen]
2024pub fn cwma_batch_into(
2025    in_ptr: *const f64,
2026    out_ptr: *mut f64,
2027    len: usize,
2028    period_start: usize,
2029    period_end: usize,
2030    period_step: usize,
2031) -> Result<usize, JsValue> {
2032    if in_ptr.is_null() || out_ptr.is_null() {
2033        return Err(JsValue::from_str("null pointer passed to cwma_batch_into"));
2034    }
2035
2036    unsafe {
2037        let data = std::slice::from_raw_parts(in_ptr, len);
2038
2039        let sweep = CwmaBatchRange {
2040            period: (period_start, period_end, period_step),
2041        };
2042
2043        let combos = expand_grid(&sweep);
2044        let rows = combos.len();
2045        let cols = len;
2046
2047        let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
2048
2049        cwma_batch_inner_into(data, &sweep, Kernel::Auto, false, out)
2050            .map_err(|e| JsValue::from_str(&e.to_string()))?;
2051
2052        Ok(rows)
2053    }
2054}
2055
2056#[cfg(feature = "python")]
2057pub fn register_cwma_module(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
2058    m.add_function(wrap_pyfunction!(cwma_py, m)?)?;
2059    m.add_function(wrap_pyfunction!(cwma_batch_py, m)?)?;
2060    m.add_class::<CwmaStreamPy>()?;
2061    #[cfg(feature = "cuda")]
2062    {
2063        m.add_function(wrap_pyfunction!(cwma_cuda_batch_dev_py, m)?)?;
2064        m.add_function(wrap_pyfunction!(cwma_cuda_many_series_one_param_dev_py, m)?)?;
2065    }
2066    Ok(())
2067}
2068
2069#[cfg(test)]
2070mod tests {
2071    use super::*;
2072    use crate::skip_if_unsupported;
2073    use crate::utilities::data_loader::read_candles_from_csv;
2074    #[cfg(feature = "proptest")]
2075    use proptest::prelude::*;
2076
2077    #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2078    #[test]
2079    fn test_cwma_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
2080        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2081        let candles = read_candles_from_csv(file_path)?;
2082
2083        let input = CwmaInput::with_default_candles(&candles);
2084
2085        let baseline = cwma(&input)?.values;
2086
2087        let mut out = vec![0.0; candles.close.len()];
2088        cwma_into(&input, &mut out)?;
2089
2090        assert_eq!(baseline.len(), out.len());
2091
2092        fn eq_or_both_nan(a: f64, b: f64) -> bool {
2093            (a.is_nan() && b.is_nan()) || (a == b)
2094        }
2095
2096        for (i, (a, b)) in baseline.iter().zip(out.iter()).enumerate() {
2097            assert!(
2098                eq_or_both_nan(*a, *b),
2099                "mismatch at index {}: baseline={}, into={}",
2100                i,
2101                a,
2102                b
2103            );
2104        }
2105
2106        Ok(())
2107    }
2108
2109    fn check_cwma_partial_params(
2110        test_name: &str,
2111        kernel: Kernel,
2112    ) -> Result<(), Box<dyn std::error::Error>> {
2113        skip_if_unsupported!(kernel, test_name);
2114        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2115        let candles = read_candles_from_csv(file_path)?;
2116
2117        let default_params = CwmaParams { period: None };
2118        let input_def = CwmaInput::from_candles(&candles, "close", default_params);
2119        let output_def = cwma_with_kernel(&input_def, kernel)?;
2120        assert_eq!(output_def.values.len(), candles.close.len());
2121
2122        let params_14 = CwmaParams { period: Some(14) };
2123        let input_14 = CwmaInput::from_candles(&candles, "hl2", params_14);
2124        let output_14 = cwma_with_kernel(&input_14, kernel)?;
2125        assert_eq!(output_14.values.len(), candles.close.len());
2126
2127        let params_custom = CwmaParams { period: Some(20) };
2128        let input_custom = CwmaInput::from_candles(&candles, "hlc3", params_custom);
2129        let output_custom = cwma_with_kernel(&input_custom, kernel)?;
2130        assert_eq!(output_custom.values.len(), candles.close.len());
2131
2132        Ok(())
2133    }
2134
2135    fn check_cwma_accuracy(
2136        test_name: &str,
2137        kernel: Kernel,
2138    ) -> Result<(), Box<dyn std::error::Error>> {
2139        skip_if_unsupported!(kernel, test_name);
2140        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2141        let candles = read_candles_from_csv(file_path)?;
2142
2143        let input = CwmaInput::with_default_candles(&candles);
2144        let result = cwma_with_kernel(&input, kernel)?;
2145        assert_eq!(result.values.len(), candles.close.len());
2146
2147        let expected_last_five = [
2148            59224.641237300435,
2149            59213.64831277214,
2150            59171.21190130624,
2151            59167.01279027576,
2152            59039.413552249636,
2153        ];
2154
2155        let start = result.values.len().saturating_sub(5);
2156        for (i, &val) in result.values[start..].iter().enumerate() {
2157            let diff = (val - expected_last_five[i]).abs();
2158            assert!(
2159                diff < 1e-9,
2160                "[{}] CWMA {:?} mismatch at idx {}: got {}, expected {}",
2161                test_name,
2162                kernel,
2163                i,
2164                val,
2165                expected_last_five[i]
2166            );
2167        }
2168        Ok(())
2169    }
2170
2171    fn check_cwma_default_candles(
2172        test_name: &str,
2173        kernel: Kernel,
2174    ) -> Result<(), Box<dyn std::error::Error>> {
2175        skip_if_unsupported!(kernel, test_name);
2176        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2177        let candles = read_candles_from_csv(file_path)?;
2178
2179        let input = CwmaInput::with_default_candles(&candles);
2180        match input.data {
2181            CwmaData::Candles { source, .. } => assert_eq!(source, "close"),
2182            _ => panic!("Expected CwmaData::Candles"),
2183        }
2184        let output = cwma_with_kernel(&input, kernel)?;
2185        assert_eq!(output.values.len(), candles.close.len());
2186
2187        Ok(())
2188    }
2189
2190    fn check_cwma_zero_period(
2191        test_name: &str,
2192        kernel: Kernel,
2193    ) -> Result<(), Box<dyn std::error::Error>> {
2194        skip_if_unsupported!(kernel, test_name);
2195        let input_data = [10.0, 20.0, 30.0];
2196        let params = CwmaParams { period: Some(0) };
2197        let input = CwmaInput::from_slice(&input_data, params);
2198        let res = cwma_with_kernel(&input, kernel);
2199        assert!(res.is_err(), "[{}] Should fail with zero period", test_name);
2200        Ok(())
2201    }
2202
2203    fn check_cwma_period_exceeds_length(
2204        test_name: &str,
2205        kernel: Kernel,
2206    ) -> Result<(), Box<dyn std::error::Error>> {
2207        skip_if_unsupported!(kernel, test_name);
2208        let data_small = [10.0, 20.0, 30.0];
2209        let params = CwmaParams { period: Some(10) };
2210        let input = CwmaInput::from_slice(&data_small, params);
2211        let res = cwma_with_kernel(&input, kernel);
2212        assert!(
2213            res.is_err(),
2214            "[{}] Should fail with period exceeding length",
2215            test_name
2216        );
2217        Ok(())
2218    }
2219
2220    fn check_cwma_very_small_dataset(
2221        test_name: &str,
2222        kernel: Kernel,
2223    ) -> Result<(), Box<dyn std::error::Error>> {
2224        skip_if_unsupported!(kernel, test_name);
2225        let single_point = [42.0];
2226        let params = CwmaParams { period: Some(9) };
2227        let input = CwmaInput::from_slice(&single_point, params);
2228        let res = cwma_with_kernel(&input, kernel);
2229        assert!(
2230            res.is_err(),
2231            "[{}] Should fail with insufficient data",
2232            test_name
2233        );
2234        Ok(())
2235    }
2236
2237    fn check_cwma_empty_input(
2238        test_name: &str,
2239        kernel: Kernel,
2240    ) -> Result<(), Box<dyn std::error::Error>> {
2241        skip_if_unsupported!(kernel, test_name);
2242        let empty: [f64; 0] = [];
2243        let input = CwmaInput::from_slice(&empty, CwmaParams::default());
2244        let res = cwma_with_kernel(&input, kernel);
2245        assert!(
2246            matches!(res, Err(CwmaError::EmptyInputData)),
2247            "[{}] Should fail with empty input",
2248            test_name
2249        );
2250        Ok(())
2251    }
2252
2253    fn check_cwma_reinput(
2254        test_name: &str,
2255        kernel: Kernel,
2256    ) -> Result<(), Box<dyn std::error::Error>> {
2257        skip_if_unsupported!(kernel, test_name);
2258        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2259        let candles = read_candles_from_csv(file_path)?;
2260
2261        let first_params = CwmaParams { period: Some(80) };
2262        let first_input = CwmaInput::from_candles(&candles, "close", first_params);
2263        let first_result = cwma_with_kernel(&first_input, kernel)?;
2264
2265        let second_params = CwmaParams { period: Some(60) };
2266        let second_input = CwmaInput::from_slice(&first_result.values, second_params);
2267        let second_result = cwma_with_kernel(&second_input, kernel)?;
2268        assert_eq!(second_result.values.len(), first_result.values.len());
2269
2270        if second_result.values.len() > 240 {
2271            for i in 240..second_result.values.len() {
2272                assert!(
2273                    !second_result.values[i].is_nan(),
2274                    "[{}] Found unexpected NaN at index {}",
2275                    test_name,
2276                    i
2277                );
2278            }
2279        }
2280        Ok(())
2281    }
2282
2283    fn check_cwma_nan_handling(
2284        test_name: &str,
2285        kernel: Kernel,
2286    ) -> Result<(), Box<dyn std::error::Error>> {
2287        skip_if_unsupported!(kernel, test_name);
2288        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2289        let candles = read_candles_from_csv(file_path)?;
2290
2291        let input = CwmaInput::from_candles(&candles, "close", CwmaParams { period: Some(9) });
2292        let res = cwma_with_kernel(&input, kernel)?;
2293        assert_eq!(res.values.len(), candles.close.len());
2294        if res.values.len() > 240 {
2295            for (i, &val) in res.values[240..].iter().enumerate() {
2296                assert!(
2297                    !val.is_nan(),
2298                    "[{}] Found unexpected NaN at out-index {}",
2299                    test_name,
2300                    240 + i
2301                );
2302            }
2303        }
2304        Ok(())
2305    }
2306
2307    fn check_cwma_streaming(
2308        test_name: &str,
2309        kernel: Kernel,
2310    ) -> Result<(), Box<dyn std::error::Error>> {
2311        skip_if_unsupported!(kernel, test_name);
2312        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2313        let candles = read_candles_from_csv(file_path)?;
2314
2315        let period = 9;
2316        let input = CwmaInput::from_candles(
2317            &candles,
2318            "close",
2319            CwmaParams {
2320                period: Some(period),
2321            },
2322        );
2323        let batch_output = cwma_with_kernel(&input, kernel)?.values;
2324
2325        let mut stream = CwmaStream::try_new(CwmaParams {
2326            period: Some(period),
2327        })?;
2328        let mut stream_values = Vec::with_capacity(candles.close.len());
2329        for &price in &candles.close {
2330            match stream.update(price) {
2331                Some(val) => stream_values.push(val),
2332                None => stream_values.push(f64::NAN),
2333            }
2334        }
2335
2336        assert_eq!(batch_output.len(), stream_values.len());
2337        for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
2338            if b.is_nan() && s.is_nan() {
2339                continue;
2340            }
2341            let diff = (b - s).abs();
2342            assert!(
2343                diff < 1e-9,
2344                "[{}] CWMA streaming mismatch at idx {}: batch={}, stream={}, diff={}",
2345                test_name,
2346                i,
2347                b,
2348                s,
2349                diff
2350            );
2351        }
2352        Ok(())
2353    }
2354
2355    #[cfg(debug_assertions)]
2356    fn check_cwma_no_poison(
2357        test_name: &str,
2358        kernel: Kernel,
2359    ) -> Result<(), Box<dyn std::error::Error>> {
2360        skip_if_unsupported!(kernel, test_name);
2361
2362        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2363        let candles = read_candles_from_csv(file_path)?;
2364
2365        let test_params = vec![
2366            CwmaParams::default(),
2367            CwmaParams { period: Some(2) },
2368            CwmaParams { period: Some(3) },
2369            CwmaParams { period: Some(5) },
2370            CwmaParams { period: Some(7) },
2371            CwmaParams { period: Some(10) },
2372            CwmaParams { period: Some(14) },
2373            CwmaParams { period: Some(20) },
2374            CwmaParams { period: Some(30) },
2375            CwmaParams { period: Some(50) },
2376            CwmaParams { period: Some(100) },
2377            CwmaParams { period: Some(200) },
2378            CwmaParams { period: Some(250) },
2379        ];
2380
2381        for (param_idx, params) in test_params.iter().enumerate() {
2382            let input = CwmaInput::from_candles(&candles, "close", params.clone());
2383            let output = cwma_with_kernel(&input, kernel)?;
2384
2385            for (i, &val) in output.values.iter().enumerate() {
2386                if val.is_nan() {
2387                    continue;
2388                }
2389
2390                let bits = val.to_bits();
2391
2392                if bits == 0x11111111_11111111 {
2393                    panic!(
2394                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
2395                        with params: period={}",
2396                        test_name,
2397                        val,
2398                        bits,
2399                        i,
2400                        params.period.unwrap_or(14)
2401                    );
2402                }
2403
2404                if bits == 0x22222222_22222222 {
2405                    panic!(
2406                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
2407                        with params: period={}",
2408                        test_name,
2409                        val,
2410                        bits,
2411                        i,
2412                        params.period.unwrap_or(14)
2413                    );
2414                }
2415
2416                if bits == 0x33333333_33333333 {
2417                    panic!(
2418                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
2419                        with params: period={}",
2420                        test_name,
2421                        val,
2422                        bits,
2423                        i,
2424                        params.period.unwrap_or(14)
2425                    );
2426                }
2427            }
2428        }
2429
2430        Ok(())
2431    }
2432
2433    #[cfg(not(debug_assertions))]
2434    fn check_cwma_no_poison(
2435        _test_name: &str,
2436        _kernel: Kernel,
2437    ) -> Result<(), Box<dyn std::error::Error>> {
2438        Ok(())
2439    }
2440
2441    #[cfg(feature = "proptest")]
2442    #[allow(clippy::float_cmp)]
2443    fn check_cwma_property(
2444        test_name: &str,
2445        kernel: Kernel,
2446    ) -> Result<(), Box<dyn std::error::Error>> {
2447        use proptest::prelude::*;
2448        skip_if_unsupported!(kernel, test_name);
2449
2450        let strat = (2usize..=32).prop_flat_map(|period| {
2451            (
2452                prop::collection::vec(
2453                    (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
2454                    period..400,
2455                ),
2456                Just(period),
2457                (-1e3f64..1e3f64).prop_filter("finite a", |a| a.is_finite() && *a != 0.0),
2458                -1e3f64..1e3f64,
2459            )
2460        });
2461
2462        proptest::test_runner::TestRunner::default()
2463            .run(&strat, |(data, period, a, b)| {
2464                let params = CwmaParams {
2465                    period: Some(period),
2466                };
2467                let input = CwmaInput::from_slice(&data, params.clone());
2468
2469                let fast = cwma_with_kernel(&input, kernel);
2470                let slow = cwma_with_kernel(&input, Kernel::Scalar);
2471
2472                match (fast, slow) {
2473                    (Err(e1), Err(e2))
2474                        if std::mem::discriminant(&e1) == std::mem::discriminant(&e2) =>
2475                    {
2476                        return Ok(());
2477                    }
2478
2479                    (Err(e1), Err(e2)) => {
2480                        prop_assert!(false, "different errors: fast={:?} slow={:?}", e1, e2)
2481                    }
2482
2483                    (Err(e1), Ok(_)) => {
2484                        prop_assert!(false, "fast errored {e1:?} but scalar succeeded")
2485                    }
2486                    (Ok(_), Err(e2)) => {
2487                        prop_assert!(false, "scalar errored {e2:?} but fast succeeded")
2488                    }
2489
2490                    (Ok(fast), Ok(reference)) => {
2491                        let CwmaOutput { values: out } = fast;
2492                        let CwmaOutput { values: rref } = reference;
2493
2494                        let mut stream = CwmaStream::try_new(params.clone()).unwrap();
2495                        let mut s_out = Vec::with_capacity(data.len());
2496                        for &v in &data {
2497                            s_out.push(stream.update(v).unwrap_or(f64::NAN));
2498                        }
2499
2500                        let transformed: Vec<f64> = data.iter().map(|x| a * x + b).collect();
2501                        let t_out = cwma(&CwmaInput::from_slice(&transformed, params))?.values;
2502
2503                        for i in (period - 1)..data.len() {
2504                            let w = &data[i + 1 - period..=i];
2505                            let (lo, hi) = w
2506                                .iter()
2507                                .fold((f64::INFINITY, f64::NEG_INFINITY), |(l, h), &v| {
2508                                    (l.min(v), h.max(v))
2509                                });
2510                            let y = out[i];
2511                            let yr = rref[i];
2512                            let ys = s_out[i];
2513                            let yt = t_out[i];
2514
2515                            prop_assert!(
2516                                y.is_nan() || (y >= lo - 1e-9 && y <= hi + 1e-9),
2517                                "idx {i}: {y} ∉ [{lo}, {hi}]"
2518                            );
2519
2520                            if period == 1 && y.is_finite() {
2521                                prop_assert!((y - data[i]).abs() <= f64::EPSILON);
2522                            }
2523
2524                            if w.iter().all(|v| *v == w[0]) {
2525                                prop_assert!((y - w[0]).abs() <= 1e-9);
2526                            }
2527
2528                            if data[..=i].windows(2).all(|p| p[0] <= p[1])
2529                                && y.is_finite()
2530                                && out[i - 1].is_finite()
2531                            {
2532                                prop_assert!(y >= out[i - 1] - 1e-12);
2533                            }
2534
2535                            {
2536                                let expected = a * y + b;
2537                                let diff = (yt - expected).abs();
2538                                let tol_abs = 1e-9_f64;
2539                                let tol_rel = expected.abs() * 1e-9;
2540                                let ulp = yt.to_bits().abs_diff(expected.to_bits());
2541
2542                                prop_assert!(
2543                                    diff <= tol_abs.max(tol_rel) || ulp <= 8,
2544                                    "idx {i}: affine mismatch diff={diff:e}  ULP={ulp}"
2545                                );
2546                            }
2547
2548                            let ulp = y.to_bits().abs_diff(yr.to_bits());
2549                            prop_assert!(
2550                                (y - yr).abs() <= 1e-9 || ulp <= 4,
2551                                "idx {i}: fast={y} ref={yr} ULP={ulp}"
2552                            );
2553
2554                            prop_assert!(
2555                                (y - ys).abs() <= 1e-9 || (y.is_nan() && ys.is_nan()),
2556                                "idx {i}: stream mismatch"
2557                            );
2558                        }
2559
2560                        let first = data.iter().position(|x| !x.is_nan()).unwrap_or(data.len());
2561                        let warm = first + period - 1;
2562                        prop_assert!(out[..warm].iter().all(|v| v.is_nan()));
2563                    }
2564                }
2565
2566                Ok(())
2567            })
2568            .unwrap();
2569
2570        assert!(cwma(&CwmaInput::from_slice(&[], CwmaParams::default())).is_err());
2571        assert!(cwma(&CwmaInput::from_slice(
2572            &[f64::NAN; 12],
2573            CwmaParams::default()
2574        ))
2575        .is_err());
2576        assert!(cwma(&CwmaInput::from_slice(
2577            &[1.0; 5],
2578            CwmaParams { period: Some(8) }
2579        ))
2580        .is_err());
2581        assert!(cwma(&CwmaInput::from_slice(
2582            &[1.0; 5],
2583            CwmaParams { period: Some(0) }
2584        ))
2585        .is_err());
2586
2587        Ok(())
2588    }
2589
2590    macro_rules! generate_all_cwma_tests {
2591        ($($test_fn:ident),*) => {
2592            paste::paste! {
2593                $(
2594                    #[test]
2595                    fn [<$test_fn _scalar_f64>]() {
2596                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2597                    }
2598                )*
2599                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2600                $(
2601                    #[test]
2602                    fn [<$test_fn _avx2_f64>]() {
2603                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2604                    }
2605                    #[test]
2606                    fn [<$test_fn _avx512_f64>]() {
2607                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2608                    }
2609                )*
2610            }
2611        }
2612    }
2613
2614    generate_all_cwma_tests!(
2615        check_cwma_partial_params,
2616        check_cwma_accuracy,
2617        check_cwma_default_candles,
2618        check_cwma_zero_period,
2619        check_cwma_period_exceeds_length,
2620        check_cwma_very_small_dataset,
2621        check_cwma_empty_input,
2622        check_cwma_reinput,
2623        check_cwma_nan_handling,
2624        check_cwma_streaming,
2625        check_cwma_no_poison
2626    );
2627
2628    #[cfg(feature = "proptest")]
2629    generate_all_cwma_tests!(check_cwma_property);
2630
2631    fn check_batch_default_row(
2632        test: &str,
2633        kernel: Kernel,
2634    ) -> Result<(), Box<dyn std::error::Error>> {
2635        skip_if_unsupported!(kernel, test);
2636
2637        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2638        let c = read_candles_from_csv(file)?;
2639
2640        let output = CwmaBatchBuilder::new()
2641            .kernel(kernel)
2642            .apply_candles(&c, "close")?;
2643
2644        let def = CwmaParams::default();
2645        let row = output.values_for(&def).expect("default row missing");
2646
2647        assert_eq!(row.len(), c.close.len());
2648
2649        let expected = [
2650            59224.641237300435,
2651            59213.64831277214,
2652            59171.21190130624,
2653            59167.01279027576,
2654            59039.413552249636,
2655        ];
2656        let start = row.len() - 5;
2657        for (i, &v) in row[start..].iter().enumerate() {
2658            assert!(
2659                (v - expected[i]).abs() < 1e-8,
2660                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
2661            );
2662        }
2663        Ok(())
2664    }
2665
2666    #[cfg(debug_assertions)]
2667    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
2668        skip_if_unsupported!(kernel, test);
2669
2670        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2671        let c = read_candles_from_csv(file)?;
2672
2673        let test_configs = vec![
2674            (2, 5, 1),
2675            (5, 25, 5),
2676            (10, 50, 10),
2677            (2, 4, 1),
2678            (50, 150, 25),
2679            (9, 21, 2),
2680            (9, 21, 4),
2681            (100, 300, 50),
2682        ];
2683
2684        for (cfg_idx, &(p_start, p_end, p_step)) in test_configs.iter().enumerate() {
2685            let output = CwmaBatchBuilder::new()
2686                .kernel(kernel)
2687                .period_range(p_start, p_end, p_step)
2688                .apply_candles(&c, "close")?;
2689
2690            for (idx, &val) in output.values.iter().enumerate() {
2691                if val.is_nan() {
2692                    continue;
2693                }
2694
2695                let bits = val.to_bits();
2696                let row = idx / output.cols;
2697                let col = idx % output.cols;
2698                let combo = &output.combos[row];
2699
2700                if bits == 0x11111111_11111111 {
2701                    panic!(
2702                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2703                        at row {} col {} (flat index {}) with params: period={}",
2704                        test,
2705                        cfg_idx,
2706                        val,
2707                        bits,
2708                        row,
2709                        col,
2710                        idx,
2711                        combo.period.unwrap_or(14)
2712                    );
2713                }
2714
2715                if bits == 0x22222222_22222222 {
2716                    panic!(
2717                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2718                        at row {} col {} (flat index {}) with params: period={}",
2719                        test,
2720                        cfg_idx,
2721                        val,
2722                        bits,
2723                        row,
2724                        col,
2725                        idx,
2726                        combo.period.unwrap_or(14)
2727                    );
2728                }
2729
2730                if bits == 0x33333333_33333333 {
2731                    panic!(
2732                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2733                        at row {} col {} (flat index {}) with params: period={}",
2734                        test,
2735                        cfg_idx,
2736                        val,
2737                        bits,
2738                        row,
2739                        col,
2740                        idx,
2741                        combo.period.unwrap_or(14)
2742                    );
2743                }
2744            }
2745        }
2746
2747        Ok(())
2748    }
2749
2750    #[cfg(not(debug_assertions))]
2751    fn check_batch_no_poison(
2752        _test: &str,
2753        _kernel: Kernel,
2754    ) -> Result<(), Box<dyn std::error::Error>> {
2755        Ok(())
2756    }
2757
2758    macro_rules! gen_batch_tests {
2759        ($fn_name:ident) => {
2760            paste::paste! {
2761                #[test] fn [<$fn_name _scalar>]()      {
2762                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2763                }
2764                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2765                #[test] fn [<$fn_name _avx2>]()        {
2766                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2767                }
2768                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2769                #[test] fn [<$fn_name _avx512>]()      {
2770                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2771                }
2772                #[test] fn [<$fn_name _auto_detect>]() {
2773                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2774                }
2775            }
2776        };
2777    }
2778    gen_batch_tests!(check_batch_default_row);
2779    gen_batch_tests!(check_batch_no_poison);
2780}