Skip to main content

vector_ta/indicators/moving_averages/
fwma.rs

1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::cuda_available;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::cuda::moving_averages::alma_wrapper::DeviceArrayF32;
5#[cfg(all(feature = "python", feature = "cuda"))]
6use crate::cuda::moving_averages::CudaFwma;
7use crate::utilities::aligned_vector::AlignedVec;
8use crate::utilities::data_loader::{source_type, Candles};
9use crate::utilities::enums::Kernel;
10use crate::utilities::helpers::{
11    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
12    make_uninit_matrix,
13};
14#[cfg(feature = "python")]
15use crate::utilities::kernel_validation::validate_kernel;
16#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
17use core::arch::x86_64::*;
18#[cfg(feature = "python")]
19use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
20#[cfg(feature = "python")]
21use pyo3::exceptions::PyValueError;
22#[cfg(feature = "python")]
23use pyo3::prelude::*;
24#[cfg(feature = "python")]
25use pyo3::types::PyDict;
26
27#[cfg(all(feature = "python", feature = "cuda"))]
28use cust::context::Context;
29#[cfg(all(feature = "python", feature = "cuda"))]
30use cust::memory::DeviceBuffer;
31#[cfg(not(target_arch = "wasm32"))]
32use rayon::prelude::*;
33#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
34use serde::{Deserialize, Serialize};
35use std::convert::AsRef;
36use std::error::Error;
37use std::mem::{ManuallyDrop, MaybeUninit};
38#[cfg(all(feature = "python", feature = "cuda"))]
39use std::sync::Arc;
40use thiserror::Error;
41#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
42use wasm_bindgen::prelude::*;
43
44#[cfg(all(feature = "python", feature = "cuda"))]
45#[pyclass(module = "ta_indicators.cuda", name = "FwmaDeviceArrayF32", unsendable)]
46pub struct DeviceArrayF32Py {
47    pub(crate) inner: DeviceArrayF32,
48    pub(crate) _ctx: Arc<Context>,
49    pub(crate) device_id: u32,
50    pub(crate) stream: usize,
51}
52
53#[cfg(all(feature = "python", feature = "cuda"))]
54#[pymethods]
55impl DeviceArrayF32Py {
56    #[getter]
57    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
58        let d = PyDict::new(py);
59        d.set_item("shape", (self.inner.rows, self.inner.cols))?;
60        d.set_item("typestr", "<f4")?;
61        d.set_item(
62            "strides",
63            (
64                self.inner.cols * core::mem::size_of::<f32>(),
65                core::mem::size_of::<f32>(),
66            ),
67        )?;
68        d.set_item("data", (self.inner.device_ptr() as usize, false))?;
69
70        d.set_item("version", 3)?;
71        Ok(d)
72    }
73
74    fn __dlpack_device__(&self) -> (i32, i32) {
75        (2, self.device_id as i32)
76    }
77
78    #[pyo3(signature=(stream=None, max_version=None, dl_device=None, copy=None))]
79    fn __dlpack__<'py>(
80        &mut self,
81        py: Python<'py>,
82        stream: Option<pyo3::PyObject>,
83        max_version: Option<pyo3::PyObject>,
84        dl_device: Option<pyo3::PyObject>,
85        copy: Option<pyo3::PyObject>,
86    ) -> PyResult<PyObject> {
87        use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
88
89        let (kdl, alloc_dev) = self.__dlpack_device__();
90        if let Some(dev_obj) = dl_device.as_ref() {
91            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
92                if dev_ty != kdl || dev_id != alloc_dev {
93                    let wants_copy = copy
94                        .as_ref()
95                        .and_then(|c| c.extract::<bool>(py).ok())
96                        .unwrap_or(false);
97                    if wants_copy {
98                        return Err(PyValueError::new_err(
99                            "device copy not implemented for __dlpack__",
100                        ));
101                    } else {
102                        return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
103                    }
104                }
105            }
106        }
107        let _ = stream;
108
109        let dummy =
110            DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
111        let inner = std::mem::replace(
112            &mut self.inner,
113            DeviceArrayF32 {
114                buf: dummy,
115                rows: 0,
116                cols: 0,
117            },
118        );
119
120        let rows = inner.rows;
121        let cols = inner.cols;
122        let buf = inner.buf;
123
124        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
125
126        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
127    }
128}
129
130impl<'a> AsRef<[f64]> for FwmaInput<'a> {
131    #[inline(always)]
132    fn as_ref(&self) -> &[f64] {
133        match &self.data {
134            FwmaData::Slice(slice) => slice,
135            FwmaData::Candles { candles, source } => source_type(candles, source),
136        }
137    }
138}
139
140#[derive(Debug, Clone)]
141pub enum FwmaData<'a> {
142    Candles {
143        candles: &'a Candles,
144        source: &'a str,
145    },
146    Slice(&'a [f64]),
147}
148
149#[derive(Debug, Clone)]
150pub struct FwmaOutput {
151    pub values: Vec<f64>,
152}
153
154#[derive(Debug, Clone)]
155#[cfg_attr(
156    all(target_arch = "wasm32", feature = "wasm"),
157    derive(Serialize, Deserialize)
158)]
159pub struct FwmaParams {
160    pub period: Option<usize>,
161}
162
163impl Default for FwmaParams {
164    fn default() -> Self {
165        Self { period: Some(5) }
166    }
167}
168
169#[derive(Debug, Clone)]
170pub struct FwmaInput<'a> {
171    pub data: FwmaData<'a>,
172    pub params: FwmaParams,
173}
174
175impl<'a> FwmaInput<'a> {
176    #[inline]
177    pub fn from_candles(c: &'a Candles, s: &'a str, p: FwmaParams) -> Self {
178        Self {
179            data: FwmaData::Candles {
180                candles: c,
181                source: s,
182            },
183            params: p,
184        }
185    }
186    #[inline]
187    pub fn from_slice(sl: &'a [f64], p: FwmaParams) -> Self {
188        Self {
189            data: FwmaData::Slice(sl),
190            params: p,
191        }
192    }
193    #[inline]
194    pub fn with_default_candles(c: &'a Candles) -> Self {
195        Self::from_candles(c, "close", FwmaParams::default())
196    }
197    #[inline]
198    pub fn get_period(&self) -> usize {
199        self.params.period.unwrap_or(5)
200    }
201}
202
203#[derive(Copy, Clone, Debug)]
204pub struct FwmaBuilder {
205    period: Option<usize>,
206    kernel: Kernel,
207}
208
209impl Default for FwmaBuilder {
210    fn default() -> Self {
211        Self {
212            period: None,
213            kernel: Kernel::Auto,
214        }
215    }
216}
217
218impl FwmaBuilder {
219    #[inline(always)]
220    pub fn new() -> Self {
221        Self::default()
222    }
223    #[inline(always)]
224    pub fn period(mut self, n: usize) -> Self {
225        self.period = Some(n);
226        self
227    }
228    #[inline(always)]
229    pub fn kernel(mut self, k: Kernel) -> Self {
230        self.kernel = k;
231        self
232    }
233
234    #[inline(always)]
235    pub fn apply(self, c: &Candles) -> Result<FwmaOutput, FwmaError> {
236        let p = FwmaParams {
237            period: self.period,
238        };
239        let i = FwmaInput::from_candles(c, "close", p);
240        fwma_with_kernel(&i, self.kernel)
241    }
242
243    #[inline(always)]
244    pub fn apply_slice(self, d: &[f64]) -> Result<FwmaOutput, FwmaError> {
245        let p = FwmaParams {
246            period: self.period,
247        };
248        let i = FwmaInput::from_slice(d, p);
249        fwma_with_kernel(&i, self.kernel)
250    }
251
252    #[inline(always)]
253    pub fn into_stream(self) -> Result<FwmaStream, FwmaError> {
254        let p = FwmaParams {
255            period: self.period,
256        };
257        FwmaStream::try_new(p)
258    }
259}
260
261#[derive(Debug, Error)]
262pub enum FwmaError {
263    #[error("fwma: Input data slice is empty.")]
264    EmptyInputData,
265    #[error("fwma: All values are NaN.")]
266    AllValuesNaN,
267    #[error("fwma: Invalid period: period = {period}, data length = {data_len}")]
268    InvalidPeriod { period: usize, data_len: usize },
269    #[error("fwma: Not enough valid data: needed = {needed}, valid = {valid}")]
270    NotEnoughValidData { needed: usize, valid: usize },
271    #[error("fwma: Fibonacci sum is zero. Cannot normalize weights.")]
272    ZeroFibonacciSum,
273    #[error("fwma: Output buffer length mismatch: expected = {expected}, got = {got}")]
274    OutputLengthMismatch { expected: usize, got: usize },
275    #[error("fwma: Invalid range: start={start}, end={end}, step={step}")]
276    InvalidRange {
277        start: usize,
278        end: usize,
279        step: usize,
280    },
281    #[error("fwma: Invalid kernel for batch API: {0:?}")]
282    InvalidKernelForBatch(Kernel),
283    #[error("fwma: arithmetic overflow while computing {context}")]
284    ArithmeticOverflow { context: &'static str },
285}
286
287#[inline]
288pub fn fwma(input: &FwmaInput) -> Result<FwmaOutput, FwmaError> {
289    fwma_with_kernel(input, Kernel::Auto)
290}
291
292#[inline(always)]
293fn fwma_prepare<'a>(
294    input: &'a FwmaInput,
295    kernel: Kernel,
296) -> Result<(&'a [f64], Vec<f64>, usize, usize, Kernel), FwmaError> {
297    let data: &[f64] = input.as_ref();
298    let len = data.len();
299    if len == 0 {
300        return Err(FwmaError::EmptyInputData);
301    }
302    let first = data
303        .iter()
304        .position(|x| !x.is_nan())
305        .ok_or(FwmaError::AllValuesNaN)?;
306    let period = input.get_period();
307
308    if period == 0 || period > len {
309        return Err(FwmaError::InvalidPeriod {
310            period,
311            data_len: len,
312        });
313    }
314    if (len - first) < period {
315        return Err(FwmaError::NotEnoughValidData {
316            needed: period,
317            valid: len - first,
318        });
319    }
320
321    let mut fib = vec![1.0; period];
322    for i in 2..period {
323        fib[i] = fib[i - 1] + fib[i - 2];
324    }
325    let fib_sum: f64 = fib.iter().sum();
326    if fib_sum == 0.0 {
327        return Err(FwmaError::ZeroFibonacciSum);
328    }
329    for w in &mut fib {
330        *w /= fib_sum;
331    }
332
333    let chosen = match kernel {
334        Kernel::Auto => detect_best_kernel(),
335        other => other,
336    };
337
338    Ok((data, fib, period, first, chosen))
339}
340
341#[inline(always)]
342fn fwma_compute_into(
343    data: &[f64],
344    fib: &[f64],
345    period: usize,
346    first: usize,
347    kernel: Kernel,
348    out: &mut [f64],
349) {
350    unsafe {
351        #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
352        {
353            if matches!(kernel, Kernel::Scalar | Kernel::ScalarBatch) {
354                fwma_simd128(data, fib, period, first, out);
355                return;
356            }
357        }
358
359        match kernel {
360            Kernel::Scalar | Kernel::ScalarBatch => fwma_scalar(data, fib, period, first, out),
361            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
362            Kernel::Avx2 | Kernel::Avx2Batch => fwma_avx2(data, fib, period, first, out),
363            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
364            Kernel::Avx512 | Kernel::Avx512Batch => fwma_avx512(data, fib, period, first, out),
365            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
366            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
367                fwma_scalar(data, fib, period, first, out)
368            }
369            _ => unreachable!(),
370        }
371    }
372}
373
374pub fn fwma_with_kernel(input: &FwmaInput, kernel: Kernel) -> Result<FwmaOutput, FwmaError> {
375    let (data, fib, period, first, chosen) = fwma_prepare(input, kernel)?;
376
377    let warm = first + period - 1;
378    let mut out = alloc_with_nan_prefix(data.len(), warm);
379
380    fwma_compute_into(data, &fib, period, first, chosen, &mut out);
381
382    Ok(FwmaOutput { values: out })
383}
384
385#[inline]
386pub fn fwma_into_slice(dst: &mut [f64], input: &FwmaInput, kern: Kernel) -> Result<(), FwmaError> {
387    let (data, fib, period, first, chosen) = fwma_prepare(input, kern)?;
388
389    if dst.len() != data.len() {
390        return Err(FwmaError::OutputLengthMismatch {
391            expected: data.len(),
392            got: dst.len(),
393        });
394    }
395
396    let warmup_end = (first + period - 1).min(dst.len());
397    for v in &mut dst[..warmup_end] {
398        *v = f64::from_bits(0x7ff8_0000_0000_0000);
399    }
400
401    fwma_compute_into(data, &fib, period, first, chosen, dst);
402
403    Ok(())
404}
405
406#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
407#[inline]
408pub fn fwma_into(input: &FwmaInput, out: &mut [f64]) -> Result<(), FwmaError> {
409    let (data, fib, period, first, chosen) = fwma_prepare(input, Kernel::Auto)?;
410
411    if out.len() != data.len() {
412        return Err(FwmaError::OutputLengthMismatch {
413            expected: data.len(),
414            got: out.len(),
415        });
416    }
417
418    let warm = (first + period - 1).min(out.len());
419    for v in &mut out[..warm] {
420        *v = f64::from_bits(0x7ff8_0000_0000_0000);
421    }
422
423    fwma_compute_into(data, &fib, period, first, chosen, out);
424    Ok(())
425}
426
427#[inline(always)]
428pub unsafe fn fwma_scalar(
429    data: &[f64],
430    fib: &[f64],
431    period: usize,
432    first_val: usize,
433    out: &mut [f64],
434) {
435    assert_eq!(fib.len(), period, "fib.len() must equal period");
436    assert!(
437        out.len() >= data.len(),
438        "out must be at least as long as data"
439    );
440
441    let p4 = period & !3;
442    for i in (first_val + period - 1)..data.len() {
443        let start = i + 1 - period;
444        let window = &data[start..start + period];
445        let mut sum = 0.0;
446        for (d4, w4) in window[..p4].chunks_exact(4).zip(fib[..p4].chunks_exact(4)) {
447            sum += d4[0] * w4[0] + d4[1] * w4[1] + d4[2] * w4[2] + d4[3] * w4[3];
448        }
449        for (d, w) in window[p4..].iter().zip(&fib[p4..]) {
450            sum += d * w;
451        }
452        out[i] = sum;
453    }
454}
455
456#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
457#[inline]
458unsafe fn fwma_simd128(
459    data: &[f64],
460    fib: &[f64],
461    period: usize,
462    first_val: usize,
463    out: &mut [f64],
464) {
465    use core::arch::wasm32::*;
466
467    assert_eq!(fib.len(), period, "fib.len() must equal period");
468    assert!(
469        out.len() >= data.len(),
470        "out must be at least as long as data"
471    );
472
473    const STEP: usize = 2;
474    let chunks = period / STEP;
475    let tail = period % STEP;
476
477    for i in (first_val + period - 1)..data.len() {
478        let start = i + 1 - period;
479        let window = &data[start..start + period];
480
481        let mut sum_vec = f64x2_splat(0.0);
482
483        for j in 0..chunks {
484            let idx = j * STEP;
485            let d_vec = v128_load(&window[idx] as *const f64 as *const v128);
486            let w_vec = v128_load(&fib[idx] as *const f64 as *const v128);
487            sum_vec = f64x2_add(sum_vec, f64x2_mul(d_vec, w_vec));
488        }
489
490        let mut sum = f64x2_extract_lane::<0>(sum_vec) + f64x2_extract_lane::<1>(sum_vec);
491
492        if tail > 0 {
493            sum += window[chunks * STEP] * fib[chunks * STEP];
494        }
495
496        out[i] = sum;
497    }
498}
499
500#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
501#[target_feature(enable = "avx2")]
502unsafe fn horizontal_sum_avx2(v: __m256d) -> f64 {
503    let high_low = _mm256_hadd_pd(v, v);
504    let high = _mm256_extractf128_pd(high_low, 1);
505    let low = _mm256_castpd256_pd128(high_low);
506    let sum = _mm_add_pd(high, low);
507    let result = _mm_hadd_pd(sum, sum);
508    _mm_cvtsd_f64(result) * 0.5
509}
510
511#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
512#[target_feature(enable = "avx512f,fma")]
513unsafe fn fwma_avx512_short(
514    data: &[f64],
515    fib: &[f64],
516    period: usize,
517    first: usize,
518    out: &mut [f64],
519) {
520    const SIMD_WIDTH: usize = 8;
521    let simd_chunks = period / SIMD_WIDTH;
522    let remainder = period % SIMD_WIDTH;
523
524    let data_ptr = data.as_ptr();
525    let out_ptr = out.as_mut_ptr();
526
527    let mut aligned_fib = AlignedVec::with_capacity(period + SIMD_WIDTH);
528    let fib_buf = aligned_fib.as_mut_slice();
529    fib_buf[..period].copy_from_slice(fib);
530    let fib_ptr = fib_buf.as_ptr();
531
532    let mut fib_vecs = Vec::with_capacity(simd_chunks);
533    for chunk in 0..simd_chunks {
534        fib_vecs.push(_mm512_load_pd(fib_ptr.add(chunk * SIMD_WIDTH)));
535    }
536
537    let tail_mask: __mmask8 = (1u8 << remainder).wrapping_sub(1);
538
539    for idx in (first + period - 1)..data.len() {
540        let start = idx + 1 - period;
541        let window_ptr = data_ptr.add(start);
542
543        _mm_prefetch(window_ptr.add(64) as *const i8, _MM_HINT_T0);
544
545        let mut sum0 = _mm512_setzero_pd();
546        let mut sum1 = _mm512_setzero_pd();
547        let mut sum2 = _mm512_setzero_pd();
548        let mut sum3 = _mm512_setzero_pd();
549
550        let chunks4 = simd_chunks / 4;
551        for i in 0..chunks4 {
552            let base = window_ptr.add(i * 32);
553            sum0 = _mm512_fmadd_pd(_mm512_loadu_pd(base), fib_vecs[i * 4 + 0], sum0);
554            sum1 = _mm512_fmadd_pd(_mm512_loadu_pd(base.add(8)), fib_vecs[i * 4 + 1], sum1);
555            sum2 = _mm512_fmadd_pd(_mm512_loadu_pd(base.add(16)), fib_vecs[i * 4 + 2], sum2);
556            sum3 = _mm512_fmadd_pd(_mm512_loadu_pd(base.add(24)), fib_vecs[i * 4 + 3], sum3);
557        }
558
559        for i in (chunks4 * 4)..simd_chunks {
560            let base = window_ptr.add(i * SIMD_WIDTH);
561            sum0 = _mm512_fmadd_pd(_mm512_loadu_pd(base), fib_vecs[i], sum0);
562        }
563
564        sum0 = _mm512_add_pd(sum0, sum1);
565        sum2 = _mm512_add_pd(sum2, sum3);
566        sum0 = _mm512_add_pd(sum0, sum2);
567
568        if remainder != 0 {
569            let data_tail =
570                _mm512_maskz_loadu_pd(tail_mask, window_ptr.add(simd_chunks * SIMD_WIDTH));
571            let weight_tail =
572                _mm512_maskz_load_pd(tail_mask, fib_ptr.add(simd_chunks * SIMD_WIDTH));
573            sum0 = _mm512_fmadd_pd(data_tail, weight_tail, sum0);
574        }
575
576        let total = _mm512_reduce_add_pd(sum0);
577        *out_ptr.add(idx) = total;
578    }
579}
580
581#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
582#[target_feature(enable = "avx512f,fma")]
583unsafe fn fwma_avx512_long(
584    data: &[f64],
585    fib: &[f64],
586    period: usize,
587    first: usize,
588    out: &mut [f64],
589) {
590    const STEP: usize = 8;
591    const UNROLL: usize = 4;
592
593    let chunks = period / STEP;
594    let tail_len = period % STEP;
595
596    let mut aligned_fib = AlignedVec::with_capacity(period + STEP);
597    let fib_buf = aligned_fib.as_mut_slice();
598    fib_buf[..period].copy_from_slice(fib);
599    let fib_ptr = fib_buf.as_ptr();
600
601    let mut weight_vecs = Vec::with_capacity(chunks);
602    for i in 0..chunks {
603        weight_vecs.push(_mm512_load_pd(fib_ptr.add(i * STEP)));
604    }
605
606    let tmask: __mmask8 = (1u8 << tail_len).wrapping_sub(1);
607    let w_tail = if tail_len > 0 {
608        Some(_mm512_maskz_load_pd(tmask, fib_ptr.add(chunks * STEP)))
609    } else {
610        None
611    };
612
613    let end = data.len();
614    let last_valid = end.saturating_sub(UNROLL - 1);
615    let mut i = first + period - 1;
616
617    while i < last_valid {
618        let base0 = data.as_ptr().add(i + 1 - period);
619        let base1 = base0.add(1);
620        let base2 = base0.add(2);
621        let base3 = base0.add(3);
622
623        let mut sum0 = _mm512_setzero_pd();
624        let mut sum1 = _mm512_setzero_pd();
625        let mut sum2 = _mm512_setzero_pd();
626        let mut sum3 = _mm512_setzero_pd();
627
628        for (j, &w) in weight_vecs.iter().enumerate() {
629            let offset = j * STEP;
630
631            let d0 = _mm512_loadu_pd(base0.add(offset));
632            let d1 = _mm512_loadu_pd(base1.add(offset));
633            let d2 = _mm512_loadu_pd(base2.add(offset));
634            let d3 = _mm512_loadu_pd(base3.add(offset));
635
636            sum0 = _mm512_fmadd_pd(d0, w, sum0);
637            sum1 = _mm512_fmadd_pd(d1, w, sum1);
638            sum2 = _mm512_fmadd_pd(d2, w, sum2);
639            sum3 = _mm512_fmadd_pd(d3, w, sum3);
640        }
641
642        if let Some(wt) = w_tail {
643            let offset = chunks * STEP;
644            let d0 = _mm512_maskz_loadu_pd(tmask, base0.add(offset));
645            let d1 = _mm512_maskz_loadu_pd(tmask, base1.add(offset));
646            let d2 = _mm512_maskz_loadu_pd(tmask, base2.add(offset));
647            let d3 = _mm512_maskz_loadu_pd(tmask, base3.add(offset));
648
649            sum0 = _mm512_fmadd_pd(d0, wt, sum0);
650            sum1 = _mm512_fmadd_pd(d1, wt, sum1);
651            sum2 = _mm512_fmadd_pd(d2, wt, sum2);
652            sum3 = _mm512_fmadd_pd(d3, wt, sum3);
653        }
654
655        out[i] = _mm512_reduce_add_pd(sum0);
656        out[i + 1] = _mm512_reduce_add_pd(sum1);
657        out[i + 2] = _mm512_reduce_add_pd(sum2);
658        out[i + 3] = _mm512_reduce_add_pd(sum3);
659
660        i += UNROLL;
661    }
662
663    while i < end {
664        let base = data.as_ptr().add(i + 1 - period);
665        let mut sum = _mm512_setzero_pd();
666
667        for (j, &w) in weight_vecs.iter().enumerate() {
668            let d = _mm512_loadu_pd(base.add(j * STEP));
669            sum = _mm512_fmadd_pd(d, w, sum);
670        }
671
672        if let Some(wt) = w_tail {
673            let d = _mm512_maskz_loadu_pd(tmask, base.add(chunks * STEP));
674            sum = _mm512_fmadd_pd(d, wt, sum);
675        }
676
677        out[i] = _mm512_reduce_add_pd(sum);
678        i += 1;
679    }
680}
681
682#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
683#[target_feature(enable = "avx512f,fma")]
684pub unsafe fn fwma_avx512(data: &[f64], fib: &[f64], period: usize, first: usize, out: &mut [f64]) {
685    if period <= 32 {
686        fwma_avx512_short(data, fib, period, first, out);
687    } else {
688        fwma_avx512_long(data, fib, period, first, out);
689    }
690}
691
692#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
693#[target_feature(enable = "avx2")]
694pub unsafe fn fwma_avx2(
695    data: &[f64],
696    fib: &[f64],
697    period: usize,
698    first_valid: usize,
699    out: &mut [f64],
700) {
701    const W: usize = 4;
702    let full = period / W;
703    let tail = period % W;
704
705    let mut aligned = AlignedVec::with_capacity(period + W);
706    let fib_aln = aligned.as_mut_slice();
707    fib_aln[..period].copy_from_slice(fib);
708    let wptr = fib_aln.as_ptr();
709
710    let dptr = data.as_ptr();
711    let optr = out.as_mut_ptr();
712
713    for i in (first_valid + period - 1)..data.len() {
714        let base = dptr.add(i + 1 - period);
715
716        let mut acc0 = _mm256_setzero_pd();
717        let mut acc1 = _mm256_setzero_pd();
718
719        let mut j = 0;
720        while j + 8 <= period {
721            let v0 = _mm256_loadu_pd(base.add(j));
722            let w0 = _mm256_load_pd(wptr.add(j));
723            acc0 = _mm256_fmadd_pd(v0, w0, acc0);
724
725            let v1 = _mm256_loadu_pd(base.add(j + 4));
726            let w1 = _mm256_load_pd(wptr.add(j + 4));
727            acc1 = _mm256_fmadd_pd(v1, w1, acc1);
728
729            j += 8;
730        }
731        if j + 4 <= period {
732            let v = _mm256_loadu_pd(base.add(j));
733            let w = _mm256_load_pd(wptr.add(j));
734            acc0 = _mm256_fmadd_pd(v, w, acc0);
735            j += 4;
736        }
737
738        let sum_vec = _mm256_add_pd(acc0, acc1);
739        let mut sum = horizontal_sum_avx2(sum_vec);
740
741        for k in 0..tail {
742            let idx = period - tail + k;
743            sum += *base.add(idx) * *wptr.add(idx);
744        }
745        *optr.add(i) = sum;
746    }
747}
748
749#[derive(Debug, Clone)]
750pub struct FwmaStream {
751    period: usize,
752
753    w: Vec<f64>,
754    w0: f64,
755    w_last: f64,
756    w_prev: f64,
757    w_next: f64,
758
759    buffer: Vec<f64>,
760    head: usize,
761    filled: bool,
762
763    acc_n: f64,
764    acc_d: f64,
765
766    nan_count: usize,
767}
768
769impl FwmaStream {
770    pub fn try_new(params: FwmaParams) -> Result<Self, FwmaError> {
771        let period = params.period.unwrap_or(5);
772        if period == 0 {
773            return Err(FwmaError::InvalidPeriod {
774                period,
775                data_len: 0,
776            });
777        }
778
779        let mut w = vec![1.0; period];
780        for i in 2..period {
781            w[i] = w[i - 1] + w[i - 2];
782        }
783
784        let sum: f64 = w.iter().sum();
785        if sum == 0.0 {
786            return Err(FwmaError::ZeroFibonacciSum);
787        }
788        for wi in &mut w {
789            *wi /= sum;
790        }
791
792        let w0 = w[0];
793        let w_last = w[period - 1];
794        let w_prev = if period > 1 { w[period - 2] } else { 0.0 };
795        let w_next = w_last + w_prev;
796
797        Ok(Self {
798            period,
799            w,
800            w0,
801            w_last,
802            w_prev,
803            w_next,
804            buffer: vec![f64::NAN; period],
805            head: 0,
806            filled: false,
807            acc_n: 0.0,
808            acc_d: 0.0,
809            nan_count: 0,
810        })
811    }
812
813    #[inline(always)]
814    pub fn update(&mut self, value: f64) -> Option<f64> {
815        let old_raw = self.buffer[self.head];
816
817        if self.filled && old_raw.is_nan() {
818            self.nan_count = self.nan_count.saturating_sub(1);
819        }
820
821        self.buffer[self.head] = value;
822        if value.is_nan() {
823            self.nan_count += 1;
824        }
825
826        self.head += 1;
827        if self.head == self.period {
828            self.head = 0;
829        }
830
831        if !self.filled {
832            if self.head != 0 {
833                return None;
834            }
835
836            self.filled = true;
837
838            let mut n = 0.0f64;
839            let mut q = 0.0f64;
840            let last = self.period - 1;
841            for j in 0..self.period {
842                let x = self.buffer[j];
843                let xv = if x.is_nan() { 0.0 } else { x };
844
845                n += xv * self.w[j];
846
847                let wn = if j < last { self.w[j + 1] } else { self.w_next };
848                q += xv * wn;
849            }
850            self.acc_n = n;
851            self.acc_d = q - n;
852
853            return Some(if self.nan_count == 0 { n } else { f64::NAN });
854        }
855
856        let x_old = if old_raw.is_nan() { 0.0 } else { old_raw };
857        let x_new = if value.is_nan() { 0.0 } else { value };
858
859        let prev_n = self.acc_n;
860
861        let d_old = self.acc_d;
862        let n_prime = x_new.mul_add(self.w_last, d_old);
863
864        let d_prime = x_new.mul_add(self.w_prev, prev_n - d_old - self.w0 * x_old);
865
866        self.acc_n = n_prime;
867        self.acc_d = d_prime;
868
869        Some(if self.nan_count == 0 {
870            self.dot_ring()
871        } else {
872            f64::NAN
873        })
874    }
875}
876
877impl FwmaStream {
878    #[inline(always)]
879    fn dot_ring(&self) -> f64 {
880        let mut sum = 0.0;
881        let mut idx = self.head;
882        for &wj in &self.w {
883            sum += wj * self.buffer[idx];
884            idx += 1;
885            if idx == self.period {
886                idx = 0;
887            }
888        }
889        sum
890    }
891}
892
893#[derive(Clone, Debug)]
894pub struct FwmaBatchRange {
895    pub period: (usize, usize, usize),
896}
897impl Default for FwmaBatchRange {
898    fn default() -> Self {
899        Self {
900            period: (5, 254, 1),
901        }
902    }
903}
904
905#[derive(Clone, Debug, Default)]
906pub struct FwmaBatchBuilder {
907    range: FwmaBatchRange,
908    kernel: Kernel,
909}
910
911impl FwmaBatchBuilder {
912    pub fn new() -> Self {
913        Self::default()
914    }
915    pub fn kernel(mut self, k: Kernel) -> Self {
916        self.kernel = k;
917        self
918    }
919
920    #[inline]
921    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
922        self.range.period = (start, end, step);
923        self
924    }
925    #[inline]
926    pub fn period_static(mut self, p: usize) -> Self {
927        self.range.period = (p, p, 0);
928        self
929    }
930    pub fn apply_slice(self, data: &[f64]) -> Result<FwmaBatchOutput, FwmaError> {
931        fwma_batch_with_kernel(data, &self.range, self.kernel)
932    }
933    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<FwmaBatchOutput, FwmaError> {
934        FwmaBatchBuilder::new().kernel(k).apply_slice(data)
935    }
936    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<FwmaBatchOutput, FwmaError> {
937        let slice = source_type(c, src);
938        self.apply_slice(slice)
939    }
940    pub fn with_default_candles(c: &Candles) -> Result<FwmaBatchOutput, FwmaError> {
941        FwmaBatchBuilder::new()
942            .kernel(Kernel::Auto)
943            .apply_candles(c, "close")
944    }
945}
946
947pub fn fwma_batch_with_kernel(
948    data: &[f64],
949    sweep: &FwmaBatchRange,
950    k: Kernel,
951) -> Result<FwmaBatchOutput, FwmaError> {
952    let kernel = match k {
953        Kernel::Auto => detect_best_batch_kernel(),
954        other if other.is_batch() => other,
955        other => {
956            return Err(FwmaError::InvalidKernelForBatch(other));
957        }
958    };
959    let simd = match kernel {
960        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
961        Kernel::Avx512Batch => Kernel::Avx512,
962        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
963        Kernel::Avx2Batch => Kernel::Avx2,
964        Kernel::ScalarBatch => Kernel::Scalar,
965        _ => unreachable!(),
966    };
967    fwma_batch_par_slice(data, sweep, simd)
968}
969
970#[derive(Clone, Debug)]
971pub struct FwmaBatchOutput {
972    pub values: Vec<f64>,
973    pub combos: Vec<FwmaParams>,
974    pub rows: usize,
975    pub cols: usize,
976}
977impl FwmaBatchOutput {
978    pub fn row_for_params(&self, p: &FwmaParams) -> Option<usize> {
979        self.combos
980            .iter()
981            .position(|c| c.period.unwrap_or(5) == p.period.unwrap_or(5))
982    }
983    pub fn values_for(&self, p: &FwmaParams) -> Option<&[f64]> {
984        self.row_for_params(p).map(|row| {
985            let start = row * self.cols;
986            &self.values[start..start + self.cols]
987        })
988    }
989}
990
991#[inline(always)]
992fn expand_grid(r: &FwmaBatchRange) -> Vec<FwmaParams> {
993    fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
994        if step == 0 || start == end {
995            return vec![start];
996        }
997        let (lo, hi) = if start <= end {
998            (start, end)
999        } else {
1000            (end, start)
1001        };
1002        (lo..=hi).step_by(step).collect()
1003    }
1004    let periods = axis_usize(r.period);
1005    let mut out = Vec::with_capacity(periods.len());
1006    for &p in &periods {
1007        out.push(FwmaParams { period: Some(p) });
1008    }
1009    out
1010}
1011
1012#[inline(always)]
1013fn fill_nan_prefixes_slice(
1014    rows: usize,
1015    cols: usize,
1016    warmup_periods: &[usize],
1017    out_slice: &mut [f64],
1018) {
1019    for (row, &warmup) in warmup_periods.iter().enumerate() {
1020        let row_start = row * cols;
1021        let row_end = row_start + warmup.min(cols);
1022        for i in row_start..row_end {
1023            out_slice[i] = f64::NAN;
1024        }
1025    }
1026}
1027
1028#[inline(always)]
1029pub fn fwma_batch_slice(
1030    data: &[f64],
1031    sweep: &FwmaBatchRange,
1032    kern: Kernel,
1033) -> Result<FwmaBatchOutput, FwmaError> {
1034    fwma_batch_inner(data, sweep, kern, false)
1035}
1036
1037#[inline(always)]
1038pub fn fwma_batch_par_slice(
1039    data: &[f64],
1040    sweep: &FwmaBatchRange,
1041    kern: Kernel,
1042) -> Result<FwmaBatchOutput, FwmaError> {
1043    fwma_batch_inner(data, sweep, kern, true)
1044}
1045
1046#[inline(always)]
1047fn fwma_batch_inner(
1048    data: &[f64],
1049    sweep: &FwmaBatchRange,
1050    kern: Kernel,
1051    parallel: bool,
1052) -> Result<FwmaBatchOutput, FwmaError> {
1053    let combos = expand_grid(sweep);
1054    if combos.is_empty() {
1055        let (s, e, t) = sweep.period;
1056        return Err(FwmaError::InvalidRange {
1057            start: s,
1058            end: e,
1059            step: t,
1060        });
1061    }
1062
1063    let rows = combos.len();
1064    let cols = data.len();
1065
1066    let _total = rows
1067        .checked_mul(cols)
1068        .ok_or(FwmaError::ArithmeticOverflow {
1069            context: "rows*cols in fwma_batch_inner",
1070        })?;
1071
1072    let mut buf_mu = make_uninit_matrix(rows, cols);
1073
1074    let first = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1075    let warm: Vec<usize> = combos
1076        .iter()
1077        .map(|c| first + c.period.unwrap() - 1)
1078        .collect();
1079
1080    init_matrix_prefixes(&mut buf_mu, cols, &warm);
1081
1082    let mut buf_guard = ManuallyDrop::new(buf_mu);
1083    let values_slice: &mut [f64] = unsafe {
1084        core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
1085    };
1086
1087    fwma_batch_inner_into(data, sweep, kern, parallel, values_slice)?;
1088
1089    let values = unsafe {
1090        Vec::from_raw_parts(
1091            buf_guard.as_mut_ptr() as *mut f64,
1092            buf_guard.len(),
1093            buf_guard.capacity(),
1094        )
1095    };
1096
1097    Ok(FwmaBatchOutput {
1098        values,
1099        combos,
1100        rows,
1101        cols,
1102    })
1103}
1104#[inline(always)]
1105fn fwma_batch_inner_into(
1106    data: &[f64],
1107    sweep: &FwmaBatchRange,
1108    kern: Kernel,
1109    parallel: bool,
1110    out: &mut [f64],
1111) -> Result<Vec<FwmaParams>, FwmaError> {
1112    let combos = expand_grid(sweep);
1113    if combos.is_empty() {
1114        let (s, e, t) = sweep.period;
1115        return Err(FwmaError::InvalidRange {
1116            start: s,
1117            end: e,
1118            step: t,
1119        });
1120    }
1121
1122    let first = data
1123        .iter()
1124        .position(|x| !x.is_nan())
1125        .ok_or(FwmaError::AllValuesNaN)?;
1126    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1127    if data.len() - first < max_p {
1128        return Err(FwmaError::NotEnoughValidData {
1129            needed: max_p,
1130            valid: data.len() - first,
1131        });
1132    }
1133
1134    let rows = combos.len();
1135    let cols = data.len();
1136
1137    let cap = rows
1138        .checked_mul(max_p)
1139        .ok_or(FwmaError::ArithmeticOverflow {
1140            context: "rows*max_p in fwma_batch_inner_into",
1141        })?;
1142    let mut aligned = AlignedVec::with_capacity(cap);
1143    let flat_fib = aligned.as_mut_slice();
1144
1145    for (row, prm) in combos.iter().enumerate() {
1146        let period = prm.period.unwrap();
1147        let base = row * max_p;
1148        let slice = &mut flat_fib[base..base + period];
1149        slice[0] = 1.0;
1150        if period > 1 {
1151            slice[1] = 1.0;
1152        }
1153        for i in 2..period {
1154            slice[i] = slice[i - 1] + slice[i - 2];
1155        }
1156        let sum: f64 = slice[..period].iter().sum();
1157        for w in &mut slice[..period] {
1158            *w /= sum;
1159        }
1160    }
1161
1162    let do_row = |row: usize, dst: &mut [f64]| unsafe {
1163        let period = combos[row].period.unwrap();
1164        let fib_ptr = flat_fib.as_ptr().add(row * max_p);
1165
1166        match kern {
1167            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1168            Kernel::Avx512 => fwma_row_avx512(data, first, period, max_p, fib_ptr, dst),
1169            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1170            Kernel::Avx2 => fwma_row_avx2(data, first, period, max_p, fib_ptr, dst),
1171            _ => fwma_row_scalar(data, first, period, max_p, fib_ptr, dst),
1172        }
1173    };
1174
1175    if parallel {
1176        #[cfg(not(target_arch = "wasm32"))]
1177        {
1178            out.par_chunks_mut(cols)
1179                .enumerate()
1180                .for_each(|(row, slice)| do_row(row, slice));
1181        }
1182        #[cfg(target_arch = "wasm32")]
1183        {
1184            for (row, slice) in out.chunks_mut(cols).enumerate() {
1185                do_row(row, slice);
1186            }
1187        }
1188    } else {
1189        for (row, slice) in out.chunks_mut(cols).enumerate() {
1190            do_row(row, slice);
1191        }
1192    }
1193
1194    Ok(combos)
1195}
1196
1197#[inline]
1198unsafe fn fwma_row_scalar(
1199    data: &[f64],
1200    first: usize,
1201    period: usize,
1202    _stride: usize,
1203    fib_ptr: *const f64,
1204    out: &mut [f64],
1205) {
1206    let fib = std::slice::from_raw_parts(fib_ptr, period);
1207    fwma_scalar(data, fib, period, first, out);
1208}
1209
1210#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1211#[target_feature(enable = "avx2")]
1212#[inline]
1213unsafe fn fwma_row_avx2(
1214    data: &[f64],
1215    first: usize,
1216    period: usize,
1217    _stride: usize,
1218    fib_ptr: *const f64,
1219    out: &mut [f64],
1220) {
1221    let fib = std::slice::from_raw_parts(fib_ptr, period);
1222    fwma_avx2(data, fib, period, first, out);
1223}
1224
1225#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1226#[target_feature(enable = "avx512f,fma")]
1227#[inline]
1228unsafe fn fwma_row_avx512(
1229    data: &[f64],
1230    first: usize,
1231    period: usize,
1232    _stride: usize,
1233    fib_ptr: *const f64,
1234    out: &mut [f64],
1235) {
1236    let fib = std::slice::from_raw_parts(fib_ptr, period);
1237    fwma_avx512(data, fib, period, first, out);
1238}
1239
1240#[cfg(test)]
1241mod tests {
1242    use super::*;
1243    use crate::skip_if_unsupported;
1244    use crate::utilities::data_loader::read_candles_from_csv;
1245
1246    fn check_fwma_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1247        skip_if_unsupported!(kernel, test_name);
1248        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1249        let candles = read_candles_from_csv(file_path)?;
1250
1251        let default_params = FwmaParams { period: None };
1252        let input = FwmaInput::from_candles(&candles, "close", default_params);
1253        let output = fwma_with_kernel(&input, kernel)?;
1254        assert_eq!(output.values.len(), candles.close.len());
1255
1256        Ok(())
1257    }
1258
1259    fn check_fwma_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1260        skip_if_unsupported!(kernel, test_name);
1261        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1262        let candles = read_candles_from_csv(file_path)?;
1263
1264        let input = FwmaInput::with_default_candles(&candles);
1265        let result = fwma_with_kernel(&input, kernel)?;
1266        let expected_last_five = [
1267            59273.583333333336,
1268            59252.5,
1269            59167.083333333336,
1270            59151.0,
1271            58940.333333333336,
1272        ];
1273        let start = result.values.len().saturating_sub(5);
1274        for (i, &val) in result.values[start..].iter().enumerate() {
1275            let diff = (val - expected_last_five[i]).abs();
1276            assert!(
1277                diff < 1e-8,
1278                "[{}] FWMA {:?} mismatch at idx {}: got {}, expected {}",
1279                test_name,
1280                kernel,
1281                i,
1282                val,
1283                expected_last_five[i]
1284            );
1285        }
1286        Ok(())
1287    }
1288
1289    fn check_fwma_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1290        skip_if_unsupported!(kernel, test_name);
1291        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1292        let candles = read_candles_from_csv(file_path)?;
1293
1294        let input = FwmaInput::with_default_candles(&candles);
1295        match input.data {
1296            FwmaData::Candles { source, .. } => assert_eq!(source, "close"),
1297            _ => panic!("Expected FwmaData::Candles"),
1298        }
1299        let output = fwma_with_kernel(&input, kernel)?;
1300        assert_eq!(output.values.len(), candles.close.len());
1301
1302        Ok(())
1303    }
1304
1305    fn check_fwma_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1306        skip_if_unsupported!(kernel, test_name);
1307        let input_data = [10.0, 20.0, 30.0];
1308        let params = FwmaParams { period: Some(0) };
1309        let input = FwmaInput::from_slice(&input_data, params);
1310        let res = fwma_with_kernel(&input, kernel);
1311        assert!(
1312            res.is_err(),
1313            "[{}] FWMA should fail with zero period",
1314            test_name
1315        );
1316        Ok(())
1317    }
1318
1319    fn check_fwma_period_exceeds_length(
1320        test_name: &str,
1321        kernel: Kernel,
1322    ) -> Result<(), Box<dyn Error>> {
1323        skip_if_unsupported!(kernel, test_name);
1324        let data_small = [10.0, 20.0, 30.0];
1325        let params = FwmaParams { period: Some(10) };
1326        let input = FwmaInput::from_slice(&data_small, params);
1327        let res = fwma_with_kernel(&input, kernel);
1328        assert!(
1329            res.is_err(),
1330            "[{}] FWMA should fail with period exceeding length",
1331            test_name
1332        );
1333        Ok(())
1334    }
1335
1336    fn check_fwma_very_small_dataset(
1337        test_name: &str,
1338        kernel: Kernel,
1339    ) -> Result<(), Box<dyn Error>> {
1340        skip_if_unsupported!(kernel, test_name);
1341        let single_point = [42.0];
1342        let params = FwmaParams { period: Some(5) };
1343        let input = FwmaInput::from_slice(&single_point, params);
1344        let res = fwma_with_kernel(&input, kernel);
1345        assert!(
1346            res.is_err(),
1347            "[{}] FWMA should fail with insufficient data",
1348            test_name
1349        );
1350        Ok(())
1351    }
1352
1353    fn check_fwma_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1354        skip_if_unsupported!(kernel, test_name);
1355        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1356        let candles = read_candles_from_csv(file_path)?;
1357
1358        let first_params = FwmaParams { period: Some(5) };
1359        let first_input = FwmaInput::from_candles(&candles, "close", first_params);
1360        let first_result = fwma_with_kernel(&first_input, kernel)?;
1361
1362        let second_params = FwmaParams { period: Some(3) };
1363        let second_input = FwmaInput::from_slice(&first_result.values, second_params);
1364        let second_result = fwma_with_kernel(&second_input, kernel)?;
1365
1366        assert_eq!(second_result.values.len(), first_result.values.len());
1367        for i in 240..second_result.values.len() {
1368            assert!(
1369                !second_result.values[i].is_nan(),
1370                "[{}] NaN found at idx {}",
1371                test_name,
1372                i
1373            );
1374        }
1375        Ok(())
1376    }
1377
1378    fn check_fwma_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1379        skip_if_unsupported!(kernel, test_name);
1380        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1381        let candles = read_candles_from_csv(file_path)?;
1382
1383        let input = FwmaInput::from_candles(&candles, "close", FwmaParams { period: Some(5) });
1384        let res = fwma_with_kernel(&input, kernel)?;
1385        assert_eq!(res.values.len(), candles.close.len());
1386        if res.values.len() > 50 {
1387            for (i, &val) in res.values[50..].iter().enumerate() {
1388                assert!(
1389                    !val.is_nan(),
1390                    "[{}] Found unexpected NaN at out-index {}",
1391                    test_name,
1392                    50 + i
1393                );
1394            }
1395        }
1396        Ok(())
1397    }
1398
1399    fn check_fwma_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1400        skip_if_unsupported!(kernel, test_name);
1401
1402        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1403        let candles = read_candles_from_csv(file_path)?;
1404
1405        let period = 5;
1406
1407        let input = FwmaInput::from_candles(
1408            &candles,
1409            "close",
1410            FwmaParams {
1411                period: Some(period),
1412            },
1413        );
1414        let batch_output = fwma_with_kernel(&input, kernel)?.values;
1415
1416        let mut stream = FwmaStream::try_new(FwmaParams {
1417            period: Some(period),
1418        })?;
1419
1420        let mut stream_values = Vec::with_capacity(candles.close.len());
1421        for &price in &candles.close {
1422            match stream.update(price) {
1423                Some(val) => stream_values.push(val),
1424                None => stream_values.push(f64::NAN),
1425            }
1426        }
1427
1428        assert_eq!(batch_output.len(), stream_values.len());
1429        for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1430            if b.is_nan() && s.is_nan() {
1431                continue;
1432            }
1433            let diff = (b - s).abs();
1434            assert!(
1435                diff < 1e-9,
1436                "[{}] FWMA streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1437                test_name,
1438                i,
1439                b,
1440                s,
1441                diff
1442            );
1443        }
1444        Ok(())
1445    }
1446
1447    #[cfg(debug_assertions)]
1448    fn check_fwma_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1449        skip_if_unsupported!(kernel, test_name);
1450
1451        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1452        let candles = read_candles_from_csv(file_path)?;
1453
1454        let test_cases = vec![
1455            FwmaParams::default(),
1456            FwmaParams { period: Some(2) },
1457            FwmaParams { period: Some(3) },
1458            FwmaParams { period: Some(5) },
1459            FwmaParams { period: Some(8) },
1460            FwmaParams { period: Some(10) },
1461            FwmaParams { period: Some(15) },
1462            FwmaParams { period: Some(20) },
1463            FwmaParams { period: Some(30) },
1464            FwmaParams { period: Some(50) },
1465        ];
1466
1467        for params in test_cases {
1468            let input = FwmaInput::from_candles(&candles, "close", params.clone());
1469            let output = fwma_with_kernel(&input, kernel)?;
1470
1471            for (i, &val) in output.values.iter().enumerate() {
1472                if val.is_nan() {
1473                    continue;
1474                }
1475
1476                let bits = val.to_bits();
1477
1478                if bits == 0x11111111_11111111 {
1479                    panic!(
1480                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} with params period={:?}",
1481                        test_name, val, bits, i, params.period
1482                    );
1483                }
1484
1485                if bits == 0x22222222_22222222 {
1486                    panic!(
1487                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} with params period={:?}",
1488                        test_name, val, bits, i, params.period
1489                    );
1490                }
1491
1492                if bits == 0x33333333_33333333 {
1493                    panic!(
1494						"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} with params period={:?}",
1495						test_name, val, bits, i, params.period
1496					);
1497                }
1498            }
1499        }
1500
1501        Ok(())
1502    }
1503
1504    #[cfg(not(debug_assertions))]
1505    fn check_fwma_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1506        Ok(())
1507    }
1508
1509    macro_rules! generate_all_fwma_tests {
1510        ($($test_fn:ident),*) => {
1511            paste::paste! {
1512                $(
1513                    #[test]
1514                    fn [<$test_fn _scalar_f64>]() {
1515                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1516                    }
1517
1518                    #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1519                    #[test]
1520                    fn [<$test_fn _avx2_f64>]() {
1521                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1522                    }
1523
1524                    #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1525                    #[test]
1526                    fn [<$test_fn _avx512_f64>]() {
1527                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1528                    }
1529
1530
1531                    #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
1532                    #[test]
1533                    fn [<$test_fn _simd128_f64>]() {
1534                        let _ = $test_fn(stringify!([<$test_fn _simd128_f64>]), Kernel::Scalar);
1535                    }
1536                )*
1537            }
1538        }
1539    }
1540
1541    generate_all_fwma_tests!(
1542        check_fwma_partial_params,
1543        check_fwma_accuracy,
1544        check_fwma_default_candles,
1545        check_fwma_zero_period,
1546        check_fwma_period_exceeds_length,
1547        check_fwma_very_small_dataset,
1548        check_fwma_reinput,
1549        check_fwma_nan_handling,
1550        check_fwma_streaming,
1551        check_fwma_no_poison
1552    );
1553
1554    #[cfg(feature = "proptest")]
1555    generate_all_fwma_tests!(check_fwma_property);
1556
1557    #[test]
1558    fn test_fwma_into_matches_api() -> Result<(), Box<dyn Error>> {
1559        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1560        let candles = read_candles_from_csv(file_path)?;
1561
1562        let input = FwmaInput::with_default_candles(&candles);
1563        let baseline = fwma(&input)?.values;
1564
1565        let mut out = vec![0.0f64; baseline.len()];
1566
1567        #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1568        {
1569            fwma_into(&input, &mut out)?;
1570        }
1571        #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1572        {
1573            fwma_into_slice(&mut out, &input, Kernel::Auto)?;
1574        }
1575
1576        assert_eq!(out.len(), baseline.len());
1577
1578        for (i, (&a, &b)) in out.iter().zip(baseline.iter()).enumerate() {
1579            let equal = (a.is_nan() && b.is_nan()) || (a == b);
1580            assert!(
1581                equal,
1582                "into parity mismatch at idx {}: got {}, expected {}",
1583                i, a, b
1584            );
1585        }
1586
1587        Ok(())
1588    }
1589
1590    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1591        skip_if_unsupported!(kernel, test);
1592
1593        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1594        let c = read_candles_from_csv(file)?;
1595
1596        let output = FwmaBatchBuilder::new()
1597            .kernel(kernel)
1598            .apply_candles(&c, "close")?;
1599
1600        let def = FwmaParams::default();
1601        let row = output.values_for(&def).expect("default row missing");
1602
1603        assert_eq!(row.len(), c.close.len());
1604
1605        let expected = [
1606            59273.583333333336,
1607            59252.5,
1608            59167.083333333336,
1609            59151.0,
1610            58940.333333333336,
1611        ];
1612        let start = row.len() - 5;
1613        for (i, &v) in row[start..].iter().enumerate() {
1614            assert!(
1615                (v - expected[i]).abs() < 1e-8,
1616                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1617            );
1618        }
1619        Ok(())
1620    }
1621
1622    #[cfg(debug_assertions)]
1623    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1624        skip_if_unsupported!(kernel, test);
1625
1626        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1627        let c = read_candles_from_csv(file)?;
1628
1629        let test_configs = vec![
1630            (2, 5, 1),
1631            (5, 15, 2),
1632            (10, 30, 5),
1633            (3, 10, 1),
1634            (5, 50, 10),
1635            (2, 20, 1),
1636        ];
1637
1638        for (start, end, step) in test_configs {
1639            let output = FwmaBatchBuilder::new()
1640                .kernel(kernel)
1641                .period_range(start, end, step)
1642                .apply_candles(&c, "close")?;
1643
1644            for (idx, &val) in output.values.iter().enumerate() {
1645                if val.is_nan() {
1646                    continue;
1647                }
1648
1649                let bits = val.to_bits();
1650                let row = idx / output.cols;
1651                let col = idx % output.cols;
1652                let params = &output.combos[row];
1653
1654                if bits == 0x11111111_11111111 {
1655                    panic!(
1656                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (params: period={:?})",
1657                        test, val, bits, row, col, params.period
1658                    );
1659                }
1660
1661                if bits == 0x22222222_22222222 {
1662                    panic!(
1663                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (params: period={:?})",
1664                        test, val, bits, row, col, params.period
1665                    );
1666                }
1667
1668                if bits == 0x33333333_33333333 {
1669                    panic!(
1670                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (params: period={:?})",
1671                        test, val, bits, row, col, params.period
1672                    );
1673                }
1674            }
1675        }
1676
1677        Ok(())
1678    }
1679
1680    #[cfg(not(debug_assertions))]
1681    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1682        Ok(())
1683    }
1684
1685    #[cfg(feature = "proptest")]
1686    #[allow(clippy::float_cmp)]
1687    fn check_fwma_property(
1688        test_name: &str,
1689        kernel: Kernel,
1690    ) -> Result<(), Box<dyn std::error::Error>> {
1691        use proptest::prelude::*;
1692        skip_if_unsupported!(kernel, test_name);
1693
1694        let strat = (1usize..=64).prop_flat_map(|period| {
1695            (
1696                prop::collection::vec(
1697                    (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1698                    period..400,
1699                ),
1700                Just(period),
1701            )
1702        });
1703
1704        proptest::test_runner::TestRunner::default()
1705            .run(&strat, |(mut data, period)| {
1706                if data.len() > period && period > 1 {
1707                    if data.len() % 10 == 0 {
1708                        data.truncate(period);
1709                    }
1710                }
1711
1712                let params = FwmaParams {
1713                    period: Some(period),
1714                };
1715                let input = FwmaInput::from_slice(&data, params);
1716
1717                let FwmaOutput { values: out } = fwma_with_kernel(&input, kernel).unwrap();
1718
1719                let FwmaOutput { values: ref_out } =
1720                    fwma_with_kernel(&input, Kernel::Scalar).unwrap();
1721
1722                prop_assert_eq!(out.len(), data.len());
1723                prop_assert_eq!(ref_out.len(), data.len());
1724
1725                for i in 0..(period - 1).min(data.len()) {
1726                    prop_assert!(
1727                        out[i].is_nan(),
1728                        "Expected NaN during warmup at index {}, got {}",
1729                        i,
1730                        out[i]
1731                    );
1732                }
1733
1734                if period == 2 && data.len() >= 2 {
1735                    let expected = (data[0] + data[1]) / 2.0;
1736                    if out[1].is_finite() && data[0].is_finite() && data[1].is_finite() {
1737                        prop_assert!(
1738                            (out[1] - expected).abs() <= 1e-9,
1739                            "Period=2: output {} should equal average {} at index 1",
1740                            out[1],
1741                            expected
1742                        );
1743                    }
1744                }
1745
1746                for i in (period - 1)..data.len() {
1747                    let window = &data[i + 1 - period..=i];
1748                    let lo = window.iter().cloned().fold(f64::INFINITY, f64::min);
1749                    let hi = window.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1750                    let y = out[i];
1751                    let r = ref_out[i];
1752
1753                    prop_assert!(
1754                        y.is_nan() || (y >= lo - 1e-9 && y <= hi + 1e-9),
1755                        "idx {}: {} ∉ [{}, {}]",
1756                        i,
1757                        y,
1758                        lo,
1759                        hi
1760                    );
1761
1762                    if period == 1 {
1763                        prop_assert!(
1764                            (y - data[i]).abs() <= f64::EPSILON,
1765                            "Period=1: output {} should equal input {} at index {}",
1766                            y,
1767                            data[i],
1768                            i
1769                        );
1770                    }
1771
1772                    if data.windows(2).all(|w| w[0] == w[1]) && !data.is_empty() {
1773                        prop_assert!(
1774                            (y - data[0]).abs() <= 1e-9,
1775                            "Constant data: output {} should equal constant {} at index {}",
1776                            y,
1777                            data[0],
1778                            i
1779                        );
1780                    }
1781
1782                    if window.iter().any(|x| x.is_nan()) {
1783                        prop_assert!(
1784                            y.is_nan(),
1785                            "Window contains NaN but output {} is not NaN at index {}",
1786                            y,
1787                            i
1788                        );
1789                    }
1790
1791                    if !y.is_finite() || !r.is_finite() {
1792                        prop_assert!(
1793                            y.to_bits() == r.to_bits(),
1794                            "finite/NaN mismatch idx {}: {} vs {}",
1795                            i,
1796                            y,
1797                            r
1798                        );
1799                        continue;
1800                    }
1801
1802                    let y_bits = y.to_bits();
1803                    let r_bits = r.to_bits();
1804                    let ulp_diff: u64 = y_bits.abs_diff(r_bits);
1805
1806                    prop_assert!(
1807                        (y - r).abs() <= 1e-9 || ulp_diff <= 4,
1808                        "mismatch idx {}: {} vs {} (ULP={})",
1809                        i,
1810                        y,
1811                        r,
1812                        ulp_diff
1813                    );
1814                }
1815
1816                let is_monotonic_inc = data.windows(2).all(|w| w[0] <= w[1]);
1817                let is_monotonic_dec = data.windows(2).all(|w| w[0] >= w[1]);
1818
1819                if (is_monotonic_inc || is_monotonic_dec) && data.len() >= period + 1 {
1820                    for i in period..out.len() {
1821                        if out[i].is_finite() && out[i - 1].is_finite() {
1822                            if is_monotonic_inc {
1823                                prop_assert!(
1824									out[i] >= out[i-1] - 1e-9,
1825									"Monotonic increasing data but output decreases: {} < {} at index {}",
1826									out[i],
1827									out[i-1],
1828									i
1829								);
1830                            }
1831                            if is_monotonic_dec {
1832                                prop_assert!(
1833									out[i] <= out[i-1] + 1e-9,
1834									"Monotonic decreasing data but output increases: {} > {} at index {}",
1835									out[i],
1836									out[i-1],
1837									i
1838								);
1839                            }
1840                        }
1841                    }
1842                }
1843
1844                if period >= 3 && data.len() >= period * 2 {
1845                    let test_start = period;
1846                    if test_start + period <= data.len() {
1847                        let all_ascending = (0..period).all(|j| {
1848                            let idx = test_start + j;
1849                            idx == 0
1850                                || !data[idx].is_finite()
1851                                || !data[idx - 1].is_finite()
1852                                || data[idx] >= data[idx - 1]
1853                        });
1854
1855                        if all_ascending && out[test_start + period - 1].is_finite() {
1856                            let window = &data[test_start..test_start + period];
1857                            let window_avg = window.iter().sum::<f64>() / period as f64;
1858                            if window.iter().all(|x| x.is_finite()) {
1859                                prop_assert!(
1860                                    out[test_start + period - 1] >= window_avg - 1e-9,
1861                                    "FWMA {} should be >= average {} for ascending window",
1862                                    out[test_start + period - 1],
1863                                    window_avg
1864                                );
1865                            }
1866                        }
1867                    }
1868                }
1869
1870                if period > 1 && data.len() >= period * 2 {
1871                    let data_range = data.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b.abs()));
1872                    for &val in &out[(period - 1)..] {
1873                        if val.is_finite() && data_range > 0.0 {
1874                            prop_assert!(
1875                                val.abs() <= data_range * 1.1,
1876                                "Output {} exceeds reasonable bounds for data range {}",
1877                                val,
1878                                data_range
1879                            );
1880                        }
1881                    }
1882                }
1883
1884                if period == 3 && data.len() >= 3 {
1885                    let idx = period - 1;
1886                    if data[idx - 2].is_finite()
1887                        && data[idx - 1].is_finite()
1888                        && data[idx].is_finite()
1889                    {
1890                        let expected =
1891                            data[idx - 2] * 0.25 + data[idx - 1] * 0.25 + data[idx] * 0.5;
1892                        prop_assert!(
1893                            (out[idx] - expected).abs() <= 1e-9,
1894                            "Period=3: output {} should equal weighted avg {} at index {}",
1895                            out[idx],
1896                            expected,
1897                            idx
1898                        );
1899                    }
1900                }
1901
1902                Ok(())
1903            })
1904            .unwrap();
1905
1906        let nan_strat = (2usize..=10).prop_flat_map(|period| (Just(period), 1usize..10));
1907
1908        proptest::test_runner::TestRunner::default()
1909            .run(&nan_strat, |(period, nan_pos)| {
1910                let mut data = vec![
1911                    1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1912                ];
1913                if nan_pos < data.len() {
1914                    data[nan_pos] = f64::NAN;
1915                }
1916
1917                let params = FwmaParams {
1918                    period: Some(period),
1919                };
1920                let input = FwmaInput::from_slice(&data, params);
1921                let FwmaOutput { values: out } = fwma_with_kernel(&input, kernel).unwrap();
1922
1923                for i in (period - 1)..data.len() {
1924                    let window_start = i + 1 - period;
1925                    let window = &data[window_start..=i];
1926                    let has_nan = window.iter().any(|x| x.is_nan());
1927
1928                    if has_nan {
1929                        prop_assert!(
1930                            out[i].is_nan(),
1931                            "Window [{}, {}] contains NaN but output {} is not NaN at index {}",
1932                            window_start,
1933                            i,
1934                            out[i],
1935                            i
1936                        );
1937                    }
1938                }
1939                Ok(())
1940            })
1941            .unwrap();
1942
1943        let extreme_strat = (1usize..=10).prop_flat_map(|period| (Just(period), prop::bool::ANY));
1944
1945        proptest::test_runner::TestRunner::default()
1946            .run(&extreme_strat, |(period, use_max)| {
1947                let extreme_val = if use_max { 1e308 } else { 1e-308 };
1948                let data = vec![extreme_val; period * 2];
1949
1950                let params = FwmaParams {
1951                    period: Some(period),
1952                };
1953                let input = FwmaInput::from_slice(&data, params);
1954                let result = fwma_with_kernel(&input, kernel);
1955
1956                prop_assert!(result.is_ok(), "Failed to handle extreme values");
1957
1958                if let Ok(FwmaOutput { values: out }) = result {
1959                    for i in (period - 1)..data.len() {
1960                        if out[i].is_finite() {
1961                            prop_assert!(
1962                                (out[i] - extreme_val).abs() / extreme_val.abs() <= 1e-9,
1963                                "Extreme constant value {} doesn't match output {} at index {}",
1964                                extreme_val,
1965                                out[i],
1966                                i
1967                            );
1968                        }
1969                    }
1970                }
1971                Ok(())
1972            })
1973            .unwrap();
1974
1975        Ok(())
1976    }
1977    macro_rules! gen_batch_tests {
1978        ($fn_name:ident) => {
1979            paste::paste! {
1980                #[test]
1981                fn [<$fn_name _scalar>]() {
1982                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1983                }
1984                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1985                #[test]
1986                fn [<$fn_name _avx2>]() {
1987                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1988                }
1989                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1990                #[test]
1991                fn [<$fn_name _avx512>]() {
1992                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1993                }
1994                #[test]
1995                fn [<$fn_name _auto_detect>]() {
1996                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1997                }
1998            }
1999        };
2000    }
2001    gen_batch_tests!(check_batch_default_row);
2002    gen_batch_tests!(check_batch_no_poison);
2003
2004    #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2005    #[test]
2006    fn test_fwma_batch_into_warmup() {
2007        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
2008        let len = data.len();
2009
2010        let period_start = 2;
2011        let period_end = 4;
2012        let period_step = 1;
2013
2014        let sweep = FwmaBatchRange {
2015            period: (period_start, period_end, period_step),
2016        };
2017        let combos = expand_grid(&sweep);
2018        let rows = combos.len();
2019
2020        let mut output = vec![999.0; rows * len];
2021
2022        let result = unsafe {
2023            fwma_batch_into(
2024                data.as_ptr(),
2025                output.as_mut_ptr(),
2026                len,
2027                period_start,
2028                period_end,
2029                period_step,
2030            )
2031        };
2032
2033        assert!(result.is_ok());
2034        assert_eq!(result.unwrap(), rows);
2035
2036        for (r, params) in combos.iter().enumerate() {
2037            let period = params.period.unwrap();
2038            let warmup_end = period - 1;
2039
2040            for i in 0..warmup_end {
2041                let idx = r * len + i;
2042                assert!(
2043                    output[idx].is_nan(),
2044                    "Expected NaN at row {} col {} (period {}) but got {}",
2045                    r,
2046                    i,
2047                    period,
2048                    output[idx]
2049                );
2050            }
2051
2052            let first_valid_idx = r * len + warmup_end;
2053            assert!(
2054                !output[first_valid_idx].is_nan(),
2055                "Expected valid value at row {} col {} (period {}) but got NaN",
2056                r,
2057                warmup_end,
2058                period
2059            );
2060        }
2061    }
2062
2063    #[test]
2064    fn print_actual_fwma_values() {
2065        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2066        let candles = read_candles_from_csv(file_path).unwrap();
2067
2068        println!("\nCandle data info:");
2069        println!("Total candles: {}", candles.close.len());
2070        println!("Last 10 close prices:");
2071        let close_len = candles.close.len();
2072        for i in (close_len - 10)..close_len {
2073            println!("  [{}]: {}", i - (close_len - 10), candles.close[i]);
2074        }
2075
2076        let input = FwmaInput::with_default_candles(&candles);
2077        let result = fwma(&input).unwrap();
2078
2079        println!("\nFWMA results (period = 5):");
2080        println!("Last 5 values:");
2081        let result_len = result.values.len();
2082        for i in (result_len - 5)..result_len {
2083            println!("  [{}]: {:.12}", i - (result_len - 5), result.values[i]);
2084        }
2085
2086        let expected_last_five = [
2087            59273.583333333336,
2088            59252.5,
2089            59167.083333333336,
2090            59151.0,
2091            58940.333333333336,
2092        ];
2093
2094        println!("\nExpected values:");
2095        for (i, val) in expected_last_five.iter().enumerate() {
2096            println!("  [{}]: {:.12}", i, val);
2097        }
2098
2099        println!("\nDifferences:");
2100        for i in 0..5 {
2101            let actual = result.values[result_len - 5 + i];
2102            let expected = expected_last_five[i];
2103            println!(
2104                "  [{}]: {:.12} (diff: {:.2e})",
2105                i,
2106                actual - expected,
2107                (actual - expected).abs()
2108            );
2109        }
2110    }
2111}
2112
2113#[cfg(all(feature = "python", feature = "cuda"))]
2114#[pyfunction(name = "fwma_cuda_batch_dev")]
2115#[pyo3(signature = (data, period_range, device_id=0))]
2116pub fn fwma_cuda_batch_dev_py(
2117    py: Python<'_>,
2118    data: PyReadonlyArray1<'_, f64>,
2119    period_range: (usize, usize, usize),
2120    device_id: usize,
2121) -> PyResult<DeviceArrayF32Py> {
2122    use numpy::PyArrayMethods;
2123
2124    if !cuda_available() {
2125        return Err(PyValueError::new_err("CUDA not available"));
2126    }
2127
2128    let slice_in = data.as_slice()?;
2129    let sweep = FwmaBatchRange {
2130        period: period_range,
2131    };
2132    let data_f32: Vec<f32> = slice_in.iter().map(|&v| v as f32).collect();
2133
2134    let (inner, ctx, dev_id) = py.allow_threads(|| {
2135        let cuda = CudaFwma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2136        let ctx = cuda.context_arc_clone();
2137        let dev = cuda.device_id();
2138        let arr = cuda
2139            .fwma_batch_dev(&data_f32, &sweep)
2140            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2141        Ok::<_, PyErr>((arr, ctx, dev))
2142    })?;
2143
2144    Ok(DeviceArrayF32Py {
2145        inner,
2146        _ctx: ctx,
2147        device_id: dev_id,
2148        stream: 0,
2149    })
2150}
2151
2152#[cfg(all(feature = "python", feature = "cuda"))]
2153#[pyfunction(name = "fwma_cuda_many_series_one_param_dev")]
2154#[pyo3(signature = (data_tm_f32, period, device_id=0))]
2155pub fn fwma_cuda_many_series_one_param_dev_py(
2156    py: Python<'_>,
2157    data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
2158    period: usize,
2159    device_id: usize,
2160) -> PyResult<DeviceArrayF32Py> {
2161    use numpy::PyUntypedArrayMethods;
2162
2163    if !cuda_available() {
2164        return Err(PyValueError::new_err("CUDA not available"));
2165    }
2166
2167    let flat_in = data_tm_f32.as_slice()?;
2168    let rows = data_tm_f32.shape()[0];
2169    let cols = data_tm_f32.shape()[1];
2170    let params = FwmaParams {
2171        period: Some(period),
2172    };
2173
2174    let (inner, ctx, dev_id) = py.allow_threads(|| {
2175        let cuda = CudaFwma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2176        let ctx = cuda.context_arc_clone();
2177        let dev = cuda.device_id();
2178        let arr = cuda
2179            .fwma_multi_series_one_param_time_major_dev(flat_in, cols, rows, &params)
2180            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2181        Ok::<_, PyErr>((arr, ctx, dev))
2182    })?;
2183
2184    Ok(DeviceArrayF32Py {
2185        inner,
2186        _ctx: ctx,
2187        device_id: dev_id,
2188        stream: 0,
2189    })
2190}
2191
2192#[cfg(feature = "python")]
2193#[pyfunction(name = "fwma")]
2194#[pyo3(signature = (data, period, kernel=None))]
2195pub fn fwma_py<'py>(
2196    py: Python<'py>,
2197    data: numpy::PyReadonlyArray1<'py, f64>,
2198    period: usize,
2199    kernel: Option<&str>,
2200) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
2201    use numpy::PyArrayMethods;
2202
2203    let slice_in = data.as_slice()?;
2204    let kern = validate_kernel(kernel, false)?;
2205
2206    let params = FwmaParams {
2207        period: Some(period),
2208    };
2209    let fwma_in = FwmaInput::from_slice(slice_in, params);
2210
2211    let result_vec: Vec<f64> = py
2212        .allow_threads(|| fwma_with_kernel(&fwma_in, kern).map(|o| o.values))
2213        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2214
2215    Ok(result_vec.into_pyarray(py))
2216}
2217
2218#[cfg(feature = "python")]
2219#[pyclass(name = "FwmaStream")]
2220pub struct FwmaStreamPy {
2221    stream: FwmaStream,
2222}
2223
2224#[cfg(feature = "python")]
2225#[pymethods]
2226impl FwmaStreamPy {
2227    #[new]
2228    fn new(period: usize) -> PyResult<Self> {
2229        let params = FwmaParams {
2230            period: Some(period),
2231        };
2232        let stream =
2233            FwmaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2234        Ok(FwmaStreamPy { stream })
2235    }
2236
2237    fn update(&mut self, value: f64) -> Option<f64> {
2238        self.stream.update(value)
2239    }
2240}
2241
2242#[cfg(feature = "python")]
2243#[pyfunction(name = "fwma_batch")]
2244#[pyo3(signature = (data, period_range, kernel=None))]
2245pub fn fwma_batch_py<'py>(
2246    py: Python<'py>,
2247    data: numpy::PyReadonlyArray1<'py, f64>,
2248    period_range: (usize, usize, usize),
2249    kernel: Option<&str>,
2250) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
2251    use numpy::PyArrayMethods;
2252
2253    let slice_in = data.as_slice()?;
2254
2255    let sweep = FwmaBatchRange {
2256        period: period_range,
2257    };
2258
2259    let combos = expand_grid(&sweep);
2260    let rows = combos.len();
2261    let cols = slice_in.len();
2262
2263    let total = rows
2264        .checked_mul(cols)
2265        .ok_or_else(|| PyValueError::new_err("fwma_batch: rows*cols overflow"))?;
2266    let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2267    let slice_out = unsafe { out_arr.as_slice_mut()? };
2268
2269    let kern = validate_kernel(kernel, true)?;
2270
2271    let first = slice_in.iter().position(|x| !x.is_nan()).unwrap_or(0);
2272    let warm: Vec<usize> = combos
2273        .iter()
2274        .map(|c| first + c.period.unwrap() - 1)
2275        .collect();
2276
2277    fill_nan_prefixes_slice(rows, cols, &warm, slice_out);
2278
2279    let combos = py
2280        .allow_threads(|| {
2281            let kernel = match kern {
2282                Kernel::Auto => detect_best_batch_kernel(),
2283                k => k,
2284            };
2285            let simd = match kernel {
2286                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2287                Kernel::Avx512Batch => Kernel::Avx512,
2288                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2289                Kernel::Avx2Batch => Kernel::Avx2,
2290                Kernel::ScalarBatch => Kernel::Scalar,
2291                #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
2292                _ => Kernel::Scalar,
2293                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2294                _ => unreachable!(),
2295            };
2296            fwma_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
2297        })
2298        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2299
2300    let dict = PyDict::new(py);
2301    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2302    dict.set_item(
2303        "periods",
2304        combos
2305            .iter()
2306            .map(|p| p.period.unwrap() as u64)
2307            .collect::<Vec<_>>()
2308            .into_pyarray(py),
2309    )?;
2310
2311    Ok(dict)
2312}
2313
2314#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2315#[wasm_bindgen]
2316pub fn fwma_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
2317    let params = FwmaParams {
2318        period: Some(period),
2319    };
2320    let input = FwmaInput::from_slice(data, params);
2321
2322    let mut output = vec![0.0; data.len()];
2323
2324    fwma_into_slice(&mut output, &input, Kernel::Auto)
2325        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2326
2327    Ok(output)
2328}
2329
2330#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2331#[wasm_bindgen]
2332pub fn fwma_batch_js(
2333    data: &[f64],
2334    period_start: usize,
2335    period_end: usize,
2336    period_step: usize,
2337) -> Result<Vec<f64>, JsValue> {
2338    let sweep = FwmaBatchRange {
2339        period: (period_start, period_end, period_step),
2340    };
2341
2342    fwma_batch_inner(data, &sweep, Kernel::Scalar, false)
2343        .map(|output| output.values)
2344        .map_err(|e| JsValue::from_str(&e.to_string()))
2345}
2346
2347#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2348#[wasm_bindgen]
2349pub fn fwma_batch_metadata_js(
2350    period_start: usize,
2351    period_end: usize,
2352    period_step: usize,
2353) -> Vec<f64> {
2354    fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
2355        if step == 0 || start == end {
2356            return vec![start];
2357        }
2358        let (lo, hi) = if start <= end {
2359            (start, end)
2360        } else {
2361            (end, start)
2362        };
2363        (lo..=hi).step_by(step).collect()
2364    }
2365
2366    let periods = axis_usize((period_start, period_end, period_step));
2367    periods.into_iter().map(|p| p as f64).collect()
2368}
2369
2370#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2371#[wasm_bindgen]
2372pub fn fwma_alloc(len: usize) -> *mut f64 {
2373    let mut vec = Vec::<f64>::with_capacity(len);
2374    let ptr = vec.as_mut_ptr();
2375    std::mem::forget(vec);
2376    ptr
2377}
2378
2379#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2380#[wasm_bindgen]
2381pub fn fwma_free(ptr: *mut f64, len: usize) {
2382    if !ptr.is_null() {
2383        unsafe {
2384            let _ = Vec::from_raw_parts(ptr, len, len);
2385        }
2386    }
2387}
2388
2389#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2390#[wasm_bindgen]
2391pub fn fwma_into(
2392    in_ptr: *const f64,
2393    out_ptr: *mut f64,
2394    len: usize,
2395    period: usize,
2396) -> Result<(), JsValue> {
2397    if in_ptr.is_null() || out_ptr.is_null() {
2398        return Err(JsValue::from_str("null pointer passed to fwma_into"));
2399    }
2400
2401    unsafe {
2402        let data = std::slice::from_raw_parts(in_ptr, len);
2403
2404        if period == 0 || period > len {
2405            return Err(JsValue::from_str("Invalid period"));
2406        }
2407
2408        let params = FwmaParams {
2409            period: Some(period),
2410        };
2411        let input = FwmaInput::from_slice(data, params);
2412
2413        if in_ptr == out_ptr {
2414            let mut temp = vec![0.0; len];
2415            fwma_into_slice(&mut temp, &input, Kernel::Auto)
2416                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2417
2418            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2419            out.copy_from_slice(&temp);
2420        } else {
2421            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2422            fwma_into_slice(out, &input, Kernel::Auto)
2423                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2424        }
2425
2426        Ok(())
2427    }
2428}
2429
2430#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2431#[derive(Serialize, Deserialize)]
2432pub struct FwmaBatchConfig {
2433    pub period_range: (usize, usize, usize),
2434}
2435
2436#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2437#[derive(Serialize, Deserialize)]
2438pub struct FwmaBatchJsOutput {
2439    pub values: Vec<f64>,
2440    pub combos: Vec<FwmaParams>,
2441    pub rows: usize,
2442    pub cols: usize,
2443}
2444
2445#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2446#[wasm_bindgen(js_name = fwma_batch)]
2447pub fn fwma_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2448    let config: FwmaBatchConfig = serde_wasm_bindgen::from_value(config)
2449        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2450
2451    let sweep = FwmaBatchRange {
2452        period: config.period_range,
2453    };
2454
2455    let output = fwma_batch_inner(data, &sweep, Kernel::Auto, false)
2456        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2457
2458    let js_output = FwmaBatchJsOutput {
2459        values: output.values,
2460        combos: output.combos,
2461        rows: output.rows,
2462        cols: output.cols,
2463    };
2464
2465    serde_wasm_bindgen::to_value(&js_output)
2466        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2467}
2468
2469#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2470#[wasm_bindgen]
2471pub fn fwma_batch_into(
2472    in_ptr: *const f64,
2473    out_ptr: *mut f64,
2474    len: usize,
2475    period_start: usize,
2476    period_end: usize,
2477    period_step: usize,
2478) -> Result<usize, JsValue> {
2479    if in_ptr.is_null() || out_ptr.is_null() {
2480        return Err(JsValue::from_str("null pointer passed to fwma_batch_into"));
2481    }
2482
2483    unsafe {
2484        let data = std::slice::from_raw_parts(in_ptr, len);
2485
2486        let sweep = FwmaBatchRange {
2487            period: (period_start, period_end, period_step),
2488        };
2489
2490        let combos = expand_grid(&sweep);
2491        let rows = combos.len();
2492        let cols = len;
2493
2494        let total = rows
2495            .checked_mul(cols)
2496            .ok_or(JsValue::from_str("fwma_batch_into: rows*cols overflow"))?;
2497
2498        let out = std::slice::from_raw_parts_mut(out_ptr, total);
2499
2500        let first = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
2501        let warm: Vec<usize> = combos
2502            .iter()
2503            .map(|c| first + c.period.unwrap() - 1)
2504            .collect();
2505
2506        fill_nan_prefixes_slice(rows, cols, &warm, out);
2507
2508        fwma_batch_inner_into(data, &sweep, Kernel::Auto, false, out)
2509            .map_err(|e| JsValue::from_str(&e.to_string()))?;
2510
2511        Ok(rows)
2512    }
2513}