1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::cuda_available;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::cuda::moving_averages::alma_wrapper::DeviceArrayF32;
5#[cfg(all(feature = "python", feature = "cuda"))]
6use crate::cuda::moving_averages::CudaFwma;
7use crate::utilities::aligned_vector::AlignedVec;
8use crate::utilities::data_loader::{source_type, Candles};
9use crate::utilities::enums::Kernel;
10use crate::utilities::helpers::{
11 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
12 make_uninit_matrix,
13};
14#[cfg(feature = "python")]
15use crate::utilities::kernel_validation::validate_kernel;
16#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
17use core::arch::x86_64::*;
18#[cfg(feature = "python")]
19use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
20#[cfg(feature = "python")]
21use pyo3::exceptions::PyValueError;
22#[cfg(feature = "python")]
23use pyo3::prelude::*;
24#[cfg(feature = "python")]
25use pyo3::types::PyDict;
26
27#[cfg(all(feature = "python", feature = "cuda"))]
28use cust::context::Context;
29#[cfg(all(feature = "python", feature = "cuda"))]
30use cust::memory::DeviceBuffer;
31#[cfg(not(target_arch = "wasm32"))]
32use rayon::prelude::*;
33#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
34use serde::{Deserialize, Serialize};
35use std::convert::AsRef;
36use std::error::Error;
37use std::mem::{ManuallyDrop, MaybeUninit};
38#[cfg(all(feature = "python", feature = "cuda"))]
39use std::sync::Arc;
40use thiserror::Error;
41#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
42use wasm_bindgen::prelude::*;
43
44#[cfg(all(feature = "python", feature = "cuda"))]
45#[pyclass(module = "ta_indicators.cuda", name = "FwmaDeviceArrayF32", unsendable)]
46pub struct DeviceArrayF32Py {
47 pub(crate) inner: DeviceArrayF32,
48 pub(crate) _ctx: Arc<Context>,
49 pub(crate) device_id: u32,
50 pub(crate) stream: usize,
51}
52
53#[cfg(all(feature = "python", feature = "cuda"))]
54#[pymethods]
55impl DeviceArrayF32Py {
56 #[getter]
57 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
58 let d = PyDict::new(py);
59 d.set_item("shape", (self.inner.rows, self.inner.cols))?;
60 d.set_item("typestr", "<f4")?;
61 d.set_item(
62 "strides",
63 (
64 self.inner.cols * core::mem::size_of::<f32>(),
65 core::mem::size_of::<f32>(),
66 ),
67 )?;
68 d.set_item("data", (self.inner.device_ptr() as usize, false))?;
69
70 d.set_item("version", 3)?;
71 Ok(d)
72 }
73
74 fn __dlpack_device__(&self) -> (i32, i32) {
75 (2, self.device_id as i32)
76 }
77
78 #[pyo3(signature=(stream=None, max_version=None, dl_device=None, copy=None))]
79 fn __dlpack__<'py>(
80 &mut self,
81 py: Python<'py>,
82 stream: Option<pyo3::PyObject>,
83 max_version: Option<pyo3::PyObject>,
84 dl_device: Option<pyo3::PyObject>,
85 copy: Option<pyo3::PyObject>,
86 ) -> PyResult<PyObject> {
87 use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
88
89 let (kdl, alloc_dev) = self.__dlpack_device__();
90 if let Some(dev_obj) = dl_device.as_ref() {
91 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
92 if dev_ty != kdl || dev_id != alloc_dev {
93 let wants_copy = copy
94 .as_ref()
95 .and_then(|c| c.extract::<bool>(py).ok())
96 .unwrap_or(false);
97 if wants_copy {
98 return Err(PyValueError::new_err(
99 "device copy not implemented for __dlpack__",
100 ));
101 } else {
102 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
103 }
104 }
105 }
106 }
107 let _ = stream;
108
109 let dummy =
110 DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
111 let inner = std::mem::replace(
112 &mut self.inner,
113 DeviceArrayF32 {
114 buf: dummy,
115 rows: 0,
116 cols: 0,
117 },
118 );
119
120 let rows = inner.rows;
121 let cols = inner.cols;
122 let buf = inner.buf;
123
124 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
125
126 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
127 }
128}
129
130impl<'a> AsRef<[f64]> for FwmaInput<'a> {
131 #[inline(always)]
132 fn as_ref(&self) -> &[f64] {
133 match &self.data {
134 FwmaData::Slice(slice) => slice,
135 FwmaData::Candles { candles, source } => source_type(candles, source),
136 }
137 }
138}
139
140#[derive(Debug, Clone)]
141pub enum FwmaData<'a> {
142 Candles {
143 candles: &'a Candles,
144 source: &'a str,
145 },
146 Slice(&'a [f64]),
147}
148
149#[derive(Debug, Clone)]
150pub struct FwmaOutput {
151 pub values: Vec<f64>,
152}
153
154#[derive(Debug, Clone)]
155#[cfg_attr(
156 all(target_arch = "wasm32", feature = "wasm"),
157 derive(Serialize, Deserialize)
158)]
159pub struct FwmaParams {
160 pub period: Option<usize>,
161}
162
163impl Default for FwmaParams {
164 fn default() -> Self {
165 Self { period: Some(5) }
166 }
167}
168
169#[derive(Debug, Clone)]
170pub struct FwmaInput<'a> {
171 pub data: FwmaData<'a>,
172 pub params: FwmaParams,
173}
174
175impl<'a> FwmaInput<'a> {
176 #[inline]
177 pub fn from_candles(c: &'a Candles, s: &'a str, p: FwmaParams) -> Self {
178 Self {
179 data: FwmaData::Candles {
180 candles: c,
181 source: s,
182 },
183 params: p,
184 }
185 }
186 #[inline]
187 pub fn from_slice(sl: &'a [f64], p: FwmaParams) -> Self {
188 Self {
189 data: FwmaData::Slice(sl),
190 params: p,
191 }
192 }
193 #[inline]
194 pub fn with_default_candles(c: &'a Candles) -> Self {
195 Self::from_candles(c, "close", FwmaParams::default())
196 }
197 #[inline]
198 pub fn get_period(&self) -> usize {
199 self.params.period.unwrap_or(5)
200 }
201}
202
203#[derive(Copy, Clone, Debug)]
204pub struct FwmaBuilder {
205 period: Option<usize>,
206 kernel: Kernel,
207}
208
209impl Default for FwmaBuilder {
210 fn default() -> Self {
211 Self {
212 period: None,
213 kernel: Kernel::Auto,
214 }
215 }
216}
217
218impl FwmaBuilder {
219 #[inline(always)]
220 pub fn new() -> Self {
221 Self::default()
222 }
223 #[inline(always)]
224 pub fn period(mut self, n: usize) -> Self {
225 self.period = Some(n);
226 self
227 }
228 #[inline(always)]
229 pub fn kernel(mut self, k: Kernel) -> Self {
230 self.kernel = k;
231 self
232 }
233
234 #[inline(always)]
235 pub fn apply(self, c: &Candles) -> Result<FwmaOutput, FwmaError> {
236 let p = FwmaParams {
237 period: self.period,
238 };
239 let i = FwmaInput::from_candles(c, "close", p);
240 fwma_with_kernel(&i, self.kernel)
241 }
242
243 #[inline(always)]
244 pub fn apply_slice(self, d: &[f64]) -> Result<FwmaOutput, FwmaError> {
245 let p = FwmaParams {
246 period: self.period,
247 };
248 let i = FwmaInput::from_slice(d, p);
249 fwma_with_kernel(&i, self.kernel)
250 }
251
252 #[inline(always)]
253 pub fn into_stream(self) -> Result<FwmaStream, FwmaError> {
254 let p = FwmaParams {
255 period: self.period,
256 };
257 FwmaStream::try_new(p)
258 }
259}
260
261#[derive(Debug, Error)]
262pub enum FwmaError {
263 #[error("fwma: Input data slice is empty.")]
264 EmptyInputData,
265 #[error("fwma: All values are NaN.")]
266 AllValuesNaN,
267 #[error("fwma: Invalid period: period = {period}, data length = {data_len}")]
268 InvalidPeriod { period: usize, data_len: usize },
269 #[error("fwma: Not enough valid data: needed = {needed}, valid = {valid}")]
270 NotEnoughValidData { needed: usize, valid: usize },
271 #[error("fwma: Fibonacci sum is zero. Cannot normalize weights.")]
272 ZeroFibonacciSum,
273 #[error("fwma: Output buffer length mismatch: expected = {expected}, got = {got}")]
274 OutputLengthMismatch { expected: usize, got: usize },
275 #[error("fwma: Invalid range: start={start}, end={end}, step={step}")]
276 InvalidRange {
277 start: usize,
278 end: usize,
279 step: usize,
280 },
281 #[error("fwma: Invalid kernel for batch API: {0:?}")]
282 InvalidKernelForBatch(Kernel),
283 #[error("fwma: arithmetic overflow while computing {context}")]
284 ArithmeticOverflow { context: &'static str },
285}
286
287#[inline]
288pub fn fwma(input: &FwmaInput) -> Result<FwmaOutput, FwmaError> {
289 fwma_with_kernel(input, Kernel::Auto)
290}
291
292#[inline(always)]
293fn fwma_prepare<'a>(
294 input: &'a FwmaInput,
295 kernel: Kernel,
296) -> Result<(&'a [f64], Vec<f64>, usize, usize, Kernel), FwmaError> {
297 let data: &[f64] = input.as_ref();
298 let len = data.len();
299 if len == 0 {
300 return Err(FwmaError::EmptyInputData);
301 }
302 let first = data
303 .iter()
304 .position(|x| !x.is_nan())
305 .ok_or(FwmaError::AllValuesNaN)?;
306 let period = input.get_period();
307
308 if period == 0 || period > len {
309 return Err(FwmaError::InvalidPeriod {
310 period,
311 data_len: len,
312 });
313 }
314 if (len - first) < period {
315 return Err(FwmaError::NotEnoughValidData {
316 needed: period,
317 valid: len - first,
318 });
319 }
320
321 let mut fib = vec![1.0; period];
322 for i in 2..period {
323 fib[i] = fib[i - 1] + fib[i - 2];
324 }
325 let fib_sum: f64 = fib.iter().sum();
326 if fib_sum == 0.0 {
327 return Err(FwmaError::ZeroFibonacciSum);
328 }
329 for w in &mut fib {
330 *w /= fib_sum;
331 }
332
333 let chosen = match kernel {
334 Kernel::Auto => detect_best_kernel(),
335 other => other,
336 };
337
338 Ok((data, fib, period, first, chosen))
339}
340
341#[inline(always)]
342fn fwma_compute_into(
343 data: &[f64],
344 fib: &[f64],
345 period: usize,
346 first: usize,
347 kernel: Kernel,
348 out: &mut [f64],
349) {
350 unsafe {
351 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
352 {
353 if matches!(kernel, Kernel::Scalar | Kernel::ScalarBatch) {
354 fwma_simd128(data, fib, period, first, out);
355 return;
356 }
357 }
358
359 match kernel {
360 Kernel::Scalar | Kernel::ScalarBatch => fwma_scalar(data, fib, period, first, out),
361 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
362 Kernel::Avx2 | Kernel::Avx2Batch => fwma_avx2(data, fib, period, first, out),
363 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
364 Kernel::Avx512 | Kernel::Avx512Batch => fwma_avx512(data, fib, period, first, out),
365 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
366 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
367 fwma_scalar(data, fib, period, first, out)
368 }
369 _ => unreachable!(),
370 }
371 }
372}
373
374pub fn fwma_with_kernel(input: &FwmaInput, kernel: Kernel) -> Result<FwmaOutput, FwmaError> {
375 let (data, fib, period, first, chosen) = fwma_prepare(input, kernel)?;
376
377 let warm = first + period - 1;
378 let mut out = alloc_with_nan_prefix(data.len(), warm);
379
380 fwma_compute_into(data, &fib, period, first, chosen, &mut out);
381
382 Ok(FwmaOutput { values: out })
383}
384
385#[inline]
386pub fn fwma_into_slice(dst: &mut [f64], input: &FwmaInput, kern: Kernel) -> Result<(), FwmaError> {
387 let (data, fib, period, first, chosen) = fwma_prepare(input, kern)?;
388
389 if dst.len() != data.len() {
390 return Err(FwmaError::OutputLengthMismatch {
391 expected: data.len(),
392 got: dst.len(),
393 });
394 }
395
396 let warmup_end = (first + period - 1).min(dst.len());
397 for v in &mut dst[..warmup_end] {
398 *v = f64::from_bits(0x7ff8_0000_0000_0000);
399 }
400
401 fwma_compute_into(data, &fib, period, first, chosen, dst);
402
403 Ok(())
404}
405
406#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
407#[inline]
408pub fn fwma_into(input: &FwmaInput, out: &mut [f64]) -> Result<(), FwmaError> {
409 let (data, fib, period, first, chosen) = fwma_prepare(input, Kernel::Auto)?;
410
411 if out.len() != data.len() {
412 return Err(FwmaError::OutputLengthMismatch {
413 expected: data.len(),
414 got: out.len(),
415 });
416 }
417
418 let warm = (first + period - 1).min(out.len());
419 for v in &mut out[..warm] {
420 *v = f64::from_bits(0x7ff8_0000_0000_0000);
421 }
422
423 fwma_compute_into(data, &fib, period, first, chosen, out);
424 Ok(())
425}
426
427#[inline(always)]
428pub unsafe fn fwma_scalar(
429 data: &[f64],
430 fib: &[f64],
431 period: usize,
432 first_val: usize,
433 out: &mut [f64],
434) {
435 assert_eq!(fib.len(), period, "fib.len() must equal period");
436 assert!(
437 out.len() >= data.len(),
438 "out must be at least as long as data"
439 );
440
441 let p4 = period & !3;
442 for i in (first_val + period - 1)..data.len() {
443 let start = i + 1 - period;
444 let window = &data[start..start + period];
445 let mut sum = 0.0;
446 for (d4, w4) in window[..p4].chunks_exact(4).zip(fib[..p4].chunks_exact(4)) {
447 sum += d4[0] * w4[0] + d4[1] * w4[1] + d4[2] * w4[2] + d4[3] * w4[3];
448 }
449 for (d, w) in window[p4..].iter().zip(&fib[p4..]) {
450 sum += d * w;
451 }
452 out[i] = sum;
453 }
454}
455
456#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
457#[inline]
458unsafe fn fwma_simd128(
459 data: &[f64],
460 fib: &[f64],
461 period: usize,
462 first_val: usize,
463 out: &mut [f64],
464) {
465 use core::arch::wasm32::*;
466
467 assert_eq!(fib.len(), period, "fib.len() must equal period");
468 assert!(
469 out.len() >= data.len(),
470 "out must be at least as long as data"
471 );
472
473 const STEP: usize = 2;
474 let chunks = period / STEP;
475 let tail = period % STEP;
476
477 for i in (first_val + period - 1)..data.len() {
478 let start = i + 1 - period;
479 let window = &data[start..start + period];
480
481 let mut sum_vec = f64x2_splat(0.0);
482
483 for j in 0..chunks {
484 let idx = j * STEP;
485 let d_vec = v128_load(&window[idx] as *const f64 as *const v128);
486 let w_vec = v128_load(&fib[idx] as *const f64 as *const v128);
487 sum_vec = f64x2_add(sum_vec, f64x2_mul(d_vec, w_vec));
488 }
489
490 let mut sum = f64x2_extract_lane::<0>(sum_vec) + f64x2_extract_lane::<1>(sum_vec);
491
492 if tail > 0 {
493 sum += window[chunks * STEP] * fib[chunks * STEP];
494 }
495
496 out[i] = sum;
497 }
498}
499
500#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
501#[target_feature(enable = "avx2")]
502unsafe fn horizontal_sum_avx2(v: __m256d) -> f64 {
503 let high_low = _mm256_hadd_pd(v, v);
504 let high = _mm256_extractf128_pd(high_low, 1);
505 let low = _mm256_castpd256_pd128(high_low);
506 let sum = _mm_add_pd(high, low);
507 let result = _mm_hadd_pd(sum, sum);
508 _mm_cvtsd_f64(result) * 0.5
509}
510
511#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
512#[target_feature(enable = "avx512f,fma")]
513unsafe fn fwma_avx512_short(
514 data: &[f64],
515 fib: &[f64],
516 period: usize,
517 first: usize,
518 out: &mut [f64],
519) {
520 const SIMD_WIDTH: usize = 8;
521 let simd_chunks = period / SIMD_WIDTH;
522 let remainder = period % SIMD_WIDTH;
523
524 let data_ptr = data.as_ptr();
525 let out_ptr = out.as_mut_ptr();
526
527 let mut aligned_fib = AlignedVec::with_capacity(period + SIMD_WIDTH);
528 let fib_buf = aligned_fib.as_mut_slice();
529 fib_buf[..period].copy_from_slice(fib);
530 let fib_ptr = fib_buf.as_ptr();
531
532 let mut fib_vecs = Vec::with_capacity(simd_chunks);
533 for chunk in 0..simd_chunks {
534 fib_vecs.push(_mm512_load_pd(fib_ptr.add(chunk * SIMD_WIDTH)));
535 }
536
537 let tail_mask: __mmask8 = (1u8 << remainder).wrapping_sub(1);
538
539 for idx in (first + period - 1)..data.len() {
540 let start = idx + 1 - period;
541 let window_ptr = data_ptr.add(start);
542
543 _mm_prefetch(window_ptr.add(64) as *const i8, _MM_HINT_T0);
544
545 let mut sum0 = _mm512_setzero_pd();
546 let mut sum1 = _mm512_setzero_pd();
547 let mut sum2 = _mm512_setzero_pd();
548 let mut sum3 = _mm512_setzero_pd();
549
550 let chunks4 = simd_chunks / 4;
551 for i in 0..chunks4 {
552 let base = window_ptr.add(i * 32);
553 sum0 = _mm512_fmadd_pd(_mm512_loadu_pd(base), fib_vecs[i * 4 + 0], sum0);
554 sum1 = _mm512_fmadd_pd(_mm512_loadu_pd(base.add(8)), fib_vecs[i * 4 + 1], sum1);
555 sum2 = _mm512_fmadd_pd(_mm512_loadu_pd(base.add(16)), fib_vecs[i * 4 + 2], sum2);
556 sum3 = _mm512_fmadd_pd(_mm512_loadu_pd(base.add(24)), fib_vecs[i * 4 + 3], sum3);
557 }
558
559 for i in (chunks4 * 4)..simd_chunks {
560 let base = window_ptr.add(i * SIMD_WIDTH);
561 sum0 = _mm512_fmadd_pd(_mm512_loadu_pd(base), fib_vecs[i], sum0);
562 }
563
564 sum0 = _mm512_add_pd(sum0, sum1);
565 sum2 = _mm512_add_pd(sum2, sum3);
566 sum0 = _mm512_add_pd(sum0, sum2);
567
568 if remainder != 0 {
569 let data_tail =
570 _mm512_maskz_loadu_pd(tail_mask, window_ptr.add(simd_chunks * SIMD_WIDTH));
571 let weight_tail =
572 _mm512_maskz_load_pd(tail_mask, fib_ptr.add(simd_chunks * SIMD_WIDTH));
573 sum0 = _mm512_fmadd_pd(data_tail, weight_tail, sum0);
574 }
575
576 let total = _mm512_reduce_add_pd(sum0);
577 *out_ptr.add(idx) = total;
578 }
579}
580
581#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
582#[target_feature(enable = "avx512f,fma")]
583unsafe fn fwma_avx512_long(
584 data: &[f64],
585 fib: &[f64],
586 period: usize,
587 first: usize,
588 out: &mut [f64],
589) {
590 const STEP: usize = 8;
591 const UNROLL: usize = 4;
592
593 let chunks = period / STEP;
594 let tail_len = period % STEP;
595
596 let mut aligned_fib = AlignedVec::with_capacity(period + STEP);
597 let fib_buf = aligned_fib.as_mut_slice();
598 fib_buf[..period].copy_from_slice(fib);
599 let fib_ptr = fib_buf.as_ptr();
600
601 let mut weight_vecs = Vec::with_capacity(chunks);
602 for i in 0..chunks {
603 weight_vecs.push(_mm512_load_pd(fib_ptr.add(i * STEP)));
604 }
605
606 let tmask: __mmask8 = (1u8 << tail_len).wrapping_sub(1);
607 let w_tail = if tail_len > 0 {
608 Some(_mm512_maskz_load_pd(tmask, fib_ptr.add(chunks * STEP)))
609 } else {
610 None
611 };
612
613 let end = data.len();
614 let last_valid = end.saturating_sub(UNROLL - 1);
615 let mut i = first + period - 1;
616
617 while i < last_valid {
618 let base0 = data.as_ptr().add(i + 1 - period);
619 let base1 = base0.add(1);
620 let base2 = base0.add(2);
621 let base3 = base0.add(3);
622
623 let mut sum0 = _mm512_setzero_pd();
624 let mut sum1 = _mm512_setzero_pd();
625 let mut sum2 = _mm512_setzero_pd();
626 let mut sum3 = _mm512_setzero_pd();
627
628 for (j, &w) in weight_vecs.iter().enumerate() {
629 let offset = j * STEP;
630
631 let d0 = _mm512_loadu_pd(base0.add(offset));
632 let d1 = _mm512_loadu_pd(base1.add(offset));
633 let d2 = _mm512_loadu_pd(base2.add(offset));
634 let d3 = _mm512_loadu_pd(base3.add(offset));
635
636 sum0 = _mm512_fmadd_pd(d0, w, sum0);
637 sum1 = _mm512_fmadd_pd(d1, w, sum1);
638 sum2 = _mm512_fmadd_pd(d2, w, sum2);
639 sum3 = _mm512_fmadd_pd(d3, w, sum3);
640 }
641
642 if let Some(wt) = w_tail {
643 let offset = chunks * STEP;
644 let d0 = _mm512_maskz_loadu_pd(tmask, base0.add(offset));
645 let d1 = _mm512_maskz_loadu_pd(tmask, base1.add(offset));
646 let d2 = _mm512_maskz_loadu_pd(tmask, base2.add(offset));
647 let d3 = _mm512_maskz_loadu_pd(tmask, base3.add(offset));
648
649 sum0 = _mm512_fmadd_pd(d0, wt, sum0);
650 sum1 = _mm512_fmadd_pd(d1, wt, sum1);
651 sum2 = _mm512_fmadd_pd(d2, wt, sum2);
652 sum3 = _mm512_fmadd_pd(d3, wt, sum3);
653 }
654
655 out[i] = _mm512_reduce_add_pd(sum0);
656 out[i + 1] = _mm512_reduce_add_pd(sum1);
657 out[i + 2] = _mm512_reduce_add_pd(sum2);
658 out[i + 3] = _mm512_reduce_add_pd(sum3);
659
660 i += UNROLL;
661 }
662
663 while i < end {
664 let base = data.as_ptr().add(i + 1 - period);
665 let mut sum = _mm512_setzero_pd();
666
667 for (j, &w) in weight_vecs.iter().enumerate() {
668 let d = _mm512_loadu_pd(base.add(j * STEP));
669 sum = _mm512_fmadd_pd(d, w, sum);
670 }
671
672 if let Some(wt) = w_tail {
673 let d = _mm512_maskz_loadu_pd(tmask, base.add(chunks * STEP));
674 sum = _mm512_fmadd_pd(d, wt, sum);
675 }
676
677 out[i] = _mm512_reduce_add_pd(sum);
678 i += 1;
679 }
680}
681
682#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
683#[target_feature(enable = "avx512f,fma")]
684pub unsafe fn fwma_avx512(data: &[f64], fib: &[f64], period: usize, first: usize, out: &mut [f64]) {
685 if period <= 32 {
686 fwma_avx512_short(data, fib, period, first, out);
687 } else {
688 fwma_avx512_long(data, fib, period, first, out);
689 }
690}
691
692#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
693#[target_feature(enable = "avx2")]
694pub unsafe fn fwma_avx2(
695 data: &[f64],
696 fib: &[f64],
697 period: usize,
698 first_valid: usize,
699 out: &mut [f64],
700) {
701 const W: usize = 4;
702 let full = period / W;
703 let tail = period % W;
704
705 let mut aligned = AlignedVec::with_capacity(period + W);
706 let fib_aln = aligned.as_mut_slice();
707 fib_aln[..period].copy_from_slice(fib);
708 let wptr = fib_aln.as_ptr();
709
710 let dptr = data.as_ptr();
711 let optr = out.as_mut_ptr();
712
713 for i in (first_valid + period - 1)..data.len() {
714 let base = dptr.add(i + 1 - period);
715
716 let mut acc0 = _mm256_setzero_pd();
717 let mut acc1 = _mm256_setzero_pd();
718
719 let mut j = 0;
720 while j + 8 <= period {
721 let v0 = _mm256_loadu_pd(base.add(j));
722 let w0 = _mm256_load_pd(wptr.add(j));
723 acc0 = _mm256_fmadd_pd(v0, w0, acc0);
724
725 let v1 = _mm256_loadu_pd(base.add(j + 4));
726 let w1 = _mm256_load_pd(wptr.add(j + 4));
727 acc1 = _mm256_fmadd_pd(v1, w1, acc1);
728
729 j += 8;
730 }
731 if j + 4 <= period {
732 let v = _mm256_loadu_pd(base.add(j));
733 let w = _mm256_load_pd(wptr.add(j));
734 acc0 = _mm256_fmadd_pd(v, w, acc0);
735 j += 4;
736 }
737
738 let sum_vec = _mm256_add_pd(acc0, acc1);
739 let mut sum = horizontal_sum_avx2(sum_vec);
740
741 for k in 0..tail {
742 let idx = period - tail + k;
743 sum += *base.add(idx) * *wptr.add(idx);
744 }
745 *optr.add(i) = sum;
746 }
747}
748
749#[derive(Debug, Clone)]
750pub struct FwmaStream {
751 period: usize,
752
753 w: Vec<f64>,
754 w0: f64,
755 w_last: f64,
756 w_prev: f64,
757 w_next: f64,
758
759 buffer: Vec<f64>,
760 head: usize,
761 filled: bool,
762
763 acc_n: f64,
764 acc_d: f64,
765
766 nan_count: usize,
767}
768
769impl FwmaStream {
770 pub fn try_new(params: FwmaParams) -> Result<Self, FwmaError> {
771 let period = params.period.unwrap_or(5);
772 if period == 0 {
773 return Err(FwmaError::InvalidPeriod {
774 period,
775 data_len: 0,
776 });
777 }
778
779 let mut w = vec![1.0; period];
780 for i in 2..period {
781 w[i] = w[i - 1] + w[i - 2];
782 }
783
784 let sum: f64 = w.iter().sum();
785 if sum == 0.0 {
786 return Err(FwmaError::ZeroFibonacciSum);
787 }
788 for wi in &mut w {
789 *wi /= sum;
790 }
791
792 let w0 = w[0];
793 let w_last = w[period - 1];
794 let w_prev = if period > 1 { w[period - 2] } else { 0.0 };
795 let w_next = w_last + w_prev;
796
797 Ok(Self {
798 period,
799 w,
800 w0,
801 w_last,
802 w_prev,
803 w_next,
804 buffer: vec![f64::NAN; period],
805 head: 0,
806 filled: false,
807 acc_n: 0.0,
808 acc_d: 0.0,
809 nan_count: 0,
810 })
811 }
812
813 #[inline(always)]
814 pub fn update(&mut self, value: f64) -> Option<f64> {
815 let old_raw = self.buffer[self.head];
816
817 if self.filled && old_raw.is_nan() {
818 self.nan_count = self.nan_count.saturating_sub(1);
819 }
820
821 self.buffer[self.head] = value;
822 if value.is_nan() {
823 self.nan_count += 1;
824 }
825
826 self.head += 1;
827 if self.head == self.period {
828 self.head = 0;
829 }
830
831 if !self.filled {
832 if self.head != 0 {
833 return None;
834 }
835
836 self.filled = true;
837
838 let mut n = 0.0f64;
839 let mut q = 0.0f64;
840 let last = self.period - 1;
841 for j in 0..self.period {
842 let x = self.buffer[j];
843 let xv = if x.is_nan() { 0.0 } else { x };
844
845 n += xv * self.w[j];
846
847 let wn = if j < last { self.w[j + 1] } else { self.w_next };
848 q += xv * wn;
849 }
850 self.acc_n = n;
851 self.acc_d = q - n;
852
853 return Some(if self.nan_count == 0 { n } else { f64::NAN });
854 }
855
856 let x_old = if old_raw.is_nan() { 0.0 } else { old_raw };
857 let x_new = if value.is_nan() { 0.0 } else { value };
858
859 let prev_n = self.acc_n;
860
861 let d_old = self.acc_d;
862 let n_prime = x_new.mul_add(self.w_last, d_old);
863
864 let d_prime = x_new.mul_add(self.w_prev, prev_n - d_old - self.w0 * x_old);
865
866 self.acc_n = n_prime;
867 self.acc_d = d_prime;
868
869 Some(if self.nan_count == 0 {
870 self.dot_ring()
871 } else {
872 f64::NAN
873 })
874 }
875}
876
877impl FwmaStream {
878 #[inline(always)]
879 fn dot_ring(&self) -> f64 {
880 let mut sum = 0.0;
881 let mut idx = self.head;
882 for &wj in &self.w {
883 sum += wj * self.buffer[idx];
884 idx += 1;
885 if idx == self.period {
886 idx = 0;
887 }
888 }
889 sum
890 }
891}
892
893#[derive(Clone, Debug)]
894pub struct FwmaBatchRange {
895 pub period: (usize, usize, usize),
896}
897impl Default for FwmaBatchRange {
898 fn default() -> Self {
899 Self {
900 period: (5, 254, 1),
901 }
902 }
903}
904
905#[derive(Clone, Debug, Default)]
906pub struct FwmaBatchBuilder {
907 range: FwmaBatchRange,
908 kernel: Kernel,
909}
910
911impl FwmaBatchBuilder {
912 pub fn new() -> Self {
913 Self::default()
914 }
915 pub fn kernel(mut self, k: Kernel) -> Self {
916 self.kernel = k;
917 self
918 }
919
920 #[inline]
921 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
922 self.range.period = (start, end, step);
923 self
924 }
925 #[inline]
926 pub fn period_static(mut self, p: usize) -> Self {
927 self.range.period = (p, p, 0);
928 self
929 }
930 pub fn apply_slice(self, data: &[f64]) -> Result<FwmaBatchOutput, FwmaError> {
931 fwma_batch_with_kernel(data, &self.range, self.kernel)
932 }
933 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<FwmaBatchOutput, FwmaError> {
934 FwmaBatchBuilder::new().kernel(k).apply_slice(data)
935 }
936 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<FwmaBatchOutput, FwmaError> {
937 let slice = source_type(c, src);
938 self.apply_slice(slice)
939 }
940 pub fn with_default_candles(c: &Candles) -> Result<FwmaBatchOutput, FwmaError> {
941 FwmaBatchBuilder::new()
942 .kernel(Kernel::Auto)
943 .apply_candles(c, "close")
944 }
945}
946
947pub fn fwma_batch_with_kernel(
948 data: &[f64],
949 sweep: &FwmaBatchRange,
950 k: Kernel,
951) -> Result<FwmaBatchOutput, FwmaError> {
952 let kernel = match k {
953 Kernel::Auto => detect_best_batch_kernel(),
954 other if other.is_batch() => other,
955 other => {
956 return Err(FwmaError::InvalidKernelForBatch(other));
957 }
958 };
959 let simd = match kernel {
960 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
961 Kernel::Avx512Batch => Kernel::Avx512,
962 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
963 Kernel::Avx2Batch => Kernel::Avx2,
964 Kernel::ScalarBatch => Kernel::Scalar,
965 _ => unreachable!(),
966 };
967 fwma_batch_par_slice(data, sweep, simd)
968}
969
970#[derive(Clone, Debug)]
971pub struct FwmaBatchOutput {
972 pub values: Vec<f64>,
973 pub combos: Vec<FwmaParams>,
974 pub rows: usize,
975 pub cols: usize,
976}
977impl FwmaBatchOutput {
978 pub fn row_for_params(&self, p: &FwmaParams) -> Option<usize> {
979 self.combos
980 .iter()
981 .position(|c| c.period.unwrap_or(5) == p.period.unwrap_or(5))
982 }
983 pub fn values_for(&self, p: &FwmaParams) -> Option<&[f64]> {
984 self.row_for_params(p).map(|row| {
985 let start = row * self.cols;
986 &self.values[start..start + self.cols]
987 })
988 }
989}
990
991#[inline(always)]
992fn expand_grid(r: &FwmaBatchRange) -> Vec<FwmaParams> {
993 fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
994 if step == 0 || start == end {
995 return vec![start];
996 }
997 let (lo, hi) = if start <= end {
998 (start, end)
999 } else {
1000 (end, start)
1001 };
1002 (lo..=hi).step_by(step).collect()
1003 }
1004 let periods = axis_usize(r.period);
1005 let mut out = Vec::with_capacity(periods.len());
1006 for &p in &periods {
1007 out.push(FwmaParams { period: Some(p) });
1008 }
1009 out
1010}
1011
1012#[inline(always)]
1013fn fill_nan_prefixes_slice(
1014 rows: usize,
1015 cols: usize,
1016 warmup_periods: &[usize],
1017 out_slice: &mut [f64],
1018) {
1019 for (row, &warmup) in warmup_periods.iter().enumerate() {
1020 let row_start = row * cols;
1021 let row_end = row_start + warmup.min(cols);
1022 for i in row_start..row_end {
1023 out_slice[i] = f64::NAN;
1024 }
1025 }
1026}
1027
1028#[inline(always)]
1029pub fn fwma_batch_slice(
1030 data: &[f64],
1031 sweep: &FwmaBatchRange,
1032 kern: Kernel,
1033) -> Result<FwmaBatchOutput, FwmaError> {
1034 fwma_batch_inner(data, sweep, kern, false)
1035}
1036
1037#[inline(always)]
1038pub fn fwma_batch_par_slice(
1039 data: &[f64],
1040 sweep: &FwmaBatchRange,
1041 kern: Kernel,
1042) -> Result<FwmaBatchOutput, FwmaError> {
1043 fwma_batch_inner(data, sweep, kern, true)
1044}
1045
1046#[inline(always)]
1047fn fwma_batch_inner(
1048 data: &[f64],
1049 sweep: &FwmaBatchRange,
1050 kern: Kernel,
1051 parallel: bool,
1052) -> Result<FwmaBatchOutput, FwmaError> {
1053 let combos = expand_grid(sweep);
1054 if combos.is_empty() {
1055 let (s, e, t) = sweep.period;
1056 return Err(FwmaError::InvalidRange {
1057 start: s,
1058 end: e,
1059 step: t,
1060 });
1061 }
1062
1063 let rows = combos.len();
1064 let cols = data.len();
1065
1066 let _total = rows
1067 .checked_mul(cols)
1068 .ok_or(FwmaError::ArithmeticOverflow {
1069 context: "rows*cols in fwma_batch_inner",
1070 })?;
1071
1072 let mut buf_mu = make_uninit_matrix(rows, cols);
1073
1074 let first = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1075 let warm: Vec<usize> = combos
1076 .iter()
1077 .map(|c| first + c.period.unwrap() - 1)
1078 .collect();
1079
1080 init_matrix_prefixes(&mut buf_mu, cols, &warm);
1081
1082 let mut buf_guard = ManuallyDrop::new(buf_mu);
1083 let values_slice: &mut [f64] = unsafe {
1084 core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
1085 };
1086
1087 fwma_batch_inner_into(data, sweep, kern, parallel, values_slice)?;
1088
1089 let values = unsafe {
1090 Vec::from_raw_parts(
1091 buf_guard.as_mut_ptr() as *mut f64,
1092 buf_guard.len(),
1093 buf_guard.capacity(),
1094 )
1095 };
1096
1097 Ok(FwmaBatchOutput {
1098 values,
1099 combos,
1100 rows,
1101 cols,
1102 })
1103}
1104#[inline(always)]
1105fn fwma_batch_inner_into(
1106 data: &[f64],
1107 sweep: &FwmaBatchRange,
1108 kern: Kernel,
1109 parallel: bool,
1110 out: &mut [f64],
1111) -> Result<Vec<FwmaParams>, FwmaError> {
1112 let combos = expand_grid(sweep);
1113 if combos.is_empty() {
1114 let (s, e, t) = sweep.period;
1115 return Err(FwmaError::InvalidRange {
1116 start: s,
1117 end: e,
1118 step: t,
1119 });
1120 }
1121
1122 let first = data
1123 .iter()
1124 .position(|x| !x.is_nan())
1125 .ok_or(FwmaError::AllValuesNaN)?;
1126 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1127 if data.len() - first < max_p {
1128 return Err(FwmaError::NotEnoughValidData {
1129 needed: max_p,
1130 valid: data.len() - first,
1131 });
1132 }
1133
1134 let rows = combos.len();
1135 let cols = data.len();
1136
1137 let cap = rows
1138 .checked_mul(max_p)
1139 .ok_or(FwmaError::ArithmeticOverflow {
1140 context: "rows*max_p in fwma_batch_inner_into",
1141 })?;
1142 let mut aligned = AlignedVec::with_capacity(cap);
1143 let flat_fib = aligned.as_mut_slice();
1144
1145 for (row, prm) in combos.iter().enumerate() {
1146 let period = prm.period.unwrap();
1147 let base = row * max_p;
1148 let slice = &mut flat_fib[base..base + period];
1149 slice[0] = 1.0;
1150 if period > 1 {
1151 slice[1] = 1.0;
1152 }
1153 for i in 2..period {
1154 slice[i] = slice[i - 1] + slice[i - 2];
1155 }
1156 let sum: f64 = slice[..period].iter().sum();
1157 for w in &mut slice[..period] {
1158 *w /= sum;
1159 }
1160 }
1161
1162 let do_row = |row: usize, dst: &mut [f64]| unsafe {
1163 let period = combos[row].period.unwrap();
1164 let fib_ptr = flat_fib.as_ptr().add(row * max_p);
1165
1166 match kern {
1167 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1168 Kernel::Avx512 => fwma_row_avx512(data, first, period, max_p, fib_ptr, dst),
1169 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1170 Kernel::Avx2 => fwma_row_avx2(data, first, period, max_p, fib_ptr, dst),
1171 _ => fwma_row_scalar(data, first, period, max_p, fib_ptr, dst),
1172 }
1173 };
1174
1175 if parallel {
1176 #[cfg(not(target_arch = "wasm32"))]
1177 {
1178 out.par_chunks_mut(cols)
1179 .enumerate()
1180 .for_each(|(row, slice)| do_row(row, slice));
1181 }
1182 #[cfg(target_arch = "wasm32")]
1183 {
1184 for (row, slice) in out.chunks_mut(cols).enumerate() {
1185 do_row(row, slice);
1186 }
1187 }
1188 } else {
1189 for (row, slice) in out.chunks_mut(cols).enumerate() {
1190 do_row(row, slice);
1191 }
1192 }
1193
1194 Ok(combos)
1195}
1196
1197#[inline]
1198unsafe fn fwma_row_scalar(
1199 data: &[f64],
1200 first: usize,
1201 period: usize,
1202 _stride: usize,
1203 fib_ptr: *const f64,
1204 out: &mut [f64],
1205) {
1206 let fib = std::slice::from_raw_parts(fib_ptr, period);
1207 fwma_scalar(data, fib, period, first, out);
1208}
1209
1210#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1211#[target_feature(enable = "avx2")]
1212#[inline]
1213unsafe fn fwma_row_avx2(
1214 data: &[f64],
1215 first: usize,
1216 period: usize,
1217 _stride: usize,
1218 fib_ptr: *const f64,
1219 out: &mut [f64],
1220) {
1221 let fib = std::slice::from_raw_parts(fib_ptr, period);
1222 fwma_avx2(data, fib, period, first, out);
1223}
1224
1225#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1226#[target_feature(enable = "avx512f,fma")]
1227#[inline]
1228unsafe fn fwma_row_avx512(
1229 data: &[f64],
1230 first: usize,
1231 period: usize,
1232 _stride: usize,
1233 fib_ptr: *const f64,
1234 out: &mut [f64],
1235) {
1236 let fib = std::slice::from_raw_parts(fib_ptr, period);
1237 fwma_avx512(data, fib, period, first, out);
1238}
1239
1240#[cfg(test)]
1241mod tests {
1242 use super::*;
1243 use crate::skip_if_unsupported;
1244 use crate::utilities::data_loader::read_candles_from_csv;
1245
1246 fn check_fwma_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1247 skip_if_unsupported!(kernel, test_name);
1248 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1249 let candles = read_candles_from_csv(file_path)?;
1250
1251 let default_params = FwmaParams { period: None };
1252 let input = FwmaInput::from_candles(&candles, "close", default_params);
1253 let output = fwma_with_kernel(&input, kernel)?;
1254 assert_eq!(output.values.len(), candles.close.len());
1255
1256 Ok(())
1257 }
1258
1259 fn check_fwma_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1260 skip_if_unsupported!(kernel, test_name);
1261 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1262 let candles = read_candles_from_csv(file_path)?;
1263
1264 let input = FwmaInput::with_default_candles(&candles);
1265 let result = fwma_with_kernel(&input, kernel)?;
1266 let expected_last_five = [
1267 59273.583333333336,
1268 59252.5,
1269 59167.083333333336,
1270 59151.0,
1271 58940.333333333336,
1272 ];
1273 let start = result.values.len().saturating_sub(5);
1274 for (i, &val) in result.values[start..].iter().enumerate() {
1275 let diff = (val - expected_last_five[i]).abs();
1276 assert!(
1277 diff < 1e-8,
1278 "[{}] FWMA {:?} mismatch at idx {}: got {}, expected {}",
1279 test_name,
1280 kernel,
1281 i,
1282 val,
1283 expected_last_five[i]
1284 );
1285 }
1286 Ok(())
1287 }
1288
1289 fn check_fwma_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1290 skip_if_unsupported!(kernel, test_name);
1291 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1292 let candles = read_candles_from_csv(file_path)?;
1293
1294 let input = FwmaInput::with_default_candles(&candles);
1295 match input.data {
1296 FwmaData::Candles { source, .. } => assert_eq!(source, "close"),
1297 _ => panic!("Expected FwmaData::Candles"),
1298 }
1299 let output = fwma_with_kernel(&input, kernel)?;
1300 assert_eq!(output.values.len(), candles.close.len());
1301
1302 Ok(())
1303 }
1304
1305 fn check_fwma_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1306 skip_if_unsupported!(kernel, test_name);
1307 let input_data = [10.0, 20.0, 30.0];
1308 let params = FwmaParams { period: Some(0) };
1309 let input = FwmaInput::from_slice(&input_data, params);
1310 let res = fwma_with_kernel(&input, kernel);
1311 assert!(
1312 res.is_err(),
1313 "[{}] FWMA should fail with zero period",
1314 test_name
1315 );
1316 Ok(())
1317 }
1318
1319 fn check_fwma_period_exceeds_length(
1320 test_name: &str,
1321 kernel: Kernel,
1322 ) -> Result<(), Box<dyn Error>> {
1323 skip_if_unsupported!(kernel, test_name);
1324 let data_small = [10.0, 20.0, 30.0];
1325 let params = FwmaParams { period: Some(10) };
1326 let input = FwmaInput::from_slice(&data_small, params);
1327 let res = fwma_with_kernel(&input, kernel);
1328 assert!(
1329 res.is_err(),
1330 "[{}] FWMA should fail with period exceeding length",
1331 test_name
1332 );
1333 Ok(())
1334 }
1335
1336 fn check_fwma_very_small_dataset(
1337 test_name: &str,
1338 kernel: Kernel,
1339 ) -> Result<(), Box<dyn Error>> {
1340 skip_if_unsupported!(kernel, test_name);
1341 let single_point = [42.0];
1342 let params = FwmaParams { period: Some(5) };
1343 let input = FwmaInput::from_slice(&single_point, params);
1344 let res = fwma_with_kernel(&input, kernel);
1345 assert!(
1346 res.is_err(),
1347 "[{}] FWMA should fail with insufficient data",
1348 test_name
1349 );
1350 Ok(())
1351 }
1352
1353 fn check_fwma_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1354 skip_if_unsupported!(kernel, test_name);
1355 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1356 let candles = read_candles_from_csv(file_path)?;
1357
1358 let first_params = FwmaParams { period: Some(5) };
1359 let first_input = FwmaInput::from_candles(&candles, "close", first_params);
1360 let first_result = fwma_with_kernel(&first_input, kernel)?;
1361
1362 let second_params = FwmaParams { period: Some(3) };
1363 let second_input = FwmaInput::from_slice(&first_result.values, second_params);
1364 let second_result = fwma_with_kernel(&second_input, kernel)?;
1365
1366 assert_eq!(second_result.values.len(), first_result.values.len());
1367 for i in 240..second_result.values.len() {
1368 assert!(
1369 !second_result.values[i].is_nan(),
1370 "[{}] NaN found at idx {}",
1371 test_name,
1372 i
1373 );
1374 }
1375 Ok(())
1376 }
1377
1378 fn check_fwma_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1379 skip_if_unsupported!(kernel, test_name);
1380 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1381 let candles = read_candles_from_csv(file_path)?;
1382
1383 let input = FwmaInput::from_candles(&candles, "close", FwmaParams { period: Some(5) });
1384 let res = fwma_with_kernel(&input, kernel)?;
1385 assert_eq!(res.values.len(), candles.close.len());
1386 if res.values.len() > 50 {
1387 for (i, &val) in res.values[50..].iter().enumerate() {
1388 assert!(
1389 !val.is_nan(),
1390 "[{}] Found unexpected NaN at out-index {}",
1391 test_name,
1392 50 + i
1393 );
1394 }
1395 }
1396 Ok(())
1397 }
1398
1399 fn check_fwma_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1400 skip_if_unsupported!(kernel, test_name);
1401
1402 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1403 let candles = read_candles_from_csv(file_path)?;
1404
1405 let period = 5;
1406
1407 let input = FwmaInput::from_candles(
1408 &candles,
1409 "close",
1410 FwmaParams {
1411 period: Some(period),
1412 },
1413 );
1414 let batch_output = fwma_with_kernel(&input, kernel)?.values;
1415
1416 let mut stream = FwmaStream::try_new(FwmaParams {
1417 period: Some(period),
1418 })?;
1419
1420 let mut stream_values = Vec::with_capacity(candles.close.len());
1421 for &price in &candles.close {
1422 match stream.update(price) {
1423 Some(val) => stream_values.push(val),
1424 None => stream_values.push(f64::NAN),
1425 }
1426 }
1427
1428 assert_eq!(batch_output.len(), stream_values.len());
1429 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1430 if b.is_nan() && s.is_nan() {
1431 continue;
1432 }
1433 let diff = (b - s).abs();
1434 assert!(
1435 diff < 1e-9,
1436 "[{}] FWMA streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1437 test_name,
1438 i,
1439 b,
1440 s,
1441 diff
1442 );
1443 }
1444 Ok(())
1445 }
1446
1447 #[cfg(debug_assertions)]
1448 fn check_fwma_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1449 skip_if_unsupported!(kernel, test_name);
1450
1451 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1452 let candles = read_candles_from_csv(file_path)?;
1453
1454 let test_cases = vec![
1455 FwmaParams::default(),
1456 FwmaParams { period: Some(2) },
1457 FwmaParams { period: Some(3) },
1458 FwmaParams { period: Some(5) },
1459 FwmaParams { period: Some(8) },
1460 FwmaParams { period: Some(10) },
1461 FwmaParams { period: Some(15) },
1462 FwmaParams { period: Some(20) },
1463 FwmaParams { period: Some(30) },
1464 FwmaParams { period: Some(50) },
1465 ];
1466
1467 for params in test_cases {
1468 let input = FwmaInput::from_candles(&candles, "close", params.clone());
1469 let output = fwma_with_kernel(&input, kernel)?;
1470
1471 for (i, &val) in output.values.iter().enumerate() {
1472 if val.is_nan() {
1473 continue;
1474 }
1475
1476 let bits = val.to_bits();
1477
1478 if bits == 0x11111111_11111111 {
1479 panic!(
1480 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} with params period={:?}",
1481 test_name, val, bits, i, params.period
1482 );
1483 }
1484
1485 if bits == 0x22222222_22222222 {
1486 panic!(
1487 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} with params period={:?}",
1488 test_name, val, bits, i, params.period
1489 );
1490 }
1491
1492 if bits == 0x33333333_33333333 {
1493 panic!(
1494 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} with params period={:?}",
1495 test_name, val, bits, i, params.period
1496 );
1497 }
1498 }
1499 }
1500
1501 Ok(())
1502 }
1503
1504 #[cfg(not(debug_assertions))]
1505 fn check_fwma_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1506 Ok(())
1507 }
1508
1509 macro_rules! generate_all_fwma_tests {
1510 ($($test_fn:ident),*) => {
1511 paste::paste! {
1512 $(
1513 #[test]
1514 fn [<$test_fn _scalar_f64>]() {
1515 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1516 }
1517
1518 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1519 #[test]
1520 fn [<$test_fn _avx2_f64>]() {
1521 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1522 }
1523
1524 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1525 #[test]
1526 fn [<$test_fn _avx512_f64>]() {
1527 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1528 }
1529
1530
1531 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
1532 #[test]
1533 fn [<$test_fn _simd128_f64>]() {
1534 let _ = $test_fn(stringify!([<$test_fn _simd128_f64>]), Kernel::Scalar);
1535 }
1536 )*
1537 }
1538 }
1539 }
1540
1541 generate_all_fwma_tests!(
1542 check_fwma_partial_params,
1543 check_fwma_accuracy,
1544 check_fwma_default_candles,
1545 check_fwma_zero_period,
1546 check_fwma_period_exceeds_length,
1547 check_fwma_very_small_dataset,
1548 check_fwma_reinput,
1549 check_fwma_nan_handling,
1550 check_fwma_streaming,
1551 check_fwma_no_poison
1552 );
1553
1554 #[cfg(feature = "proptest")]
1555 generate_all_fwma_tests!(check_fwma_property);
1556
1557 #[test]
1558 fn test_fwma_into_matches_api() -> Result<(), Box<dyn Error>> {
1559 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1560 let candles = read_candles_from_csv(file_path)?;
1561
1562 let input = FwmaInput::with_default_candles(&candles);
1563 let baseline = fwma(&input)?.values;
1564
1565 let mut out = vec![0.0f64; baseline.len()];
1566
1567 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1568 {
1569 fwma_into(&input, &mut out)?;
1570 }
1571 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1572 {
1573 fwma_into_slice(&mut out, &input, Kernel::Auto)?;
1574 }
1575
1576 assert_eq!(out.len(), baseline.len());
1577
1578 for (i, (&a, &b)) in out.iter().zip(baseline.iter()).enumerate() {
1579 let equal = (a.is_nan() && b.is_nan()) || (a == b);
1580 assert!(
1581 equal,
1582 "into parity mismatch at idx {}: got {}, expected {}",
1583 i, a, b
1584 );
1585 }
1586
1587 Ok(())
1588 }
1589
1590 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1591 skip_if_unsupported!(kernel, test);
1592
1593 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1594 let c = read_candles_from_csv(file)?;
1595
1596 let output = FwmaBatchBuilder::new()
1597 .kernel(kernel)
1598 .apply_candles(&c, "close")?;
1599
1600 let def = FwmaParams::default();
1601 let row = output.values_for(&def).expect("default row missing");
1602
1603 assert_eq!(row.len(), c.close.len());
1604
1605 let expected = [
1606 59273.583333333336,
1607 59252.5,
1608 59167.083333333336,
1609 59151.0,
1610 58940.333333333336,
1611 ];
1612 let start = row.len() - 5;
1613 for (i, &v) in row[start..].iter().enumerate() {
1614 assert!(
1615 (v - expected[i]).abs() < 1e-8,
1616 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1617 );
1618 }
1619 Ok(())
1620 }
1621
1622 #[cfg(debug_assertions)]
1623 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1624 skip_if_unsupported!(kernel, test);
1625
1626 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1627 let c = read_candles_from_csv(file)?;
1628
1629 let test_configs = vec![
1630 (2, 5, 1),
1631 (5, 15, 2),
1632 (10, 30, 5),
1633 (3, 10, 1),
1634 (5, 50, 10),
1635 (2, 20, 1),
1636 ];
1637
1638 for (start, end, step) in test_configs {
1639 let output = FwmaBatchBuilder::new()
1640 .kernel(kernel)
1641 .period_range(start, end, step)
1642 .apply_candles(&c, "close")?;
1643
1644 for (idx, &val) in output.values.iter().enumerate() {
1645 if val.is_nan() {
1646 continue;
1647 }
1648
1649 let bits = val.to_bits();
1650 let row = idx / output.cols;
1651 let col = idx % output.cols;
1652 let params = &output.combos[row];
1653
1654 if bits == 0x11111111_11111111 {
1655 panic!(
1656 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (params: period={:?})",
1657 test, val, bits, row, col, params.period
1658 );
1659 }
1660
1661 if bits == 0x22222222_22222222 {
1662 panic!(
1663 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (params: period={:?})",
1664 test, val, bits, row, col, params.period
1665 );
1666 }
1667
1668 if bits == 0x33333333_33333333 {
1669 panic!(
1670 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (params: period={:?})",
1671 test, val, bits, row, col, params.period
1672 );
1673 }
1674 }
1675 }
1676
1677 Ok(())
1678 }
1679
1680 #[cfg(not(debug_assertions))]
1681 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1682 Ok(())
1683 }
1684
1685 #[cfg(feature = "proptest")]
1686 #[allow(clippy::float_cmp)]
1687 fn check_fwma_property(
1688 test_name: &str,
1689 kernel: Kernel,
1690 ) -> Result<(), Box<dyn std::error::Error>> {
1691 use proptest::prelude::*;
1692 skip_if_unsupported!(kernel, test_name);
1693
1694 let strat = (1usize..=64).prop_flat_map(|period| {
1695 (
1696 prop::collection::vec(
1697 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1698 period..400,
1699 ),
1700 Just(period),
1701 )
1702 });
1703
1704 proptest::test_runner::TestRunner::default()
1705 .run(&strat, |(mut data, period)| {
1706 if data.len() > period && period > 1 {
1707 if data.len() % 10 == 0 {
1708 data.truncate(period);
1709 }
1710 }
1711
1712 let params = FwmaParams {
1713 period: Some(period),
1714 };
1715 let input = FwmaInput::from_slice(&data, params);
1716
1717 let FwmaOutput { values: out } = fwma_with_kernel(&input, kernel).unwrap();
1718
1719 let FwmaOutput { values: ref_out } =
1720 fwma_with_kernel(&input, Kernel::Scalar).unwrap();
1721
1722 prop_assert_eq!(out.len(), data.len());
1723 prop_assert_eq!(ref_out.len(), data.len());
1724
1725 for i in 0..(period - 1).min(data.len()) {
1726 prop_assert!(
1727 out[i].is_nan(),
1728 "Expected NaN during warmup at index {}, got {}",
1729 i,
1730 out[i]
1731 );
1732 }
1733
1734 if period == 2 && data.len() >= 2 {
1735 let expected = (data[0] + data[1]) / 2.0;
1736 if out[1].is_finite() && data[0].is_finite() && data[1].is_finite() {
1737 prop_assert!(
1738 (out[1] - expected).abs() <= 1e-9,
1739 "Period=2: output {} should equal average {} at index 1",
1740 out[1],
1741 expected
1742 );
1743 }
1744 }
1745
1746 for i in (period - 1)..data.len() {
1747 let window = &data[i + 1 - period..=i];
1748 let lo = window.iter().cloned().fold(f64::INFINITY, f64::min);
1749 let hi = window.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1750 let y = out[i];
1751 let r = ref_out[i];
1752
1753 prop_assert!(
1754 y.is_nan() || (y >= lo - 1e-9 && y <= hi + 1e-9),
1755 "idx {}: {} ∉ [{}, {}]",
1756 i,
1757 y,
1758 lo,
1759 hi
1760 );
1761
1762 if period == 1 {
1763 prop_assert!(
1764 (y - data[i]).abs() <= f64::EPSILON,
1765 "Period=1: output {} should equal input {} at index {}",
1766 y,
1767 data[i],
1768 i
1769 );
1770 }
1771
1772 if data.windows(2).all(|w| w[0] == w[1]) && !data.is_empty() {
1773 prop_assert!(
1774 (y - data[0]).abs() <= 1e-9,
1775 "Constant data: output {} should equal constant {} at index {}",
1776 y,
1777 data[0],
1778 i
1779 );
1780 }
1781
1782 if window.iter().any(|x| x.is_nan()) {
1783 prop_assert!(
1784 y.is_nan(),
1785 "Window contains NaN but output {} is not NaN at index {}",
1786 y,
1787 i
1788 );
1789 }
1790
1791 if !y.is_finite() || !r.is_finite() {
1792 prop_assert!(
1793 y.to_bits() == r.to_bits(),
1794 "finite/NaN mismatch idx {}: {} vs {}",
1795 i,
1796 y,
1797 r
1798 );
1799 continue;
1800 }
1801
1802 let y_bits = y.to_bits();
1803 let r_bits = r.to_bits();
1804 let ulp_diff: u64 = y_bits.abs_diff(r_bits);
1805
1806 prop_assert!(
1807 (y - r).abs() <= 1e-9 || ulp_diff <= 4,
1808 "mismatch idx {}: {} vs {} (ULP={})",
1809 i,
1810 y,
1811 r,
1812 ulp_diff
1813 );
1814 }
1815
1816 let is_monotonic_inc = data.windows(2).all(|w| w[0] <= w[1]);
1817 let is_monotonic_dec = data.windows(2).all(|w| w[0] >= w[1]);
1818
1819 if (is_monotonic_inc || is_monotonic_dec) && data.len() >= period + 1 {
1820 for i in period..out.len() {
1821 if out[i].is_finite() && out[i - 1].is_finite() {
1822 if is_monotonic_inc {
1823 prop_assert!(
1824 out[i] >= out[i-1] - 1e-9,
1825 "Monotonic increasing data but output decreases: {} < {} at index {}",
1826 out[i],
1827 out[i-1],
1828 i
1829 );
1830 }
1831 if is_monotonic_dec {
1832 prop_assert!(
1833 out[i] <= out[i-1] + 1e-9,
1834 "Monotonic decreasing data but output increases: {} > {} at index {}",
1835 out[i],
1836 out[i-1],
1837 i
1838 );
1839 }
1840 }
1841 }
1842 }
1843
1844 if period >= 3 && data.len() >= period * 2 {
1845 let test_start = period;
1846 if test_start + period <= data.len() {
1847 let all_ascending = (0..period).all(|j| {
1848 let idx = test_start + j;
1849 idx == 0
1850 || !data[idx].is_finite()
1851 || !data[idx - 1].is_finite()
1852 || data[idx] >= data[idx - 1]
1853 });
1854
1855 if all_ascending && out[test_start + period - 1].is_finite() {
1856 let window = &data[test_start..test_start + period];
1857 let window_avg = window.iter().sum::<f64>() / period as f64;
1858 if window.iter().all(|x| x.is_finite()) {
1859 prop_assert!(
1860 out[test_start + period - 1] >= window_avg - 1e-9,
1861 "FWMA {} should be >= average {} for ascending window",
1862 out[test_start + period - 1],
1863 window_avg
1864 );
1865 }
1866 }
1867 }
1868 }
1869
1870 if period > 1 && data.len() >= period * 2 {
1871 let data_range = data.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b.abs()));
1872 for &val in &out[(period - 1)..] {
1873 if val.is_finite() && data_range > 0.0 {
1874 prop_assert!(
1875 val.abs() <= data_range * 1.1,
1876 "Output {} exceeds reasonable bounds for data range {}",
1877 val,
1878 data_range
1879 );
1880 }
1881 }
1882 }
1883
1884 if period == 3 && data.len() >= 3 {
1885 let idx = period - 1;
1886 if data[idx - 2].is_finite()
1887 && data[idx - 1].is_finite()
1888 && data[idx].is_finite()
1889 {
1890 let expected =
1891 data[idx - 2] * 0.25 + data[idx - 1] * 0.25 + data[idx] * 0.5;
1892 prop_assert!(
1893 (out[idx] - expected).abs() <= 1e-9,
1894 "Period=3: output {} should equal weighted avg {} at index {}",
1895 out[idx],
1896 expected,
1897 idx
1898 );
1899 }
1900 }
1901
1902 Ok(())
1903 })
1904 .unwrap();
1905
1906 let nan_strat = (2usize..=10).prop_flat_map(|period| (Just(period), 1usize..10));
1907
1908 proptest::test_runner::TestRunner::default()
1909 .run(&nan_strat, |(period, nan_pos)| {
1910 let mut data = vec![
1911 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1912 ];
1913 if nan_pos < data.len() {
1914 data[nan_pos] = f64::NAN;
1915 }
1916
1917 let params = FwmaParams {
1918 period: Some(period),
1919 };
1920 let input = FwmaInput::from_slice(&data, params);
1921 let FwmaOutput { values: out } = fwma_with_kernel(&input, kernel).unwrap();
1922
1923 for i in (period - 1)..data.len() {
1924 let window_start = i + 1 - period;
1925 let window = &data[window_start..=i];
1926 let has_nan = window.iter().any(|x| x.is_nan());
1927
1928 if has_nan {
1929 prop_assert!(
1930 out[i].is_nan(),
1931 "Window [{}, {}] contains NaN but output {} is not NaN at index {}",
1932 window_start,
1933 i,
1934 out[i],
1935 i
1936 );
1937 }
1938 }
1939 Ok(())
1940 })
1941 .unwrap();
1942
1943 let extreme_strat = (1usize..=10).prop_flat_map(|period| (Just(period), prop::bool::ANY));
1944
1945 proptest::test_runner::TestRunner::default()
1946 .run(&extreme_strat, |(period, use_max)| {
1947 let extreme_val = if use_max { 1e308 } else { 1e-308 };
1948 let data = vec![extreme_val; period * 2];
1949
1950 let params = FwmaParams {
1951 period: Some(period),
1952 };
1953 let input = FwmaInput::from_slice(&data, params);
1954 let result = fwma_with_kernel(&input, kernel);
1955
1956 prop_assert!(result.is_ok(), "Failed to handle extreme values");
1957
1958 if let Ok(FwmaOutput { values: out }) = result {
1959 for i in (period - 1)..data.len() {
1960 if out[i].is_finite() {
1961 prop_assert!(
1962 (out[i] - extreme_val).abs() / extreme_val.abs() <= 1e-9,
1963 "Extreme constant value {} doesn't match output {} at index {}",
1964 extreme_val,
1965 out[i],
1966 i
1967 );
1968 }
1969 }
1970 }
1971 Ok(())
1972 })
1973 .unwrap();
1974
1975 Ok(())
1976 }
1977 macro_rules! gen_batch_tests {
1978 ($fn_name:ident) => {
1979 paste::paste! {
1980 #[test]
1981 fn [<$fn_name _scalar>]() {
1982 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1983 }
1984 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1985 #[test]
1986 fn [<$fn_name _avx2>]() {
1987 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1988 }
1989 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1990 #[test]
1991 fn [<$fn_name _avx512>]() {
1992 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1993 }
1994 #[test]
1995 fn [<$fn_name _auto_detect>]() {
1996 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1997 }
1998 }
1999 };
2000 }
2001 gen_batch_tests!(check_batch_default_row);
2002 gen_batch_tests!(check_batch_no_poison);
2003
2004 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2005 #[test]
2006 fn test_fwma_batch_into_warmup() {
2007 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
2008 let len = data.len();
2009
2010 let period_start = 2;
2011 let period_end = 4;
2012 let period_step = 1;
2013
2014 let sweep = FwmaBatchRange {
2015 period: (period_start, period_end, period_step),
2016 };
2017 let combos = expand_grid(&sweep);
2018 let rows = combos.len();
2019
2020 let mut output = vec![999.0; rows * len];
2021
2022 let result = unsafe {
2023 fwma_batch_into(
2024 data.as_ptr(),
2025 output.as_mut_ptr(),
2026 len,
2027 period_start,
2028 period_end,
2029 period_step,
2030 )
2031 };
2032
2033 assert!(result.is_ok());
2034 assert_eq!(result.unwrap(), rows);
2035
2036 for (r, params) in combos.iter().enumerate() {
2037 let period = params.period.unwrap();
2038 let warmup_end = period - 1;
2039
2040 for i in 0..warmup_end {
2041 let idx = r * len + i;
2042 assert!(
2043 output[idx].is_nan(),
2044 "Expected NaN at row {} col {} (period {}) but got {}",
2045 r,
2046 i,
2047 period,
2048 output[idx]
2049 );
2050 }
2051
2052 let first_valid_idx = r * len + warmup_end;
2053 assert!(
2054 !output[first_valid_idx].is_nan(),
2055 "Expected valid value at row {} col {} (period {}) but got NaN",
2056 r,
2057 warmup_end,
2058 period
2059 );
2060 }
2061 }
2062
2063 #[test]
2064 fn print_actual_fwma_values() {
2065 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2066 let candles = read_candles_from_csv(file_path).unwrap();
2067
2068 println!("\nCandle data info:");
2069 println!("Total candles: {}", candles.close.len());
2070 println!("Last 10 close prices:");
2071 let close_len = candles.close.len();
2072 for i in (close_len - 10)..close_len {
2073 println!(" [{}]: {}", i - (close_len - 10), candles.close[i]);
2074 }
2075
2076 let input = FwmaInput::with_default_candles(&candles);
2077 let result = fwma(&input).unwrap();
2078
2079 println!("\nFWMA results (period = 5):");
2080 println!("Last 5 values:");
2081 let result_len = result.values.len();
2082 for i in (result_len - 5)..result_len {
2083 println!(" [{}]: {:.12}", i - (result_len - 5), result.values[i]);
2084 }
2085
2086 let expected_last_five = [
2087 59273.583333333336,
2088 59252.5,
2089 59167.083333333336,
2090 59151.0,
2091 58940.333333333336,
2092 ];
2093
2094 println!("\nExpected values:");
2095 for (i, val) in expected_last_five.iter().enumerate() {
2096 println!(" [{}]: {:.12}", i, val);
2097 }
2098
2099 println!("\nDifferences:");
2100 for i in 0..5 {
2101 let actual = result.values[result_len - 5 + i];
2102 let expected = expected_last_five[i];
2103 println!(
2104 " [{}]: {:.12} (diff: {:.2e})",
2105 i,
2106 actual - expected,
2107 (actual - expected).abs()
2108 );
2109 }
2110 }
2111}
2112
2113#[cfg(all(feature = "python", feature = "cuda"))]
2114#[pyfunction(name = "fwma_cuda_batch_dev")]
2115#[pyo3(signature = (data, period_range, device_id=0))]
2116pub fn fwma_cuda_batch_dev_py(
2117 py: Python<'_>,
2118 data: PyReadonlyArray1<'_, f64>,
2119 period_range: (usize, usize, usize),
2120 device_id: usize,
2121) -> PyResult<DeviceArrayF32Py> {
2122 use numpy::PyArrayMethods;
2123
2124 if !cuda_available() {
2125 return Err(PyValueError::new_err("CUDA not available"));
2126 }
2127
2128 let slice_in = data.as_slice()?;
2129 let sweep = FwmaBatchRange {
2130 period: period_range,
2131 };
2132 let data_f32: Vec<f32> = slice_in.iter().map(|&v| v as f32).collect();
2133
2134 let (inner, ctx, dev_id) = py.allow_threads(|| {
2135 let cuda = CudaFwma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2136 let ctx = cuda.context_arc_clone();
2137 let dev = cuda.device_id();
2138 let arr = cuda
2139 .fwma_batch_dev(&data_f32, &sweep)
2140 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2141 Ok::<_, PyErr>((arr, ctx, dev))
2142 })?;
2143
2144 Ok(DeviceArrayF32Py {
2145 inner,
2146 _ctx: ctx,
2147 device_id: dev_id,
2148 stream: 0,
2149 })
2150}
2151
2152#[cfg(all(feature = "python", feature = "cuda"))]
2153#[pyfunction(name = "fwma_cuda_many_series_one_param_dev")]
2154#[pyo3(signature = (data_tm_f32, period, device_id=0))]
2155pub fn fwma_cuda_many_series_one_param_dev_py(
2156 py: Python<'_>,
2157 data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
2158 period: usize,
2159 device_id: usize,
2160) -> PyResult<DeviceArrayF32Py> {
2161 use numpy::PyUntypedArrayMethods;
2162
2163 if !cuda_available() {
2164 return Err(PyValueError::new_err("CUDA not available"));
2165 }
2166
2167 let flat_in = data_tm_f32.as_slice()?;
2168 let rows = data_tm_f32.shape()[0];
2169 let cols = data_tm_f32.shape()[1];
2170 let params = FwmaParams {
2171 period: Some(period),
2172 };
2173
2174 let (inner, ctx, dev_id) = py.allow_threads(|| {
2175 let cuda = CudaFwma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2176 let ctx = cuda.context_arc_clone();
2177 let dev = cuda.device_id();
2178 let arr = cuda
2179 .fwma_multi_series_one_param_time_major_dev(flat_in, cols, rows, ¶ms)
2180 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2181 Ok::<_, PyErr>((arr, ctx, dev))
2182 })?;
2183
2184 Ok(DeviceArrayF32Py {
2185 inner,
2186 _ctx: ctx,
2187 device_id: dev_id,
2188 stream: 0,
2189 })
2190}
2191
2192#[cfg(feature = "python")]
2193#[pyfunction(name = "fwma")]
2194#[pyo3(signature = (data, period, kernel=None))]
2195pub fn fwma_py<'py>(
2196 py: Python<'py>,
2197 data: numpy::PyReadonlyArray1<'py, f64>,
2198 period: usize,
2199 kernel: Option<&str>,
2200) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
2201 use numpy::PyArrayMethods;
2202
2203 let slice_in = data.as_slice()?;
2204 let kern = validate_kernel(kernel, false)?;
2205
2206 let params = FwmaParams {
2207 period: Some(period),
2208 };
2209 let fwma_in = FwmaInput::from_slice(slice_in, params);
2210
2211 let result_vec: Vec<f64> = py
2212 .allow_threads(|| fwma_with_kernel(&fwma_in, kern).map(|o| o.values))
2213 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2214
2215 Ok(result_vec.into_pyarray(py))
2216}
2217
2218#[cfg(feature = "python")]
2219#[pyclass(name = "FwmaStream")]
2220pub struct FwmaStreamPy {
2221 stream: FwmaStream,
2222}
2223
2224#[cfg(feature = "python")]
2225#[pymethods]
2226impl FwmaStreamPy {
2227 #[new]
2228 fn new(period: usize) -> PyResult<Self> {
2229 let params = FwmaParams {
2230 period: Some(period),
2231 };
2232 let stream =
2233 FwmaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2234 Ok(FwmaStreamPy { stream })
2235 }
2236
2237 fn update(&mut self, value: f64) -> Option<f64> {
2238 self.stream.update(value)
2239 }
2240}
2241
2242#[cfg(feature = "python")]
2243#[pyfunction(name = "fwma_batch")]
2244#[pyo3(signature = (data, period_range, kernel=None))]
2245pub fn fwma_batch_py<'py>(
2246 py: Python<'py>,
2247 data: numpy::PyReadonlyArray1<'py, f64>,
2248 period_range: (usize, usize, usize),
2249 kernel: Option<&str>,
2250) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
2251 use numpy::PyArrayMethods;
2252
2253 let slice_in = data.as_slice()?;
2254
2255 let sweep = FwmaBatchRange {
2256 period: period_range,
2257 };
2258
2259 let combos = expand_grid(&sweep);
2260 let rows = combos.len();
2261 let cols = slice_in.len();
2262
2263 let total = rows
2264 .checked_mul(cols)
2265 .ok_or_else(|| PyValueError::new_err("fwma_batch: rows*cols overflow"))?;
2266 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2267 let slice_out = unsafe { out_arr.as_slice_mut()? };
2268
2269 let kern = validate_kernel(kernel, true)?;
2270
2271 let first = slice_in.iter().position(|x| !x.is_nan()).unwrap_or(0);
2272 let warm: Vec<usize> = combos
2273 .iter()
2274 .map(|c| first + c.period.unwrap() - 1)
2275 .collect();
2276
2277 fill_nan_prefixes_slice(rows, cols, &warm, slice_out);
2278
2279 let combos = py
2280 .allow_threads(|| {
2281 let kernel = match kern {
2282 Kernel::Auto => detect_best_batch_kernel(),
2283 k => k,
2284 };
2285 let simd = match kernel {
2286 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2287 Kernel::Avx512Batch => Kernel::Avx512,
2288 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2289 Kernel::Avx2Batch => Kernel::Avx2,
2290 Kernel::ScalarBatch => Kernel::Scalar,
2291 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
2292 _ => Kernel::Scalar,
2293 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2294 _ => unreachable!(),
2295 };
2296 fwma_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
2297 })
2298 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2299
2300 let dict = PyDict::new(py);
2301 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2302 dict.set_item(
2303 "periods",
2304 combos
2305 .iter()
2306 .map(|p| p.period.unwrap() as u64)
2307 .collect::<Vec<_>>()
2308 .into_pyarray(py),
2309 )?;
2310
2311 Ok(dict)
2312}
2313
2314#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2315#[wasm_bindgen]
2316pub fn fwma_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
2317 let params = FwmaParams {
2318 period: Some(period),
2319 };
2320 let input = FwmaInput::from_slice(data, params);
2321
2322 let mut output = vec![0.0; data.len()];
2323
2324 fwma_into_slice(&mut output, &input, Kernel::Auto)
2325 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2326
2327 Ok(output)
2328}
2329
2330#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2331#[wasm_bindgen]
2332pub fn fwma_batch_js(
2333 data: &[f64],
2334 period_start: usize,
2335 period_end: usize,
2336 period_step: usize,
2337) -> Result<Vec<f64>, JsValue> {
2338 let sweep = FwmaBatchRange {
2339 period: (period_start, period_end, period_step),
2340 };
2341
2342 fwma_batch_inner(data, &sweep, Kernel::Scalar, false)
2343 .map(|output| output.values)
2344 .map_err(|e| JsValue::from_str(&e.to_string()))
2345}
2346
2347#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2348#[wasm_bindgen]
2349pub fn fwma_batch_metadata_js(
2350 period_start: usize,
2351 period_end: usize,
2352 period_step: usize,
2353) -> Vec<f64> {
2354 fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
2355 if step == 0 || start == end {
2356 return vec![start];
2357 }
2358 let (lo, hi) = if start <= end {
2359 (start, end)
2360 } else {
2361 (end, start)
2362 };
2363 (lo..=hi).step_by(step).collect()
2364 }
2365
2366 let periods = axis_usize((period_start, period_end, period_step));
2367 periods.into_iter().map(|p| p as f64).collect()
2368}
2369
2370#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2371#[wasm_bindgen]
2372pub fn fwma_alloc(len: usize) -> *mut f64 {
2373 let mut vec = Vec::<f64>::with_capacity(len);
2374 let ptr = vec.as_mut_ptr();
2375 std::mem::forget(vec);
2376 ptr
2377}
2378
2379#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2380#[wasm_bindgen]
2381pub fn fwma_free(ptr: *mut f64, len: usize) {
2382 if !ptr.is_null() {
2383 unsafe {
2384 let _ = Vec::from_raw_parts(ptr, len, len);
2385 }
2386 }
2387}
2388
2389#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2390#[wasm_bindgen]
2391pub fn fwma_into(
2392 in_ptr: *const f64,
2393 out_ptr: *mut f64,
2394 len: usize,
2395 period: usize,
2396) -> Result<(), JsValue> {
2397 if in_ptr.is_null() || out_ptr.is_null() {
2398 return Err(JsValue::from_str("null pointer passed to fwma_into"));
2399 }
2400
2401 unsafe {
2402 let data = std::slice::from_raw_parts(in_ptr, len);
2403
2404 if period == 0 || period > len {
2405 return Err(JsValue::from_str("Invalid period"));
2406 }
2407
2408 let params = FwmaParams {
2409 period: Some(period),
2410 };
2411 let input = FwmaInput::from_slice(data, params);
2412
2413 if in_ptr == out_ptr {
2414 let mut temp = vec![0.0; len];
2415 fwma_into_slice(&mut temp, &input, Kernel::Auto)
2416 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2417
2418 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2419 out.copy_from_slice(&temp);
2420 } else {
2421 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2422 fwma_into_slice(out, &input, Kernel::Auto)
2423 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2424 }
2425
2426 Ok(())
2427 }
2428}
2429
2430#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2431#[derive(Serialize, Deserialize)]
2432pub struct FwmaBatchConfig {
2433 pub period_range: (usize, usize, usize),
2434}
2435
2436#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2437#[derive(Serialize, Deserialize)]
2438pub struct FwmaBatchJsOutput {
2439 pub values: Vec<f64>,
2440 pub combos: Vec<FwmaParams>,
2441 pub rows: usize,
2442 pub cols: usize,
2443}
2444
2445#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2446#[wasm_bindgen(js_name = fwma_batch)]
2447pub fn fwma_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2448 let config: FwmaBatchConfig = serde_wasm_bindgen::from_value(config)
2449 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2450
2451 let sweep = FwmaBatchRange {
2452 period: config.period_range,
2453 };
2454
2455 let output = fwma_batch_inner(data, &sweep, Kernel::Auto, false)
2456 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2457
2458 let js_output = FwmaBatchJsOutput {
2459 values: output.values,
2460 combos: output.combos,
2461 rows: output.rows,
2462 cols: output.cols,
2463 };
2464
2465 serde_wasm_bindgen::to_value(&js_output)
2466 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2467}
2468
2469#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2470#[wasm_bindgen]
2471pub fn fwma_batch_into(
2472 in_ptr: *const f64,
2473 out_ptr: *mut f64,
2474 len: usize,
2475 period_start: usize,
2476 period_end: usize,
2477 period_step: usize,
2478) -> Result<usize, JsValue> {
2479 if in_ptr.is_null() || out_ptr.is_null() {
2480 return Err(JsValue::from_str("null pointer passed to fwma_batch_into"));
2481 }
2482
2483 unsafe {
2484 let data = std::slice::from_raw_parts(in_ptr, len);
2485
2486 let sweep = FwmaBatchRange {
2487 period: (period_start, period_end, period_step),
2488 };
2489
2490 let combos = expand_grid(&sweep);
2491 let rows = combos.len();
2492 let cols = len;
2493
2494 let total = rows
2495 .checked_mul(cols)
2496 .ok_or(JsValue::from_str("fwma_batch_into: rows*cols overflow"))?;
2497
2498 let out = std::slice::from_raw_parts_mut(out_ptr, total);
2499
2500 let first = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
2501 let warm: Vec<usize> = combos
2502 .iter()
2503 .map(|c| first + c.period.unwrap() - 1)
2504 .collect();
2505
2506 fill_nan_prefixes_slice(rows, cols, &warm, out);
2507
2508 fwma_batch_inner_into(data, &sweep, Kernel::Auto, false, out)
2509 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2510
2511 Ok(rows)
2512 }
2513}