Skip to main content

vector_ta/indicators/
mom.rs

1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5    make_uninit_matrix,
6};
7#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
8use core::arch::x86_64::*;
9#[cfg(not(target_arch = "wasm32"))]
10use rayon::prelude::*;
11use std::convert::AsRef;
12use thiserror::Error;
13
14impl<'a> AsRef<[f64]> for MomInput<'a> {
15    #[inline(always)]
16    fn as_ref(&self) -> &[f64] {
17        match &self.data {
18            MomData::Slice(slice) => slice,
19            MomData::Candles { candles, source } => source_type(candles, source),
20        }
21    }
22}
23
24#[derive(Debug, Clone)]
25pub enum MomData<'a> {
26    Candles {
27        candles: &'a Candles,
28        source: &'a str,
29    },
30    Slice(&'a [f64]),
31}
32
33#[derive(Debug, Clone)]
34pub struct MomOutput {
35    pub values: Vec<f64>,
36}
37
38#[derive(Debug, Clone)]
39#[cfg_attr(
40    all(target_arch = "wasm32", feature = "wasm"),
41    derive(Serialize, Deserialize)
42)]
43pub struct MomParams {
44    pub period: Option<usize>,
45}
46
47impl Default for MomParams {
48    fn default() -> Self {
49        Self { period: Some(10) }
50    }
51}
52
53#[derive(Debug, Clone)]
54pub struct MomInput<'a> {
55    pub data: MomData<'a>,
56    pub params: MomParams,
57}
58
59impl<'a> MomInput<'a> {
60    #[inline]
61    pub fn from_candles(c: &'a Candles, s: &'a str, p: MomParams) -> Self {
62        Self {
63            data: MomData::Candles {
64                candles: c,
65                source: s,
66            },
67            params: p,
68        }
69    }
70    #[inline]
71    pub fn from_slice(sl: &'a [f64], p: MomParams) -> Self {
72        Self {
73            data: MomData::Slice(sl),
74            params: p,
75        }
76    }
77    #[inline]
78    pub fn with_default_candles(c: &'a Candles) -> Self {
79        Self::from_candles(c, "close", MomParams::default())
80    }
81    #[inline]
82    pub fn get_period(&self) -> usize {
83        self.params.period.unwrap_or(10)
84    }
85}
86
87#[derive(Clone, Debug)]
88pub struct MomBuilder {
89    period: Option<usize>,
90    kernel: Kernel,
91}
92
93impl Default for MomBuilder {
94    fn default() -> Self {
95        Self {
96            period: None,
97            kernel: Kernel::Auto,
98        }
99    }
100}
101
102impl MomBuilder {
103    #[inline(always)]
104    pub fn new() -> Self {
105        Self::default()
106    }
107    #[inline(always)]
108    pub fn period(mut self, n: usize) -> Self {
109        self.period = Some(n);
110        self
111    }
112    #[inline(always)]
113    pub fn kernel(mut self, k: Kernel) -> Self {
114        self.kernel = k;
115        self
116    }
117    #[inline(always)]
118    pub fn apply(self, c: &Candles) -> Result<MomOutput, MomError> {
119        let p = MomParams {
120            period: self.period,
121        };
122        let i = MomInput::from_candles(c, "close", p);
123        mom_with_kernel(&i, self.kernel)
124    }
125    #[inline(always)]
126    pub fn apply_slice(self, d: &[f64]) -> Result<MomOutput, MomError> {
127        let p = MomParams {
128            period: self.period,
129        };
130        let i = MomInput::from_slice(d, p);
131        mom_with_kernel(&i, self.kernel)
132    }
133    #[inline(always)]
134    pub fn into_stream(self) -> Result<MomStream, MomError> {
135        let p = MomParams {
136            period: self.period,
137        };
138        MomStream::try_new(p)
139    }
140}
141
142#[derive(Debug, Error)]
143pub enum MomError {
144    #[error("mom: Input data slice is empty.")]
145    EmptyInputData,
146    #[error("mom: All values are NaN.")]
147    AllValuesNaN,
148
149    #[error("mom: Invalid period: period = {period}, data length = {data_len}")]
150    InvalidPeriod { period: usize, data_len: usize },
151
152    #[error("mom: Not enough valid data: needed = {needed}, valid = {valid}")]
153    NotEnoughValidData { needed: usize, valid: usize },
154
155    #[error("mom: Output length mismatch (expected {expected}, got {got})")]
156    OutputLengthMismatch { expected: usize, got: usize },
157
158    #[error("mom: invalid range: start={start} end={end} step={step}")]
159    InvalidRange {
160        start: usize,
161        end: usize,
162        step: usize,
163    },
164
165    #[error("mom: invalid kernel for batch path: {0:?}")]
166    InvalidKernelForBatch(Kernel),
167
168    #[error("mom: invalid input: {0}")]
169    InvalidInput(&'static str),
170}
171
172#[inline]
173pub fn mom(input: &MomInput) -> Result<MomOutput, MomError> {
174    mom_with_kernel(input, Kernel::Auto)
175}
176
177#[inline(always)]
178fn mom_prepare<'a>(
179    input: &'a MomInput,
180    kernel: Kernel,
181) -> Result<(&'a [f64], usize, usize, Kernel), MomError> {
182    let data: &[f64] = input.as_ref();
183    let len = data.len();
184    if len == 0 {
185        return Err(MomError::EmptyInputData);
186    }
187    let first = data
188        .iter()
189        .position(|x| !x.is_nan())
190        .ok_or(MomError::AllValuesNaN)?;
191    let period = input.get_period();
192
193    if period == 0 || period > len {
194        return Err(MomError::InvalidPeriod {
195            period,
196            data_len: len,
197        });
198    }
199    if len - first < period {
200        return Err(MomError::NotEnoughValidData {
201            needed: period,
202            valid: len - first,
203        });
204    }
205
206    let chosen = match kernel {
207        Kernel::Auto => Kernel::Scalar,
208        k => k,
209    };
210    Ok((data, period, first, chosen))
211}
212
213#[inline(always)]
214fn mom_compute_into(data: &[f64], period: usize, first: usize, kernel: Kernel, out: &mut [f64]) {
215    unsafe {
216        match kernel {
217            Kernel::Scalar | Kernel::ScalarBatch => mom_scalar(data, period, first, out),
218            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
219            Kernel::Avx2 | Kernel::Avx2Batch => mom_avx2(data, period, first, out),
220            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
221            Kernel::Avx512 | Kernel::Avx512Batch => mom_avx512(data, period, first, out),
222            _ => unreachable!(),
223        }
224    }
225}
226
227pub fn mom_with_kernel(input: &MomInput, kernel: Kernel) -> Result<MomOutput, MomError> {
228    let (data, period, first, chosen) = mom_prepare(input, kernel)?;
229    let warm = first + period;
230    let mut out = alloc_with_nan_prefix(data.len(), warm);
231    mom_compute_into(data, period, first, chosen, &mut out);
232    Ok(MomOutput { values: out })
233}
234
235#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
236#[inline]
237pub fn mom_into(input: &MomInput, out: &mut [f64]) -> Result<(), MomError> {
238    let (data, period, first, chosen) = mom_prepare(input, Kernel::Auto)?;
239
240    if out.len() != data.len() {
241        return Err(MomError::OutputLengthMismatch {
242            expected: data.len(),
243            got: out.len(),
244        });
245    }
246
247    let warm = first + period;
248    let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
249    let n = warm.min(out.len());
250    for v in &mut out[..n] {
251        *v = qnan;
252    }
253
254    mom_compute_into(data, period, first, chosen, out);
255
256    Ok(())
257}
258
259#[inline(always)]
260pub fn mom_scalar(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
261    let n = data.len();
262    let start = first_valid + period;
263    if start >= n {
264        return;
265    }
266
267    unsafe {
268        let mut cur = data.as_ptr().add(start);
269        let mut prev = data.as_ptr().add(start - period);
270        let mut dst = out.as_mut_ptr().add(start);
271
272        let mut m = n - start;
273        while m >= 4 {
274            *dst = *cur - *prev;
275            *dst.add(1) = *cur.add(1) - *prev.add(1);
276            *dst.add(2) = *cur.add(2) - *prev.add(2);
277            *dst.add(3) = *cur.add(3) - *prev.add(3);
278
279            cur = cur.add(4);
280            prev = prev.add(4);
281            dst = dst.add(4);
282            m -= 4;
283        }
284
285        while m != 0 {
286            *dst = *cur - *prev;
287            cur = cur.add(1);
288            prev = prev.add(1);
289            dst = dst.add(1);
290            m -= 1;
291        }
292    }
293}
294
295#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
296#[inline(always)]
297pub fn mom_avx512(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
298    unsafe {
299        if period <= 32 {
300            mom_avx512_short(data, period, first_valid, out);
301        } else {
302            mom_avx512_long(data, period, first_valid, out);
303        }
304    }
305}
306
307#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
308#[target_feature(enable = "avx512f")]
309unsafe fn mom_avx512_short(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
310    mom_avx512_core(data, period, first_valid, out)
311}
312
313#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
314#[target_feature(enable = "avx512f")]
315unsafe fn mom_avx512_long(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
316    mom_avx512_core(data, period, first_valid, out)
317}
318
319#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
320#[target_feature(enable = "avx2")]
321unsafe fn mom_avx2(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
322    use core::arch::x86_64::*;
323    let n = data.len();
324    let start = first_valid + period;
325    if start >= n {
326        return;
327    }
328
329    let mut cur = data.as_ptr().add(start);
330    let mut prev = data.as_ptr().add(start - period);
331    let mut dst = out.as_mut_ptr().add(start);
332
333    let mut m = n - start;
334    while m >= 4 {
335        let a = _mm256_loadu_pd(cur);
336        let b = _mm256_loadu_pd(prev);
337        let d = _mm256_sub_pd(a, b);
338        _mm256_storeu_pd(dst, d);
339
340        cur = cur.add(4);
341        prev = prev.add(4);
342        dst = dst.add(4);
343        m -= 4;
344    }
345
346    while m != 0 {
347        *dst = *cur - *prev;
348        cur = cur.add(1);
349        prev = prev.add(1);
350        dst = dst.add(1);
351        m -= 1;
352    }
353}
354
355#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
356#[target_feature(enable = "avx512f")]
357unsafe fn mom_avx512_core(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
358    use core::arch::x86_64::*;
359    let n = data.len();
360    let start = first_valid + period;
361    if start >= n {
362        return;
363    }
364
365    let mut cur = data.as_ptr().add(start);
366    let mut prev = data.as_ptr().add(start - period);
367    let mut dst = out.as_mut_ptr().add(start);
368
369    let mut m = n - start;
370    while m >= 8 {
371        let a = _mm512_loadu_pd(cur);
372        let b = _mm512_loadu_pd(prev);
373        let d = _mm512_sub_pd(a, b);
374        _mm512_storeu_pd(dst, d);
375
376        cur = cur.add(8);
377        prev = prev.add(8);
378        dst = dst.add(8);
379        m -= 8;
380    }
381
382    if m != 0 {
383        let mask: __mmask8 = ((1u16 << m) - 1) as u8;
384        let a = _mm512_maskz_loadu_pd(mask, cur);
385        let b = _mm512_maskz_loadu_pd(mask, prev);
386        let d = _mm512_sub_pd(a, b);
387        _mm512_mask_storeu_pd(dst, mask, d);
388    }
389}
390
391#[derive(Debug, Clone)]
392pub struct MomStream {
393    period: usize,
394    buffer: Vec<f64>,
395    head: usize,
396    filled: bool,
397}
398
399impl MomStream {
400    pub fn try_new(params: MomParams) -> Result<Self, MomError> {
401        let period = params.period.unwrap_or(10);
402        if period == 0 {
403            return Err(MomError::InvalidPeriod {
404                period,
405                data_len: 0,
406            });
407        }
408        Ok(Self {
409            period,
410            buffer: vec![f64::NAN; period],
411            head: 0,
412            filled: false,
413        })
414    }
415
416    #[inline(always)]
417    pub fn update(&mut self, value: f64) -> Option<f64> {
418        if self.period == 1 {
419            let prev = self.buffer[0];
420            self.buffer[0] = value;
421
422            if !self.filled {
423                self.filled = true;
424                return None;
425            }
426            return Some(value - prev);
427        }
428
429        let idx = self.head;
430        let prev = self.buffer[idx];
431        self.buffer[idx] = value;
432
433        let next = idx + 1;
434        if self.period.is_power_of_two() {
435            let mask = self.period - 1;
436            self.head = next & mask;
437        } else if next == self.period {
438            self.head = 0;
439        } else {
440            self.head = next;
441        }
442
443        if !self.filled {
444            if self.head == 0 {
445                self.filled = true;
446            }
447            return None;
448        }
449
450        Some(value - prev)
451    }
452
453    #[inline(always)]
454    pub fn update_reset_on_nan(&mut self, value: f64) -> Option<f64> {
455        if value.is_nan() {
456            self.head = 0;
457            self.filled = false;
458            return None;
459        }
460        self.update(value)
461    }
462}
463
464#[derive(Clone, Debug)]
465pub struct MomBatchRange {
466    pub period: (usize, usize, usize),
467}
468
469impl Default for MomBatchRange {
470    fn default() -> Self {
471        Self {
472            period: (10, 259, 1),
473        }
474    }
475}
476
477#[derive(Clone, Debug, Default)]
478pub struct MomBatchBuilder {
479    range: MomBatchRange,
480    kernel: Kernel,
481}
482
483impl MomBatchBuilder {
484    pub fn new() -> Self {
485        Self::default()
486    }
487    pub fn kernel(mut self, k: Kernel) -> Self {
488        self.kernel = k;
489        self
490    }
491    #[inline]
492    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
493        self.range.period = (start, end, step);
494        self
495    }
496    #[inline]
497    pub fn period_static(mut self, p: usize) -> Self {
498        self.range.period = (p, p, 0);
499        self
500    }
501    pub fn apply_slice(self, data: &[f64]) -> Result<MomBatchOutput, MomError> {
502        mom_batch_with_kernel(data, &self.range, self.kernel)
503    }
504    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<MomBatchOutput, MomError> {
505        MomBatchBuilder::new().kernel(k).apply_slice(data)
506    }
507    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<MomBatchOutput, MomError> {
508        let slice = source_type(c, src);
509        self.apply_slice(slice)
510    }
511    pub fn with_default_candles(c: &Candles) -> Result<MomBatchOutput, MomError> {
512        MomBatchBuilder::new()
513            .kernel(Kernel::Auto)
514            .apply_candles(c, "close")
515    }
516}
517
518pub fn mom_batch_with_kernel(
519    data: &[f64],
520    sweep: &MomBatchRange,
521    k: Kernel,
522) -> Result<MomBatchOutput, MomError> {
523    let kernel = match k {
524        Kernel::Auto => Kernel::ScalarBatch,
525        other if other.is_batch() => other,
526        other => return Err(MomError::InvalidKernelForBatch(other)),
527    };
528
529    let simd = match kernel {
530        Kernel::Avx512Batch => Kernel::Avx512,
531        Kernel::Avx2Batch => Kernel::Avx2,
532        Kernel::ScalarBatch => Kernel::Scalar,
533        _ => unreachable!(),
534    };
535    mom_batch_par_slice(data, sweep, simd)
536}
537
538#[derive(Clone, Debug)]
539pub struct MomBatchOutput {
540    pub values: Vec<f64>,
541    pub combos: Vec<MomParams>,
542    pub rows: usize,
543    pub cols: usize,
544}
545impl MomBatchOutput {
546    pub fn row_for_params(&self, p: &MomParams) -> Option<usize> {
547        self.combos
548            .iter()
549            .position(|c| c.period.unwrap_or(10) == p.period.unwrap_or(10))
550    }
551
552    pub fn values_for(&self, p: &MomParams) -> Option<&[f64]> {
553        self.row_for_params(p).map(|row| {
554            let start = row * self.cols;
555            &self.values[start..start + self.cols]
556        })
557    }
558}
559
560#[inline(always)]
561fn expand_grid_checked(r: &MomBatchRange) -> Result<Vec<MomParams>, MomError> {
562    fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, MomError> {
563        if step == 0 || start == end {
564            return Ok(vec![start]);
565        }
566        let mut out = Vec::new();
567        if start < end {
568            let mut v = start;
569            while v <= end {
570                out.push(v);
571                match v.checked_add(step) {
572                    Some(next) if next > v => v = next,
573                    _ => break,
574                }
575            }
576        } else {
577            let mut v = start;
578            while v >= end {
579                out.push(v);
580                if v < end + step {
581                    break;
582                }
583                v = v.saturating_sub(step);
584                if v == 0 {
585                    break;
586                }
587            }
588        }
589        if out.is_empty() {
590            return Err(MomError::InvalidRange { start, end, step });
591        }
592        Ok(out)
593    }
594
595    let periods = axis_usize(r.period)?;
596    let mut out = Vec::with_capacity(periods.len());
597    for &p in &periods {
598        out.push(MomParams { period: Some(p) });
599    }
600    Ok(out)
601}
602
603#[inline(always)]
604pub fn mom_batch_slice(
605    data: &[f64],
606    sweep: &MomBatchRange,
607    kern: Kernel,
608) -> Result<MomBatchOutput, MomError> {
609    mom_batch_inner(data, sweep, kern, false)
610}
611
612#[inline(always)]
613pub fn mom_batch_par_slice(
614    data: &[f64],
615    sweep: &MomBatchRange,
616    kern: Kernel,
617) -> Result<MomBatchOutput, MomError> {
618    mom_batch_inner(data, sweep, kern, true)
619}
620
621#[inline(always)]
622fn mom_batch_inner(
623    data: &[f64],
624    sweep: &MomBatchRange,
625    kern: Kernel,
626    parallel: bool,
627) -> Result<MomBatchOutput, MomError> {
628    let combos = expand_grid_checked(sweep)?;
629    let len = data.len();
630    if len == 0 {
631        return Err(MomError::EmptyInputData);
632    }
633    let first = data
634        .iter()
635        .position(|x| !x.is_nan())
636        .ok_or(MomError::AllValuesNaN)?;
637    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
638    if len - first < max_p {
639        return Err(MomError::NotEnoughValidData {
640            needed: max_p,
641            valid: len - first,
642        });
643    }
644
645    let rows = combos.len();
646    let cols = len;
647    let _total = rows
648        .checked_mul(cols)
649        .ok_or(MomError::InvalidInput("rows*cols overflow"))?;
650
651    let mut buf_mu = make_uninit_matrix(rows, cols);
652    let warm: Vec<usize> = combos.iter().map(|c| first + c.period.unwrap()).collect();
653    init_matrix_prefixes(&mut buf_mu, cols, &warm);
654
655    let mut guard = core::mem::ManuallyDrop::new(buf_mu);
656    let out_mu: &mut [std::mem::MaybeUninit<f64>] =
657        unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr(), guard.len()) };
658
659    let simd = match kern {
660        Kernel::Auto => Kernel::Scalar,
661        Kernel::Avx512Batch => Kernel::Avx512,
662        Kernel::Avx2Batch => Kernel::Avx2,
663        Kernel::ScalarBatch => Kernel::Scalar,
664        k if !k.is_batch() => k,
665        _ => unreachable!(),
666    };
667
668    let do_row = |row: usize, row_mu: &mut [std::mem::MaybeUninit<f64>]| {
669        let period = combos[row].period.unwrap();
670        let dst = unsafe {
671            core::slice::from_raw_parts_mut(row_mu.as_mut_ptr() as *mut f64, row_mu.len())
672        };
673        mom_compute_into(data, period, first, simd, dst);
674    };
675
676    if parallel {
677        #[cfg(not(target_arch = "wasm32"))]
678        out_mu
679            .par_chunks_mut(cols)
680            .enumerate()
681            .for_each(|(r, mu)| do_row(r, mu));
682        #[cfg(target_arch = "wasm32")]
683        for (r, mu) in out_mu.chunks_mut(cols).enumerate() {
684            do_row(r, mu);
685        }
686    } else {
687        for (r, mu) in out_mu.chunks_mut(cols).enumerate() {
688            do_row(r, mu);
689        }
690    }
691
692    let values = unsafe {
693        Vec::from_raw_parts(
694            guard.as_mut_ptr() as *mut f64,
695            guard.len(),
696            guard.capacity(),
697        )
698    };
699
700    Ok(MomBatchOutput {
701        values,
702        combos,
703        rows,
704        cols,
705    })
706}
707
708#[inline(always)]
709unsafe fn mom_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
710    for i in (first + period)..data.len() {
711        out[i] = data[i] - data[i - period];
712    }
713}
714
715#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
716#[inline(always)]
717unsafe fn mom_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
718    mom_row_scalar(data, first, period, out)
719}
720
721#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
722#[inline(always)]
723unsafe fn mom_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
724    if period <= 32 {
725        mom_row_avx512_short(data, first, period, out);
726    } else {
727        mom_row_avx512_long(data, first, period, out);
728    }
729}
730
731#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
732#[inline(always)]
733unsafe fn mom_row_avx512_short(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
734    mom_row_scalar(data, first, period, out)
735}
736
737#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
738#[inline(always)]
739unsafe fn mom_row_avx512_long(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
740    mom_row_scalar(data, first, period, out)
741}
742
743#[inline(always)]
744pub fn mom_batch_inner_into(
745    data: &[f64],
746    sweep: &MomBatchRange,
747    kern: Kernel,
748    parallel: bool,
749    output: &mut [f64],
750) -> Result<Vec<MomParams>, MomError> {
751    let combos = expand_grid_checked(sweep)?;
752    let len = data.len();
753    if len == 0 {
754        return Err(MomError::EmptyInputData);
755    }
756    let first = data
757        .iter()
758        .position(|x| !x.is_nan())
759        .ok_or(MomError::AllValuesNaN)?;
760    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
761    if len - first < max_p {
762        return Err(MomError::NotEnoughValidData {
763            needed: max_p,
764            valid: len - first,
765        });
766    }
767
768    let rows = combos.len();
769    let cols = len;
770    let expected = rows
771        .checked_mul(cols)
772        .ok_or(MomError::InvalidInput("rows*cols overflow"))?;
773    if output.len() != expected {
774        return Err(MomError::OutputLengthMismatch {
775            expected,
776            got: output.len(),
777        });
778    }
779
780    let out_mu: &mut [std::mem::MaybeUninit<f64>] = unsafe {
781        core::slice::from_raw_parts_mut(
782            output.as_mut_ptr() as *mut std::mem::MaybeUninit<f64>,
783            output.len(),
784        )
785    };
786    let warm: Vec<usize> = combos.iter().map(|c| first + c.period.unwrap()).collect();
787    init_matrix_prefixes(out_mu, cols, &warm);
788
789    let simd = match kern {
790        Kernel::Auto => Kernel::Scalar,
791        Kernel::Avx512Batch => Kernel::Avx512,
792        Kernel::Avx2Batch => Kernel::Avx2,
793        Kernel::ScalarBatch => Kernel::Scalar,
794        k if !k.is_batch() => k,
795        _ => unreachable!(),
796    };
797
798    let do_row = |row: usize, row_mu: &mut [std::mem::MaybeUninit<f64>]| {
799        let period = combos[row].period.unwrap();
800        let dst = unsafe {
801            core::slice::from_raw_parts_mut(row_mu.as_mut_ptr() as *mut f64, row_mu.len())
802        };
803        mom_compute_into(data, period, first, simd, dst);
804    };
805
806    if parallel {
807        #[cfg(not(target_arch = "wasm32"))]
808        out_mu
809            .par_chunks_mut(cols)
810            .enumerate()
811            .for_each(|(r, mu)| do_row(r, mu));
812        #[cfg(target_arch = "wasm32")]
813        for (r, mu) in out_mu.chunks_mut(cols).enumerate() {
814            do_row(r, mu);
815        }
816    } else {
817        for (r, mu) in out_mu.chunks_mut(cols).enumerate() {
818            do_row(r, mu);
819        }
820    }
821
822    Ok(combos)
823}
824
825#[cfg(test)]
826mod tests {
827    use super::*;
828    use crate::skip_if_unsupported;
829    use crate::utilities::data_loader::read_candles_from_csv;
830    #[cfg(feature = "proptest")]
831    use proptest::prelude::*;
832
833    #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
834    #[test]
835    fn test_mom_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
836        let mut data = Vec::with_capacity(256);
837        data.push(f64::NAN);
838        for i in 0..255 {
839            let x = (i as f64).sin() * 10.0 + (i as f64) * 0.5;
840            data.push(x);
841        }
842
843        let params = MomParams { period: Some(10) };
844        let input = MomInput::from_slice(&data, params);
845
846        let base = mom(&input)?.values;
847
848        let mut into_out = vec![0.0; data.len()];
849        mom_into(&input, &mut into_out)?;
850
851        assert_eq!(base.len(), into_out.len());
852
853        for (i, (a, b)) in base.iter().zip(into_out.iter()).enumerate() {
854            let ok = (a.is_nan() && b.is_nan()) || (a == b);
855            assert!(ok, "mismatch at index {}: base={}, into={}", i, a, b);
856        }
857        Ok(())
858    }
859
860    fn check_mom_partial_params(
861        test_name: &str,
862        kernel: Kernel,
863    ) -> Result<(), Box<dyn std::error::Error>> {
864        skip_if_unsupported!(kernel, test_name);
865        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
866        let candles = read_candles_from_csv(file_path)?;
867
868        let default_params = MomParams { period: None };
869        let input = MomInput::from_candles(&candles, "close", default_params);
870        let output = mom_with_kernel(&input, kernel)?;
871        assert_eq!(output.values.len(), candles.close.len());
872
873        Ok(())
874    }
875
876    fn check_mom_accuracy(
877        test_name: &str,
878        kernel: Kernel,
879    ) -> Result<(), Box<dyn std::error::Error>> {
880        skip_if_unsupported!(kernel, test_name);
881        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
882        let candles = read_candles_from_csv(file_path)?;
883
884        let input = MomInput::from_candles(&candles, "close", MomParams::default());
885        let result = mom_with_kernel(&input, kernel)?;
886        let expected_last_five = [-134.0, -331.0, -194.0, -294.0, -896.0];
887        let start = result.values.len().saturating_sub(5);
888        for (i, &val) in result.values[start..].iter().enumerate() {
889            let diff = (val - expected_last_five[i]).abs();
890            assert!(
891                diff < 1e-1,
892                "[{}] MOM {:?} mismatch at idx {}: got {}, expected {}",
893                test_name,
894                kernel,
895                i,
896                val,
897                expected_last_five[i]
898            );
899        }
900        Ok(())
901    }
902
903    fn check_mom_default_candles(
904        test_name: &str,
905        kernel: Kernel,
906    ) -> Result<(), Box<dyn std::error::Error>> {
907        skip_if_unsupported!(kernel, test_name);
908        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
909        let candles = read_candles_from_csv(file_path)?;
910
911        let input = MomInput::with_default_candles(&candles);
912        match input.data {
913            MomData::Candles { source, .. } => assert_eq!(source, "close"),
914            _ => panic!("Expected MomData::Candles"),
915        }
916        let output = mom_with_kernel(&input, kernel)?;
917        assert_eq!(output.values.len(), candles.close.len());
918
919        Ok(())
920    }
921
922    fn check_mom_zero_period(
923        test_name: &str,
924        kernel: Kernel,
925    ) -> Result<(), Box<dyn std::error::Error>> {
926        skip_if_unsupported!(kernel, test_name);
927        let input_data = [10.0, 20.0, 30.0];
928        let params = MomParams { period: Some(0) };
929        let input = MomInput::from_slice(&input_data, params);
930        let res = mom_with_kernel(&input, kernel);
931        assert!(
932            res.is_err(),
933            "[{}] MOM should fail with zero period",
934            test_name
935        );
936        Ok(())
937    }
938
939    fn check_mom_period_exceeds_length(
940        test_name: &str,
941        kernel: Kernel,
942    ) -> Result<(), Box<dyn std::error::Error>> {
943        skip_if_unsupported!(kernel, test_name);
944        let data_small = [10.0, 20.0, 30.0];
945        let params = MomParams { period: Some(10) };
946        let input = MomInput::from_slice(&data_small, params);
947        let res = mom_with_kernel(&input, kernel);
948        assert!(
949            res.is_err(),
950            "[{}] MOM should fail with period exceeding length",
951            test_name
952        );
953        Ok(())
954    }
955
956    fn check_mom_very_small_dataset(
957        test_name: &str,
958        kernel: Kernel,
959    ) -> Result<(), Box<dyn std::error::Error>> {
960        skip_if_unsupported!(kernel, test_name);
961        let single_point = [42.0];
962        let params = MomParams { period: Some(9) };
963        let input = MomInput::from_slice(&single_point, params);
964        let res = mom_with_kernel(&input, kernel);
965        assert!(
966            res.is_err(),
967            "[{}] MOM should fail with insufficient data",
968            test_name
969        );
970        Ok(())
971    }
972
973    fn check_mom_reinput(
974        test_name: &str,
975        kernel: Kernel,
976    ) -> Result<(), Box<dyn std::error::Error>> {
977        skip_if_unsupported!(kernel, test_name);
978        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
979        let candles = read_candles_from_csv(file_path)?;
980
981        let first_params = MomParams { period: Some(14) };
982        let first_input = MomInput::from_candles(&candles, "close", first_params);
983        let first_result = mom_with_kernel(&first_input, kernel)?;
984
985        let second_params = MomParams { period: Some(14) };
986        let second_input = MomInput::from_slice(&first_result.values, second_params);
987        let second_result = mom_with_kernel(&second_input, kernel)?;
988
989        assert_eq!(second_result.values.len(), first_result.values.len());
990        for i in 28..second_result.values.len() {
991            assert!(
992                !second_result.values[i].is_nan(),
993                "[{}] MOM Slice Reinput {:?} unexpected NaN at idx {}",
994                test_name,
995                kernel,
996                i
997            );
998        }
999        Ok(())
1000    }
1001
1002    fn check_mom_nan_handling(
1003        test_name: &str,
1004        kernel: Kernel,
1005    ) -> Result<(), Box<dyn std::error::Error>> {
1006        skip_if_unsupported!(kernel, test_name);
1007        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1008        let candles = read_candles_from_csv(file_path)?;
1009
1010        let input = MomInput::from_candles(&candles, "close", MomParams { period: Some(10) });
1011        let res = mom_with_kernel(&input, kernel)?;
1012        assert_eq!(res.values.len(), candles.close.len());
1013        if res.values.len() > 240 {
1014            for (i, &val) in res.values[240..].iter().enumerate() {
1015                assert!(
1016                    !val.is_nan(),
1017                    "[{}] Found unexpected NaN at out-index {}",
1018                    test_name,
1019                    240 + i
1020                );
1021            }
1022        }
1023        Ok(())
1024    }
1025
1026    fn check_mom_streaming(
1027        test_name: &str,
1028        kernel: Kernel,
1029    ) -> Result<(), Box<dyn std::error::Error>> {
1030        skip_if_unsupported!(kernel, test_name);
1031
1032        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1033        let candles = read_candles_from_csv(file_path)?;
1034
1035        let period = 10;
1036        let input = MomInput::from_candles(
1037            &candles,
1038            "close",
1039            MomParams {
1040                period: Some(period),
1041            },
1042        );
1043        let batch_output = mom_with_kernel(&input, kernel)?.values;
1044
1045        let mut stream = MomStream::try_new(MomParams {
1046            period: Some(period),
1047        })?;
1048        let mut stream_values = Vec::with_capacity(candles.close.len());
1049        for &price in &candles.close {
1050            match stream.update(price) {
1051                Some(val) => stream_values.push(val),
1052                None => stream_values.push(f64::NAN),
1053            }
1054        }
1055        assert_eq!(batch_output.len(), stream_values.len());
1056        for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1057            if b.is_nan() && s.is_nan() {
1058                continue;
1059            }
1060            let diff = (b - s).abs();
1061            assert!(
1062                diff < 1e-9,
1063                "[{}] MOM streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1064                test_name,
1065                i,
1066                b,
1067                s,
1068                diff
1069            );
1070        }
1071        Ok(())
1072    }
1073
1074    #[cfg(debug_assertions)]
1075    fn check_mom_no_poison(
1076        test_name: &str,
1077        kernel: Kernel,
1078    ) -> Result<(), Box<dyn std::error::Error>> {
1079        skip_if_unsupported!(kernel, test_name);
1080
1081        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1082        let candles = read_candles_from_csv(file_path)?;
1083
1084        let test_params = vec![
1085            MomParams::default(),
1086            MomParams { period: Some(2) },
1087            MomParams { period: Some(5) },
1088            MomParams { period: Some(7) },
1089            MomParams { period: Some(14) },
1090            MomParams { period: Some(20) },
1091            MomParams { period: Some(50) },
1092            MomParams { period: Some(100) },
1093            MomParams { period: Some(200) },
1094        ];
1095
1096        for (param_idx, params) in test_params.iter().enumerate() {
1097            let input = MomInput::from_candles(&candles, "close", params.clone());
1098            let output = mom_with_kernel(&input, kernel)?;
1099
1100            for (i, &val) in output.values.iter().enumerate() {
1101                if val.is_nan() {
1102                    continue;
1103                }
1104
1105                let bits = val.to_bits();
1106
1107                if bits == 0x11111111_11111111 {
1108                    panic!(
1109                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1110						 with params: period={} (param set {})",
1111                        test_name,
1112                        val,
1113                        bits,
1114                        i,
1115                        params.period.unwrap_or(10),
1116                        param_idx
1117                    );
1118                }
1119
1120                if bits == 0x22222222_22222222 {
1121                    panic!(
1122                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1123						 with params: period={} (param set {})",
1124                        test_name,
1125                        val,
1126                        bits,
1127                        i,
1128                        params.period.unwrap_or(10),
1129                        param_idx
1130                    );
1131                }
1132
1133                if bits == 0x33333333_33333333 {
1134                    panic!(
1135                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1136						 with params: period={} (param set {})",
1137                        test_name,
1138                        val,
1139                        bits,
1140                        i,
1141                        params.period.unwrap_or(10),
1142                        param_idx
1143                    );
1144                }
1145            }
1146        }
1147
1148        Ok(())
1149    }
1150
1151    #[cfg(not(debug_assertions))]
1152    fn check_mom_no_poison(
1153        _test_name: &str,
1154        _kernel: Kernel,
1155    ) -> Result<(), Box<dyn std::error::Error>> {
1156        Ok(())
1157    }
1158
1159    #[cfg(feature = "proptest")]
1160    #[allow(clippy::float_cmp)]
1161    fn check_mom_property(
1162        test_name: &str,
1163        kernel: Kernel,
1164    ) -> Result<(), Box<dyn std::error::Error>> {
1165        skip_if_unsupported!(kernel, test_name);
1166
1167        let strat = (1usize..=64).prop_flat_map(|period| {
1168            (
1169                prop::collection::vec(
1170                    (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1171                    period..400,
1172                ),
1173                Just(period),
1174            )
1175        });
1176
1177        proptest::test_runner::TestRunner::default()
1178            .run(&strat, |(data, period)| {
1179                let params = MomParams {
1180                    period: Some(period),
1181                };
1182                let input = MomInput::from_slice(&data, params.clone());
1183
1184                let MomOutput { values: out } = mom_with_kernel(&input, kernel).unwrap();
1185
1186                let MomOutput { values: ref_out } =
1187                    mom_with_kernel(&input, Kernel::Scalar).unwrap();
1188
1189                let first_valid = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1190                let warmup_period = first_valid + period;
1191
1192                for i in 0..warmup_period.min(out.len()) {
1193                    prop_assert!(
1194                        out[i].is_nan(),
1195                        "[{}] Expected NaN during warmup at index {}, got {}",
1196                        test_name,
1197                        i,
1198                        out[i]
1199                    );
1200                }
1201
1202                for i in warmup_period..data.len() {
1203                    let expected = data[i] - data[i - period];
1204                    let actual = out[i];
1205
1206                    if expected.is_finite() && actual.is_finite() {
1207                        prop_assert!(
1208                            (actual - expected).abs() < 1e-10,
1209                            "[{}] MOM formula mismatch at index {}: expected {}, got {}",
1210                            test_name,
1211                            i,
1212                            expected,
1213                            actual
1214                        );
1215                    }
1216                }
1217
1218                for i in 0..out.len() {
1219                    let y = out[i];
1220                    let r = ref_out[i];
1221
1222                    if !y.is_finite() || !r.is_finite() {
1223                        prop_assert!(
1224                            y.to_bits() == r.to_bits(),
1225                            "[{}] NaN/Inf mismatch at index {}: {} vs {}",
1226                            test_name,
1227                            i,
1228                            y,
1229                            r
1230                        );
1231                    } else {
1232                        let ulp_diff = y.to_bits().abs_diff(r.to_bits());
1233                        prop_assert!(
1234                            (y - r).abs() <= 1e-10 || ulp_diff <= 4,
1235                            "[{}] Kernel mismatch at index {}: {} vs {} (ULP diff: {})",
1236                            test_name,
1237                            i,
1238                            y,
1239                            r,
1240                            ulp_diff
1241                        );
1242                    }
1243                }
1244
1245                let all_same = data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10);
1246                if all_same && data.len() > period {
1247                    for i in warmup_period..out.len() {
1248                        prop_assert!(
1249                            out[i].abs() < 1e-10,
1250                            "[{}] Constant data should produce zero momentum at index {}, got {}",
1251                            test_name,
1252                            i,
1253                            out[i]
1254                        );
1255                    }
1256                }
1257
1258                if data.len() > period + 2 {
1259                    let diffs: Vec<f64> = data.windows(2).map(|w| w[1] - w[0]).collect();
1260                    let is_linear = diffs.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-9);
1261
1262                    if is_linear && diffs.len() > 0 {
1263                        let expected_momentum = diffs[0] * period as f64;
1264                        for i in warmup_period..out.len().min(warmup_period + 10) {
1265                            if out[i].is_finite() {
1266                                prop_assert!(
1267									(out[i] - expected_momentum).abs() < 1e-9,
1268									"[{}] Linear data should produce constant momentum {} at index {}, got {}",
1269									test_name, expected_momentum, i, out[i]
1270								);
1271                            }
1272                        }
1273                    }
1274                }
1275
1276                if data.len() < 100 {
1277                    let neg_data: Vec<f64> = data.iter().map(|&x| -x).collect();
1278                    let neg_input = MomInput::from_slice(&neg_data, params);
1279                    let MomOutput { values: neg_out } =
1280                        mom_with_kernel(&neg_input, kernel).unwrap();
1281
1282                    for i in warmup_period..out.len() {
1283                        if out[i].is_finite() && neg_out[i].is_finite() {
1284                            prop_assert!(
1285                                (out[i] + neg_out[i]).abs() < 1e-10,
1286                                "[{}] Symmetry violated at index {}: {} vs {}",
1287                                test_name,
1288                                i,
1289                                out[i],
1290                                -neg_out[i]
1291                            );
1292                        }
1293                    }
1294                }
1295
1296                if period == 1 && data.len() > 1 {
1297                    for i in 1..data.len() {
1298                        let expected = data[i] - data[i - 1];
1299                        if out[i].is_finite() && expected.is_finite() {
1300                            prop_assert!(
1301								(out[i] - expected).abs() < 1e-10,
1302								"[{}] Period=1 should give adjacent differences at index {}: expected {}, got {}",
1303								test_name, i, expected, out[i]
1304							);
1305                        }
1306                    }
1307                }
1308
1309                if data.len() > period {
1310                    let min_val = data
1311                        .iter()
1312                        .filter(|x| x.is_finite())
1313                        .fold(f64::INFINITY, |a, &b| a.min(b));
1314                    let max_val = data
1315                        .iter()
1316                        .filter(|x| x.is_finite())
1317                        .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
1318                    let max_diff = max_val - min_val;
1319
1320                    for i in warmup_period..out.len() {
1321                        if out[i].is_finite() {
1322                            prop_assert!(
1323								out[i].abs() <= max_diff + 1e-9,
1324								"[{}] Output magnitude {} exceeds max possible difference {} at index {}",
1325								test_name, out[i].abs(), max_diff, i
1326							);
1327                        }
1328                    }
1329                }
1330
1331                Ok(())
1332            })
1333            .unwrap();
1334
1335        Ok(())
1336    }
1337
1338    macro_rules! generate_all_mom_tests {
1339        ($($test_fn:ident),*) => {
1340            paste::paste! {
1341                $( #[test]
1342                fn [<$test_fn _scalar_f64>]() {
1343                    let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1344                })*
1345                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1346                $( #[test]
1347                fn [<$test_fn _avx2_f64>]() {
1348                    let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1349                }
1350                #[test]
1351                fn [<$test_fn _avx512_f64>]() {
1352                    let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1353                })*
1354            }
1355        }
1356    }
1357
1358    generate_all_mom_tests!(
1359        check_mom_partial_params,
1360        check_mom_accuracy,
1361        check_mom_default_candles,
1362        check_mom_zero_period,
1363        check_mom_period_exceeds_length,
1364        check_mom_very_small_dataset,
1365        check_mom_reinput,
1366        check_mom_nan_handling,
1367        check_mom_streaming,
1368        check_mom_no_poison
1369    );
1370
1371    #[cfg(feature = "proptest")]
1372    generate_all_mom_tests!(check_mom_property);
1373
1374    fn check_batch_default_row(
1375        test: &str,
1376        kernel: Kernel,
1377    ) -> Result<(), Box<dyn std::error::Error>> {
1378        skip_if_unsupported!(kernel, test);
1379
1380        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1381        let c = read_candles_from_csv(file)?;
1382
1383        let output = MomBatchBuilder::new()
1384            .kernel(kernel)
1385            .apply_candles(&c, "close")?;
1386
1387        let def = MomParams::default();
1388        let row = output.values_for(&def).expect("default row missing");
1389
1390        assert_eq!(row.len(), c.close.len());
1391
1392        let expected = [-134.0, -331.0, -194.0, -294.0, -896.0];
1393        let start = row.len() - 5;
1394        for (i, &v) in row[start..].iter().enumerate() {
1395            assert!(
1396                (v - expected[i]).abs() < 1e-1,
1397                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1398            );
1399        }
1400        Ok(())
1401    }
1402
1403    #[cfg(debug_assertions)]
1404    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
1405        skip_if_unsupported!(kernel, test);
1406
1407        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1408        let c = read_candles_from_csv(file)?;
1409
1410        let test_configs = vec![
1411            (2, 10, 2),
1412            (5, 25, 5),
1413            (30, 60, 15),
1414            (2, 5, 1),
1415            (10, 50, 10),
1416            (100, 200, 50),
1417        ];
1418
1419        for (cfg_idx, &(p_start, p_end, p_step)) in test_configs.iter().enumerate() {
1420            let output = MomBatchBuilder::new()
1421                .kernel(kernel)
1422                .period_range(p_start, p_end, p_step)
1423                .apply_candles(&c, "close")?;
1424
1425            for (idx, &val) in output.values.iter().enumerate() {
1426                if val.is_nan() {
1427                    continue;
1428                }
1429
1430                let bits = val.to_bits();
1431                let row = idx / output.cols;
1432                let col = idx % output.cols;
1433                let combo = &output.combos[row];
1434
1435                if bits == 0x11111111_11111111 {
1436                    panic!(
1437                        "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1438						 at row {} col {} (flat index {}) with params: period={}",
1439                        test,
1440                        cfg_idx,
1441                        val,
1442                        bits,
1443                        row,
1444                        col,
1445                        idx,
1446                        combo.period.unwrap_or(10)
1447                    );
1448                }
1449
1450                if bits == 0x22222222_22222222 {
1451                    panic!(
1452                        "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1453						 at row {} col {} (flat index {}) with params: period={}",
1454                        test,
1455                        cfg_idx,
1456                        val,
1457                        bits,
1458                        row,
1459                        col,
1460                        idx,
1461                        combo.period.unwrap_or(10)
1462                    );
1463                }
1464
1465                if bits == 0x33333333_33333333 {
1466                    panic!(
1467                        "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1468						 at row {} col {} (flat index {}) with params: period={}",
1469                        test,
1470                        cfg_idx,
1471                        val,
1472                        bits,
1473                        row,
1474                        col,
1475                        idx,
1476                        combo.period.unwrap_or(10)
1477                    );
1478                }
1479            }
1480        }
1481
1482        Ok(())
1483    }
1484
1485    #[cfg(not(debug_assertions))]
1486    fn check_batch_no_poison(
1487        _test: &str,
1488        _kernel: Kernel,
1489    ) -> Result<(), Box<dyn std::error::Error>> {
1490        Ok(())
1491    }
1492
1493    macro_rules! gen_batch_tests {
1494        ($fn_name:ident) => {
1495            paste::paste! {
1496                #[test] fn [<$fn_name _scalar>]()      {
1497                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1498                }
1499                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1500                #[test] fn [<$fn_name _avx2>]()        {
1501                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1502                }
1503                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1504                #[test] fn [<$fn_name _avx512>]()      {
1505                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1506                }
1507                #[test] fn [<$fn_name _auto_detect>]() {
1508                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1509                }
1510            }
1511        };
1512    }
1513    gen_batch_tests!(check_batch_default_row);
1514    gen_batch_tests!(check_batch_no_poison);
1515}
1516
1517#[inline(always)]
1518pub fn mom_into_slice(dst: &mut [f64], input: &MomInput, kernel: Kernel) -> Result<(), MomError> {
1519    let (data, period, first, chosen) = mom_prepare(input, kernel)?;
1520    if dst.len() != data.len() {
1521        return Err(MomError::OutputLengthMismatch {
1522            expected: data.len(),
1523            got: dst.len(),
1524        });
1525    }
1526    let warm = first + period;
1527    for v in &mut dst[..warm] {
1528        *v = f64::NAN;
1529    }
1530    mom_compute_into(data, period, first, chosen, dst);
1531    Ok(())
1532}
1533
1534#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1535use serde::{Deserialize, Serialize};
1536#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1537use wasm_bindgen::prelude::*;
1538
1539#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1540#[wasm_bindgen]
1541pub fn mom_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1542    let params = MomParams {
1543        period: Some(period),
1544    };
1545    let input = MomInput::from_slice(data, params);
1546
1547    let mut output = vec![0.0; data.len()];
1548    mom_into_slice(&mut output, &input, Kernel::Auto)
1549        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1550
1551    Ok(output)
1552}
1553
1554#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1555#[wasm_bindgen]
1556pub fn mom_into(
1557    in_ptr: *const f64,
1558    out_ptr: *mut f64,
1559    len: usize,
1560    period: usize,
1561) -> Result<(), JsValue> {
1562    if in_ptr.is_null() || out_ptr.is_null() {
1563        return Err(JsValue::from_str("Null pointer provided"));
1564    }
1565
1566    unsafe {
1567        let data = std::slice::from_raw_parts(in_ptr, len);
1568        let params = MomParams {
1569            period: Some(period),
1570        };
1571        let input = MomInput::from_slice(data, params);
1572
1573        if in_ptr == out_ptr {
1574            let mut temp = vec![0.0; len];
1575            mom_into_slice(&mut temp, &input, Kernel::Auto)
1576                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1577            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1578            out.copy_from_slice(&temp);
1579        } else {
1580            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1581            mom_into_slice(out, &input, Kernel::Auto)
1582                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1583        }
1584        Ok(())
1585    }
1586}
1587
1588#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1589#[wasm_bindgen]
1590pub fn mom_alloc(len: usize) -> *mut f64 {
1591    let mut vec = Vec::<f64>::with_capacity(len);
1592    let ptr = vec.as_mut_ptr();
1593    std::mem::forget(vec);
1594    ptr
1595}
1596
1597#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1598#[wasm_bindgen]
1599pub fn mom_free(ptr: *mut f64, len: usize) {
1600    if !ptr.is_null() {
1601        unsafe {
1602            let _ = Vec::from_raw_parts(ptr, len, len);
1603        }
1604    }
1605}
1606
1607#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1608#[derive(Serialize, Deserialize)]
1609pub struct MomBatchConfig {
1610    pub period_range: (usize, usize, usize),
1611}
1612
1613#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1614#[derive(Serialize, Deserialize)]
1615pub struct MomBatchJsOutput {
1616    pub values: Vec<f64>,
1617    pub combos: Vec<MomParams>,
1618    pub rows: usize,
1619    pub cols: usize,
1620}
1621
1622#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1623#[wasm_bindgen(js_name = mom_batch)]
1624pub fn mom_batch_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1625    let config: MomBatchConfig = serde_wasm_bindgen::from_value(config)
1626        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1627
1628    let sweep = MomBatchRange {
1629        period: config.period_range,
1630    };
1631
1632    let output = mom_batch_inner(data, &sweep, Kernel::Auto, false)
1633        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1634
1635    let js_output = MomBatchJsOutput {
1636        values: output.values,
1637        combos: output.combos,
1638        rows: output.rows,
1639        cols: output.cols,
1640    };
1641
1642    serde_wasm_bindgen::to_value(&js_output)
1643        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1644}
1645
1646#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1647#[wasm_bindgen]
1648pub fn mom_batch_into(
1649    in_ptr: *const f64,
1650    out_ptr: *mut f64,
1651    len: usize,
1652    period_start: usize,
1653    period_end: usize,
1654    period_step: usize,
1655) -> Result<usize, JsValue> {
1656    if in_ptr.is_null() || out_ptr.is_null() {
1657        return Err(JsValue::from_str("null pointer passed to mom_batch_into"));
1658    }
1659
1660    unsafe {
1661        let data = std::slice::from_raw_parts(in_ptr, len);
1662
1663        let sweep = MomBatchRange {
1664            period: (period_start, period_end, period_step),
1665        };
1666
1667        let combos = expand_grid_checked(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1668        let rows = combos.len();
1669        let cols = len;
1670        let total = rows
1671            .checked_mul(cols)
1672            .ok_or_else(|| JsValue::from_str("size overflow"))?;
1673
1674        let out = std::slice::from_raw_parts_mut(out_ptr, total);
1675
1676        mom_batch_inner_into(data, &sweep, Kernel::Auto, false, out)
1677            .map_err(|e| JsValue::from_str(&e.to_string()))?;
1678
1679        Ok(rows)
1680    }
1681}
1682
1683#[cfg(feature = "python")]
1684use crate::utilities::kernel_validation::validate_kernel;
1685#[cfg(feature = "python")]
1686use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
1687#[cfg(feature = "python")]
1688use pyo3::exceptions::PyValueError;
1689#[cfg(feature = "python")]
1690use pyo3::prelude::*;
1691#[cfg(feature = "python")]
1692use pyo3::types::PyDict;
1693
1694#[cfg(feature = "python")]
1695#[pyfunction(name = "mom")]
1696#[pyo3(signature = (data, period, kernel=None))]
1697pub fn mom_py<'py>(
1698    py: Python<'py>,
1699    data: PyReadonlyArray1<'py, f64>,
1700    period: usize,
1701    kernel: Option<&str>,
1702) -> PyResult<Bound<'py, PyArray1<f64>>> {
1703    let slice_in = data.as_slice()?;
1704    let kern = validate_kernel(kernel, false)?;
1705
1706    let params = MomParams {
1707        period: Some(period),
1708    };
1709    let input = MomInput::from_slice(slice_in, params);
1710
1711    let result_vec: Vec<f64> = py
1712        .allow_threads(|| mom_with_kernel(&input, kern).map(|o| o.values))
1713        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1714
1715    Ok(result_vec.into_pyarray(py))
1716}
1717
1718#[cfg(feature = "python")]
1719#[pyfunction(name = "mom_batch")]
1720#[pyo3(signature = (data, period_range, kernel=None))]
1721pub fn mom_batch_py<'py>(
1722    py: Python<'py>,
1723    data: PyReadonlyArray1<'py, f64>,
1724    period_range: (usize, usize, usize),
1725    kernel: Option<&str>,
1726) -> PyResult<Bound<'py, PyDict>> {
1727    let slice_in = data.as_slice()?;
1728    let kern = validate_kernel(kernel, true)?;
1729
1730    let sweep = MomBatchRange {
1731        period: period_range,
1732    };
1733
1734    let combos = expand_grid_checked(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1735    let rows = combos.len();
1736    let cols = slice_in.len();
1737    let total = rows
1738        .checked_mul(cols)
1739        .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
1740
1741    let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1742    let slice_out = unsafe { out_arr.as_slice_mut()? };
1743
1744    let combos = py
1745        .allow_threads(|| {
1746            let kernel = match kern {
1747                Kernel::Auto => detect_best_batch_kernel(),
1748                k => k,
1749            };
1750
1751            mom_batch_inner_into(slice_in, &sweep, kernel, true, slice_out)
1752        })
1753        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1754
1755    let dict = PyDict::new(py);
1756    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1757    dict.set_item(
1758        "periods",
1759        combos
1760            .iter()
1761            .map(|p| p.period.unwrap() as u64)
1762            .collect::<Vec<_>>()
1763            .into_pyarray(py),
1764    )?;
1765
1766    Ok(dict)
1767}
1768
1769#[cfg(feature = "python")]
1770#[pyclass(name = "MomStream")]
1771pub struct MomStreamPy {
1772    inner: MomStream,
1773}
1774
1775#[cfg(feature = "python")]
1776#[pymethods]
1777impl MomStreamPy {
1778    #[new]
1779    pub fn new(period: usize) -> PyResult<Self> {
1780        let params = MomParams {
1781            period: Some(period),
1782        };
1783        let inner = MomStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1784        Ok(MomStreamPy { inner })
1785    }
1786
1787    pub fn update(&mut self, value: f64) -> Option<f64> {
1788        self.inner.update(value)
1789    }
1790}
1791
1792#[cfg(feature = "python")]
1793pub fn register_mom_module(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
1794    m.add_function(wrap_pyfunction!(mom_py, m)?)?;
1795    m.add_function(wrap_pyfunction!(mom_batch_py, m)?)?;
1796    Ok(())
1797}
1798
1799#[cfg(all(feature = "python", feature = "cuda"))]
1800use crate::cuda::moving_averages::DeviceArrayF32;
1801#[cfg(all(feature = "python", feature = "cuda"))]
1802use crate::cuda::oscillators::CudaMom;
1803#[cfg(all(feature = "python", feature = "cuda"))]
1804use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
1805#[cfg(all(feature = "python", feature = "cuda"))]
1806use cust::context::Context;
1807#[cfg(all(feature = "python", feature = "cuda"))]
1808use std::sync::Arc;
1809
1810#[cfg(all(feature = "python", feature = "cuda"))]
1811#[pyclass(module = "ta_indicators.cuda", unsendable)]
1812pub struct MomDeviceArrayF32Py {
1813    pub(crate) inner: Option<DeviceArrayF32>,
1814    pub(crate) ctx: Arc<Context>,
1815    pub(crate) device_id: i32,
1816}
1817
1818#[cfg(all(feature = "python", feature = "cuda"))]
1819#[pymethods]
1820impl MomDeviceArrayF32Py {
1821    #[getter]
1822    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1823        let inner = self
1824            .inner
1825            .as_ref()
1826            .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
1827        let d = PyDict::new(py);
1828        d.set_item("shape", (inner.rows, inner.cols))?;
1829        d.set_item("typestr", "<f4")?;
1830        d.set_item(
1831            "strides",
1832            (
1833                inner.cols * std::mem::size_of::<f32>(),
1834                std::mem::size_of::<f32>(),
1835            ),
1836        )?;
1837        d.set_item("data", (inner.device_ptr() as usize, false))?;
1838
1839        d.set_item("version", 3)?;
1840        Ok(d)
1841    }
1842
1843    fn __dlpack_device__(&self) -> (i32, i32) {
1844        (2, self.device_id)
1845    }
1846
1847    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1848    fn __dlpack__<'py>(
1849        &mut self,
1850        py: Python<'py>,
1851        stream: Option<PyObject>,
1852        max_version: Option<PyObject>,
1853        dl_device: Option<PyObject>,
1854        copy: Option<PyObject>,
1855    ) -> PyResult<PyObject> {
1856        let (kdl, alloc_dev) = self.__dlpack_device__();
1857        if let Some(dev_obj) = dl_device.as_ref() {
1858            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1859                if dev_ty != kdl || dev_id != alloc_dev {
1860                    let wants_copy = copy
1861                        .as_ref()
1862                        .and_then(|c| c.extract::<bool>(py).ok())
1863                        .unwrap_or(false);
1864                    if wants_copy {
1865                        return Err(PyValueError::new_err(
1866                            "device copy not implemented for __dlpack__",
1867                        ));
1868                    } else {
1869                        return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
1870                    }
1871                }
1872            }
1873        }
1874        let _ = stream;
1875
1876        let inner = self
1877            .inner
1878            .take()
1879            .ok_or_else(|| PyValueError::new_err("__dlpack__ may only be called once"))?;
1880
1881        let rows = inner.rows;
1882        let cols = inner.cols;
1883        let buf = inner.buf;
1884
1885        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1886
1887        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1888    }
1889}
1890
1891#[cfg(all(feature = "python", feature = "cuda"))]
1892#[pyfunction(name = "mom_cuda_batch_dev")]
1893#[pyo3(signature = (data, period_range, device_id=0))]
1894pub fn mom_cuda_batch_dev_py(
1895    py: Python<'_>,
1896    data: numpy::PyReadonlyArray1<'_, f32>,
1897    period_range: (usize, usize, usize),
1898    device_id: usize,
1899) -> PyResult<MomDeviceArrayF32Py> {
1900    use crate::cuda::cuda_available;
1901    if !cuda_available() {
1902        return Err(PyValueError::new_err("CUDA not available"));
1903    }
1904    let slice = data.as_slice()?;
1905    if slice.is_empty() {
1906        return Err(PyValueError::new_err("empty input"));
1907    }
1908    let sweep = MomBatchRange {
1909        period: period_range,
1910    };
1911    let (inner, ctx, dev_id) = py.allow_threads(|| {
1912        let cuda = CudaMom::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1913        let ctx = cuda.context_arc();
1914        let dev_id = cuda.device_id() as i32;
1915        let arr = cuda
1916            .mom_batch_dev(slice, &sweep)
1917            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1918        Ok::<_, PyErr>((arr, ctx, dev_id))
1919    })?;
1920    Ok(MomDeviceArrayF32Py {
1921        inner: Some(inner),
1922        ctx,
1923        device_id: dev_id,
1924    })
1925}
1926
1927#[cfg(all(feature = "python", feature = "cuda"))]
1928#[pyfunction(name = "mom_cuda_many_series_one_param_dev")]
1929#[pyo3(signature = (data_tm, cols, rows, period, device_id=0))]
1930pub fn mom_cuda_many_series_one_param_dev_py(
1931    py: Python<'_>,
1932    data_tm: numpy::PyReadonlyArray1<'_, f32>,
1933    cols: usize,
1934    rows: usize,
1935    period: usize,
1936    device_id: usize,
1937) -> PyResult<MomDeviceArrayF32Py> {
1938    use crate::cuda::cuda_available;
1939    if !cuda_available() {
1940        return Err(PyValueError::new_err("CUDA not available"));
1941    }
1942    let slice = data_tm.as_slice()?;
1943    let expected = cols
1944        .checked_mul(rows)
1945        .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
1946    if slice.len() != expected {
1947        return Err(PyValueError::new_err("time-major input length mismatch"));
1948    }
1949    let (inner, ctx, dev_id) = py.allow_threads(|| {
1950        let cuda = CudaMom::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1951        let ctx = cuda.context_arc();
1952        let dev_id = cuda.device_id() as i32;
1953        let arr = cuda
1954            .mom_many_series_one_param_time_major_dev(slice, cols, rows, period)
1955            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1956        Ok::<_, PyErr>((arr, ctx, dev_id))
1957    })?;
1958    Ok(MomDeviceArrayF32Py {
1959        inner: Some(inner),
1960        ctx,
1961        device_id: dev_id,
1962    })
1963}