1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::cuda_available;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::cuda::moving_averages::wma_wrapper::DeviceArrayF32Py;
5#[cfg(all(feature = "python", feature = "cuda"))]
6use crate::cuda::moving_averages::CudaWma;
7#[cfg(feature = "python")]
8use crate::utilities::kernel_validation::validate_kernel;
9#[cfg(feature = "python")]
10use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1, PyReadonlyArray2};
11#[cfg(feature = "python")]
12use pyo3::exceptions::PyValueError;
13#[cfg(feature = "python")]
14use pyo3::prelude::*;
15#[cfg(feature = "python")]
16use pyo3::types::PyDict;
17#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
18use serde::{Deserialize, Serialize};
19#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
20use wasm_bindgen::prelude::*;
21
22use crate::utilities::data_loader::{source_type, Candles};
23use crate::utilities::enums::Kernel;
24use crate::utilities::helpers::{
25 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
26 make_uninit_matrix,
27};
28use aligned_vec::{AVec, CACHELINE_ALIGN};
29#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
30use core::arch::x86_64::*;
31#[cfg(not(target_arch = "wasm32"))]
32use rayon::prelude::*;
33use std::convert::AsRef;
34use std::error::Error;
35use std::mem::MaybeUninit;
36use thiserror::Error;
37
38impl<'a> AsRef<[f64]> for WmaInput<'a> {
39 #[inline(always)]
40 fn as_ref(&self) -> &[f64] {
41 match &self.data {
42 WmaData::Slice(slice) => slice,
43 WmaData::Candles { candles, source } => source_type(candles, source),
44 }
45 }
46}
47
48#[derive(Debug, Clone)]
49pub enum WmaData<'a> {
50 Candles {
51 candles: &'a Candles,
52 source: &'a str,
53 },
54 Slice(&'a [f64]),
55}
56
57#[derive(Debug, Clone)]
58pub struct WmaOutput {
59 pub values: Vec<f64>,
60}
61
62#[derive(Debug, Clone)]
63#[cfg_attr(
64 all(target_arch = "wasm32", feature = "wasm"),
65 derive(Serialize, Deserialize)
66)]
67pub struct WmaParams {
68 pub period: Option<usize>,
69}
70
71impl Default for WmaParams {
72 fn default() -> Self {
73 Self { period: Some(30) }
74 }
75}
76
77#[derive(Debug, Clone)]
78pub struct WmaInput<'a> {
79 pub data: WmaData<'a>,
80 pub params: WmaParams,
81}
82
83impl<'a> WmaInput<'a> {
84 #[inline]
85 pub fn from_candles(c: &'a Candles, s: &'a str, p: WmaParams) -> Self {
86 Self {
87 data: WmaData::Candles {
88 candles: c,
89 source: s,
90 },
91 params: p,
92 }
93 }
94 #[inline]
95 pub fn from_slice(sl: &'a [f64], p: WmaParams) -> Self {
96 Self {
97 data: WmaData::Slice(sl),
98 params: p,
99 }
100 }
101 #[inline]
102 pub fn with_default_candles(c: &'a Candles) -> Self {
103 Self::from_candles(c, "close", WmaParams::default())
104 }
105 #[inline]
106 pub fn get_period(&self) -> usize {
107 self.params.period.unwrap_or(30)
108 }
109}
110
111#[derive(Copy, Clone, Debug)]
112pub struct WmaBuilder {
113 period: Option<usize>,
114 kernel: Kernel,
115}
116
117impl Default for WmaBuilder {
118 fn default() -> Self {
119 Self {
120 period: None,
121 kernel: Kernel::Auto,
122 }
123 }
124}
125
126impl WmaBuilder {
127 #[inline(always)]
128 pub fn new() -> Self {
129 Self::default()
130 }
131 #[inline(always)]
132 pub fn period(mut self, n: usize) -> Self {
133 self.period = Some(n);
134 self
135 }
136 #[inline(always)]
137 pub fn kernel(mut self, k: Kernel) -> Self {
138 self.kernel = k;
139 self
140 }
141
142 #[inline(always)]
143 pub fn apply(self, c: &Candles) -> Result<WmaOutput, WmaError> {
144 let p = WmaParams {
145 period: self.period,
146 };
147 let i = WmaInput::from_candles(c, "close", p);
148 wma_with_kernel(&i, self.kernel)
149 }
150
151 #[inline(always)]
152 pub fn apply_slice(self, d: &[f64]) -> Result<WmaOutput, WmaError> {
153 let p = WmaParams {
154 period: self.period,
155 };
156 let i = WmaInput::from_slice(d, p);
157 wma_with_kernel(&i, self.kernel)
158 }
159
160 #[inline(always)]
161 pub fn into_stream(self) -> Result<WmaStream, WmaError> {
162 let p = WmaParams {
163 period: self.period,
164 };
165 WmaStream::try_new(p)
166 }
167}
168
169#[derive(Debug, Error)]
170pub enum WmaError {
171 #[error("wma: Input data slice is empty.")]
172 EmptyInputData,
173
174 #[error("wma: All values are NaN.")]
175 AllValuesNaN,
176
177 #[error("wma: Invalid period: period = {period}, data length = {data_len}")]
178 InvalidPeriod { period: usize, data_len: usize },
179
180 #[error("wma: Not enough valid data: needed = {needed}, valid = {valid}")]
181 NotEnoughValidData { needed: usize, valid: usize },
182
183 #[error("wma: Non-batch kernel passed to batch path: {0:?}")]
184 InvalidKernelForBatch(crate::utilities::enums::Kernel),
185
186 #[error("wma: Output length mismatch: expected = {expected}, got = {got}")]
187 OutputLengthMismatch { expected: usize, got: usize },
188
189 #[error("wma: Invalid range expansion: start = {start}, end = {end}, step = {step}")]
190 InvalidRange {
191 start: usize,
192 end: usize,
193 step: usize,
194 },
195
196 #[error("wma: invalid input: {0}")]
197 InvalidInput(String),
198}
199
200#[inline]
201pub fn wma(input: &WmaInput) -> Result<WmaOutput, WmaError> {
202 wma_with_kernel(input, Kernel::Auto)
203}
204
205pub fn wma_with_kernel(input: &WmaInput, kernel: Kernel) -> Result<WmaOutput, WmaError> {
206 let (data, period, first, chosen) = wma_prepare(input, kernel)?;
207 let len = data.len();
208 let warm = first + period - 1;
209 let mut out = alloc_with_nan_prefix(len, warm);
210
211 wma_compute_into(data, period, first, chosen, &mut out);
212
213 Ok(WmaOutput { values: out })
214}
215
216#[inline]
217pub fn wma_into_slice(dst: &mut [f64], input: &WmaInput, kern: Kernel) -> Result<(), WmaError> {
218 let (data, period, first, chosen) = wma_prepare(input, kern)?;
219
220 if dst.len() != data.len() {
221 return Err(WmaError::OutputLengthMismatch {
222 expected: data.len(),
223 got: dst.len(),
224 });
225 }
226
227 wma_compute_into(data, period, first, chosen, dst);
228
229 let warmup_end = first + period - 1;
230 for v in &mut dst[..warmup_end] {
231 *v = f64::NAN;
232 }
233
234 Ok(())
235}
236
237#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
238#[inline]
239pub fn wma_into(input: &WmaInput, out: &mut [f64]) -> Result<(), WmaError> {
240 wma_into_slice(out, input, Kernel::Auto)
241}
242
243fn wma_prepare<'a>(
244 input: &'a WmaInput,
245 kernel: Kernel,
246) -> Result<(&'a [f64], usize, usize, Kernel), WmaError> {
247 let data: &[f64] = input.as_ref();
248 let len = data.len();
249 if len == 0 {
250 return Err(WmaError::EmptyInputData);
251 }
252
253 let first = data
254 .iter()
255 .position(|x| !x.is_nan())
256 .ok_or(WmaError::AllValuesNaN)?;
257 let period = input.get_period();
258
259 if period < 2 || period > len {
260 return Err(WmaError::InvalidPeriod {
261 period,
262 data_len: len,
263 });
264 }
265 if len - first < period {
266 return Err(WmaError::NotEnoughValidData {
267 needed: period,
268 valid: len - first,
269 });
270 }
271
272 let chosen = match kernel {
273 Kernel::Auto => Kernel::Scalar,
274 k => k,
275 };
276
277 Ok((data, period, first, chosen))
278}
279
280#[inline(always)]
281fn wma_compute_into(data: &[f64], period: usize, first: usize, kernel: Kernel, out: &mut [f64]) {
282 unsafe {
283 match kernel {
284 Kernel::Scalar | Kernel::ScalarBatch => wma_scalar(data, period, first, out),
285 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
286 Kernel::Avx2 | Kernel::Avx2Batch => wma_avx2(data, period, first, out),
287 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
288 Kernel::Avx512 | Kernel::Avx512Batch => wma_avx512(data, period, first, out),
289 _ => unreachable!(),
290 }
291 }
292}
293
294#[inline]
295pub fn wma_scalar(data: &[f64], period: usize, first_val: usize, out: &mut [f64]) {
296 debug_assert_eq!(out.len(), data.len());
297 let lookback = period - 1;
298 let period_f = period as f64;
299
300 let weights = period_f * (period_f + 1.0) * 0.5;
301
302 unsafe {
303 let base = data.as_ptr().add(first_val);
304 let end = data.as_ptr().add(data.len());
305
306 let mut sum = 0.0_f64;
307 let mut weight_sum = 0.0_f64;
308
309 let mut k = 0usize;
310 while k < lookback {
311 let v = *base.add(k);
312 weight_sum += v * (k as f64 + 1.0);
313 sum += v;
314 k += 1;
315 }
316
317 let mut in_new = base.add(lookback);
318 let mut in_old = base;
319 let mut out_ptr = out.as_mut_ptr().add(first_val + lookback);
320
321 while in_new < end {
322 let v = *in_new;
323 weight_sum += v * period_f;
324 sum += v;
325
326 *out_ptr = weight_sum / weights;
327
328 weight_sum -= sum;
329 sum -= *in_old;
330
331 in_new = in_new.add(1);
332 in_old = in_old.add(1);
333 out_ptr = out_ptr.add(1);
334 }
335 }
336}
337
338#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
339#[inline]
340pub fn wma_avx2(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
341 wma_scalar(data, period, first_valid, out)
342}
343
344#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
345#[inline]
346pub fn wma_avx512(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
347 wma_scalar(data, period, first_valid, out)
348}
349
350#[inline]
351pub fn wma_avx512_short(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
352 wma_scalar(data, period, first_valid, out)
353}
354
355#[inline]
356pub fn wma_avx512_long(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
357 wma_scalar(data, period, first_valid, out)
358}
359
360#[inline(always)]
361pub fn wma_with_kernel_batch(
362 data: &[f64],
363 sweep: &WmaBatchRange,
364 k: Kernel,
365) -> Result<WmaBatchOutput, WmaError> {
366 let kernel = match k {
367 Kernel::Auto => detect_best_batch_kernel(),
368 other if other.is_batch() => other,
369 _ => return Err(WmaError::InvalidKernelForBatch(k)),
370 };
371
372 let simd = match kernel {
373 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
374 Kernel::Avx512Batch => Kernel::Avx512,
375 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
376 Kernel::Avx2Batch => Kernel::Avx2,
377 Kernel::ScalarBatch => Kernel::Scalar,
378 _ => Kernel::Scalar,
379 };
380 wma_batch_par_slice(data, sweep, simd)
381}
382
383#[derive(Clone, Debug)]
384pub struct WmaStream {
385 period: usize,
386 buffer: Vec<f64>,
387 head: usize,
388 filled: bool,
389
390 plain_sum: f64,
391 weighted_sum: f64,
392 inv_div: f64,
393 p_f64: f64,
394}
395
396impl WmaStream {
397 pub fn try_new(params: WmaParams) -> Result<Self, WmaError> {
398 let period = params.period.unwrap_or(30);
399 if period < 2 {
400 return Err(WmaError::InvalidPeriod {
401 period,
402 data_len: 0,
403 });
404 }
405
406 let sum_of_weights = (period * (period + 1)) as f64 * 0.5;
407 Ok(Self {
408 period,
409 buffer: vec![f64::NAN; period],
410 head: 0,
411 filled: false,
412 plain_sum: 0.0,
413 weighted_sum: 0.0,
414 inv_div: 1.0 / sum_of_weights,
415 p_f64: period as f64,
416 })
417 }
418
419 #[inline(always)]
420 pub fn update(&mut self, value: f64) -> Option<f64> {
421 let write_idx = self.head;
422 self.buffer[write_idx] = value;
423
424 self.head += 1;
425 if self.head == self.period {
426 self.head = 0;
427 }
428
429 if !self.filled {
430 if self.head == 0 {
431 let mut wsum = 0.0;
432 let mut ssum = 0.0;
433
434 let mut idx = self.head;
435 for w in 1..=self.period {
436 let v = self.buffer[idx];
437 ssum += v;
438 wsum += (w as f64) * v;
439 idx += 1;
440 if idx == self.period {
441 idx = 0;
442 }
443 }
444 let out = wsum * self.inv_div;
445
446 self.weighted_sum = wsum - ssum;
447 let oldest_next = self.buffer[self.head];
448 self.plain_sum = ssum - oldest_next;
449
450 self.filled = true;
451 Some(out)
452 } else {
453 None
454 }
455 } else {
456 let oldest = self.buffer[self.head];
457 self.weighted_sum += self.p_f64 * value;
458 self.plain_sum += value;
459
460 let out = self.weighted_sum * self.inv_div;
461
462 self.weighted_sum -= self.plain_sum;
463 self.plain_sum -= oldest;
464
465 Some(out)
466 }
467 }
468}
469
470#[derive(Clone, Debug)]
471pub struct WmaBatchRange {
472 pub period: (usize, usize, usize),
473}
474
475impl Default for WmaBatchRange {
476 fn default() -> Self {
477 Self {
478 period: (2, 251, 1),
479 }
480 }
481}
482
483#[derive(Clone, Debug, Default)]
484pub struct WmaBatchBuilder {
485 range: WmaBatchRange,
486 kernel: Kernel,
487}
488
489impl WmaBatchBuilder {
490 pub fn new() -> Self {
491 Self::default()
492 }
493 pub fn kernel(mut self, k: Kernel) -> Self {
494 self.kernel = k;
495 self
496 }
497
498 #[inline]
499 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
500 self.range.period = (start, end, step);
501 self
502 }
503 #[inline]
504 pub fn period_static(mut self, p: usize) -> Self {
505 self.range.period = (p, p, 0);
506 self
507 }
508
509 pub fn apply_slice(self, data: &[f64]) -> Result<WmaBatchOutput, WmaError> {
510 wma_with_kernel_batch(data, &self.range, self.kernel)
511 }
512
513 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<WmaBatchOutput, WmaError> {
514 WmaBatchBuilder::new().kernel(k).apply_slice(data)
515 }
516
517 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<WmaBatchOutput, WmaError> {
518 let slice = source_type(c, src);
519 self.apply_slice(slice)
520 }
521
522 pub fn with_default_candles(c: &Candles) -> Result<WmaBatchOutput, WmaError> {
523 WmaBatchBuilder::new()
524 .kernel(Kernel::Auto)
525 .apply_candles(c, "close")
526 }
527}
528
529#[derive(Clone, Debug)]
530pub struct WmaBatchOutput {
531 pub values: Vec<f64>,
532 pub combos: Vec<WmaParams>,
533 pub rows: usize,
534 pub cols: usize,
535}
536impl WmaBatchOutput {
537 pub fn row_for_params(&self, p: &WmaParams) -> Option<usize> {
538 self.combos
539 .iter()
540 .position(|c| c.period.unwrap_or(30) == p.period.unwrap_or(30))
541 }
542
543 pub fn values_for(&self, p: &WmaParams) -> Option<&[f64]> {
544 self.row_for_params(p).map(|row| {
545 let start = row * self.cols;
546 &self.values[start..start + self.cols]
547 })
548 }
549}
550
551#[inline(always)]
552fn expand_grid(r: &WmaBatchRange) -> Vec<WmaParams> {
553 fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
554 if step == 0 || start == end {
555 return vec![start];
556 }
557 if start < end {
558 return (start..=end).step_by(step.max(1)).collect();
559 }
560
561 let mut out = Vec::new();
562 let mut x = start as isize;
563 let end_i = end as isize;
564 let st = (step as isize).max(1);
565 while x >= end_i {
566 out.push(x as usize);
567 x -= st;
568 }
569 if out.is_empty() {
570 return out;
571 }
572 if *out.last().unwrap() != end {
573 out.push(end);
574 }
575 out
576 }
577
578 let periods = axis_usize(r.period);
579 let mut out = Vec::with_capacity(periods.len());
580 for &p in &periods {
581 out.push(WmaParams { period: Some(p) });
582 }
583 out
584}
585
586#[inline(always)]
587pub fn wma_batch_slice(
588 data: &[f64],
589 sweep: &WmaBatchRange,
590 kern: Kernel,
591) -> Result<WmaBatchOutput, WmaError> {
592 wma_batch_inner(data, sweep, kern, false)
593}
594
595#[inline(always)]
596pub fn wma_batch_par_slice(
597 data: &[f64],
598 sweep: &WmaBatchRange,
599 kern: Kernel,
600) -> Result<WmaBatchOutput, WmaError> {
601 wma_batch_inner(data, sweep, kern, true)
602}
603
604#[inline(always)]
605fn wma_batch_inner(
606 data: &[f64],
607 sweep: &WmaBatchRange,
608 kern: Kernel,
609 parallel: bool,
610) -> Result<WmaBatchOutput, WmaError> {
611 let combos = expand_grid(sweep);
612 let rows = combos.len();
613 let cols = data.len();
614 if cols == 0 {
615 return Err(WmaError::EmptyInputData);
616 }
617
618 rows.checked_mul(cols)
619 .ok_or_else(|| WmaError::InvalidInput("rows*cols overflow".into()))?;
620
621 let mut buf_mu = make_uninit_matrix(rows, cols);
622
623 let first = data
624 .iter()
625 .position(|x| !x.is_nan())
626 .ok_or(WmaError::AllValuesNaN)?;
627 let warm: Vec<usize> = combos
628 .iter()
629 .map(|c| first + c.period.unwrap() - 1)
630 .collect();
631 init_matrix_prefixes(&mut buf_mu, cols, &warm);
632
633 let mut guard = core::mem::ManuallyDrop::new(buf_mu);
634 let out: &mut [f64] =
635 unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
636
637 wma_batch_inner_into(data, sweep, kern, parallel, out)?;
638
639 let values = unsafe {
640 Vec::from_raw_parts(
641 guard.as_mut_ptr() as *mut f64,
642 guard.len(),
643 guard.capacity(),
644 )
645 };
646
647 Ok(WmaBatchOutput {
648 values,
649 combos,
650 rows,
651 cols,
652 })
653}
654
655#[inline(always)]
656fn wma_batch_inner_into(
657 data: &[f64],
658 sweep: &WmaBatchRange,
659 kern: Kernel,
660 parallel: bool,
661 out: &mut [f64],
662) -> Result<Vec<WmaParams>, WmaError> {
663 let combos = expand_grid(sweep);
664 if combos.is_empty() {
665 let (start, end, step) = sweep.period;
666 return Err(WmaError::InvalidRange { start, end, step });
667 }
668
669 let first = data
670 .iter()
671 .position(|x| !x.is_nan())
672 .ok_or(WmaError::AllValuesNaN)?;
673 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
674 if data.len() - first < max_p {
675 return Err(WmaError::NotEnoughValidData {
676 needed: max_p,
677 valid: data.len() - first,
678 });
679 }
680
681 let rows = combos.len();
682 let cols = data.len();
683
684 let needed = rows
685 .checked_mul(cols)
686 .ok_or_else(|| WmaError::InvalidInput("rows*cols overflow".into()))?;
687 if out.len() != needed {
688 return Err(WmaError::OutputLengthMismatch {
689 expected: needed,
690 got: out.len(),
691 });
692 }
693
694 let warm: Vec<usize> = combos
695 .iter()
696 .map(|c| first + c.period.unwrap() - 1)
697 .collect();
698
699 let out_uninit = unsafe {
700 std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
701 };
702
703 unsafe { init_matrix_prefixes(out_uninit, cols, &warm) };
704
705 let cols = data.len();
706 let mut pref_a = AVec::<f64>::with_capacity(CACHELINE_ALIGN, cols + 1);
707 let mut pref_b = AVec::<f64>::with_capacity(CACHELINE_ALIGN, cols + 1);
708 unsafe {
709 pref_a.set_len(cols + 1);
710 pref_b.set_len(cols + 1);
711 }
712 pref_a[0] = 0.0;
713 pref_b[0] = 0.0;
714 for i in 0..cols {
715 let x = if i < first { 0.0 } else { data[i] };
716 pref_a[i + 1] = pref_a[i] + x;
717 pref_b[i + 1] = pref_b[i] + (i as f64) * x;
718 }
719
720 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
721 let period = combos[row].period.unwrap();
722 let denom = (period * (period + 1)) as f64 / 2.0;
723 let inv_div = 1.0 / denom;
724 let warm_end = first + period - 1;
725
726 let out_row =
727 core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
728
729 for i in warm_end..cols {
730 let s_a = pref_a[i + 1] - pref_a[i + 1 - period];
731 let s_b = pref_b[i + 1] - pref_b[i + 1 - period];
732 let wsum = s_b - ((i + 1 - period) as f64 - 1.0) * s_a;
733 out_row[i] = wsum * inv_div;
734 }
735 };
736
737 if parallel {
738 #[cfg(not(target_arch = "wasm32"))]
739 {
740 out_uninit
741 .par_chunks_mut(cols)
742 .enumerate()
743 .for_each(|(row, slice)| do_row(row, slice));
744 }
745
746 #[cfg(target_arch = "wasm32")]
747 {
748 for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
749 do_row(row, slice);
750 }
751 }
752 } else {
753 for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
754 do_row(row, slice);
755 }
756 }
757
758 Ok(combos)
759}
760
761#[inline(always)]
762unsafe fn wma_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
763 wma_scalar(data, period, first, out)
764}
765
766#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
767#[inline(always)]
768unsafe fn wma_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
769 wma_row_scalar(data, first, period, out)
770}
771
772#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
773#[inline(always)]
774pub unsafe fn wma_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
775 if period <= 32 {
776 wma_row_avx512_short(data, first, period, out);
777 } else {
778 wma_row_avx512_long(data, first, period, out);
779 }
780}
781
782#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
783#[inline(always)]
784unsafe fn wma_row_avx512_short(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
785 wma_row_scalar(data, first, period, out)
786}
787
788#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
789#[inline(always)]
790unsafe fn wma_row_avx512_long(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
791 wma_row_scalar(data, first, period, out)
792}
793
794#[cfg(test)]
795mod tests {
796 use super::*;
797 use crate::skip_if_unsupported;
798 use crate::utilities::data_loader::read_candles_from_csv;
799 use paste::paste;
800
801 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
802 #[test]
803 fn test_wma_into_matches_api() -> Result<(), Box<dyn Error>> {
804 let mut data = Vec::with_capacity(256);
805
806 data.extend_from_slice(&[f64::NAN, f64::NAN, f64::NAN, f64::NAN]);
807 for i in 0..252u32 {
808 let v = 0.5 * (i as f64) + ((i % 7) as f64);
809 data.push(v);
810 }
811
812 let params = WmaParams { period: Some(30) };
813 let input = WmaInput::from_slice(&data, params);
814
815 let baseline = wma(&input)?.values;
816
817 let mut out = vec![0.0; data.len()];
818 wma_into(&input, &mut out)?;
819
820 assert_eq!(baseline.len(), out.len());
821 for (i, (&a, &b)) in baseline.iter().zip(out.iter()).enumerate() {
822 let equal = (a.is_nan() && b.is_nan()) || (a == b);
823 assert!(equal, "Mismatch at index {}: api={} into={}", i, a, b);
824 }
825
826 Ok(())
827 }
828
829 fn check_wma_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
830 skip_if_unsupported!(kernel, test_name);
831 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
832 let candles = read_candles_from_csv(file_path)?;
833
834 let default_params = WmaParams { period: None };
835 let input = WmaInput::from_candles(&candles, "close", default_params);
836 let output = wma_with_kernel(&input, kernel)?;
837 assert_eq!(output.values.len(), candles.close.len());
838
839 let params_period_14 = WmaParams { period: Some(14) };
840 let input2 = WmaInput::from_candles(&candles, "hl2", params_period_14);
841 let output2 = wma_with_kernel(&input2, kernel)?;
842 assert_eq!(output2.values.len(), candles.close.len());
843
844 let params_custom = WmaParams { period: Some(20) };
845 let input3 = WmaInput::from_candles(&candles, "hlc3", params_custom);
846 let output3 = wma_with_kernel(&input3, kernel)?;
847 assert_eq!(output3.values.len(), candles.close.len());
848 Ok(())
849 }
850
851 fn check_wma_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
852 skip_if_unsupported!(kernel, test_name);
853 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
854 let candles = read_candles_from_csv(file_path)?;
855 let data = &candles.close;
856 let default_params = WmaParams::default();
857 let input = WmaInput::from_candles(&candles, "close", default_params);
858 let result = wma_with_kernel(&input, kernel)?;
859
860 let expected_last_five = [
861 59638.52903225806,
862 59563.7376344086,
863 59489.4064516129,
864 59432.02580645162,
865 59350.58279569892,
866 ];
867 assert!(result.values.len() >= 5, "Not enough WMA values");
868 assert_eq!(
869 result.values.len(),
870 data.len(),
871 "WMA output length should match input length"
872 );
873 let start_index = result.values.len().saturating_sub(5);
874 let last_five = &result.values[start_index..];
875 for (i, &value) in last_five.iter().enumerate() {
876 assert!(
877 (value - expected_last_five[i]).abs() < 1e-6,
878 "WMA value mismatch at index {}: expected {}, got {}",
879 i,
880 expected_last_five[i],
881 value
882 );
883 }
884 let period = input.params.period.unwrap_or(30);
885 for val in result.values.iter().skip(period - 1) {
886 if !val.is_nan() {
887 assert!(val.is_finite(), "WMA output should be finite");
888 }
889 }
890 Ok(())
891 }
892
893 fn check_wma_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
894 skip_if_unsupported!(kernel, test_name);
895 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
896 let candles = read_candles_from_csv(file_path)?;
897 let input = WmaInput::with_default_candles(&candles);
898 match input.data {
899 WmaData::Candles { source, .. } => assert_eq!(source, "close"),
900 _ => panic!("Expected WmaData::Candles"),
901 }
902 let output = wma_with_kernel(&input, kernel)?;
903 assert_eq!(output.values.len(), candles.close.len());
904 Ok(())
905 }
906
907 fn check_wma_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
908 skip_if_unsupported!(kernel, test_name);
909 let input_data = [10.0, 20.0, 30.0];
910 let params = WmaParams { period: Some(0) };
911 let input = WmaInput::from_slice(&input_data, params);
912 let res = wma_with_kernel(&input, kernel);
913 assert!(
914 res.is_err(),
915 "[{}] WMA should fail with zero period",
916 test_name
917 );
918 Ok(())
919 }
920
921 fn check_wma_period_exceeds_length(
922 test_name: &str,
923 kernel: Kernel,
924 ) -> Result<(), Box<dyn Error>> {
925 skip_if_unsupported!(kernel, test_name);
926 let data_small = [10.0, 20.0, 30.0];
927 let params = WmaParams { period: Some(10) };
928 let input = WmaInput::from_slice(&data_small, params);
929 let res = wma_with_kernel(&input, kernel);
930 assert!(
931 res.is_err(),
932 "[{}] WMA should fail with period exceeding length",
933 test_name
934 );
935 Ok(())
936 }
937
938 fn check_wma_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
939 skip_if_unsupported!(kernel, test_name);
940 let single_point = [42.0];
941 let params = WmaParams { period: Some(9) };
942 let input = WmaInput::from_slice(&single_point, params);
943 let res = wma_with_kernel(&input, kernel);
944 assert!(
945 res.is_err(),
946 "[{}] WMA should fail with insufficient data",
947 test_name
948 );
949 Ok(())
950 }
951
952 fn check_wma_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
953 skip_if_unsupported!(kernel, test_name);
954 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
955 let candles = read_candles_from_csv(file_path)?;
956 let first_params = WmaParams { period: Some(14) };
957 let first_input = WmaInput::from_candles(&candles, "close", first_params);
958 let first_result = wma_with_kernel(&first_input, kernel)?;
959 let second_params = WmaParams { period: Some(5) };
960 let second_input = WmaInput::from_slice(&first_result.values, second_params);
961 let second_result = wma_with_kernel(&second_input, kernel)?;
962 assert_eq!(second_result.values.len(), first_result.values.len());
963 for val in &second_result.values[50..] {
964 assert!(!val.is_nan());
965 }
966 Ok(())
967 }
968
969 fn check_wma_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
970 skip_if_unsupported!(kernel, test_name);
971 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
972 let candles = read_candles_from_csv(file_path)?;
973 let params = WmaParams { period: Some(14) };
974 let input = WmaInput::from_candles(&candles, "close", params);
975 let result = wma_with_kernel(&input, kernel)?;
976 assert_eq!(result.values.len(), candles.close.len());
977 if result.values.len() > 50 {
978 for i in 50..result.values.len() {
979 assert!(!result.values[i].is_nan());
980 }
981 }
982 Ok(())
983 }
984
985 fn check_wma_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
986 skip_if_unsupported!(kernel, test_name);
987 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
988 let candles = read_candles_from_csv(file_path)?;
989 let period = 30;
990 let input = WmaInput::from_candles(
991 &candles,
992 "close",
993 WmaParams {
994 period: Some(period),
995 },
996 );
997 let batch_output = wma_with_kernel(&input, kernel)?.values;
998
999 let mut stream = WmaStream::try_new(WmaParams {
1000 period: Some(period),
1001 })?;
1002 let mut stream_values = Vec::with_capacity(candles.close.len());
1003 for &price in &candles.close {
1004 match stream.update(price) {
1005 Some(val) => stream_values.push(val),
1006 None => stream_values.push(f64::NAN),
1007 }
1008 }
1009 assert_eq!(batch_output.len(), stream_values.len());
1010 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1011 if b.is_nan() && s.is_nan() {
1012 continue;
1013 }
1014 let diff = (b - s).abs();
1015 assert!(
1016 diff < 1e-8,
1017 "[{}] WMA streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1018 test_name,
1019 i,
1020 b,
1021 s,
1022 diff
1023 );
1024 }
1025 Ok(())
1026 }
1027
1028 #[cfg(debug_assertions)]
1029 fn check_wma_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1030 skip_if_unsupported!(kernel, test_name);
1031
1032 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1033 let candles = read_candles_from_csv(file_path)?;
1034
1035 let test_periods = vec![2, 5, 10, 14, 20, 30, 50, 100, 200];
1036
1037 for &period in &test_periods {
1038 if period > candles.close.len() {
1039 continue;
1040 }
1041
1042 let input = WmaInput::from_candles(
1043 &candles,
1044 "close",
1045 WmaParams {
1046 period: Some(period),
1047 },
1048 );
1049 let output = wma_with_kernel(&input, kernel)?;
1050
1051 for (i, &val) in output.values.iter().enumerate() {
1052 if val.is_nan() {
1053 continue;
1054 }
1055
1056 let bits = val.to_bits();
1057
1058 if bits == 0x11111111_11111111 {
1059 panic!(
1060 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} with period {}",
1061 test_name, val, bits, i, period
1062 );
1063 }
1064
1065 if bits == 0x22222222_22222222 {
1066 panic!(
1067 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} with period {}",
1068 test_name, val, bits, i, period
1069 );
1070 }
1071
1072 if bits == 0x33333333_33333333 {
1073 panic!(
1074 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} with period {}",
1075 test_name, val, bits, i, period
1076 );
1077 }
1078 }
1079 }
1080
1081 Ok(())
1082 }
1083
1084 #[cfg(not(debug_assertions))]
1085 fn check_wma_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1086 Ok(())
1087 }
1088
1089 macro_rules! generate_all_wma_tests {
1090 ($($test_fn:ident),*) => {
1091 paste! {
1092 $(
1093 #[test]
1094 fn [<$test_fn _scalar_f64>]() {
1095 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1096 }
1097 )*
1098 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1099 $(
1100 #[test]
1101 fn [<$test_fn _avx2_f64>]() {
1102 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1103 }
1104 #[test]
1105 fn [<$test_fn _avx512_f64>]() {
1106 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1107 }
1108 )*
1109 }
1110 }
1111 }
1112
1113 #[cfg(feature = "proptest")]
1114 #[allow(clippy::float_cmp)]
1115 fn check_wma_property(
1116 test_name: &str,
1117 kernel: Kernel,
1118 ) -> Result<(), Box<dyn std::error::Error>> {
1119 use proptest::prelude::*;
1120 skip_if_unsupported!(kernel, test_name);
1121
1122 let strat = (2usize..=100).prop_flat_map(|period| {
1123 (
1124 prop::collection::vec(
1125 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1126 period..400,
1127 ),
1128 Just(period),
1129 )
1130 });
1131
1132 proptest::test_runner::TestRunner::default()
1133 .run(&strat, |(data, period)| {
1134 let params = WmaParams {
1135 period: Some(period),
1136 };
1137 let input = WmaInput::from_slice(&data, params.clone());
1138
1139 let WmaOutput { values: out } = wma_with_kernel(&input, kernel).unwrap();
1140
1141 let WmaOutput { values: ref_out } =
1142 wma_with_kernel(&input, Kernel::Scalar).unwrap();
1143
1144 let first = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1145 let warmup_end = first + period - 1;
1146
1147 for i in 0..warmup_end.min(out.len()) {
1148 prop_assert!(
1149 out[i].is_nan(),
1150 "Expected NaN during warmup at index {}, got {}",
1151 i,
1152 out[i]
1153 );
1154 }
1155
1156 for i in warmup_end..out.len() {
1157 prop_assert!(
1158 out[i].is_finite(),
1159 "Expected finite value after warmup at index {}, got {}",
1160 i,
1161 out[i]
1162 );
1163 }
1164
1165 for i in warmup_end..data.len() {
1166 let window_start = i + 1 - period;
1167 let window = &data[window_start..=i];
1168
1169 let lo = window.iter().cloned().fold(f64::INFINITY, f64::min);
1170 let hi = window.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1171 let y = out[i];
1172
1173 prop_assert!(
1174 y >= lo - 1e-9 && y <= hi + 1e-9,
1175 "WMA at index {} = {} is outside window bounds [{}, {}]",
1176 i,
1177 y,
1178 lo,
1179 hi
1180 );
1181
1182 if window.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-12) {
1183 prop_assert!(
1184 (y - window[0]).abs() <= 1e-9,
1185 "Constant input should produce constant output: {} vs {}",
1186 y,
1187 window[0]
1188 );
1189 }
1190
1191 let r = ref_out[i];
1192 if y.is_finite() && r.is_finite() {
1193 let y_bits = y.to_bits();
1194 let r_bits = r.to_bits();
1195 let ulp_diff = y_bits.abs_diff(r_bits);
1196
1197 prop_assert!(
1198 (y - r).abs() <= 1e-9 || ulp_diff <= 4,
1199 "Kernel mismatch at index {}: {} vs {} (ULP={})",
1200 i,
1201 y,
1202 r,
1203 ulp_diff
1204 );
1205 }
1206 }
1207
1208 if period == 2 && out.len() >= 2 {
1209 let idx = warmup_end;
1210 if idx < out.len() {
1211 let expected = (data[idx - 1] + 2.0 * data[idx]) / 3.0;
1212 prop_assert!(
1213 (out[idx] - expected).abs() <= 1e-9,
1214 "Period=2 calculation mismatch: {} vs expected {}",
1215 out[idx],
1216 expected
1217 );
1218 }
1219 }
1220
1221 if period <= 5 && warmup_end < out.len() {
1222 let idx = warmup_end;
1223 let window_start = idx + 1 - period;
1224
1225 let mut weighted_sum = 0.0;
1226 let mut weight_sum = 0.0;
1227 for (j, &val) in data[window_start..=idx].iter().enumerate() {
1228 let weight = (j + 1) as f64;
1229 weighted_sum += weight * val;
1230 weight_sum += weight;
1231 }
1232 let expected = weighted_sum / weight_sum;
1233
1234 prop_assert!(
1235 (out[idx] - expected).abs() <= 1e-9,
1236 "Weight formula verification failed at index {}: {} vs expected {}",
1237 idx,
1238 out[idx],
1239 expected
1240 );
1241 }
1242
1243 if data.len() >= period * 2 {
1244 let mid = data.len() / 2;
1245 if mid > warmup_end {
1246 let mut step_data = vec![10.0; data.len()];
1247 for i in mid..step_data.len() {
1248 step_data[i] = 100.0;
1249 }
1250
1251 let step_input = WmaInput::from_slice(&step_data, params.clone());
1252 let WmaOutput { values: step_out } =
1253 wma_with_kernel(&step_input, kernel).unwrap();
1254
1255 if mid + period < step_out.len() {
1256 let wma_after_step = step_out[mid + period - 1];
1257 let distance_to_new = (wma_after_step - 100.0).abs();
1258 let distance_to_old = (wma_after_step - 10.0).abs();
1259 prop_assert!(
1260 distance_to_new < distance_to_old,
1261 "WMA should respond more to recent values: {} should be closer to 100 than 10",
1262 wma_after_step
1263 );
1264 }
1265 }
1266 }
1267
1268 if data.len() == period {
1269 let valid_count = out.iter().filter(|x| x.is_finite()).count();
1270 prop_assert!(
1271 valid_count == 1,
1272 "With data.len() == period, should have exactly 1 valid output, got {}",
1273 valid_count
1274 );
1275
1276 prop_assert!(
1277 out[data.len() - 1].is_finite(),
1278 "Last value should be valid when data.len() == period"
1279 );
1280 }
1281
1282 let is_monotonic_increasing = data.windows(2).all(|w| w[1] >= w[0] - 1e-12);
1283 if is_monotonic_increasing && out.len() > warmup_end + 1 {
1284 for i in (warmup_end + 1)..out.len() {
1285 prop_assert!(
1286 out[i] >= out[i - 1] - 1e-9,
1287 "Monotonic input should produce monotonic WMA: {} < {} at index {}",
1288 out[i],
1289 out[i - 1],
1290 i
1291 );
1292 }
1293 }
1294
1295 #[cfg(debug_assertions)]
1296 {
1297 for (i, &val) in out.iter().enumerate() {
1298 if !val.is_nan() {
1299 let bits = val.to_bits();
1300 prop_assert!(
1301 bits != 0x11111111_11111111
1302 && bits != 0x22222222_22222222
1303 && bits != 0x33333333_33333333,
1304 "Found poison value at index {}: {} (0x{:016X})",
1305 i,
1306 val,
1307 bits
1308 );
1309 }
1310 }
1311 }
1312
1313 Ok(())
1314 })
1315 .unwrap();
1316
1317 Ok(())
1318 }
1319
1320 generate_all_wma_tests!(
1321 check_wma_partial_params,
1322 check_wma_accuracy,
1323 check_wma_default_candles,
1324 check_wma_zero_period,
1325 check_wma_period_exceeds_length,
1326 check_wma_very_small_dataset,
1327 check_wma_reinput,
1328 check_wma_nan_handling,
1329 check_wma_streaming,
1330 check_wma_no_poison
1331 );
1332
1333 #[cfg(feature = "proptest")]
1334 generate_all_wma_tests!(check_wma_property);
1335
1336 fn check_invalid_kernel_error(test: &str) -> Result<(), Box<dyn Error>> {
1337 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
1338 let sweep = WmaBatchRange { period: (2, 5, 1) };
1339
1340 let non_batch_kernels = vec![Kernel::Scalar, Kernel::Avx2, Kernel::Avx512];
1341 for kernel in non_batch_kernels {
1342 let result = wma_with_kernel_batch(&data, &sweep, kernel);
1343 assert!(
1344 matches!(result, Err(WmaError::InvalidKernelForBatch(_))),
1345 "[{}] Expected InvalidKernelForBatch error for {:?}, got {:?}",
1346 test,
1347 kernel,
1348 result
1349 );
1350 }
1351
1352 let batch_kernels = vec![Kernel::Auto, Kernel::ScalarBatch];
1353 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1354 let batch_kernels = vec![
1355 Kernel::Auto,
1356 Kernel::ScalarBatch,
1357 Kernel::Avx2Batch,
1358 Kernel::Avx512Batch,
1359 ];
1360
1361 for kernel in batch_kernels {
1362 let result = wma_with_kernel_batch(&data, &sweep, kernel);
1363 assert!(
1364 result.is_ok(),
1365 "[{}] Expected success for batch kernel {:?}, got error: {:?}",
1366 test,
1367 kernel,
1368 result.err()
1369 );
1370 }
1371
1372 Ok(())
1373 }
1374
1375 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1376 skip_if_unsupported!(kernel, test);
1377 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1378 let c = read_candles_from_csv(file)?;
1379
1380 let output = WmaBatchBuilder::new()
1381 .kernel(kernel)
1382 .apply_candles(&c, "close")?;
1383
1384 let def = WmaParams::default();
1385 let row = output.values_for(&def).expect("default row missing");
1386
1387 assert_eq!(row.len(), c.close.len());
1388
1389 let expected = [
1390 59638.52903225806,
1391 59563.7376344086,
1392 59489.4064516129,
1393 59432.02580645162,
1394 59350.58279569892,
1395 ];
1396 let start = row.len() - 5;
1397 for (i, &v) in row[start..].iter().enumerate() {
1398 assert!(
1399 (v - expected[i]).abs() < 1e-6,
1400 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1401 );
1402 }
1403 Ok(())
1404 }
1405
1406 #[cfg(debug_assertions)]
1407 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1408 skip_if_unsupported!(kernel, test);
1409
1410 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1411 let c = read_candles_from_csv(file)?;
1412
1413 let batch_configs = vec![
1414 (2, 10, 1),
1415 (5, 25, 5),
1416 (10, 30, 10),
1417 (20, 100, 10),
1418 (30, 150, 30),
1419 (50, 200, 50),
1420 (2, 5, 1),
1421 ];
1422
1423 for (start, end, step) in batch_configs {
1424 if start > c.close.len() {
1425 continue;
1426 }
1427
1428 let output = WmaBatchBuilder::new()
1429 .kernel(kernel)
1430 .period_range(start, end, step)
1431 .apply_candles(&c, "close")?;
1432
1433 for (idx, &val) in output.values.iter().enumerate() {
1434 if val.is_nan() {
1435 continue;
1436 }
1437
1438 let bits = val.to_bits();
1439 let row = idx / output.cols;
1440 let col = idx % output.cols;
1441 let period = output.combos[row].period.unwrap_or(0);
1442
1443 if bits == 0x11111111_11111111 {
1444 panic!(
1445 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (flat index {}) for period {} in range ({}, {}, {})",
1446 test, val, bits, row, col, idx, period, start, end, step
1447 );
1448 }
1449
1450 if bits == 0x22222222_22222222 {
1451 panic!(
1452 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (flat index {}) for period {} in range ({}, {}, {})",
1453 test, val, bits, row, col, idx, period, start, end, step
1454 );
1455 }
1456
1457 if bits == 0x33333333_33333333 {
1458 panic!(
1459 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (flat index {}) for period {} in range ({}, {}, {})",
1460 test, val, bits, row, col, idx, period, start, end, step
1461 );
1462 }
1463 }
1464 }
1465
1466 Ok(())
1467 }
1468
1469 #[cfg(not(debug_assertions))]
1470 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1471 Ok(())
1472 }
1473
1474 macro_rules! gen_batch_tests {
1475 ($fn_name:ident) => {
1476 paste! {
1477 #[test] fn [<$fn_name _scalar>]() {
1478 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1479 }
1480 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1481 #[test] fn [<$fn_name _avx2>]() {
1482 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1483 }
1484 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1485 #[test] fn [<$fn_name _avx512>]() {
1486 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1487 }
1488 #[test] fn [<$fn_name _auto_detect>]() {
1489 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1490 }
1491 }
1492 };
1493 }
1494 gen_batch_tests!(check_batch_default_row);
1495 gen_batch_tests!(check_batch_no_poison);
1496
1497 #[test]
1498 fn test_invalid_kernel_error() {
1499 let _ = check_invalid_kernel_error("test_invalid_kernel_error");
1500 }
1501}
1502
1503#[cfg(feature = "python")]
1504#[pyfunction(name = "wma")]
1505#[pyo3(signature = (data, period, kernel=None))]
1506
1507pub fn wma_py<'py>(
1508 py: Python<'py>,
1509 data: numpy::PyReadonlyArray1<'py, f64>,
1510 period: usize,
1511 kernel: Option<&str>,
1512) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
1513 use numpy::{IntoPyArray, PyArrayMethods};
1514
1515 let slice_in = data.as_slice()?;
1516
1517 let kern = validate_kernel(kernel, false)?;
1518
1519 let params = WmaParams {
1520 period: Some(period),
1521 };
1522 let wma_in = WmaInput::from_slice(slice_in, params);
1523
1524 let result_vec: Vec<f64> = py
1525 .allow_threads(|| wma_with_kernel(&wma_in, kern).map(|o| o.values))
1526 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1527
1528 Ok(result_vec.into_pyarray(py))
1529}
1530
1531#[cfg(feature = "python")]
1532#[pyclass(name = "WmaStream")]
1533pub struct WmaStreamPy {
1534 stream: WmaStream,
1535}
1536
1537#[cfg(feature = "python")]
1538#[pymethods]
1539impl WmaStreamPy {
1540 #[new]
1541 fn new(period: usize) -> PyResult<Self> {
1542 let params = WmaParams {
1543 period: Some(period),
1544 };
1545 let stream =
1546 WmaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1547 Ok(WmaStreamPy { stream })
1548 }
1549
1550 fn update(&mut self, value: f64) -> Option<f64> {
1551 self.stream.update(value)
1552 }
1553}
1554
1555#[cfg(feature = "python")]
1556#[pyfunction(name = "wma_batch")]
1557#[pyo3(signature = (data, period_range, kernel=None))]
1558
1559pub fn wma_batch_py<'py>(
1560 py: Python<'py>,
1561 data: numpy::PyReadonlyArray1<'py, f64>,
1562 period_range: (usize, usize, usize),
1563 kernel: Option<&str>,
1564) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1565 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1566 use pyo3::types::PyDict;
1567
1568 let slice_in = data.as_slice()?;
1569
1570 let sweep = WmaBatchRange {
1571 period: period_range,
1572 };
1573
1574 let kern = validate_kernel(kernel, true)?;
1575
1576 let combos = expand_grid(&sweep);
1577 let rows = combos.len();
1578 let cols = slice_in.len();
1579
1580 let needed = rows
1581 .checked_mul(cols)
1582 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
1583 let out_arr = unsafe { PyArray1::<f64>::new(py, [needed], false) };
1584 let slice_out = unsafe { out_arr.as_slice_mut()? };
1585
1586 let combos = py
1587 .allow_threads(|| {
1588 let kernel = match kern {
1589 Kernel::Auto => detect_best_batch_kernel(),
1590 k => k,
1591 };
1592 let simd = match kernel {
1593 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1594 Kernel::Avx512Batch => Kernel::Avx512,
1595 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1596 Kernel::Avx2Batch => Kernel::Avx2,
1597 Kernel::ScalarBatch => Kernel::Scalar,
1598 _ => Kernel::Scalar,
1599 };
1600
1601 wma_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1602 })
1603 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1604
1605 let dict = PyDict::new(py);
1606 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1607 dict.set_item(
1608 "periods",
1609 combos
1610 .iter()
1611 .map(|p| p.period.unwrap() as u64)
1612 .collect::<Vec<_>>()
1613 .into_pyarray(py),
1614 )?;
1615 dict.set_item("rows", rows)?;
1616 dict.set_item("cols", cols)?;
1617
1618 Ok(dict)
1619}
1620
1621#[cfg(all(feature = "python", feature = "cuda"))]
1622#[pyfunction(name = "wma_cuda_batch_dev")]
1623#[pyo3(signature = (data_f32, period_range, device_id=0))]
1624pub fn wma_cuda_batch_dev_py(
1625 py: Python<'_>,
1626 data_f32: numpy::PyReadonlyArray1<'_, f32>,
1627 period_range: (usize, usize, usize),
1628 device_id: usize,
1629) -> PyResult<DeviceArrayF32Py> {
1630 if !cuda_available() {
1631 return Err(PyValueError::new_err("CUDA not available"));
1632 }
1633
1634 let slice_in = data_f32.as_slice()?;
1635 let sweep = WmaBatchRange {
1636 period: period_range,
1637 };
1638
1639 let (inner, ctx, dev_id) = py.allow_threads(|| {
1640 let cuda = CudaWma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1641 let ctx = cuda.context_arc();
1642 let dev_id = cuda.device_id();
1643 let arr = cuda
1644 .wma_batch_dev(slice_in, &sweep)
1645 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1646 Ok::<_, pyo3::PyErr>((arr, ctx, dev_id))
1647 })?;
1648
1649 Ok(DeviceArrayF32Py::new_from_rust(inner, ctx, dev_id))
1650}
1651
1652#[cfg(all(feature = "python", feature = "cuda"))]
1653#[pyfunction(name = "wma_cuda_many_series_one_param_dev")]
1654#[pyo3(signature = (data_tm_f32, period, device_id=0))]
1655pub fn wma_cuda_many_series_one_param_dev_py(
1656 py: Python<'_>,
1657 data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1658 period: usize,
1659 device_id: usize,
1660) -> PyResult<DeviceArrayF32Py> {
1661 use numpy::PyUntypedArrayMethods;
1662
1663 if !cuda_available() {
1664 return Err(PyValueError::new_err("CUDA not available"));
1665 }
1666
1667 let flat_in = data_tm_f32.as_slice()?;
1668 let rows = data_tm_f32.shape()[0];
1669 let cols = data_tm_f32.shape()[1];
1670 let params = WmaParams {
1671 period: Some(period),
1672 };
1673
1674 let (inner, ctx, dev_id) = py.allow_threads(|| {
1675 let cuda = CudaWma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1676 let ctx = cuda.context_arc();
1677 let dev_id = cuda.device_id();
1678 let arr = cuda
1679 .wma_multi_series_one_param_time_major_dev(flat_in, cols, rows, ¶ms)
1680 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1681 Ok::<_, pyo3::PyErr>((arr, ctx, dev_id))
1682 })?;
1683
1684 Ok(DeviceArrayF32Py::new_from_rust(inner, ctx, dev_id))
1685}
1686
1687#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1688#[derive(Serialize, Deserialize)]
1689pub struct WmaBatchConfig {
1690 pub period_range: (usize, usize, usize),
1691}
1692
1693#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1694#[derive(Serialize, Deserialize)]
1695pub struct WmaBatchJsOutput {
1696 pub values: Vec<f64>,
1697 pub combos: Vec<WmaParams>,
1698 pub rows: usize,
1699 pub cols: usize,
1700}
1701
1702#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1703#[wasm_bindgen(js_name = wma_batch)]
1704pub fn wma_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1705 let cfg: WmaBatchConfig = serde_wasm_bindgen::from_value(config)
1706 .map_err(|e| JsValue::from_str(&format!("Invalid config: {e}")))?;
1707 let sweep = WmaBatchRange {
1708 period: cfg.period_range,
1709 };
1710
1711 let out = wma_batch_inner(data, &sweep, detect_best_kernel(), false)
1712 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1713
1714 let js = WmaBatchJsOutput {
1715 values: out.values,
1716 combos: out.combos,
1717 rows: out.rows,
1718 cols: out.cols,
1719 };
1720 serde_wasm_bindgen::to_value(&js)
1721 .map_err(|e| JsValue::from_str(&format!("Serialization error: {e}")))
1722}
1723
1724#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1725#[wasm_bindgen]
1726pub fn wma_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1727 let params = WmaParams {
1728 period: Some(period),
1729 };
1730 let input = WmaInput::from_slice(data, params);
1731
1732 let mut output = vec![0.0; data.len()];
1733 wma_into_slice(&mut output, &input, Kernel::Auto)
1734 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1735
1736 Ok(output)
1737}
1738
1739#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1740#[wasm_bindgen]
1741pub fn wma_batch_js(
1742 data: &[f64],
1743 period_start: usize,
1744 period_end: usize,
1745 period_step: usize,
1746) -> Result<Vec<f64>, JsValue> {
1747 let sweep = WmaBatchRange {
1748 period: (period_start, period_end, period_step),
1749 };
1750
1751 wma_batch_inner(data, &sweep, Kernel::Auto, false)
1752 .map(|output| output.values)
1753 .map_err(|e| JsValue::from_str(&e.to_string()))
1754}
1755
1756#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1757#[wasm_bindgen]
1758pub fn wma_batch_metadata_js(
1759 period_start: usize,
1760 period_end: usize,
1761 period_step: usize,
1762) -> Result<Vec<f64>, JsValue> {
1763 let sweep = WmaBatchRange {
1764 period: (period_start, period_end, period_step),
1765 };
1766
1767 let combos = expand_grid(&sweep);
1768 let mut metadata = Vec::with_capacity(combos.len());
1769
1770 for combo in combos {
1771 metadata.push(combo.period.unwrap() as f64);
1772 }
1773
1774 Ok(metadata)
1775}
1776
1777#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1778#[wasm_bindgen]
1779pub fn wma_alloc(len: usize) -> *mut f64 {
1780 let mut vec = Vec::<f64>::with_capacity(len);
1781 let ptr = vec.as_mut_ptr();
1782 std::mem::forget(vec);
1783 ptr
1784}
1785
1786#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1787#[wasm_bindgen]
1788pub fn wma_free(ptr: *mut f64, len: usize) {
1789 if !ptr.is_null() {
1790 unsafe {
1791 let _ = Vec::from_raw_parts(ptr, len, len);
1792 }
1793 }
1794}
1795
1796#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1797#[wasm_bindgen]
1798pub fn wma_into(
1799 in_ptr: *const f64,
1800 out_ptr: *mut f64,
1801 len: usize,
1802 period: usize,
1803) -> Result<(), JsValue> {
1804 if in_ptr.is_null() || out_ptr.is_null() {
1805 return Err(JsValue::from_str("Null pointer provided"));
1806 }
1807
1808 unsafe {
1809 let data = std::slice::from_raw_parts(in_ptr, len);
1810 let params = WmaParams {
1811 period: Some(period),
1812 };
1813 let input = WmaInput::from_slice(data, params);
1814
1815 if in_ptr == out_ptr {
1816 let mut temp = vec![0.0; len];
1817 wma_into_slice(&mut temp, &input, Kernel::Auto)
1818 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1819 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1820 out.copy_from_slice(&temp);
1821 } else {
1822 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1823 wma_into_slice(out, &input, Kernel::Auto)
1824 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1825 }
1826 Ok(())
1827 }
1828}
1829
1830#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1831#[wasm_bindgen]
1832pub fn wma_batch_into(
1833 in_ptr: *const f64,
1834 out_ptr: *mut f64,
1835 len: usize,
1836 period_start: usize,
1837 period_end: usize,
1838 period_step: usize,
1839) -> Result<usize, JsValue> {
1840 if in_ptr.is_null() || out_ptr.is_null() {
1841 return Err(JsValue::from_str("Null pointer provided"));
1842 }
1843
1844 unsafe {
1845 let data = std::slice::from_raw_parts(in_ptr, len);
1846
1847 let sweep = WmaBatchRange {
1848 period: (period_start, period_end, period_step),
1849 };
1850
1851 let combos = expand_grid(&sweep);
1852 let rows = combos.len();
1853 let total_size = rows * len;
1854
1855 let out = std::slice::from_raw_parts_mut(out_ptr, total_size);
1856
1857 wma_batch_inner_into(data, &sweep, Kernel::Auto, false, out)
1858 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1859
1860 Ok(rows)
1861 }
1862}