1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5 make_uninit_matrix,
6};
7#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
8use core::arch::x86_64::*;
9#[cfg(not(target_arch = "wasm32"))]
10use rayon::prelude::*;
11use std::convert::AsRef;
12use thiserror::Error;
13
14impl<'a> AsRef<[f64]> for MomInput<'a> {
15 #[inline(always)]
16 fn as_ref(&self) -> &[f64] {
17 match &self.data {
18 MomData::Slice(slice) => slice,
19 MomData::Candles { candles, source } => source_type(candles, source),
20 }
21 }
22}
23
24#[derive(Debug, Clone)]
25pub enum MomData<'a> {
26 Candles {
27 candles: &'a Candles,
28 source: &'a str,
29 },
30 Slice(&'a [f64]),
31}
32
33#[derive(Debug, Clone)]
34pub struct MomOutput {
35 pub values: Vec<f64>,
36}
37
38#[derive(Debug, Clone)]
39#[cfg_attr(
40 all(target_arch = "wasm32", feature = "wasm"),
41 derive(Serialize, Deserialize)
42)]
43pub struct MomParams {
44 pub period: Option<usize>,
45}
46
47impl Default for MomParams {
48 fn default() -> Self {
49 Self { period: Some(10) }
50 }
51}
52
53#[derive(Debug, Clone)]
54pub struct MomInput<'a> {
55 pub data: MomData<'a>,
56 pub params: MomParams,
57}
58
59impl<'a> MomInput<'a> {
60 #[inline]
61 pub fn from_candles(c: &'a Candles, s: &'a str, p: MomParams) -> Self {
62 Self {
63 data: MomData::Candles {
64 candles: c,
65 source: s,
66 },
67 params: p,
68 }
69 }
70 #[inline]
71 pub fn from_slice(sl: &'a [f64], p: MomParams) -> Self {
72 Self {
73 data: MomData::Slice(sl),
74 params: p,
75 }
76 }
77 #[inline]
78 pub fn with_default_candles(c: &'a Candles) -> Self {
79 Self::from_candles(c, "close", MomParams::default())
80 }
81 #[inline]
82 pub fn get_period(&self) -> usize {
83 self.params.period.unwrap_or(10)
84 }
85}
86
87#[derive(Clone, Debug)]
88pub struct MomBuilder {
89 period: Option<usize>,
90 kernel: Kernel,
91}
92
93impl Default for MomBuilder {
94 fn default() -> Self {
95 Self {
96 period: None,
97 kernel: Kernel::Auto,
98 }
99 }
100}
101
102impl MomBuilder {
103 #[inline(always)]
104 pub fn new() -> Self {
105 Self::default()
106 }
107 #[inline(always)]
108 pub fn period(mut self, n: usize) -> Self {
109 self.period = Some(n);
110 self
111 }
112 #[inline(always)]
113 pub fn kernel(mut self, k: Kernel) -> Self {
114 self.kernel = k;
115 self
116 }
117 #[inline(always)]
118 pub fn apply(self, c: &Candles) -> Result<MomOutput, MomError> {
119 let p = MomParams {
120 period: self.period,
121 };
122 let i = MomInput::from_candles(c, "close", p);
123 mom_with_kernel(&i, self.kernel)
124 }
125 #[inline(always)]
126 pub fn apply_slice(self, d: &[f64]) -> Result<MomOutput, MomError> {
127 let p = MomParams {
128 period: self.period,
129 };
130 let i = MomInput::from_slice(d, p);
131 mom_with_kernel(&i, self.kernel)
132 }
133 #[inline(always)]
134 pub fn into_stream(self) -> Result<MomStream, MomError> {
135 let p = MomParams {
136 period: self.period,
137 };
138 MomStream::try_new(p)
139 }
140}
141
142#[derive(Debug, Error)]
143pub enum MomError {
144 #[error("mom: Input data slice is empty.")]
145 EmptyInputData,
146 #[error("mom: All values are NaN.")]
147 AllValuesNaN,
148
149 #[error("mom: Invalid period: period = {period}, data length = {data_len}")]
150 InvalidPeriod { period: usize, data_len: usize },
151
152 #[error("mom: Not enough valid data: needed = {needed}, valid = {valid}")]
153 NotEnoughValidData { needed: usize, valid: usize },
154
155 #[error("mom: Output length mismatch (expected {expected}, got {got})")]
156 OutputLengthMismatch { expected: usize, got: usize },
157
158 #[error("mom: invalid range: start={start} end={end} step={step}")]
159 InvalidRange {
160 start: usize,
161 end: usize,
162 step: usize,
163 },
164
165 #[error("mom: invalid kernel for batch path: {0:?}")]
166 InvalidKernelForBatch(Kernel),
167
168 #[error("mom: invalid input: {0}")]
169 InvalidInput(&'static str),
170}
171
172#[inline]
173pub fn mom(input: &MomInput) -> Result<MomOutput, MomError> {
174 mom_with_kernel(input, Kernel::Auto)
175}
176
177#[inline(always)]
178fn mom_prepare<'a>(
179 input: &'a MomInput,
180 kernel: Kernel,
181) -> Result<(&'a [f64], usize, usize, Kernel), MomError> {
182 let data: &[f64] = input.as_ref();
183 let len = data.len();
184 if len == 0 {
185 return Err(MomError::EmptyInputData);
186 }
187 let first = data
188 .iter()
189 .position(|x| !x.is_nan())
190 .ok_or(MomError::AllValuesNaN)?;
191 let period = input.get_period();
192
193 if period == 0 || period > len {
194 return Err(MomError::InvalidPeriod {
195 period,
196 data_len: len,
197 });
198 }
199 if len - first < period {
200 return Err(MomError::NotEnoughValidData {
201 needed: period,
202 valid: len - first,
203 });
204 }
205
206 let chosen = match kernel {
207 Kernel::Auto => Kernel::Scalar,
208 k => k,
209 };
210 Ok((data, period, first, chosen))
211}
212
213#[inline(always)]
214fn mom_compute_into(data: &[f64], period: usize, first: usize, kernel: Kernel, out: &mut [f64]) {
215 unsafe {
216 match kernel {
217 Kernel::Scalar | Kernel::ScalarBatch => mom_scalar(data, period, first, out),
218 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
219 Kernel::Avx2 | Kernel::Avx2Batch => mom_avx2(data, period, first, out),
220 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
221 Kernel::Avx512 | Kernel::Avx512Batch => mom_avx512(data, period, first, out),
222 _ => unreachable!(),
223 }
224 }
225}
226
227pub fn mom_with_kernel(input: &MomInput, kernel: Kernel) -> Result<MomOutput, MomError> {
228 let (data, period, first, chosen) = mom_prepare(input, kernel)?;
229 let warm = first + period;
230 let mut out = alloc_with_nan_prefix(data.len(), warm);
231 mom_compute_into(data, period, first, chosen, &mut out);
232 Ok(MomOutput { values: out })
233}
234
235#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
236#[inline]
237pub fn mom_into(input: &MomInput, out: &mut [f64]) -> Result<(), MomError> {
238 let (data, period, first, chosen) = mom_prepare(input, Kernel::Auto)?;
239
240 if out.len() != data.len() {
241 return Err(MomError::OutputLengthMismatch {
242 expected: data.len(),
243 got: out.len(),
244 });
245 }
246
247 let warm = first + period;
248 let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
249 let n = warm.min(out.len());
250 for v in &mut out[..n] {
251 *v = qnan;
252 }
253
254 mom_compute_into(data, period, first, chosen, out);
255
256 Ok(())
257}
258
259#[inline(always)]
260pub fn mom_scalar(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
261 let n = data.len();
262 let start = first_valid + period;
263 if start >= n {
264 return;
265 }
266
267 unsafe {
268 let mut cur = data.as_ptr().add(start);
269 let mut prev = data.as_ptr().add(start - period);
270 let mut dst = out.as_mut_ptr().add(start);
271
272 let mut m = n - start;
273 while m >= 4 {
274 *dst = *cur - *prev;
275 *dst.add(1) = *cur.add(1) - *prev.add(1);
276 *dst.add(2) = *cur.add(2) - *prev.add(2);
277 *dst.add(3) = *cur.add(3) - *prev.add(3);
278
279 cur = cur.add(4);
280 prev = prev.add(4);
281 dst = dst.add(4);
282 m -= 4;
283 }
284
285 while m != 0 {
286 *dst = *cur - *prev;
287 cur = cur.add(1);
288 prev = prev.add(1);
289 dst = dst.add(1);
290 m -= 1;
291 }
292 }
293}
294
295#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
296#[inline(always)]
297pub fn mom_avx512(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
298 unsafe {
299 if period <= 32 {
300 mom_avx512_short(data, period, first_valid, out);
301 } else {
302 mom_avx512_long(data, period, first_valid, out);
303 }
304 }
305}
306
307#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
308#[target_feature(enable = "avx512f")]
309unsafe fn mom_avx512_short(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
310 mom_avx512_core(data, period, first_valid, out)
311}
312
313#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
314#[target_feature(enable = "avx512f")]
315unsafe fn mom_avx512_long(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
316 mom_avx512_core(data, period, first_valid, out)
317}
318
319#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
320#[target_feature(enable = "avx2")]
321unsafe fn mom_avx2(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
322 use core::arch::x86_64::*;
323 let n = data.len();
324 let start = first_valid + period;
325 if start >= n {
326 return;
327 }
328
329 let mut cur = data.as_ptr().add(start);
330 let mut prev = data.as_ptr().add(start - period);
331 let mut dst = out.as_mut_ptr().add(start);
332
333 let mut m = n - start;
334 while m >= 4 {
335 let a = _mm256_loadu_pd(cur);
336 let b = _mm256_loadu_pd(prev);
337 let d = _mm256_sub_pd(a, b);
338 _mm256_storeu_pd(dst, d);
339
340 cur = cur.add(4);
341 prev = prev.add(4);
342 dst = dst.add(4);
343 m -= 4;
344 }
345
346 while m != 0 {
347 *dst = *cur - *prev;
348 cur = cur.add(1);
349 prev = prev.add(1);
350 dst = dst.add(1);
351 m -= 1;
352 }
353}
354
355#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
356#[target_feature(enable = "avx512f")]
357unsafe fn mom_avx512_core(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
358 use core::arch::x86_64::*;
359 let n = data.len();
360 let start = first_valid + period;
361 if start >= n {
362 return;
363 }
364
365 let mut cur = data.as_ptr().add(start);
366 let mut prev = data.as_ptr().add(start - period);
367 let mut dst = out.as_mut_ptr().add(start);
368
369 let mut m = n - start;
370 while m >= 8 {
371 let a = _mm512_loadu_pd(cur);
372 let b = _mm512_loadu_pd(prev);
373 let d = _mm512_sub_pd(a, b);
374 _mm512_storeu_pd(dst, d);
375
376 cur = cur.add(8);
377 prev = prev.add(8);
378 dst = dst.add(8);
379 m -= 8;
380 }
381
382 if m != 0 {
383 let mask: __mmask8 = ((1u16 << m) - 1) as u8;
384 let a = _mm512_maskz_loadu_pd(mask, cur);
385 let b = _mm512_maskz_loadu_pd(mask, prev);
386 let d = _mm512_sub_pd(a, b);
387 _mm512_mask_storeu_pd(dst, mask, d);
388 }
389}
390
391#[derive(Debug, Clone)]
392pub struct MomStream {
393 period: usize,
394 buffer: Vec<f64>,
395 head: usize,
396 filled: bool,
397}
398
399impl MomStream {
400 pub fn try_new(params: MomParams) -> Result<Self, MomError> {
401 let period = params.period.unwrap_or(10);
402 if period == 0 {
403 return Err(MomError::InvalidPeriod {
404 period,
405 data_len: 0,
406 });
407 }
408 Ok(Self {
409 period,
410 buffer: vec![f64::NAN; period],
411 head: 0,
412 filled: false,
413 })
414 }
415
416 #[inline(always)]
417 pub fn update(&mut self, value: f64) -> Option<f64> {
418 if self.period == 1 {
419 let prev = self.buffer[0];
420 self.buffer[0] = value;
421
422 if !self.filled {
423 self.filled = true;
424 return None;
425 }
426 return Some(value - prev);
427 }
428
429 let idx = self.head;
430 let prev = self.buffer[idx];
431 self.buffer[idx] = value;
432
433 let next = idx + 1;
434 if self.period.is_power_of_two() {
435 let mask = self.period - 1;
436 self.head = next & mask;
437 } else if next == self.period {
438 self.head = 0;
439 } else {
440 self.head = next;
441 }
442
443 if !self.filled {
444 if self.head == 0 {
445 self.filled = true;
446 }
447 return None;
448 }
449
450 Some(value - prev)
451 }
452
453 #[inline(always)]
454 pub fn update_reset_on_nan(&mut self, value: f64) -> Option<f64> {
455 if value.is_nan() {
456 self.head = 0;
457 self.filled = false;
458 return None;
459 }
460 self.update(value)
461 }
462}
463
464#[derive(Clone, Debug)]
465pub struct MomBatchRange {
466 pub period: (usize, usize, usize),
467}
468
469impl Default for MomBatchRange {
470 fn default() -> Self {
471 Self {
472 period: (10, 259, 1),
473 }
474 }
475}
476
477#[derive(Clone, Debug, Default)]
478pub struct MomBatchBuilder {
479 range: MomBatchRange,
480 kernel: Kernel,
481}
482
483impl MomBatchBuilder {
484 pub fn new() -> Self {
485 Self::default()
486 }
487 pub fn kernel(mut self, k: Kernel) -> Self {
488 self.kernel = k;
489 self
490 }
491 #[inline]
492 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
493 self.range.period = (start, end, step);
494 self
495 }
496 #[inline]
497 pub fn period_static(mut self, p: usize) -> Self {
498 self.range.period = (p, p, 0);
499 self
500 }
501 pub fn apply_slice(self, data: &[f64]) -> Result<MomBatchOutput, MomError> {
502 mom_batch_with_kernel(data, &self.range, self.kernel)
503 }
504 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<MomBatchOutput, MomError> {
505 MomBatchBuilder::new().kernel(k).apply_slice(data)
506 }
507 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<MomBatchOutput, MomError> {
508 let slice = source_type(c, src);
509 self.apply_slice(slice)
510 }
511 pub fn with_default_candles(c: &Candles) -> Result<MomBatchOutput, MomError> {
512 MomBatchBuilder::new()
513 .kernel(Kernel::Auto)
514 .apply_candles(c, "close")
515 }
516}
517
518pub fn mom_batch_with_kernel(
519 data: &[f64],
520 sweep: &MomBatchRange,
521 k: Kernel,
522) -> Result<MomBatchOutput, MomError> {
523 let kernel = match k {
524 Kernel::Auto => Kernel::ScalarBatch,
525 other if other.is_batch() => other,
526 other => return Err(MomError::InvalidKernelForBatch(other)),
527 };
528
529 let simd = match kernel {
530 Kernel::Avx512Batch => Kernel::Avx512,
531 Kernel::Avx2Batch => Kernel::Avx2,
532 Kernel::ScalarBatch => Kernel::Scalar,
533 _ => unreachable!(),
534 };
535 mom_batch_par_slice(data, sweep, simd)
536}
537
538#[derive(Clone, Debug)]
539pub struct MomBatchOutput {
540 pub values: Vec<f64>,
541 pub combos: Vec<MomParams>,
542 pub rows: usize,
543 pub cols: usize,
544}
545impl MomBatchOutput {
546 pub fn row_for_params(&self, p: &MomParams) -> Option<usize> {
547 self.combos
548 .iter()
549 .position(|c| c.period.unwrap_or(10) == p.period.unwrap_or(10))
550 }
551
552 pub fn values_for(&self, p: &MomParams) -> Option<&[f64]> {
553 self.row_for_params(p).map(|row| {
554 let start = row * self.cols;
555 &self.values[start..start + self.cols]
556 })
557 }
558}
559
560#[inline(always)]
561fn expand_grid_checked(r: &MomBatchRange) -> Result<Vec<MomParams>, MomError> {
562 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, MomError> {
563 if step == 0 || start == end {
564 return Ok(vec![start]);
565 }
566 let mut out = Vec::new();
567 if start < end {
568 let mut v = start;
569 while v <= end {
570 out.push(v);
571 match v.checked_add(step) {
572 Some(next) if next > v => v = next,
573 _ => break,
574 }
575 }
576 } else {
577 let mut v = start;
578 while v >= end {
579 out.push(v);
580 if v < end + step {
581 break;
582 }
583 v = v.saturating_sub(step);
584 if v == 0 {
585 break;
586 }
587 }
588 }
589 if out.is_empty() {
590 return Err(MomError::InvalidRange { start, end, step });
591 }
592 Ok(out)
593 }
594
595 let periods = axis_usize(r.period)?;
596 let mut out = Vec::with_capacity(periods.len());
597 for &p in &periods {
598 out.push(MomParams { period: Some(p) });
599 }
600 Ok(out)
601}
602
603#[inline(always)]
604pub fn mom_batch_slice(
605 data: &[f64],
606 sweep: &MomBatchRange,
607 kern: Kernel,
608) -> Result<MomBatchOutput, MomError> {
609 mom_batch_inner(data, sweep, kern, false)
610}
611
612#[inline(always)]
613pub fn mom_batch_par_slice(
614 data: &[f64],
615 sweep: &MomBatchRange,
616 kern: Kernel,
617) -> Result<MomBatchOutput, MomError> {
618 mom_batch_inner(data, sweep, kern, true)
619}
620
621#[inline(always)]
622fn mom_batch_inner(
623 data: &[f64],
624 sweep: &MomBatchRange,
625 kern: Kernel,
626 parallel: bool,
627) -> Result<MomBatchOutput, MomError> {
628 let combos = expand_grid_checked(sweep)?;
629 let len = data.len();
630 if len == 0 {
631 return Err(MomError::EmptyInputData);
632 }
633 let first = data
634 .iter()
635 .position(|x| !x.is_nan())
636 .ok_or(MomError::AllValuesNaN)?;
637 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
638 if len - first < max_p {
639 return Err(MomError::NotEnoughValidData {
640 needed: max_p,
641 valid: len - first,
642 });
643 }
644
645 let rows = combos.len();
646 let cols = len;
647 let _total = rows
648 .checked_mul(cols)
649 .ok_or(MomError::InvalidInput("rows*cols overflow"))?;
650
651 let mut buf_mu = make_uninit_matrix(rows, cols);
652 let warm: Vec<usize> = combos.iter().map(|c| first + c.period.unwrap()).collect();
653 init_matrix_prefixes(&mut buf_mu, cols, &warm);
654
655 let mut guard = core::mem::ManuallyDrop::new(buf_mu);
656 let out_mu: &mut [std::mem::MaybeUninit<f64>] =
657 unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr(), guard.len()) };
658
659 let simd = match kern {
660 Kernel::Auto => Kernel::Scalar,
661 Kernel::Avx512Batch => Kernel::Avx512,
662 Kernel::Avx2Batch => Kernel::Avx2,
663 Kernel::ScalarBatch => Kernel::Scalar,
664 k if !k.is_batch() => k,
665 _ => unreachable!(),
666 };
667
668 let do_row = |row: usize, row_mu: &mut [std::mem::MaybeUninit<f64>]| {
669 let period = combos[row].period.unwrap();
670 let dst = unsafe {
671 core::slice::from_raw_parts_mut(row_mu.as_mut_ptr() as *mut f64, row_mu.len())
672 };
673 mom_compute_into(data, period, first, simd, dst);
674 };
675
676 if parallel {
677 #[cfg(not(target_arch = "wasm32"))]
678 out_mu
679 .par_chunks_mut(cols)
680 .enumerate()
681 .for_each(|(r, mu)| do_row(r, mu));
682 #[cfg(target_arch = "wasm32")]
683 for (r, mu) in out_mu.chunks_mut(cols).enumerate() {
684 do_row(r, mu);
685 }
686 } else {
687 for (r, mu) in out_mu.chunks_mut(cols).enumerate() {
688 do_row(r, mu);
689 }
690 }
691
692 let values = unsafe {
693 Vec::from_raw_parts(
694 guard.as_mut_ptr() as *mut f64,
695 guard.len(),
696 guard.capacity(),
697 )
698 };
699
700 Ok(MomBatchOutput {
701 values,
702 combos,
703 rows,
704 cols,
705 })
706}
707
708#[inline(always)]
709unsafe fn mom_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
710 for i in (first + period)..data.len() {
711 out[i] = data[i] - data[i - period];
712 }
713}
714
715#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
716#[inline(always)]
717unsafe fn mom_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
718 mom_row_scalar(data, first, period, out)
719}
720
721#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
722#[inline(always)]
723unsafe fn mom_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
724 if period <= 32 {
725 mom_row_avx512_short(data, first, period, out);
726 } else {
727 mom_row_avx512_long(data, first, period, out);
728 }
729}
730
731#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
732#[inline(always)]
733unsafe fn mom_row_avx512_short(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
734 mom_row_scalar(data, first, period, out)
735}
736
737#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
738#[inline(always)]
739unsafe fn mom_row_avx512_long(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
740 mom_row_scalar(data, first, period, out)
741}
742
743#[inline(always)]
744pub fn mom_batch_inner_into(
745 data: &[f64],
746 sweep: &MomBatchRange,
747 kern: Kernel,
748 parallel: bool,
749 output: &mut [f64],
750) -> Result<Vec<MomParams>, MomError> {
751 let combos = expand_grid_checked(sweep)?;
752 let len = data.len();
753 if len == 0 {
754 return Err(MomError::EmptyInputData);
755 }
756 let first = data
757 .iter()
758 .position(|x| !x.is_nan())
759 .ok_or(MomError::AllValuesNaN)?;
760 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
761 if len - first < max_p {
762 return Err(MomError::NotEnoughValidData {
763 needed: max_p,
764 valid: len - first,
765 });
766 }
767
768 let rows = combos.len();
769 let cols = len;
770 let expected = rows
771 .checked_mul(cols)
772 .ok_or(MomError::InvalidInput("rows*cols overflow"))?;
773 if output.len() != expected {
774 return Err(MomError::OutputLengthMismatch {
775 expected,
776 got: output.len(),
777 });
778 }
779
780 let out_mu: &mut [std::mem::MaybeUninit<f64>] = unsafe {
781 core::slice::from_raw_parts_mut(
782 output.as_mut_ptr() as *mut std::mem::MaybeUninit<f64>,
783 output.len(),
784 )
785 };
786 let warm: Vec<usize> = combos.iter().map(|c| first + c.period.unwrap()).collect();
787 init_matrix_prefixes(out_mu, cols, &warm);
788
789 let simd = match kern {
790 Kernel::Auto => Kernel::Scalar,
791 Kernel::Avx512Batch => Kernel::Avx512,
792 Kernel::Avx2Batch => Kernel::Avx2,
793 Kernel::ScalarBatch => Kernel::Scalar,
794 k if !k.is_batch() => k,
795 _ => unreachable!(),
796 };
797
798 let do_row = |row: usize, row_mu: &mut [std::mem::MaybeUninit<f64>]| {
799 let period = combos[row].period.unwrap();
800 let dst = unsafe {
801 core::slice::from_raw_parts_mut(row_mu.as_mut_ptr() as *mut f64, row_mu.len())
802 };
803 mom_compute_into(data, period, first, simd, dst);
804 };
805
806 if parallel {
807 #[cfg(not(target_arch = "wasm32"))]
808 out_mu
809 .par_chunks_mut(cols)
810 .enumerate()
811 .for_each(|(r, mu)| do_row(r, mu));
812 #[cfg(target_arch = "wasm32")]
813 for (r, mu) in out_mu.chunks_mut(cols).enumerate() {
814 do_row(r, mu);
815 }
816 } else {
817 for (r, mu) in out_mu.chunks_mut(cols).enumerate() {
818 do_row(r, mu);
819 }
820 }
821
822 Ok(combos)
823}
824
825#[cfg(test)]
826mod tests {
827 use super::*;
828 use crate::skip_if_unsupported;
829 use crate::utilities::data_loader::read_candles_from_csv;
830 #[cfg(feature = "proptest")]
831 use proptest::prelude::*;
832
833 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
834 #[test]
835 fn test_mom_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
836 let mut data = Vec::with_capacity(256);
837 data.push(f64::NAN);
838 for i in 0..255 {
839 let x = (i as f64).sin() * 10.0 + (i as f64) * 0.5;
840 data.push(x);
841 }
842
843 let params = MomParams { period: Some(10) };
844 let input = MomInput::from_slice(&data, params);
845
846 let base = mom(&input)?.values;
847
848 let mut into_out = vec![0.0; data.len()];
849 mom_into(&input, &mut into_out)?;
850
851 assert_eq!(base.len(), into_out.len());
852
853 for (i, (a, b)) in base.iter().zip(into_out.iter()).enumerate() {
854 let ok = (a.is_nan() && b.is_nan()) || (a == b);
855 assert!(ok, "mismatch at index {}: base={}, into={}", i, a, b);
856 }
857 Ok(())
858 }
859
860 fn check_mom_partial_params(
861 test_name: &str,
862 kernel: Kernel,
863 ) -> Result<(), Box<dyn std::error::Error>> {
864 skip_if_unsupported!(kernel, test_name);
865 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
866 let candles = read_candles_from_csv(file_path)?;
867
868 let default_params = MomParams { period: None };
869 let input = MomInput::from_candles(&candles, "close", default_params);
870 let output = mom_with_kernel(&input, kernel)?;
871 assert_eq!(output.values.len(), candles.close.len());
872
873 Ok(())
874 }
875
876 fn check_mom_accuracy(
877 test_name: &str,
878 kernel: Kernel,
879 ) -> Result<(), Box<dyn std::error::Error>> {
880 skip_if_unsupported!(kernel, test_name);
881 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
882 let candles = read_candles_from_csv(file_path)?;
883
884 let input = MomInput::from_candles(&candles, "close", MomParams::default());
885 let result = mom_with_kernel(&input, kernel)?;
886 let expected_last_five = [-134.0, -331.0, -194.0, -294.0, -896.0];
887 let start = result.values.len().saturating_sub(5);
888 for (i, &val) in result.values[start..].iter().enumerate() {
889 let diff = (val - expected_last_five[i]).abs();
890 assert!(
891 diff < 1e-1,
892 "[{}] MOM {:?} mismatch at idx {}: got {}, expected {}",
893 test_name,
894 kernel,
895 i,
896 val,
897 expected_last_five[i]
898 );
899 }
900 Ok(())
901 }
902
903 fn check_mom_default_candles(
904 test_name: &str,
905 kernel: Kernel,
906 ) -> Result<(), Box<dyn std::error::Error>> {
907 skip_if_unsupported!(kernel, test_name);
908 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
909 let candles = read_candles_from_csv(file_path)?;
910
911 let input = MomInput::with_default_candles(&candles);
912 match input.data {
913 MomData::Candles { source, .. } => assert_eq!(source, "close"),
914 _ => panic!("Expected MomData::Candles"),
915 }
916 let output = mom_with_kernel(&input, kernel)?;
917 assert_eq!(output.values.len(), candles.close.len());
918
919 Ok(())
920 }
921
922 fn check_mom_zero_period(
923 test_name: &str,
924 kernel: Kernel,
925 ) -> Result<(), Box<dyn std::error::Error>> {
926 skip_if_unsupported!(kernel, test_name);
927 let input_data = [10.0, 20.0, 30.0];
928 let params = MomParams { period: Some(0) };
929 let input = MomInput::from_slice(&input_data, params);
930 let res = mom_with_kernel(&input, kernel);
931 assert!(
932 res.is_err(),
933 "[{}] MOM should fail with zero period",
934 test_name
935 );
936 Ok(())
937 }
938
939 fn check_mom_period_exceeds_length(
940 test_name: &str,
941 kernel: Kernel,
942 ) -> Result<(), Box<dyn std::error::Error>> {
943 skip_if_unsupported!(kernel, test_name);
944 let data_small = [10.0, 20.0, 30.0];
945 let params = MomParams { period: Some(10) };
946 let input = MomInput::from_slice(&data_small, params);
947 let res = mom_with_kernel(&input, kernel);
948 assert!(
949 res.is_err(),
950 "[{}] MOM should fail with period exceeding length",
951 test_name
952 );
953 Ok(())
954 }
955
956 fn check_mom_very_small_dataset(
957 test_name: &str,
958 kernel: Kernel,
959 ) -> Result<(), Box<dyn std::error::Error>> {
960 skip_if_unsupported!(kernel, test_name);
961 let single_point = [42.0];
962 let params = MomParams { period: Some(9) };
963 let input = MomInput::from_slice(&single_point, params);
964 let res = mom_with_kernel(&input, kernel);
965 assert!(
966 res.is_err(),
967 "[{}] MOM should fail with insufficient data",
968 test_name
969 );
970 Ok(())
971 }
972
973 fn check_mom_reinput(
974 test_name: &str,
975 kernel: Kernel,
976 ) -> Result<(), Box<dyn std::error::Error>> {
977 skip_if_unsupported!(kernel, test_name);
978 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
979 let candles = read_candles_from_csv(file_path)?;
980
981 let first_params = MomParams { period: Some(14) };
982 let first_input = MomInput::from_candles(&candles, "close", first_params);
983 let first_result = mom_with_kernel(&first_input, kernel)?;
984
985 let second_params = MomParams { period: Some(14) };
986 let second_input = MomInput::from_slice(&first_result.values, second_params);
987 let second_result = mom_with_kernel(&second_input, kernel)?;
988
989 assert_eq!(second_result.values.len(), first_result.values.len());
990 for i in 28..second_result.values.len() {
991 assert!(
992 !second_result.values[i].is_nan(),
993 "[{}] MOM Slice Reinput {:?} unexpected NaN at idx {}",
994 test_name,
995 kernel,
996 i
997 );
998 }
999 Ok(())
1000 }
1001
1002 fn check_mom_nan_handling(
1003 test_name: &str,
1004 kernel: Kernel,
1005 ) -> Result<(), Box<dyn std::error::Error>> {
1006 skip_if_unsupported!(kernel, test_name);
1007 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1008 let candles = read_candles_from_csv(file_path)?;
1009
1010 let input = MomInput::from_candles(&candles, "close", MomParams { period: Some(10) });
1011 let res = mom_with_kernel(&input, kernel)?;
1012 assert_eq!(res.values.len(), candles.close.len());
1013 if res.values.len() > 240 {
1014 for (i, &val) in res.values[240..].iter().enumerate() {
1015 assert!(
1016 !val.is_nan(),
1017 "[{}] Found unexpected NaN at out-index {}",
1018 test_name,
1019 240 + i
1020 );
1021 }
1022 }
1023 Ok(())
1024 }
1025
1026 fn check_mom_streaming(
1027 test_name: &str,
1028 kernel: Kernel,
1029 ) -> Result<(), Box<dyn std::error::Error>> {
1030 skip_if_unsupported!(kernel, test_name);
1031
1032 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1033 let candles = read_candles_from_csv(file_path)?;
1034
1035 let period = 10;
1036 let input = MomInput::from_candles(
1037 &candles,
1038 "close",
1039 MomParams {
1040 period: Some(period),
1041 },
1042 );
1043 let batch_output = mom_with_kernel(&input, kernel)?.values;
1044
1045 let mut stream = MomStream::try_new(MomParams {
1046 period: Some(period),
1047 })?;
1048 let mut stream_values = Vec::with_capacity(candles.close.len());
1049 for &price in &candles.close {
1050 match stream.update(price) {
1051 Some(val) => stream_values.push(val),
1052 None => stream_values.push(f64::NAN),
1053 }
1054 }
1055 assert_eq!(batch_output.len(), stream_values.len());
1056 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1057 if b.is_nan() && s.is_nan() {
1058 continue;
1059 }
1060 let diff = (b - s).abs();
1061 assert!(
1062 diff < 1e-9,
1063 "[{}] MOM streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1064 test_name,
1065 i,
1066 b,
1067 s,
1068 diff
1069 );
1070 }
1071 Ok(())
1072 }
1073
1074 #[cfg(debug_assertions)]
1075 fn check_mom_no_poison(
1076 test_name: &str,
1077 kernel: Kernel,
1078 ) -> Result<(), Box<dyn std::error::Error>> {
1079 skip_if_unsupported!(kernel, test_name);
1080
1081 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1082 let candles = read_candles_from_csv(file_path)?;
1083
1084 let test_params = vec![
1085 MomParams::default(),
1086 MomParams { period: Some(2) },
1087 MomParams { period: Some(5) },
1088 MomParams { period: Some(7) },
1089 MomParams { period: Some(14) },
1090 MomParams { period: Some(20) },
1091 MomParams { period: Some(50) },
1092 MomParams { period: Some(100) },
1093 MomParams { period: Some(200) },
1094 ];
1095
1096 for (param_idx, params) in test_params.iter().enumerate() {
1097 let input = MomInput::from_candles(&candles, "close", params.clone());
1098 let output = mom_with_kernel(&input, kernel)?;
1099
1100 for (i, &val) in output.values.iter().enumerate() {
1101 if val.is_nan() {
1102 continue;
1103 }
1104
1105 let bits = val.to_bits();
1106
1107 if bits == 0x11111111_11111111 {
1108 panic!(
1109 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1110 with params: period={} (param set {})",
1111 test_name,
1112 val,
1113 bits,
1114 i,
1115 params.period.unwrap_or(10),
1116 param_idx
1117 );
1118 }
1119
1120 if bits == 0x22222222_22222222 {
1121 panic!(
1122 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1123 with params: period={} (param set {})",
1124 test_name,
1125 val,
1126 bits,
1127 i,
1128 params.period.unwrap_or(10),
1129 param_idx
1130 );
1131 }
1132
1133 if bits == 0x33333333_33333333 {
1134 panic!(
1135 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1136 with params: period={} (param set {})",
1137 test_name,
1138 val,
1139 bits,
1140 i,
1141 params.period.unwrap_or(10),
1142 param_idx
1143 );
1144 }
1145 }
1146 }
1147
1148 Ok(())
1149 }
1150
1151 #[cfg(not(debug_assertions))]
1152 fn check_mom_no_poison(
1153 _test_name: &str,
1154 _kernel: Kernel,
1155 ) -> Result<(), Box<dyn std::error::Error>> {
1156 Ok(())
1157 }
1158
1159 #[cfg(feature = "proptest")]
1160 #[allow(clippy::float_cmp)]
1161 fn check_mom_property(
1162 test_name: &str,
1163 kernel: Kernel,
1164 ) -> Result<(), Box<dyn std::error::Error>> {
1165 skip_if_unsupported!(kernel, test_name);
1166
1167 let strat = (1usize..=64).prop_flat_map(|period| {
1168 (
1169 prop::collection::vec(
1170 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1171 period..400,
1172 ),
1173 Just(period),
1174 )
1175 });
1176
1177 proptest::test_runner::TestRunner::default()
1178 .run(&strat, |(data, period)| {
1179 let params = MomParams {
1180 period: Some(period),
1181 };
1182 let input = MomInput::from_slice(&data, params.clone());
1183
1184 let MomOutput { values: out } = mom_with_kernel(&input, kernel).unwrap();
1185
1186 let MomOutput { values: ref_out } =
1187 mom_with_kernel(&input, Kernel::Scalar).unwrap();
1188
1189 let first_valid = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1190 let warmup_period = first_valid + period;
1191
1192 for i in 0..warmup_period.min(out.len()) {
1193 prop_assert!(
1194 out[i].is_nan(),
1195 "[{}] Expected NaN during warmup at index {}, got {}",
1196 test_name,
1197 i,
1198 out[i]
1199 );
1200 }
1201
1202 for i in warmup_period..data.len() {
1203 let expected = data[i] - data[i - period];
1204 let actual = out[i];
1205
1206 if expected.is_finite() && actual.is_finite() {
1207 prop_assert!(
1208 (actual - expected).abs() < 1e-10,
1209 "[{}] MOM formula mismatch at index {}: expected {}, got {}",
1210 test_name,
1211 i,
1212 expected,
1213 actual
1214 );
1215 }
1216 }
1217
1218 for i in 0..out.len() {
1219 let y = out[i];
1220 let r = ref_out[i];
1221
1222 if !y.is_finite() || !r.is_finite() {
1223 prop_assert!(
1224 y.to_bits() == r.to_bits(),
1225 "[{}] NaN/Inf mismatch at index {}: {} vs {}",
1226 test_name,
1227 i,
1228 y,
1229 r
1230 );
1231 } else {
1232 let ulp_diff = y.to_bits().abs_diff(r.to_bits());
1233 prop_assert!(
1234 (y - r).abs() <= 1e-10 || ulp_diff <= 4,
1235 "[{}] Kernel mismatch at index {}: {} vs {} (ULP diff: {})",
1236 test_name,
1237 i,
1238 y,
1239 r,
1240 ulp_diff
1241 );
1242 }
1243 }
1244
1245 let all_same = data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10);
1246 if all_same && data.len() > period {
1247 for i in warmup_period..out.len() {
1248 prop_assert!(
1249 out[i].abs() < 1e-10,
1250 "[{}] Constant data should produce zero momentum at index {}, got {}",
1251 test_name,
1252 i,
1253 out[i]
1254 );
1255 }
1256 }
1257
1258 if data.len() > period + 2 {
1259 let diffs: Vec<f64> = data.windows(2).map(|w| w[1] - w[0]).collect();
1260 let is_linear = diffs.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-9);
1261
1262 if is_linear && diffs.len() > 0 {
1263 let expected_momentum = diffs[0] * period as f64;
1264 for i in warmup_period..out.len().min(warmup_period + 10) {
1265 if out[i].is_finite() {
1266 prop_assert!(
1267 (out[i] - expected_momentum).abs() < 1e-9,
1268 "[{}] Linear data should produce constant momentum {} at index {}, got {}",
1269 test_name, expected_momentum, i, out[i]
1270 );
1271 }
1272 }
1273 }
1274 }
1275
1276 if data.len() < 100 {
1277 let neg_data: Vec<f64> = data.iter().map(|&x| -x).collect();
1278 let neg_input = MomInput::from_slice(&neg_data, params);
1279 let MomOutput { values: neg_out } =
1280 mom_with_kernel(&neg_input, kernel).unwrap();
1281
1282 for i in warmup_period..out.len() {
1283 if out[i].is_finite() && neg_out[i].is_finite() {
1284 prop_assert!(
1285 (out[i] + neg_out[i]).abs() < 1e-10,
1286 "[{}] Symmetry violated at index {}: {} vs {}",
1287 test_name,
1288 i,
1289 out[i],
1290 -neg_out[i]
1291 );
1292 }
1293 }
1294 }
1295
1296 if period == 1 && data.len() > 1 {
1297 for i in 1..data.len() {
1298 let expected = data[i] - data[i - 1];
1299 if out[i].is_finite() && expected.is_finite() {
1300 prop_assert!(
1301 (out[i] - expected).abs() < 1e-10,
1302 "[{}] Period=1 should give adjacent differences at index {}: expected {}, got {}",
1303 test_name, i, expected, out[i]
1304 );
1305 }
1306 }
1307 }
1308
1309 if data.len() > period {
1310 let min_val = data
1311 .iter()
1312 .filter(|x| x.is_finite())
1313 .fold(f64::INFINITY, |a, &b| a.min(b));
1314 let max_val = data
1315 .iter()
1316 .filter(|x| x.is_finite())
1317 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
1318 let max_diff = max_val - min_val;
1319
1320 for i in warmup_period..out.len() {
1321 if out[i].is_finite() {
1322 prop_assert!(
1323 out[i].abs() <= max_diff + 1e-9,
1324 "[{}] Output magnitude {} exceeds max possible difference {} at index {}",
1325 test_name, out[i].abs(), max_diff, i
1326 );
1327 }
1328 }
1329 }
1330
1331 Ok(())
1332 })
1333 .unwrap();
1334
1335 Ok(())
1336 }
1337
1338 macro_rules! generate_all_mom_tests {
1339 ($($test_fn:ident),*) => {
1340 paste::paste! {
1341 $( #[test]
1342 fn [<$test_fn _scalar_f64>]() {
1343 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1344 })*
1345 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1346 $( #[test]
1347 fn [<$test_fn _avx2_f64>]() {
1348 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1349 }
1350 #[test]
1351 fn [<$test_fn _avx512_f64>]() {
1352 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1353 })*
1354 }
1355 }
1356 }
1357
1358 generate_all_mom_tests!(
1359 check_mom_partial_params,
1360 check_mom_accuracy,
1361 check_mom_default_candles,
1362 check_mom_zero_period,
1363 check_mom_period_exceeds_length,
1364 check_mom_very_small_dataset,
1365 check_mom_reinput,
1366 check_mom_nan_handling,
1367 check_mom_streaming,
1368 check_mom_no_poison
1369 );
1370
1371 #[cfg(feature = "proptest")]
1372 generate_all_mom_tests!(check_mom_property);
1373
1374 fn check_batch_default_row(
1375 test: &str,
1376 kernel: Kernel,
1377 ) -> Result<(), Box<dyn std::error::Error>> {
1378 skip_if_unsupported!(kernel, test);
1379
1380 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1381 let c = read_candles_from_csv(file)?;
1382
1383 let output = MomBatchBuilder::new()
1384 .kernel(kernel)
1385 .apply_candles(&c, "close")?;
1386
1387 let def = MomParams::default();
1388 let row = output.values_for(&def).expect("default row missing");
1389
1390 assert_eq!(row.len(), c.close.len());
1391
1392 let expected = [-134.0, -331.0, -194.0, -294.0, -896.0];
1393 let start = row.len() - 5;
1394 for (i, &v) in row[start..].iter().enumerate() {
1395 assert!(
1396 (v - expected[i]).abs() < 1e-1,
1397 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1398 );
1399 }
1400 Ok(())
1401 }
1402
1403 #[cfg(debug_assertions)]
1404 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
1405 skip_if_unsupported!(kernel, test);
1406
1407 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1408 let c = read_candles_from_csv(file)?;
1409
1410 let test_configs = vec![
1411 (2, 10, 2),
1412 (5, 25, 5),
1413 (30, 60, 15),
1414 (2, 5, 1),
1415 (10, 50, 10),
1416 (100, 200, 50),
1417 ];
1418
1419 for (cfg_idx, &(p_start, p_end, p_step)) in test_configs.iter().enumerate() {
1420 let output = MomBatchBuilder::new()
1421 .kernel(kernel)
1422 .period_range(p_start, p_end, p_step)
1423 .apply_candles(&c, "close")?;
1424
1425 for (idx, &val) in output.values.iter().enumerate() {
1426 if val.is_nan() {
1427 continue;
1428 }
1429
1430 let bits = val.to_bits();
1431 let row = idx / output.cols;
1432 let col = idx % output.cols;
1433 let combo = &output.combos[row];
1434
1435 if bits == 0x11111111_11111111 {
1436 panic!(
1437 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1438 at row {} col {} (flat index {}) with params: period={}",
1439 test,
1440 cfg_idx,
1441 val,
1442 bits,
1443 row,
1444 col,
1445 idx,
1446 combo.period.unwrap_or(10)
1447 );
1448 }
1449
1450 if bits == 0x22222222_22222222 {
1451 panic!(
1452 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1453 at row {} col {} (flat index {}) with params: period={}",
1454 test,
1455 cfg_idx,
1456 val,
1457 bits,
1458 row,
1459 col,
1460 idx,
1461 combo.period.unwrap_or(10)
1462 );
1463 }
1464
1465 if bits == 0x33333333_33333333 {
1466 panic!(
1467 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1468 at row {} col {} (flat index {}) with params: period={}",
1469 test,
1470 cfg_idx,
1471 val,
1472 bits,
1473 row,
1474 col,
1475 idx,
1476 combo.period.unwrap_or(10)
1477 );
1478 }
1479 }
1480 }
1481
1482 Ok(())
1483 }
1484
1485 #[cfg(not(debug_assertions))]
1486 fn check_batch_no_poison(
1487 _test: &str,
1488 _kernel: Kernel,
1489 ) -> Result<(), Box<dyn std::error::Error>> {
1490 Ok(())
1491 }
1492
1493 macro_rules! gen_batch_tests {
1494 ($fn_name:ident) => {
1495 paste::paste! {
1496 #[test] fn [<$fn_name _scalar>]() {
1497 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1498 }
1499 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1500 #[test] fn [<$fn_name _avx2>]() {
1501 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1502 }
1503 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1504 #[test] fn [<$fn_name _avx512>]() {
1505 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1506 }
1507 #[test] fn [<$fn_name _auto_detect>]() {
1508 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1509 }
1510 }
1511 };
1512 }
1513 gen_batch_tests!(check_batch_default_row);
1514 gen_batch_tests!(check_batch_no_poison);
1515}
1516
1517#[inline(always)]
1518pub fn mom_into_slice(dst: &mut [f64], input: &MomInput, kernel: Kernel) -> Result<(), MomError> {
1519 let (data, period, first, chosen) = mom_prepare(input, kernel)?;
1520 if dst.len() != data.len() {
1521 return Err(MomError::OutputLengthMismatch {
1522 expected: data.len(),
1523 got: dst.len(),
1524 });
1525 }
1526 let warm = first + period;
1527 for v in &mut dst[..warm] {
1528 *v = f64::NAN;
1529 }
1530 mom_compute_into(data, period, first, chosen, dst);
1531 Ok(())
1532}
1533
1534#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1535use serde::{Deserialize, Serialize};
1536#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1537use wasm_bindgen::prelude::*;
1538
1539#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1540#[wasm_bindgen]
1541pub fn mom_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1542 let params = MomParams {
1543 period: Some(period),
1544 };
1545 let input = MomInput::from_slice(data, params);
1546
1547 let mut output = vec![0.0; data.len()];
1548 mom_into_slice(&mut output, &input, Kernel::Auto)
1549 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1550
1551 Ok(output)
1552}
1553
1554#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1555#[wasm_bindgen]
1556pub fn mom_into(
1557 in_ptr: *const f64,
1558 out_ptr: *mut f64,
1559 len: usize,
1560 period: usize,
1561) -> Result<(), JsValue> {
1562 if in_ptr.is_null() || out_ptr.is_null() {
1563 return Err(JsValue::from_str("Null pointer provided"));
1564 }
1565
1566 unsafe {
1567 let data = std::slice::from_raw_parts(in_ptr, len);
1568 let params = MomParams {
1569 period: Some(period),
1570 };
1571 let input = MomInput::from_slice(data, params);
1572
1573 if in_ptr == out_ptr {
1574 let mut temp = vec![0.0; len];
1575 mom_into_slice(&mut temp, &input, Kernel::Auto)
1576 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1577 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1578 out.copy_from_slice(&temp);
1579 } else {
1580 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1581 mom_into_slice(out, &input, Kernel::Auto)
1582 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1583 }
1584 Ok(())
1585 }
1586}
1587
1588#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1589#[wasm_bindgen]
1590pub fn mom_alloc(len: usize) -> *mut f64 {
1591 let mut vec = Vec::<f64>::with_capacity(len);
1592 let ptr = vec.as_mut_ptr();
1593 std::mem::forget(vec);
1594 ptr
1595}
1596
1597#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1598#[wasm_bindgen]
1599pub fn mom_free(ptr: *mut f64, len: usize) {
1600 if !ptr.is_null() {
1601 unsafe {
1602 let _ = Vec::from_raw_parts(ptr, len, len);
1603 }
1604 }
1605}
1606
1607#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1608#[derive(Serialize, Deserialize)]
1609pub struct MomBatchConfig {
1610 pub period_range: (usize, usize, usize),
1611}
1612
1613#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1614#[derive(Serialize, Deserialize)]
1615pub struct MomBatchJsOutput {
1616 pub values: Vec<f64>,
1617 pub combos: Vec<MomParams>,
1618 pub rows: usize,
1619 pub cols: usize,
1620}
1621
1622#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1623#[wasm_bindgen(js_name = mom_batch)]
1624pub fn mom_batch_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1625 let config: MomBatchConfig = serde_wasm_bindgen::from_value(config)
1626 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1627
1628 let sweep = MomBatchRange {
1629 period: config.period_range,
1630 };
1631
1632 let output = mom_batch_inner(data, &sweep, Kernel::Auto, false)
1633 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1634
1635 let js_output = MomBatchJsOutput {
1636 values: output.values,
1637 combos: output.combos,
1638 rows: output.rows,
1639 cols: output.cols,
1640 };
1641
1642 serde_wasm_bindgen::to_value(&js_output)
1643 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1644}
1645
1646#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1647#[wasm_bindgen]
1648pub fn mom_batch_into(
1649 in_ptr: *const f64,
1650 out_ptr: *mut f64,
1651 len: usize,
1652 period_start: usize,
1653 period_end: usize,
1654 period_step: usize,
1655) -> Result<usize, JsValue> {
1656 if in_ptr.is_null() || out_ptr.is_null() {
1657 return Err(JsValue::from_str("null pointer passed to mom_batch_into"));
1658 }
1659
1660 unsafe {
1661 let data = std::slice::from_raw_parts(in_ptr, len);
1662
1663 let sweep = MomBatchRange {
1664 period: (period_start, period_end, period_step),
1665 };
1666
1667 let combos = expand_grid_checked(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1668 let rows = combos.len();
1669 let cols = len;
1670 let total = rows
1671 .checked_mul(cols)
1672 .ok_or_else(|| JsValue::from_str("size overflow"))?;
1673
1674 let out = std::slice::from_raw_parts_mut(out_ptr, total);
1675
1676 mom_batch_inner_into(data, &sweep, Kernel::Auto, false, out)
1677 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1678
1679 Ok(rows)
1680 }
1681}
1682
1683#[cfg(feature = "python")]
1684use crate::utilities::kernel_validation::validate_kernel;
1685#[cfg(feature = "python")]
1686use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
1687#[cfg(feature = "python")]
1688use pyo3::exceptions::PyValueError;
1689#[cfg(feature = "python")]
1690use pyo3::prelude::*;
1691#[cfg(feature = "python")]
1692use pyo3::types::PyDict;
1693
1694#[cfg(feature = "python")]
1695#[pyfunction(name = "mom")]
1696#[pyo3(signature = (data, period, kernel=None))]
1697pub fn mom_py<'py>(
1698 py: Python<'py>,
1699 data: PyReadonlyArray1<'py, f64>,
1700 period: usize,
1701 kernel: Option<&str>,
1702) -> PyResult<Bound<'py, PyArray1<f64>>> {
1703 let slice_in = data.as_slice()?;
1704 let kern = validate_kernel(kernel, false)?;
1705
1706 let params = MomParams {
1707 period: Some(period),
1708 };
1709 let input = MomInput::from_slice(slice_in, params);
1710
1711 let result_vec: Vec<f64> = py
1712 .allow_threads(|| mom_with_kernel(&input, kern).map(|o| o.values))
1713 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1714
1715 Ok(result_vec.into_pyarray(py))
1716}
1717
1718#[cfg(feature = "python")]
1719#[pyfunction(name = "mom_batch")]
1720#[pyo3(signature = (data, period_range, kernel=None))]
1721pub fn mom_batch_py<'py>(
1722 py: Python<'py>,
1723 data: PyReadonlyArray1<'py, f64>,
1724 period_range: (usize, usize, usize),
1725 kernel: Option<&str>,
1726) -> PyResult<Bound<'py, PyDict>> {
1727 let slice_in = data.as_slice()?;
1728 let kern = validate_kernel(kernel, true)?;
1729
1730 let sweep = MomBatchRange {
1731 period: period_range,
1732 };
1733
1734 let combos = expand_grid_checked(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1735 let rows = combos.len();
1736 let cols = slice_in.len();
1737 let total = rows
1738 .checked_mul(cols)
1739 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
1740
1741 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1742 let slice_out = unsafe { out_arr.as_slice_mut()? };
1743
1744 let combos = py
1745 .allow_threads(|| {
1746 let kernel = match kern {
1747 Kernel::Auto => detect_best_batch_kernel(),
1748 k => k,
1749 };
1750
1751 mom_batch_inner_into(slice_in, &sweep, kernel, true, slice_out)
1752 })
1753 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1754
1755 let dict = PyDict::new(py);
1756 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1757 dict.set_item(
1758 "periods",
1759 combos
1760 .iter()
1761 .map(|p| p.period.unwrap() as u64)
1762 .collect::<Vec<_>>()
1763 .into_pyarray(py),
1764 )?;
1765
1766 Ok(dict)
1767}
1768
1769#[cfg(feature = "python")]
1770#[pyclass(name = "MomStream")]
1771pub struct MomStreamPy {
1772 inner: MomStream,
1773}
1774
1775#[cfg(feature = "python")]
1776#[pymethods]
1777impl MomStreamPy {
1778 #[new]
1779 pub fn new(period: usize) -> PyResult<Self> {
1780 let params = MomParams {
1781 period: Some(period),
1782 };
1783 let inner = MomStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1784 Ok(MomStreamPy { inner })
1785 }
1786
1787 pub fn update(&mut self, value: f64) -> Option<f64> {
1788 self.inner.update(value)
1789 }
1790}
1791
1792#[cfg(feature = "python")]
1793pub fn register_mom_module(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
1794 m.add_function(wrap_pyfunction!(mom_py, m)?)?;
1795 m.add_function(wrap_pyfunction!(mom_batch_py, m)?)?;
1796 Ok(())
1797}
1798
1799#[cfg(all(feature = "python", feature = "cuda"))]
1800use crate::cuda::moving_averages::DeviceArrayF32;
1801#[cfg(all(feature = "python", feature = "cuda"))]
1802use crate::cuda::oscillators::CudaMom;
1803#[cfg(all(feature = "python", feature = "cuda"))]
1804use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
1805#[cfg(all(feature = "python", feature = "cuda"))]
1806use cust::context::Context;
1807#[cfg(all(feature = "python", feature = "cuda"))]
1808use std::sync::Arc;
1809
1810#[cfg(all(feature = "python", feature = "cuda"))]
1811#[pyclass(module = "ta_indicators.cuda", unsendable)]
1812pub struct MomDeviceArrayF32Py {
1813 pub(crate) inner: Option<DeviceArrayF32>,
1814 pub(crate) ctx: Arc<Context>,
1815 pub(crate) device_id: i32,
1816}
1817
1818#[cfg(all(feature = "python", feature = "cuda"))]
1819#[pymethods]
1820impl MomDeviceArrayF32Py {
1821 #[getter]
1822 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1823 let inner = self
1824 .inner
1825 .as_ref()
1826 .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
1827 let d = PyDict::new(py);
1828 d.set_item("shape", (inner.rows, inner.cols))?;
1829 d.set_item("typestr", "<f4")?;
1830 d.set_item(
1831 "strides",
1832 (
1833 inner.cols * std::mem::size_of::<f32>(),
1834 std::mem::size_of::<f32>(),
1835 ),
1836 )?;
1837 d.set_item("data", (inner.device_ptr() as usize, false))?;
1838
1839 d.set_item("version", 3)?;
1840 Ok(d)
1841 }
1842
1843 fn __dlpack_device__(&self) -> (i32, i32) {
1844 (2, self.device_id)
1845 }
1846
1847 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1848 fn __dlpack__<'py>(
1849 &mut self,
1850 py: Python<'py>,
1851 stream: Option<PyObject>,
1852 max_version: Option<PyObject>,
1853 dl_device: Option<PyObject>,
1854 copy: Option<PyObject>,
1855 ) -> PyResult<PyObject> {
1856 let (kdl, alloc_dev) = self.__dlpack_device__();
1857 if let Some(dev_obj) = dl_device.as_ref() {
1858 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1859 if dev_ty != kdl || dev_id != alloc_dev {
1860 let wants_copy = copy
1861 .as_ref()
1862 .and_then(|c| c.extract::<bool>(py).ok())
1863 .unwrap_or(false);
1864 if wants_copy {
1865 return Err(PyValueError::new_err(
1866 "device copy not implemented for __dlpack__",
1867 ));
1868 } else {
1869 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
1870 }
1871 }
1872 }
1873 }
1874 let _ = stream;
1875
1876 let inner = self
1877 .inner
1878 .take()
1879 .ok_or_else(|| PyValueError::new_err("__dlpack__ may only be called once"))?;
1880
1881 let rows = inner.rows;
1882 let cols = inner.cols;
1883 let buf = inner.buf;
1884
1885 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1886
1887 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1888 }
1889}
1890
1891#[cfg(all(feature = "python", feature = "cuda"))]
1892#[pyfunction(name = "mom_cuda_batch_dev")]
1893#[pyo3(signature = (data, period_range, device_id=0))]
1894pub fn mom_cuda_batch_dev_py(
1895 py: Python<'_>,
1896 data: numpy::PyReadonlyArray1<'_, f32>,
1897 period_range: (usize, usize, usize),
1898 device_id: usize,
1899) -> PyResult<MomDeviceArrayF32Py> {
1900 use crate::cuda::cuda_available;
1901 if !cuda_available() {
1902 return Err(PyValueError::new_err("CUDA not available"));
1903 }
1904 let slice = data.as_slice()?;
1905 if slice.is_empty() {
1906 return Err(PyValueError::new_err("empty input"));
1907 }
1908 let sweep = MomBatchRange {
1909 period: period_range,
1910 };
1911 let (inner, ctx, dev_id) = py.allow_threads(|| {
1912 let cuda = CudaMom::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1913 let ctx = cuda.context_arc();
1914 let dev_id = cuda.device_id() as i32;
1915 let arr = cuda
1916 .mom_batch_dev(slice, &sweep)
1917 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1918 Ok::<_, PyErr>((arr, ctx, dev_id))
1919 })?;
1920 Ok(MomDeviceArrayF32Py {
1921 inner: Some(inner),
1922 ctx,
1923 device_id: dev_id,
1924 })
1925}
1926
1927#[cfg(all(feature = "python", feature = "cuda"))]
1928#[pyfunction(name = "mom_cuda_many_series_one_param_dev")]
1929#[pyo3(signature = (data_tm, cols, rows, period, device_id=0))]
1930pub fn mom_cuda_many_series_one_param_dev_py(
1931 py: Python<'_>,
1932 data_tm: numpy::PyReadonlyArray1<'_, f32>,
1933 cols: usize,
1934 rows: usize,
1935 period: usize,
1936 device_id: usize,
1937) -> PyResult<MomDeviceArrayF32Py> {
1938 use crate::cuda::cuda_available;
1939 if !cuda_available() {
1940 return Err(PyValueError::new_err("CUDA not available"));
1941 }
1942 let slice = data_tm.as_slice()?;
1943 let expected = cols
1944 .checked_mul(rows)
1945 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
1946 if slice.len() != expected {
1947 return Err(PyValueError::new_err("time-major input length mismatch"));
1948 }
1949 let (inner, ctx, dev_id) = py.allow_threads(|| {
1950 let cuda = CudaMom::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1951 let ctx = cuda.context_arc();
1952 let dev_id = cuda.device_id() as i32;
1953 let arr = cuda
1954 .mom_many_series_one_param_time_major_dev(slice, cols, rows, period)
1955 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1956 Ok::<_, PyErr>((arr, ctx, dev_id))
1957 })?;
1958 Ok(MomDeviceArrayF32Py {
1959 inner: Some(inner),
1960 ctx,
1961 device_id: dev_id,
1962 })
1963}