1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::cuda_available;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::cuda::moving_averages::{CudaCwma, DeviceArrayF32};
5use crate::utilities::aligned_vector::AlignedVec;
6use crate::utilities::data_loader::{source_type, Candles};
7use crate::utilities::enums::Kernel;
8use crate::utilities::helpers::{
9 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
10 make_uninit_matrix,
11};
12#[cfg(feature = "python")]
13use crate::utilities::kernel_validation::validate_kernel;
14#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
15use core::arch::x86_64::*;
16#[cfg(all(feature = "python", feature = "cuda"))]
17use cust::memory::DeviceBuffer;
18#[cfg(feature = "python")]
19use numpy::{IntoPyArray, PyArray1};
20#[cfg(feature = "python")]
21use pyo3::exceptions::PyValueError;
22#[cfg(feature = "python")]
23use pyo3::prelude::*;
24#[cfg(feature = "python")]
25use pyo3::types::PyDict;
26#[cfg(not(target_arch = "wasm32"))]
27use rayon::prelude::*;
28#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
29use serde::{Deserialize, Serialize};
30use std::convert::AsRef;
31use std::mem::MaybeUninit;
32#[cfg(all(feature = "python", feature = "cuda"))]
33use std::sync::Arc;
34use thiserror::Error;
35#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
36use wasm_bindgen::prelude::*;
37
38#[inline(always)]
39fn cube(x: f64) -> f64 {
40 x * x * x
41}
42
43impl<'a> AsRef<[f64]> for CwmaInput<'a> {
44 #[inline(always)]
45 fn as_ref(&self) -> &[f64] {
46 match &self.data {
47 CwmaData::Slice(slice) => slice,
48 CwmaData::Candles { candles, source } => source_type(candles, source),
49 }
50 }
51}
52
53#[derive(Debug, Clone)]
54pub enum CwmaData<'a> {
55 Candles {
56 candles: &'a Candles,
57 source: &'a str,
58 },
59 Slice(&'a [f64]),
60}
61
62#[derive(Debug, Clone)]
63pub struct CwmaOutput {
64 pub values: Vec<f64>,
65}
66
67#[derive(Debug, Clone)]
68#[cfg_attr(
69 all(target_arch = "wasm32", feature = "wasm"),
70 derive(Serialize, Deserialize)
71)]
72pub struct CwmaParams {
73 pub period: Option<usize>,
74}
75
76impl Default for CwmaParams {
77 fn default() -> Self {
78 Self { period: Some(14) }
79 }
80}
81
82#[derive(Debug, Clone)]
83pub struct CwmaInput<'a> {
84 pub data: CwmaData<'a>,
85 pub params: CwmaParams,
86}
87
88impl<'a> CwmaInput<'a> {
89 #[inline]
90 pub fn from_candles(c: &'a Candles, s: &'a str, p: CwmaParams) -> Self {
91 Self {
92 data: CwmaData::Candles {
93 candles: c,
94 source: s,
95 },
96 params: p,
97 }
98 }
99 #[inline]
100 pub fn from_slice(sl: &'a [f64], p: CwmaParams) -> Self {
101 Self {
102 data: CwmaData::Slice(sl),
103 params: p,
104 }
105 }
106 #[inline]
107 pub fn with_default_candles(c: &'a Candles) -> Self {
108 Self::from_candles(c, "close", CwmaParams::default())
109 }
110 #[inline]
111 pub fn get_period(&self) -> usize {
112 self.params.period.unwrap_or(14)
113 }
114}
115
116#[derive(Copy, Clone, Debug)]
117pub struct CwmaBuilder {
118 period: Option<usize>,
119 kernel: Kernel,
120}
121
122impl Default for CwmaBuilder {
123 fn default() -> Self {
124 Self {
125 period: None,
126 kernel: Kernel::Auto,
127 }
128 }
129}
130
131impl CwmaBuilder {
132 #[inline(always)]
133 pub fn new() -> Self {
134 Self::default()
135 }
136 #[inline(always)]
137 pub fn period(mut self, n: usize) -> Self {
138 self.period = Some(n);
139 self
140 }
141 #[inline(always)]
142 pub fn kernel(mut self, k: Kernel) -> Self {
143 self.kernel = k;
144 self
145 }
146
147 #[inline(always)]
148 pub fn apply(self, c: &Candles) -> Result<CwmaOutput, CwmaError> {
149 let p = CwmaParams {
150 period: self.period,
151 };
152 let i = CwmaInput::from_candles(c, "close", p);
153 cwma_with_kernel(&i, self.kernel)
154 }
155
156 #[inline(always)]
157 pub fn apply_slice(self, d: &[f64]) -> Result<CwmaOutput, CwmaError> {
158 let p = CwmaParams {
159 period: self.period,
160 };
161 let i = CwmaInput::from_slice(d, p);
162 cwma_with_kernel(&i, self.kernel)
163 }
164
165 #[inline(always)]
166 pub fn into_stream(self) -> Result<CwmaStream, CwmaError> {
167 let p = CwmaParams {
168 period: self.period,
169 };
170 CwmaStream::try_new(p)
171 }
172}
173
174#[derive(Debug, Error)]
175pub enum CwmaError {
176 #[error("cwma: Input data slice is empty.")]
177 EmptyInputData,
178 #[error("cwma: All values are NaN.")]
179 AllValuesNaN,
180 #[error("cwma: Invalid period specified for CWMA calculation: period = {period}, data length = {data_len}")]
181 InvalidPeriod { period: usize, data_len: usize },
182 #[error(
183 "cwma: Not enough valid data points to compute CWMA: needed = {needed}, valid = {valid}"
184 )]
185 NotEnoughValidData { needed: usize, valid: usize },
186 #[error("cwma: output length mismatch: expected = {expected}, got = {got}")]
187 OutputLengthMismatch { expected: usize, got: usize },
188 #[error("cwma: invalid sweep range: start={start}, end={end}, step={step}")]
189 InvalidRange {
190 start: usize,
191 end: usize,
192 step: usize,
193 },
194 #[error("cwma: invalid kernel for batch API: {0:?}")]
195 InvalidKernelForBatch(Kernel),
196 #[error("cwma: size overflow while computing {ctx}")]
197 SizeOverflow { ctx: &'static str },
198}
199
200#[inline]
201pub fn cwma(input: &CwmaInput) -> Result<CwmaOutput, CwmaError> {
202 cwma_with_kernel(input, Kernel::Auto)
203}
204
205#[inline(always)]
206fn cwma_prepare<'a>(
207 input: &'a CwmaInput,
208 kernel: Kernel,
209) -> Result<(&'a [f64], Vec<f64>, usize, usize, f64, usize, Kernel), CwmaError> {
210 let data: &[f64] = match &input.data {
211 CwmaData::Candles { candles, source } => source_type(candles, source),
212 CwmaData::Slice(sl) => sl,
213 };
214 let len = data.len();
215 if len == 0 {
216 return Err(CwmaError::EmptyInputData);
217 }
218 let first = data
219 .iter()
220 .position(|x| !x.is_nan())
221 .ok_or(CwmaError::AllValuesNaN)?;
222
223 let period = input.get_period();
224
225 if period == 0 || period > len {
226 return Err(CwmaError::InvalidPeriod {
227 period,
228 data_len: len,
229 });
230 }
231
232 if period == 1 {
233 return Err(CwmaError::InvalidPeriod {
234 period,
235 data_len: len,
236 });
237 }
238 if (len - first) < period {
239 return Err(CwmaError::NotEnoughValidData {
240 needed: period,
241 valid: len - first,
242 });
243 }
244
245 let mut weights = Vec::with_capacity(period - 1);
246 let mut norm = 0.0;
247 for i in 0..period - 1 {
248 let w = cube((period - i) as f64);
249 weights.push(w);
250 norm += w;
251 }
252 let inv_norm = 1.0 / norm;
253
254 let warm = first + period - 1;
255
256 let chosen = match kernel {
257 Kernel::Auto => detect_best_kernel(),
258 k => k,
259 };
260
261 Ok((data, weights, period, first, inv_norm, warm, chosen))
262}
263
264#[inline(always)]
265fn cwma_compute_into(
266 data: &[f64],
267 weights: &[f64],
268 period: usize,
269 first: usize,
270 inv_norm: f64,
271 chosen: Kernel,
272 out: &mut [f64],
273) {
274 unsafe {
275 match chosen {
276 Kernel::Scalar | Kernel::ScalarBatch => {
277 cwma_scalar(data, weights, period, first, inv_norm, out)
278 }
279 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
280 Kernel::Avx2 | Kernel::Avx2Batch => {
281 cwma_avx2(data, weights, period, first, inv_norm, out)
282 }
283 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
284 Kernel::Avx512 | Kernel::Avx512Batch => {
285 cwma_avx512(data, weights, period, first, inv_norm, out)
286 }
287 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
288 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
289 cwma_scalar(data, weights, period, first, inv_norm, out)
290 }
291 _ => unreachable!(),
292 }
293 }
294}
295
296#[inline]
297pub fn cwma_into_slice(dst: &mut [f64], input: &CwmaInput, kern: Kernel) -> Result<(), CwmaError> {
298 let (data, weights, period, first, inv_norm, warm, chosen) = cwma_prepare(input, kern)?;
299
300 if dst.len() != data.len() {
301 return Err(CwmaError::OutputLengthMismatch {
302 expected: data.len(),
303 got: dst.len(),
304 });
305 }
306
307 cwma_compute_into(data, &weights, period, first, inv_norm, chosen, dst);
308
309 for v in &mut dst[..warm] {
310 *v = f64::NAN;
311 }
312
313 Ok(())
314}
315
316pub fn cwma_with_kernel(input: &CwmaInput, kernel: Kernel) -> Result<CwmaOutput, CwmaError> {
317 let (data, weights, period, first, inv_norm, warm, chosen) = cwma_prepare(input, kernel)?;
318 let len = data.len();
319 let mut out = alloc_with_nan_prefix(len, warm);
320 cwma_compute_into(data, &weights, period, first, inv_norm, chosen, &mut out);
321 Ok(CwmaOutput { values: out })
322}
323
324#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
325#[inline]
326pub fn cwma_into(input: &CwmaInput, out: &mut [f64]) -> Result<(), CwmaError> {
327 cwma_into_slice(out, input, Kernel::Auto)
328}
329
330#[inline]
331#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
332pub unsafe fn cwma_scalar(
333 data: &[f64],
334 weights: &[f64],
335 _period: usize,
336 first_val: usize,
337 inv_norm: f64,
338 out: &mut [f64],
339) {
340 let wlen = weights.len();
341 let wptr = weights.as_ptr();
342 let first_out = first_val + wlen;
343
344 for i in first_out..data.len() {
345 let mut d = data.as_ptr().add(i);
346
347 let mut s0 = 0.0;
348 let mut s1 = 0.0;
349 let mut s2 = 0.0;
350 let mut s3 = 0.0;
351 let mut s4 = 0.0;
352 let mut s5 = 0.0;
353 let mut s6 = 0.0;
354 let mut s7 = 0.0;
355
356 let mut k = 0usize;
357 while k + 7 < wlen {
358 s0 = (*d).mul_add(*wptr.add(k + 0), s0);
359 d = d.sub(1);
360 s1 = (*d).mul_add(*wptr.add(k + 1), s1);
361 d = d.sub(1);
362 s2 = (*d).mul_add(*wptr.add(k + 2), s2);
363 d = d.sub(1);
364 s3 = (*d).mul_add(*wptr.add(k + 3), s3);
365 d = d.sub(1);
366 s4 = (*d).mul_add(*wptr.add(k + 4), s4);
367 d = d.sub(1);
368 s5 = (*d).mul_add(*wptr.add(k + 5), s5);
369 d = d.sub(1);
370 s6 = (*d).mul_add(*wptr.add(k + 6), s6);
371 d = d.sub(1);
372 s7 = (*d).mul_add(*wptr.add(k + 7), s7);
373 d = d.sub(1);
374 k += 8;
375 }
376
377 let mut sum = (s0 + s1) + (s2 + s3) + (s4 + s5) + (s6 + s7);
378
379 match wlen - k {
380 0 => {}
381 1 => {
382 sum = (*d).mul_add(*wptr.add(k + 0), sum);
383 }
384 2 => {
385 sum = (*d).mul_add(*wptr.add(k + 0), sum);
386 d = d.sub(1);
387 sum = (*d).mul_add(*wptr.add(k + 1), sum);
388 }
389 3 => {
390 sum = (*d).mul_add(*wptr.add(k + 0), sum);
391 d = d.sub(1);
392 sum = (*d).mul_add(*wptr.add(k + 1), sum);
393 d = d.sub(1);
394 sum = (*d).mul_add(*wptr.add(k + 2), sum);
395 }
396 4 => {
397 sum = (*d).mul_add(*wptr.add(k + 0), sum);
398 d = d.sub(1);
399 sum = (*d).mul_add(*wptr.add(k + 1), sum);
400 d = d.sub(1);
401 sum = (*d).mul_add(*wptr.add(k + 2), sum);
402 d = d.sub(1);
403 sum = (*d).mul_add(*wptr.add(k + 3), sum);
404 }
405 5 => {
406 sum = (*d).mul_add(*wptr.add(k + 0), sum);
407 d = d.sub(1);
408 sum = (*d).mul_add(*wptr.add(k + 1), sum);
409 d = d.sub(1);
410 sum = (*d).mul_add(*wptr.add(k + 2), sum);
411 d = d.sub(1);
412 sum = (*d).mul_add(*wptr.add(k + 3), sum);
413 d = d.sub(1);
414 sum = (*d).mul_add(*wptr.add(k + 4), sum);
415 }
416 6 => {
417 sum = (*d).mul_add(*wptr.add(k + 0), sum);
418 d = d.sub(1);
419 sum = (*d).mul_add(*wptr.add(k + 1), sum);
420 d = d.sub(1);
421 sum = (*d).mul_add(*wptr.add(k + 2), sum);
422 d = d.sub(1);
423 sum = (*d).mul_add(*wptr.add(k + 3), sum);
424 d = d.sub(1);
425 sum = (*d).mul_add(*wptr.add(k + 4), sum);
426 d = d.sub(1);
427 sum = (*d).mul_add(*wptr.add(k + 5), sum);
428 }
429 7 => {
430 sum = (*d).mul_add(*wptr.add(k + 0), sum);
431 d = d.sub(1);
432 sum = (*d).mul_add(*wptr.add(k + 1), sum);
433 d = d.sub(1);
434 sum = (*d).mul_add(*wptr.add(k + 2), sum);
435 d = d.sub(1);
436 sum = (*d).mul_add(*wptr.add(k + 3), sum);
437 d = d.sub(1);
438 sum = (*d).mul_add(*wptr.add(k + 4), sum);
439 d = d.sub(1);
440 sum = (*d).mul_add(*wptr.add(k + 5), sum);
441 d = d.sub(1);
442 sum = (*d).mul_add(*wptr.add(k + 6), sum);
443 }
444 _ => core::hint::unreachable_unchecked(),
445 }
446
447 *out.get_unchecked_mut(i) = sum * inv_norm;
448 }
449}
450
451#[inline]
452#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
453pub unsafe fn cwma_scalar(
454 data: &[f64],
455 weights: &[f64],
456 _period: usize,
457 first_val: usize,
458 inv_norm: f64,
459 out: &mut [f64],
460) {
461 let wlen = weights.len();
462 let wptr = weights.as_ptr();
463 let first_out = first_val + wlen;
464
465 for i in first_out..data.len() {
466 let mut d = data.as_ptr().add(i);
467
468 let mut s0 = 0.0;
469 let mut s1 = 0.0;
470 let mut s2 = 0.0;
471 let mut s3 = 0.0;
472 let mut s4 = 0.0;
473 let mut s5 = 0.0;
474 let mut s6 = 0.0;
475 let mut s7 = 0.0;
476
477 let mut k = 0usize;
478 while k + 7 < wlen {
479 s0 += *d * *wptr.add(k + 0);
480 d = d.sub(1);
481 s1 += *d * *wptr.add(k + 1);
482 d = d.sub(1);
483 s2 += *d * *wptr.add(k + 2);
484 d = d.sub(1);
485 s3 += *d * *wptr.add(k + 3);
486 d = d.sub(1);
487 s4 += *d * *wptr.add(k + 4);
488 d = d.sub(1);
489 s5 += *d * *wptr.add(k + 5);
490 d = d.sub(1);
491 s6 += *d * *wptr.add(k + 6);
492 d = d.sub(1);
493 s7 += *d * *wptr.add(k + 7);
494 d = d.sub(1);
495 k += 8;
496 }
497
498 let mut sum = (s0 + s1) + (s2 + s3) + (s4 + s5) + (s6 + s7);
499
500 match wlen - k {
501 0 => {}
502 1 => {
503 sum += *d * *wptr.add(k + 0);
504 }
505 2 => {
506 sum += *d * *wptr.add(k + 0);
507 d = d.sub(1);
508 sum += *d * *wptr.add(k + 1);
509 }
510 3 => {
511 sum += *d * *wptr.add(k + 0);
512 d = d.sub(1);
513 sum += *d * *wptr.add(k + 1);
514 d = d.sub(1);
515 sum += *d * *wptr.add(k + 2);
516 }
517 4 => {
518 sum += *d * *wptr.add(k + 0);
519 d = d.sub(1);
520 sum += *d * *wptr.add(k + 1);
521 d = d.sub(1);
522 sum += *d * *wptr.add(k + 2);
523 d = d.sub(1);
524 sum += *d * *wptr.add(k + 3);
525 }
526 5 => {
527 sum += *d * *wptr.add(k + 0);
528 d = d.sub(1);
529 sum += *d * *wptr.add(k + 1);
530 d = d.sub(1);
531 sum += *d * *wptr.add(k + 2);
532 d = d.sub(1);
533 sum += *d * *wptr.add(k + 3);
534 d = d.sub(1);
535 sum += *d * *wptr.add(k + 4);
536 }
537 6 => {
538 sum += *d * *wptr.add(k + 0);
539 d = d.sub(1);
540 sum += *d * *wptr.add(k + 1);
541 d = d.sub(1);
542 sum += *d * *wptr.add(k + 2);
543 d = d.sub(1);
544 sum += *d * *wptr.add(k + 3);
545 d = d.sub(1);
546 sum += *d * *wptr.add(k + 4);
547 d = d.sub(1);
548 sum += *d * *wptr.add(k + 5);
549 }
550 7 => {
551 sum += *d * *wptr.add(k + 0);
552 d = d.sub(1);
553 sum += *d * *wptr.add(k + 1);
554 d = d.sub(1);
555 sum += *d * *wptr.add(k + 2);
556 d = d.sub(1);
557 sum += *d * *wptr.add(k + 3);
558 d = d.sub(1);
559 sum += *d * *wptr.add(k + 4);
560 d = d.sub(1);
561 sum += *d * *wptr.add(k + 5);
562 d = d.sub(1);
563 sum += *d * *wptr.add(k + 6);
564 }
565 _ => core::hint::unreachable_unchecked(),
566 }
567
568 *out.get_unchecked_mut(i) = sum * inv_norm;
569 }
570}
571
572#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
573#[target_feature(enable = "avx2,fma")]
574pub unsafe fn cwma_avx2(
575 data: &[f64],
576 weights: &[f64],
577 _period: usize,
578 first_valid: usize,
579 inv_norm: f64,
580 out: &mut [f64],
581) {
582 const STEP: usize = 4;
583
584 let wlen = weights.len();
585 let chunks = wlen / STEP;
586 let tail = wlen % STEP;
587 let first_out = first_valid + wlen;
588
589 for i in first_out..data.len() {
590 let mut acc = _mm256_setzero_pd();
591
592 for blk in 0..chunks {
593 let idx = blk * STEP;
594 let w = _mm256_loadu_pd(weights.as_ptr().add(idx));
595
596 let base = i - idx - (STEP - 1);
597 let mut d = _mm256_loadu_pd(data.as_ptr().add(base));
598 d = _mm256_permute4x64_pd(d, 0b00011011);
599
600 acc = _mm256_fmadd_pd(d, w, acc);
601 }
602
603 let mut tail_sum = 0.0;
604 if tail != 0 {
605 let base = chunks * STEP;
606 for k in 0..tail {
607 let w = *weights.get_unchecked(base + k);
608 let d = *data.get_unchecked(i - (base + k));
609 tail_sum = d.mul_add(w, tail_sum);
610 }
611 }
612
613 let hi = _mm256_extractf128_pd(acc, 1);
614 let lo = _mm256_castpd256_pd128(acc);
615 let sum2 = _mm_add_pd(hi, lo);
616 let sum1 = _mm_add_pd(sum2, _mm_unpackhi_pd(sum2, sum2));
617 let mut sum = _mm_cvtsd_f64(sum1);
618
619 sum += tail_sum;
620 *out.get_unchecked_mut(i) = sum * inv_norm;
621 }
622}
623#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
624#[inline]
625pub fn cwma_avx512(
626 data: &[f64],
627 weights: &[f64],
628 period: usize,
629 first_valid: usize,
630 inv_norm: f64,
631 out: &mut [f64],
632) {
633 if weights.len() < 24 {
634 unsafe { cwma_avx2(data, weights, period, first_valid, inv_norm, out) }
635 return;
636 }
637 if period <= 32 {
638 unsafe { cwma_avx512_short(data, weights, period, first_valid, inv_norm, out) }
639 } else {
640 unsafe { cwma_avx512_long(data, weights, period, first_valid, inv_norm, out) }
641 }
642}
643#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
644#[target_feature(enable = "avx512f")]
645unsafe fn reverse_vec(v: __m512d) -> __m512d {
646 let lanes = _mm512_set_epi64(0, 1, 2, 3, 4, 5, 6, 7);
647 _mm512_permutexvar_pd(lanes, v)
648}
649#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
650#[target_feature(enable = "avx512f")]
651unsafe fn make_weight_blocks(weights: &[f64]) -> Vec<__m512d> {
652 const STEP: usize = 8;
653 weights
654 .chunks_exact(STEP)
655 .map(|chunk| {
656 let v = _mm512_loadu_pd(chunk.as_ptr());
657 reverse_vec(v)
658 })
659 .collect()
660}
661#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
662#[target_feature(enable = "avx512f,fma")]
663pub unsafe fn cwma_avx512_short(
664 data: &[f64],
665 weights: &[f64],
666 period: usize,
667 first_valid: usize,
668 inv_norm: f64,
669 out: &mut [f64],
670) {
671 debug_assert!(period <= 32);
672 debug_assert_eq!(weights.len(), period - 1);
673
674 const STEP: usize = 8;
675 let wlen = weights.len();
676 let chunks = wlen / STEP;
677 let tail = wlen % STEP;
678 let first_out = first_valid + wlen;
679
680 let mut wv: [__m512d; 4] = [_mm512_setzero_pd(); 4];
681 if chunks >= 1 {
682 wv[0] = reverse_vec(_mm512_loadu_pd(weights.as_ptr()));
683 }
684 if chunks >= 2 {
685 wv[1] = reverse_vec(_mm512_loadu_pd(weights.as_ptr().add(STEP)));
686 }
687 if chunks >= 3 {
688 wv[2] = reverse_vec(_mm512_loadu_pd(weights.as_ptr().add(2 * STEP)));
689 }
690 if chunks == 4 {
691 wv[3] = reverse_vec(_mm512_loadu_pd(weights.as_ptr().add(3 * STEP)));
692 }
693
694 for i in first_out..data.len() {
695 let mut acc = _mm512_setzero_pd();
696
697 if chunks >= 1 && i >= 7 {
698 let d0 = _mm512_loadu_pd(data.as_ptr().add(i - 7));
699 acc = _mm512_fmadd_pd(d0, wv[0], acc);
700 }
701 if chunks >= 2 && i >= STEP + 7 {
702 let d1 = _mm512_loadu_pd(data.as_ptr().add(i - STEP - 7));
703 acc = _mm512_fmadd_pd(d1, wv[1], acc);
704 }
705 if chunks >= 3 && i >= 2 * STEP + 7 {
706 let d2 = _mm512_loadu_pd(data.as_ptr().add(i - 2 * STEP - 7));
707 acc = _mm512_fmadd_pd(d2, wv[2], acc);
708 }
709 if chunks == 4 && i >= 3 * STEP + 7 {
710 let d3 = _mm512_loadu_pd(data.as_ptr().add(i - 3 * STEP - 7));
711 acc = _mm512_fmadd_pd(d3, wv[3], acc);
712 }
713
714 let mut sum = _mm512_reduce_add_pd(acc);
715
716 if tail != 0 {
717 let base = chunks * STEP;
718 for k in 0..tail {
719 let w = *weights.get_unchecked(base + k);
720 let d = *data.get_unchecked(i - (base + k));
721 sum = d.mul_add(w, sum);
722 }
723 }
724
725 *out.get_unchecked_mut(i) = sum * inv_norm;
726 }
727}
728#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
729#[target_feature(enable = "avx512f,fma")]
730pub unsafe fn cwma_avx512_long(
731 data: &[f64],
732 weights: &[f64],
733 _period: usize,
734 first_valid: usize,
735 inv_norm: f64,
736 out: &mut [f64],
737) {
738 const STEP: usize = 8;
739
740 let wlen = weights.len();
741 let chunks = wlen / STEP;
742 let tail = wlen % STEP;
743 let first_out = first_valid + wlen;
744
745 if wlen < 24 {
746 cwma_avx2(data, weights, _period, first_valid, inv_norm, out);
747 return;
748 }
749
750 let wblocks = make_weight_blocks(weights);
751
752 for i in first_out..data.len() {
753 let mut acc0 = _mm512_setzero_pd();
754 let mut acc1 = _mm512_setzero_pd();
755
756 let paired = chunks & !1;
757 let mut blk = 0;
758
759 while blk < paired {
760 if i >= blk * STEP + 7 && i >= (blk + 1) * STEP + 7 {
761 let d0 = _mm512_loadu_pd(data.as_ptr().add(i - blk * STEP - 7));
762 let d1 = _mm512_loadu_pd(data.as_ptr().add(i - (blk + 1) * STEP - 7));
763
764 acc0 = _mm512_fmadd_pd(d0, *wblocks.get_unchecked(blk), acc0);
765 acc1 = _mm512_fmadd_pd(d1, *wblocks.get_unchecked(blk + 1), acc1);
766 }
767
768 blk += 2;
769 }
770
771 if blk < chunks && i >= blk * STEP + 7 {
772 let d = _mm512_loadu_pd(data.as_ptr().add(i - blk * STEP - 7));
773 acc0 = _mm512_fmadd_pd(d, *wblocks.get_unchecked(blk), acc0);
774 }
775
776 let mut sum = _mm512_reduce_add_pd(_mm512_add_pd(acc0, acc1));
777
778 if tail != 0 {
779 let base = chunks * STEP;
780 for k in 0..tail {
781 let w = *weights.get_unchecked(base + k);
782 let d = *data.get_unchecked(i - (base + k));
783 sum = d.mul_add(w, sum);
784 }
785 }
786
787 *out.get_unchecked_mut(i) = sum * inv_norm;
788 }
789}
790
791#[derive(Debug, Clone)]
792pub struct CwmaStream {
793 period: usize,
794 inv_norm: f64,
795
796 n: usize,
797
798 ring: Vec<f64>,
799 head: usize,
800 filled: usize,
801 nan_count: usize,
802
803 total_count: usize,
804 found_first: bool,
805 first_idx: usize,
806
807 m0: f64,
808 m1: f64,
809 m2: f64,
810 m3: f64,
811
812 s: f64,
813
814 a: f64,
815 w1: f64,
816 wn: f64,
817 alpha0: f64,
818 alpha1: f64,
819 alpha2: f64,
820
821 n_f: f64,
822 n_sq: f64,
823 n_p1: f64,
824 n_p1_sq: f64,
825
826 moments_ready: bool,
827}
828
829impl CwmaStream {
830 pub fn try_new(params: CwmaParams) -> Result<Self, CwmaError> {
831 let period = params.period.unwrap_or(14);
832 if period <= 1 {
833 return Err(CwmaError::InvalidPeriod {
834 period,
835 data_len: 0,
836 });
837 }
838
839 let n = period - 1;
840
841 let mut norm = 0.0;
842 for j in 2..=period {
843 let jf = j as f64;
844 norm += jf * jf * jf;
845 }
846 let inv_norm = 1.0 / norm;
847
848 let n_f = n as f64;
849 let n_p1 = (n + 1) as f64;
850 let n_p1_sq = n_p1 * n_p1;
851 let a = (n + 2) as f64;
852
853 let w1 = n_p1 * n_p1 * n_p1;
854 let wn = 8.0;
855
856 let alpha0 = -3.0 * a * a + 3.0 * a - 1.0;
857 let alpha1 = 6.0 * a - 3.0;
858 let alpha2 = -3.0;
859
860 Ok(Self {
861 period,
862 inv_norm,
863 n,
864 ring: vec![f64::NAN; n.max(1)],
865 head: 0,
866 filled: 0,
867 nan_count: 0,
868 total_count: 0,
869 found_first: false,
870 first_idx: 0,
871
872 m0: 0.0,
873 m1: 0.0,
874 m2: 0.0,
875 m3: 0.0,
876 s: 0.0,
877
878 a,
879 w1,
880 wn,
881 alpha0,
882 alpha1,
883 alpha2,
884 n_f,
885 n_sq: n_f * n_f,
886 n_p1,
887 n_p1_sq,
888 moments_ready: false,
889 })
890 }
891
892 #[inline(always)]
893 pub fn update(&mut self, value: f64) -> Option<f64> {
894 let idx = self.total_count;
895 self.total_count = idx + 1;
896
897 if !self.found_first {
898 if value.is_nan() {
899 return None;
900 } else {
901 self.found_first = true;
902 self.first_idx = idx;
903 }
904 }
905
906 let mut old = f64::NAN;
907 if self.filled >= self.n {
908 old = self.ring[self.head];
909 }
910
911 let new_nan = value.is_nan() as usize;
912 let old_nan = (self.filled >= self.n && old.is_nan()) as usize;
913
914 if self.n > 0 {
915 self.ring[self.head] = value;
916 self.head = (self.head + 1) % self.n;
917 }
918
919 if self.filled <= self.n {
920 self.filled += 1;
921 self.nan_count += new_nan;
922
923 if self.filled == self.n + 1 {
924 self.nan_count -= old_nan;
925 }
926
927 if self.filled <= self.n {
928 return None;
929 }
930
931 if self.nan_count > 0 {
932 self.moments_ready = false;
933 return Some(f64::NAN);
934 }
935
936 self.rebuild_moments_and_sum();
937 self.moments_ready = true;
938 return Some(self.sum_weighted() * self.inv_norm);
939 }
940
941 self.nan_count = self.nan_count + new_nan - old_nan;
942
943 if self.nan_count > 0 {
944 self.moments_ready = false;
945 return Some(f64::NAN);
946 }
947
948 if !self.moments_ready {
949 self.rebuild_moments_and_sum();
950 self.moments_ready = true;
951 return Some(self.sum_weighted() * self.inv_norm);
952 }
953
954 let m0_prev = self.m0;
955 let m1_prev = self.m1;
956 let m2_prev = self.m2;
957 let m3_prev = self.m3;
958
959 let newv = value;
960 let oldv = old;
961
962 self.m0 = m0_prev + newv - oldv;
963 self.m1 = (-self.n_p1).mul_add(oldv, m1_prev + m0_prev + newv);
964 let tmp2 = m1_prev.mul_add(2.0, m2_prev + m0_prev + newv);
965 self.m2 = (-self.n_p1_sq).mul_add(oldv, tmp2);
966 let np13 = self.n_p1 * self.n_p1 * self.n_p1;
967 let tmp3 = m2_prev.mul_add(3.0, m3_prev + m0_prev + newv);
968 let tmp3 = m1_prev.mul_add(3.0, tmp3);
969 self.m3 = (-np13).mul_add(oldv, tmp3);
970
971 let mut ds = newv.mul_add(self.w1, 0.0);
972 ds = oldv.mul_add(-self.wn, ds);
973 let t1 = self.alpha0.mul_add(m0_prev - oldv, ds);
974 let u1 = (-self.n_f).mul_add(oldv, m1_prev);
975 let t2 = self.alpha1.mul_add(u1, t1);
976 let u2 = (-self.n_sq).mul_add(oldv, m2_prev);
977 let delta_s = self.alpha2.mul_add(u2, t2);
978 self.s += delta_s;
979
980 Some(self.sum_weighted() * self.inv_norm)
981 }
982
983 #[inline(always)]
984 fn rebuild_moments_and_sum(&mut self) {
985 debug_assert!(self.nan_count == 0, "rebuild called with NaNs present");
986 let mut m0 = 0.0;
987 let mut m1 = 0.0;
988 let mut m2 = 0.0;
989 let mut m3 = 0.0;
990 let mut s = 0.0;
991
992 let a = self.a;
993 for r in 1..=self.n {
994 let idx = (self.head + self.n - r) % self.n;
995 let v = self.ring[idx];
996 let rf = r as f64;
997
998 m0 += v;
999 m1 += rf * v;
1000 m2 += (rf * rf) * v;
1001 m3 += (rf * rf * rf) * v;
1002
1003 let w = {
1004 let t = a - rf;
1005 t * t * t
1006 };
1007 s = v.mul_add(w, s);
1008 }
1009
1010 self.m0 = m0;
1011 self.m1 = m1;
1012 self.m2 = m2;
1013 self.m3 = m3;
1014 self.s = s;
1015 }
1016
1017 #[inline(always)]
1018 fn sum_weighted(&self) -> f64 {
1019 let mut s = 0.0;
1020 let a = self.a;
1021 if self.n == 0 {
1022 return 0.0;
1023 }
1024 for r in 1..=self.n {
1025 let idx = (self.head + self.n - r) % self.n;
1026 let v = self.ring[idx];
1027 let rf = r as f64;
1028 let t = a - rf;
1029 let w = t * t * t;
1030 s = v.mul_add(w, s);
1031 }
1032 s
1033 }
1034}
1035
1036#[derive(Clone, Debug)]
1037pub struct CwmaBatchRange {
1038 pub period: (usize, usize, usize),
1039}
1040
1041impl Default for CwmaBatchRange {
1042 fn default() -> Self {
1043 Self {
1044 period: (14, 263, 1),
1045 }
1046 }
1047}
1048
1049#[derive(Clone, Debug, Default)]
1050pub struct CwmaBatchBuilder {
1051 range: CwmaBatchRange,
1052 kernel: Kernel,
1053}
1054
1055impl CwmaBatchBuilder {
1056 pub fn new() -> Self {
1057 Self::default()
1058 }
1059 pub fn kernel(mut self, k: Kernel) -> Self {
1060 self.kernel = k;
1061 self
1062 }
1063
1064 #[inline]
1065 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1066 self.range.period = (start, end, step);
1067 self
1068 }
1069 #[inline]
1070 pub fn period_static(mut self, p: usize) -> Self {
1071 self.range.period = (p, p, 0);
1072 self
1073 }
1074
1075 pub fn apply_slice(self, data: &[f64]) -> Result<CwmaBatchOutput, CwmaError> {
1076 cwma_batch_with_kernel(data, &self.range, self.kernel)
1077 }
1078
1079 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<CwmaBatchOutput, CwmaError> {
1080 CwmaBatchBuilder::new().kernel(k).apply_slice(data)
1081 }
1082
1083 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<CwmaBatchOutput, CwmaError> {
1084 let slice = source_type(c, src);
1085 self.apply_slice(slice)
1086 }
1087
1088 pub fn with_default_candles(c: &Candles) -> Result<CwmaBatchOutput, CwmaError> {
1089 CwmaBatchBuilder::new()
1090 .kernel(Kernel::Auto)
1091 .apply_candles(c, "close")
1092 }
1093}
1094
1095pub fn cwma_batch_with_kernel(
1096 data: &[f64],
1097 sweep: &CwmaBatchRange,
1098 k: Kernel,
1099) -> Result<CwmaBatchOutput, CwmaError> {
1100 let kernel = match k {
1101 Kernel::Auto => detect_best_batch_kernel(),
1102 other if other.is_batch() => other,
1103 other => return Err(CwmaError::InvalidKernelForBatch(other)),
1104 };
1105
1106 let simd = match kernel {
1107 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1108 Kernel::Avx512Batch => Kernel::Avx512,
1109 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1110 Kernel::Avx2Batch => Kernel::Avx2,
1111 Kernel::ScalarBatch => Kernel::Scalar,
1112 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1113 Kernel::Avx2Batch | Kernel::Avx512Batch => Kernel::Scalar,
1114 _ => unreachable!(),
1115 };
1116 cwma_batch_par_slice(data, sweep, simd)
1117}
1118
1119#[derive(Clone, Debug)]
1120pub struct CwmaBatchOutput {
1121 pub values: Vec<f64>,
1122 pub combos: Vec<CwmaParams>,
1123 pub rows: usize,
1124 pub cols: usize,
1125}
1126
1127impl CwmaBatchOutput {
1128 pub fn row_for_params(&self, p: &CwmaParams) -> Option<usize> {
1129 self.combos
1130 .iter()
1131 .position(|c| c.period.unwrap_or(14) == p.period.unwrap_or(14))
1132 }
1133 pub fn values_for(&self, p: &CwmaParams) -> Option<&[f64]> {
1134 self.row_for_params(p).map(|row| {
1135 let start = row * self.cols;
1136 &self.values[start..start + self.cols]
1137 })
1138 }
1139}
1140
1141#[inline(always)]
1142fn expand_grid(r: &CwmaBatchRange) -> Vec<CwmaParams> {
1143 fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
1144 if step == 0 || start == end {
1145 return vec![start];
1146 }
1147 let (lo, hi) = if start <= end {
1148 (start, end)
1149 } else {
1150 (end, start)
1151 };
1152 (lo..=hi).step_by(step).collect()
1153 }
1154
1155 let periods = axis_usize(r.period);
1156
1157 let mut out = Vec::with_capacity(periods.len());
1158 for &p in &periods {
1159 out.push(CwmaParams { period: Some(p) });
1160 }
1161 out
1162}
1163
1164#[inline(always)]
1165pub fn cwma_batch_slice(
1166 data: &[f64],
1167 sweep: &CwmaBatchRange,
1168 kern: Kernel,
1169) -> Result<CwmaBatchOutput, CwmaError> {
1170 cwma_batch_inner(data, sweep, kern, false)
1171}
1172
1173#[inline(always)]
1174pub fn cwma_batch_par_slice(
1175 data: &[f64],
1176 sweep: &CwmaBatchRange,
1177 kern: Kernel,
1178) -> Result<CwmaBatchOutput, CwmaError> {
1179 cwma_batch_inner(data, sweep, kern, true)
1180}
1181
1182#[inline]
1183fn round_up8(x: usize) -> usize {
1184 (x + 7) & !7
1185}
1186
1187#[inline(always)]
1188fn cwma_batch_inner(
1189 data: &[f64],
1190 sweep: &CwmaBatchRange,
1191 kern: Kernel,
1192 parallel: bool,
1193) -> Result<CwmaBatchOutput, CwmaError> {
1194 let combos = expand_grid(sweep);
1195 let cols = data.len();
1196 let rows = combos.len();
1197
1198 let _total = rows
1199 .checked_mul(cols)
1200 .ok_or(CwmaError::SizeOverflow { ctx: "rows*cols" })?;
1201
1202 if cols == 0 {
1203 return Err(CwmaError::EmptyInputData);
1204 }
1205
1206 let first = data
1207 .iter()
1208 .position(|x| !x.is_nan())
1209 .ok_or(CwmaError::AllValuesNaN)?;
1210
1211 let max_p = combos
1212 .iter()
1213 .map(|c| round_up8(c.period.unwrap()))
1214 .max()
1215 .unwrap();
1216
1217 if (cols - first) < max_p {
1218 return Err(CwmaError::NotEnoughValidData {
1219 needed: max_p,
1220 valid: cols - first,
1221 });
1222 }
1223
1224 let mut buf_mu = make_uninit_matrix(rows, cols);
1225
1226 let warm: Vec<usize> = combos
1227 .iter()
1228 .map(|c| first + c.period.unwrap() - 1)
1229 .collect();
1230
1231 init_matrix_prefixes(&mut buf_mu, cols, &warm);
1232
1233 let mut buf_guard = core::mem::ManuallyDrop::new(buf_mu);
1234 let out: &mut [f64] = unsafe {
1235 core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
1236 };
1237
1238 cwma_batch_inner_into(data, sweep, kern, parallel, out)?;
1239
1240 let values = unsafe {
1241 Vec::from_raw_parts(
1242 buf_guard.as_mut_ptr() as *mut f64,
1243 buf_guard.len(),
1244 buf_guard.capacity(),
1245 )
1246 };
1247
1248 Ok(CwmaBatchOutput {
1249 values,
1250 combos,
1251 rows,
1252 cols,
1253 })
1254}
1255
1256#[inline(always)]
1257fn cwma_batch_inner_into(
1258 data: &[f64],
1259 sweep: &CwmaBatchRange,
1260 kern: Kernel,
1261 parallel: bool,
1262 out: &mut [f64],
1263) -> Result<Vec<CwmaParams>, CwmaError> {
1264 let combos = expand_grid(sweep);
1265
1266 let first = data
1267 .iter()
1268 .position(|x| !x.is_nan())
1269 .ok_or(CwmaError::AllValuesNaN)?;
1270
1271 let max_p = combos
1272 .iter()
1273 .map(|c| round_up8(c.period.unwrap()))
1274 .max()
1275 .unwrap();
1276
1277 let rows = combos.len();
1278 let cols = data.len();
1279 let expected = rows
1280 .checked_mul(cols)
1281 .ok_or(CwmaError::SizeOverflow { ctx: "rows*cols" })?;
1282 if out.len() != expected {
1283 return Err(CwmaError::OutputLengthMismatch {
1284 expected,
1285 got: out.len(),
1286 });
1287 }
1288 let mut inv_norms = vec![0.0; rows];
1289
1290 let cap = rows
1291 .checked_mul(max_p)
1292 .ok_or(CwmaError::SizeOverflow { ctx: "rows*max_p" })?;
1293 let mut aligned = AlignedVec::with_capacity(cap);
1294 let flat_w = aligned.as_mut_slice();
1295
1296 for (row, prm) in combos.iter().enumerate() {
1297 let period = prm.period.unwrap();
1298 let mut norm = 0.0;
1299 for i in 0..period - 1 {
1300 let w = cube((period - i) as f64);
1301 flat_w[row * max_p + i] = w;
1302 norm += w;
1303 }
1304 let inv_norm = 1.0 / norm;
1305 inv_norms[row] = inv_norm;
1306 }
1307
1308 let out_uninit = unsafe {
1309 std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
1310 };
1311
1312 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
1313 let period = combos[row].period.unwrap();
1314 let w_ptr = flat_w.as_ptr().add(row * max_p);
1315 let inv_n = *inv_norms.get_unchecked(row);
1316
1317 let out_row =
1318 core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
1319
1320 match kern {
1321 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1322 Kernel::Avx512 => cwma_row_avx512(data, first, period, max_p, w_ptr, inv_n, out_row),
1323 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1324 Kernel::Avx2 => cwma_row_avx2(data, first, period, max_p, w_ptr, inv_n, out_row),
1325 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1326 Kernel::Avx2 | Kernel::Avx512 => {
1327 cwma_row_scalar(data, first, period, max_p, w_ptr, inv_n, out_row)
1328 }
1329 _ => cwma_row_scalar(data, first, period, max_p, w_ptr, inv_n, out_row),
1330 }
1331 };
1332
1333 if parallel {
1334 #[cfg(not(target_arch = "wasm32"))]
1335 {
1336 out_uninit
1337 .par_chunks_mut(cols)
1338 .enumerate()
1339 .for_each(|(row, slice)| do_row(row, slice));
1340 }
1341
1342 #[cfg(target_arch = "wasm32")]
1343 {
1344 for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
1345 do_row(row, slice);
1346 }
1347 }
1348 } else {
1349 for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
1350 do_row(row, slice);
1351 }
1352 }
1353
1354 Ok(combos)
1355}
1356
1357#[inline(always)]
1358unsafe fn cwma_row_scalar(
1359 data: &[f64],
1360 first: usize,
1361 period: usize,
1362 stride: usize,
1363 w_ptr: *const f64,
1364 inv_n: f64,
1365 out: &mut [f64],
1366) {
1367 let wlen = period - 1;
1368 let start_idx = first + wlen;
1369 for i in start_idx..data.len().min(out.len()) {
1370 let mut sum = 0.0;
1371 for k in 0..wlen {
1372 let w = *w_ptr.add(k);
1373 let d = *data.get_unchecked(i - k);
1374 sum += d * w;
1375 }
1376 out[i] = sum * inv_n;
1377 }
1378}
1379
1380#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1381#[inline(always)]
1382unsafe fn load_w512_rev(ptr: *const f64) -> __m512d {
1383 let rev = _mm512_set_epi64(0, 1, 2, 3, 4, 5, 6, 7);
1384 _mm512_permutexvar_pd(rev, _mm512_loadu_pd(ptr))
1385}
1386
1387#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1388#[inline(always)]
1389unsafe fn load_w256_rev(ptr: *const f64) -> __m256d {
1390 _mm256_permute4x64_pd::<0b00011011>(_mm256_loadu_pd(ptr))
1391}
1392
1393#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1394#[target_feature(enable = "avx512f,fma")]
1395#[inline]
1396unsafe fn cwma_row_avx512(
1397 data: &[f64],
1398 first: usize,
1399 period: usize,
1400 _stride: usize,
1401 w_ptr: *const f64,
1402 inv_n: f64,
1403 out: &mut [f64],
1404) {
1405 if period <= 32 {
1406 cwma_row_avx512_short(data, first, period, w_ptr, inv_n, out);
1407 } else {
1408 cwma_row_avx512_long(data, first, period, w_ptr, inv_n, out);
1409 }
1410}
1411
1412#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1413#[target_feature(enable = "avx512f,fma")]
1414#[inline]
1415unsafe fn cwma_row_avx512_short(
1416 data: &[f64],
1417 first: usize,
1418 period: usize,
1419 w_ptr: *const f64,
1420 inv_n: f64,
1421 out: &mut [f64],
1422) {
1423 const STEP: usize = 8;
1424 let wlen = period - 1;
1425 let chunks = wlen / STEP;
1426 let tail_len = wlen % STEP;
1427 let tmask: __mmask8 = (1u8 << tail_len).wrapping_sub(1);
1428
1429 let w0 = load_w512_rev(w_ptr);
1430 let w1 = if chunks >= 2 {
1431 Some(load_w512_rev(w_ptr.add(STEP)))
1432 } else {
1433 None
1434 };
1435
1436 let start_idx = first + wlen;
1437 for i in start_idx..data.len().min(out.len()) {
1438 let mut acc = _mm512_setzero_pd();
1439
1440 if chunks >= 1 && i >= 7 {
1441 acc = _mm512_fmadd_pd(_mm512_loadu_pd(data.as_ptr().add(i - 7)), w0, acc);
1442 }
1443
1444 if let Some(w1v) = w1 {
1445 if i >= STEP + 7 {
1446 let d1 = _mm512_loadu_pd(data.as_ptr().add(i - STEP - 7));
1447 acc = _mm512_fmadd_pd(d1, w1v, acc);
1448 }
1449 }
1450
1451 let mut tail_sum = 0.0;
1452 if tail_len != 0 {
1453 let base = chunks * STEP;
1454 for k in 0..tail_len {
1455 let w = *w_ptr.add(base + k);
1456 let d = *data.get_unchecked(i - (base + k));
1457 tail_sum = d.mul_add(w, tail_sum);
1458 }
1459 }
1460
1461 let sum = _mm512_reduce_add_pd(acc) + tail_sum;
1462 *out.get_unchecked_mut(i) = sum * inv_n;
1463 }
1464}
1465#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1466#[target_feature(enable = "avx512f,avx512dq,fma")]
1467#[inline]
1468unsafe fn cwma_row_avx512_long(
1469 data: &[f64],
1470 first: usize,
1471 period: usize,
1472 w_ptr: *const f64,
1473 inv_n: f64,
1474 out: &mut [f64],
1475) {
1476 const STEP: usize = 8;
1477 const MAX: usize = 512;
1478
1479 let wlen = period - 1;
1480 let n_chunks = wlen / STEP;
1481 let tail_len = wlen % STEP;
1482 let tmask: __mmask8 = (1u8 << tail_len).wrapping_sub(1);
1483
1484 debug_assert!(n_chunks + (tail_len != 0) as usize <= MAX);
1485
1486 let mut wregs: [core::mem::MaybeUninit<__m512d>; MAX] =
1487 core::mem::MaybeUninit::uninit().assume_init();
1488
1489 for blk in 0..n_chunks {
1490 let src = w_ptr.add(wlen - (blk + 1) * STEP);
1491 wregs[blk].as_mut_ptr().write(load_w512_rev(src));
1492 }
1493 if tail_len != 0 {
1494 let src = w_ptr.add(wlen - n_chunks * STEP - tail_len);
1495 let wtl = _mm512_maskz_loadu_pd(tmask, src);
1496 wregs[n_chunks].as_mut_ptr().write(_mm512_permutexvar_pd(
1497 _mm512_set_epi64(0, 1, 2, 3, 4, 5, 6, 7),
1498 wtl,
1499 ));
1500 }
1501 let wregs: &[__m512d] = core::slice::from_raw_parts(
1502 wregs.as_ptr() as *const __m512d,
1503 n_chunks + (tail_len != 0) as usize,
1504 );
1505
1506 let start_idx = first + wlen;
1507 for i in start_idx..data.len().min(out.len()) {
1508 let base_ptr = data.as_ptr().add(i - wlen);
1509
1510 let mut s0 = _mm512_setzero_pd();
1511 let mut s1 = _mm512_setzero_pd();
1512 let mut s2 = _mm512_setzero_pd();
1513 let mut s3 = _mm512_setzero_pd();
1514
1515 let mut blk = 0;
1516 while blk + 3 < n_chunks {
1517 let d0 = _mm512_loadu_pd(base_ptr.add((blk + 0) * STEP));
1518 let d1 = _mm512_loadu_pd(base_ptr.add((blk + 1) * STEP));
1519 let d2 = _mm512_loadu_pd(base_ptr.add((blk + 2) * STEP));
1520 let d3 = _mm512_loadu_pd(base_ptr.add((blk + 3) * STEP));
1521
1522 s0 = _mm512_fmadd_pd(d0, *wregs.get_unchecked(blk + 0), s0);
1523 s1 = _mm512_fmadd_pd(d1, *wregs.get_unchecked(blk + 1), s1);
1524 s2 = _mm512_fmadd_pd(d2, *wregs.get_unchecked(blk + 2), s2);
1525 s3 = _mm512_fmadd_pd(d3, *wregs.get_unchecked(blk + 3), s3);
1526
1527 blk += 4;
1528 }
1529
1530 for r in blk..n_chunks {
1531 let d = _mm512_loadu_pd(base_ptr.add(r * STEP));
1532 s0 = _mm512_fmadd_pd(d, *wregs.get_unchecked(r), s0);
1533 }
1534
1535 let mut sum =
1536 _mm512_reduce_add_pd(_mm512_add_pd(_mm512_add_pd(s0, s1), _mm512_add_pd(s2, s3)));
1537
1538 if tail_len != 0 {
1539 let d_tail = _mm512_maskz_loadu_pd(tmask, base_ptr.add(n_chunks * STEP));
1540 let tail = _mm512_mul_pd(d_tail, *wregs.get_unchecked(n_chunks));
1541 sum += _mm512_reduce_add_pd(tail);
1542 }
1543
1544 out[i] = sum * inv_n;
1545 }
1546}
1547#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1548#[target_feature(enable = "avx2,fma")]
1549#[inline]
1550unsafe fn cwma_row_avx2(
1551 data: &[f64],
1552 first: usize,
1553 period: usize,
1554 _stride: usize,
1555 w_ptr: *const f64,
1556 inv_n: f64,
1557 out: &mut [f64],
1558) {
1559 const STEP: usize = 4;
1560 let wlen = period - 1;
1561 let vec_blks = wlen / STEP;
1562 let tail = wlen % STEP;
1563
1564 let tail_mask = match tail {
1565 0 => _mm256_setzero_si256(),
1566 1 => _mm256_setr_epi64x(-1, 0, 0, 0),
1567 2 => _mm256_setr_epi64x(-1, -1, 0, 0),
1568 3 => _mm256_setr_epi64x(-1, -1, -1, 0),
1569 _ => unreachable!(),
1570 };
1571
1572 let start_idx = first + wlen;
1573 for i in start_idx..data.len().min(out.len()) {
1574 let mut acc = _mm256_setzero_pd();
1575
1576 for blk in 0..vec_blks {
1577 let d = _mm256_loadu_pd(data.as_ptr().add(i - blk * STEP - 3));
1578 let w = load_w256_rev(w_ptr.add(blk * STEP));
1579 acc = _mm256_fmadd_pd(d, w, acc);
1580 }
1581
1582 if tail != 0 {
1583 let base = vec_blks * STEP;
1584 let mut tail_sum = 0.0;
1585 for k in 0..tail {
1586 let w = *w_ptr.add(base + k);
1587 let d = *data.get_unchecked(i - (base + k));
1588 tail_sum = d.mul_add(w, tail_sum);
1589 }
1590 let hi = _mm256_extractf128_pd(acc, 1);
1591 let lo = _mm256_castpd256_pd128(acc);
1592 let tmp = _mm_add_pd(hi, lo);
1593 let tmp2 = _mm_add_pd(tmp, _mm_unpackhi_pd(tmp, tmp));
1594 out[i] = (_mm_cvtsd_f64(tmp2) + tail_sum) * inv_n;
1595 } else {
1596 let hi = _mm256_extractf128_pd(acc, 1);
1597 let lo = _mm256_castpd256_pd128(acc);
1598 let tmp = _mm_add_pd(hi, lo);
1599 let tmp2 = _mm_add_pd(tmp, _mm_unpackhi_pd(tmp, tmp));
1600 out[i] = _mm_cvtsd_f64(tmp2) * inv_n;
1601 }
1602 }
1603}
1604
1605#[cfg(feature = "python")]
1606#[pyfunction(name = "cwma")]
1607#[pyo3(signature = (data, period, kernel=None))]
1608pub fn cwma_py<'py>(
1609 py: Python<'py>,
1610 data: numpy::PyReadonlyArray1<'py, f64>,
1611 period: usize,
1612 kernel: Option<&str>,
1613) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
1614 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1615
1616 let slice_in = data.as_slice()?;
1617 let kern = validate_kernel(kernel, false)?;
1618
1619 let params = CwmaParams {
1620 period: Some(period),
1621 };
1622 let cwma_in = CwmaInput::from_slice(slice_in, params);
1623
1624 let result_vec: Vec<f64> = py
1625 .allow_threads(|| cwma_with_kernel(&cwma_in, kern).map(|o| o.values))
1626 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1627
1628 Ok(result_vec.into_pyarray(py))
1629}
1630
1631#[cfg(feature = "python")]
1632#[pyclass(name = "CwmaStream")]
1633pub struct CwmaStreamPy {
1634 stream: CwmaStream,
1635}
1636
1637#[cfg(feature = "python")]
1638#[pymethods]
1639impl CwmaStreamPy {
1640 #[new]
1641 fn new(period: usize) -> PyResult<Self> {
1642 let params = CwmaParams {
1643 period: Some(period),
1644 };
1645 let stream =
1646 CwmaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1647 Ok(CwmaStreamPy { stream })
1648 }
1649
1650 fn update(&mut self, value: f64) -> Option<f64> {
1651 self.stream.update(value)
1652 }
1653}
1654
1655#[cfg(feature = "python")]
1656#[pyfunction(name = "cwma_batch")]
1657#[pyo3(signature = (data, period_range, kernel=None))]
1658pub fn cwma_batch_py<'py>(
1659 py: Python<'py>,
1660 data: numpy::PyReadonlyArray1<'py, f64>,
1661 period_range: (usize, usize, usize),
1662 kernel: Option<&str>,
1663) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1664 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1665 use pyo3::types::PyDict;
1666
1667 let slice_in = data.as_slice()?;
1668 let kern = validate_kernel(kernel, true)?;
1669
1670 let sweep = CwmaBatchRange {
1671 period: period_range,
1672 };
1673
1674 let combos = expand_grid(&sweep);
1675 let rows = combos.len();
1676 let cols = slice_in.len();
1677
1678 let out_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
1679 let slice_out = unsafe { out_arr.as_slice_mut()? };
1680
1681 let first = slice_in
1682 .iter()
1683 .position(|x| !x.is_nan())
1684 .unwrap_or(slice_in.len());
1685 for (row, combo) in combos.iter().enumerate() {
1686 let warmup = first + combo.period.unwrap() - 1;
1687 let row_start = row * cols;
1688 for i in 0..warmup.min(cols) {
1689 slice_out[row_start + i] = f64::NAN;
1690 }
1691 }
1692
1693 let combos = py
1694 .allow_threads(|| {
1695 let kernel = match kern {
1696 Kernel::Auto => detect_best_batch_kernel(),
1697 k => k,
1698 };
1699 let simd = match kernel {
1700 Kernel::Avx512Batch => Kernel::Avx512,
1701 Kernel::Avx2Batch => Kernel::Avx2,
1702 Kernel::ScalarBatch => Kernel::Scalar,
1703 _ => unreachable!(),
1704 };
1705
1706 cwma_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1707 })
1708 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1709
1710 let dict = PyDict::new(py);
1711 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1712 dict.set_item(
1713 "periods",
1714 combos
1715 .iter()
1716 .map(|p| p.period.unwrap() as u64)
1717 .collect::<Vec<_>>()
1718 .into_pyarray(py),
1719 )?;
1720
1721 Ok(dict)
1722}
1723
1724#[cfg(all(feature = "python", feature = "cuda"))]
1725#[pyclass(module = "ta_indicators.cuda", unsendable)]
1726pub struct DeviceArrayF32CwmaPy {
1727 pub(crate) inner: DeviceArrayF32,
1728 pub(crate) guard: Arc<CudaCwma>,
1729}
1730
1731#[cfg(all(feature = "python", feature = "cuda"))]
1732#[pymethods]
1733impl DeviceArrayF32CwmaPy {
1734 #[getter]
1735 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1736 let d = PyDict::new(py);
1737 d.set_item("shape", (self.inner.rows, self.inner.cols))?;
1738 d.set_item("typestr", "<f4")?;
1739 d.set_item(
1740 "strides",
1741 (
1742 self.inner.cols * std::mem::size_of::<f32>(),
1743 std::mem::size_of::<f32>(),
1744 ),
1745 )?;
1746 d.set_item("data", (self.inner.device_ptr() as usize, false))?;
1747
1748 d.set_item("version", 3)?;
1749 Ok(d)
1750 }
1751
1752 fn __dlpack_device__(&self) -> (i32, i32) {
1753 (2, self.guard.device_id() as i32)
1754 }
1755
1756 #[allow(clippy::too_many_arguments)]
1757 #[pyo3(signature=(stream=None, max_version=None, dl_device=None, copy=None))]
1758 fn __dlpack__<'py>(
1759 &mut self,
1760 py: Python<'py>,
1761 stream: Option<pyo3::PyObject>,
1762 max_version: Option<pyo3::PyObject>,
1763 dl_device: Option<pyo3::PyObject>,
1764 copy: Option<pyo3::PyObject>,
1765 ) -> PyResult<PyObject> {
1766 use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
1767
1768 let (kdl, alloc_dev) = self.__dlpack_device__();
1769 if let Some(dev_obj) = dl_device.as_ref() {
1770 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1771 if dev_ty != kdl || dev_id != alloc_dev {
1772 let wants_copy = copy
1773 .as_ref()
1774 .and_then(|c| c.extract::<bool>(py).ok())
1775 .unwrap_or(false);
1776 if wants_copy {
1777 return Err(PyValueError::new_err(
1778 "device copy not implemented for __dlpack__",
1779 ));
1780 } else {
1781 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
1782 }
1783 }
1784 }
1785 }
1786
1787 let _ = stream;
1788
1789 let dummy =
1790 DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
1791 let inner = std::mem::replace(
1792 &mut self.inner,
1793 DeviceArrayF32 {
1794 buf: dummy,
1795 rows: 0,
1796 cols: 0,
1797 },
1798 );
1799
1800 let rows = inner.rows;
1801 let cols = inner.cols;
1802 let buf = inner.buf;
1803
1804 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1805
1806 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1807 }
1808}
1809
1810#[cfg(all(feature = "python", feature = "cuda"))]
1811#[pyfunction(name = "cwma_cuda_batch_dev")]
1812#[pyo3(signature = (data_f32, period_range, device_id=0))]
1813pub fn cwma_cuda_batch_dev_py(
1814 py: Python<'_>,
1815 data_f32: numpy::PyReadonlyArray1<'_, f32>,
1816 period_range: (usize, usize, usize),
1817 device_id: usize,
1818) -> PyResult<DeviceArrayF32CwmaPy> {
1819 if !cuda_available() {
1820 return Err(PyValueError::new_err("CUDA not available"));
1821 }
1822 let slice_in = data_f32.as_slice()?;
1823 let sweep = CwmaBatchRange {
1824 period: period_range,
1825 };
1826
1827 let cuda =
1828 Arc::new(CudaCwma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?);
1829 let inner = py.allow_threads(|| {
1830 cuda.cwma_batch_dev(slice_in, &sweep)
1831 .map_err(|e| PyValueError::new_err(e.to_string()))
1832 })?;
1833
1834 Ok(DeviceArrayF32CwmaPy { inner, guard: cuda })
1835}
1836
1837#[cfg(all(feature = "python", feature = "cuda"))]
1838#[pyfunction(name = "cwma_cuda_many_series_one_param_dev")]
1839#[pyo3(signature = (data_tm_f32, period, device_id=0))]
1840pub fn cwma_cuda_many_series_one_param_dev_py(
1841 py: Python<'_>,
1842 data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1843 period: usize,
1844 device_id: usize,
1845) -> PyResult<DeviceArrayF32CwmaPy> {
1846 use numpy::PyUntypedArrayMethods;
1847
1848 if !cuda_available() {
1849 return Err(PyValueError::new_err("CUDA not available"));
1850 }
1851
1852 let flat_in: &[f32] = data_tm_f32.as_slice()?;
1853 let rows = data_tm_f32.shape()[0];
1854 let cols = data_tm_f32.shape()[1];
1855 let params = CwmaParams {
1856 period: Some(period),
1857 };
1858
1859 let cuda =
1860 Arc::new(CudaCwma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?);
1861 let inner = py.allow_threads(|| {
1862 cuda.cwma_multi_series_one_param_time_major_dev(flat_in, cols, rows, ¶ms)
1863 .map_err(|e| PyValueError::new_err(e.to_string()))
1864 })?;
1865
1866 Ok(DeviceArrayF32CwmaPy { inner, guard: cuda })
1867}
1868
1869#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1870#[wasm_bindgen]
1871pub fn cwma_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1872 let params = CwmaParams {
1873 period: Some(period),
1874 };
1875 let input = CwmaInput::from_slice(data, params);
1876
1877 let mut output = vec![0.0; data.len()];
1878
1879 cwma_into_slice(&mut output, &input, Kernel::Auto)
1880 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1881
1882 Ok(output)
1883}
1884
1885#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1886#[wasm_bindgen]
1887pub fn cwma_alloc(len: usize) -> *mut f64 {
1888 let mut vec = Vec::<f64>::with_capacity(len);
1889 let ptr = vec.as_mut_ptr();
1890 std::mem::forget(vec);
1891 ptr
1892}
1893
1894#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1895#[wasm_bindgen]
1896pub fn cwma_free(ptr: *mut f64, len: usize) {
1897 if !ptr.is_null() && len > 0 {
1898 unsafe {
1899 let _ = Vec::from_raw_parts(ptr, len, len);
1900 }
1901 }
1902}
1903
1904#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1905#[wasm_bindgen]
1906pub fn cwma_into(
1907 in_ptr: *const f64,
1908 out_ptr: *mut f64,
1909 len: usize,
1910 period: usize,
1911) -> Result<(), JsValue> {
1912 if in_ptr.is_null() || out_ptr.is_null() {
1913 return Err(JsValue::from_str("null pointer passed to cwma_into"));
1914 }
1915
1916 unsafe {
1917 let data = std::slice::from_raw_parts(in_ptr, len);
1918
1919 if period == 0 || period > len {
1920 return Err(JsValue::from_str("Invalid period"));
1921 }
1922
1923 let params = CwmaParams {
1924 period: Some(period),
1925 };
1926 let input = CwmaInput::from_slice(data, params);
1927
1928 if in_ptr == out_ptr {
1929 let mut temp = vec![0.0; len];
1930 cwma_into_slice(&mut temp, &input, Kernel::Auto)
1931 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1932
1933 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1934 out.copy_from_slice(&temp);
1935 } else {
1936 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1937 cwma_into_slice(out, &input, Kernel::Auto)
1938 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1939 }
1940
1941 Ok(())
1942 }
1943}
1944
1945#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1946#[wasm_bindgen]
1947#[deprecated(since = "1.0.0", note = "Use cwma_batch instead")]
1948pub fn cwma_batch_js(
1949 data: &[f64],
1950 period_start: usize,
1951 period_end: usize,
1952 period_step: usize,
1953) -> Result<Vec<f64>, JsValue> {
1954 let sweep = CwmaBatchRange {
1955 period: (period_start, period_end, period_step),
1956 };
1957
1958 cwma_batch_inner(data, &sweep, Kernel::Auto, false)
1959 .map(|output| output.values)
1960 .map_err(|e| JsValue::from_str(&e.to_string()))
1961}
1962
1963#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1964#[wasm_bindgen]
1965pub fn cwma_batch_metadata_js(
1966 period_start: usize,
1967 period_end: usize,
1968 period_step: usize,
1969) -> Result<Vec<f64>, JsValue> {
1970 let sweep = CwmaBatchRange {
1971 period: (period_start, period_end, period_step),
1972 };
1973
1974 let combos = expand_grid(&sweep);
1975 let metadata: Vec<f64> = combos
1976 .iter()
1977 .map(|combo| combo.period.unwrap() as f64)
1978 .collect();
1979
1980 Ok(metadata)
1981}
1982
1983#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1984#[derive(Serialize, Deserialize)]
1985pub struct CwmaBatchConfig {
1986 pub period_range: (usize, usize, usize),
1987}
1988
1989#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1990#[derive(Serialize, Deserialize)]
1991pub struct CwmaBatchJsOutput {
1992 pub values: Vec<f64>,
1993 pub combos: Vec<CwmaParams>,
1994 pub rows: usize,
1995 pub cols: usize,
1996}
1997
1998#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1999#[wasm_bindgen(js_name = cwma_batch)]
2000pub fn cwma_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2001 let config: CwmaBatchConfig = serde_wasm_bindgen::from_value(config)
2002 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2003
2004 let sweep = CwmaBatchRange {
2005 period: config.period_range,
2006 };
2007
2008 let output = cwma_batch_inner(data, &sweep, Kernel::Auto, false)
2009 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2010
2011 let js_output = CwmaBatchJsOutput {
2012 values: output.values,
2013 combos: output.combos,
2014 rows: output.rows,
2015 cols: output.cols,
2016 };
2017
2018 serde_wasm_bindgen::to_value(&js_output)
2019 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2020}
2021
2022#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2023#[wasm_bindgen]
2024pub fn cwma_batch_into(
2025 in_ptr: *const f64,
2026 out_ptr: *mut f64,
2027 len: usize,
2028 period_start: usize,
2029 period_end: usize,
2030 period_step: usize,
2031) -> Result<usize, JsValue> {
2032 if in_ptr.is_null() || out_ptr.is_null() {
2033 return Err(JsValue::from_str("null pointer passed to cwma_batch_into"));
2034 }
2035
2036 unsafe {
2037 let data = std::slice::from_raw_parts(in_ptr, len);
2038
2039 let sweep = CwmaBatchRange {
2040 period: (period_start, period_end, period_step),
2041 };
2042
2043 let combos = expand_grid(&sweep);
2044 let rows = combos.len();
2045 let cols = len;
2046
2047 let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
2048
2049 cwma_batch_inner_into(data, &sweep, Kernel::Auto, false, out)
2050 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2051
2052 Ok(rows)
2053 }
2054}
2055
2056#[cfg(feature = "python")]
2057pub fn register_cwma_module(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
2058 m.add_function(wrap_pyfunction!(cwma_py, m)?)?;
2059 m.add_function(wrap_pyfunction!(cwma_batch_py, m)?)?;
2060 m.add_class::<CwmaStreamPy>()?;
2061 #[cfg(feature = "cuda")]
2062 {
2063 m.add_function(wrap_pyfunction!(cwma_cuda_batch_dev_py, m)?)?;
2064 m.add_function(wrap_pyfunction!(cwma_cuda_many_series_one_param_dev_py, m)?)?;
2065 }
2066 Ok(())
2067}
2068
2069#[cfg(test)]
2070mod tests {
2071 use super::*;
2072 use crate::skip_if_unsupported;
2073 use crate::utilities::data_loader::read_candles_from_csv;
2074 #[cfg(feature = "proptest")]
2075 use proptest::prelude::*;
2076
2077 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2078 #[test]
2079 fn test_cwma_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
2080 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2081 let candles = read_candles_from_csv(file_path)?;
2082
2083 let input = CwmaInput::with_default_candles(&candles);
2084
2085 let baseline = cwma(&input)?.values;
2086
2087 let mut out = vec![0.0; candles.close.len()];
2088 cwma_into(&input, &mut out)?;
2089
2090 assert_eq!(baseline.len(), out.len());
2091
2092 fn eq_or_both_nan(a: f64, b: f64) -> bool {
2093 (a.is_nan() && b.is_nan()) || (a == b)
2094 }
2095
2096 for (i, (a, b)) in baseline.iter().zip(out.iter()).enumerate() {
2097 assert!(
2098 eq_or_both_nan(*a, *b),
2099 "mismatch at index {}: baseline={}, into={}",
2100 i,
2101 a,
2102 b
2103 );
2104 }
2105
2106 Ok(())
2107 }
2108
2109 fn check_cwma_partial_params(
2110 test_name: &str,
2111 kernel: Kernel,
2112 ) -> Result<(), Box<dyn std::error::Error>> {
2113 skip_if_unsupported!(kernel, test_name);
2114 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2115 let candles = read_candles_from_csv(file_path)?;
2116
2117 let default_params = CwmaParams { period: None };
2118 let input_def = CwmaInput::from_candles(&candles, "close", default_params);
2119 let output_def = cwma_with_kernel(&input_def, kernel)?;
2120 assert_eq!(output_def.values.len(), candles.close.len());
2121
2122 let params_14 = CwmaParams { period: Some(14) };
2123 let input_14 = CwmaInput::from_candles(&candles, "hl2", params_14);
2124 let output_14 = cwma_with_kernel(&input_14, kernel)?;
2125 assert_eq!(output_14.values.len(), candles.close.len());
2126
2127 let params_custom = CwmaParams { period: Some(20) };
2128 let input_custom = CwmaInput::from_candles(&candles, "hlc3", params_custom);
2129 let output_custom = cwma_with_kernel(&input_custom, kernel)?;
2130 assert_eq!(output_custom.values.len(), candles.close.len());
2131
2132 Ok(())
2133 }
2134
2135 fn check_cwma_accuracy(
2136 test_name: &str,
2137 kernel: Kernel,
2138 ) -> Result<(), Box<dyn std::error::Error>> {
2139 skip_if_unsupported!(kernel, test_name);
2140 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2141 let candles = read_candles_from_csv(file_path)?;
2142
2143 let input = CwmaInput::with_default_candles(&candles);
2144 let result = cwma_with_kernel(&input, kernel)?;
2145 assert_eq!(result.values.len(), candles.close.len());
2146
2147 let expected_last_five = [
2148 59224.641237300435,
2149 59213.64831277214,
2150 59171.21190130624,
2151 59167.01279027576,
2152 59039.413552249636,
2153 ];
2154
2155 let start = result.values.len().saturating_sub(5);
2156 for (i, &val) in result.values[start..].iter().enumerate() {
2157 let diff = (val - expected_last_five[i]).abs();
2158 assert!(
2159 diff < 1e-9,
2160 "[{}] CWMA {:?} mismatch at idx {}: got {}, expected {}",
2161 test_name,
2162 kernel,
2163 i,
2164 val,
2165 expected_last_five[i]
2166 );
2167 }
2168 Ok(())
2169 }
2170
2171 fn check_cwma_default_candles(
2172 test_name: &str,
2173 kernel: Kernel,
2174 ) -> Result<(), Box<dyn std::error::Error>> {
2175 skip_if_unsupported!(kernel, test_name);
2176 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2177 let candles = read_candles_from_csv(file_path)?;
2178
2179 let input = CwmaInput::with_default_candles(&candles);
2180 match input.data {
2181 CwmaData::Candles { source, .. } => assert_eq!(source, "close"),
2182 _ => panic!("Expected CwmaData::Candles"),
2183 }
2184 let output = cwma_with_kernel(&input, kernel)?;
2185 assert_eq!(output.values.len(), candles.close.len());
2186
2187 Ok(())
2188 }
2189
2190 fn check_cwma_zero_period(
2191 test_name: &str,
2192 kernel: Kernel,
2193 ) -> Result<(), Box<dyn std::error::Error>> {
2194 skip_if_unsupported!(kernel, test_name);
2195 let input_data = [10.0, 20.0, 30.0];
2196 let params = CwmaParams { period: Some(0) };
2197 let input = CwmaInput::from_slice(&input_data, params);
2198 let res = cwma_with_kernel(&input, kernel);
2199 assert!(res.is_err(), "[{}] Should fail with zero period", test_name);
2200 Ok(())
2201 }
2202
2203 fn check_cwma_period_exceeds_length(
2204 test_name: &str,
2205 kernel: Kernel,
2206 ) -> Result<(), Box<dyn std::error::Error>> {
2207 skip_if_unsupported!(kernel, test_name);
2208 let data_small = [10.0, 20.0, 30.0];
2209 let params = CwmaParams { period: Some(10) };
2210 let input = CwmaInput::from_slice(&data_small, params);
2211 let res = cwma_with_kernel(&input, kernel);
2212 assert!(
2213 res.is_err(),
2214 "[{}] Should fail with period exceeding length",
2215 test_name
2216 );
2217 Ok(())
2218 }
2219
2220 fn check_cwma_very_small_dataset(
2221 test_name: &str,
2222 kernel: Kernel,
2223 ) -> Result<(), Box<dyn std::error::Error>> {
2224 skip_if_unsupported!(kernel, test_name);
2225 let single_point = [42.0];
2226 let params = CwmaParams { period: Some(9) };
2227 let input = CwmaInput::from_slice(&single_point, params);
2228 let res = cwma_with_kernel(&input, kernel);
2229 assert!(
2230 res.is_err(),
2231 "[{}] Should fail with insufficient data",
2232 test_name
2233 );
2234 Ok(())
2235 }
2236
2237 fn check_cwma_empty_input(
2238 test_name: &str,
2239 kernel: Kernel,
2240 ) -> Result<(), Box<dyn std::error::Error>> {
2241 skip_if_unsupported!(kernel, test_name);
2242 let empty: [f64; 0] = [];
2243 let input = CwmaInput::from_slice(&empty, CwmaParams::default());
2244 let res = cwma_with_kernel(&input, kernel);
2245 assert!(
2246 matches!(res, Err(CwmaError::EmptyInputData)),
2247 "[{}] Should fail with empty input",
2248 test_name
2249 );
2250 Ok(())
2251 }
2252
2253 fn check_cwma_reinput(
2254 test_name: &str,
2255 kernel: Kernel,
2256 ) -> Result<(), Box<dyn std::error::Error>> {
2257 skip_if_unsupported!(kernel, test_name);
2258 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2259 let candles = read_candles_from_csv(file_path)?;
2260
2261 let first_params = CwmaParams { period: Some(80) };
2262 let first_input = CwmaInput::from_candles(&candles, "close", first_params);
2263 let first_result = cwma_with_kernel(&first_input, kernel)?;
2264
2265 let second_params = CwmaParams { period: Some(60) };
2266 let second_input = CwmaInput::from_slice(&first_result.values, second_params);
2267 let second_result = cwma_with_kernel(&second_input, kernel)?;
2268 assert_eq!(second_result.values.len(), first_result.values.len());
2269
2270 if second_result.values.len() > 240 {
2271 for i in 240..second_result.values.len() {
2272 assert!(
2273 !second_result.values[i].is_nan(),
2274 "[{}] Found unexpected NaN at index {}",
2275 test_name,
2276 i
2277 );
2278 }
2279 }
2280 Ok(())
2281 }
2282
2283 fn check_cwma_nan_handling(
2284 test_name: &str,
2285 kernel: Kernel,
2286 ) -> Result<(), Box<dyn std::error::Error>> {
2287 skip_if_unsupported!(kernel, test_name);
2288 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2289 let candles = read_candles_from_csv(file_path)?;
2290
2291 let input = CwmaInput::from_candles(&candles, "close", CwmaParams { period: Some(9) });
2292 let res = cwma_with_kernel(&input, kernel)?;
2293 assert_eq!(res.values.len(), candles.close.len());
2294 if res.values.len() > 240 {
2295 for (i, &val) in res.values[240..].iter().enumerate() {
2296 assert!(
2297 !val.is_nan(),
2298 "[{}] Found unexpected NaN at out-index {}",
2299 test_name,
2300 240 + i
2301 );
2302 }
2303 }
2304 Ok(())
2305 }
2306
2307 fn check_cwma_streaming(
2308 test_name: &str,
2309 kernel: Kernel,
2310 ) -> Result<(), Box<dyn std::error::Error>> {
2311 skip_if_unsupported!(kernel, test_name);
2312 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2313 let candles = read_candles_from_csv(file_path)?;
2314
2315 let period = 9;
2316 let input = CwmaInput::from_candles(
2317 &candles,
2318 "close",
2319 CwmaParams {
2320 period: Some(period),
2321 },
2322 );
2323 let batch_output = cwma_with_kernel(&input, kernel)?.values;
2324
2325 let mut stream = CwmaStream::try_new(CwmaParams {
2326 period: Some(period),
2327 })?;
2328 let mut stream_values = Vec::with_capacity(candles.close.len());
2329 for &price in &candles.close {
2330 match stream.update(price) {
2331 Some(val) => stream_values.push(val),
2332 None => stream_values.push(f64::NAN),
2333 }
2334 }
2335
2336 assert_eq!(batch_output.len(), stream_values.len());
2337 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
2338 if b.is_nan() && s.is_nan() {
2339 continue;
2340 }
2341 let diff = (b - s).abs();
2342 assert!(
2343 diff < 1e-9,
2344 "[{}] CWMA streaming mismatch at idx {}: batch={}, stream={}, diff={}",
2345 test_name,
2346 i,
2347 b,
2348 s,
2349 diff
2350 );
2351 }
2352 Ok(())
2353 }
2354
2355 #[cfg(debug_assertions)]
2356 fn check_cwma_no_poison(
2357 test_name: &str,
2358 kernel: Kernel,
2359 ) -> Result<(), Box<dyn std::error::Error>> {
2360 skip_if_unsupported!(kernel, test_name);
2361
2362 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2363 let candles = read_candles_from_csv(file_path)?;
2364
2365 let test_params = vec![
2366 CwmaParams::default(),
2367 CwmaParams { period: Some(2) },
2368 CwmaParams { period: Some(3) },
2369 CwmaParams { period: Some(5) },
2370 CwmaParams { period: Some(7) },
2371 CwmaParams { period: Some(10) },
2372 CwmaParams { period: Some(14) },
2373 CwmaParams { period: Some(20) },
2374 CwmaParams { period: Some(30) },
2375 CwmaParams { period: Some(50) },
2376 CwmaParams { period: Some(100) },
2377 CwmaParams { period: Some(200) },
2378 CwmaParams { period: Some(250) },
2379 ];
2380
2381 for (param_idx, params) in test_params.iter().enumerate() {
2382 let input = CwmaInput::from_candles(&candles, "close", params.clone());
2383 let output = cwma_with_kernel(&input, kernel)?;
2384
2385 for (i, &val) in output.values.iter().enumerate() {
2386 if val.is_nan() {
2387 continue;
2388 }
2389
2390 let bits = val.to_bits();
2391
2392 if bits == 0x11111111_11111111 {
2393 panic!(
2394 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
2395 with params: period={}",
2396 test_name,
2397 val,
2398 bits,
2399 i,
2400 params.period.unwrap_or(14)
2401 );
2402 }
2403
2404 if bits == 0x22222222_22222222 {
2405 panic!(
2406 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
2407 with params: period={}",
2408 test_name,
2409 val,
2410 bits,
2411 i,
2412 params.period.unwrap_or(14)
2413 );
2414 }
2415
2416 if bits == 0x33333333_33333333 {
2417 panic!(
2418 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
2419 with params: period={}",
2420 test_name,
2421 val,
2422 bits,
2423 i,
2424 params.period.unwrap_or(14)
2425 );
2426 }
2427 }
2428 }
2429
2430 Ok(())
2431 }
2432
2433 #[cfg(not(debug_assertions))]
2434 fn check_cwma_no_poison(
2435 _test_name: &str,
2436 _kernel: Kernel,
2437 ) -> Result<(), Box<dyn std::error::Error>> {
2438 Ok(())
2439 }
2440
2441 #[cfg(feature = "proptest")]
2442 #[allow(clippy::float_cmp)]
2443 fn check_cwma_property(
2444 test_name: &str,
2445 kernel: Kernel,
2446 ) -> Result<(), Box<dyn std::error::Error>> {
2447 use proptest::prelude::*;
2448 skip_if_unsupported!(kernel, test_name);
2449
2450 let strat = (2usize..=32).prop_flat_map(|period| {
2451 (
2452 prop::collection::vec(
2453 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
2454 period..400,
2455 ),
2456 Just(period),
2457 (-1e3f64..1e3f64).prop_filter("finite a", |a| a.is_finite() && *a != 0.0),
2458 -1e3f64..1e3f64,
2459 )
2460 });
2461
2462 proptest::test_runner::TestRunner::default()
2463 .run(&strat, |(data, period, a, b)| {
2464 let params = CwmaParams {
2465 period: Some(period),
2466 };
2467 let input = CwmaInput::from_slice(&data, params.clone());
2468
2469 let fast = cwma_with_kernel(&input, kernel);
2470 let slow = cwma_with_kernel(&input, Kernel::Scalar);
2471
2472 match (fast, slow) {
2473 (Err(e1), Err(e2))
2474 if std::mem::discriminant(&e1) == std::mem::discriminant(&e2) =>
2475 {
2476 return Ok(());
2477 }
2478
2479 (Err(e1), Err(e2)) => {
2480 prop_assert!(false, "different errors: fast={:?} slow={:?}", e1, e2)
2481 }
2482
2483 (Err(e1), Ok(_)) => {
2484 prop_assert!(false, "fast errored {e1:?} but scalar succeeded")
2485 }
2486 (Ok(_), Err(e2)) => {
2487 prop_assert!(false, "scalar errored {e2:?} but fast succeeded")
2488 }
2489
2490 (Ok(fast), Ok(reference)) => {
2491 let CwmaOutput { values: out } = fast;
2492 let CwmaOutput { values: rref } = reference;
2493
2494 let mut stream = CwmaStream::try_new(params.clone()).unwrap();
2495 let mut s_out = Vec::with_capacity(data.len());
2496 for &v in &data {
2497 s_out.push(stream.update(v).unwrap_or(f64::NAN));
2498 }
2499
2500 let transformed: Vec<f64> = data.iter().map(|x| a * x + b).collect();
2501 let t_out = cwma(&CwmaInput::from_slice(&transformed, params))?.values;
2502
2503 for i in (period - 1)..data.len() {
2504 let w = &data[i + 1 - period..=i];
2505 let (lo, hi) = w
2506 .iter()
2507 .fold((f64::INFINITY, f64::NEG_INFINITY), |(l, h), &v| {
2508 (l.min(v), h.max(v))
2509 });
2510 let y = out[i];
2511 let yr = rref[i];
2512 let ys = s_out[i];
2513 let yt = t_out[i];
2514
2515 prop_assert!(
2516 y.is_nan() || (y >= lo - 1e-9 && y <= hi + 1e-9),
2517 "idx {i}: {y} ∉ [{lo}, {hi}]"
2518 );
2519
2520 if period == 1 && y.is_finite() {
2521 prop_assert!((y - data[i]).abs() <= f64::EPSILON);
2522 }
2523
2524 if w.iter().all(|v| *v == w[0]) {
2525 prop_assert!((y - w[0]).abs() <= 1e-9);
2526 }
2527
2528 if data[..=i].windows(2).all(|p| p[0] <= p[1])
2529 && y.is_finite()
2530 && out[i - 1].is_finite()
2531 {
2532 prop_assert!(y >= out[i - 1] - 1e-12);
2533 }
2534
2535 {
2536 let expected = a * y + b;
2537 let diff = (yt - expected).abs();
2538 let tol_abs = 1e-9_f64;
2539 let tol_rel = expected.abs() * 1e-9;
2540 let ulp = yt.to_bits().abs_diff(expected.to_bits());
2541
2542 prop_assert!(
2543 diff <= tol_abs.max(tol_rel) || ulp <= 8,
2544 "idx {i}: affine mismatch diff={diff:e} ULP={ulp}"
2545 );
2546 }
2547
2548 let ulp = y.to_bits().abs_diff(yr.to_bits());
2549 prop_assert!(
2550 (y - yr).abs() <= 1e-9 || ulp <= 4,
2551 "idx {i}: fast={y} ref={yr} ULP={ulp}"
2552 );
2553
2554 prop_assert!(
2555 (y - ys).abs() <= 1e-9 || (y.is_nan() && ys.is_nan()),
2556 "idx {i}: stream mismatch"
2557 );
2558 }
2559
2560 let first = data.iter().position(|x| !x.is_nan()).unwrap_or(data.len());
2561 let warm = first + period - 1;
2562 prop_assert!(out[..warm].iter().all(|v| v.is_nan()));
2563 }
2564 }
2565
2566 Ok(())
2567 })
2568 .unwrap();
2569
2570 assert!(cwma(&CwmaInput::from_slice(&[], CwmaParams::default())).is_err());
2571 assert!(cwma(&CwmaInput::from_slice(
2572 &[f64::NAN; 12],
2573 CwmaParams::default()
2574 ))
2575 .is_err());
2576 assert!(cwma(&CwmaInput::from_slice(
2577 &[1.0; 5],
2578 CwmaParams { period: Some(8) }
2579 ))
2580 .is_err());
2581 assert!(cwma(&CwmaInput::from_slice(
2582 &[1.0; 5],
2583 CwmaParams { period: Some(0) }
2584 ))
2585 .is_err());
2586
2587 Ok(())
2588 }
2589
2590 macro_rules! generate_all_cwma_tests {
2591 ($($test_fn:ident),*) => {
2592 paste::paste! {
2593 $(
2594 #[test]
2595 fn [<$test_fn _scalar_f64>]() {
2596 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2597 }
2598 )*
2599 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2600 $(
2601 #[test]
2602 fn [<$test_fn _avx2_f64>]() {
2603 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2604 }
2605 #[test]
2606 fn [<$test_fn _avx512_f64>]() {
2607 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2608 }
2609 )*
2610 }
2611 }
2612 }
2613
2614 generate_all_cwma_tests!(
2615 check_cwma_partial_params,
2616 check_cwma_accuracy,
2617 check_cwma_default_candles,
2618 check_cwma_zero_period,
2619 check_cwma_period_exceeds_length,
2620 check_cwma_very_small_dataset,
2621 check_cwma_empty_input,
2622 check_cwma_reinput,
2623 check_cwma_nan_handling,
2624 check_cwma_streaming,
2625 check_cwma_no_poison
2626 );
2627
2628 #[cfg(feature = "proptest")]
2629 generate_all_cwma_tests!(check_cwma_property);
2630
2631 fn check_batch_default_row(
2632 test: &str,
2633 kernel: Kernel,
2634 ) -> Result<(), Box<dyn std::error::Error>> {
2635 skip_if_unsupported!(kernel, test);
2636
2637 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2638 let c = read_candles_from_csv(file)?;
2639
2640 let output = CwmaBatchBuilder::new()
2641 .kernel(kernel)
2642 .apply_candles(&c, "close")?;
2643
2644 let def = CwmaParams::default();
2645 let row = output.values_for(&def).expect("default row missing");
2646
2647 assert_eq!(row.len(), c.close.len());
2648
2649 let expected = [
2650 59224.641237300435,
2651 59213.64831277214,
2652 59171.21190130624,
2653 59167.01279027576,
2654 59039.413552249636,
2655 ];
2656 let start = row.len() - 5;
2657 for (i, &v) in row[start..].iter().enumerate() {
2658 assert!(
2659 (v - expected[i]).abs() < 1e-8,
2660 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
2661 );
2662 }
2663 Ok(())
2664 }
2665
2666 #[cfg(debug_assertions)]
2667 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
2668 skip_if_unsupported!(kernel, test);
2669
2670 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2671 let c = read_candles_from_csv(file)?;
2672
2673 let test_configs = vec![
2674 (2, 5, 1),
2675 (5, 25, 5),
2676 (10, 50, 10),
2677 (2, 4, 1),
2678 (50, 150, 25),
2679 (9, 21, 2),
2680 (9, 21, 4),
2681 (100, 300, 50),
2682 ];
2683
2684 for (cfg_idx, &(p_start, p_end, p_step)) in test_configs.iter().enumerate() {
2685 let output = CwmaBatchBuilder::new()
2686 .kernel(kernel)
2687 .period_range(p_start, p_end, p_step)
2688 .apply_candles(&c, "close")?;
2689
2690 for (idx, &val) in output.values.iter().enumerate() {
2691 if val.is_nan() {
2692 continue;
2693 }
2694
2695 let bits = val.to_bits();
2696 let row = idx / output.cols;
2697 let col = idx % output.cols;
2698 let combo = &output.combos[row];
2699
2700 if bits == 0x11111111_11111111 {
2701 panic!(
2702 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2703 at row {} col {} (flat index {}) with params: period={}",
2704 test,
2705 cfg_idx,
2706 val,
2707 bits,
2708 row,
2709 col,
2710 idx,
2711 combo.period.unwrap_or(14)
2712 );
2713 }
2714
2715 if bits == 0x22222222_22222222 {
2716 panic!(
2717 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2718 at row {} col {} (flat index {}) with params: period={}",
2719 test,
2720 cfg_idx,
2721 val,
2722 bits,
2723 row,
2724 col,
2725 idx,
2726 combo.period.unwrap_or(14)
2727 );
2728 }
2729
2730 if bits == 0x33333333_33333333 {
2731 panic!(
2732 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2733 at row {} col {} (flat index {}) with params: period={}",
2734 test,
2735 cfg_idx,
2736 val,
2737 bits,
2738 row,
2739 col,
2740 idx,
2741 combo.period.unwrap_or(14)
2742 );
2743 }
2744 }
2745 }
2746
2747 Ok(())
2748 }
2749
2750 #[cfg(not(debug_assertions))]
2751 fn check_batch_no_poison(
2752 _test: &str,
2753 _kernel: Kernel,
2754 ) -> Result<(), Box<dyn std::error::Error>> {
2755 Ok(())
2756 }
2757
2758 macro_rules! gen_batch_tests {
2759 ($fn_name:ident) => {
2760 paste::paste! {
2761 #[test] fn [<$fn_name _scalar>]() {
2762 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2763 }
2764 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2765 #[test] fn [<$fn_name _avx2>]() {
2766 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2767 }
2768 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2769 #[test] fn [<$fn_name _avx512>]() {
2770 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2771 }
2772 #[test] fn [<$fn_name _auto_detect>]() {
2773 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2774 }
2775 }
2776 };
2777 }
2778 gen_batch_tests!(check_batch_default_row);
2779 gen_batch_tests!(check_batch_no_poison);
2780}