1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5 make_uninit_matrix,
6};
7use aligned_vec::{AVec, CACHELINE_ALIGN};
8#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
9use core::arch::x86_64::*;
10#[cfg(not(target_arch = "wasm32"))]
11use rayon::prelude::*;
12use std::convert::AsRef;
13use std::error::Error;
14use std::f64::consts::PI;
15use std::mem::MaybeUninit;
16use thiserror::Error;
17
18#[cfg(all(feature = "python", feature = "cuda"))]
19use crate::cuda::cuda_available;
20#[cfg(all(feature = "python", feature = "cuda"))]
21use crate::cuda::moving_averages::CudaSinwma;
22#[cfg(all(feature = "python", feature = "cuda"))]
23use crate::utilities::dlpack_cuda::{make_device_array_py, DeviceArrayF32Py};
24#[cfg(feature = "python")]
25use crate::utilities::kernel_validation::validate_kernel;
26#[cfg(feature = "python")]
27use numpy::{IntoPyArray, PyArray1, PyArray2, PyArrayMethods, PyReadonlyArray1};
28#[cfg(all(feature = "python", feature = "cuda"))]
29use numpy::{PyReadonlyArray2, PyUntypedArrayMethods};
30#[cfg(feature = "python")]
31use pyo3::exceptions::PyValueError;
32#[cfg(feature = "python")]
33use pyo3::prelude::*;
34#[cfg(feature = "python")]
35use pyo3::types::{PyDict, PyList};
36
37#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
38use serde::{Deserialize, Serialize};
39#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
40use wasm_bindgen::prelude::*;
41
42impl<'a> AsRef<[f64]> for SinWmaInput<'a> {
43 #[inline(always)]
44 fn as_ref(&self) -> &[f64] {
45 match &self.data {
46 SinWmaData::Slice(slice) => slice,
47 SinWmaData::Candles { candles, source } => source_type(candles, source),
48 }
49 }
50}
51
52#[derive(Debug, Clone)]
53pub enum SinWmaData<'a> {
54 Candles {
55 candles: &'a Candles,
56 source: &'a str,
57 },
58 Slice(&'a [f64]),
59}
60
61#[derive(Debug, Clone)]
62pub struct SinWmaOutput {
63 pub values: Vec<f64>,
64}
65
66#[derive(Debug, Clone)]
67pub struct SinWmaParams {
68 pub period: Option<usize>,
69}
70
71impl Default for SinWmaParams {
72 fn default() -> Self {
73 Self { period: Some(14) }
74 }
75}
76
77#[derive(Debug, Clone)]
78pub struct SinWmaInput<'a> {
79 pub data: SinWmaData<'a>,
80 pub params: SinWmaParams,
81}
82
83impl<'a> SinWmaInput<'a> {
84 #[inline]
85 pub fn from_candles(c: &'a Candles, s: &'a str, p: SinWmaParams) -> Self {
86 Self {
87 data: SinWmaData::Candles {
88 candles: c,
89 source: s,
90 },
91 params: p,
92 }
93 }
94 #[inline]
95 pub fn from_slice(sl: &'a [f64], p: SinWmaParams) -> Self {
96 Self {
97 data: SinWmaData::Slice(sl),
98 params: p,
99 }
100 }
101 #[inline]
102 pub fn with_default_candles(c: &'a Candles) -> Self {
103 Self::from_candles(c, "close", SinWmaParams::default())
104 }
105 #[inline]
106 pub fn get_period(&self) -> usize {
107 self.params.period.unwrap_or(14)
108 }
109}
110
111#[derive(Copy, Clone, Debug)]
112pub struct SinWmaBuilder {
113 period: Option<usize>,
114 kernel: Kernel,
115}
116
117impl Default for SinWmaBuilder {
118 fn default() -> Self {
119 Self {
120 period: None,
121 kernel: Kernel::Auto,
122 }
123 }
124}
125
126impl SinWmaBuilder {
127 #[inline(always)]
128 pub fn new() -> Self {
129 Self::default()
130 }
131 #[inline(always)]
132 pub fn period(mut self, n: usize) -> Self {
133 self.period = Some(n);
134 self
135 }
136 #[inline(always)]
137 pub fn kernel(mut self, k: Kernel) -> Self {
138 self.kernel = k;
139 self
140 }
141
142 #[inline(always)]
143 pub fn apply(self, c: &Candles) -> Result<SinWmaOutput, SinWmaError> {
144 let p = SinWmaParams {
145 period: self.period,
146 };
147 let i = SinWmaInput::from_candles(c, "close", p);
148 sinwma_with_kernel(&i, self.kernel)
149 }
150
151 #[inline(always)]
152 pub fn apply_slice(self, d: &[f64]) -> Result<SinWmaOutput, SinWmaError> {
153 let p = SinWmaParams {
154 period: self.period,
155 };
156 let i = SinWmaInput::from_slice(d, p);
157 sinwma_with_kernel(&i, self.kernel)
158 }
159
160 #[inline(always)]
161 pub fn into_stream(self) -> Result<SinWmaStream, SinWmaError> {
162 let p = SinWmaParams {
163 period: self.period,
164 };
165 SinWmaStream::try_new(p)
166 }
167}
168
169#[derive(Debug, Error)]
170pub enum SinWmaError {
171 #[error("sinwma: No data provided (empty slice).")]
172 EmptyInputData,
173 #[error("sinwma: All values are NaN.")]
174 AllValuesNaN,
175 #[error("sinwma: Invalid period: period = {period}, data length = {data_len}")]
176 InvalidPeriod { period: usize, data_len: usize },
177 #[error("sinwma: Not enough valid data: needed = {needed}, valid = {valid}")]
178 NotEnoughValidData { needed: usize, valid: usize },
179 #[error("sinwma: Sum of sines is zero or too close to zero. sum_sines = {sum_sines}")]
180 ZeroSumSines { sum_sines: f64 },
181 #[error("sinwma: Output slice length mismatch: expected = {expected}, got = {got}")]
182 OutputLengthMismatch { expected: usize, got: usize },
183 #[error("sinwma: Invalid range: start={start}, end={end}, step={step}")]
184 InvalidRange {
185 start: usize,
186 end: usize,
187 step: usize,
188 },
189 #[error("sinwma: Invalid kernel for batch path: {0:?}")]
190 InvalidKernelForBatch(Kernel),
191}
192
193#[inline]
194pub fn sinwma(input: &SinWmaInput) -> Result<SinWmaOutput, SinWmaError> {
195 sinwma_with_kernel(input, Kernel::Auto)
196}
197
198pub fn sinwma_with_kernel(
199 input: &SinWmaInput,
200 kernel: Kernel,
201) -> Result<SinWmaOutput, SinWmaError> {
202 let (data, weights, period, first, chosen) = sinwma_prepare(input, kernel)?;
203
204 let mut out = alloc_with_nan_prefix(data.len(), first + period - 1);
205
206 sinwma_compute_into(data, &weights, period, first, chosen, &mut out);
207
208 Ok(SinWmaOutput { values: out })
209}
210
211#[inline]
212pub fn sinwma_into_slice(
213 dst: &mut [f64],
214 input: &SinWmaInput,
215 kern: Kernel,
216) -> Result<(), SinWmaError> {
217 let (data, weights, period, first, chosen) = sinwma_prepare(input, kern)?;
218
219 if dst.len() != data.len() {
220 return Err(SinWmaError::OutputLengthMismatch {
221 expected: data.len(),
222 got: dst.len(),
223 });
224 }
225
226 sinwma_compute_into(data, &weights, period, first, chosen, dst);
227
228 let warmup_end = first + period - 1;
229 for v in &mut dst[..warmup_end] {
230 *v = f64::NAN;
231 }
232
233 Ok(())
234}
235
236#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
237#[inline]
238pub fn sinwma_into(input: &SinWmaInput, out: &mut [f64]) -> Result<(), SinWmaError> {
239 sinwma_into_slice(out, input, Kernel::Auto)
240}
241
242#[inline(always)]
243fn sinwma_prepare<'a>(
244 input: &'a SinWmaInput,
245 kernel: Kernel,
246) -> Result<(&'a [f64], AVec<f64>, usize, usize, Kernel), SinWmaError> {
247 let data: &[f64] = input.as_ref();
248 let len = data.len();
249
250 if len == 0 {
251 return Err(SinWmaError::EmptyInputData);
252 }
253
254 let first = data
255 .iter()
256 .position(|x| !x.is_nan())
257 .ok_or(SinWmaError::AllValuesNaN)?;
258
259 let period = input.get_period();
260
261 if period == 0 || period > len {
262 return Err(SinWmaError::InvalidPeriod {
263 period,
264 data_len: len,
265 });
266 }
267 if (len - first) < period {
268 return Err(SinWmaError::NotEnoughValidData {
269 needed: period,
270 valid: len - first,
271 });
272 }
273
274 let mut weights: AVec<f64> = AVec::with_capacity(CACHELINE_ALIGN, period);
275 weights.resize(period, 0.0);
276 let mut sum_sines = 0.0;
277 for k in 0..period {
278 let angle = (k as f64 + 1.0) * PI / (period as f64 + 1.0);
279 let val = angle.sin();
280 weights[k] = val;
281 sum_sines += val;
282 }
283
284 if sum_sines.abs() < f64::EPSILON {
285 return Err(SinWmaError::ZeroSumSines { sum_sines });
286 }
287 let inv_sum = 1.0 / sum_sines;
288 for w in &mut weights[..] {
289 *w *= inv_sum;
290 }
291
292 let chosen = match kernel {
293 Kernel::Auto => detect_best_kernel(),
294 k => k,
295 };
296
297 Ok((data, weights, period, first, chosen))
298}
299
300#[inline(always)]
301fn sinwma_compute_into(
302 data: &[f64],
303 weights: &[f64],
304 period: usize,
305 first: usize,
306 kernel: Kernel,
307 out: &mut [f64],
308) {
309 unsafe {
310 match kernel {
311 Kernel::Scalar | Kernel::ScalarBatch => {
312 sinwma_scalar(data, weights, period, first, out)
313 }
314 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
315 Kernel::Avx2 | Kernel::Avx2Batch => sinwma_avx2(data, weights, period, first, out),
316 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
317 Kernel::Avx512 | Kernel::Avx512Batch => {
318 sinwma_avx512(data, weights, period, first, out)
319 }
320 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
321 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
322 sinwma_scalar(data, weights, period, first, out)
323 }
324 _ => unreachable!(),
325 }
326 }
327}
328
329#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
330#[inline]
331#[target_feature(enable = "avx512f")]
332unsafe fn hsum_pd_zmm(v: __m512d) -> f64 {
333 _mm512_reduce_add_pd(v)
334}
335
336#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
337#[inline]
338pub fn sinwma_avx512(
339 data: &[f64],
340 weights: &[f64],
341 period: usize,
342 first_valid: usize,
343 out: &mut [f64],
344) {
345 if period <= 32 {
346 unsafe { sinwma_avx512_short(data, weights, period, first_valid, out) }
347 } else {
348 unsafe { sinwma_avx512_long(data, weights, period, first_valid, out) }
349 }
350}
351
352#[inline]
353pub fn sinwma_scalar(
354 data: &[f64],
355 weights: &[f64],
356 period: usize,
357 first_val: usize,
358 out: &mut [f64],
359) {
360 assert_eq!(weights.len(), period, "weights.len() must equal `period`");
361 assert!(
362 out.len() >= data.len(),
363 "`out` must be at least as long as `data`"
364 );
365
366 let p4 = period & !3;
367
368 for i in (first_val + period - 1)..data.len() {
369 let start = i + 1 - period;
370 let window = &data[start..start + period];
371
372 let mut sum = 0.0;
373 for (d4, w4) in window[..p4]
374 .chunks_exact(4)
375 .zip(weights[..p4].chunks_exact(4))
376 {
377 sum += d4[0] * w4[0] + d4[1] * w4[1] + d4[2] * w4[2] + d4[3] * w4[3];
378 }
379
380 for (d, w) in window[p4..].iter().zip(&weights[p4..]) {
381 sum += d * w;
382 }
383
384 out[i] = sum;
385 }
386}
387
388#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
389#[inline]
390#[target_feature(enable = "avx2,fma")]
391pub unsafe fn sinwma_avx2(
392 data: &[f64],
393 weights: &[f64],
394 period: usize,
395 first_valid: usize,
396 out: &mut [f64],
397) {
398 debug_assert!(period >= 1);
399 debug_assert_eq!(data.len(), out.len());
400 debug_assert_eq!(weights.len(), period);
401
402 let p4 = period & !3usize;
403 for i in (first_valid + period - 1)..data.len() {
404 let start = i + 1 - period;
405 let mut acc = _mm256_setzero_pd();
406
407 let mut k = 0usize;
408 while k < p4 {
409 let w = _mm256_loadu_pd(weights.as_ptr().add(k));
410 let d = _mm256_loadu_pd(data.as_ptr().add(start + k));
411 acc = _mm256_fmadd_pd(d, w, acc);
412 k += 4;
413 }
414
415 let sum128 = _mm_add_pd(_mm256_castpd256_pd128(acc), _mm256_extractf128_pd(acc, 1));
416 let mut sum = _mm_cvtsd_f64(_mm_hadd_pd(sum128, sum128));
417
418 while k < period {
419 sum = (*data.get_unchecked(start + k)).mul_add(*weights.get_unchecked(k), sum);
420 k += 1;
421 }
422
423 *out.get_unchecked_mut(i) = sum;
424 }
425}
426
427#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
428#[inline]
429#[target_feature(enable = "avx512f,fma")]
430pub unsafe fn sinwma_avx512_short(
431 data: &[f64],
432 weights: &[f64],
433 period: usize,
434 first_valid: usize,
435 out: &mut [f64],
436) {
437 debug_assert!(period >= 1);
438 debug_assert_eq!(data.len(), out.len());
439 debug_assert_eq!(weights.len(), period);
440
441 const STEP: usize = 8;
442 let chunks = period / STEP;
443 let tail_len = period % STEP;
444 let tail_mask: __mmask8 = (1u8 << tail_len).wrapping_sub(1);
445
446 if chunks == 0 {
447 let wv = _mm512_maskz_loadu_pd(tail_mask, weights.as_ptr());
448 for i in (first_valid + period - 1)..data.len() {
449 let start = i + 1 - period;
450 let dv = _mm512_maskz_loadu_pd(tail_mask, data.as_ptr().add(start));
451 let sum = hsum_pd_zmm(_mm512_mul_pd(dv, wv));
452 *out.get_unchecked_mut(i) = sum;
453 }
454 return;
455 }
456
457 for i in (first_valid + period - 1)..data.len() {
458 let start = i + 1 - period;
459 let mut acc = _mm512_setzero_pd();
460
461 for blk in 0..chunks {
462 let w = _mm512_loadu_pd(weights.as_ptr().add(blk * STEP));
463 let d = _mm512_loadu_pd(data.as_ptr().add(start + blk * STEP));
464 acc = _mm512_fmadd_pd(d, w, acc);
465 }
466
467 if tail_len != 0 {
468 let wt = _mm512_maskz_loadu_pd(tail_mask, weights.as_ptr().add(chunks * STEP));
469 let dt = _mm512_maskz_loadu_pd(tail_mask, data.as_ptr().add(start + chunks * STEP));
470 acc = _mm512_fmadd_pd(dt, wt, acc);
471 }
472
473 *out.get_unchecked_mut(i) = hsum_pd_zmm(acc);
474 }
475}
476
477#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
478#[inline]
479#[target_feature(enable = "avx512f,fma,avx512dq")]
480pub unsafe fn sinwma_avx512_long(
481 data: &[f64],
482 weights: &[f64],
483 period: usize,
484 first_valid: usize,
485 out: &mut [f64],
486) {
487 const STEP: usize = 8;
488 let n_chunks = period / STEP;
489 let tail_len = period % STEP;
490 let unroll8 = n_chunks & !3;
491 let tail_mask: __mmask8 = (1u8 << tail_len).wrapping_sub(1);
492
493 debug_assert!(period >= 1);
494 debug_assert_eq!(data.len(), out.len());
495 debug_assert_eq!(weights.len(), period);
496
497 let mut wregs: Vec<__m512d> = Vec::with_capacity(n_chunks);
498 for blk in 0..n_chunks {
499 wregs.push(_mm512_loadu_pd(weights.as_ptr().add(blk * STEP)));
500 }
501 let w_tail = if tail_len != 0 {
502 Some(_mm512_maskz_loadu_pd(
503 tail_mask,
504 weights.as_ptr().add(n_chunks * STEP),
505 ))
506 } else {
507 None
508 };
509
510 let mut data_ptr = data.as_ptr().add(first_valid);
511 let stop_ptr = data.as_ptr().add(data.len());
512 let mut dst_ptr = out.as_mut_ptr().add(first_valid + period - 1);
513
514 if tail_len == 0 {
515 while data_ptr.add(period) <= stop_ptr {
516 let mut s0 = _mm512_setzero_pd();
517 let mut s1 = _mm512_setzero_pd();
518 let mut s2 = _mm512_setzero_pd();
519 let mut s3 = _mm512_setzero_pd();
520
521 for blk in (0..unroll8).step_by(4) {
522 let d0 = _mm512_loadu_pd(data_ptr.add((blk + 0) * STEP));
523 let d1 = _mm512_loadu_pd(data_ptr.add((blk + 1) * STEP));
524 let d2 = _mm512_loadu_pd(data_ptr.add((blk + 2) * STEP));
525 let d3 = _mm512_loadu_pd(data_ptr.add((blk + 3) * STEP));
526
527 s0 = _mm512_fmadd_pd(d0, *wregs.get_unchecked(blk + 0), s0);
528 s1 = _mm512_fmadd_pd(d1, *wregs.get_unchecked(blk + 1), s1);
529 s2 = _mm512_fmadd_pd(d2, *wregs.get_unchecked(blk + 2), s2);
530 s3 = _mm512_fmadd_pd(d3, *wregs.get_unchecked(blk + 3), s3);
531 }
532
533 for blk in unroll8..n_chunks {
534 let d = _mm512_loadu_pd(data_ptr.add(blk * STEP));
535 s0 = _mm512_fmadd_pd(d, *wregs.get_unchecked(blk), s0);
536 }
537
538 let sum01 = _mm512_add_pd(s0, s1);
539 let sum23 = _mm512_add_pd(s2, s3);
540 let tot = _mm512_add_pd(sum01, sum23);
541 *dst_ptr = hsum_pd_zmm(tot);
542
543 data_ptr = data_ptr.add(1);
544 dst_ptr = dst_ptr.add(1);
545 }
546 } else {
547 let wt = w_tail.unwrap();
548 while data_ptr.add(period) <= stop_ptr {
549 let mut s0 = _mm512_setzero_pd();
550 let mut s1 = _mm512_setzero_pd();
551 let mut s2 = _mm512_setzero_pd();
552 let mut s3 = _mm512_setzero_pd();
553
554 for blk in (0..unroll8).step_by(4) {
555 let d0 = _mm512_loadu_pd(data_ptr.add((blk + 0) * STEP));
556 let d1 = _mm512_loadu_pd(data_ptr.add((blk + 1) * STEP));
557 let d2 = _mm512_loadu_pd(data_ptr.add((blk + 2) * STEP));
558 let d3 = _mm512_loadu_pd(data_ptr.add((blk + 3) * STEP));
559
560 s0 = _mm512_fmadd_pd(d0, *wregs.get_unchecked(blk + 0), s0);
561 s1 = _mm512_fmadd_pd(d1, *wregs.get_unchecked(blk + 1), s1);
562 s2 = _mm512_fmadd_pd(d2, *wregs.get_unchecked(blk + 2), s2);
563 s3 = _mm512_fmadd_pd(d3, *wregs.get_unchecked(blk + 3), s3);
564 }
565
566 for blk in unroll8..n_chunks {
567 let d = _mm512_loadu_pd(data_ptr.add(blk * STEP));
568 s0 = _mm512_fmadd_pd(d, *wregs.get_unchecked(blk), s0);
569 }
570
571 let dt = _mm512_maskz_loadu_pd(tail_mask, data_ptr.add(n_chunks * STEP));
572 s0 = _mm512_fmadd_pd(dt, wt, s0);
573
574 let sum01 = _mm512_add_pd(s0, s1);
575 let sum23 = _mm512_add_pd(s2, s3);
576 let tot = _mm512_add_pd(sum01, sum23);
577 *dst_ptr = hsum_pd_zmm(tot);
578
579 data_ptr = data_ptr.add(1);
580 dst_ptr = dst_ptr.add(1);
581 }
582 }
583}
584
585#[derive(Debug, Clone)]
586pub struct SinWmaStream {
587 period: usize,
588
589 r_re: f64,
590 r_im: f64,
591 rp_re: f64,
592 rp_im: f64,
593
594 sinp: f64,
595 cosp: f64,
596
597 inv_sum: f64,
598
599 z_re: f64,
600 z_im: f64,
601
602 buffer: Vec<f64>,
603 head: usize,
604 filled: bool,
605
606 nan_count: usize,
607 z_valid: bool,
608}
609
610impl SinWmaStream {
611 pub fn try_new(params: SinWmaParams) -> Result<Self, SinWmaError> {
612 let period = params.period.unwrap_or(14);
613 if period == 0 {
614 return Err(SinWmaError::InvalidPeriod {
615 period,
616 data_len: 0,
617 });
618 }
619
620 let alpha = PI / (period as f64 + 1.0);
621
622 let (sina, cosa) = alpha.sin_cos();
623 let (sinp, cosp) = (alpha * period as f64).sin_cos();
624
625 let denom = (0.5 * alpha).sin();
626 if denom.abs() < f64::EPSILON {
627 return Err(SinWmaError::ZeroSumSines { sum_sines: 0.0 });
628 }
629 let sum_sines = (0.5 * alpha * period as f64).sin() / denom;
630 if sum_sines.abs() < f64::EPSILON {
631 return Err(SinWmaError::ZeroSumSines { sum_sines });
632 }
633 let inv_sum = 1.0 / sum_sines;
634
635 Ok(Self {
636 period,
637 r_re: cosa,
638 r_im: sina,
639 rp_re: cosp,
640 rp_im: sinp,
641 sinp,
642 cosp,
643 inv_sum,
644 z_re: 0.0,
645 z_im: 0.0,
646 buffer: vec![f64::NAN; period],
647 head: 0,
648 filled: false,
649 nan_count: period,
650 z_valid: false,
651 })
652 }
653
654 #[inline(always)]
655 pub fn update(&mut self, value: f64) -> Option<f64> {
656 if !self.filled {
657 let idx = self.head;
658 let old = self.buffer[idx];
659 if old.is_nan() {
660 self.nan_count = self.nan_count.saturating_sub(1);
661 }
662 if value.is_nan() {
663 self.nan_count += 1;
664 }
665 self.buffer[idx] = value;
666 self.head = (idx + 1) % self.period;
667
668 if self.head != 0 {
669 return None;
670 }
671
672 self.filled = true;
673
674 if self.nan_count != 0 {
675 self.z_valid = false;
676 return Some(f64::NAN);
677 }
678
679 self.rebuild_z();
680 return Some(self.output_from_z());
681 }
682
683 let idx_old = self.head;
684 let x_old = self.buffer[idx_old];
685
686 self.buffer[idx_old] = value;
687 self.head = (idx_old + 1) % self.period;
688
689 if x_old.is_nan() {
690 self.nan_count = self.nan_count.saturating_sub(1);
691 }
692 if value.is_nan() {
693 self.nan_count += 1;
694 }
695
696 if self.nan_count != 0 {
697 self.z_valid = false;
698 return Some(f64::NAN);
699 }
700
701 if !self.z_valid {
702 self.rebuild_z();
703 return Some(self.output_from_z());
704 }
705
706 let rZ_re = self.r_re.mul_add(self.z_re, -self.r_im * self.z_im);
707 let rZ_im = self.r_re.mul_add(self.z_im, self.r_im * self.z_re);
708
709 self.z_re = rZ_re + value - self.rp_re * x_old;
710 self.z_im = rZ_im - self.rp_im * x_old;
711
712 Some(self.output_from_z())
713 }
714
715 #[inline(always)]
716 fn output_from_z(&self) -> f64 {
717 let y_unscaled = self.sinp.mul_add(self.z_re, -self.cosp * self.z_im);
718 y_unscaled * self.inv_sum
719 }
720
721 #[inline(always)]
722 fn rebuild_z(&mut self) {
723 debug_assert!(
724 self.nan_count == 0,
725 "rebuild_z called on NaN-contaminated window"
726 );
727 let newest = (self.head + self.period - 1) % self.period;
728
729 let mut rj_re = 1.0f64;
730 let mut rj_im = 0.0f64;
731
732 let mut zr = 0.0f64;
733 let mut zi = 0.0f64;
734
735 for j in 0..self.period {
736 let idx = (newest + self.period - j) % self.period;
737 let x = self.buffer[idx];
738 zr = rj_re.mul_add(x, zr);
739 zi = rj_im.mul_add(x, zi);
740
741 let nr_re = rj_re.mul_add(self.r_re, -rj_im * self.r_im);
742 let nr_im = rj_re.mul_add(self.r_im, rj_im * self.r_re);
743 rj_re = nr_re;
744 rj_im = nr_im;
745 }
746
747 self.z_re = zr;
748 self.z_im = zi;
749 self.z_valid = true;
750 }
751}
752
753#[derive(Clone, Debug)]
754pub struct SinWmaBatchRange {
755 pub period: (usize, usize, usize),
756}
757
758impl Default for SinWmaBatchRange {
759 fn default() -> Self {
760 Self {
761 period: (14, 263, 1),
762 }
763 }
764}
765
766#[derive(Clone, Debug, Default)]
767pub struct SinWmaBatchBuilder {
768 range: SinWmaBatchRange,
769 kernel: Kernel,
770}
771
772impl SinWmaBatchBuilder {
773 pub fn new() -> Self {
774 Self::default()
775 }
776 pub fn kernel(mut self, k: Kernel) -> Self {
777 self.kernel = k;
778 self
779 }
780
781 #[inline]
782 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
783 self.range.period = (start, end, step);
784 self
785 }
786 #[inline]
787 pub fn period_static(mut self, p: usize) -> Self {
788 self.range.period = (p, p, 0);
789 self
790 }
791
792 pub fn apply_slice(self, data: &[f64]) -> Result<SinWmaBatchOutput, SinWmaError> {
793 sinwma_batch_with_kernel(data, &self.range, self.kernel)
794 }
795
796 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<SinWmaBatchOutput, SinWmaError> {
797 SinWmaBatchBuilder::new().kernel(k).apply_slice(data)
798 }
799
800 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<SinWmaBatchOutput, SinWmaError> {
801 let slice = source_type(c, src);
802 self.apply_slice(slice)
803 }
804
805 pub fn with_default_candles(c: &Candles) -> Result<SinWmaBatchOutput, SinWmaError> {
806 SinWmaBatchBuilder::new()
807 .kernel(Kernel::Auto)
808 .apply_candles(c, "close")
809 }
810}
811
812pub fn sinwma_batch_with_kernel(
813 data: &[f64],
814 sweep: &SinWmaBatchRange,
815 k: Kernel,
816) -> Result<SinWmaBatchOutput, SinWmaError> {
817 let kernel = match k {
818 Kernel::Auto => detect_best_batch_kernel(),
819 other if other.is_batch() => other,
820 _ => return Err(SinWmaError::InvalidKernelForBatch(k)),
821 };
822
823 let simd = match kernel {
824 Kernel::Avx512Batch => Kernel::Avx512,
825 Kernel::Avx2Batch => Kernel::Avx2,
826 Kernel::ScalarBatch => Kernel::Scalar,
827 _ => unreachable!(),
828 };
829 sinwma_batch_par_slice(data, sweep, simd)
830}
831
832#[derive(Clone, Debug)]
833pub struct SinWmaBatchOutput {
834 pub values: Vec<f64>,
835 pub combos: Vec<SinWmaParams>,
836 pub rows: usize,
837 pub cols: usize,
838}
839impl SinWmaBatchOutput {
840 pub fn row_for_params(&self, p: &SinWmaParams) -> Option<usize> {
841 self.combos
842 .iter()
843 .position(|c| c.period.unwrap_or(14) == p.period.unwrap_or(14))
844 }
845
846 pub fn values_for(&self, p: &SinWmaParams) -> Option<&[f64]> {
847 self.row_for_params(p).map(|row| {
848 let start = row * self.cols;
849 &self.values[start..start + self.cols]
850 })
851 }
852}
853
854#[inline(always)]
855fn expand_grid(r: &SinWmaBatchRange) -> Result<Vec<SinWmaParams>, SinWmaError> {
856 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, SinWmaError> {
857 if step == 0 || start == end {
858 return Ok(vec![start]);
859 }
860 if start < end {
861 let v: Vec<usize> = (start..=end).step_by(step).collect();
862 if v.is_empty() {
863 return Err(SinWmaError::InvalidRange { start, end, step });
864 }
865 return Ok(v);
866 }
867
868 let mut v = Vec::new();
869 let mut cur = start;
870 while cur >= end {
871 v.push(cur);
872 if cur == end {
873 break;
874 }
875
876 let next = cur.saturating_sub(step);
877 if next == cur {
878 break;
879 }
880 cur = next;
881 }
882 if v.is_empty() {
883 return Err(SinWmaError::InvalidRange { start, end, step });
884 }
885 Ok(v)
886 }
887
888 let periods = axis_usize(r.period)?;
889
890 let mut out = Vec::with_capacity(periods.len());
891 for &p in &periods {
892 out.push(SinWmaParams { period: Some(p) });
893 }
894 Ok(out)
895}
896
897#[inline]
898fn round_up8(x: usize) -> usize {
899 (x + 7) & !7
900}
901
902#[inline(always)]
903pub unsafe fn sinwma_row_dispatch(
904 data: &[f64],
905 first: usize,
906 period: usize,
907 w_ptr: *const f64,
908 out: &mut [f64],
909 kern: Kernel,
910) {
911 match kern {
912 Kernel::Scalar | Kernel::ScalarBatch => sinwma_row_scalar(data, first, period, w_ptr, out),
913 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
914 Kernel::Avx2 | Kernel::Avx2Batch => sinwma_row_avx2(data, first, period, w_ptr, out),
915 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
916 Kernel::Avx512 | Kernel::Avx512Batch => sinwma_row_avx512(data, first, period, w_ptr, out),
917 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
918 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
919 sinwma_row_scalar(data, first, period, w_ptr, out)
920 }
921 _ => unreachable!(),
922 }
923}
924
925#[inline(always)]
926pub fn sinwma_batch_slice(
927 data: &[f64],
928 sweep: &SinWmaBatchRange,
929 kern: Kernel,
930) -> Result<SinWmaBatchOutput, SinWmaError> {
931 sinwma_batch_inner(data, sweep, kern, false)
932}
933
934#[inline(always)]
935pub fn sinwma_batch_par_slice(
936 data: &[f64],
937 sweep: &SinWmaBatchRange,
938 kern: Kernel,
939) -> Result<SinWmaBatchOutput, SinWmaError> {
940 sinwma_batch_inner(data, sweep, kern, true)
941}
942
943#[inline(always)]
944fn sinwma_batch_inner(
945 data: &[f64],
946 sweep: &SinWmaBatchRange,
947 kern: Kernel,
948 parallel: bool,
949) -> Result<SinWmaBatchOutput, SinWmaError> {
950 let combos = expand_grid(sweep)?;
951 if combos.is_empty() {
952 return Err(SinWmaError::InvalidRange {
953 start: sweep.period.0,
954 end: sweep.period.1,
955 step: sweep.period.2,
956 });
957 }
958
959 if data.is_empty() {
960 return Err(SinWmaError::EmptyInputData);
961 }
962
963 let first = data
964 .iter()
965 .position(|x| !x.is_nan())
966 .ok_or(SinWmaError::AllValuesNaN)?;
967 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
968 if data.len() - first < max_p {
969 return Err(SinWmaError::NotEnoughValidData {
970 needed: max_p,
971 valid: data.len() - first,
972 });
973 }
974
975 let rows = combos.len();
976 let cols = data.len();
977
978 let _total = rows.checked_mul(cols).ok_or(SinWmaError::InvalidRange {
979 start: sweep.period.0,
980 end: sweep.period.1,
981 step: sweep.period.2,
982 })?;
983 let warm: Vec<usize> = combos
984 .iter()
985 .map(|c| first + c.period.unwrap() - 1)
986 .collect();
987
988 let mut raw = make_uninit_matrix(rows, cols);
989 init_matrix_prefixes(&mut raw, cols, &warm);
990
991 let stride = round_up8(max_p);
992 let cap = rows.checked_mul(stride).ok_or(SinWmaError::InvalidRange {
993 start: sweep.period.0,
994 end: sweep.period.1,
995 step: sweep.period.2,
996 })?;
997 let mut flat_w = AVec::<f64>::with_capacity(CACHELINE_ALIGN, cap);
998 flat_w.resize(cap, 0.0);
999
1000 for (row, prm) in combos.iter().enumerate() {
1001 let p = prm.period.unwrap();
1002 let base = row * stride;
1003 let mut sum = 0.0;
1004 for k in 0..p {
1005 let a = (k as f64 + 1.0) * PI / (p as f64 + 1.0);
1006 let v = a.sin();
1007 flat_w[base + k] = v;
1008 sum += v;
1009 }
1010 let inv = 1.0 / sum;
1011 for k in 0..p {
1012 flat_w[base + k] *= inv;
1013 }
1014 }
1015
1016 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
1017 let p = combos[row].period.unwrap();
1018 let w_ptr = flat_w.as_ptr().add(row * stride);
1019 let dst = core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
1020 sinwma_row_dispatch(data, first, p, w_ptr, dst, kern);
1021 };
1022
1023 if parallel {
1024 #[cfg(not(target_arch = "wasm32"))]
1025 {
1026 raw.par_chunks_mut(cols)
1027 .enumerate()
1028 .for_each(|(r, sl)| do_row(r, sl));
1029 }
1030
1031 #[cfg(target_arch = "wasm32")]
1032 {
1033 for (r, sl) in raw.chunks_mut(cols).enumerate() {
1034 do_row(r, sl);
1035 }
1036 }
1037 } else {
1038 for (r, sl) in raw.chunks_mut(cols).enumerate() {
1039 do_row(r, sl);
1040 }
1041 }
1042
1043 let mut guard = core::mem::ManuallyDrop::new(raw);
1044 let values = unsafe {
1045 Vec::from_raw_parts(
1046 guard.as_mut_ptr() as *mut f64,
1047 guard.len(),
1048 guard.capacity(),
1049 )
1050 };
1051
1052 Ok(SinWmaBatchOutput {
1053 values,
1054 combos,
1055 rows,
1056 cols,
1057 })
1058}
1059
1060#[inline(always)]
1061fn sinwma_batch_inner_into(
1062 data: &[f64],
1063 sweep: &SinWmaBatchRange,
1064 kern: Kernel,
1065 parallel: bool,
1066 out: &mut [f64],
1067) -> Result<Vec<SinWmaParams>, SinWmaError> {
1068 let combos = expand_grid(sweep)?;
1069 if combos.is_empty() {
1070 return Err(SinWmaError::InvalidRange {
1071 start: sweep.period.0,
1072 end: sweep.period.1,
1073 step: sweep.period.2,
1074 });
1075 }
1076
1077 if data.is_empty() {
1078 return Err(SinWmaError::EmptyInputData);
1079 }
1080
1081 let first = data
1082 .iter()
1083 .position(|x| !x.is_nan())
1084 .ok_or(SinWmaError::AllValuesNaN)?;
1085 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1086 if data.len() - first < max_p {
1087 return Err(SinWmaError::NotEnoughValidData {
1088 needed: max_p,
1089 valid: data.len() - first,
1090 });
1091 }
1092
1093 let rows = combos.len();
1094 let cols = data.len();
1095 let warm: Vec<usize> = combos
1096 .iter()
1097 .map(|c| first + c.period.unwrap() - 1)
1098 .collect();
1099
1100 let expected = rows.checked_mul(cols).ok_or(SinWmaError::InvalidRange {
1101 start: sweep.period.0,
1102 end: sweep.period.1,
1103 step: sweep.period.2,
1104 })?;
1105 if out.len() != expected {
1106 return Err(SinWmaError::OutputLengthMismatch {
1107 expected,
1108 got: out.len(),
1109 });
1110 }
1111
1112 let out_mu = unsafe {
1113 std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
1114 };
1115 init_matrix_prefixes(out_mu, cols, &warm);
1116
1117 let stride = round_up8(max_p);
1118 let cap = rows.checked_mul(stride).ok_or(SinWmaError::InvalidRange {
1119 start: sweep.period.0,
1120 end: sweep.period.1,
1121 step: sweep.period.2,
1122 })?;
1123 let mut flat_w = AVec::<f64>::with_capacity(CACHELINE_ALIGN, cap);
1124 flat_w.resize(cap, 0.0);
1125
1126 for (row, prm) in combos.iter().enumerate() {
1127 let p = prm.period.unwrap();
1128 let base = row * stride;
1129 let mut sum = 0.0;
1130 for k in 0..p {
1131 let a = (k as f64 + 1.0) * PI / (p as f64 + 1.0);
1132 let v = a.sin();
1133 flat_w[base + k] = v;
1134 sum += v;
1135 }
1136 let inv = 1.0 / sum;
1137 for k in 0..p {
1138 flat_w[base + k] *= inv;
1139 }
1140 }
1141
1142 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
1143 let p = combos[row].period.unwrap();
1144 let w_ptr = flat_w.as_ptr().add(row * stride);
1145 let dst = core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
1146 sinwma_row_dispatch(data, first, p, w_ptr, dst, kern);
1147 };
1148
1149 if parallel {
1150 #[cfg(not(target_arch = "wasm32"))]
1151 {
1152 out_mu
1153 .par_chunks_mut(cols)
1154 .enumerate()
1155 .for_each(|(r, sl)| do_row(r, sl));
1156 }
1157 #[cfg(target_arch = "wasm32")]
1158 {
1159 for (r, sl) in out_mu.chunks_mut(cols).enumerate() {
1160 do_row(r, sl);
1161 }
1162 }
1163 } else {
1164 for (r, sl) in out_mu.chunks_mut(cols).enumerate() {
1165 do_row(r, sl);
1166 }
1167 }
1168
1169 Ok(combos)
1170}
1171
1172#[inline(always)]
1173pub unsafe fn sinwma_row_scalar(
1174 data: &[f64],
1175 first: usize,
1176 period: usize,
1177 w_ptr: *const f64,
1178 out: &mut [f64],
1179) {
1180 let p4 = period & !3;
1181 for i in (first + period - 1)..data.len() {
1182 let start = i + 1 - period;
1183 let mut sum = 0.0;
1184 for k in (0..p4).step_by(4) {
1185 let w = std::slice::from_raw_parts(w_ptr.add(k), 4);
1186 let d = &data[start + k..start + k + 4];
1187 sum += d[0] * w[0] + d[1] * w[1] + d[2] * w[2] + d[3] * w[3];
1188 }
1189 for k in p4..period {
1190 sum += *data.get_unchecked(start + k) * *w_ptr.add(k);
1191 }
1192 out[i] = sum;
1193 }
1194}
1195
1196#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1197#[inline]
1198#[target_feature(enable = "avx2,fma")]
1199pub unsafe fn sinwma_row_avx2(
1200 data: &[f64],
1201 first: usize,
1202 period: usize,
1203 w_ptr: *const f64,
1204 out: &mut [f64],
1205) {
1206 let p4 = period & !3usize;
1207 for i in (first + period - 1)..data.len() {
1208 let start = i + 1 - period;
1209 let mut acc = _mm256_setzero_pd();
1210
1211 let mut k = 0usize;
1212 while k < p4 {
1213 let d = _mm256_loadu_pd(data.as_ptr().add(start + k));
1214 let w = _mm256_loadu_pd(w_ptr.add(k));
1215 acc = _mm256_fmadd_pd(d, w, acc);
1216 k += 4;
1217 }
1218
1219 let sum128 = _mm_add_pd(_mm256_castpd256_pd128(acc), _mm256_extractf128_pd(acc, 1));
1220 let mut sum = _mm_cvtsd_f64(_mm_hadd_pd(sum128, sum128));
1221
1222 while k < period {
1223 sum = (*data.get_unchecked(start + k)).mul_add(*w_ptr.add(k), sum);
1224 k += 1;
1225 }
1226
1227 *out.get_unchecked_mut(i) = sum;
1228 }
1229}
1230
1231#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1232#[inline]
1233#[target_feature(enable = "avx512f,fma,avx512dq")]
1234pub unsafe fn sinwma_row_avx512(
1235 data: &[f64],
1236 first: usize,
1237 period: usize,
1238 w_ptr: *const f64,
1239 out: &mut [f64],
1240) {
1241 if period <= 32 {
1242 sinwma_row_avx512_short(data, first, period, w_ptr, out);
1243 } else {
1244 sinwma_row_avx512_long(data, first, period, w_ptr, out);
1245 }
1246}
1247
1248#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1249#[inline]
1250#[target_feature(enable = "avx512f,fma")]
1251pub unsafe fn sinwma_row_avx512_short(
1252 data: &[f64],
1253 first: usize,
1254 period: usize,
1255 w_ptr: *const f64,
1256 out: &mut [f64],
1257) {
1258 const STEP: usize = 8;
1259 let chunks = period / STEP;
1260 let tail = period % STEP;
1261 let mask: __mmask8 = (1u8 << tail).wrapping_sub(1);
1262
1263 if chunks == 0 {
1264 let wv = _mm512_maskz_loadu_pd(mask, w_ptr);
1265 for i in (first + period - 1)..data.len() {
1266 let start = i + 1 - period;
1267 let dv = _mm512_maskz_loadu_pd(mask, data.as_ptr().add(start));
1268 let sum = hsum_pd_zmm(_mm512_mul_pd(dv, wv));
1269 *out.get_unchecked_mut(i) = sum;
1270 }
1271 return;
1272 }
1273
1274 for i in (first + period - 1)..data.len() {
1275 let start = i + 1 - period;
1276 let mut acc = _mm512_setzero_pd();
1277 for blk in 0..chunks {
1278 let d = _mm512_loadu_pd(data.as_ptr().add(start + blk * STEP));
1279 let w = _mm512_loadu_pd(w_ptr.add(blk * STEP));
1280 acc = _mm512_fmadd_pd(d, w, acc);
1281 }
1282 if tail != 0 {
1283 let dt = _mm512_maskz_loadu_pd(mask, data.as_ptr().add(start + chunks * STEP));
1284 let wt = _mm512_maskz_loadu_pd(mask, w_ptr.add(chunks * STEP));
1285 acc = _mm512_fmadd_pd(dt, wt, acc);
1286 }
1287 *out.get_unchecked_mut(i) = hsum_pd_zmm(acc);
1288 }
1289}
1290
1291#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1292#[inline]
1293#[target_feature(enable = "avx512f,fma,avx512dq")]
1294pub unsafe fn sinwma_row_avx512_long(
1295 data: &[f64],
1296 first: usize,
1297 period: usize,
1298 w_ptr: *const f64,
1299 out: &mut [f64],
1300) {
1301 const STEP: usize = 8;
1302 let n_chunks = period / STEP;
1303 let tail_len = period % STEP;
1304 let unroll4 = n_chunks & !3;
1305 let mask: __mmask8 = (1u8 << tail_len).wrapping_sub(1);
1306
1307 let mut wregs: Vec<__m512d> = Vec::with_capacity(n_chunks);
1308 for blk in 0..n_chunks {
1309 wregs.push(_mm512_loadu_pd(w_ptr.add(blk * STEP)));
1310 }
1311 let wt = if tail_len != 0 {
1312 Some(_mm512_maskz_loadu_pd(mask, w_ptr.add(n_chunks * STEP)))
1313 } else {
1314 None
1315 };
1316
1317 for i in (first + period - 1)..data.len() {
1318 let start = i + 1 - period;
1319 let mut s0 = _mm512_setzero_pd();
1320 let mut s1 = _mm512_setzero_pd();
1321 let mut s2 = _mm512_setzero_pd();
1322 let mut s3 = _mm512_setzero_pd();
1323
1324 for blk in (0..unroll4).step_by(4) {
1325 let d0 = _mm512_loadu_pd(data.as_ptr().add(start + (blk + 0) * STEP));
1326 let d1 = _mm512_loadu_pd(data.as_ptr().add(start + (blk + 1) * STEP));
1327 let d2 = _mm512_loadu_pd(data.as_ptr().add(start + (blk + 2) * STEP));
1328 let d3 = _mm512_loadu_pd(data.as_ptr().add(start + (blk + 3) * STEP));
1329
1330 s0 = _mm512_fmadd_pd(d0, *wregs.get_unchecked(blk + 0), s0);
1331 s1 = _mm512_fmadd_pd(d1, *wregs.get_unchecked(blk + 1), s1);
1332 s2 = _mm512_fmadd_pd(d2, *wregs.get_unchecked(blk + 2), s2);
1333 s3 = _mm512_fmadd_pd(d3, *wregs.get_unchecked(blk + 3), s3);
1334 }
1335
1336 for blk in unroll4..n_chunks {
1337 let d = _mm512_loadu_pd(data.as_ptr().add(start + blk * STEP));
1338 s0 = _mm512_fmadd_pd(d, *wregs.get_unchecked(blk), s0);
1339 }
1340
1341 if let Some(wt) = wt {
1342 let dt = _mm512_maskz_loadu_pd(mask, data.as_ptr().add(start + n_chunks * STEP));
1343 s0 = _mm512_fmadd_pd(dt, wt, s0);
1344 }
1345
1346 let sum01 = _mm512_add_pd(s0, s1);
1347 let sum23 = _mm512_add_pd(s2, s3);
1348 let tot = _mm512_add_pd(sum01, sum23);
1349 *out.get_unchecked_mut(i) = hsum_pd_zmm(tot);
1350 }
1351}
1352
1353#[cfg(test)]
1354mod tests {
1355 use super::*;
1356 use crate::skip_if_unsupported;
1357 use crate::utilities::data_loader::read_candles_from_csv;
1358 #[cfg(feature = "proptest")]
1359 use proptest::prelude::*;
1360
1361 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1362 #[test]
1363 fn test_sinwma_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
1364 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1365 let candles = read_candles_from_csv(file_path)?;
1366
1367 let input = SinWmaInput::with_default_candles(&candles);
1368
1369 let baseline = sinwma(&input)?.values;
1370
1371 let mut out = vec![0.0; candles.close.len()];
1372 sinwma_into(&input, &mut out)?;
1373
1374 assert_eq!(baseline.len(), out.len());
1375
1376 fn eq_or_both_nan(a: f64, b: f64) -> bool {
1377 (a.is_nan() && b.is_nan()) || (a == b)
1378 }
1379
1380 for (i, (a, b)) in baseline.iter().zip(out.iter()).enumerate() {
1381 assert!(
1382 eq_or_both_nan(*a, *b),
1383 "mismatch at index {}: baseline={}, into={}",
1384 i,
1385 a,
1386 b
1387 );
1388 }
1389
1390 Ok(())
1391 }
1392
1393 fn check_sinwma_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1394 skip_if_unsupported!(kernel, test_name);
1395 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1396 let candles = read_candles_from_csv(file_path)?;
1397 let default_params = SinWmaParams { period: None };
1398 let input = SinWmaInput::from_candles(&candles, "close", default_params);
1399 let output = sinwma_with_kernel(&input, kernel)?;
1400 assert_eq!(output.values.len(), candles.close.len());
1401 Ok(())
1402 }
1403
1404 fn check_sinwma_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1405 skip_if_unsupported!(kernel, test_name);
1406 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1407 let candles = read_candles_from_csv(file_path)?;
1408 let input = SinWmaInput::from_candles(&candles, "close", SinWmaParams { period: Some(14) });
1409 let result = sinwma_with_kernel(&input, kernel)?;
1410 let expected_last_five = [
1411 59376.72903536103,
1412 59300.76862770367,
1413 59229.27622157621,
1414 59178.48781774477,
1415 59154.66580703081,
1416 ];
1417 let start = result.values.len().saturating_sub(5);
1418 for (i, &val) in result.values[start..].iter().enumerate() {
1419 let diff = (val - expected_last_five[i]).abs();
1420 assert!(
1421 diff < 1e-6,
1422 "[{}] SINWMA {:?} mismatch at idx {}: got {}, expected {}",
1423 test_name,
1424 kernel,
1425 i,
1426 val,
1427 expected_last_five[i]
1428 );
1429 }
1430 Ok(())
1431 }
1432
1433 fn check_sinwma_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1434 skip_if_unsupported!(kernel, test_name);
1435 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1436 let candles = read_candles_from_csv(file_path)?;
1437 let input = SinWmaInput::with_default_candles(&candles);
1438 match input.data {
1439 SinWmaData::Candles { source, .. } => assert_eq!(source, "close"),
1440 _ => panic!("Expected SinWmaData::Candles"),
1441 }
1442 let output = sinwma_with_kernel(&input, kernel)?;
1443 assert_eq!(output.values.len(), candles.close.len());
1444 Ok(())
1445 }
1446
1447 fn check_sinwma_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1448 skip_if_unsupported!(kernel, test_name);
1449 let input_data = [10.0, 20.0, 30.0];
1450 let params = SinWmaParams { period: Some(0) };
1451 let input = SinWmaInput::from_slice(&input_data, params);
1452 let res = sinwma_with_kernel(&input, kernel);
1453 assert!(
1454 res.is_err(),
1455 "[{}] SINWMA should fail with zero period",
1456 test_name
1457 );
1458 Ok(())
1459 }
1460
1461 fn check_sinwma_period_exceeds_length(
1462 test_name: &str,
1463 kernel: Kernel,
1464 ) -> Result<(), Box<dyn Error>> {
1465 skip_if_unsupported!(kernel, test_name);
1466 let data_small = [10.0, 20.0, 30.0];
1467 let params = SinWmaParams { period: Some(10) };
1468 let input = SinWmaInput::from_slice(&data_small, params);
1469 let res = sinwma_with_kernel(&input, kernel);
1470 assert!(
1471 res.is_err(),
1472 "[{}] SINWMA should fail with period exceeding length",
1473 test_name
1474 );
1475 Ok(())
1476 }
1477
1478 fn check_sinwma_very_small_dataset(
1479 test_name: &str,
1480 kernel: Kernel,
1481 ) -> Result<(), Box<dyn Error>> {
1482 skip_if_unsupported!(kernel, test_name);
1483 let single_point = [42.0];
1484 let params = SinWmaParams { period: Some(14) };
1485 let input = SinWmaInput::from_slice(&single_point, params);
1486 let res = sinwma_with_kernel(&input, kernel);
1487 assert!(
1488 res.is_err(),
1489 "[{}] SINWMA should fail with insufficient data",
1490 test_name
1491 );
1492 Ok(())
1493 }
1494
1495 fn check_sinwma_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1496 skip_if_unsupported!(kernel, test_name);
1497 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1498 let candles = read_candles_from_csv(file_path)?;
1499
1500 let first_params = SinWmaParams { period: Some(14) };
1501 let first_input = SinWmaInput::from_candles(&candles, "close", first_params);
1502 let first_result = sinwma_with_kernel(&first_input, kernel)?;
1503
1504 let second_params = SinWmaParams { period: Some(5) };
1505 let second_input = SinWmaInput::from_slice(&first_result.values, second_params);
1506 let second_result = sinwma_with_kernel(&second_input, kernel)?;
1507
1508 assert_eq!(second_result.values.len(), first_result.values.len());
1509 for &val in second_result.values.iter().skip(240) {
1510 assert!(val.is_finite());
1511 }
1512 Ok(())
1513 }
1514
1515 fn check_sinwma_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1516 skip_if_unsupported!(kernel, test_name);
1517 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1518 let candles = read_candles_from_csv(file_path)?;
1519 let input = SinWmaInput::from_candles(&candles, "close", SinWmaParams { period: Some(14) });
1520 let res = sinwma_with_kernel(&input, kernel)?;
1521 assert_eq!(res.values.len(), candles.close.len());
1522 if res.values.len() > 240 {
1523 for (i, &val) in res.values[240..].iter().enumerate() {
1524 assert!(
1525 !val.is_nan(),
1526 "[{}] Found unexpected NaN at out-index {}",
1527 test_name,
1528 240 + i
1529 );
1530 }
1531 }
1532 Ok(())
1533 }
1534
1535 fn check_sinwma_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1536 skip_if_unsupported!(kernel, test_name);
1537
1538 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1539 let candles = read_candles_from_csv(file_path)?;
1540
1541 let period = 14;
1542
1543 let input = SinWmaInput::from_candles(
1544 &candles,
1545 "close",
1546 SinWmaParams {
1547 period: Some(period),
1548 },
1549 );
1550 let batch_output = sinwma_with_kernel(&input, kernel)?.values;
1551
1552 let mut stream = SinWmaStream::try_new(SinWmaParams {
1553 period: Some(period),
1554 })?;
1555
1556 let mut stream_values = Vec::with_capacity(candles.close.len());
1557 for &price in &candles.close {
1558 match stream.update(price) {
1559 Some(val) => stream_values.push(val),
1560 None => stream_values.push(f64::NAN),
1561 }
1562 }
1563
1564 assert_eq!(batch_output.len(), stream_values.len());
1565 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1566 if b.is_nan() && s.is_nan() {
1567 continue;
1568 }
1569 let diff = (b - s).abs();
1570 assert!(
1571 diff < 1e-9,
1572 "[{}] SINWMA streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1573 test_name,
1574 i,
1575 b,
1576 s,
1577 diff
1578 );
1579 }
1580 Ok(())
1581 }
1582
1583 #[cfg(debug_assertions)]
1584 fn check_sinwma_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1585 skip_if_unsupported!(kernel, test_name);
1586
1587 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1588 let candles = read_candles_from_csv(file_path)?;
1589
1590 let test_periods = vec![5, 10, 14, 20, 30, 50];
1591
1592 for period in test_periods {
1593 let params = SinWmaParams {
1594 period: Some(period),
1595 };
1596 let input = SinWmaInput::from_candles(&candles, "close", params);
1597 let output = sinwma_with_kernel(&input, kernel)?;
1598
1599 for (i, &val) in output.values.iter().enumerate() {
1600 if val.is_nan() {
1601 continue;
1602 }
1603
1604 let bits = val.to_bits();
1605
1606 if bits == 0x11111111_11111111 {
1607 panic!(
1608 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} (period={})",
1609 test_name, val, bits, i, period
1610 );
1611 }
1612
1613 if bits == 0x22222222_22222222 {
1614 panic!(
1615 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} (period={})",
1616 test_name, val, bits, i, period
1617 );
1618 }
1619
1620 if bits == 0x33333333_33333333 {
1621 panic!(
1622 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} (period={})",
1623 test_name, val, bits, i, period
1624 );
1625 }
1626 }
1627 }
1628
1629 Ok(())
1630 }
1631
1632 #[cfg(not(debug_assertions))]
1633 fn check_sinwma_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1634 Ok(())
1635 }
1636
1637 #[cfg(feature = "proptest")]
1638 #[allow(clippy::float_cmp)]
1639 fn check_sinwma_property(
1640 test_name: &str,
1641 kernel: Kernel,
1642 ) -> Result<(), Box<dyn std::error::Error>> {
1643 use proptest::prelude::*;
1644 skip_if_unsupported!(kernel, test_name);
1645
1646 let strat = (1usize..=100).prop_flat_map(|period| {
1647 (
1648 prop::collection::vec(
1649 (-1e5f64..1e5f64).prop_filter("finite", |x| x.is_finite()),
1650 period..=500,
1651 ),
1652 Just(period),
1653 )
1654 });
1655
1656 proptest::test_runner::TestRunner::default()
1657 .run(&strat, |(data, period)| {
1658 let params = SinWmaParams { period: Some(period) };
1659 let input = SinWmaInput::from_slice(&data, params);
1660
1661 let SinWmaOutput { values: out } = sinwma_with_kernel(&input, kernel).unwrap();
1662 let SinWmaOutput { values: ref_out } = sinwma_with_kernel(&input, Kernel::Scalar).unwrap();
1663
1664
1665 prop_assert_eq!(out.len(), data.len(), "Output length should match input length");
1666
1667
1668
1669 let warmup_end = period - 1;
1670 for i in 0..warmup_end.min(data.len()) {
1671 prop_assert!(
1672 out[i].is_nan(),
1673 "[{}] Expected NaN at index {} during warmup (period={})",
1674 test_name,
1675 i,
1676 period
1677 );
1678 }
1679
1680
1681 for i in warmup_end..data.len() {
1682 let y = out[i];
1683
1684
1685 if y.is_nan() {
1686 continue;
1687 }
1688
1689
1690 let window_start = i + 1 - period;
1691 let window = &data[window_start..=i];
1692
1693 let lo = window.iter().cloned().fold(f64::INFINITY, f64::min);
1694 let hi = window.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1695
1696
1697 let tolerance = 1e-9 + (hi - lo).abs() * 1e-12;
1698 prop_assert!(
1699 y >= lo - tolerance && y <= hi + tolerance,
1700 "[{}] idx {}: value {} not in window bounds [{}, {}] (period={})",
1701 test_name,
1702 i,
1703 y,
1704 lo,
1705 hi,
1706 period
1707 );
1708 }
1709
1710
1711 if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-12) && !data.is_empty() {
1712 for i in warmup_end..data.len() {
1713 if !out[i].is_nan() {
1714 prop_assert!(
1715 (out[i] - data[0]).abs() <= 1e-9,
1716 "[{}] Constant input should produce constant output: expected {}, got {} at index {}",
1717 test_name,
1718 data[0],
1719 out[i],
1720 i
1721 );
1722 }
1723 }
1724 }
1725
1726
1727 if period == 1 {
1728
1729
1730 for i in 0..data.len() {
1731 if !out[i].is_nan() && !data[i].is_nan() {
1732 prop_assert!(
1733 (out[i] - data[i]).abs() <= 1e-12,
1734 "[{}] Period=1 should pass through input: expected {}, got {} at index {}",
1735 test_name,
1736 data[i],
1737 out[i],
1738 i
1739 );
1740 }
1741 }
1742 }
1743
1744
1745
1746 for i in 0..data.len() {
1747 if out[i].is_nan() && ref_out[i].is_nan() {
1748 continue;
1749 }
1750
1751
1752 let y_bits = out[i].to_bits();
1753 let r_bits = ref_out[i].to_bits();
1754
1755 prop_assert_eq!(
1756 y_bits,
1757 r_bits,
1758 "[{}] Kernel consistency failed at index {}: {:?} gives {}, Scalar gives {}",
1759 test_name,
1760 i,
1761 kernel,
1762 out[i],
1763 ref_out[i]
1764 );
1765 }
1766
1767
1768
1769
1770
1771
1772 for i in warmup_end..data.len() {
1773 if out[i].is_nan() {
1774 continue;
1775 }
1776
1777 let window_start = i + 1 - period;
1778 let window = &data[window_start..=i];
1779
1780 let all_positive = window.iter().all(|&x| x > 0.0);
1781 let all_negative = window.iter().all(|&x| x < 0.0);
1782
1783 if all_positive {
1784 prop_assert!(
1785 out[i] > 0.0,
1786 "[{}] All positive window should produce positive output, got {} at index {}",
1787 test_name,
1788 out[i],
1789 i
1790 );
1791 }
1792
1793 if all_negative {
1794 prop_assert!(
1795 out[i] < 0.0,
1796 "[{}] All negative window should produce negative output, got {} at index {}",
1797 test_name,
1798 out[i],
1799 i
1800 );
1801 }
1802 }
1803
1804
1805
1806
1807
1808 if data.len() >= period * 2 {
1809 for i in (warmup_end + period)..data.len().min(warmup_end + period * 3) {
1810 if out[i].is_nan() {
1811 continue;
1812 }
1813
1814 let window_start = i + 1 - period;
1815 let window = &data[window_start..=i];
1816
1817
1818 let window_min = window.iter().cloned().fold(f64::INFINITY, f64::min);
1819 let window_max = window.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1820 let range = window_max - window_min;
1821
1822 if range > 1.0 {
1823
1824 let middle_idx = period / 2;
1825 let middle_value = window[middle_idx];
1826 let first_value = window[0];
1827 let last_value = window[period - 1];
1828
1829
1830
1831 let clearly_ascending = first_value < middle_value && middle_value < last_value
1832 && (last_value - first_value) > range * 0.8;
1833 let clearly_descending = first_value > middle_value && middle_value > last_value
1834 && (first_value - last_value) > range * 0.8;
1835
1836 if clearly_ascending || clearly_descending {
1837
1838 let dist_to_middle = (out[i] - middle_value).abs();
1839 let dist_to_first = (out[i] - first_value).abs();
1840 let dist_to_last = (out[i] - last_value).abs();
1841
1842
1843
1844 prop_assert!(
1845 dist_to_middle < dist_to_first.min(dist_to_last) * 1.2,
1846 "[{}] idx {}: SINWMA output {} should be closer to middle {} than to extremes [{}, {}] (period={})",
1847 test_name,
1848 i,
1849 out[i],
1850 middle_value,
1851 first_value,
1852 last_value,
1853 period
1854 );
1855 }
1856 }
1857 }
1858 }
1859
1860 Ok(())
1861 })?;
1862
1863 Ok(())
1864 }
1865
1866 macro_rules! generate_all_sinwma_tests {
1867 ($($test_fn:ident),*) => {
1868 paste::paste! {
1869 $(
1870 #[test]
1871 fn [<$test_fn _scalar_f64>]() {
1872 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1873 }
1874 )*
1875 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1876 $(
1877 #[test]
1878 fn [<$test_fn _avx2_f64>]() {
1879 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1880 }
1881 #[test]
1882 fn [<$test_fn _avx512_f64>]() {
1883 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1884 }
1885 )*
1886 }
1887 }
1888 }
1889
1890 generate_all_sinwma_tests!(
1891 check_sinwma_partial_params,
1892 check_sinwma_accuracy,
1893 check_sinwma_default_candles,
1894 check_sinwma_zero_period,
1895 check_sinwma_period_exceeds_length,
1896 check_sinwma_very_small_dataset,
1897 check_sinwma_reinput,
1898 check_sinwma_nan_handling,
1899 check_sinwma_streaming,
1900 check_sinwma_no_poison
1901 );
1902
1903 #[cfg(feature = "proptest")]
1904 generate_all_sinwma_tests!(check_sinwma_property);
1905
1906 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1907 skip_if_unsupported!(kernel, test);
1908
1909 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1910 let c = read_candles_from_csv(file)?;
1911
1912 let output = SinWmaBatchBuilder::new()
1913 .kernel(kernel)
1914 .apply_candles(&c, "close")?;
1915
1916 let def = SinWmaParams::default();
1917 let row = output.values_for(&def).expect("default row missing");
1918
1919 assert_eq!(row.len(), c.close.len());
1920
1921 let expected = [
1922 59376.72903536103,
1923 59300.76862770367,
1924 59229.27622157621,
1925 59178.48781774477,
1926 59154.66580703081,
1927 ];
1928 let start = row.len() - 5;
1929 for (i, &v) in row[start..].iter().enumerate() {
1930 assert!(
1931 (v - expected[i]).abs() < 1e-6,
1932 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1933 );
1934 }
1935 Ok(())
1936 }
1937
1938 #[cfg(debug_assertions)]
1939 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1940 skip_if_unsupported!(kernel, test);
1941
1942 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1943 let c = read_candles_from_csv(file)?;
1944
1945 let test_configs = vec![(5, 15, 5), (10, 30, 10), (20, 50, 15), (2, 10, 2)];
1946
1947 for (start, end, step) in test_configs {
1948 let output = SinWmaBatchBuilder::new()
1949 .kernel(kernel)
1950 .period_range(start, end, step)
1951 .apply_candles(&c, "close")?;
1952
1953 for (idx, &val) in output.values.iter().enumerate() {
1954 if val.is_nan() {
1955 continue;
1956 }
1957
1958 let bits = val.to_bits();
1959 let row = idx / output.cols;
1960 let col = idx % output.cols;
1961 let period = output.combos[row].period.unwrap();
1962
1963 if bits == 0x11111111_11111111 {
1964 panic!(
1965 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (flat index {}, period={})",
1966 test, val, bits, row, col, idx, period
1967 );
1968 }
1969
1970 if bits == 0x22222222_22222222 {
1971 panic!(
1972 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (flat index {}, period={})",
1973 test, val, bits, row, col, idx, period
1974 );
1975 }
1976
1977 if bits == 0x33333333_33333333 {
1978 panic!(
1979 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (flat index {}, period={})",
1980 test, val, bits, row, col, idx, period
1981 );
1982 }
1983 }
1984 }
1985
1986 Ok(())
1987 }
1988
1989 #[cfg(not(debug_assertions))]
1990 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1991 Ok(())
1992 }
1993
1994 macro_rules! gen_batch_tests {
1995 ($fn_name:ident) => {
1996 paste::paste! {
1997 #[test] fn [<$fn_name _scalar>]() {
1998 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1999 }
2000 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2001 #[test] fn [<$fn_name _avx2>]() {
2002 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2003 }
2004 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2005 #[test] fn [<$fn_name _avx512>]() {
2006 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2007 }
2008 #[test] fn [<$fn_name _auto_detect>]() {
2009 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2010 }
2011 }
2012 };
2013 }
2014 gen_batch_tests!(check_batch_default_row);
2015 gen_batch_tests!(check_batch_no_poison);
2016}
2017
2018#[cfg(feature = "python")]
2019#[pyfunction(name = "sinwma")]
2020#[pyo3(signature = (data, period, kernel=None))]
2021
2022pub fn sinwma_py<'py>(
2023 py: Python<'py>,
2024 data: PyReadonlyArray1<'py, f64>,
2025 period: usize,
2026 kernel: Option<&str>,
2027) -> PyResult<Bound<'py, PyArray1<f64>>> {
2028 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
2029
2030 let slice_in = data.as_slice()?;
2031 let kern = validate_kernel(kernel, false)?;
2032
2033 let params = SinWmaParams {
2034 period: Some(period),
2035 };
2036 let input = SinWmaInput::from_slice(slice_in, params);
2037
2038 let result_vec: Vec<f64> = py
2039 .allow_threads(|| sinwma_with_kernel(&input, kern).map(|o| o.values))
2040 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2041
2042 Ok(result_vec.into_pyarray(py))
2043}
2044
2045#[cfg(feature = "python")]
2046#[pyclass(name = "SinWmaStream")]
2047
2048pub struct SinWmaStreamPy {
2049 stream: SinWmaStream,
2050}
2051
2052#[cfg(feature = "python")]
2053#[pymethods]
2054impl SinWmaStreamPy {
2055 #[new]
2056
2057 fn new(period: usize) -> PyResult<Self> {
2058 let params = SinWmaParams {
2059 period: Some(period),
2060 };
2061 let stream =
2062 SinWmaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2063 Ok(SinWmaStreamPy { stream })
2064 }
2065
2066 fn update(&mut self, value: f64) -> Option<f64> {
2067 self.stream.update(value)
2068 }
2069}
2070
2071#[cfg(feature = "python")]
2072#[pyfunction(name = "sinwma_batch")]
2073#[pyo3(signature = (data, period_range, kernel=None))]
2074
2075pub fn sinwma_batch_py<'py>(
2076 py: Python<'py>,
2077 data: PyReadonlyArray1<'py, f64>,
2078 period_range: (usize, usize, usize),
2079 kernel: Option<&str>,
2080) -> PyResult<Bound<'py, PyDict>> {
2081 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
2082 use pyo3::types::PyDict;
2083
2084 let slice_in = data.as_slice()?;
2085
2086 let sweep = SinWmaBatchRange {
2087 period: period_range,
2088 };
2089
2090 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
2091 let rows = combos.len();
2092 let cols = slice_in.len();
2093
2094 let out_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
2095 let slice_out = unsafe { out_arr.as_slice_mut()? };
2096
2097 let kern = validate_kernel(kernel, true)?;
2098
2099 let combos = py
2100 .allow_threads(|| {
2101 let kernel = match kern {
2102 Kernel::Auto => detect_best_batch_kernel(),
2103 k => k,
2104 };
2105 let simd = match kernel {
2106 Kernel::Avx512Batch => Kernel::Avx512,
2107 Kernel::Avx2Batch => Kernel::Avx2,
2108 Kernel::ScalarBatch => Kernel::Scalar,
2109 _ => unreachable!(),
2110 };
2111 sinwma_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
2112 })
2113 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2114
2115 let dict = PyDict::new(py);
2116 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2117 dict.set_item(
2118 "periods",
2119 combos
2120 .iter()
2121 .map(|p| p.period.unwrap() as u64)
2122 .collect::<Vec<_>>()
2123 .into_pyarray(py),
2124 )?;
2125
2126 Ok(dict)
2127}
2128
2129#[cfg(all(feature = "python", feature = "cuda"))]
2130#[pyfunction(name = "sinwma_cuda_batch_dev")]
2131#[pyo3(signature = (data_f32, period_range, device_id=0))]
2132pub fn sinwma_cuda_batch_dev_py(
2133 py: Python<'_>,
2134 data_f32: PyReadonlyArray1<'_, f32>,
2135 period_range: (usize, usize, usize),
2136 device_id: usize,
2137) -> PyResult<DeviceArrayF32Py> {
2138 if !cuda_available() {
2139 return Err(PyValueError::new_err("CUDA not available"));
2140 }
2141
2142 let slice_in = data_f32.as_slice()?;
2143 let sweep = SinWmaBatchRange {
2144 period: period_range,
2145 };
2146
2147 let inner = py.allow_threads(|| {
2148 let cuda = CudaSinwma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2149 cuda.sinwma_batch_dev(slice_in, &sweep)
2150 .map_err(|e| PyValueError::new_err(e.to_string()))
2151 })?;
2152
2153 Ok(make_device_array_py(device_id, inner)?)
2154}
2155
2156#[cfg(all(feature = "python", feature = "cuda"))]
2157#[pyfunction(name = "sinwma_cuda_many_series_one_param_dev")]
2158#[pyo3(signature = (data_tm_f32, period, device_id=0))]
2159pub fn sinwma_cuda_many_series_one_param_dev_py(
2160 py: Python<'_>,
2161 data_tm_f32: PyReadonlyArray2<'_, f32>,
2162 period: usize,
2163 device_id: usize,
2164) -> PyResult<DeviceArrayF32Py> {
2165 if !cuda_available() {
2166 return Err(PyValueError::new_err("CUDA not available"));
2167 }
2168
2169 let flat_in = data_tm_f32.as_slice()?;
2170 let rows = data_tm_f32.shape()[0];
2171 let cols = data_tm_f32.shape()[1];
2172 let params = SinWmaParams {
2173 period: Some(period),
2174 };
2175
2176 let inner = py.allow_threads(|| {
2177 let cuda = CudaSinwma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2178 cuda.sinwma_many_series_one_param_time_major_dev(flat_in, cols, rows, ¶ms)
2179 .map_err(|e| PyValueError::new_err(e.to_string()))
2180 })?;
2181
2182 Ok(make_device_array_py(device_id, inner)?)
2183}
2184
2185#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2186#[wasm_bindgen]
2187pub fn sinwma_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
2188 let params = SinWmaParams {
2189 period: Some(period),
2190 };
2191 let input = SinWmaInput::from_slice(data, params);
2192
2193 let mut output = vec![0.0; data.len()];
2194
2195 sinwma_into_slice(&mut output, &input, Kernel::Auto)
2196 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2197
2198 Ok(output)
2199}
2200
2201#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2202#[wasm_bindgen]
2203pub fn sinwma_alloc(len: usize) -> *mut f64 {
2204 let mut vec = Vec::<f64>::with_capacity(len);
2205 let ptr = vec.as_mut_ptr();
2206 std::mem::forget(vec);
2207 ptr
2208}
2209
2210#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2211#[wasm_bindgen]
2212pub fn sinwma_free(ptr: *mut f64, len: usize) {
2213 if !ptr.is_null() {
2214 unsafe {
2215 let _ = Vec::from_raw_parts(ptr, len, len);
2216 }
2217 }
2218}
2219
2220#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2221#[wasm_bindgen]
2222pub fn sinwma_into(
2223 in_ptr: *const f64,
2224 out_ptr: *mut f64,
2225 len: usize,
2226 period: usize,
2227) -> Result<(), JsValue> {
2228 if in_ptr.is_null() || out_ptr.is_null() {
2229 return Err(JsValue::from_str("null pointer passed to sinwma_into"));
2230 }
2231
2232 unsafe {
2233 let data = std::slice::from_raw_parts(in_ptr, len);
2234
2235 if period == 0 || period > len {
2236 return Err(JsValue::from_str("Invalid period"));
2237 }
2238
2239 let params = SinWmaParams {
2240 period: Some(period),
2241 };
2242 let input = SinWmaInput::from_slice(data, params);
2243
2244 if in_ptr == out_ptr {
2245 let mut temp = vec![0.0; len];
2246 sinwma_into_slice(&mut temp, &input, Kernel::Auto)
2247 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2248
2249 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2250 out.copy_from_slice(&temp);
2251 } else {
2252 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2253 sinwma_into_slice(out, &input, Kernel::Auto)
2254 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2255 }
2256
2257 Ok(())
2258 }
2259}
2260
2261#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2262#[derive(Serialize, Deserialize)]
2263pub struct SinWmaBatchConfig {
2264 pub period_range: (usize, usize, usize),
2265}
2266
2267#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2268#[derive(Serialize, Deserialize)]
2269pub struct SinWmaBatchJsOutput {
2270 pub values: Vec<f64>,
2271 pub periods: Vec<usize>,
2272 pub rows: usize,
2273 pub cols: usize,
2274}
2275
2276#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2277#[wasm_bindgen(js_name = sinwma_batch)]
2278pub fn sinwma_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2279 let config: SinWmaBatchConfig = serde_wasm_bindgen::from_value(config)
2280 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2281
2282 let sweep = SinWmaBatchRange {
2283 period: config.period_range,
2284 };
2285
2286 let output = sinwma_batch_with_kernel(data, &sweep, Kernel::Auto)
2287 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2288
2289 let periods: Vec<usize> = output.combos.iter().map(|c| c.period.unwrap()).collect();
2290
2291 let js_output = SinWmaBatchJsOutput {
2292 values: output.values,
2293 periods,
2294 rows: output.rows,
2295 cols: output.cols,
2296 };
2297
2298 serde_wasm_bindgen::to_value(&js_output)
2299 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2300}
2301
2302#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2303#[wasm_bindgen]
2304pub fn sinwma_batch_js(
2305 data: &[f64],
2306 period_start: usize,
2307 period_end: usize,
2308 period_step: usize,
2309) -> Result<Vec<f64>, JsValue> {
2310 let sweep = SinWmaBatchRange {
2311 period: (period_start, period_end, period_step),
2312 };
2313
2314 sinwma_batch_with_kernel(data, &sweep, Kernel::Auto)
2315 .map(|output| output.values)
2316 .map_err(|e| JsValue::from_str(&e.to_string()))
2317}
2318
2319#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2320#[wasm_bindgen]
2321pub fn sinwma_batch_metadata_js(
2322 period_start: usize,
2323 period_end: usize,
2324 period_step: usize,
2325) -> Vec<u32> {
2326 let sweep = SinWmaBatchRange {
2327 period: (period_start, period_end, period_step),
2328 };
2329
2330 let combos = expand_grid(&sweep).unwrap_or_else(|_| Vec::new());
2331 combos.iter().map(|p| p.period.unwrap() as u32).collect()
2332}
2333
2334#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2335#[wasm_bindgen]
2336pub fn sinwma_batch_rows_cols_js(
2337 period_start: usize,
2338 period_end: usize,
2339 period_step: usize,
2340 data_len: usize,
2341) -> Vec<u32> {
2342 let sweep = SinWmaBatchRange {
2343 period: (period_start, period_end, period_step),
2344 };
2345
2346 let combos = expand_grid(&sweep).unwrap_or_else(|_| Vec::new());
2347 vec![combos.len() as u32, data_len as u32]
2348}
2349
2350#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2351#[wasm_bindgen]
2352pub fn sinwma_batch_into(
2353 in_ptr: *const f64,
2354 out_ptr: *mut f64,
2355 len: usize,
2356 period_start: usize,
2357 period_end: usize,
2358 period_step: usize,
2359) -> Result<usize, JsValue> {
2360 if in_ptr.is_null() || out_ptr.is_null() {
2361 return Err(JsValue::from_str(
2362 "null pointer passed to sinwma_batch_into",
2363 ));
2364 }
2365 unsafe {
2366 let data = std::slice::from_raw_parts(in_ptr, len);
2367 let sweep = SinWmaBatchRange {
2368 period: (period_start, period_end, period_step),
2369 };
2370 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2371 let rows = combos.len();
2372 let cols = len;
2373
2374 let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
2375
2376 let simd = match detect_best_batch_kernel() {
2377 Kernel::Avx512Batch => Kernel::Avx512,
2378 Kernel::Avx2Batch => Kernel::Avx2,
2379 _ => Kernel::Scalar,
2380 };
2381 sinwma_batch_inner_into(data, &sweep, simd, false, out)
2382 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2383 Ok(rows)
2384 }
2385}