1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5 make_uninit_matrix,
6};
7use aligned_vec::{AVec, CACHELINE_ALIGN};
8#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
9use core::arch::x86_64::*;
10#[cfg(not(target_arch = "wasm32"))]
11use rayon::prelude::*;
12use std::convert::AsRef;
13use std::mem::MaybeUninit;
14use thiserror::Error;
15
16#[cfg(all(feature = "python", feature = "cuda"))]
17use crate::cuda::cuda_available;
18#[cfg(all(feature = "python", feature = "cuda"))]
19use crate::cuda::moving_averages::CudaSwma;
20#[cfg(all(feature = "python", feature = "cuda"))]
21use crate::utilities::dlpack_cuda::DeviceArrayF32Py;
22#[cfg(feature = "python")]
23use crate::utilities::kernel_validation::validate_kernel;
24#[cfg(feature = "python")]
25use numpy::{IntoPyArray, PyArray1};
26#[cfg(feature = "python")]
27use pyo3::exceptions::PyValueError;
28#[cfg(feature = "python")]
29use pyo3::prelude::*;
30#[cfg(feature = "python")]
31use pyo3::types::{PyDict, PyList};
32#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
33use serde::{Deserialize, Serialize};
34#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
35use wasm_bindgen::prelude::*;
36
37impl<'a> AsRef<[f64]> for SwmaInput<'a> {
38 #[inline(always)]
39 fn as_ref(&self) -> &[f64] {
40 match &self.data {
41 SwmaData::Slice(slice) => slice,
42 SwmaData::Candles { candles, source } => source_type(candles, source),
43 }
44 }
45}
46
47#[derive(Debug, Clone)]
48pub enum SwmaData<'a> {
49 Candles {
50 candles: &'a Candles,
51 source: &'a str,
52 },
53 Slice(&'a [f64]),
54}
55
56#[derive(Debug, Clone)]
57pub struct SwmaOutput {
58 pub values: Vec<f64>,
59}
60
61#[derive(Debug, Clone)]
62#[cfg_attr(
63 all(target_arch = "wasm32", feature = "wasm"),
64 derive(Serialize, Deserialize)
65)]
66pub struct SwmaParams {
67 pub period: Option<usize>,
68}
69
70impl Default for SwmaParams {
71 fn default() -> Self {
72 Self { period: Some(5) }
73 }
74}
75
76#[derive(Debug, Clone)]
77pub struct SwmaInput<'a> {
78 pub data: SwmaData<'a>,
79 pub params: SwmaParams,
80}
81
82impl<'a> SwmaInput<'a> {
83 #[inline]
84 pub fn from_candles(c: &'a Candles, s: &'a str, p: SwmaParams) -> Self {
85 Self {
86 data: SwmaData::Candles {
87 candles: c,
88 source: s,
89 },
90 params: p,
91 }
92 }
93 #[inline]
94 pub fn from_slice(sl: &'a [f64], p: SwmaParams) -> Self {
95 Self {
96 data: SwmaData::Slice(sl),
97 params: p,
98 }
99 }
100 #[inline]
101 pub fn with_default_candles(c: &'a Candles) -> Self {
102 Self::from_candles(c, "close", SwmaParams::default())
103 }
104 #[inline]
105 pub fn get_period(&self) -> usize {
106 self.params.period.unwrap_or(5)
107 }
108}
109
110#[derive(Copy, Clone, Debug)]
111pub struct SwmaBuilder {
112 period: Option<usize>,
113 kernel: Kernel,
114}
115
116impl Default for SwmaBuilder {
117 fn default() -> Self {
118 Self {
119 period: None,
120 kernel: Kernel::Auto,
121 }
122 }
123}
124
125impl SwmaBuilder {
126 #[inline(always)]
127 pub fn new() -> Self {
128 Self::default()
129 }
130 #[inline(always)]
131 pub fn period(mut self, n: usize) -> Self {
132 self.period = Some(n);
133 self
134 }
135 #[inline(always)]
136 pub fn kernel(mut self, k: Kernel) -> Self {
137 self.kernel = k;
138 self
139 }
140
141 #[inline(always)]
142 pub fn apply(self, c: &Candles) -> Result<SwmaOutput, SwmaError> {
143 let p = SwmaParams {
144 period: self.period,
145 };
146 let i = SwmaInput::from_candles(c, "close", p);
147 swma_with_kernel(&i, self.kernel)
148 }
149
150 #[inline(always)]
151 pub fn apply_slice(self, d: &[f64]) -> Result<SwmaOutput, SwmaError> {
152 let p = SwmaParams {
153 period: self.period,
154 };
155 let i = SwmaInput::from_slice(d, p);
156 swma_with_kernel(&i, self.kernel)
157 }
158
159 #[inline(always)]
160 pub fn into_stream(self) -> Result<SwmaStream, SwmaError> {
161 let p = SwmaParams {
162 period: self.period,
163 };
164 SwmaStream::try_new(p)
165 }
166}
167
168#[derive(Debug, Error)]
169pub enum SwmaError {
170 #[error("swma: Input data slice is empty.")]
171 EmptyInputData,
172 #[error("swma: All values are NaN.")]
173 AllValuesNaN,
174
175 #[error(
176 "swma: Invalid period: period = {period}, data length = {data_len}. Period must be between 1 and data length."
177 )]
178 InvalidPeriod { period: usize, data_len: usize },
179
180 #[error("swma: Not enough valid data: needed = {needed}, valid = {valid}")]
181 NotEnoughValidData { needed: usize, valid: usize },
182
183 #[error("swma: Output length mismatch: expected {expected}, got {got}")]
184 OutputLengthMismatch { expected: usize, got: usize },
185
186 #[error("swma: Invalid range expansion: start={start}, end={end}, step={step}")]
187 InvalidRange {
188 start: usize,
189 end: usize,
190 step: usize,
191 },
192
193 #[error("swma: Invalid kernel passed to batch path: {0:?}")]
194 InvalidKernelForBatch(Kernel),
195}
196
197#[inline]
198pub fn swma(input: &SwmaInput) -> Result<SwmaOutput, SwmaError> {
199 swma_with_kernel(input, Kernel::Auto)
200}
201
202#[inline]
203fn swma_prepare<'a>(
204 input: &'a SwmaInput,
205 kernel: Kernel,
206) -> Result<(&'a [f64], AVec<f64>, usize, usize, Kernel), SwmaError> {
207 let data: &[f64] = input.as_ref();
208 let len = data.len();
209 if len == 0 {
210 return Err(SwmaError::EmptyInputData);
211 }
212
213 let first = data
214 .iter()
215 .position(|x| !x.is_nan())
216 .ok_or(SwmaError::AllValuesNaN)?;
217 let period = input.get_period();
218
219 if period == 0 || period > len {
220 return Err(SwmaError::InvalidPeriod {
221 period,
222 data_len: len,
223 });
224 }
225 if len - first < period {
226 return Err(SwmaError::NotEnoughValidData {
227 needed: period,
228 valid: len - first,
229 });
230 }
231
232 let weights = build_symmetric_triangle_avec(period);
233 let chosen = match kernel {
234 Kernel::Auto => Kernel::Scalar,
235 k => k,
236 };
237
238 Ok((data, weights, period, first, chosen))
239}
240
241#[inline(always)]
242fn swma_compute_into(
243 data: &[f64],
244 weights: &[f64],
245 period: usize,
246 first: usize,
247 kernel: Kernel,
248 out: &mut [f64],
249) {
250 unsafe {
251 match kernel {
252 Kernel::Scalar | Kernel::ScalarBatch => swma_scalar(data, weights, period, first, out),
253 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
254 Kernel::Avx2 | Kernel::Avx2Batch => swma_avx2(data, weights, period, first, out),
255 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
256 Kernel::Avx512 | Kernel::Avx512Batch => swma_avx512(data, weights, period, first, out),
257 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
258 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
259 swma_scalar(data, weights, period, first, out)
260 }
261 _ => unreachable!(),
262 }
263 }
264}
265
266pub fn swma_with_kernel(input: &SwmaInput, kernel: Kernel) -> Result<SwmaOutput, SwmaError> {
267 let (data, weights, period, first, chosen) = swma_prepare(input, kernel)?;
268
269 let len = data.len();
270 let warm = first + period - 1;
271 let mut out = alloc_with_nan_prefix(len, warm);
272
273 swma_compute_into(data, &weights, period, first, chosen, &mut out);
274
275 Ok(SwmaOutput { values: out })
276}
277
278#[inline]
279pub fn swma_into_slice(dst: &mut [f64], input: &SwmaInput, kern: Kernel) -> Result<(), SwmaError> {
280 let (data, weights, period, first, chosen) = swma_prepare(input, kern)?;
281
282 if dst.len() != data.len() {
283 return Err(SwmaError::OutputLengthMismatch {
284 expected: data.len(),
285 got: dst.len(),
286 });
287 }
288
289 swma_compute_into(data, &weights, period, first, chosen, dst);
290
291 let warmup_end = first + period - 1;
292 for v in &mut dst[..warmup_end] {
293 *v = f64::from_bits(0x7ff8_0000_0000_0000);
294 }
295
296 Ok(())
297}
298
299#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
300#[inline]
301pub fn swma_into(input: &SwmaInput, out: &mut [f64]) -> Result<(), SwmaError> {
302 let (data, weights, period, first, chosen) = swma_prepare(input, Kernel::Auto)?;
303
304 if out.len() != data.len() {
305 return Err(SwmaError::OutputLengthMismatch {
306 expected: data.len(),
307 got: out.len(),
308 });
309 }
310
311 let warm = (first + period - 1).min(out.len());
312 for v in &mut out[..warm] {
313 *v = f64::from_bits(0x7ff8_0000_0000_0000);
314 }
315
316 swma_compute_into(data, &weights, period, first, chosen, out);
317 Ok(())
318}
319
320#[inline(always)]
321fn build_symmetric_triangle_vec(n: usize) -> Vec<f64> {
322 let mut w = Vec::with_capacity(n);
323 if n == 1 {
324 w.push(1.0);
325 } else if n == 2 {
326 w.extend_from_slice(&[0.5, 0.5]);
327 } else if n % 2 == 0 {
328 let half = n / 2;
329 for i in 1..=half {
330 w.push(i as f64);
331 }
332 for i in (1..=half).rev() {
333 w.push(i as f64);
334 }
335 let sum: f64 = triangle_weight_sum(n);
336 for x in &mut w {
337 *x /= sum;
338 }
339 } else {
340 let half_plus = (n + 1) / 2;
341 for i in 1..=half_plus {
342 w.push(i as f64);
343 }
344 for i in (1..half_plus).rev() {
345 w.push(i as f64);
346 }
347 let sum: f64 = triangle_weight_sum(n);
348 for x in &mut w {
349 *x /= sum;
350 }
351 }
352 w
353}
354
355#[inline(always)]
356fn triangle_weight_sum(n: usize) -> f64 {
357 if (n & 1) == 0 {
358 let m = (n >> 1) as f64;
359 m * (m + 1.0)
360 } else {
361 let m = ((n + 1) >> 1) as f64;
362 m * m
363 }
364}
365
366#[inline(always)]
367fn build_symmetric_triangle_avec(n: usize) -> AVec<f64> {
368 let mut weights: AVec<f64> = AVec::with_capacity(CACHELINE_ALIGN, n);
369
370 if n == 1 {
371 weights.push(1.0);
372 } else if n == 2 {
373 weights.push(0.5);
374 weights.push(0.5);
375 } else if n % 2 == 0 {
376 let half = n / 2;
377
378 for i in 1..=half {
379 weights.push(i as f64);
380 }
381
382 for i in (1..=half).rev() {
383 weights.push(i as f64);
384 }
385 } else {
386 let half_plus = (n + 1) / 2;
387
388 for i in 1..=half_plus {
389 weights.push(i as f64);
390 }
391
392 for i in (1..half_plus).rev() {
393 weights.push(i as f64);
394 }
395 }
396
397 let sum: f64 = if n <= 2 { 1.0 } else { triangle_weight_sum(n) };
398 for w in weights.iter_mut() {
399 *w /= sum;
400 }
401
402 weights
403}
404
405#[inline]
406pub fn swma_scalar(
407 data: &[f64],
408 _weights: &[f64],
409 period: usize,
410 first_val: usize,
411 out: &mut [f64],
412) {
413 debug_assert!(out.len() >= data.len());
414 debug_assert!(period >= 1);
415
416 let len = data.len();
417 if len == 0 {
418 return;
419 }
420
421 let (a, b) = if (period & 1) != 0 {
422 let m = (period + 1) >> 1;
423 (m, m)
424 } else {
425 let m = period >> 1;
426 (m, m + 1)
427 };
428
429 if period == 1 {
430 unsafe {
431 for i in first_val..len {
432 *out.get_unchecked_mut(i) = *data.get_unchecked(i);
433 }
434 }
435 return;
436 }
437
438 if period == 2 {
439 unsafe {
440 for i in (first_val + 1)..len {
441 *out.get_unchecked_mut(i) =
442 (*data.get_unchecked(i - 1) + *data.get_unchecked(i)) * 0.5;
443 }
444 }
445 return;
446 }
447
448 let inv_ab = 1.0 / ((a as f64) * (b as f64));
449 let start_full_a = first_val + a - 1;
450 let start_full_ab = first_val + period - 1;
451
452 let mut ring = AVec::<f64>::with_capacity(CACHELINE_ALIGN, b);
453 ring.resize(b, 0.0);
454 let mut rb_idx = 0usize;
455
456 let mut s1_sum = 0.0_f64;
457 let mut s2_sum = 0.0_f64;
458
459 unsafe {
460 for i in first_val..len {
461 s1_sum += *data.get_unchecked(i);
462
463 if i >= start_full_a {
464 let old = *ring.get_unchecked(rb_idx);
465 s2_sum = s2_sum + (s1_sum - old);
466 *ring.get_unchecked_mut(rb_idx) = s1_sum;
467
468 rb_idx += 1;
469 if rb_idx == b {
470 rb_idx = 0;
471 }
472
473 if i >= start_full_ab {
474 *out.get_unchecked_mut(i) = s2_sum * inv_ab;
475 }
476
477 s1_sum -= *data.get_unchecked(i + 1 - a);
478 }
479 }
480 }
481}
482
483#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
484#[inline]
485pub fn swma_avx512(
486 data: &[f64],
487 weights: &[f64],
488 period: usize,
489 first_valid: usize,
490 out: &mut [f64],
491) {
492 if period <= 32 {
493 unsafe { swma_avx512_short(data, weights, period, first_valid, out) }
494 } else {
495 unsafe { swma_avx512_long(data, weights, period, first_valid, out) }
496 }
497}
498
499#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
500#[target_feature(enable = "avx2,fma")]
501unsafe fn swma_avx2(
502 data: &[f64],
503 weights: &[f64],
504 period: usize,
505 first_valid: usize,
506 out: &mut [f64],
507) {
508 swma_scalar(data, weights, period, first_valid, out)
509}
510
511#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
512#[target_feature(enable = "avx512f,fma")]
513unsafe fn swma_avx512_short(
514 data: &[f64],
515 weights: &[f64],
516 period: usize,
517 first_valid: usize,
518 out: &mut [f64],
519) {
520 swma_scalar(data, weights, period, first_valid, out)
521}
522
523#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
524#[target_feature(enable = "avx512f,avx512dq,fma")]
525unsafe fn swma_avx512_long(
526 data: &[f64],
527 weights: &[f64],
528 period: usize,
529 first_valid: usize,
530 out: &mut [f64],
531) {
532 swma_scalar(data, weights, period, first_valid, out)
533}
534
535#[derive(Debug, Clone)]
536pub struct SwmaStream {
537 period: usize,
538
539 a: usize,
540 b: usize,
541 inv_ab: f64,
542
543 ring_a: aligned_vec::AVec<f64>,
544 idx_a: usize,
545 cnt_a: usize,
546 s1_sum: f64,
547
548 ring_b: aligned_vec::AVec<f64>,
549 idx_b: usize,
550 cnt_b: usize,
551 s2_sum: f64,
552}
553
554impl SwmaStream {
555 pub fn try_new(params: SwmaParams) -> Result<Self, SwmaError> {
556 let period = params.period.unwrap_or(5);
557 if period == 0 {
558 return Err(SwmaError::InvalidPeriod {
559 period,
560 data_len: 0,
561 });
562 }
563
564 let (a, b) = if (period & 1) != 0 {
565 let m = (period + 1) >> 1;
566 (m, m)
567 } else {
568 let m = period >> 1;
569 (m, m + 1)
570 };
571
572 let mut ring_a = aligned_vec::AVec::<f64>::with_capacity(aligned_vec::CACHELINE_ALIGN, a);
573 ring_a.resize(a, 0.0);
574
575 let mut ring_b = aligned_vec::AVec::<f64>::with_capacity(aligned_vec::CACHELINE_ALIGN, b);
576 ring_b.resize(b, 0.0);
577
578 Ok(Self {
579 period,
580 a,
581 b,
582 inv_ab: 1.0 / ((a as f64) * (b as f64)),
583 ring_a,
584 idx_a: 0,
585 cnt_a: 0,
586 s1_sum: 0.0,
587 ring_b,
588 idx_b: 0,
589 cnt_b: 0,
590 s2_sum: 0.0,
591 })
592 }
593
594 #[inline(always)]
595 pub fn update(&mut self, x: f64) -> Option<f64> {
596 let ia = self.idx_a;
597
598 let old_a = self.ring_a[ia];
599
600 if self.cnt_a == self.a {
601 self.s1_sum -= old_a;
602 } else {
603 self.cnt_a += 1;
604 }
605 self.ring_a[ia] = x;
606 self.s1_sum += x;
607
608 self.idx_a = ia + 1;
609 if self.idx_a == self.a {
610 self.idx_a = 0;
611 }
612
613 if self.cnt_a == self.a {
614 let ib = self.idx_b;
615 let old_s1 = self.ring_b[ib];
616
617 if self.cnt_b == self.b {
618 self.s2_sum -= old_s1;
619 } else {
620 self.cnt_b += 1;
621 }
622 self.ring_b[ib] = self.s1_sum;
623 self.s2_sum += self.s1_sum;
624
625 self.idx_b = ib + 1;
626 if self.idx_b == self.b {
627 self.idx_b = 0;
628 }
629
630 if self.cnt_b == self.b {
631 return Some(self.s2_sum * self.inv_ab);
632 }
633 }
634
635 None
636 }
637}
638
639#[derive(Clone, Debug)]
640pub struct SwmaBatchRange {
641 pub period: (usize, usize, usize),
642}
643
644impl Default for SwmaBatchRange {
645 fn default() -> Self {
646 Self {
647 period: (5, 254, 1),
648 }
649 }
650}
651
652#[derive(Clone, Debug, Default)]
653pub struct SwmaBatchBuilder {
654 range: SwmaBatchRange,
655 kernel: Kernel,
656}
657
658impl SwmaBatchBuilder {
659 pub fn new() -> Self {
660 Self::default()
661 }
662 pub fn kernel(mut self, k: Kernel) -> Self {
663 self.kernel = k;
664 self
665 }
666
667 #[inline]
668 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
669 self.range.period = (start, end, step);
670 self
671 }
672 #[inline]
673 pub fn period_static(mut self, p: usize) -> Self {
674 self.range.period = (p, p, 0);
675 self
676 }
677
678 pub fn apply_slice(self, data: &[f64]) -> Result<SwmaBatchOutput, SwmaError> {
679 swma_batch_with_kernel(data, &self.range, self.kernel)
680 }
681
682 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<SwmaBatchOutput, SwmaError> {
683 SwmaBatchBuilder::new().kernel(k).apply_slice(data)
684 }
685
686 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<SwmaBatchOutput, SwmaError> {
687 let slice = source_type(c, src);
688 self.apply_slice(slice)
689 }
690
691 pub fn with_default_candles(c: &Candles) -> Result<SwmaBatchOutput, SwmaError> {
692 SwmaBatchBuilder::new()
693 .kernel(Kernel::Auto)
694 .apply_candles(c, "close")
695 }
696}
697
698pub fn swma_batch_with_kernel(
699 data: &[f64],
700 sweep: &SwmaBatchRange,
701 k: Kernel,
702) -> Result<SwmaBatchOutput, SwmaError> {
703 let kernel = match k {
704 Kernel::Auto => detect_best_batch_kernel(),
705 other if other.is_batch() => other,
706 _ => return Err(SwmaError::InvalidKernelForBatch(k)),
707 };
708
709 let simd = match kernel {
710 Kernel::Avx512Batch => Kernel::Avx512,
711 Kernel::Avx2Batch => Kernel::Avx2,
712 Kernel::ScalarBatch => Kernel::Scalar,
713 _ => unreachable!(),
714 };
715 swma_batch_par_slice(data, sweep, simd)
716}
717
718#[derive(Clone, Debug)]
719pub struct SwmaBatchOutput {
720 pub values: Vec<f64>,
721 pub combos: Vec<SwmaParams>,
722 pub rows: usize,
723 pub cols: usize,
724}
725
726impl SwmaBatchOutput {
727 pub fn row_for_params(&self, p: &SwmaParams) -> Option<usize> {
728 self.combos
729 .iter()
730 .position(|c| c.period.unwrap_or(5) == p.period.unwrap_or(5))
731 }
732
733 pub fn values_for(&self, p: &SwmaParams) -> Option<&[f64]> {
734 self.row_for_params(p).map(|row| {
735 let start = row * self.cols;
736 &self.values[start..start + self.cols]
737 })
738 }
739}
740
741#[inline(always)]
742fn expand_grid(r: &SwmaBatchRange) -> Vec<SwmaParams> {
743 fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
744 if step == 0 || start == end {
745 return vec![start];
746 }
747 if start < end {
748 return (start..=end).step_by(step.max(1)).collect();
749 }
750
751 let mut v = Vec::new();
752 let mut cur = start;
753 loop {
754 v.push(cur);
755 if cur <= end {
756 break;
757 }
758 match cur.checked_sub(step.max(1)) {
759 Some(next) => {
760 cur = next;
761 if cur < end {
762 break;
763 }
764 }
765 None => break,
766 }
767 }
768 v
769 }
770 let periods = axis_usize(r.period);
771 let mut out = Vec::with_capacity(periods.len());
772 for &p in &periods {
773 out.push(SwmaParams { period: Some(p) });
774 }
775 out
776}
777
778#[inline(always)]
779pub fn swma_batch_slice(
780 data: &[f64],
781 sweep: &SwmaBatchRange,
782 kern: Kernel,
783) -> Result<SwmaBatchOutput, SwmaError> {
784 swma_batch_inner(data, sweep, kern, false)
785}
786
787#[inline(always)]
788pub fn swma_batch_par_slice(
789 data: &[f64],
790 sweep: &SwmaBatchRange,
791 kern: Kernel,
792) -> Result<SwmaBatchOutput, SwmaError> {
793 swma_batch_inner(data, sweep, kern, true)
794}
795
796pub fn swma_batch_into_slice(
797 dst: &mut [f64],
798 data: &[f64],
799 sweep: &SwmaBatchRange,
800 k: Kernel,
801) -> Result<Vec<SwmaParams>, SwmaError> {
802 swma_batch_inner_into(data, sweep, k, true, dst)
803}
804
805#[inline(always)]
806fn swma_batch_inner(
807 data: &[f64],
808 sweep: &SwmaBatchRange,
809 kern: Kernel,
810 parallel: bool,
811) -> Result<SwmaBatchOutput, SwmaError> {
812 let combos = expand_grid(sweep);
813 if combos.is_empty() {
814 let (s, e, t) = sweep.period;
815 return Err(SwmaError::InvalidRange {
816 start: s,
817 end: e,
818 step: t,
819 });
820 }
821
822 let len = data.len();
823 if len == 0 {
824 return Err(SwmaError::EmptyInputData);
825 }
826
827 let first = data
828 .iter()
829 .position(|x| !x.is_nan())
830 .ok_or(SwmaError::AllValuesNaN)?;
831 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
832
833 if max_p == 0 || max_p > len {
834 return Err(SwmaError::InvalidPeriod {
835 period: max_p,
836 data_len: len,
837 });
838 }
839 if len - first < max_p {
840 return Err(SwmaError::NotEnoughValidData {
841 needed: max_p,
842 valid: len - first,
843 });
844 }
845
846 let rows = combos.len();
847 let cols = data.len();
848 let cap = rows.checked_mul(max_p).ok_or_else(|| {
849 let (s, e, t) = sweep.period;
850 SwmaError::InvalidRange {
851 start: s,
852 end: e,
853 step: t,
854 }
855 })?;
856 let mut flat_w = AVec::<f64>::with_capacity(CACHELINE_ALIGN, cap);
857 flat_w.resize(cap, 0.0);
858
859 for (row, combo) in combos.iter().enumerate() {
860 let period = combo.period.unwrap();
861 let w_start = row * max_p;
862
863 if period == 1 {
864 flat_w[w_start] = 1.0;
865 } else if period == 2 {
866 flat_w[w_start] = 0.5;
867 flat_w[w_start + 1] = 0.5;
868 } else if period % 2 == 0 {
869 let half = period / 2;
870
871 for i in 1..=half {
872 flat_w[w_start + i - 1] = i as f64;
873 }
874
875 for i in (1..=half).rev() {
876 flat_w[w_start + period - i] = i as f64;
877 }
878
879 let sum: f64 = flat_w[w_start..w_start + period].iter().sum();
880 for i in 0..period {
881 flat_w[w_start + i] /= sum;
882 }
883 } else {
884 let half_plus = (period + 1) / 2;
885
886 for i in 1..=half_plus {
887 flat_w[w_start + i - 1] = i as f64;
888 }
889
890 for i in (1..half_plus).rev() {
891 flat_w[w_start + period - i] = i as f64;
892 }
893
894 let sum: f64 = flat_w[w_start..w_start + period].iter().sum();
895 for i in 0..period {
896 flat_w[w_start + i] /= sum;
897 }
898 }
899 }
900
901 let warm: Vec<usize> = combos
902 .iter()
903 .map(|c| first + c.period.unwrap() - 1)
904 .collect();
905
906 let _ = rows.checked_mul(cols).ok_or_else(|| {
907 let (s, e, t) = sweep.period;
908 SwmaError::InvalidRange {
909 start: s,
910 end: e,
911 step: t,
912 }
913 })?;
914 let mut buf_mu = make_uninit_matrix(rows, cols);
915 init_matrix_prefixes(&mut buf_mu, cols, &warm);
916
917 let actual_kern = match kern {
918 Kernel::Auto => detect_best_batch_kernel(),
919 k => k,
920 };
921 let simd = match actual_kern {
922 Kernel::Avx512Batch => Kernel::Avx512,
923 Kernel::Avx2Batch => Kernel::Avx2,
924 Kernel::ScalarBatch => Kernel::Scalar,
925
926 other => other,
927 };
928
929 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
930 let period = combos[row].period.unwrap();
931 let w_ptr = flat_w.as_ptr().add(row * max_p);
932 let out_row =
933 core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
934 match simd {
935 Kernel::Scalar => swma_row_scalar(data, first, period, w_ptr, out_row),
936 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
937 Kernel::Avx2 => swma_row_avx2(data, first, period, w_ptr, out_row),
938 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
939 Kernel::Avx512 => swma_row_avx512(data, first, period, w_ptr, out_row),
940 _ => swma_row_scalar(data, first, period, w_ptr, out_row),
941 }
942 };
943
944 {
945 use std::mem::MaybeUninit;
946 let rows_mut: &mut [MaybeUninit<f64>] = &mut buf_mu;
947 #[cfg(not(target_arch = "wasm32"))]
948 if parallel {
949 use rayon::prelude::*;
950 rows_mut
951 .par_chunks_mut(cols)
952 .enumerate()
953 .for_each(|(row, slice)| do_row(row, slice));
954 } else {
955 for (row, slice) in rows_mut.chunks_mut(cols).enumerate() {
956 do_row(row, slice);
957 }
958 }
959 #[cfg(target_arch = "wasm32")]
960 {
961 for (row, slice) in rows_mut.chunks_mut(cols).enumerate() {
962 do_row(row, slice);
963 }
964 }
965 }
966
967 use core::mem::ManuallyDrop;
968 let mut guard = ManuallyDrop::new(buf_mu);
969 let values = unsafe {
970 Vec::from_raw_parts(
971 guard.as_mut_ptr() as *mut f64,
972 guard.len(),
973 guard.capacity(),
974 )
975 };
976
977 Ok(SwmaBatchOutput {
978 values,
979 combos,
980 rows,
981 cols,
982 })
983}
984
985#[inline(always)]
986fn swma_batch_inner_into(
987 data: &[f64],
988 sweep: &SwmaBatchRange,
989 kern: Kernel,
990 parallel: bool,
991 out: &mut [f64],
992) -> Result<Vec<SwmaParams>, SwmaError> {
993 let combos = expand_grid(sweep);
994 if combos.is_empty() {
995 let (s, e, t) = sweep.period;
996 return Err(SwmaError::InvalidRange {
997 start: s,
998 end: e,
999 step: t,
1000 });
1001 }
1002
1003 let len = data.len();
1004 if len == 0 {
1005 return Err(SwmaError::EmptyInputData);
1006 }
1007
1008 let first = data
1009 .iter()
1010 .position(|x| !x.is_nan())
1011 .ok_or(SwmaError::AllValuesNaN)?;
1012 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1013
1014 if max_p == 0 || max_p > len {
1015 return Err(SwmaError::InvalidPeriod {
1016 period: max_p,
1017 data_len: len,
1018 });
1019 }
1020 if len - first < max_p {
1021 return Err(SwmaError::NotEnoughValidData {
1022 needed: max_p,
1023 valid: len - first,
1024 });
1025 }
1026
1027 let rows = combos.len();
1028 let cols = data.len();
1029 let cap = rows.checked_mul(max_p).ok_or_else(|| {
1030 let (s, e, t) = sweep.period;
1031 SwmaError::InvalidRange {
1032 start: s,
1033 end: e,
1034 step: t,
1035 }
1036 })?;
1037 let mut flat_w = AVec::<f64>::with_capacity(CACHELINE_ALIGN, cap);
1038 flat_w.resize(cap, 0.0);
1039
1040 for (row, combo) in combos.iter().enumerate() {
1041 let period = combo.period.unwrap();
1042 let w_start = row * max_p;
1043
1044 if period == 1 {
1045 flat_w[w_start] = 1.0;
1046 } else if period == 2 {
1047 flat_w[w_start] = 0.5;
1048 flat_w[w_start + 1] = 0.5;
1049 } else if period % 2 == 0 {
1050 let half = period / 2;
1051
1052 for i in 1..=half {
1053 flat_w[w_start + i - 1] = i as f64;
1054 }
1055
1056 for i in (1..=half).rev() {
1057 flat_w[w_start + period - i] = i as f64;
1058 }
1059
1060 let sum: f64 = flat_w[w_start..w_start + period].iter().sum();
1061 for i in 0..period {
1062 flat_w[w_start + i] /= sum;
1063 }
1064 } else {
1065 let half_plus = (period + 1) / 2;
1066
1067 for i in 1..=half_plus {
1068 flat_w[w_start + i - 1] = i as f64;
1069 }
1070
1071 for i in (1..half_plus).rev() {
1072 flat_w[w_start + period - i] = i as f64;
1073 }
1074
1075 let sum: f64 = flat_w[w_start..w_start + period].iter().sum();
1076 for i in 0..period {
1077 flat_w[w_start + i] /= sum;
1078 }
1079 }
1080 }
1081
1082 let warm: Vec<usize> = combos
1083 .iter()
1084 .map(|c| first + c.period.unwrap() - 1)
1085 .collect();
1086 let expected_len = rows.checked_mul(cols).ok_or_else(|| {
1087 let (s, e, t) = sweep.period;
1088 SwmaError::InvalidRange {
1089 start: s,
1090 end: e,
1091 step: t,
1092 }
1093 })?;
1094 if out.len() != expected_len {
1095 return Err(SwmaError::OutputLengthMismatch {
1096 expected: expected_len,
1097 got: out.len(),
1098 });
1099 }
1100 let out_uninit = unsafe {
1101 std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
1102 };
1103 init_matrix_prefixes(out_uninit, cols, &warm);
1104
1105 let actual_kern = match kern {
1106 Kernel::Auto => detect_best_batch_kernel(),
1107 k => k,
1108 };
1109 let simd = match actual_kern {
1110 Kernel::Avx512Batch => Kernel::Avx512,
1111 Kernel::Avx2Batch => Kernel::Avx2,
1112 Kernel::ScalarBatch => Kernel::Scalar,
1113 other => other,
1114 };
1115
1116 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
1117 let period = combos[row].period.unwrap();
1118 let w_ptr = flat_w.as_ptr().add(row * max_p);
1119 let out_row =
1120 core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
1121 match simd {
1122 Kernel::Scalar => swma_row_scalar(data, first, period, w_ptr, out_row),
1123 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1124 Kernel::Avx2 => swma_row_avx2(data, first, period, w_ptr, out_row),
1125 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1126 Kernel::Avx512 => swma_row_avx512(data, first, period, w_ptr, out_row),
1127 _ => swma_row_scalar(data, first, period, w_ptr, out_row),
1128 }
1129 };
1130
1131 if parallel {
1132 #[cfg(not(target_arch = "wasm32"))]
1133 {
1134 out_uninit
1135 .par_chunks_mut(cols)
1136 .enumerate()
1137 .for_each(|(row, slice)| do_row(row, slice));
1138 }
1139 #[cfg(target_arch = "wasm32")]
1140 {
1141 for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
1142 do_row(row, slice);
1143 }
1144 }
1145 } else {
1146 for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
1147 do_row(row, slice);
1148 }
1149 }
1150
1151 Ok(combos)
1152}
1153
1154#[inline(always)]
1155unsafe fn swma_row_scalar(
1156 data: &[f64],
1157 first: usize,
1158 period: usize,
1159 _w_ptr: *const f64,
1160 out: &mut [f64],
1161) {
1162 let len = data.len();
1163 if len == 0 {
1164 return;
1165 }
1166
1167 let (a, b) = if (period & 1) != 0 {
1168 let m = (period + 1) >> 1;
1169 (m, m)
1170 } else {
1171 let m = period >> 1;
1172 (m, m + 1)
1173 };
1174
1175 if period == 1 {
1176 for i in first..len {
1177 *out.get_unchecked_mut(i) = *data.get_unchecked(i);
1178 }
1179 return;
1180 }
1181 if period == 2 {
1182 for i in (first + 1)..len {
1183 *out.get_unchecked_mut(i) = (*data.get_unchecked(i - 1) + *data.get_unchecked(i)) * 0.5;
1184 }
1185 return;
1186 }
1187
1188 let inv_ab = 1.0 / ((a as f64) * (b as f64));
1189 let start_full_a = first + a - 1;
1190 let start_full_ab = first + period - 1;
1191
1192 let mut ring = AVec::<f64>::with_capacity(CACHELINE_ALIGN, b);
1193 ring.resize(b, 0.0);
1194 let mut rb_idx = 0usize;
1195
1196 let mut s1_sum = 0.0_f64;
1197 let mut s2_sum = 0.0_f64;
1198
1199 for i in first..len {
1200 s1_sum += *data.get_unchecked(i);
1201
1202 if i >= start_full_a {
1203 let old = *ring.get_unchecked(rb_idx);
1204 s2_sum = s2_sum + (s1_sum - old);
1205 *ring.get_unchecked_mut(rb_idx) = s1_sum;
1206 rb_idx += 1;
1207 if rb_idx == b {
1208 rb_idx = 0;
1209 }
1210
1211 if i >= start_full_ab {
1212 *out.get_unchecked_mut(i) = s2_sum * inv_ab;
1213 }
1214
1215 s1_sum -= *data.get_unchecked(i + 1 - a);
1216 }
1217 }
1218}
1219
1220#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1221#[target_feature(enable = "avx2,fma")]
1222unsafe fn swma_row_avx2(
1223 data: &[f64],
1224 first: usize,
1225 period: usize,
1226 w_ptr: *const f64,
1227 out: &mut [f64],
1228) {
1229 swma_row_scalar(data, first, period, w_ptr, out)
1230}
1231
1232#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1233#[target_feature(enable = "avx512f,avx512dq,fma")]
1234pub unsafe fn swma_row_avx512(
1235 data: &[f64],
1236 first: usize,
1237 period: usize,
1238 w_ptr: *const f64,
1239 out: &mut [f64],
1240) {
1241 if period <= 32 {
1242 swma_row_avx512_short(data, first, period, w_ptr, out);
1243 } else {
1244 swma_row_avx512_long(data, first, period, w_ptr, out);
1245 }
1246}
1247
1248#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1249#[target_feature(enable = "avx512f,fma")]
1250unsafe fn swma_row_avx512_short(
1251 data: &[f64],
1252 first: usize,
1253 period: usize,
1254 w_ptr: *const f64,
1255 out: &mut [f64],
1256) {
1257 swma_row_scalar(data, first, period, w_ptr, out)
1258}
1259
1260#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1261#[target_feature(enable = "avx512f,avx512dq,fma")]
1262unsafe fn swma_row_avx512_long(
1263 data: &[f64],
1264 first: usize,
1265 period: usize,
1266 w_ptr: *const f64,
1267 out: &mut [f64],
1268) {
1269 swma_row_scalar(data, first, period, w_ptr, out)
1270}
1271
1272#[cfg(test)]
1273mod tests {
1274 use super::*;
1275 use crate::skip_if_unsupported;
1276 use crate::utilities::data_loader::read_candles_from_csv;
1277
1278 fn check_swma_partial_params(
1279 test_name: &str,
1280 kernel: Kernel,
1281 ) -> Result<(), Box<dyn std::error::Error>> {
1282 skip_if_unsupported!(kernel, test_name);
1283 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1284 let candles = read_candles_from_csv(file_path)?;
1285 let default_params = SwmaParams { period: None };
1286 let input = SwmaInput::from_candles(&candles, "close", default_params);
1287 let output = swma_with_kernel(&input, kernel)?;
1288 assert_eq!(output.values.len(), candles.close.len());
1289 Ok(())
1290 }
1291
1292 fn check_swma_accuracy(
1293 test_name: &str,
1294 kernel: Kernel,
1295 ) -> Result<(), Box<dyn std::error::Error>> {
1296 skip_if_unsupported!(kernel, test_name);
1297 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1298 let candles = read_candles_from_csv(file_path)?;
1299 let input = SwmaInput::from_candles(&candles, "close", SwmaParams::default());
1300 let result = swma_with_kernel(&input, kernel)?;
1301 let expected_last_five = [
1302 59288.22222222222,
1303 59301.99999999999,
1304 59247.33333333333,
1305 59179.88888888889,
1306 59080.99999999999,
1307 ];
1308 let start = result.values.len().saturating_sub(5);
1309 for (i, &val) in result.values[start..].iter().enumerate() {
1310 let diff = (val - expected_last_five[i]).abs();
1311 assert!(
1312 diff < 1e-8,
1313 "[{}] SWMA {:?} mismatch at idx {}: got {}, expected {}",
1314 test_name,
1315 kernel,
1316 i,
1317 val,
1318 expected_last_five[i]
1319 );
1320 }
1321 Ok(())
1322 }
1323
1324 fn check_swma_default_candles(
1325 test_name: &str,
1326 kernel: Kernel,
1327 ) -> Result<(), Box<dyn std::error::Error>> {
1328 skip_if_unsupported!(kernel, test_name);
1329 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1330 let candles = read_candles_from_csv(file_path)?;
1331 let input = SwmaInput::with_default_candles(&candles);
1332 match input.data {
1333 SwmaData::Candles { source, .. } => assert_eq!(source, "close"),
1334 _ => panic!("Expected SwmaData::Candles"),
1335 }
1336 let output = swma_with_kernel(&input, kernel)?;
1337 assert_eq!(output.values.len(), candles.close.len());
1338 Ok(())
1339 }
1340
1341 fn check_swma_zero_period(
1342 test_name: &str,
1343 kernel: Kernel,
1344 ) -> Result<(), Box<dyn std::error::Error>> {
1345 skip_if_unsupported!(kernel, test_name);
1346 let input_data = [10.0, 20.0, 30.0];
1347 let params = SwmaParams { period: Some(0) };
1348 let input = SwmaInput::from_slice(&input_data, params);
1349 let res = swma_with_kernel(&input, kernel);
1350 assert!(
1351 res.is_err(),
1352 "[{}] SWMA should fail with zero period",
1353 test_name
1354 );
1355 Ok(())
1356 }
1357
1358 fn check_swma_period_exceeds_length(
1359 test_name: &str,
1360 kernel: Kernel,
1361 ) -> Result<(), Box<dyn std::error::Error>> {
1362 skip_if_unsupported!(kernel, test_name);
1363 let data_small = [10.0, 20.0, 30.0];
1364 let params = SwmaParams { period: Some(10) };
1365 let input = SwmaInput::from_slice(&data_small, params);
1366 let res = swma_with_kernel(&input, kernel);
1367 assert!(
1368 res.is_err(),
1369 "[{}] SWMA should fail with period exceeding length",
1370 test_name
1371 );
1372 Ok(())
1373 }
1374
1375 fn check_swma_very_small_dataset(
1376 test_name: &str,
1377 kernel: Kernel,
1378 ) -> Result<(), Box<dyn std::error::Error>> {
1379 skip_if_unsupported!(kernel, test_name);
1380 let single_point = [42.0];
1381 let params = SwmaParams { period: Some(5) };
1382 let input = SwmaInput::from_slice(&single_point, params);
1383 let res = swma_with_kernel(&input, kernel);
1384 assert!(
1385 res.is_err(),
1386 "[{}] SWMA should fail with insufficient data",
1387 test_name
1388 );
1389 Ok(())
1390 }
1391
1392 fn check_swma_reinput(
1393 test_name: &str,
1394 kernel: Kernel,
1395 ) -> Result<(), Box<dyn std::error::Error>> {
1396 skip_if_unsupported!(kernel, test_name);
1397 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1398 let candles = read_candles_from_csv(file_path)?;
1399 let first_params = SwmaParams { period: Some(5) };
1400 let first_input = SwmaInput::from_candles(&candles, "close", first_params);
1401 let first_result = swma_with_kernel(&first_input, kernel)?;
1402 let second_params = SwmaParams { period: Some(3) };
1403 let second_input = SwmaInput::from_slice(&first_result.values, second_params);
1404 let second_result = swma_with_kernel(&second_input, kernel)?;
1405 assert_eq!(second_result.values.len(), first_result.values.len());
1406 Ok(())
1407 }
1408
1409 fn check_swma_nan_handling(
1410 test_name: &str,
1411 kernel: Kernel,
1412 ) -> Result<(), Box<dyn std::error::Error>> {
1413 skip_if_unsupported!(kernel, test_name);
1414 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1415 let candles = read_candles_from_csv(file_path)?;
1416 let params = SwmaParams { period: Some(5) };
1417 let input = SwmaInput::from_candles(&candles, "close", params);
1418 let res = swma_with_kernel(&input, kernel)?;
1419 assert_eq!(res.values.len(), candles.close.len());
1420 if res.values.len() > 240 {
1421 for (i, &val) in res.values[240..].iter().enumerate() {
1422 assert!(
1423 !val.is_nan(),
1424 "[{}] Found unexpected NaN at out-index {}",
1425 test_name,
1426 240 + i
1427 );
1428 }
1429 }
1430 Ok(())
1431 }
1432
1433 fn check_swma_streaming(
1434 test_name: &str,
1435 kernel: Kernel,
1436 ) -> Result<(), Box<dyn std::error::Error>> {
1437 skip_if_unsupported!(kernel, test_name);
1438 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1439 let candles = read_candles_from_csv(file_path)?;
1440 let period = 5;
1441 let input = SwmaInput::from_candles(
1442 &candles,
1443 "close",
1444 SwmaParams {
1445 period: Some(period),
1446 },
1447 );
1448 let batch_output = swma_with_kernel(&input, kernel)?.values;
1449 let mut stream = SwmaStream::try_new(SwmaParams {
1450 period: Some(period),
1451 })?;
1452 let mut stream_values = Vec::with_capacity(candles.close.len());
1453 for &price in &candles.close {
1454 match stream.update(price) {
1455 Some(swma_val) => stream_values.push(swma_val),
1456 None => stream_values.push(f64::NAN),
1457 }
1458 }
1459 assert_eq!(batch_output.len(), stream_values.len());
1460 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1461 if b.is_nan() && s.is_nan() {
1462 continue;
1463 }
1464 let diff = (b - s).abs();
1465 assert!(
1466 diff < 1e-9,
1467 "[{}] SWMA streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1468 test_name,
1469 i,
1470 b,
1471 s,
1472 diff
1473 );
1474 }
1475 Ok(())
1476 }
1477
1478 macro_rules! generate_all_swma_tests {
1479 ($($test_fn:ident),*) => {
1480 paste::paste! {
1481 $(
1482 #[test]
1483 fn [<$test_fn _scalar_f64>]() {
1484 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1485 }
1486 )*
1487 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1488 $(
1489 #[test]
1490 fn [<$test_fn _avx2_f64>]() {
1491 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1492 }
1493 #[test]
1494 fn [<$test_fn _avx512_f64>]() {
1495 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1496 }
1497 )*
1498 }
1499 }
1500 }
1501
1502 #[cfg(debug_assertions)]
1503 fn check_swma_no_poison(
1504 test_name: &str,
1505 kernel: Kernel,
1506 ) -> Result<(), Box<dyn std::error::Error>> {
1507 skip_if_unsupported!(kernel, test_name);
1508
1509 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1510 let candles = read_candles_from_csv(file_path)?;
1511
1512 let test_periods = vec![1, 2, 3, 5, 7, 10, 15, 20, 30, 50, 100];
1513
1514 for period in test_periods {
1515 let params = SwmaParams {
1516 period: Some(period),
1517 };
1518 let input = SwmaInput::from_candles(&candles, "close", params);
1519
1520 if period > candles.close.len() {
1521 continue;
1522 }
1523
1524 let output = swma_with_kernel(&input, kernel)?;
1525
1526 for (i, &val) in output.values.iter().enumerate() {
1527 if val.is_nan() {
1528 continue;
1529 }
1530
1531 let bits = val.to_bits();
1532
1533 if bits == 0x11111111_11111111 {
1534 panic!(
1535 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} with period {}",
1536 test_name, val, bits, i, period
1537 );
1538 }
1539
1540 if bits == 0x22222222_22222222 {
1541 panic!(
1542 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} with period {}",
1543 test_name, val, bits, i, period
1544 );
1545 }
1546
1547 if bits == 0x33333333_33333333 {
1548 panic!(
1549 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} with period {}",
1550 test_name, val, bits, i, period
1551 );
1552 }
1553 }
1554 }
1555
1556 Ok(())
1557 }
1558
1559 #[cfg(not(debug_assertions))]
1560 fn check_swma_no_poison(
1561 _test_name: &str,
1562 _kernel: Kernel,
1563 ) -> Result<(), Box<dyn std::error::Error>> {
1564 Ok(())
1565 }
1566
1567 #[cfg(feature = "proptest")]
1568 fn check_swma_property(
1569 test_name: &str,
1570 kernel: Kernel,
1571 ) -> Result<(), Box<dyn std::error::Error>> {
1572 use proptest::prelude::*;
1573 skip_if_unsupported!(kernel, test_name);
1574
1575 let strat = (1usize..=100).prop_flat_map(|period| {
1576 (
1577 prop::collection::vec(
1578 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1579 period.max(2)..400,
1580 ),
1581 Just(period),
1582 )
1583 });
1584
1585 proptest::test_runner::TestRunner::default()
1586 .run(&strat, |(data, period)| {
1587 let params = SwmaParams {
1588 period: Some(period),
1589 };
1590 let input = SwmaInput::from_slice(&data, params);
1591
1592 let SwmaOutput { values: out } = swma_with_kernel(&input, kernel).unwrap();
1593 let SwmaOutput { values: ref_out } =
1594 swma_with_kernel(&input, Kernel::Scalar).unwrap();
1595
1596 prop_assert_eq!(out.len(), data.len(), "Output length mismatch");
1597
1598 if period > 1 {
1599 for i in 0..(period - 1) {
1600 prop_assert!(
1601 out[i].is_nan(),
1602 "Expected NaN during warmup at index {}, got {}",
1603 i,
1604 out[i]
1605 );
1606 }
1607 }
1608
1609 let weights = build_symmetric_triangle_avec(period);
1610
1611 let weight_sum: f64 = weights.iter().sum();
1612 prop_assert!(
1613 (weight_sum - 1.0).abs() < 1e-10,
1614 "Weights don't sum to 1.0, got {}",
1615 weight_sum
1616 );
1617
1618 for i in 0..period / 2 {
1619 let left = weights[i];
1620 let right = weights[period - 1 - i];
1621 prop_assert!(
1622 (left - right).abs() < 1e-10,
1623 "Weights not symmetric at positions {} and {}: {} vs {}",
1624 i,
1625 period - 1 - i,
1626 left,
1627 right
1628 );
1629 }
1630
1631 for i in (period - 1)..data.len() {
1632 let window = &data[i + 1 - period..=i];
1633 let lo = window.iter().cloned().fold(f64::INFINITY, f64::min);
1634 let hi = window.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1635 let y = out[i];
1636 let r = ref_out[i];
1637
1638 prop_assert!(
1639 y.is_nan() || (y >= lo - 1e-9 && y <= hi + 1e-9),
1640 "idx {}: {} ∉ [{}, {}]",
1641 i,
1642 y,
1643 lo,
1644 hi
1645 );
1646
1647 if period == 1 {
1648 prop_assert!(
1649 (y - data[i]).abs() <= f64::EPSILON,
1650 "Period=1 should return input value at idx {}: {} vs {}",
1651 i,
1652 y,
1653 data[i]
1654 );
1655 }
1656
1657 if period == 2 && i >= 1 {
1658 let expected = (data[i - 1] + data[i]) / 2.0;
1659 prop_assert!(
1660 (y - expected).abs() < 1e-9,
1661 "Period=2 should return average at idx {}: {} vs {}",
1662 i,
1663 y,
1664 expected
1665 );
1666 }
1667
1668 if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10) {
1669 prop_assert!(
1670 (y - data[0]).abs() < 1e-9,
1671 "Constant data should produce constant output at idx {}: {} vs {}",
1672 i,
1673 y,
1674 data[0]
1675 );
1676 }
1677
1678 let y_bits = y.to_bits();
1679 let r_bits = r.to_bits();
1680
1681 if !y.is_finite() || !r.is_finite() {
1682 prop_assert!(
1683 y.to_bits() == r.to_bits(),
1684 "finite/NaN mismatch idx {}: {} vs {}",
1685 i,
1686 y,
1687 r
1688 );
1689 continue;
1690 }
1691
1692 let ulp_diff: u64 = y_bits.abs_diff(r_bits);
1693
1694 let max_ulp = if matches!(kernel, Kernel::Avx512) {
1695 20
1696 } else {
1697 10
1698 };
1699
1700 prop_assert!(
1701 (y - r).abs() <= 1e-9 || ulp_diff <= max_ulp,
1702 "mismatch idx {}: {} vs {} (ULP={})",
1703 i,
1704 y,
1705 r,
1706 ulp_diff
1707 );
1708 }
1709
1710 Ok(())
1711 })
1712 .unwrap();
1713
1714 Ok(())
1715 }
1716
1717 generate_all_swma_tests!(
1718 check_swma_partial_params,
1719 check_swma_accuracy,
1720 check_swma_default_candles,
1721 check_swma_zero_period,
1722 check_swma_period_exceeds_length,
1723 check_swma_very_small_dataset,
1724 check_swma_reinput,
1725 check_swma_nan_handling,
1726 check_swma_streaming,
1727 check_swma_no_poison
1728 );
1729
1730 #[cfg(feature = "proptest")]
1731 generate_all_swma_tests!(check_swma_property);
1732
1733 fn check_batch_default_row(
1734 test: &str,
1735 kernel: Kernel,
1736 ) -> Result<(), Box<dyn std::error::Error>> {
1737 skip_if_unsupported!(kernel, test);
1738
1739 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1740 let c = read_candles_from_csv(file)?;
1741
1742 let output = SwmaBatchBuilder::new()
1743 .kernel(kernel)
1744 .apply_candles(&c, "close")?;
1745
1746 let def = SwmaParams::default();
1747 let period = def.period.unwrap_or(5);
1748 let row = output.values_for(&def).expect("default row missing");
1749
1750 assert_eq!(row.len(), c.close.len());
1751
1752 let expected = [
1753 59288.22222222222,
1754 59301.99999999999,
1755 59247.33333333333,
1756 59179.88888888889,
1757 59080.99999999999,
1758 ];
1759 let tail = &row[row.len() - 5..];
1760 for (i, &v) in tail.iter().enumerate() {
1761 assert!(
1762 (v - expected[i]).abs() < 1e-8,
1763 "[{test}] default-row mismatch at idx {i}: {v} vs {}",
1764 expected[i]
1765 );
1766 }
1767 Ok(())
1768 }
1769
1770 macro_rules! gen_batch_tests {
1771 ($fn_name:ident) => {
1772 paste::paste! {
1773 #[test] fn [<$fn_name _scalar>]() {
1774 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1775 }
1776 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1777 #[test] fn [<$fn_name _avx2>]() {
1778 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1779 }
1780 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1781 #[test] fn [<$fn_name _avx512>]() {
1782 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1783 }
1784 #[test] fn [<$fn_name _auto_detect>]() {
1785 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1786 }
1787 }
1788 };
1789 }
1790
1791 #[cfg(debug_assertions)]
1792 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
1793 skip_if_unsupported!(kernel, test);
1794
1795 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1796 let c = read_candles_from_csv(file)?;
1797
1798 let batch_configs = vec![
1799 (1, 10, 1),
1800 (3, 9, 3),
1801 (5, 25, 5),
1802 (10, 50, 10),
1803 (2, 2, 1),
1804 (1, 30, 2),
1805 ];
1806
1807 for (start, end, step) in batch_configs {
1808 if end > c.close.len() {
1809 continue;
1810 }
1811
1812 let output = SwmaBatchBuilder::new()
1813 .kernel(kernel)
1814 .period_range(start, end, step)
1815 .apply_candles(&c, "close")?;
1816
1817 for (idx, &val) in output.values.iter().enumerate() {
1818 if val.is_nan() {
1819 continue;
1820 }
1821
1822 let bits = val.to_bits();
1823 let row = idx / output.cols;
1824 let col = idx % output.cols;
1825 let period = if row < output.combos.len() {
1826 output.combos[row].period.unwrap_or(0)
1827 } else {
1828 0
1829 };
1830
1831 if bits == 0x11111111_11111111 {
1832 panic!(
1833 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (flat index {}) with period {} in batch ({}, {}, {})",
1834 test, val, bits, row, col, idx, period, start, end, step
1835 );
1836 }
1837
1838 if bits == 0x22222222_22222222 {
1839 panic!(
1840 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (flat index {}) with period {} in batch ({}, {}, {})",
1841 test, val, bits, row, col, idx, period, start, end, step
1842 );
1843 }
1844
1845 if bits == 0x33333333_33333333 {
1846 panic!(
1847 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (flat index {}) with period {} in batch ({}, {}, {})",
1848 test, val, bits, row, col, idx, period, start, end, step
1849 );
1850 }
1851 }
1852 }
1853
1854 Ok(())
1855 }
1856
1857 #[cfg(not(debug_assertions))]
1858 fn check_batch_no_poison(
1859 _test: &str,
1860 _kernel: Kernel,
1861 ) -> Result<(), Box<dyn std::error::Error>> {
1862 Ok(())
1863 }
1864
1865 #[test]
1866 fn test_swma_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
1867 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1868 let candles = read_candles_from_csv(file_path)?;
1869
1870 let input = SwmaInput::with_default_candles(&candles);
1871 let baseline = swma(&input)?.values;
1872
1873 let mut out = vec![0.0f64; baseline.len()];
1874
1875 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1876 {
1877 swma_into(&input, &mut out)?;
1878 }
1879 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1880 {
1881 swma_into_slice(&mut out, &input, Kernel::Auto)?;
1882 }
1883
1884 assert_eq!(out.len(), baseline.len());
1885
1886 for (i, (&a, &b)) in out.iter().zip(baseline.iter()).enumerate() {
1887 let equal = (a.is_nan() && b.is_nan()) || (a == b);
1888 assert!(
1889 equal,
1890 "into parity mismatch at idx {}: got {}, expected {}",
1891 i, a, b
1892 );
1893 }
1894
1895 Ok(())
1896 }
1897
1898 gen_batch_tests!(check_batch_default_row);
1899 gen_batch_tests!(check_batch_no_poison);
1900}
1901
1902#[cfg(feature = "python")]
1903#[pyfunction(name = "swma")]
1904#[pyo3(signature = (data, period, kernel=None))]
1905
1906pub fn swma_py<'py>(
1907 py: Python<'py>,
1908 data: numpy::PyReadonlyArray1<'py, f64>,
1909 period: usize,
1910 kernel: Option<&str>,
1911) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
1912 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1913
1914 let slice_in = data.as_slice()?;
1915 let kern = validate_kernel(kernel, false)?;
1916
1917 let params = SwmaParams {
1918 period: Some(period),
1919 };
1920 let swma_in = SwmaInput::from_slice(slice_in, params);
1921
1922 let result_vec: Vec<f64> = py
1923 .allow_threads(|| swma_with_kernel(&swma_in, kern).map(|o| o.values))
1924 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1925
1926 Ok(result_vec.into_pyarray(py))
1927}
1928
1929#[cfg(feature = "python")]
1930#[pyclass(name = "SwmaStream")]
1931pub struct SwmaStreamPy {
1932 stream: SwmaStream,
1933}
1934
1935#[cfg(feature = "python")]
1936#[pymethods]
1937impl SwmaStreamPy {
1938 #[new]
1939 fn new(period: usize) -> PyResult<Self> {
1940 let params = SwmaParams {
1941 period: Some(period),
1942 };
1943 let stream =
1944 SwmaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1945 Ok(SwmaStreamPy { stream })
1946 }
1947
1948 fn update(&mut self, value: f64) -> Option<f64> {
1949 self.stream.update(value)
1950 }
1951}
1952
1953#[cfg(feature = "python")]
1954#[pyfunction(name = "swma_batch")]
1955#[pyo3(signature = (data, period_range, kernel=None))]
1956
1957pub fn swma_batch_py<'py>(
1958 py: Python<'py>,
1959 data: numpy::PyReadonlyArray1<'py, f64>,
1960 period_range: (usize, usize, usize),
1961 kernel: Option<&str>,
1962) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1963 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1964 use pyo3::types::PyDict;
1965
1966 let slice_in = data.as_slice()?;
1967 let kern = validate_kernel(kernel, true)?;
1968
1969 let sweep = SwmaBatchRange {
1970 period: period_range,
1971 };
1972
1973 let combos = expand_grid(&sweep);
1974 let rows = combos.len();
1975 let cols = slice_in.len();
1976
1977 let rows_cols = rows
1978 .checked_mul(cols)
1979 .ok_or_else(|| PyValueError::new_err("swma: rows*cols overflow during allocation"))?;
1980 let out_arr = unsafe { PyArray1::<f64>::new(py, [rows_cols], false) };
1981 let slice_out = unsafe { out_arr.as_slice_mut()? };
1982
1983 let combos = py
1984 .allow_threads(|| swma_batch_inner_into(slice_in, &sweep, kern, true, slice_out))
1985 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1986
1987 let dict = PyDict::new(py);
1988 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1989 dict.set_item(
1990 "periods",
1991 combos
1992 .iter()
1993 .map(|c| c.period.unwrap_or(5))
1994 .collect::<Vec<_>>()
1995 .into_pyarray(py),
1996 )?;
1997
1998 Ok(dict)
1999}
2000
2001#[cfg(all(feature = "python", feature = "cuda"))]
2002#[pyfunction(name = "swma_cuda_batch_dev")]
2003#[pyo3(signature = (data, period_range, device_id=0))]
2004pub fn swma_cuda_batch_dev_py(
2005 py: Python<'_>,
2006 data: numpy::PyReadonlyArray1<'_, f64>,
2007 period_range: (usize, usize, usize),
2008 device_id: usize,
2009) -> PyResult<DeviceArrayF32SwmaPy> {
2010 use numpy::PyArrayMethods;
2011
2012 if !cuda_available() {
2013 return Err(PyValueError::new_err("CUDA not available"));
2014 }
2015
2016 let slice_in = data.as_slice()?;
2017 let sweep = SwmaBatchRange {
2018 period: period_range,
2019 };
2020 let data_f32: Vec<f32> = slice_in.iter().map(|&v| v as f32).collect();
2021
2022 let (inner, ctx, dev_id) = py.allow_threads(|| {
2023 let cuda = CudaSwma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2024 let ctx = cuda.context_arc();
2025 let dev_id = cuda.device_id();
2026 cuda.swma_batch_dev(&data_f32, &sweep)
2027 .map(|inner| (inner, ctx, dev_id))
2028 .map_err(|e| PyValueError::new_err(e.to_string()))
2029 })?;
2030
2031 Ok(DeviceArrayF32SwmaPy {
2032 inner: Some(DeviceArrayF32Py {
2033 inner,
2034 _ctx: Some(ctx),
2035 device_id: Some(dev_id),
2036 }),
2037 })
2038}
2039
2040#[cfg(all(feature = "python", feature = "cuda"))]
2041#[pyfunction(name = "swma_cuda_many_series_one_param_dev")]
2042#[pyo3(signature = (data_tm_f32, period, device_id=0))]
2043pub fn swma_cuda_many_series_one_param_dev_py(
2044 py: Python<'_>,
2045 data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
2046 period: usize,
2047 device_id: usize,
2048) -> PyResult<DeviceArrayF32SwmaPy> {
2049 use numpy::PyUntypedArrayMethods;
2050
2051 if !cuda_available() {
2052 return Err(PyValueError::new_err("CUDA not available"));
2053 }
2054
2055 let flat_in = data_tm_f32.as_slice()?;
2056 let rows = data_tm_f32.shape()[0];
2057 let cols = data_tm_f32.shape()[1];
2058 let params = SwmaParams {
2059 period: Some(period),
2060 };
2061
2062 let (inner, ctx, dev_id) = py.allow_threads(|| {
2063 let cuda = CudaSwma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2064 let ctx = cuda.context_arc();
2065 let dev_id = cuda.device_id();
2066 cuda.swma_multi_series_one_param_time_major_dev(flat_in, cols, rows, ¶ms)
2067 .map(|inner| (inner, ctx, dev_id))
2068 .map_err(|e| PyValueError::new_err(e.to_string()))
2069 })?;
2070
2071 Ok(DeviceArrayF32SwmaPy {
2072 inner: Some(DeviceArrayF32Py {
2073 inner,
2074 _ctx: Some(ctx),
2075 device_id: Some(dev_id),
2076 }),
2077 })
2078}
2079
2080#[cfg(all(feature = "python", feature = "cuda"))]
2081#[pyclass(module = "ta_indicators.cuda", name = "DeviceArrayF32Swma", unsendable)]
2082pub struct DeviceArrayF32SwmaPy {
2083 pub(crate) inner: Option<DeviceArrayF32Py>,
2084}
2085
2086#[cfg(all(feature = "python", feature = "cuda"))]
2087#[pymethods]
2088impl DeviceArrayF32SwmaPy {
2089 #[getter]
2090 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2091 let inner = self
2092 .inner
2093 .as_ref()
2094 .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
2095 inner.__cuda_array_interface__(py)
2096 }
2097
2098 fn __dlpack_device__(&self) -> PyResult<(i32, i32)> {
2099 let inner = self
2100 .inner
2101 .as_ref()
2102 .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
2103 inner.__dlpack_device__()
2104 }
2105
2106 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2107 fn __dlpack__<'py>(
2108 &mut self,
2109 py: Python<'py>,
2110 stream: Option<PyObject>,
2111 max_version: Option<PyObject>,
2112 dl_device: Option<PyObject>,
2113 copy: Option<PyObject>,
2114 ) -> PyResult<PyObject> {
2115 let mut inner = self
2116 .inner
2117 .take()
2118 .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
2119 let capsule = inner.__dlpack__(py, stream, max_version, dl_device, copy)?;
2120 Ok(capsule)
2121 }
2122}
2123
2124#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2125#[wasm_bindgen]
2126pub fn swma_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
2127 let params = SwmaParams {
2128 period: Some(period),
2129 };
2130 let input = SwmaInput::from_slice(data, params);
2131
2132 let mut output = vec![0.0; data.len()];
2133
2134 swma_into_slice(&mut output, &input, Kernel::Auto)
2135 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2136
2137 Ok(output)
2138}
2139
2140#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2141#[wasm_bindgen]
2142pub fn swma_batch_js(
2143 data: &[f64],
2144 period_start: usize,
2145 period_end: usize,
2146 period_step: usize,
2147) -> Result<Vec<f64>, JsValue> {
2148 let sweep = SwmaBatchRange {
2149 period: (period_start, period_end, period_step),
2150 };
2151 swma_batch_with_kernel(data, &sweep, Kernel::Auto)
2152 .map(|o| o.values)
2153 .map_err(|e| JsValue::from_str(&e.to_string()))
2154}
2155
2156#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2157#[wasm_bindgen]
2158pub fn swma_batch_metadata_js(
2159 period_start: usize,
2160 period_end: usize,
2161 period_step: usize,
2162) -> Result<Vec<f64>, JsValue> {
2163 let sweep = SwmaBatchRange {
2164 period: (period_start, period_end, period_step),
2165 };
2166
2167 let combos = expand_grid(&sweep);
2168 let mut metadata = Vec::with_capacity(combos.len());
2169
2170 for combo in combos {
2171 metadata.push(combo.period.unwrap_or(5) as f64);
2172 }
2173
2174 Ok(metadata)
2175}
2176
2177#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2178#[derive(Serialize, Deserialize)]
2179pub struct SwmaBatchConfig {
2180 pub period_range: (usize, usize, usize),
2181}
2182
2183#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2184#[derive(Serialize, Deserialize)]
2185pub struct SwmaBatchJsOutput {
2186 pub values: Vec<f64>,
2187 pub combos: Vec<SwmaParams>,
2188 pub rows: usize,
2189 pub cols: usize,
2190}
2191
2192#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2193#[wasm_bindgen(js_name = swma_batch)]
2194pub fn swma_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2195 let config: SwmaBatchConfig = serde_wasm_bindgen::from_value(config)
2196 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2197
2198 let sweep = SwmaBatchRange {
2199 period: config.period_range,
2200 };
2201
2202 let output = swma_batch_with_kernel(data, &sweep, Kernel::Auto)
2203 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2204
2205 let js_output = SwmaBatchJsOutput {
2206 values: output.values,
2207 combos: output.combos,
2208 rows: output.rows,
2209 cols: output.cols,
2210 };
2211
2212 serde_wasm_bindgen::to_value(&js_output)
2213 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2214}
2215
2216#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2217#[wasm_bindgen]
2218pub fn swma_alloc(len: usize) -> *mut f64 {
2219 let mut vec = Vec::<f64>::with_capacity(len);
2220 let ptr = vec.as_mut_ptr();
2221 std::mem::forget(vec);
2222 ptr
2223}
2224
2225#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2226#[wasm_bindgen]
2227pub fn swma_free(ptr: *mut f64, len: usize) {
2228 if !ptr.is_null() {
2229 unsafe {
2230 let _ = Vec::from_raw_parts(ptr, len, len);
2231 }
2232 }
2233}
2234
2235#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2236#[wasm_bindgen]
2237pub fn swma_into(
2238 in_ptr: *const f64,
2239 out_ptr: *mut f64,
2240 len: usize,
2241 period: usize,
2242) -> Result<(), JsValue> {
2243 if in_ptr.is_null() || out_ptr.is_null() {
2244 return Err(JsValue::from_str("Null pointer provided"));
2245 }
2246
2247 unsafe {
2248 let data = std::slice::from_raw_parts(in_ptr, len);
2249
2250 if period == 0 || period > len {
2251 return Err(JsValue::from_str("Invalid period"));
2252 }
2253
2254 let params = SwmaParams {
2255 period: Some(period),
2256 };
2257 let input = SwmaInput::from_slice(data, params);
2258
2259 if in_ptr == out_ptr {
2260 let mut temp = vec![0.0; len];
2261 swma_into_slice(&mut temp, &input, Kernel::Auto)
2262 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2263
2264 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2265 out.copy_from_slice(&temp);
2266 } else {
2267 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2268 swma_into_slice(out, &input, Kernel::Auto)
2269 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2270 }
2271
2272 Ok(())
2273 }
2274}
2275
2276#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2277#[wasm_bindgen]
2278pub fn swma_batch_into(
2279 in_ptr: *const f64,
2280 out_ptr: *mut f64,
2281 len: usize,
2282 period_start: usize,
2283 period_end: usize,
2284 period_step: usize,
2285) -> Result<usize, JsValue> {
2286 if in_ptr.is_null() || out_ptr.is_null() {
2287 return Err(JsValue::from_str("null pointer passed to swma_batch_into"));
2288 }
2289
2290 unsafe {
2291 let data = std::slice::from_raw_parts(in_ptr, len);
2292
2293 let sweep = SwmaBatchRange {
2294 period: (period_start, period_end, period_step),
2295 };
2296
2297 let combos = expand_grid(&sweep);
2298 if combos.is_empty() {
2299 return Err(JsValue::from_str(
2300 "swma: invalid period range (empty expansion)",
2301 ));
2302 }
2303 let rows = combos.len();
2304 let cols = len;
2305 let rows_cols = rows
2306 .checked_mul(cols)
2307 .ok_or_else(|| JsValue::from_str("swma: rows*cols overflow"))?;
2308
2309 let out = std::slice::from_raw_parts_mut(out_ptr, rows_cols);
2310
2311 swma_batch_inner_into(data, &sweep, Kernel::Auto, false, out)
2312 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2313
2314 Ok(rows)
2315 }
2316}