1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1, PyReadonlyArray2};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
10use serde::{Deserialize, Serialize};
11#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
12use wasm_bindgen::prelude::*;
13
14#[cfg(all(feature = "python", feature = "cuda"))]
15use crate::cuda::cuda_available;
16#[cfg(all(feature = "python", feature = "cuda"))]
17use crate::cuda::moving_averages::smma_wrapper::DeviceArrayF32Smma;
18#[cfg(all(feature = "python", feature = "cuda"))]
19use crate::cuda::moving_averages::CudaSmma;
20use crate::utilities::data_loader::{source_type, Candles};
21use crate::utilities::enums::Kernel;
22use crate::utilities::helpers::{
23 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
24 make_uninit_matrix,
25};
26#[cfg(feature = "python")]
27use crate::utilities::kernel_validation::validate_kernel;
28#[cfg(not(target_arch = "wasm32"))]
29use rayon::prelude::*;
30use std::mem::MaybeUninit;
31use thiserror::Error;
32
33impl<'a> AsRef<[f64]> for SmmaInput<'a> {
34 #[inline(always)]
35 fn as_ref(&self) -> &[f64] {
36 match &self.data {
37 SmmaData::Slice(slice) => slice,
38 SmmaData::Candles { candles, source } => source_type(candles, source),
39 }
40 }
41}
42
43#[derive(Debug, Clone)]
44pub enum SmmaData<'a> {
45 Candles {
46 candles: &'a Candles,
47 source: &'a str,
48 },
49 Slice(&'a [f64]),
50}
51
52#[derive(Debug, Clone)]
53pub struct SmmaOutput {
54 pub values: Vec<f64>,
55}
56
57#[derive(Debug, Clone)]
58#[cfg_attr(
59 all(target_arch = "wasm32", feature = "wasm"),
60 derive(Serialize, Deserialize)
61)]
62pub struct SmmaParams {
63 pub period: Option<usize>,
64}
65
66impl Default for SmmaParams {
67 fn default() -> Self {
68 Self { period: Some(7) }
69 }
70}
71
72#[derive(Debug, Clone)]
73pub struct SmmaInput<'a> {
74 pub data: SmmaData<'a>,
75 pub params: SmmaParams,
76}
77
78impl<'a> SmmaInput<'a> {
79 #[inline]
80 pub fn from_candles(c: &'a Candles, s: &'a str, p: SmmaParams) -> Self {
81 Self {
82 data: SmmaData::Candles {
83 candles: c,
84 source: s,
85 },
86 params: p,
87 }
88 }
89 #[inline]
90 pub fn from_slice(sl: &'a [f64], p: SmmaParams) -> Self {
91 Self {
92 data: SmmaData::Slice(sl),
93 params: p,
94 }
95 }
96 #[inline]
97 pub fn with_default_candles(c: &'a Candles) -> Self {
98 Self::from_candles(c, "close", SmmaParams::default())
99 }
100 #[inline]
101 pub fn get_period(&self) -> usize {
102 self.params.period.unwrap_or(7)
103 }
104}
105
106#[derive(Copy, Clone, Debug)]
107pub struct SmmaBuilder {
108 period: Option<usize>,
109 kernel: Kernel,
110}
111
112impl Default for SmmaBuilder {
113 fn default() -> Self {
114 Self {
115 period: None,
116 kernel: Kernel::Auto,
117 }
118 }
119}
120
121impl SmmaBuilder {
122 #[inline(always)]
123 pub fn new() -> Self {
124 Self::default()
125 }
126 #[inline(always)]
127 pub fn period(mut self, n: usize) -> Self {
128 self.period = Some(n);
129 self
130 }
131 #[inline(always)]
132 pub fn kernel(mut self, k: Kernel) -> Self {
133 self.kernel = k;
134 self
135 }
136
137 #[inline(always)]
138 pub fn apply(self, c: &Candles) -> Result<SmmaOutput, SmmaError> {
139 let p = SmmaParams {
140 period: self.period,
141 };
142 let i = SmmaInput::from_candles(c, "close", p);
143 smma_with_kernel(&i, self.kernel)
144 }
145
146 #[inline(always)]
147 pub fn apply_slice(self, d: &[f64]) -> Result<SmmaOutput, SmmaError> {
148 let p = SmmaParams {
149 period: self.period,
150 };
151 let i = SmmaInput::from_slice(d, p);
152 smma_with_kernel(&i, self.kernel)
153 }
154
155 #[inline(always)]
156 pub fn into_stream(self) -> Result<SmmaStream, SmmaError> {
157 let p = SmmaParams {
158 period: self.period,
159 };
160 SmmaStream::try_new(p)
161 }
162}
163
164#[derive(Debug, Error)]
165pub enum SmmaError {
166 #[error("smma: Input data slice is empty.")]
167 EmptyInputData,
168 #[error("smma: All values are NaN.")]
169 AllValuesNaN,
170 #[error("smma: Invalid period: period = {period}, data length = {data_len}")]
171 InvalidPeriod { period: usize, data_len: usize },
172 #[error("smma: Not enough valid data: needed = {needed}, valid = {valid}")]
173 NotEnoughValidData { needed: usize, valid: usize },
174 #[error("smma: Invalid kernel for batch operation: {kernel:?} (expected batch kernel)")]
175 InvalidKernelForBatch { kernel: Kernel },
176 #[error("smma: Invalid range expansion: start={start}, end={end}, step={step}")]
177 InvalidRange {
178 start: usize,
179 end: usize,
180 step: usize,
181 },
182 #[error("smma: Output buffer length mismatch: expected = {expected}, got = {got}")]
183 OutputLengthMismatch { expected: usize, got: usize },
184}
185
186#[inline]
187pub fn smma(input: &SmmaInput) -> Result<SmmaOutput, SmmaError> {
188 smma_with_kernel(input, Kernel::Auto)
189}
190
191pub fn smma_with_kernel(input: &SmmaInput, kernel: Kernel) -> Result<SmmaOutput, SmmaError> {
192 let data: &[f64] = match &input.data {
193 SmmaData::Candles { candles, source } => source_type(candles, source),
194 SmmaData::Slice(sl) => sl,
195 };
196
197 if data.is_empty() {
198 return Err(SmmaError::EmptyInputData);
199 }
200
201 let first = data
202 .iter()
203 .position(|x| !x.is_nan())
204 .ok_or(SmmaError::AllValuesNaN)?;
205 let len = data.len();
206 let period = input.get_period();
207
208 if period == 0 || period > len {
209 return Err(SmmaError::InvalidPeriod {
210 period,
211 data_len: len,
212 });
213 }
214 if (len - first) < period {
215 return Err(SmmaError::NotEnoughValidData {
216 needed: period,
217 valid: len - first,
218 });
219 }
220
221 let chosen = match kernel {
222 Kernel::Auto => detect_best_kernel(),
223 other => other,
224 };
225
226 let warm = first + period - 1;
227 let mut out = alloc_with_nan_prefix(len, warm);
228
229 unsafe {
230 match chosen {
231 Kernel::Scalar | Kernel::ScalarBatch => smma_scalar(data, period, first, &mut out),
232 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
233 Kernel::Avx2 | Kernel::Avx2Batch => smma_avx2(data, period, first, &mut out),
234 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
235 Kernel::Avx512 | Kernel::Avx512Batch => smma_avx512(data, period, first, &mut out),
236 _ => smma_scalar(data, period, first, &mut out),
237 }
238 }
239 Ok(SmmaOutput { values: out })
240}
241
242#[inline]
243pub fn smma_scalar(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
244 let len = data.len();
245 let end = first + period;
246
247 if period == 1 {
248 out[first] = data[first];
249 let mut i = first + 1;
250 while i < len {
251 out[i] = data[i];
252 i += 1;
253 }
254 return;
255 }
256
257 let mut sum = 0.0f64;
258 for i in first..end {
259 sum += data[i];
260 }
261
262 let pf64 = period as f64;
263 let pm1 = pf64 - 1.0;
264
265 let mut prev = sum / pf64;
266 out[end - 1] = prev;
267
268 for i in end..len {
269 prev = (prev * pm1 + data[i]) / pf64;
270 out[i] = prev;
271 }
272}
273
274#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
275#[inline]
276unsafe fn smma_simd128(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
277 let end = first + period;
278 let sum: f64 = data[first..end].iter().sum();
279 let mut prev = sum / period as f64;
280 out[end - 1] = prev;
281
282 let period_f64 = period as f64;
283 let period_minus_1 = period_f64 - 1.0;
284 let inv_period = 1.0 / period_f64;
285
286 for i in end..data.len() {
287 prev = (prev * period_minus_1 + data[i]) * inv_period;
288 out[i] = prev;
289 }
290}
291
292#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
293#[inline]
294pub fn smma_avx512(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
295 unsafe { smma_avx512_long(data, period, first, out) }
296}
297
298#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
299#[inline]
300pub fn smma_avx2(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
301 let len = data.len();
302 let end = first + period;
303
304 if period == 1 {
305 out[first] = data[first];
306 let mut i = first + 1;
307 while i < len {
308 out[i] = data[i];
309 i += 1;
310 }
311 return;
312 }
313
314 let mut sum = 0.0f64;
315 let mut i = first;
316 while i < end {
317 sum += data[i];
318 i += 1;
319 }
320
321 let pf64 = period as f64;
322 let pm1 = pf64 - 1.0;
323 let inv_p = 1.0 / pf64;
324
325 let mut prev = sum * inv_p;
326 out[end - 1] = prev;
327
328 let mut t = end;
329 while t + 4 <= len {
330 let x0 = data[t];
331 prev = f64::mul_add(prev, pm1, x0) * inv_p;
332 out[t] = prev;
333
334 let x1 = data[t + 1];
335 prev = f64::mul_add(prev, pm1, x1) * inv_p;
336 out[t + 1] = prev;
337
338 let x2 = data[t + 2];
339 prev = f64::mul_add(prev, pm1, x2) * inv_p;
340 out[t + 2] = prev;
341
342 let x3 = data[t + 3];
343 prev = f64::mul_add(prev, pm1, x3) * inv_p;
344 out[t + 3] = prev;
345
346 t += 4;
347 }
348 while t < len {
349 let x = data[t];
350 prev = f64::mul_add(prev, pm1, x) * inv_p;
351 out[t] = prev;
352 t += 1;
353 }
354}
355
356#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
357#[inline]
358pub unsafe fn smma_avx512_short(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
359 smma_avx2(data, period, first, out)
360}
361
362#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
363#[inline]
364pub unsafe fn smma_avx512_long(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
365 smma_avx2(data, period, first, out)
366}
367
368#[inline(always)]
369fn smma_prepare<'a>(
370 input: &'a SmmaInput,
371 kernel: Kernel,
372) -> Result<(&'a [f64], usize, usize, Kernel), SmmaError> {
373 let data: &[f64] = input.as_ref();
374
375 if data.is_empty() {
376 return Err(SmmaError::EmptyInputData);
377 }
378
379 let first = data
380 .iter()
381 .position(|x| !x.is_nan())
382 .ok_or(SmmaError::AllValuesNaN)?;
383 let len = data.len();
384 let period = input.get_period();
385
386 if period == 0 || period > len {
387 return Err(SmmaError::InvalidPeriod {
388 period,
389 data_len: len,
390 });
391 }
392 if (len - first) < period {
393 return Err(SmmaError::NotEnoughValidData {
394 needed: period,
395 valid: len - first,
396 });
397 }
398
399 let chosen = match kernel {
400 Kernel::Auto => detect_best_kernel(),
401 other => other,
402 };
403
404 Ok((data, period, first, chosen))
405}
406
407#[inline(always)]
408fn smma_compute_into(data: &[f64], period: usize, first: usize, kernel: Kernel, out: &mut [f64]) {
409 unsafe {
410 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
411 {
412 if matches!(kernel, Kernel::Scalar | Kernel::ScalarBatch) {
413 smma_simd128(data, period, first, out);
414 return;
415 }
416 }
417
418 match kernel {
419 Kernel::Scalar | Kernel::ScalarBatch => smma_scalar(data, period, first, out),
420 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
421 Kernel::Avx2 | Kernel::Avx2Batch => smma_avx2(data, period, first, out),
422 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
423 Kernel::Avx512 | Kernel::Avx512Batch => smma_avx512(data, period, first, out),
424 _ => smma_scalar(data, period, first, out),
425 }
426 }
427}
428
429#[inline]
430pub fn expand_grid(r: &SmmaBatchRange) -> Vec<SmmaParams> {
431 let axis_usize = |(start, end, step): (usize, usize, usize)| -> Vec<usize> {
432 if step == 0 {
433 return vec![start];
434 }
435 if start == end {
436 return vec![start];
437 }
438 let mut out = Vec::new();
439 if start < end {
440 let mut v = start;
441 while v <= end {
442 out.push(v);
443 match v.checked_add(step) {
444 Some(next) => {
445 if next == v {
446 break;
447 }
448 v = next;
449 }
450 None => break,
451 }
452 }
453 } else {
454 let mut v = start;
455 while v >= end {
456 out.push(v);
457
458 let next = v.saturating_sub(step);
459 if next == v {
460 break;
461 }
462 v = next;
463 if v == 0 {
464 if end > 0 {
465 break;
466 }
467 }
468 }
469 }
470 out
471 };
472 axis_usize(r.period)
473 .into_iter()
474 .map(|p| SmmaParams { period: Some(p) })
475 .collect()
476}
477
478#[derive(Debug, Clone)]
479pub struct SmmaStream {
480 period: usize,
481
482 pf64: f64,
483 pm1: f64,
484 inv_p: f64,
485
486 state: SmmaStreamState,
487}
488
489#[derive(Debug, Clone)]
490enum SmmaStreamState {
491 SeekingFirst,
492 Warming { sum: f64, count: usize },
493 Ready { value: f64 },
494}
495
496impl SmmaStream {
497 #[inline]
498 pub fn try_new(params: SmmaParams) -> Result<Self, SmmaError> {
499 let period = params.period.unwrap_or(7);
500 if period == 0 {
501 return Err(SmmaError::InvalidPeriod {
502 period,
503 data_len: 0,
504 });
505 }
506 let pf64 = period as f64;
507 Ok(Self {
508 period,
509 pf64,
510 pm1: pf64 - 1.0,
511 inv_p: 1.0 / pf64,
512 state: SmmaStreamState::SeekingFirst,
513 })
514 }
515
516 #[inline(always)]
517 pub fn update(&mut self, v: f64) -> Option<f64> {
518 use SmmaStreamState::*;
519 match &mut self.state {
520 SeekingFirst => {
521 if v.is_finite() {
522 if self.period == 1 {
523 self.state = Ready { value: v };
524 return Some(v);
525 }
526 self.state = Warming { sum: v, count: 1 };
527 }
528 None
529 }
530 Warming { sum, count } => {
531 *sum += v;
532 *count += 1;
533 if *count == self.period {
534 let first_val = *sum / self.pf64;
535 self.state = Ready { value: first_val };
536 Some(first_val)
537 } else {
538 None
539 }
540 }
541 Ready { value } => {
542 let next = (*value * self.pm1 + v) / self.pf64;
543 *value = next;
544 Some(next)
545 }
546 }
547 }
548
549 #[inline(always)]
550 pub fn is_ready(&self) -> bool {
551 matches!(self.state, SmmaStreamState::Ready { .. })
552 }
553 #[inline(always)]
554 pub fn current(&self) -> Option<f64> {
555 match self.state {
556 SmmaStreamState::Ready { value } => Some(value),
557 _ => None,
558 }
559 }
560 #[inline]
561 pub fn reset(&mut self) {
562 self.state = SmmaStreamState::SeekingFirst;
563 }
564}
565
566#[derive(Clone, Debug)]
567pub struct SmmaBatchRange {
568 pub period: (usize, usize, usize),
569}
570
571impl Default for SmmaBatchRange {
572 fn default() -> Self {
573 Self {
574 period: (7, 256, 1),
575 }
576 }
577}
578
579#[derive(Clone, Debug, Default)]
580pub struct SmmaBatchBuilder {
581 range: SmmaBatchRange,
582 kernel: Kernel,
583}
584
585impl SmmaBatchBuilder {
586 pub fn new() -> Self {
587 Self::default()
588 }
589 pub fn kernel(mut self, k: Kernel) -> Self {
590 self.kernel = k;
591 self
592 }
593 #[inline]
594 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
595 self.range.period = (start, end, step);
596 self
597 }
598 #[inline]
599 pub fn period_static(mut self, p: usize) -> Self {
600 self.range.period = (p, p, 0);
601 self
602 }
603
604 pub fn apply_slice(self, data: &[f64]) -> Result<SmmaBatchOutput, SmmaError> {
605 smma_batch_with_kernel(data, &self.range, self.kernel)
606 }
607 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<SmmaBatchOutput, SmmaError> {
608 SmmaBatchBuilder::new().kernel(k).apply_slice(data)
609 }
610 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<SmmaBatchOutput, SmmaError> {
611 let slice = source_type(c, src);
612 self.apply_slice(slice)
613 }
614 pub fn with_default_candles(c: &Candles) -> Result<SmmaBatchOutput, SmmaError> {
615 SmmaBatchBuilder::new()
616 .kernel(Kernel::Auto)
617 .apply_candles(c, "close")
618 }
619}
620
621pub fn smma_batch_with_kernel(
622 data: &[f64],
623 sweep: &SmmaBatchRange,
624 k: Kernel,
625) -> Result<SmmaBatchOutput, SmmaError> {
626 let kernel = match k {
627 Kernel::Auto => detect_best_batch_kernel(),
628 other if other.is_batch() => other,
629 other => return Err(SmmaError::InvalidKernelForBatch { kernel: other }),
630 };
631 let simd = match kernel {
632 Kernel::Avx512Batch => Kernel::Avx512,
633 Kernel::Avx2Batch => Kernel::Avx2,
634 Kernel::ScalarBatch => Kernel::Scalar,
635 _ => Kernel::Scalar,
636 };
637 smma_batch_par_slice(data, sweep, simd)
638}
639
640#[derive(Clone, Debug)]
641pub struct SmmaBatchOutput {
642 pub values: Vec<f64>,
643 pub combos: Vec<SmmaParams>,
644 pub rows: usize,
645 pub cols: usize,
646}
647impl SmmaBatchOutput {
648 pub fn row_for_params(&self, p: &SmmaParams) -> Option<usize> {
649 self.combos
650 .iter()
651 .position(|c| c.period.unwrap_or(7) == p.period.unwrap_or(7))
652 }
653 pub fn values_for(&self, p: &SmmaParams) -> Option<&[f64]> {
654 self.row_for_params(p).map(|row| {
655 let start = row * self.cols;
656 &self.values[start..start + self.cols]
657 })
658 }
659}
660
661#[inline(always)]
662pub fn smma_batch_slice(
663 data: &[f64],
664 sweep: &SmmaBatchRange,
665 kern: Kernel,
666) -> Result<SmmaBatchOutput, SmmaError> {
667 smma_batch_inner(data, sweep, kern, false)
668}
669#[inline(always)]
670pub fn smma_batch_par_slice(
671 data: &[f64],
672 sweep: &SmmaBatchRange,
673 kern: Kernel,
674) -> Result<SmmaBatchOutput, SmmaError> {
675 smma_batch_inner(data, sweep, kern, true)
676}
677#[inline(always)]
678fn smma_batch_inner(
679 data: &[f64],
680 sweep: &SmmaBatchRange,
681 kern: Kernel,
682 parallel: bool,
683) -> Result<SmmaBatchOutput, SmmaError> {
684 if data.is_empty() {
685 return Err(SmmaError::EmptyInputData);
686 }
687
688 let combos = expand_grid(sweep);
689 if combos.is_empty() {
690 let (s, e, st) = sweep.period;
691 return Err(SmmaError::InvalidRange {
692 start: s,
693 end: e,
694 step: st,
695 });
696 }
697 let first = data
698 .iter()
699 .position(|x| !x.is_nan())
700 .ok_or(SmmaError::AllValuesNaN)?;
701 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
702 if data.len() - first < max_p {
703 return Err(SmmaError::NotEnoughValidData {
704 needed: max_p,
705 valid: data.len() - first,
706 });
707 }
708 let rows = combos.len();
709 let cols = data.len();
710
711 let warm: Vec<usize> = combos
712 .iter()
713 .map(|c| first + c.period.unwrap() - 1)
714 .collect();
715
716 let mut raw = make_uninit_matrix(rows, cols);
717 unsafe {
718 init_matrix_prefixes(&mut raw, cols, &warm);
719 }
720
721 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
722 let period = combos[row].period.unwrap();
723
724 let out_row =
725 core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
726
727 match kern {
728 Kernel::Scalar | Kernel::ScalarBatch => smma_row_scalar(data, first, period, out_row),
729 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
730 Kernel::Avx2 | Kernel::Avx2Batch => smma_row_avx2(data, first, period, out_row),
731 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
732 Kernel::Avx512 | Kernel::Avx512Batch => smma_row_avx512(data, first, period, out_row),
733 _ => smma_row_scalar(data, first, period, out_row),
734 }
735 };
736
737 if parallel {
738 #[cfg(not(target_arch = "wasm32"))]
739 {
740 raw.par_chunks_mut(cols)
741 .enumerate()
742 .for_each(|(row, slice)| do_row(row, slice));
743 }
744
745 #[cfg(target_arch = "wasm32")]
746 {
747 for (row, slice) in raw.chunks_mut(cols).enumerate() {
748 do_row(row, slice);
749 }
750 }
751 } else {
752 for (row, slice) in raw.chunks_mut(cols).enumerate() {
753 do_row(row, slice);
754 }
755 }
756
757 let mut guard = core::mem::ManuallyDrop::new(raw);
758 let values = unsafe {
759 Vec::from_raw_parts(
760 guard.as_mut_ptr() as *mut f64,
761 guard.len(),
762 guard.capacity(),
763 )
764 };
765
766 Ok(SmmaBatchOutput {
767 values,
768 combos,
769 rows,
770 cols,
771 })
772}
773
774#[inline(always)]
775fn smma_batch_inner_into(
776 data: &[f64],
777 sweep: &SmmaBatchRange,
778 kern: Kernel,
779 parallel: bool,
780 out: &mut [f64],
781) -> Result<Vec<SmmaParams>, SmmaError> {
782 if data.is_empty() {
783 return Err(SmmaError::EmptyInputData);
784 }
785
786 let combos = expand_grid(sweep);
787 if combos.is_empty() {
788 let (s, e, st) = sweep.period;
789 return Err(SmmaError::InvalidRange {
790 start: s,
791 end: e,
792 step: st,
793 });
794 }
795 let first = data
796 .iter()
797 .position(|x| !x.is_nan())
798 .ok_or(SmmaError::AllValuesNaN)?;
799 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
800 if data.len() - first < max_p {
801 return Err(SmmaError::NotEnoughValidData {
802 needed: max_p,
803 valid: data.len() - first,
804 });
805 }
806
807 let rows = combos.len();
808 let cols = data.len();
809
810 let warm: Vec<usize> = combos
811 .iter()
812 .map(|c| first + c.period.unwrap() - 1)
813 .collect();
814
815 let out_uninit = unsafe {
816 std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
817 };
818
819 unsafe { init_matrix_prefixes(out_uninit, cols, &warm) };
820
821 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
822 let period = combos[row].period.unwrap();
823
824 let dst = core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
825
826 match kern {
827 Kernel::Scalar | Kernel::ScalarBatch => smma_row_scalar(data, first, period, dst),
828 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
829 Kernel::Avx2 | Kernel::Avx2Batch => smma_row_avx2(data, first, period, dst),
830 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
831 Kernel::Avx512 | Kernel::Avx512Batch => smma_row_avx512(data, first, period, dst),
832 _ => smma_row_scalar(data, first, period, dst),
833 }
834 };
835
836 if parallel {
837 #[cfg(not(target_arch = "wasm32"))]
838 {
839 out_uninit
840 .par_chunks_mut(cols)
841 .enumerate()
842 .for_each(|(row, slice)| do_row(row, slice));
843 }
844 #[cfg(target_arch = "wasm32")]
845 {
846 for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
847 do_row(row, slice);
848 }
849 }
850 } else {
851 for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
852 do_row(row, slice);
853 }
854 }
855
856 Ok(combos)
857}
858
859#[inline(always)]
860pub unsafe fn smma_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
861 smma_scalar(data, period, first, out)
862}
863#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
864#[inline(always)]
865pub unsafe fn smma_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
866 smma_avx2(data, period, first, out)
867}
868#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
869#[inline(always)]
870pub unsafe fn smma_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
871 if period <= 32 {
872 smma_row_avx512_short(data, first, period, out);
873 } else {
874 smma_row_avx512_long(data, first, period, out);
875 }
876}
877#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
878#[inline(always)]
879pub unsafe fn smma_row_avx512_short(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
880 smma_row_avx2(data, first, period, out)
881}
882#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
883#[inline(always)]
884pub unsafe fn smma_row_avx512_long(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
885 smma_row_avx2(data, first, period, out)
886}
887
888#[cfg(test)]
889mod tests {
890 use super::*;
891 use crate::skip_if_unsupported;
892 use crate::utilities::data_loader::read_candles_from_csv;
893 #[cfg(feature = "proptest")]
894 use proptest::prelude::*;
895
896 fn check_smma_partial_params(
897 test_name: &str,
898 kernel: Kernel,
899 ) -> Result<(), Box<dyn std::error::Error>> {
900 skip_if_unsupported!(kernel, test_name);
901 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
902 let candles = read_candles_from_csv(file_path)?;
903 let input = SmmaInput::from_candles(&candles, "close", SmmaParams { period: None });
904 let output = smma_with_kernel(&input, kernel)?;
905 assert_eq!(output.values.len(), candles.close.len());
906 Ok(())
907 }
908 fn check_smma_accuracy(
909 test_name: &str,
910 kernel: Kernel,
911 ) -> Result<(), Box<dyn std::error::Error>> {
912 skip_if_unsupported!(kernel, test_name);
913 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
914 let candles = read_candles_from_csv(file_path)?;
915 let input = SmmaInput::from_candles(&candles, "close", SmmaParams::default());
916 let result = smma_with_kernel(&input, kernel)?;
917 let expected_last_five = [59434.4, 59398.2, 59346.9, 59319.4, 59224.5];
918 let start = result.values.len().saturating_sub(5);
919 for (i, &val) in result.values[start..].iter().enumerate() {
920 let diff = (val - expected_last_five[i]).abs();
921 assert!(
922 diff < 1e-1,
923 "[{}] SMMA {:?} mismatch at idx {}: got {}, expected {}",
924 test_name,
925 kernel,
926 i,
927 val,
928 expected_last_five[i]
929 );
930 }
931 Ok(())
932 }
933 fn check_smma_default_candles(
934 test_name: &str,
935 kernel: Kernel,
936 ) -> Result<(), Box<dyn std::error::Error>> {
937 skip_if_unsupported!(kernel, test_name);
938 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
939 let candles = read_candles_from_csv(file_path)?;
940 let input = SmmaInput::with_default_candles(&candles);
941 match input.data {
942 SmmaData::Candles { source, .. } => assert_eq!(source, "close"),
943 _ => panic!("Expected SmmaData::Candles"),
944 }
945 let output = smma_with_kernel(&input, kernel)?;
946 assert_eq!(output.values.len(), candles.close.len());
947 Ok(())
948 }
949 fn check_smma_zero_period(
950 test_name: &str,
951 kernel: Kernel,
952 ) -> Result<(), Box<dyn std::error::Error>> {
953 skip_if_unsupported!(kernel, test_name);
954 let input_data = [10.0, 20.0, 30.0];
955 let params = SmmaParams { period: Some(0) };
956 let input = SmmaInput::from_slice(&input_data, params);
957 let res = smma_with_kernel(&input, kernel);
958 assert!(
959 res.is_err(),
960 "[{}] SMMA should fail with zero period",
961 test_name
962 );
963 Ok(())
964 }
965 fn check_smma_period_exceeds_length(
966 test_name: &str,
967 kernel: Kernel,
968 ) -> Result<(), Box<dyn std::error::Error>> {
969 skip_if_unsupported!(kernel, test_name);
970 let data_small = [10.0, 20.0, 30.0];
971 let params = SmmaParams { period: Some(10) };
972 let input = SmmaInput::from_slice(&data_small, params);
973 let res = smma_with_kernel(&input, kernel);
974 assert!(
975 res.is_err(),
976 "[{}] SMMA should fail with period exceeding length",
977 test_name
978 );
979 Ok(())
980 }
981 fn check_smma_very_small_dataset(
982 test_name: &str,
983 kernel: Kernel,
984 ) -> Result<(), Box<dyn std::error::Error>> {
985 skip_if_unsupported!(kernel, test_name);
986 let single_point = [42.0];
987 let params = SmmaParams { period: Some(9) };
988 let input = SmmaInput::from_slice(&single_point, params);
989 let res = smma_with_kernel(&input, kernel);
990 assert!(
991 res.is_err(),
992 "[{}] SMMA should fail with insufficient data",
993 test_name
994 );
995 Ok(())
996 }
997 fn check_smma_empty_input(
998 test_name: &str,
999 kernel: Kernel,
1000 ) -> Result<(), Box<dyn std::error::Error>> {
1001 skip_if_unsupported!(kernel, test_name);
1002 let empty: Vec<f64> = vec![];
1003 let params = SmmaParams { period: Some(7) };
1004 let input = SmmaInput::from_slice(&empty, params);
1005 let res = smma_with_kernel(&input, kernel);
1006 assert!(
1007 res.is_err(),
1008 "[{}] SMMA should fail with empty input",
1009 test_name
1010 );
1011 if let Err(SmmaError::EmptyInputData) = res {
1012 } else {
1013 panic!(
1014 "[{}] Expected EmptyInputData error, got {:?}",
1015 test_name, res
1016 );
1017 }
1018 Ok(())
1019 }
1020 fn check_smma_reinput(
1021 test_name: &str,
1022 kernel: Kernel,
1023 ) -> Result<(), Box<dyn std::error::Error>> {
1024 skip_if_unsupported!(kernel, test_name);
1025 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1026 let candles = read_candles_from_csv(file_path)?;
1027 let first_params = SmmaParams { period: Some(7) };
1028 let first_input = SmmaInput::from_candles(&candles, "close", first_params);
1029 let first_result = smma_with_kernel(&first_input, kernel)?;
1030 let second_params = SmmaParams { period: Some(5) };
1031 let second_input = SmmaInput::from_slice(&first_result.values, second_params);
1032 let second_result = smma_with_kernel(&second_input, kernel)?;
1033 assert_eq!(second_result.values.len(), first_result.values.len());
1034 Ok(())
1035 }
1036 fn check_smma_nan_handling(
1037 test_name: &str,
1038 kernel: Kernel,
1039 ) -> Result<(), Box<dyn std::error::Error>> {
1040 skip_if_unsupported!(kernel, test_name);
1041 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1042 let candles = read_candles_from_csv(file_path)?;
1043 let input = SmmaInput::from_candles(&candles, "close", SmmaParams { period: Some(7) });
1044 let res = smma_with_kernel(&input, kernel)?;
1045 assert_eq!(res.values.len(), candles.close.len());
1046 if res.values.len() > 240 {
1047 for (i, &val) in res.values[240..].iter().enumerate() {
1048 assert!(
1049 !val.is_nan(),
1050 "[{}] Found unexpected NaN at out-index {}",
1051 test_name,
1052 240 + i
1053 );
1054 }
1055 }
1056 Ok(())
1057 }
1058 fn check_smma_streaming(
1059 test_name: &str,
1060 kernel: Kernel,
1061 ) -> Result<(), Box<dyn std::error::Error>> {
1062 skip_if_unsupported!(kernel, test_name);
1063 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1064 let candles = read_candles_from_csv(file_path)?;
1065 let period = 7;
1066 let input = SmmaInput::from_candles(
1067 &candles,
1068 "close",
1069 SmmaParams {
1070 period: Some(period),
1071 },
1072 );
1073 let batch_output = smma_with_kernel(&input, kernel)?.values;
1074 let mut stream = SmmaStream::try_new(SmmaParams {
1075 period: Some(period),
1076 })?;
1077 let mut stream_values = Vec::with_capacity(candles.close.len());
1078 for &price in &candles.close {
1079 match stream.update(price) {
1080 Some(smma_val) => stream_values.push(smma_val),
1081 None => stream_values.push(f64::NAN),
1082 }
1083 }
1084 assert_eq!(batch_output.len(), stream_values.len());
1085 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1086 if b.is_nan() && s.is_nan() {
1087 continue;
1088 }
1089 let diff = (b - s).abs();
1090 assert!(
1091 diff < 1e-7,
1092 "[{}] SMMA streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1093 test_name,
1094 i,
1095 b,
1096 s,
1097 diff
1098 );
1099 }
1100 Ok(())
1101 }
1102 #[cfg(feature = "proptest")]
1103 fn check_smma_property(
1104 test_name: &str,
1105 kernel: Kernel,
1106 ) -> Result<(), Box<dyn std::error::Error>> {
1107 use proptest::prelude::*;
1108 skip_if_unsupported!(kernel, test_name);
1109
1110 let strat = (1usize..=100).prop_flat_map(|period| {
1111 (
1112 prop::collection::vec(
1113 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1114 period.max(2)..400,
1115 ),
1116 Just(period),
1117 )
1118 });
1119
1120 proptest::test_runner::TestRunner::default().run(&strat, |(data, period)| {
1121 let params = SmmaParams {
1122 period: Some(period),
1123 };
1124 let input = SmmaInput::from_slice(&data, params);
1125
1126 let output = smma_with_kernel(&input, kernel)?;
1127
1128 let reference = smma_with_kernel(&input, Kernel::Scalar)?;
1129
1130 prop_assert_eq!(output.values.len(), data.len());
1131 prop_assert_eq!(reference.values.len(), data.len());
1132
1133 if period > 1 {
1134 for i in 0..period - 1 {
1135 prop_assert!(
1136 output.values[i].is_nan(),
1137 "Expected NaN at index {} but got {}",
1138 i,
1139 output.values[i]
1140 );
1141 }
1142 }
1143
1144 let first_smma_idx = if period == 1 { 0 } else { period - 1 };
1145 let first_sum: f64 = data[0..period].iter().sum();
1146 let expected_first = first_sum / period as f64;
1147 let actual_first = output.values[first_smma_idx];
1148 prop_assert!(
1149 (actual_first - expected_first).abs() < 1e-7,
1150 "First SMMA value mismatch: expected {}, got {} (diff: {})",
1151 expected_first,
1152 actual_first,
1153 (actual_first - expected_first).abs()
1154 );
1155
1156 if data.len() > period {
1157 for i in period..data.len().min(period + 10) {
1158 let prev_smma = output.values[i - 1];
1159 let expected = (prev_smma * (period as f64 - 1.0) + data[i]) / period as f64;
1160 let actual = output.values[i];
1161
1162 prop_assert!(
1163 (actual - expected).abs() < 1e-7,
1164 "Recursive formula mismatch at index {}: expected {}, got {} (diff: {})",
1165 i,
1166 expected,
1167 actual,
1168 (actual - expected).abs()
1169 );
1170 }
1171 }
1172
1173 for i in 0..output.values.len() {
1174 let test_val = output.values[i];
1175 let ref_val = reference.values[i];
1176
1177 if test_val.is_nan() && ref_val.is_nan() {
1178 continue;
1179 }
1180
1181 prop_assert!(
1182 test_val.is_finite() == ref_val.is_finite(),
1183 "Finite/NaN mismatch at index {}: test={}, ref={}",
1184 i,
1185 test_val,
1186 ref_val
1187 );
1188
1189 if test_val.is_finite() && ref_val.is_finite() {
1190 let test_bits = test_val.to_bits();
1191 let ref_bits = ref_val.to_bits();
1192 let ulp_diff = test_bits.abs_diff(ref_bits);
1193
1194 let max_ulps = if matches!(kernel, Kernel::Avx512 | Kernel::Avx512Batch) {
1195 20
1196 } else {
1197 10
1198 };
1199
1200 prop_assert!(
1201 ulp_diff <= max_ulps || (test_val - ref_val).abs() < 1e-7,
1202 "Cross-kernel mismatch at index {}: test={}, ref={}, ULP diff={}, abs diff={}",
1203 i,
1204 test_val,
1205 ref_val,
1206 ulp_diff,
1207 (test_val - ref_val).abs()
1208 );
1209 }
1210 }
1211
1212 let start_idx = if period == 1 { 0 } else { period - 1 };
1213 for i in start_idx..output.values.len() {
1214 let val = output.values[i];
1215 if val.is_finite() {
1216 let data_up_to_i = &data[0..=i];
1217 let min = data_up_to_i.iter().copied().fold(f64::INFINITY, f64::min);
1218 let max = data_up_to_i
1219 .iter()
1220 .copied()
1221 .fold(f64::NEG_INFINITY, f64::max);
1222
1223 prop_assert!(
1224 val >= min - 1e-9 && val <= max + 1e-9,
1225 "SMMA value {} at index {} outside historical bounds [{}, {}]",
1226 val,
1227 i,
1228 min,
1229 max
1230 );
1231 }
1232 }
1233
1234 if data.windows(2).all(|w| (w[0] - w[1]).abs() < f64::EPSILON) {
1235 let constant_val = data[0];
1236 let check_start = if period == 1 { 0 } else { period - 1 };
1237 for i in check_start..output.values.len() {
1238 let val = output.values[i];
1239 prop_assert!(
1240 (val - constant_val).abs() < 1e-7,
1241 "SMMA should converge to {} for constant data, but got {} at index {}",
1242 constant_val,
1243 val,
1244 i
1245 );
1246 }
1247 }
1248
1249 if period == 1 {
1250 prop_assert_eq!(
1251 output.values[0],
1252 data[0],
1253 "Period=1: first value should equal input"
1254 );
1255 for i in 1..data.len() {
1256 prop_assert!(
1257 (output.values[i] - data[i]).abs() < 1e-7,
1258 "Period=1: output should equal input at index {}: {} != {}",
1259 i,
1260 output.values[i],
1261 data[i]
1262 );
1263 }
1264 }
1265
1266 Ok(())
1267 })?;
1268
1269 Ok(())
1270 }
1271
1272 #[cfg(debug_assertions)]
1273 fn check_smma_no_poison(
1274 test_name: &str,
1275 kernel: Kernel,
1276 ) -> Result<(), Box<dyn std::error::Error>> {
1277 skip_if_unsupported!(kernel, test_name);
1278
1279 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1280 let candles = read_candles_from_csv(file_path)?;
1281
1282 let test_periods = vec![5, 7, 14, 50, 100, 200];
1283
1284 for &period in &test_periods {
1285 if period > candles.close.len() {
1286 continue;
1287 }
1288
1289 let input = SmmaInput::from_candles(
1290 &candles,
1291 "close",
1292 SmmaParams {
1293 period: Some(period),
1294 },
1295 );
1296 let output = smma_with_kernel(&input, kernel)?;
1297
1298 for (i, &val) in output.values.iter().enumerate() {
1299 if val.is_nan() {
1300 continue;
1301 }
1302
1303 let bits = val.to_bits();
1304
1305 if bits == 0x11111111_11111111 {
1306 panic!(
1307 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} with period {}",
1308 test_name, val, bits, i, period
1309 );
1310 }
1311
1312 if bits == 0x22222222_22222222 {
1313 panic!(
1314 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} with period {}",
1315 test_name, val, bits, i, period
1316 );
1317 }
1318
1319 if bits == 0x33333333_33333333 {
1320 panic!(
1321 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} with period {}",
1322 test_name, val, bits, i, period
1323 );
1324 }
1325 }
1326 }
1327
1328 Ok(())
1329 }
1330
1331 #[cfg(not(debug_assertions))]
1332 fn check_smma_no_poison(
1333 _test_name: &str,
1334 _kernel: Kernel,
1335 ) -> Result<(), Box<dyn std::error::Error>> {
1336 Ok(())
1337 }
1338
1339 macro_rules! generate_all_smma_tests {
1340 ($($test_fn:ident),*) => {
1341 paste::paste! {
1342 $(
1343 #[test]
1344 fn [<$test_fn _scalar_f64>]() {
1345 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1346 }
1347 )*
1348 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1349 $(
1350 #[test]
1351 fn [<$test_fn _avx2_f64>]() {
1352 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1353 }
1354 #[test]
1355 fn [<$test_fn _avx512_f64>]() {
1356 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1357 }
1358 )*
1359 }
1360 }
1361 }
1362 generate_all_smma_tests!(
1363 check_smma_partial_params,
1364 check_smma_accuracy,
1365 check_smma_default_candles,
1366 check_smma_zero_period,
1367 check_smma_period_exceeds_length,
1368 check_smma_very_small_dataset,
1369 check_smma_empty_input,
1370 check_smma_reinput,
1371 check_smma_nan_handling,
1372 check_smma_streaming,
1373 check_smma_no_poison
1374 );
1375
1376 #[cfg(feature = "proptest")]
1377 generate_all_smma_tests!(check_smma_property);
1378
1379 #[test]
1380 fn test_smma_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
1381 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1382 let candles = read_candles_from_csv(file_path)?;
1383
1384 let input = SmmaInput::with_default_candles(&candles);
1385 let baseline = smma(&input)?.values;
1386
1387 let mut out = vec![0.0f64; baseline.len()];
1388
1389 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1390 {
1391 smma_into(&input, &mut out)?;
1392 }
1393 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1394 {
1395 smma_into_slice(&mut out, &input, Kernel::Auto)?;
1396 }
1397
1398 assert_eq!(out.len(), baseline.len());
1399
1400 for (i, (&a, &b)) in out.iter().zip(baseline.iter()).enumerate() {
1401 let equal = (a.is_nan() && b.is_nan()) || (a == b);
1402 assert!(
1403 equal,
1404 "into parity mismatch at idx {}: got {}, expected {}",
1405 i, a, b
1406 );
1407 }
1408
1409 Ok(())
1410 }
1411 fn check_batch_default_row(
1412 test: &str,
1413 kernel: Kernel,
1414 ) -> Result<(), Box<dyn std::error::Error>> {
1415 skip_if_unsupported!(kernel, test);
1416 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1417 let c = read_candles_from_csv(file)?;
1418 let output = SmmaBatchBuilder::new()
1419 .kernel(kernel)
1420 .apply_candles(&c, "close")?;
1421 let def = SmmaParams::default();
1422 let row = output.values_for(&def).expect("default row missing");
1423 assert_eq!(row.len(), c.close.len());
1424 let expected = [59434.4, 59398.2, 59346.9, 59319.4, 59224.5];
1425 let start = row.len() - 5;
1426 for (i, &v) in row[start..].iter().enumerate() {
1427 assert!(
1428 (v - expected[i]).abs() < 1e-1,
1429 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1430 );
1431 }
1432 Ok(())
1433 }
1434
1435 #[cfg(debug_assertions)]
1436 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
1437 skip_if_unsupported!(kernel, test);
1438
1439 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1440 let c = read_candles_from_csv(file)?;
1441
1442 let batch_configs = vec![
1443 (3, 7, 1),
1444 (10, 30, 10),
1445 (5, 100, 5),
1446 (50, 200, 50),
1447 (1, 10, 1),
1448 ];
1449
1450 for (start, end, step) in batch_configs {
1451 if start > c.close.len() {
1452 continue;
1453 }
1454
1455 let output = SmmaBatchBuilder::new()
1456 .kernel(kernel)
1457 .period_range(start, end, step)
1458 .apply_candles(&c, "close")?;
1459
1460 for (idx, &val) in output.values.iter().enumerate() {
1461 if val.is_nan() {
1462 continue;
1463 }
1464
1465 let bits = val.to_bits();
1466 let row = idx / output.cols;
1467 let col = idx % output.cols;
1468 let period = output.combos[row].period.unwrap_or(0);
1469
1470 if bits == 0x11111111_11111111 {
1471 panic!(
1472 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (flat index {}) for period {} in range ({}, {}, {})",
1473 test, val, bits, row, col, idx, period, start, end, step
1474 );
1475 }
1476
1477 if bits == 0x22222222_22222222 {
1478 panic!(
1479 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (flat index {}) for period {} in range ({}, {}, {})",
1480 test, val, bits, row, col, idx, period, start, end, step
1481 );
1482 }
1483
1484 if bits == 0x33333333_33333333 {
1485 panic!(
1486 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (flat index {}) for period {} in range ({}, {}, {})",
1487 test, val, bits, row, col, idx, period, start, end, step
1488 );
1489 }
1490 }
1491 }
1492
1493 Ok(())
1494 }
1495
1496 #[cfg(not(debug_assertions))]
1497 fn check_batch_no_poison(
1498 _test: &str,
1499 _kernel: Kernel,
1500 ) -> Result<(), Box<dyn std::error::Error>> {
1501 Ok(())
1502 }
1503 macro_rules! gen_batch_tests {
1504 ($fn_name:ident) => {
1505 paste::paste! {
1506 #[test] fn [<$fn_name _scalar>]() {
1507 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1508 }
1509 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1510 #[test] fn [<$fn_name _avx2>]() {
1511 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1512 }
1513 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1514 #[test] fn [<$fn_name _avx512>]() {
1515 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1516 }
1517 #[test] fn [<$fn_name _auto_detect>]() {
1518 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1519 }
1520 }
1521 };
1522 }
1523 gen_batch_tests!(check_batch_default_row);
1524 gen_batch_tests!(check_batch_no_poison);
1525}
1526
1527#[cfg(feature = "python")]
1528#[pyfunction(name = "smma")]
1529#[pyo3(signature = (data, period, kernel=None))]
1530
1531pub fn smma_py<'py>(
1532 py: Python<'py>,
1533 data: numpy::PyReadonlyArray1<'py, f64>,
1534 period: usize,
1535 kernel: Option<&str>,
1536) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
1537 use numpy::{IntoPyArray, PyArrayMethods};
1538
1539 let slice_in = data.as_slice()?;
1540 let kern = validate_kernel(kernel, false)?;
1541
1542 let params = SmmaParams {
1543 period: Some(period),
1544 };
1545 let smma_in = SmmaInput::from_slice(slice_in, params);
1546
1547 let result_vec: Vec<f64> = py
1548 .allow_threads(|| smma_with_kernel(&smma_in, kern).map(|o| o.values))
1549 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1550
1551 Ok(result_vec.into_pyarray(py))
1552}
1553
1554#[cfg(feature = "python")]
1555#[pyclass(name = "SmmaStream")]
1556pub struct SmmaStreamPy {
1557 stream: SmmaStream,
1558}
1559
1560#[cfg(feature = "python")]
1561#[pymethods]
1562impl SmmaStreamPy {
1563 #[new]
1564 fn new(period: usize) -> PyResult<Self> {
1565 let params = SmmaParams {
1566 period: Some(period),
1567 };
1568 let stream =
1569 SmmaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1570 Ok(SmmaStreamPy { stream })
1571 }
1572
1573 fn update(&mut self, value: f64) -> Option<f64> {
1574 self.stream.update(value)
1575 }
1576}
1577
1578#[cfg(feature = "python")]
1579#[pyfunction(name = "smma_batch")]
1580#[pyo3(signature = (data, period_range, kernel=None))]
1581
1582pub fn smma_batch_py<'py>(
1583 py: Python<'py>,
1584 data: numpy::PyReadonlyArray1<'py, f64>,
1585 period_range: (usize, usize, usize),
1586 kernel: Option<&str>,
1587) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1588 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1589 use pyo3::types::PyDict;
1590
1591 let slice_in = data.as_slice()?;
1592 let kern = validate_kernel(kernel, true)?;
1593 let sweep = SmmaBatchRange {
1594 period: period_range,
1595 };
1596
1597 let combos = expand_grid(&sweep);
1598 if combos.is_empty() {
1599 return Err(PyValueError::new_err(format!(
1600 "invalid period range: start={}, end={}, step={}",
1601 sweep.period.0, sweep.period.1, sweep.period.2
1602 )));
1603 }
1604 let rows = combos.len();
1605 let cols = slice_in.len();
1606 let total = rows
1607 .checked_mul(cols)
1608 .ok_or_else(|| PyValueError::new_err("rows * cols overflow"))?;
1609
1610 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1611 let slice_out = unsafe { out_arr.as_slice_mut()? };
1612
1613 let combos = py
1614 .allow_threads(|| {
1615 let kernel = match kern {
1616 Kernel::Auto => detect_best_batch_kernel(),
1617 k => k,
1618 };
1619 let simd = match kernel {
1620 Kernel::Avx512Batch => Kernel::Avx512,
1621 Kernel::Avx2Batch => Kernel::Avx2,
1622 Kernel::ScalarBatch => Kernel::Scalar,
1623 _ => unreachable!(),
1624 };
1625
1626 smma_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1627 })
1628 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1629
1630 let dict = PyDict::new(py);
1631 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1632 dict.set_item(
1633 "periods",
1634 combos
1635 .iter()
1636 .map(|p| p.period.unwrap() as u64)
1637 .collect::<Vec<_>>()
1638 .into_pyarray(py),
1639 )?;
1640
1641 Ok(dict.into())
1642}
1643
1644#[cfg(all(feature = "python", feature = "cuda"))]
1645#[pyfunction(name = "smma_cuda_batch_dev")]
1646#[pyo3(signature = (data, period_range, device_id=0))]
1647pub fn smma_cuda_batch_dev_py(
1648 py: Python<'_>,
1649 data: numpy::PyReadonlyArray1<'_, f64>,
1650 period_range: (usize, usize, usize),
1651 device_id: usize,
1652) -> PyResult<DeviceArrayF32SmmaPy> {
1653 use numpy::PyArrayMethods;
1654
1655 if !cuda_available() {
1656 return Err(PyValueError::new_err("CUDA not available"));
1657 }
1658
1659 let slice_in = data.as_slice()?;
1660 let sweep = SmmaBatchRange {
1661 period: period_range,
1662 };
1663 let data_f32: Vec<f32> = slice_in.iter().map(|&v| v as f32).collect();
1664
1665 let inner = py.allow_threads(|| {
1666 let cuda = CudaSmma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1667 cuda.smma_batch_dev(&data_f32, &sweep)
1668 .map_err(|e| PyValueError::new_err(e.to_string()))
1669 })?;
1670
1671 Ok(DeviceArrayF32SmmaPy { inner: Some(inner) })
1672}
1673
1674#[cfg(all(feature = "python", feature = "cuda"))]
1675#[pyfunction(name = "smma_cuda_many_series_one_param_dev")]
1676#[pyo3(signature = (data_tm_f32, period, device_id=0))]
1677pub fn smma_cuda_many_series_one_param_dev_py(
1678 py: Python<'_>,
1679 data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1680 period: usize,
1681 device_id: usize,
1682) -> PyResult<DeviceArrayF32SmmaPy> {
1683 use numpy::PyUntypedArrayMethods;
1684
1685 if !cuda_available() {
1686 return Err(PyValueError::new_err("CUDA not available"));
1687 }
1688
1689 let flat_in = data_tm_f32.as_slice()?;
1690 let rows = data_tm_f32.shape()[0];
1691 let cols = data_tm_f32.shape()[1];
1692 let params = SmmaParams {
1693 period: Some(period),
1694 };
1695
1696 let inner = py.allow_threads(|| {
1697 let cuda = CudaSmma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1698 cuda.smma_multi_series_one_param_time_major_dev(flat_in, cols, rows, ¶ms)
1699 .map_err(|e| PyValueError::new_err(e.to_string()))
1700 })?;
1701
1702 Ok(DeviceArrayF32SmmaPy { inner: Some(inner) })
1703}
1704
1705#[cfg(all(feature = "python", feature = "cuda"))]
1706#[pyclass(module = "ta_indicators.cuda", name = "DeviceArrayF32Smma", unsendable)]
1707pub struct DeviceArrayF32SmmaPy {
1708 pub(crate) inner: Option<DeviceArrayF32Smma>,
1709}
1710
1711#[cfg(all(feature = "python", feature = "cuda"))]
1712#[pymethods]
1713impl DeviceArrayF32SmmaPy {
1714 #[getter]
1715 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1716 let inner = self
1717 .inner
1718 .as_ref()
1719 .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
1720 let d = PyDict::new(py);
1721 d.set_item("shape", (inner.rows, inner.cols))?;
1722 d.set_item("typestr", "<f4")?;
1723 d.set_item(
1724 "strides",
1725 (
1726 inner.cols * std::mem::size_of::<f32>(),
1727 std::mem::size_of::<f32>(),
1728 ),
1729 )?;
1730 d.set_item("data", (inner.device_ptr() as usize, false))?;
1731 d.set_item("version", 3)?;
1732 Ok(d)
1733 }
1734
1735 fn __dlpack_device__(&self) -> PyResult<(i32, i32)> {
1736 let inner = self
1737 .inner
1738 .as_ref()
1739 .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
1740 Ok((2, inner.device_id as i32))
1741 }
1742
1743 fn __dlpack__<'py>(
1744 &mut self,
1745 py: Python<'py>,
1746 stream: Option<pyo3::PyObject>,
1747 max_version: Option<pyo3::PyObject>,
1748 dl_device: Option<pyo3::PyObject>,
1749 copy: Option<pyo3::PyObject>,
1750 ) -> PyResult<PyObject> {
1751 use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
1752
1753 let (kdl, alloc_dev) = self.__dlpack_device__()?;
1754 if let Some(dev_obj) = dl_device.as_ref() {
1755 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1756 if dev_ty != kdl || dev_id != alloc_dev {
1757 let wants_copy = copy
1758 .as_ref()
1759 .and_then(|c| c.extract::<bool>(py).ok())
1760 .unwrap_or(false);
1761 if wants_copy {
1762 return Err(PyValueError::new_err(
1763 "device copy not implemented for __dlpack__",
1764 ));
1765 } else {
1766 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
1767 }
1768 }
1769 }
1770 }
1771
1772 let _ = stream;
1773
1774 let inner = self
1775 .inner
1776 .take()
1777 .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
1778
1779 let rows = inner.rows;
1780 let cols = inner.cols;
1781 let buf = inner.buf;
1782
1783 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1784
1785 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1786 }
1787}
1788
1789#[inline]
1790pub fn smma_into_slice(dst: &mut [f64], input: &SmmaInput, kern: Kernel) -> Result<(), SmmaError> {
1791 let (data, period, first, chosen) = smma_prepare(input, kern)?;
1792
1793 if dst.len() != data.len() {
1794 return Err(SmmaError::OutputLengthMismatch {
1795 expected: data.len(),
1796 got: dst.len(),
1797 });
1798 }
1799
1800 let warmup_end = first + period - 1;
1801 unsafe {
1802 let dst_uninit =
1803 std::slice::from_raw_parts_mut(dst.as_mut_ptr() as *mut MaybeUninit<f64>, dst.len());
1804 init_matrix_prefixes(dst_uninit, dst.len(), &[warmup_end]);
1805 }
1806
1807 smma_compute_into(data, period, first, chosen, dst);
1808
1809 Ok(())
1810}
1811
1812#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1813#[inline]
1814pub fn smma_into(input: &SmmaInput, out: &mut [f64]) -> Result<(), SmmaError> {
1815 smma_into_slice(out, input, Kernel::Auto)
1816}
1817
1818#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1819#[wasm_bindgen(js_name = "smma")]
1820pub fn smma_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1821 let params = SmmaParams {
1822 period: Some(period),
1823 };
1824 let input = SmmaInput::from_slice(data, params);
1825
1826 let mut output = vec![0.0; data.len()];
1827
1828 smma_into_slice(&mut output, &input, Kernel::Auto)
1829 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1830
1831 Ok(output)
1832}
1833
1834#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1835#[derive(Serialize, Deserialize)]
1836pub struct SmmaBatchConfig {
1837 pub period_range: (usize, usize, usize),
1838}
1839
1840#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1841#[derive(Serialize, Deserialize)]
1842pub struct SmmaBatchJsOutput {
1843 pub values: Vec<f64>,
1844 pub combos: Vec<SmmaParams>,
1845 pub rows: usize,
1846 pub cols: usize,
1847}
1848
1849#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1850#[wasm_bindgen(js_name = "smma_batch")]
1851pub fn smma_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1852 let config: SmmaBatchConfig = serde_wasm_bindgen::from_value(config)
1853 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1854
1855 let sweep = SmmaBatchRange {
1856 period: config.period_range,
1857 };
1858
1859 let output = smma_batch_inner(data, &sweep, Kernel::Auto, false)
1860 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1861
1862 let js_output = SmmaBatchJsOutput {
1863 values: output.values,
1864 combos: output.combos,
1865 rows: output.rows,
1866 cols: output.cols,
1867 };
1868
1869 serde_wasm_bindgen::to_value(&js_output)
1870 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1871}
1872
1873#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1874#[wasm_bindgen(js_name = "smma_batch_legacy")]
1875pub fn smma_batch_js(
1876 data: &[f64],
1877 period_start: usize,
1878 period_end: usize,
1879 period_step: usize,
1880) -> Result<Vec<f64>, JsValue> {
1881 let sweep = SmmaBatchRange {
1882 period: (period_start, period_end, period_step),
1883 };
1884
1885 smma_batch_with_kernel(data, &sweep, Kernel::Auto)
1886 .map(|output| output.values)
1887 .map_err(|e| JsValue::from_str(&e.to_string()))
1888}
1889
1890#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1891#[wasm_bindgen(js_name = "smma_batch_metadata")]
1892
1893pub fn smma_batch_metadata_js(
1894 period_start: usize,
1895 period_end: usize,
1896 period_step: usize,
1897) -> Vec<usize> {
1898 let range = SmmaBatchRange {
1899 period: (period_start, period_end, period_step),
1900 };
1901 let combos = expand_grid(&range);
1902 combos.iter().map(|c| c.period.unwrap_or(7)).collect()
1903}
1904
1905#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1906#[wasm_bindgen(js_name = "smma_batch_rows_cols")]
1907
1908pub fn smma_batch_rows_cols_js(
1909 period_start: usize,
1910 period_end: usize,
1911 period_step: usize,
1912 data_len: usize,
1913) -> Vec<usize> {
1914 let range = SmmaBatchRange {
1915 period: (period_start, period_end, period_step),
1916 };
1917 let combos = expand_grid(&range);
1918 vec![combos.len(), data_len]
1919}
1920
1921#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1922#[wasm_bindgen]
1923pub fn smma_alloc(len: usize) -> *mut f64 {
1924 let mut vec = Vec::<f64>::with_capacity(len);
1925 let ptr = vec.as_mut_ptr();
1926 std::mem::forget(vec);
1927 ptr
1928}
1929
1930#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1931#[wasm_bindgen]
1932pub fn smma_free(ptr: *mut f64, len: usize) {
1933 if !ptr.is_null() {
1934 unsafe {
1935 let _ = Vec::from_raw_parts(ptr, len, len);
1936 }
1937 }
1938}
1939
1940#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1941#[wasm_bindgen]
1942pub fn smma_into(
1943 in_ptr: *const f64,
1944 out_ptr: *mut f64,
1945 len: usize,
1946 period: usize,
1947) -> Result<(), JsValue> {
1948 if in_ptr.is_null() || out_ptr.is_null() {
1949 return Err(JsValue::from_str("null pointer passed to smma_into"));
1950 }
1951
1952 unsafe {
1953 let data = std::slice::from_raw_parts(in_ptr, len);
1954
1955 if period == 0 || period > len {
1956 return Err(JsValue::from_str("Invalid period"));
1957 }
1958
1959 let params = SmmaParams {
1960 period: Some(period),
1961 };
1962 let input = SmmaInput::from_slice(data, params);
1963
1964 if in_ptr == out_ptr as *const f64 {
1965 let mut temp = vec![0.0; len];
1966 smma_into_slice(&mut temp, &input, Kernel::Auto)
1967 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1968
1969 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1970 out.copy_from_slice(&temp);
1971 } else {
1972 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1973 smma_into_slice(out, &input, Kernel::Auto)
1974 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1975 }
1976
1977 Ok(())
1978 }
1979}
1980
1981#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1982#[wasm_bindgen]
1983pub fn smma_batch_into(
1984 in_ptr: *const f64,
1985 out_ptr: *mut f64,
1986 len: usize,
1987 period_start: usize,
1988 period_end: usize,
1989 period_step: usize,
1990) -> Result<usize, JsValue> {
1991 if in_ptr.is_null() || out_ptr.is_null() {
1992 return Err(JsValue::from_str("null pointer passed to smma_batch_into"));
1993 }
1994
1995 unsafe {
1996 let data = std::slice::from_raw_parts(in_ptr, len);
1997
1998 let sweep = SmmaBatchRange {
1999 period: (period_start, period_end, period_step),
2000 };
2001
2002 let combos = expand_grid(&sweep);
2003 if combos.is_empty() {
2004 return Err(JsValue::from_str("invalid period range"));
2005 }
2006 let rows = combos.len();
2007 let cols = len;
2008 let total = rows
2009 .checked_mul(cols)
2010 .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
2011
2012 let out = std::slice::from_raw_parts_mut(out_ptr, total);
2013
2014 let batch = detect_best_batch_kernel();
2015 let simd = match batch {
2016 Kernel::Avx512Batch => Kernel::Avx512,
2017 Kernel::Avx2Batch => Kernel::Avx2,
2018 _ => Kernel::Scalar,
2019 };
2020
2021 smma_batch_inner_into(data, &sweep, simd, false, out)
2022 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2023
2024 Ok(rows)
2025 }
2026}