1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5 make_uninit_matrix,
6};
7#[cfg(feature = "python")]
8use crate::utilities::kernel_validation::validate_kernel;
9use aligned_vec::{AVec, CACHELINE_ALIGN};
10#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
11use core::arch::x86_64::*;
12#[cfg(not(target_arch = "wasm32"))]
13use rayon::prelude::*;
14use std::convert::AsRef;
15use std::mem::MaybeUninit;
16use thiserror::Error;
17
18#[cfg(all(feature = "python", feature = "cuda"))]
19use crate::cuda::moving_averages::{CudaSma, DeviceArrayF32};
20#[cfg(all(feature = "python", feature = "cuda"))]
21use cust::context::Context;
22#[cfg(all(feature = "python", feature = "cuda"))]
23use cust::memory::DeviceBuffer;
24#[cfg(feature = "python")]
25use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1, PyReadonlyArray2};
26#[cfg(feature = "python")]
27use pyo3::exceptions::PyValueError;
28#[cfg(feature = "python")]
29use pyo3::prelude::*;
30#[cfg(feature = "python")]
31use pyo3::types::PyDict;
32#[cfg(all(feature = "python", feature = "cuda"))]
33use std::sync::Arc;
34
35#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
36use serde::{Deserialize, Serialize};
37#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
38use wasm_bindgen::prelude::*;
39
40impl<'a> AsRef<[f64]> for SmaInput<'a> {
41 #[inline(always)]
42 fn as_ref(&self) -> &[f64] {
43 match &self.data {
44 SmaData::Slice(slice) => slice,
45 SmaData::Candles { candles, source } => source_type(candles, source),
46 }
47 }
48}
49
50#[derive(Debug, Clone)]
51pub enum SmaData<'a> {
52 Candles {
53 candles: &'a Candles,
54 source: &'a str,
55 },
56 Slice(&'a [f64]),
57}
58
59#[derive(Debug, Clone)]
60pub struct SmaOutput {
61 pub values: Vec<f64>,
62}
63
64#[derive(Debug, Clone)]
65#[cfg_attr(
66 all(target_arch = "wasm32", feature = "wasm"),
67 derive(Serialize, Deserialize)
68)]
69pub struct SmaParams {
70 pub period: Option<usize>,
71}
72
73impl Default for SmaParams {
74 fn default() -> Self {
75 Self { period: Some(9) }
76 }
77}
78
79#[derive(Debug, Clone)]
80pub struct SmaInput<'a> {
81 pub data: SmaData<'a>,
82 pub params: SmaParams,
83}
84
85impl<'a> SmaInput<'a> {
86 #[inline]
87 pub fn from_candles(c: &'a Candles, s: &'a str, p: SmaParams) -> Self {
88 Self {
89 data: SmaData::Candles {
90 candles: c,
91 source: s,
92 },
93 params: p,
94 }
95 }
96 #[inline]
97 pub fn from_slice(sl: &'a [f64], p: SmaParams) -> Self {
98 Self {
99 data: SmaData::Slice(sl),
100 params: p,
101 }
102 }
103 #[inline]
104 pub fn with_default_candles(c: &'a Candles) -> Self {
105 Self::from_candles(c, "close", SmaParams::default())
106 }
107 #[inline]
108 pub fn get_period(&self) -> usize {
109 self.params.period.unwrap_or(9)
110 }
111}
112
113#[derive(Copy, Clone, Debug)]
114pub struct SmaBuilder {
115 period: Option<usize>,
116 kernel: Kernel,
117}
118
119impl Default for SmaBuilder {
120 fn default() -> Self {
121 Self {
122 period: None,
123 kernel: Kernel::Auto,
124 }
125 }
126}
127
128impl SmaBuilder {
129 #[inline(always)]
130 pub fn new() -> Self {
131 Self::default()
132 }
133 #[inline(always)]
134 pub fn period(mut self, n: usize) -> Self {
135 self.period = Some(n);
136 self
137 }
138 #[inline(always)]
139 pub fn kernel(mut self, k: Kernel) -> Self {
140 self.kernel = k;
141 self
142 }
143 #[inline(always)]
144 pub fn apply(self, c: &Candles) -> Result<SmaOutput, SmaError> {
145 let p = SmaParams {
146 period: self.period,
147 };
148 let i = SmaInput::from_candles(c, "close", p);
149 sma_with_kernel(&i, self.kernel)
150 }
151 #[inline(always)]
152 pub fn apply_slice(self, d: &[f64]) -> Result<SmaOutput, SmaError> {
153 let p = SmaParams {
154 period: self.period,
155 };
156 let i = SmaInput::from_slice(d, p);
157 sma_with_kernel(&i, self.kernel)
158 }
159 #[inline(always)]
160 pub fn into_stream(self) -> Result<SmaStream, SmaError> {
161 let p = SmaParams {
162 period: self.period,
163 };
164 SmaStream::try_new(p)
165 }
166}
167
168#[derive(Debug, Error)]
169pub enum SmaError {
170 #[error("sma: Empty input data.")]
171 EmptyInputData,
172 #[error("sma: Invalid period: period = {period}, data length = {data_len}")]
173 InvalidPeriod { period: usize, data_len: usize },
174 #[error("sma: Not enough valid data: needed = {needed}, valid = {valid}")]
175 NotEnoughValidData { needed: usize, valid: usize },
176 #[error("sma: All values are NaN.")]
177 AllValuesNaN,
178 #[error("sma: Output buffer size mismatch: expected = {expected}, got = {got}")]
179 OutputLengthMismatch { expected: usize, got: usize },
180 #[error("sma: Invalid range: start={start}, end={end}, step={step}")]
181 InvalidRange {
182 start: usize,
183 end: usize,
184 step: usize,
185 },
186 #[error("sma: Invalid kernel for batch: {0:?}")]
187 InvalidKernelForBatch(Kernel),
188}
189
190#[inline]
191pub fn sma(input: &SmaInput) -> Result<SmaOutput, SmaError> {
192 sma_with_kernel(input, Kernel::Auto)
193}
194
195pub fn sma_with_kernel(input: &SmaInput, kernel: Kernel) -> Result<SmaOutput, SmaError> {
196 let (data, period, first, chosen) = sma_prepare(input, kernel)?;
197 let mut out = alloc_with_nan_prefix(data.len(), first + period - 1);
198 sma_compute_into(data, period, first, chosen, &mut out);
199 Ok(SmaOutput { values: out })
200}
201
202#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
203#[inline]
204pub fn sma_into(input: &SmaInput, out: &mut [f64]) -> Result<(), SmaError> {
205 let (data, period, first, chosen) = sma_prepare(input, Kernel::Auto)?;
206
207 if out.len() != data.len() {
208 return Err(SmaError::OutputLengthMismatch {
209 expected: data.len(),
210 got: out.len(),
211 });
212 }
213
214 let warm = (first + period - 1).min(out.len());
215 for v in &mut out[..warm] {
216 *v = f64::from_bits(0x7ff8_0000_0000_0000);
217 }
218
219 sma_compute_into(data, period, first, chosen, out);
220 Ok(())
221}
222
223#[inline]
224pub fn sma_into_slice(dst: &mut [f64], input: &SmaInput, kern: Kernel) -> Result<(), SmaError> {
225 let (data, period, first, chosen) = sma_prepare(input, kern)?;
226
227 if dst.len() != data.len() {
228 return Err(SmaError::OutputLengthMismatch {
229 expected: data.len(),
230 got: dst.len(),
231 });
232 }
233
234 let warmup = first + period - 1;
235 for v in &mut dst[..warmup] {
236 *v = f64::NAN;
237 }
238
239 sma_compute_into(data, period, first, chosen, dst);
240
241 Ok(())
242}
243
244#[inline(always)]
245fn sma_prepare<'a>(
246 input: &'a SmaInput,
247 kernel: Kernel,
248) -> Result<(&'a [f64], usize, usize, Kernel), SmaError> {
249 let data: &[f64] = input.as_ref();
250 if data.is_empty() {
251 return Err(SmaError::EmptyInputData);
252 }
253
254 let period = input.get_period();
255 let len = data.len();
256 if period == 0 || period > len {
257 return Err(SmaError::InvalidPeriod {
258 period,
259 data_len: len,
260 });
261 }
262
263 let first = data
264 .iter()
265 .position(|x| !x.is_nan())
266 .ok_or(SmaError::AllValuesNaN)?;
267 if len - first < period {
268 return Err(SmaError::NotEnoughValidData {
269 needed: period,
270 valid: len - first,
271 });
272 }
273
274 let chosen = match kernel {
275 Kernel::Auto => detect_best_kernel(),
276 k => k,
277 };
278 Ok((data, period, first, chosen))
279}
280
281#[inline]
282fn sma_compute_into(data: &[f64], period: usize, first: usize, kernel: Kernel, out: &mut [f64]) {
283 unsafe {
284 match kernel {
285 Kernel::Scalar | Kernel::ScalarBatch => {
286 sma_scalar(data, period, first, out);
287 }
288 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
289 Kernel::Avx2 | Kernel::Avx2Batch => {
290 sma_scalar(data, period, first, out);
291 }
292 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
293 Kernel::Avx512 | Kernel::Avx512Batch => {
294 sma_avx512(data, period, first, out);
295 }
296 _ => unreachable!(),
297 }
298 }
299}
300
301#[inline(always)]
302pub unsafe fn sma_scalar(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
303 debug_assert!(period >= 1);
304 debug_assert_eq!(data.len(), out.len());
305 let len = data.len();
306
307 let dp = data.as_ptr();
308 let op = out.as_mut_ptr();
309
310 if period == 1 {
311 for i in first..len {
312 *op.add(i) = *dp.add(i);
313 }
314 return;
315 }
316
317 let mut sum = 0.0;
318 for k in 0..period {
319 sum += *dp.add(first + k);
320 }
321 let inv = 1.0 / (period as f64);
322
323 *op.add(first + period - 1) = sum * inv;
324
325 for i in (first + period)..len {
326 sum += *dp.add(i) - *dp.add(i - period);
327 *op.add(i) = sum * inv;
328 }
329}
330
331#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
332#[target_feature(enable = "avx2")]
333#[inline]
334pub unsafe fn sma_avx2(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
335 use core::arch::x86_64::*;
336 debug_assert!(period >= 1);
337 debug_assert_eq!(data.len(), out.len());
338
339 let len = data.len();
340 let dp = data.as_ptr();
341 let op = out.as_mut_ptr();
342
343 if period == 1 {
344 let mut i = first;
345 while i < len {
346 *op.add(i) = *dp.add(i);
347 i += 1;
348 }
349 return;
350 }
351
352 let mut acc256 = _mm256_setzero_pd();
353 let mut k = 0usize;
354 let base = first;
355 let p4 = period & !3;
356
357 while k < p4 {
358 let v = _mm256_loadu_pd(dp.add(base + k));
359 acc256 = _mm256_add_pd(acc256, v);
360 k += 4;
361 }
362
363 let hadd = _mm256_hadd_pd(acc256, acc256);
364 let lo = _mm256_castpd256_pd128(hadd);
365 let hi = _mm256_extractf128_pd(hadd, 1);
366 let sum128 = _mm_add_sd(lo, hi);
367 let mut sum = _mm_cvtsd_f64(sum128);
368
369 while k < period {
370 sum += *dp.add(base + k);
371 k += 1;
372 }
373
374 let inv = 1.0 / (period as f64);
375 let inv_v = _mm256_set1_pd(inv);
376 let mut warm = first + period - 1;
377 *op.add(warm) = sum.mul_add(inv, 0.0);
378
379 let mut i = warm + 1;
380 let end = len;
381 let stride = 4usize;
382
383 while i + stride - 1 < end {
384 let v_new = _mm256_loadu_pd(dp.add(i));
385 let v_old = _mm256_loadu_pd(dp.add(i - period));
386 let d = _mm256_sub_pd(v_new, v_old);
387
388 let d_lo = _mm256_castpd256_pd128(d);
389 let d_hi = _mm256_extractf128_pd(d, 1);
390
391 let t_lo = _mm_unpacklo_pd(_mm_setzero_pd(), d_lo);
392 let p_lo = _mm_add_pd(d_lo, t_lo);
393
394 let t_hi = _mm_unpacklo_pd(_mm_setzero_pd(), d_hi);
395 let mut p_hi = _mm_add_pd(d_hi, t_hi);
396
397 let carry = _mm_permute_pd(p_lo, 0b11);
398 p_hi = _mm_add_pd(p_hi, carry);
399
400 let mut prefix = _mm256_castpd128_pd256(p_lo);
401 prefix = _mm256_insertf128_pd(prefix, p_hi, 1);
402
403 let sum_v = _mm256_set1_pd(sum);
404 let sums = _mm256_add_pd(sum_v, prefix);
405
406 let out_v = _mm256_mul_pd(sums, inv_v);
407 _mm256_storeu_pd(op.add(i), out_v);
408
409 let sums_hi = _mm256_extractf128_pd(sums, 1);
410 let last = _mm_unpackhi_pd(sums_hi, sums_hi);
411 sum = _mm_cvtsd_f64(last);
412
413 i += stride;
414 }
415
416 while i < end {
417 sum += *dp.add(i) - *dp.add(i - period);
418 *op.add(i) = sum.mul_add(inv, 0.0);
419 i += 1;
420 }
421}
422
423#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
424#[inline]
425pub fn sma_avx512(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
426 if period <= 32 {
427 unsafe { sma_avx512_short(data, period, first, out) }
428 } else {
429 unsafe { sma_avx512_long(data, period, first, out) }
430 }
431}
432
433#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
434#[target_feature(enable = "avx512f")]
435#[inline]
436pub unsafe fn sma_avx512_short(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
437 sma_avx512_long(data, period, first, out);
438}
439
440#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
441#[target_feature(enable = "avx512f")]
442#[inline]
443pub unsafe fn sma_avx512_long(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
444 use core::arch::x86_64::*;
445 debug_assert!(period >= 1);
446 debug_assert_eq!(data.len(), out.len());
447
448 let len = data.len();
449 let dp = data.as_ptr();
450 let op = out.as_mut_ptr();
451
452 if period == 1 {
453 let mut i = first;
454 while i < len {
455 *op.add(i) = *dp.add(i);
456 i += 1;
457 }
458 return;
459 }
460
461 let mut acc512 = _mm512_setzero_pd();
462 let mut k = 0usize;
463 let base = first;
464 let p8 = period & !7;
465
466 while k < p8 {
467 let v = _mm512_loadu_pd(dp.add(base + k));
468 acc512 = _mm512_add_pd(acc512, v);
469 k += 8;
470 }
471
472 let acc_lo256 = _mm512_castpd512_pd256(acc512);
473 let acc_hi256 = _mm512_extractf64x4_pd(acc512, 1);
474 let acc256 = _mm256_add_pd(acc_lo256, acc_hi256);
475
476 let hadd = _mm256_hadd_pd(acc256, acc256);
477 let lo = _mm256_castpd256_pd128(hadd);
478 let hi = _mm256_extractf128_pd(hadd, 1);
479 let sum128 = _mm_add_sd(lo, hi);
480 let mut sum = _mm_cvtsd_f64(sum128);
481
482 while k < period {
483 sum += *dp.add(base + k);
484 k += 1;
485 }
486
487 let inv = 1.0 / (period as f64);
488 let inv_v = _mm512_set1_pd(inv);
489 let warm = first + period - 1;
490 *op.add(warm) = sum.mul_add(inv, 0.0);
491
492 let idx_sl1 = _mm512_set_epi64(6, 5, 4, 3, 2, 1, 0, 0);
493
494 let idx_sl2 = _mm512_set_epi64(5, 4, 3, 2, 1, 0, 0, 0);
495
496 let idx_sl4 = _mm512_set_epi64(3, 2, 1, 0, 0, 0, 0, 0);
497
498 let mut i = warm + 1;
499 let end = len;
500
501 while i + 7 < end {
502 let v_new = _mm512_loadu_pd(dp.add(i));
503 let v_old = _mm512_loadu_pd(dp.add(i - period));
504 let d = _mm512_sub_pd(v_new, v_old);
505
506 let mut pref = d;
507 let sh1 = _mm512_maskz_permutexvar_pd(0b1111_1110, idx_sl1, pref);
508 pref = _mm512_add_pd(pref, sh1);
509
510 let sh2 = _mm512_maskz_permutexvar_pd(0b1111_1100, idx_sl2, pref);
511 pref = _mm512_add_pd(pref, sh2);
512
513 let sh4 = _mm512_maskz_permutexvar_pd(0b1111_0000, idx_sl4, pref);
514 pref = _mm512_add_pd(pref, sh4);
515
516 let sums = _mm512_add_pd(_mm512_set1_pd(sum), pref);
517
518 let out_v = _mm512_mul_pd(sums, inv_v);
519 _mm512_storeu_pd(op.add(i), out_v);
520
521 let sums_hi256 = _mm512_extractf64x4_pd(sums, 1);
522 let sums_hi128 = _mm256_extractf128_pd(sums_hi256, 1);
523 let last = _mm_unpackhi_pd(sums_hi128, sums_hi128);
524 sum = _mm_cvtsd_f64(last);
525
526 i += 8;
527 }
528
529 while i < end {
530 sum += *dp.add(i) - *dp.add(i - period);
531 *op.add(i) = sum.mul_add(inv, 0.0);
532 i += 1;
533 }
534}
535
536#[derive(Debug, Clone)]
537pub struct SmaStream {
538 period: usize,
539 buffer: Vec<f64>,
540 head: usize,
541 sum: f64,
542 count: usize,
543 inv: f64,
544
545 use_mask: bool,
546 mask: usize,
547}
548
549impl SmaStream {
550 #[inline(always)]
551 pub fn try_new(params: SmaParams) -> Result<Self, SmaError> {
552 let period = params.period.unwrap_or(9);
553 if period == 0 {
554 return Err(SmaError::InvalidPeriod {
555 period,
556 data_len: 0,
557 });
558 }
559 let use_mask = period.is_power_of_two();
560 Ok(Self {
561 period,
562 buffer: vec![0.0; period],
563 head: 0,
564 sum: 0.0,
565 count: 0,
566 inv: (period as f64).recip(),
567 use_mask,
568 mask: period.wrapping_sub(1),
569 })
570 }
571
572 #[inline(always)]
573 fn advance_head(&mut self) {
574 if self.use_mask {
575 self.head = (self.head + 1) & self.mask;
576 } else {
577 let next = self.head + 1;
578 self.head = if next == self.period { 0 } else { next };
579 }
580 }
581
582 #[inline(always)]
583 pub fn update(&mut self, value: f64) -> Option<f64> {
584 if self.period == 1 {
585 self.sum = value;
586 self.buffer[0] = value;
587 self.count = 1;
588 return Some(value);
589 }
590
591 if self.count < self.period {
592 self.sum += value;
593 self.buffer[self.head] = value;
594 self.advance_head();
595 self.count += 1;
596 if self.count == self.period {
597 return Some(self.sum * self.inv);
598 }
599 return None;
600 }
601
602 let old = self.buffer[self.head];
603 self.sum += value - old;
604 self.buffer[self.head] = value;
605 self.advance_head();
606 Some(self.sum * self.inv)
607 }
608}
609
610#[derive(Clone, Debug)]
611pub struct SmaBatchRange {
612 pub period: (usize, usize, usize),
613}
614
615impl Default for SmaBatchRange {
616 fn default() -> Self {
617 Self {
618 period: (9, 258, 1),
619 }
620 }
621}
622
623#[derive(Clone, Debug, Default)]
624pub struct SmaBatchBuilder {
625 range: SmaBatchRange,
626 kernel: Kernel,
627}
628
629impl SmaBatchBuilder {
630 pub fn new() -> Self {
631 Self::default()
632 }
633 pub fn kernel(mut self, k: Kernel) -> Self {
634 self.kernel = k;
635 self
636 }
637 #[inline]
638 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
639 self.range.period = (start, end, step);
640 self
641 }
642 #[inline]
643 pub fn period_static(mut self, p: usize) -> Self {
644 self.range.period = (p, p, 0);
645 self
646 }
647 pub fn apply_slice(self, data: &[f64]) -> Result<SmaBatchOutput, SmaError> {
648 sma_batch_with_kernel(data, &self.range, self.kernel)
649 }
650 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<SmaBatchOutput, SmaError> {
651 SmaBatchBuilder::new().kernel(k).apply_slice(data)
652 }
653 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<SmaBatchOutput, SmaError> {
654 let slice = source_type(c, src);
655 self.apply_slice(slice)
656 }
657 pub fn with_default_candles(c: &Candles) -> Result<SmaBatchOutput, SmaError> {
658 SmaBatchBuilder::new()
659 .kernel(Kernel::Auto)
660 .apply_candles(c, "close")
661 }
662}
663
664pub fn sma_batch_with_kernel(
665 data: &[f64],
666 sweep: &SmaBatchRange,
667 k: Kernel,
668) -> Result<SmaBatchOutput, SmaError> {
669 let kernel = match k {
670 Kernel::Auto => detect_best_batch_kernel(),
671 other if other.is_batch() => other,
672 other => return Err(SmaError::InvalidKernelForBatch(other)),
673 };
674 let simd = match kernel {
675 Kernel::Avx512Batch => Kernel::Avx512,
676 Kernel::Avx2Batch => Kernel::Avx2,
677 Kernel::ScalarBatch => Kernel::Scalar,
678 _ => unreachable!(),
679 };
680 sma_batch_par_slice(data, sweep, simd)
681}
682
683#[derive(Clone, Debug)]
684pub struct SmaBatchOutput {
685 pub values: Vec<f64>,
686 pub combos: Vec<SmaParams>,
687 pub rows: usize,
688 pub cols: usize,
689}
690impl SmaBatchOutput {
691 pub fn row_for_params(&self, p: &SmaParams) -> Option<usize> {
692 self.combos
693 .iter()
694 .position(|c| c.period.unwrap_or(9) == p.period.unwrap_or(9))
695 }
696 pub fn values_for(&self, p: &SmaParams) -> Option<&[f64]> {
697 self.row_for_params(p).map(|row| {
698 let start = row * self.cols;
699 &self.values[start..start + self.cols]
700 })
701 }
702}
703
704#[inline(always)]
705pub fn expand_grid_sma(r: &SmaBatchRange) -> Result<Vec<SmaParams>, SmaError> {
706 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, SmaError> {
707 if step == 0 {
708 return Ok(vec![start]);
709 }
710 if start == end {
711 return Ok(vec![start]);
712 }
713 let mut vals = Vec::new();
714 if start < end {
715 let mut v = start;
716 while v <= end {
717 vals.push(v);
718 match v.checked_add(step) {
719 Some(next) => {
720 if next == v {
721 break;
722 }
723 v = next;
724 }
725 None => break,
726 }
727 }
728 } else {
729 let mut v = start;
730 while v >= end {
731 vals.push(v);
732 if v == 0 {
733 break;
734 }
735 let next = v.saturating_sub(step);
736 if next == v {
737 break;
738 }
739 v = next;
740 if v < end {
741 break;
742 }
743 }
744 }
745 if vals.is_empty() {
746 return Err(SmaError::InvalidRange { start, end, step });
747 }
748 Ok(vals)
749 }
750 let periods = axis_usize(r.period)?;
751 let mut out = Vec::with_capacity(periods.len());
752 for &p in &periods {
753 out.push(SmaParams { period: Some(p) });
754 }
755 Ok(out)
756}
757
758#[inline(always)]
759pub fn sma_batch_slice(
760 data: &[f64],
761 sweep: &SmaBatchRange,
762 kern: Kernel,
763) -> Result<SmaBatchOutput, SmaError> {
764 sma_batch_inner(data, sweep, kern, false)
765}
766
767#[inline(always)]
768pub fn sma_batch_par_slice(
769 data: &[f64],
770 sweep: &SmaBatchRange,
771 kern: Kernel,
772) -> Result<SmaBatchOutput, SmaError> {
773 sma_batch_inner(data, sweep, kern, true)
774}
775
776#[inline(always)]
777fn sma_batch_inner(
778 data: &[f64],
779 sweep: &SmaBatchRange,
780 kern: Kernel,
781 parallel: bool,
782) -> Result<SmaBatchOutput, SmaError> {
783 let combos = expand_grid_sma(sweep)?;
784 if data.is_empty() {
785 return Err(SmaError::EmptyInputData);
786 }
787
788 let cols = data.len();
789 let first = data
790 .iter()
791 .position(|x| !x.is_nan())
792 .ok_or(SmaError::AllValuesNaN)?;
793 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
794 if cols - first < max_p {
795 return Err(SmaError::NotEnoughValidData {
796 needed: max_p,
797 valid: cols - first,
798 });
799 }
800
801 let rows = combos.len();
802
803 rows.checked_mul(cols).ok_or(SmaError::InvalidRange {
804 start: sweep.period.0,
805 end: sweep.period.1,
806 step: sweep.period.2,
807 })?;
808
809 let mut buf_mu = make_uninit_matrix(rows, cols);
810
811 let mut guard = core::mem::ManuallyDrop::new(buf_mu);
812 let out_slice: &mut [f64] =
813 unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
814
815 sma_batch_inner_into(data, sweep, kern, parallel, out_slice)?;
816
817 let values = unsafe {
818 Vec::from_raw_parts(
819 guard.as_mut_ptr() as *mut f64,
820 guard.len(),
821 guard.capacity(),
822 )
823 };
824
825 Ok(SmaBatchOutput {
826 values,
827 combos,
828 rows,
829 cols,
830 })
831}
832
833#[inline(always)]
834unsafe fn sma_batch_row_prefixsum_scalar(
835 ps: &[f64],
836 period: usize,
837 mut i: usize,
838 cols: usize,
839 inv: f64,
840 dst: *mut f64,
841) {
842 while i < cols {
843 let s_hi = *ps.get_unchecked(i);
844 let s_lo = *ps.get_unchecked(i - period);
845 *dst.add(i) = (s_hi - s_lo) * inv;
846 i += 1;
847 }
848}
849
850#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
851#[target_feature(enable = "avx2")]
852#[inline]
853unsafe fn sma_batch_row_prefixsum_avx2(
854 ps: &[f64],
855 period: usize,
856 mut i: usize,
857 cols: usize,
858 inv: f64,
859 dst: *mut f64,
860) {
861 use core::arch::x86_64::*;
862
863 let inv_v = _mm256_set1_pd(inv);
864 let ps_ptr = ps.as_ptr();
865 let lanes = 4usize;
866
867 while i + (lanes - 1) < cols {
868 let hi = _mm256_loadu_pd(ps_ptr.add(i));
869 let lo = _mm256_loadu_pd(ps_ptr.add(i - period));
870 let diff = _mm256_sub_pd(hi, lo);
871 let out_v = _mm256_mul_pd(diff, inv_v);
872 _mm256_storeu_pd(dst.add(i), out_v);
873 i += lanes;
874 }
875
876 sma_batch_row_prefixsum_scalar(ps, period, i, cols, inv, dst);
877}
878
879#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
880#[target_feature(enable = "avx512f")]
881#[inline]
882unsafe fn sma_batch_row_prefixsum_avx512(
883 ps: &[f64],
884 period: usize,
885 mut i: usize,
886 cols: usize,
887 inv: f64,
888 dst: *mut f64,
889) {
890 use core::arch::x86_64::*;
891
892 let inv_v = _mm512_set1_pd(inv);
893 let ps_ptr = ps.as_ptr();
894 let lanes = 8usize;
895
896 while i + (lanes - 1) < cols {
897 let hi = _mm512_loadu_pd(ps_ptr.add(i));
898 let lo = _mm512_loadu_pd(ps_ptr.add(i - period));
899 let diff = _mm512_sub_pd(hi, lo);
900 let out_v = _mm512_mul_pd(diff, inv_v);
901 _mm512_storeu_pd(dst.add(i), out_v);
902 i += lanes;
903 }
904
905 sma_batch_row_prefixsum_scalar(ps, period, i, cols, inv, dst);
906}
907
908#[inline(always)]
909fn sma_batch_inner_into(
910 data: &[f64],
911 sweep: &SmaBatchRange,
912 kern: Kernel,
913 parallel: bool,
914 out: &mut [f64],
915) -> Result<Vec<SmaParams>, SmaError> {
916 let combos = expand_grid_sma(sweep)?;
917 if data.is_empty() {
918 return Err(SmaError::EmptyInputData);
919 }
920
921 let first = data
922 .iter()
923 .position(|x| !x.is_nan())
924 .ok_or(SmaError::AllValuesNaN)?;
925 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
926 if data.len() - first < max_p {
927 return Err(SmaError::NotEnoughValidData {
928 needed: max_p,
929 valid: data.len() - first,
930 });
931 }
932
933 let rows = combos.len();
934 let cols = data.len();
935 rows.checked_mul(cols).ok_or(SmaError::InvalidRange {
936 start: sweep.period.0,
937 end: sweep.period.1,
938 step: sweep.period.2,
939 })?;
940
941 let actual_kern = match kern {
942 Kernel::Auto => detect_best_batch_kernel(),
943 k => k,
944 };
945 let actual_kern = match actual_kern {
946 Kernel::Avx512Batch => Kernel::Avx512,
947 Kernel::Avx2Batch => Kernel::Avx2,
948 Kernel::ScalarBatch => Kernel::Scalar,
949 other => other,
950 };
951
952 let out_uninit: &mut [MaybeUninit<f64>] = unsafe {
953 core::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
954 };
955
956 let warm: Vec<usize> = combos
957 .iter()
958 .map(|c| first + c.period.unwrap_or(9) - 1)
959 .collect();
960 init_matrix_prefixes(out_uninit, cols, &warm);
961
962 let mut ps = vec![0.0_f64; cols];
963 if first < cols {
964 ps[first] = data[first];
965 for i in (first + 1)..cols {
966 ps[i] = ps[i - 1] + data[i];
967 }
968 }
969
970 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
971 let period = combos[row].period.unwrap();
972 let warm = first + period - 1;
973
974 let dst = core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
975 if warm >= cols {
976 return;
977 }
978 let inv = (period as f64).recip();
979
980 let s_hi = *ps.get_unchecked(warm);
981 let s_lo = if warm >= period {
982 *ps.get_unchecked(warm - period)
983 } else {
984 0.0
985 };
986 dst[warm] = (s_hi - s_lo) * inv;
987
988 let mut i = warm + 1;
989 if i >= cols {
990 return;
991 }
992
993 let dst_ptr = dst.as_mut_ptr();
994 match actual_kern {
995 Kernel::Scalar => sma_batch_row_prefixsum_scalar(&ps, period, i, cols, inv, dst_ptr),
996 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
997 Kernel::Avx2 => sma_batch_row_prefixsum_avx2(&ps, period, i, cols, inv, dst_ptr),
998 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
999 Kernel::Avx512 => sma_batch_row_prefixsum_avx512(&ps, period, i, cols, inv, dst_ptr),
1000 _ => sma_batch_row_prefixsum_scalar(&ps, period, i, cols, inv, dst_ptr),
1001 }
1002 };
1003
1004 if parallel {
1005 #[cfg(not(target_arch = "wasm32"))]
1006 out_uninit
1007 .par_chunks_mut(cols)
1008 .enumerate()
1009 .for_each(|(row, slice)| do_row(row, slice));
1010 #[cfg(target_arch = "wasm32")]
1011 for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
1012 do_row(row, slice);
1013 }
1014 } else {
1015 for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
1016 do_row(row, slice);
1017 }
1018 }
1019
1020 Ok(combos)
1021}
1022
1023#[inline(always)]
1024unsafe fn sma_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1025 sma_scalar(data, period, first, out);
1026}
1027
1028#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1029#[inline(always)]
1030unsafe fn sma_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1031 sma_avx2(data, period, first, out);
1032}
1033
1034#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1035#[inline(always)]
1036unsafe fn sma_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1037 if period <= 32 {
1038 sma_avx512_short(data, period, first, out);
1039 } else {
1040 sma_avx512_long(data, period, first, out);
1041 }
1042}
1043
1044#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1045#[inline(always)]
1046unsafe fn sma_row_avx512_short(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
1047 sma_avx512_short(data, period, first, out);
1048}
1049
1050#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1051#[inline(always)]
1052unsafe fn sma_row_avx512_long(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
1053 sma_avx512_long(data, period, first, out);
1054}
1055
1056#[cfg(feature = "python")]
1057#[pyfunction(name = "sma")]
1058#[pyo3(signature = (data, period, kernel=None))]
1059
1060pub fn sma_py<'py>(
1061 py: Python<'py>,
1062 data: PyReadonlyArray1<'py, f64>,
1063 period: usize,
1064 kernel: Option<&str>,
1065) -> PyResult<Bound<'py, PyArray1<f64>>> {
1066 use numpy::IntoPyArray;
1067
1068 let kern = validate_kernel(kernel, false)?;
1069
1070 let params = SmaParams {
1071 period: Some(period),
1072 };
1073
1074 let result_vec: Vec<f64> = if let Ok(data_slice) = data.as_slice() {
1075 let input = SmaInput::from_slice(data_slice, params);
1076 py.allow_threads(|| sma_with_kernel(&input, kern).map(|o| o.values))
1077 .map_err(|e| PyValueError::new_err(e.to_string()))?
1078 } else {
1079 let owned = data.as_array().to_owned();
1080 let data_slice = owned
1081 .as_slice()
1082 .expect("owned numpy array should be contiguous");
1083 let input = SmaInput::from_slice(data_slice, params);
1084 py.allow_threads(|| sma_with_kernel(&input, kern).map(|o| o.values))
1085 .map_err(|e| PyValueError::new_err(e.to_string()))?
1086 };
1087
1088 Ok(result_vec.into_pyarray(py))
1089}
1090
1091#[cfg(feature = "python")]
1092#[pyfunction(name = "sma_batch")]
1093#[pyo3(signature = (data, period_range, kernel=None))]
1094
1095pub fn sma_batch_py<'py>(
1096 py: Python<'py>,
1097 data: PyReadonlyArray1<'py, f64>,
1098 period_range: (usize, usize, usize),
1099 kernel: Option<&str>,
1100) -> PyResult<Bound<'py, PyDict>> {
1101 use numpy::IntoPyArray;
1102 use pyo3::types::PyDict;
1103
1104 let kern = validate_kernel(kernel, true)?;
1105
1106 let data_slice = data.as_slice()?;
1107 let range = SmaBatchRange {
1108 period: period_range,
1109 };
1110
1111 let combos = expand_grid_sma(&range).map_err(|e| PyValueError::new_err(e.to_string()))?;
1112 if data_slice.is_empty() {
1113 return Err(PyValueError::new_err("Empty data"));
1114 }
1115
1116 let rows = combos.len();
1117 let cols = data_slice.len();
1118
1119 let nelems = rows
1120 .checked_mul(cols)
1121 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
1122
1123 let out_arr = unsafe { PyArray1::<f64>::new(py, [nelems], false) };
1124 let slice_out = unsafe { out_arr.as_slice_mut()? };
1125
1126 let combos = py
1127 .allow_threads(|| {
1128 let kernel = match kern {
1129 Kernel::Auto => detect_best_batch_kernel(),
1130 k => k,
1131 };
1132 let simd = match kernel {
1133 Kernel::Avx512Batch => Kernel::Avx512,
1134 Kernel::Avx2Batch => Kernel::Avx2,
1135 Kernel::ScalarBatch => Kernel::Scalar,
1136 _ => unreachable!(),
1137 };
1138
1139 sma_batch_inner_into(data_slice, &range, simd, true, slice_out)
1140 })
1141 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1142
1143 let dict = PyDict::new(py);
1144 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1145
1146 dict.set_item(
1147 "periods",
1148 combos
1149 .iter()
1150 .map(|p| p.period.unwrap_or(9) as u64)
1151 .collect::<Vec<_>>()
1152 .into_pyarray(py),
1153 )?;
1154
1155 Ok(dict.into())
1156}
1157
1158#[cfg(all(feature = "python", feature = "cuda"))]
1159#[pyfunction(name = "sma_cuda_batch_dev")]
1160#[pyo3(signature = (data_f32, period_range, device_id=0))]
1161pub fn sma_cuda_batch_dev_py<'py>(
1162 py: Python<'py>,
1163 data_f32: numpy::PyReadonlyArray1<'py, f32>,
1164 period_range: (usize, usize, usize),
1165 device_id: usize,
1166) -> PyResult<(SmaDeviceArrayF32Py, Bound<'py, PyDict>)> {
1167 use crate::cuda::cuda_available;
1168 use numpy::IntoPyArray;
1169 use pyo3::types::PyDict;
1170
1171 if !cuda_available() {
1172 return Err(PyValueError::new_err("CUDA not available"));
1173 }
1174
1175 let slice_in = data_f32.as_slice()?;
1176 let sweep = SmaBatchRange {
1177 period: period_range,
1178 };
1179
1180 let (inner, combos, ctx_arc, dev_id) = py.allow_threads(|| {
1181 let cuda = CudaSma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1182 let (dev, combos) = cuda
1183 .sma_batch_dev(slice_in, &sweep)
1184 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1185 cuda.synchronize()
1186 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1187 Ok::<_, PyErr>((dev, combos, cuda.context_arc_clone(), cuda.device_id()))
1188 })?;
1189
1190 let dict = PyDict::new(py);
1191 let periods: Vec<u64> = combos.iter().map(|c| c.period.unwrap() as u64).collect();
1192 dict.set_item("periods", periods.into_pyarray(py))?;
1193
1194 Ok((
1195 SmaDeviceArrayF32Py {
1196 inner,
1197 _ctx: ctx_arc,
1198 device_id: dev_id,
1199 },
1200 dict,
1201 ))
1202}
1203
1204#[cfg(all(feature = "python", feature = "cuda"))]
1205#[pyfunction(name = "sma_cuda_many_series_one_param_dev")]
1206#[pyo3(signature = (data_tm_f32, period, device_id=0))]
1207pub fn sma_cuda_many_series_one_param_dev_py(
1208 py: Python<'_>,
1209 data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1210 period: usize,
1211 device_id: usize,
1212) -> PyResult<SmaDeviceArrayF32Py> {
1213 use crate::cuda::cuda_available;
1214 use numpy::PyUntypedArrayMethods;
1215
1216 if !cuda_available() {
1217 return Err(PyValueError::new_err("CUDA not available"));
1218 }
1219
1220 let flat_in = data_tm_f32.as_slice()?;
1221 let rows = data_tm_f32.shape()[0];
1222 let cols = data_tm_f32.shape()[1];
1223 let params = SmaParams {
1224 period: Some(period),
1225 };
1226
1227 let (inner, ctx_arc, dev_id) = py.allow_threads(|| {
1228 let cuda = CudaSma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1229 let dev = cuda
1230 .sma_multi_series_one_param_time_major_dev(flat_in, cols, rows, ¶ms)
1231 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1232 cuda.synchronize()
1233 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1234 Ok::<_, PyErr>((dev, cuda.context_arc_clone(), cuda.device_id()))
1235 })?;
1236
1237 Ok(SmaDeviceArrayF32Py {
1238 inner,
1239 _ctx: ctx_arc,
1240 device_id: dev_id,
1241 })
1242}
1243
1244#[cfg(all(feature = "python", feature = "cuda"))]
1245#[pyclass(module = "ta_indicators.cuda", name = "SmaDeviceArrayF32", unsendable)]
1246pub struct SmaDeviceArrayF32Py {
1247 pub(crate) inner: DeviceArrayF32,
1248 pub(crate) _ctx: Arc<Context>,
1249 pub(crate) device_id: u32,
1250}
1251
1252#[cfg(all(feature = "python", feature = "cuda"))]
1253#[pymethods]
1254impl SmaDeviceArrayF32Py {
1255 #[getter]
1256 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1257 let d = PyDict::new(py);
1258
1259 d.set_item("shape", (self.inner.rows, self.inner.cols))?;
1260 d.set_item("typestr", "<f4")?;
1261 d.set_item(
1262 "strides",
1263 (
1264 self.inner.cols * std::mem::size_of::<f32>(),
1265 std::mem::size_of::<f32>(),
1266 ),
1267 )?;
1268 d.set_item("data", (self.inner.device_ptr() as usize, false))?;
1269
1270 d.set_item("version", 3)?;
1271 Ok(d)
1272 }
1273
1274 fn __dlpack_device__(&self) -> (i32, i32) {
1275 (2, self.device_id as i32)
1276 }
1277
1278 #[pyo3(signature=(stream=None, max_version=None, dl_device=None, copy=None))]
1279 fn __dlpack__<'py>(
1280 &mut self,
1281 py: Python<'py>,
1282 stream: Option<pyo3::PyObject>,
1283 max_version: Option<pyo3::PyObject>,
1284 dl_device: Option<pyo3::PyObject>,
1285 copy: Option<pyo3::PyObject>,
1286 ) -> PyResult<PyObject> {
1287 let (kdl, alloc_dev) = self.__dlpack_device__();
1288 if let Some(dev_obj) = dl_device.as_ref() {
1289 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1290 if dev_ty != kdl || dev_id != alloc_dev {
1291 let wants_copy = copy
1292 .as_ref()
1293 .and_then(|c| c.extract::<bool>(py).ok())
1294 .unwrap_or(false);
1295 if wants_copy {
1296 return Err(PyValueError::new_err(
1297 "device copy not implemented for __dlpack__",
1298 ));
1299 } else {
1300 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
1301 }
1302 }
1303 }
1304 }
1305 let _ = stream;
1306
1307 let dummy =
1308 DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
1309 let inner = std::mem::replace(
1310 &mut self.inner,
1311 DeviceArrayF32 {
1312 buf: dummy,
1313 rows: 0,
1314 cols: 0,
1315 },
1316 );
1317
1318 let rows = inner.rows;
1319 let cols = inner.cols;
1320 let buf = inner.buf;
1321
1322 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1323
1324 crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d(
1325 py,
1326 buf,
1327 rows,
1328 cols,
1329 alloc_dev,
1330 max_version_bound,
1331 )
1332 }
1333}
1334
1335#[cfg(feature = "python")]
1336#[pyclass(name = "SmaStream")]
1337
1338pub struct SmaStreamPy {
1339 inner: SmaStream,
1340}
1341
1342#[cfg(feature = "python")]
1343#[pymethods]
1344impl SmaStreamPy {
1345 #[new]
1346 #[pyo3(signature = (period))]
1347 pub fn new(period: usize) -> PyResult<Self> {
1348 let params = SmaParams {
1349 period: Some(period),
1350 };
1351 let inner = SmaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1352 Ok(Self { inner })
1353 }
1354
1355 pub fn update(&mut self, value: f64) -> Option<f64> {
1356 self.inner.update(value)
1357 }
1358}
1359
1360#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1361#[wasm_bindgen(js_name = "sma")]
1362
1363pub fn sma_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1364 let params = SmaParams {
1365 period: Some(period),
1366 };
1367 let input = SmaInput::from_slice(data, params);
1368
1369 let mut output = vec![0.0; data.len()];
1370
1371 sma_into_slice(&mut output, &input, Kernel::Auto)
1372 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1373
1374 Ok(output)
1375}
1376
1377#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1378#[derive(Serialize, Deserialize)]
1379pub struct SmaBatchConfig {
1380 pub period_range: (usize, usize, usize),
1381}
1382
1383#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1384#[derive(Serialize, Deserialize)]
1385pub struct SmaBatchJsOutput {
1386 pub values: Vec<f64>,
1387 pub combos: Vec<SmaParams>,
1388 pub periods: Vec<usize>,
1389 pub rows: usize,
1390 pub cols: usize,
1391}
1392
1393#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1394#[wasm_bindgen(js_name = "sma_batch")]
1395pub fn sma_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1396 let config: SmaBatchConfig = serde_wasm_bindgen::from_value(config)
1397 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1398
1399 let sweep = SmaBatchRange {
1400 period: config.period_range,
1401 };
1402
1403 let output = sma_batch_with_kernel(data, &sweep, Kernel::Auto)
1404 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1405
1406 let js_output = SmaBatchJsOutput {
1407 values: output.values,
1408 periods: output
1409 .combos
1410 .iter()
1411 .map(|c| c.period.unwrap_or(9))
1412 .collect(),
1413 combos: output.combos,
1414 rows: output.rows,
1415 cols: output.cols,
1416 };
1417
1418 serde_wasm_bindgen::to_value(&js_output).map_err(|e| JsValue::from_str(&e.to_string()))
1419}
1420
1421#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1422#[wasm_bindgen(js_name = "smaBatch")]
1423#[deprecated(since = "1.0.0", note = "Use sma_batch instead")]
1424pub fn sma_batch_js(
1425 data: &[f64],
1426 period_start: usize,
1427 period_end: usize,
1428 period_step: usize,
1429) -> Result<Vec<f64>, JsValue> {
1430 let range = SmaBatchRange {
1431 period: (period_start, period_end, period_step),
1432 };
1433
1434 sma_batch_with_kernel(data, &range, Kernel::Auto)
1435 .map(|output| output.values)
1436 .map_err(|e| JsValue::from_str(&e.to_string()))
1437}
1438
1439#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1440#[wasm_bindgen(js_name = "smaBatchMetadata")]
1441#[deprecated(since = "1.0.0", note = "Use sma_batch which returns metadata")]
1442pub fn sma_batch_metadata_js(
1443 period_start: usize,
1444 period_end: usize,
1445 period_step: usize,
1446) -> Vec<usize> {
1447 let range = SmaBatchRange {
1448 period: (period_start, period_end, period_step),
1449 };
1450 let combos = expand_grid_sma(&range).unwrap_or_default();
1451 combos.iter().map(|c| c.period.unwrap_or(9)).collect()
1452}
1453
1454#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1455#[wasm_bindgen(js_name = "smaBatchRowsCols")]
1456#[deprecated(since = "1.0.0", note = "Use sma_batch which returns rows and cols")]
1457pub fn sma_batch_rows_cols_js(
1458 period_start: usize,
1459 period_end: usize,
1460 period_step: usize,
1461 data_len: usize,
1462) -> Vec<usize> {
1463 let range = SmaBatchRange {
1464 period: (period_start, period_end, period_step),
1465 };
1466 let combos = expand_grid_sma(&range).unwrap_or_default();
1467 vec![combos.len(), data_len]
1468}
1469
1470#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1471#[wasm_bindgen]
1472pub fn sma_alloc(len: usize) -> *mut f64 {
1473 let mut vec = Vec::<f64>::with_capacity(len);
1474 let ptr = vec.as_mut_ptr();
1475 std::mem::forget(vec);
1476 ptr
1477}
1478
1479#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1480#[wasm_bindgen]
1481pub fn sma_free(ptr: *mut f64, len: usize) {
1482 if !ptr.is_null() {
1483 unsafe {
1484 let _ = Vec::from_raw_parts(ptr, len, len);
1485 }
1486 }
1487}
1488
1489#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1490#[wasm_bindgen]
1491pub fn sma_into(
1492 in_ptr: *const f64,
1493 out_ptr: *mut f64,
1494 len: usize,
1495 period: usize,
1496) -> Result<(), JsValue> {
1497 if in_ptr.is_null() || out_ptr.is_null() {
1498 return Err(JsValue::from_str("Null pointer provided"));
1499 }
1500
1501 unsafe {
1502 let data = std::slice::from_raw_parts(in_ptr, len);
1503
1504 let params = SmaParams {
1505 period: Some(period),
1506 };
1507 let input = SmaInput::from_slice(data, params);
1508
1509 if in_ptr == out_ptr as *const f64 {
1510 let mut temp = vec![0.0; len];
1511 sma_into_slice(&mut temp, &input, Kernel::Auto)
1512 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1513
1514 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1515 out.copy_from_slice(&temp);
1516 } else {
1517 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1518 sma_into_slice(out, &input, Kernel::Auto)
1519 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1520 }
1521
1522 Ok(())
1523 }
1524}
1525
1526#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1527#[wasm_bindgen]
1528pub fn sma_batch_into(
1529 in_ptr: *const f64,
1530 out_ptr: *mut f64,
1531 len: usize,
1532 period_start: usize,
1533 period_end: usize,
1534 period_step: usize,
1535) -> Result<usize, JsValue> {
1536 if in_ptr.is_null() || out_ptr.is_null() {
1537 return Err(JsValue::from_str("Null pointer provided"));
1538 }
1539
1540 unsafe {
1541 let data = std::slice::from_raw_parts(in_ptr, len);
1542
1543 let sweep = SmaBatchRange {
1544 period: (period_start, period_end, period_step),
1545 };
1546
1547 let combos = expand_grid_sma(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1548 let rows = combos.len();
1549 let total_size = rows * len;
1550
1551 let out = std::slice::from_raw_parts_mut(out_ptr, total_size);
1552
1553 let kernel = match detect_best_batch_kernel() {
1554 Kernel::Avx512Batch => Kernel::Avx512,
1555 Kernel::Avx2Batch => Kernel::Avx2,
1556 Kernel::ScalarBatch => Kernel::Scalar,
1557 other => other,
1558 };
1559
1560 sma_batch_inner_into(data, &sweep, kernel, false, out)
1561 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1562
1563 Ok(rows)
1564 }
1565}
1566
1567#[cfg(test)]
1568mod tests {
1569 use super::*;
1570 use crate::skip_if_unsupported;
1571 use crate::utilities::data_loader::read_candles_from_csv;
1572
1573 #[test]
1574 fn test_sma_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
1575 let mut data = Vec::with_capacity(256);
1576 data.extend_from_slice(&[f64::NAN, f64::NAN, f64::NAN]);
1577 for i in 0..253u32 {
1578 let v = ((i % 17) as f64) * 1.2345 + (i as f64).sin() * 0.001;
1579 data.push(v);
1580 }
1581
1582 let params = SmaParams::default();
1583 let input = SmaInput::from_slice(&data, params);
1584
1585 let base = sma_with_kernel(&input, Kernel::Auto)?.values;
1586
1587 let mut out = vec![0.0; data.len()];
1588 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1589 {
1590 sma_into(&input, &mut out)?;
1591 }
1592
1593 assert_eq!(base.len(), out.len());
1594
1595 for (i, (a, b)) in base.iter().zip(out.iter()).enumerate() {
1596 let ok = if a.is_nan() && b.is_nan() {
1597 true
1598 } else {
1599 (a - b).abs() <= 1e-12
1600 };
1601 assert!(ok, "Mismatch at index {}: base={} vs into={}", i, a, b);
1602 }
1603 Ok(())
1604 }
1605 fn check_sma_partial_params(
1606 test_name: &str,
1607 kernel: Kernel,
1608 ) -> Result<(), Box<dyn std::error::Error>> {
1609 skip_if_unsupported!(kernel, test_name);
1610 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1611 let candles = read_candles_from_csv(file_path)?;
1612 let default_params = SmaParams { period: None };
1613 let input = SmaInput::from_candles(&candles, "close", default_params);
1614 let output = sma_with_kernel(&input, kernel)?;
1615 assert_eq!(output.values.len(), candles.close.len());
1616 Ok(())
1617 }
1618 fn check_sma_accuracy(
1619 test_name: &str,
1620 kernel: Kernel,
1621 ) -> Result<(), Box<dyn std::error::Error>> {
1622 skip_if_unsupported!(kernel, test_name);
1623 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1624 let candles = read_candles_from_csv(file_path)?;
1625 let params = SmaParams { period: Some(9) };
1626 let input = SmaInput::from_candles(&candles, "close", params);
1627 let result = sma_with_kernel(&input, kernel)?;
1628 let expected_last_five = [59180.8, 59175.0, 59129.4, 59085.4, 59133.7];
1629 let start = result.values.len().saturating_sub(5);
1630 for (i, &val) in result.values[start..].iter().enumerate() {
1631 let diff = (val - expected_last_five[i]).abs();
1632 assert!(
1633 diff < 1e-1,
1634 "[{}] SMA {:?} mismatch at idx {}: got {}, expected {}",
1635 test_name,
1636 kernel,
1637 i,
1638 val,
1639 expected_last_five[i]
1640 );
1641 }
1642 Ok(())
1643 }
1644 fn check_sma_default_candles(
1645 test_name: &str,
1646 kernel: Kernel,
1647 ) -> Result<(), Box<dyn std::error::Error>> {
1648 skip_if_unsupported!(kernel, test_name);
1649 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1650 let candles = read_candles_from_csv(file_path)?;
1651 let input = SmaInput::with_default_candles(&candles);
1652 match input.data {
1653 SmaData::Candles { source, .. } => assert_eq!(source, "close"),
1654 _ => panic!("Expected SmaData::Candles"),
1655 }
1656 let output = sma_with_kernel(&input, kernel)?;
1657 assert_eq!(output.values.len(), candles.close.len());
1658 Ok(())
1659 }
1660 fn check_sma_zero_period(
1661 test_name: &str,
1662 kernel: Kernel,
1663 ) -> Result<(), Box<dyn std::error::Error>> {
1664 skip_if_unsupported!(kernel, test_name);
1665 let input_data = [10.0, 20.0, 30.0];
1666 let params = SmaParams { period: Some(0) };
1667 let input = SmaInput::from_slice(&input_data, params);
1668 let res = sma_with_kernel(&input, kernel);
1669 assert!(
1670 res.is_err(),
1671 "[{}] SMA should fail with zero period",
1672 test_name
1673 );
1674 Ok(())
1675 }
1676 fn check_sma_period_exceeds_length(
1677 test_name: &str,
1678 kernel: Kernel,
1679 ) -> Result<(), Box<dyn std::error::Error>> {
1680 skip_if_unsupported!(kernel, test_name);
1681 let data_small = [10.0, 20.0, 30.0];
1682 let params = SmaParams { period: Some(10) };
1683 let input = SmaInput::from_slice(&data_small, params);
1684 let res = sma_with_kernel(&input, kernel);
1685 assert!(
1686 res.is_err(),
1687 "[{}] SMA should fail with period exceeding length",
1688 test_name
1689 );
1690 Ok(())
1691 }
1692 fn check_sma_very_small_dataset(
1693 test_name: &str,
1694 kernel: Kernel,
1695 ) -> Result<(), Box<dyn std::error::Error>> {
1696 skip_if_unsupported!(kernel, test_name);
1697 let single_point = [42.0];
1698 let params = SmaParams { period: Some(9) };
1699 let input = SmaInput::from_slice(&single_point, params);
1700 let res = sma_with_kernel(&input, kernel);
1701 assert!(
1702 res.is_err(),
1703 "[{}] SMA should fail with insufficient data",
1704 test_name
1705 );
1706 Ok(())
1707 }
1708 fn check_sma_reinput(
1709 test_name: &str,
1710 kernel: Kernel,
1711 ) -> Result<(), Box<dyn std::error::Error>> {
1712 skip_if_unsupported!(kernel, test_name);
1713 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1714 let candles = read_candles_from_csv(file_path)?;
1715 let first_params = SmaParams { period: Some(14) };
1716 let first_input = SmaInput::from_candles(&candles, "close", first_params);
1717 let first_result = sma_with_kernel(&first_input, kernel)?;
1718 let second_params = SmaParams { period: Some(14) };
1719 let second_input = SmaInput::from_slice(&first_result.values, second_params);
1720 let second_result = sma_with_kernel(&second_input, kernel)?;
1721 assert_eq!(second_result.values.len(), first_result.values.len());
1722 Ok(())
1723 }
1724 fn check_sma_nan_handling(
1725 test_name: &str,
1726 kernel: Kernel,
1727 ) -> Result<(), Box<dyn std::error::Error>> {
1728 skip_if_unsupported!(kernel, test_name);
1729 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1730 let candles = read_candles_from_csv(file_path)?;
1731 let input = SmaInput::from_candles(&candles, "close", SmaParams { period: Some(9) });
1732 let res = sma_with_kernel(&input, kernel)?;
1733 assert_eq!(res.values.len(), candles.close.len());
1734 if res.values.len() > 240 {
1735 for (i, &val) in res.values[240..].iter().enumerate() {
1736 assert!(
1737 !val.is_nan(),
1738 "[{}] Found unexpected NaN at out-index {}",
1739 test_name,
1740 240 + i
1741 );
1742 }
1743 }
1744 Ok(())
1745 }
1746 fn check_sma_streaming(
1747 test_name: &str,
1748 kernel: Kernel,
1749 ) -> Result<(), Box<dyn std::error::Error>> {
1750 skip_if_unsupported!(kernel, test_name);
1751 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1752 let candles = read_candles_from_csv(file_path)?;
1753 let period = 9;
1754 let input = SmaInput::from_candles(
1755 &candles,
1756 "close",
1757 SmaParams {
1758 period: Some(period),
1759 },
1760 );
1761 let batch_output = sma_with_kernel(&input, kernel)?.values;
1762 let mut stream = SmaStream::try_new(SmaParams {
1763 period: Some(period),
1764 })?;
1765 let mut stream_values = Vec::with_capacity(candles.close.len());
1766 for &price in &candles.close {
1767 match stream.update(price) {
1768 Some(sma_val) => stream_values.push(sma_val),
1769 None => stream_values.push(f64::NAN),
1770 }
1771 }
1772 assert_eq!(batch_output.len(), stream_values.len());
1773 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1774 if b.is_nan() && s.is_nan() {
1775 continue;
1776 }
1777 let diff = (b - s).abs();
1778 assert!(
1779 diff < 1e-9,
1780 "[{}] SMA streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1781 test_name,
1782 i,
1783 b,
1784 s,
1785 diff
1786 );
1787 }
1788 Ok(())
1789 }
1790
1791 #[cfg(debug_assertions)]
1792 fn check_sma_no_poison(
1793 test_name: &str,
1794 kernel: Kernel,
1795 ) -> Result<(), Box<dyn std::error::Error>> {
1796 skip_if_unsupported!(kernel, test_name);
1797
1798 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1799 let candles = read_candles_from_csv(file_path)?;
1800
1801 let test_periods = vec![5, 9, 14, 20, 30, 50];
1802
1803 for period in test_periods {
1804 let params = SmaParams {
1805 period: Some(period),
1806 };
1807 let input = SmaInput::from_candles(&candles, "close", params);
1808 let output = sma_with_kernel(&input, kernel)?;
1809
1810 for (i, &val) in output.values.iter().enumerate() {
1811 if val.is_nan() {
1812 continue;
1813 }
1814
1815 let bits = val.to_bits();
1816
1817 if bits == 0x11111111_11111111 {
1818 panic!(
1819 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} (period={})",
1820 test_name, val, bits, i, period
1821 );
1822 }
1823
1824 if bits == 0x22222222_22222222 {
1825 panic!(
1826 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} (period={})",
1827 test_name, val, bits, i, period
1828 );
1829 }
1830
1831 if bits == 0x33333333_33333333 {
1832 panic!(
1833 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} (period={})",
1834 test_name, val, bits, i, period
1835 );
1836 }
1837 }
1838 }
1839
1840 Ok(())
1841 }
1842
1843 #[cfg(not(debug_assertions))]
1844 fn check_sma_no_poison(
1845 _test_name: &str,
1846 _kernel: Kernel,
1847 ) -> Result<(), Box<dyn std::error::Error>> {
1848 Ok(())
1849 }
1850
1851 #[cfg(feature = "proptest")]
1852 #[allow(clippy::float_cmp)]
1853 fn check_sma_property(
1854 test_name: &str,
1855 kernel: Kernel,
1856 ) -> Result<(), Box<dyn std::error::Error>> {
1857 use proptest::prelude::*;
1858 skip_if_unsupported!(kernel, test_name);
1859
1860 let strat = (1usize..=100).prop_flat_map(|period| {
1861 (
1862 prop::collection::vec(
1863 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1864 period..400,
1865 ),
1866 Just(period),
1867 )
1868 });
1869
1870 proptest::test_runner::TestRunner::default()
1871 .run(&strat, |(data, period)| {
1872 let params = SmaParams {
1873 period: Some(period),
1874 };
1875 let input = SmaInput::from_slice(&data, params);
1876
1877 let SmaOutput { values: out } = sma_with_kernel(&input, kernel).unwrap();
1878 let SmaOutput { values: ref_out } =
1879 sma_with_kernel(&input, Kernel::Scalar).unwrap();
1880
1881 for i in 0..(period - 1) {
1882 prop_assert!(
1883 out[i].is_nan(),
1884 "Expected NaN during warmup at index {}, got {}",
1885 i,
1886 out[i]
1887 );
1888 }
1889
1890 for i in (period - 1)..data.len() {
1891 let window_start = i + 1 - period;
1892 let window = &data[window_start..=i];
1893
1894 let expected_sum: f64 = window.iter().sum();
1895 let expected_mean = expected_sum / period as f64;
1896
1897 let abs_tolerance = 1e-8_f64;
1898 let rel_tolerance = 1e-12_f64;
1899 let tolerance = abs_tolerance.max(expected_mean.abs() * rel_tolerance);
1900
1901 let kernel_tol = 5e-8_f64.max(tolerance);
1902 prop_assert!(
1903 (out[i] - expected_mean).abs() <= tolerance,
1904 "SMA mismatch at index {}: expected {}, got {} (diff: {})",
1905 i,
1906 expected_mean,
1907 out[i],
1908 (out[i] - expected_mean).abs()
1909 );
1910
1911 let window_min = window.iter().cloned().fold(f64::INFINITY, f64::min);
1912 let window_max = window.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1913
1914 prop_assert!(
1915 out[i] >= window_min - kernel_tol && out[i] <= window_max + kernel_tol,
1916 "SMA out of bounds at index {}: {} not in [{}, {}]",
1917 i,
1918 out[i],
1919 window_min,
1920 window_max
1921 );
1922
1923 if window.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-12) {
1924 let tolerance = kernel_tol.max(if period == 1 { 1e-8 } else { 1e-9 });
1925 prop_assert!(
1926 (out[i] - window[0]).abs() <= tolerance,
1927 "Constant input property failed at index {}: expected {}, got {}",
1928 i,
1929 window[0],
1930 out[i]
1931 );
1932 }
1933
1934 if period >= 3 {
1935 let diffs: Vec<f64> = window.windows(2).map(|w| w[1] - w[0]).collect();
1936 let is_linear = diffs.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-9);
1937
1938 if is_linear && !diffs.is_empty() {
1939 let midpoint_value = window[period / 2];
1940 let tolerance = if period % 2 == 0 {
1941 (window[period / 2 - 1] - window[period / 2]).abs() / 2.0
1942 + kernel_tol
1943 } else {
1944 kernel_tol
1945 };
1946
1947 prop_assert!(
1948 (out[i] - midpoint_value).abs() <= tolerance,
1949 "Linear trend property failed at index {}: expected ~{}, got {}",
1950 i,
1951 midpoint_value,
1952 out[i]
1953 );
1954 }
1955 }
1956
1957 prop_assert!(
1958 (out[i] - ref_out[i]).abs() <= kernel_tol
1959 || (out[i].is_nan() && ref_out[i].is_nan()),
1960 "Kernel mismatch at index {}: {} ({:?}) vs {} (Scalar)",
1961 i,
1962 out[i],
1963 kernel,
1964 ref_out[i]
1965 );
1966
1967 if i >= period {
1968 let new_value = data[i];
1969 let old_value = data[i - period];
1970 let expected_sma_change = (new_value - old_value) / period as f64;
1971 let actual_sma_change = out[i] - out[i - 1];
1972 let lag_tol = (expected_sma_change.abs() * rel_tolerance)
1973 .max(5e-8_f64)
1974 .max(2.0 * kernel_tol);
1975
1976 prop_assert!(
1977 (actual_sma_change - expected_sma_change).abs() <= lag_tol,
1978 "Lag property failed at index {}: SMA change {} should be {} (new: {}, old: {})",
1979 i,
1980 actual_sma_change,
1981 expected_sma_change,
1982 new_value,
1983 old_value
1984 );
1985 }
1986
1987 #[cfg(debug_assertions)]
1988 {
1989 let bits = out[i].to_bits();
1990 prop_assert!(
1991 bits != 0x11111111_11111111
1992 && bits != 0x22222222_22222222
1993 && bits != 0x33333333_33333333,
1994 "Found poison value at index {}: {} (0x{:016X})",
1995 i,
1996 out[i],
1997 bits
1998 );
1999 }
2000 }
2001
2002 if period == 1 {
2003 for i in 0..data.len() {
2004 prop_assert!(
2005 (out[i] - data[i]).abs() <= 1e-8,
2006 "Period=1 property failed at index {}: expected {}, got {}",
2007 i,
2008 data[i],
2009 out[i]
2010 );
2011 }
2012 }
2013
2014 Ok(())
2015 })
2016 .unwrap();
2017
2018 Ok(())
2019 }
2020
2021 macro_rules! generate_all_sma_tests {
2022 ($($test_fn:ident),*) => {
2023 paste::paste! {
2024 $(
2025 #[test]
2026 fn [<$test_fn _scalar_f64>]() {
2027 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2028 }
2029 )*
2030 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2031 $(
2032 #[test]
2033 fn [<$test_fn _avx2_f64>]() {
2034 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2035 }
2036 #[test]
2037 fn [<$test_fn _avx512_f64>]() {
2038 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2039 }
2040 )*
2041 }
2042 }
2043 }
2044 generate_all_sma_tests!(
2045 check_sma_partial_params,
2046 check_sma_accuracy,
2047 check_sma_default_candles,
2048 check_sma_zero_period,
2049 check_sma_period_exceeds_length,
2050 check_sma_very_small_dataset,
2051 check_sma_reinput,
2052 check_sma_nan_handling,
2053 check_sma_streaming,
2054 check_sma_no_poison
2055 );
2056
2057 #[cfg(feature = "proptest")]
2058 generate_all_sma_tests!(check_sma_property);
2059 fn check_batch_default_row(
2060 test: &str,
2061 kernel: Kernel,
2062 ) -> Result<(), Box<dyn std::error::Error>> {
2063 skip_if_unsupported!(kernel, test);
2064 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2065 let c = read_candles_from_csv(file)?;
2066 let output = SmaBatchBuilder::new()
2067 .kernel(kernel)
2068 .apply_candles(&c, "close")?;
2069 let def = SmaParams::default();
2070 let row = output.values_for(&def).expect("default row missing");
2071 assert_eq!(row.len(), c.close.len());
2072 let expected = [59180.8, 59175.0, 59129.4, 59085.4, 59133.7];
2073 let start = row.len() - 5;
2074 for (i, &v) in row[start..].iter().enumerate() {
2075 assert!(
2076 (v - expected[i]).abs() < 1e-1,
2077 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
2078 );
2079 }
2080 Ok(())
2081 }
2082
2083 #[cfg(debug_assertions)]
2084 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
2085 skip_if_unsupported!(kernel, test);
2086
2087 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2088 let c = read_candles_from_csv(file)?;
2089
2090 let test_configs = vec![(5, 15, 5), (10, 30, 10), (20, 50, 15), (2, 10, 2)];
2091
2092 for (start, end, step) in test_configs {
2093 let output = SmaBatchBuilder::new()
2094 .kernel(kernel)
2095 .period_range(start, end, step)
2096 .apply_candles(&c, "close")?;
2097
2098 for (idx, &val) in output.values.iter().enumerate() {
2099 if val.is_nan() {
2100 continue;
2101 }
2102
2103 let bits = val.to_bits();
2104 let row = idx / output.cols;
2105 let col = idx % output.cols;
2106 let period = output.combos[row].period.unwrap();
2107
2108 if bits == 0x11111111_11111111 {
2109 panic!(
2110 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (flat index {}, period={})",
2111 test, val, bits, row, col, idx, period
2112 );
2113 }
2114
2115 if bits == 0x22222222_22222222 {
2116 panic!(
2117 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (flat index {}, period={})",
2118 test, val, bits, row, col, idx, period
2119 );
2120 }
2121
2122 if bits == 0x33333333_33333333 {
2123 panic!(
2124 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (flat index {}, period={})",
2125 test, val, bits, row, col, idx, period
2126 );
2127 }
2128 }
2129 }
2130
2131 Ok(())
2132 }
2133
2134 #[cfg(not(debug_assertions))]
2135 fn check_batch_no_poison(
2136 _test: &str,
2137 _kernel: Kernel,
2138 ) -> Result<(), Box<dyn std::error::Error>> {
2139 Ok(())
2140 }
2141 macro_rules! gen_batch_tests {
2142 ($fn_name:ident) => {
2143 paste::paste! {
2144 #[test] fn [<$fn_name _scalar>]() {
2145 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2146 }
2147 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2148 #[test] fn [<$fn_name _avx2>]() {
2149 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2150 }
2151 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2152 #[test] fn [<$fn_name _avx512>]() {
2153 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2154 }
2155 #[test] fn [<$fn_name _auto_detect>]() {
2156 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2157 }
2158 }
2159 };
2160 }
2161 gen_batch_tests!(check_batch_default_row);
2162 gen_batch_tests!(check_batch_no_poison);
2163}