Skip to main content

vector_ta/indicators/moving_averages/
nma.rs

1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::moving_averages::{CudaNma, DeviceArrayF32};
3use crate::utilities::data_loader::{source_type, Candles};
4use crate::utilities::enums::Kernel;
5use crate::utilities::helpers::{
6    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
7    make_uninit_matrix,
8};
9use aligned_vec::{AVec, CACHELINE_ALIGN};
10#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
11use core::arch::wasm32::*;
12#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
13use core::arch::x86_64::*;
14#[cfg(all(feature = "python", feature = "cuda"))]
15use cust::context::Context;
16#[cfg(all(feature = "python", feature = "cuda"))]
17use cust::memory::DeviceBuffer;
18#[cfg(not(target_arch = "wasm32"))]
19use rayon::prelude::*;
20use std::convert::AsRef;
21use std::error::Error;
22use std::mem::MaybeUninit;
23#[cfg(all(feature = "python", feature = "cuda"))]
24use std::sync::Arc;
25use thiserror::Error;
26
27impl<'a> AsRef<[f64]> for NmaInput<'a> {
28    #[inline(always)]
29    fn as_ref(&self) -> &[f64] {
30        match &self.data {
31            NmaData::Slice(slice) => slice,
32            NmaData::Candles { candles, source } => source_type(candles, source),
33        }
34    }
35}
36
37#[derive(Debug, Clone)]
38pub enum NmaData<'a> {
39    Candles {
40        candles: &'a Candles,
41        source: &'a str,
42    },
43    Slice(&'a [f64]),
44}
45
46#[derive(Debug, Clone)]
47pub struct NmaOutput {
48    pub values: Vec<f64>,
49}
50
51#[derive(Debug, Clone, Copy)]
52#[cfg_attr(
53    all(target_arch = "wasm32", feature = "wasm"),
54    derive(serde::Serialize, serde::Deserialize)
55)]
56pub struct NmaParams {
57    pub period: Option<usize>,
58}
59
60impl Default for NmaParams {
61    fn default() -> Self {
62        Self { period: Some(40) }
63    }
64}
65
66#[derive(Debug, Clone)]
67pub struct NmaInput<'a> {
68    pub data: NmaData<'a>,
69    pub params: NmaParams,
70}
71
72impl<'a> NmaInput<'a> {
73    #[inline]
74    pub fn from_candles(c: &'a Candles, s: &'a str, p: NmaParams) -> Self {
75        Self {
76            data: NmaData::Candles {
77                candles: c,
78                source: s,
79            },
80            params: p,
81        }
82    }
83    #[inline]
84    pub fn from_slice(sl: &'a [f64], p: NmaParams) -> Self {
85        Self {
86            data: NmaData::Slice(sl),
87            params: p,
88        }
89    }
90    #[inline]
91    pub fn with_default_candles(c: &'a Candles) -> Self {
92        Self::from_candles(c, "close", NmaParams::default())
93    }
94    #[inline]
95    pub fn get_period(&self) -> usize {
96        self.params.period.unwrap_or(40)
97    }
98}
99
100#[derive(Copy, Clone, Debug)]
101pub struct NmaBuilder {
102    period: Option<usize>,
103    kernel: Kernel,
104}
105
106impl Default for NmaBuilder {
107    fn default() -> Self {
108        Self {
109            period: None,
110            kernel: Kernel::Auto,
111        }
112    }
113}
114
115impl NmaBuilder {
116    #[inline(always)]
117    pub fn new() -> Self {
118        Self::default()
119    }
120    #[inline(always)]
121    pub fn period(mut self, n: usize) -> Self {
122        self.period = Some(n);
123        self
124    }
125    #[inline(always)]
126    pub fn kernel(mut self, k: Kernel) -> Self {
127        self.kernel = k;
128        self
129    }
130    #[inline(always)]
131    pub fn apply(self, c: &Candles) -> Result<NmaOutput, NmaError> {
132        let p = NmaParams {
133            period: self.period,
134        };
135        let i = NmaInput::from_candles(c, "close", p);
136        nma_with_kernel(&i, self.kernel)
137    }
138    #[inline(always)]
139    pub fn apply_slice(self, d: &[f64]) -> Result<NmaOutput, NmaError> {
140        let p = NmaParams {
141            period: self.period,
142        };
143        let i = NmaInput::from_slice(d, p);
144        nma_with_kernel(&i, self.kernel)
145    }
146    #[inline(always)]
147    pub fn into_stream(self) -> Result<NmaStream, NmaError> {
148        let p = NmaParams {
149            period: self.period,
150        };
151        NmaStream::try_new(p)
152    }
153}
154
155#[derive(Debug, Error)]
156pub enum NmaError {
157    #[error("nma: Input data slice is empty.")]
158    EmptyInputData,
159    #[error("nma: All values are NaN.")]
160    AllValuesNaN,
161    #[error("nma: Invalid period: period = {period}, data length = {data_len}")]
162    InvalidPeriod { period: usize, data_len: usize },
163    #[error("nma: Not enough valid data: needed = {needed}, valid = {valid}")]
164    NotEnoughValidData { needed: usize, valid: usize },
165    #[error("nma: Output length mismatch: expected = {expected}, got = {got}")]
166    OutputLengthMismatch { expected: usize, got: usize },
167    #[error("nma: Invalid range: start = {start}, end = {end}, step = {step}")]
168    InvalidRange {
169        start: usize,
170        end: usize,
171        step: usize,
172    },
173    #[error("nma: Invalid kernel for batch path: {0:?}")]
174    InvalidKernelForBatch(Kernel),
175    #[error("nma: invalid input: {0}")]
176    InvalidInput(String),
177}
178
179#[inline]
180pub fn nma(input: &NmaInput) -> Result<NmaOutput, NmaError> {
181    nma_with_kernel(input, Kernel::Auto)
182}
183
184#[inline(always)]
185fn nma_prepare<'a>(
186    input: &'a NmaInput,
187    kernel: Kernel,
188) -> Result<(&'a [f64], usize, usize, Vec<f64>, Vec<f64>, Kernel), NmaError> {
189    let data: &[f64] = input.as_ref();
190    let len = data.len();
191
192    if len == 0 {
193        return Err(NmaError::EmptyInputData);
194    }
195
196    let first = data
197        .iter()
198        .position(|x| !x.is_nan())
199        .ok_or(NmaError::AllValuesNaN)?;
200
201    let period = input.get_period();
202
203    if period == 0 || period > len {
204        return Err(NmaError::InvalidPeriod {
205            period,
206            data_len: len,
207        });
208    }
209    if (len - first) < (period + 1) {
210        return Err(NmaError::NotEnoughValidData {
211            needed: period + 1,
212            valid: len - first,
213        });
214    }
215
216    let chosen = match kernel {
217        Kernel::Auto => detect_best_kernel(),
218        other => other,
219    };
220
221    let mut ln_values = alloc_with_nan_prefix(len, 0);
222    if matches!(chosen, Kernel::Scalar | Kernel::ScalarBatch) {
223        for i in 0..len {
224            ln_values[i] = data[i].max(1e-10).ln();
225        }
226    }
227
228    let mut sqrt_diffs = vec![0.0; period];
229    for i in 0..period {
230        let s0 = (i as f64).sqrt();
231        let s1 = ((i + 1) as f64).sqrt();
232        sqrt_diffs[i] = s1 - s0;
233    }
234
235    Ok((data, period, first, ln_values, sqrt_diffs, chosen))
236}
237
238fn nma_compute_into(
239    data: &[f64],
240    period: usize,
241    first: usize,
242    ln_values: &mut [f64],
243    sqrt_diffs: &mut [f64],
244    kernel: Kernel,
245    out: &mut [f64],
246) {
247    unsafe {
248        #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
249        {
250            if matches!(kernel, Kernel::Scalar | Kernel::ScalarBatch) {
251                nma_simd128(data, period, first, ln_values, sqrt_diffs, out);
252                return;
253            }
254        }
255
256        match kernel {
257            Kernel::Scalar | Kernel::ScalarBatch => {
258                nma_scalar_with_precomputed(data, period, first, ln_values, sqrt_diffs, out)
259            }
260
261            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
262            Kernel::Avx2 | Kernel::Avx2Batch => {
263                nma_avx2(data, period, first, ln_values, sqrt_diffs, out)
264            }
265
266            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
267            Kernel::Avx512 | Kernel::Avx512Batch => {
268                nma_avx512_v2(data, period, first, ln_values, sqrt_diffs, out)
269            }
270
271            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
272            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
273                nma_scalar_with_precomputed(data, period, first, ln_values, sqrt_diffs, out)
274            }
275            _ => unreachable!(),
276        }
277    }
278}
279
280pub fn nma_with_kernel(input: &NmaInput, kernel: Kernel) -> Result<NmaOutput, NmaError> {
281    let (data, period, first, mut ln_values, mut sqrt_diffs, chosen) = nma_prepare(input, kernel)?;
282
283    let warm = first + period;
284    let mut out = alloc_with_nan_prefix(data.len(), warm);
285
286    nma_compute_into(
287        data,
288        period,
289        first,
290        &mut ln_values,
291        &mut sqrt_diffs,
292        chosen,
293        &mut out,
294    );
295
296    Ok(NmaOutput { values: out })
297}
298#[inline]
299pub fn nma_scalar(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
300    let len = data.len();
301
302    let mut ln_values = alloc_with_nan_prefix(len, 0);
303    for i in 0..len {
304        ln_values[i] = data[i].max(1e-10).ln();
305    }
306
307    let mut sqrt_diffs = vec![0.0; period];
308    for i in 0..period {
309        let s0 = (i as f64).sqrt();
310        let s1 = ((i + 1) as f64).sqrt();
311        sqrt_diffs[i] = s1 - s0;
312    }
313
314    for j in (first + period)..len {
315        let mut num = 0.0;
316        let mut denom = 0.0;
317
318        for i in 0..period {
319            let oi = (ln_values[j - i] - ln_values[j - i - 1]).abs();
320            num += oi * sqrt_diffs[i];
321            denom += oi;
322        }
323
324        let ratio = if denom == 0.0 { 0.0 } else { num / denom };
325
326        let i = period - 1;
327        out[j] = data[j - i] * ratio + data[j - i - 1] * (1.0 - ratio);
328    }
329}
330
331#[inline]
332pub fn nma_scalar_with_precomputed(
333    data: &[f64],
334    period: usize,
335    first: usize,
336    ln_values: &[f64],
337    sqrt_diffs: &[f64],
338    out: &mut [f64],
339) {
340    let len = data.len();
341
342    for j in (first + period)..len {
343        let base = j - period;
344
345        let mut num = 0.0f64;
346        let mut denom = 0.0f64;
347
348        let mut prev = ln_values[base];
349        for t in 0..period {
350            let cur = ln_values[base + t + 1];
351            let diff = (cur - prev).abs();
352            prev = cur;
353
354            num += diff * sqrt_diffs[period - 1 - t];
355            denom += diff;
356        }
357
358        let ratio = if denom == 0.0 { 0.0 } else { num / denom };
359
360        let x0 = data[j - period];
361        let x1 = data[j - period + 1];
362        out[j] = (x1 - x0).mul_add(ratio, x0);
363    }
364}
365
366#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
367#[inline]
368unsafe fn nma_simd128(
369    data: &[f64],
370    period: usize,
371    first: usize,
372    ln_values: &[f64],
373    sqrt_diffs: &[f64],
374    out: &mut [f64],
375) {
376    use core::arch::wasm32::*;
377
378    const STEP: usize = 2;
379    let len = data.len();
380
381    for j in (first + period)..len {
382        let chunks = period / STEP;
383        let tail = period % STEP;
384
385        let mut num_acc = f64x2_splat(0.0);
386        let mut denom_acc = f64x2_splat(0.0);
387
388        for blk in 0..chunks {
389            let i = blk * STEP;
390
391            let ln_curr_0 = f64x2(ln_values[j - i], ln_values[j - i - 1]);
392            let ln_prev_0 = f64x2(ln_values[j - i - 1], ln_values[j - i - 2]);
393
394            let diff = f64x2_sub(ln_curr_0, ln_prev_0);
395            let abs_diff = f64x2_abs(diff);
396
397            let sqrt_d = v128_load(sqrt_diffs.as_ptr().add(i) as *const v128);
398
399            num_acc = f64x2_add(num_acc, f64x2_mul(abs_diff, sqrt_d));
400            denom_acc = f64x2_add(denom_acc, abs_diff);
401        }
402
403        let mut num = f64x2_extract_lane::<0>(num_acc) + f64x2_extract_lane::<1>(num_acc);
404        let mut denom = f64x2_extract_lane::<0>(denom_acc) + f64x2_extract_lane::<1>(denom_acc);
405
406        for i in (chunks * STEP)..period {
407            let oi = (ln_values[j - i] - ln_values[j - i - 1]).abs();
408            num += oi * sqrt_diffs[i];
409            denom += oi;
410        }
411
412        let ratio = if denom == 0.0 { 0.0 } else { num / denom };
413        let i = period - 1;
414        out[j] = data[j - i] * ratio + data[j - i - 1] * (1.0 - ratio);
415    }
416}
417
418#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
419#[inline]
420#[target_feature(enable = "avx512f,avx512dq,avx512vl,avx512bw,fma")]
421unsafe fn fast_ln_avx512_hi(x: __m512d) -> __m512d {
422    let one = _mm512_set1_pd(1.0);
423    let two = _mm512_set1_pd(2.0);
424    let half = _mm512_set1_pd(0.5);
425    let ln2 = _mm512_set1_pd(std::f64::consts::LN_2);
426    let sqrt_half = _mm512_set1_pd(0.7071067811865475244);
427
428    let threshold = _mm512_set1_pd(0.2);
429    let x_minus_1 = _mm512_sub_pd(x, one);
430    let abs_x_minus_1 = _mm512_abs_pd(x_minus_1);
431    let near_one_mask = _mm512_cmp_pd_mask(abs_x_minus_1, threshold, _CMP_LT_OQ);
432
433    let c2 = _mm512_set1_pd(-0.5);
434    let c3 = _mm512_set1_pd(1.0 / 3.0);
435    let c4 = _mm512_set1_pd(-0.25);
436    let c5 = _mm512_set1_pd(0.2);
437    let c6 = _mm512_set1_pd(-1.0 / 6.0);
438    let c7 = _mm512_set1_pd(1.0 / 7.0);
439    let c8 = _mm512_set1_pd(-0.125);
440
441    let y = x_minus_1;
442    let y2 = _mm512_mul_pd(y, y);
443    let y3 = _mm512_mul_pd(y2, y);
444    let y4 = _mm512_mul_pd(y2, y2);
445
446    let mut taylor = y;
447    taylor = _mm512_fmadd_pd(y2, c2, taylor);
448    taylor = _mm512_fmadd_pd(y3, c3, taylor);
449    taylor = _mm512_fmadd_pd(y4, c4, taylor);
450    let y5 = _mm512_mul_pd(y4, y);
451    let y6 = _mm512_mul_pd(y4, y2);
452    let y7 = _mm512_mul_pd(y4, y3);
453    let y8 = _mm512_mul_pd(y4, y4);
454    taylor = _mm512_fmadd_pd(y5, c5, taylor);
455    taylor = _mm512_fmadd_pd(y6, c6, taylor);
456    taylor = _mm512_fmadd_pd(y7, c7, taylor);
457    taylor = _mm512_fmadd_pd(y8, c8, taylor);
458
459    let ix = _mm512_castpd_si512(x);
460    let exp_mask = _mm512_set1_epi64(0x7FF0000000000000u64 as i64);
461    let mantissa_mask = _mm512_set1_epi64(0x000FFFFFFFFFFFFFu64 as i64);
462    let bias = _mm512_set1_epi64(1023);
463
464    let exp_bits = _mm512_and_si512(ix, exp_mask);
465    let exp_shifted = _mm512_srli_epi64::<52>(exp_bits);
466    let e = _mm512_sub_epi64(exp_shifted, bias);
467    let e_f64 = _mm512_cvtepi64_pd(e);
468
469    let mantissa_bits = _mm512_and_si512(ix, mantissa_mask);
470    let one_bits = _mm512_set1_epi64(0x3FF0000000000000u64 as i64);
471    let m_bits = _mm512_or_si512(mantissa_bits, one_bits);
472    let mut m = _mm512_castsi512_pd(m_bits);
473
474    let needs_fold = _mm512_cmp_pd_mask(m, sqrt_half, _CMP_LT_OQ);
475    m = _mm512_mask_mul_pd(m, needs_fold, m, two);
476    let e_adjust = _mm512_mask_sub_pd(e_f64, needs_fold, e_f64, one);
477
478    let f = _mm512_sub_pd(m, one);
479
480    let two_plus_f = _mm512_add_pd(two, f);
481    let s = _mm512_div_pd(f, two_plus_f);
482    let z = _mm512_mul_pd(s, s);
483    let w = _mm512_mul_pd(z, z);
484
485    let lg1 = _mm512_set1_pd(6.666666666666735130e-01);
486    let lg2 = _mm512_set1_pd(3.999999999940941908e-01);
487    let lg3 = _mm512_set1_pd(2.857142874366239149e-01);
488    let lg4 = _mm512_set1_pd(2.222219843214978396e-01);
489    let lg5 = _mm512_set1_pd(1.818357216161805012e-01);
490    let lg6 = _mm512_set1_pd(1.531383769920937332e-01);
491    let lg7 = _mm512_set1_pd(1.479819860511658591e-01);
492
493    let lg8 = _mm512_set1_pd(1.333355814642869980e-01);
494    let lg9 = _mm512_set1_pd(1.253141636393179328e-01);
495
496    let mut r1 = lg9;
497    r1 = _mm512_fmadd_pd(r1, z, lg7);
498    r1 = _mm512_fmadd_pd(r1, z, lg5);
499    r1 = _mm512_fmadd_pd(r1, z, lg3);
500    r1 = _mm512_fmadd_pd(r1, z, lg1);
501    r1 = _mm512_mul_pd(r1, z);
502
503    let mut r2 = lg8;
504    r2 = _mm512_fmadd_pd(r2, z, lg6);
505    r2 = _mm512_fmadd_pd(r2, z, lg4);
506    r2 = _mm512_fmadd_pd(r2, z, lg2);
507    r2 = _mm512_mul_pd(r2, w);
508
509    let r = _mm512_add_pd(r1, r2);
510
511    let hfsq = _mm512_mul_pd(_mm512_mul_pd(half, f), f);
512
513    let ln1pf = _mm512_sub_pd(f, hfsq);
514    let s_squared_times_f = _mm512_mul_pd(_mm512_mul_pd(s, s), f);
515    let ln1pf = _mm512_fmadd_pd(s_squared_times_f, r, ln1pf);
516
517    let general_result = _mm512_fmadd_pd(e_adjust, ln2, ln1pf);
518
519    _mm512_mask_blend_pd(near_one_mask, general_result, taylor)
520}
521
522#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
523#[inline]
524#[target_feature(enable = "avx2,fma")]
525pub unsafe fn nma_avx2(
526    data: &[f64],
527    period: usize,
528    first: usize,
529    ln_values: &mut [f64],
530    sqrt_diffs: &mut [f64],
531    out: &mut [f64],
532) {
533    let len = data.len();
534
535    let epsilon = _mm256_set1_pd(1e-10);
536
537    let one = _mm256_set1_pd(1.0);
538    let zero = _mm256_setzero_pd();
539
540    let mut i = 0;
541    while i + 4 <= len {
542        let vals = _mm256_loadu_pd(data.as_ptr().add(i));
543        let clamped = _mm256_max_pd(vals, epsilon);
544
545        let mut ln_vals = [0.0f64; 4];
546        _mm256_storeu_pd(ln_vals.as_mut_ptr(), clamped);
547        for j in 0..4 {
548            ln_vals[j] = ln_vals[j].ln();
549        }
550        let ln_result = _mm256_loadu_pd(ln_vals.as_ptr());
551
552        _mm256_storeu_pd(ln_values.as_mut_ptr().add(i), ln_result);
553
554        i += 4;
555    }
556
557    for j in i..len {
558        ln_values[j] = data[j].max(1e-10).ln();
559    }
560
561    for j in (first + period)..len {
562        let mut num_accum = zero;
563        let mut denom_accum = zero;
564
565        let mut idx = 0;
566        while idx + 4 <= period {
567            let mut diffs = [0.0f64; 4];
568            for k in 0..4 {
569                let i = idx + k;
570                let diff = (ln_values[j - i] - ln_values[j - i - 1]).abs();
571                diffs[k] = diff;
572            }
573            let oi_vec = _mm256_loadu_pd(diffs.as_ptr());
574
575            let weights = _mm256_loadu_pd(sqrt_diffs.as_ptr().add(idx));
576
577            num_accum = _mm256_fmadd_pd(oi_vec, weights, num_accum);
578            denom_accum = _mm256_add_pd(denom_accum, oi_vec);
579
580            idx += 4;
581        }
582
583        let num_scalar = horizontal_sum_avx2(num_accum);
584        let denom_scalar = horizontal_sum_avx2(denom_accum);
585
586        let mut num_final = num_scalar;
587        let mut denom_final = denom_scalar;
588
589        for i in idx..period {
590            let oi = (ln_values[j - i] - ln_values[j - i - 1]).abs();
591            num_final += oi * sqrt_diffs[i];
592            denom_final += oi;
593        }
594
595        let ratio = if denom_final == 0.0 {
596            0.0
597        } else {
598            num_final / denom_final
599        };
600        let i = period - 1;
601        out[j] = data[j - i] * ratio + data[j - i - 1] * (1.0 - ratio);
602    }
603}
604
605#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
606#[inline]
607#[target_feature(enable = "avx2")]
608unsafe fn horizontal_sum_avx2(v: __m256d) -> f64 {
609    let vlow = _mm256_castpd256_pd128(v);
610    let vhigh = _mm256_extractf128_pd(v, 1);
611
612    let sum128 = _mm_add_pd(vlow, vhigh);
613
614    let high64 = _mm_unpackhi_pd(sum128, sum128);
615
616    _mm_cvtsd_f64(_mm_add_sd(sum128, high64))
617}
618
619#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
620#[inline]
621#[target_feature(enable = "avx2")]
622unsafe fn fast_ln_avx2_hi(x: __m256d) -> __m256d {
623    let one = _mm256_set1_pd(1.0);
624    let two = _mm256_set1_pd(2.0);
625    let half = _mm256_set1_pd(0.5);
626    let ln2 = _mm256_set1_pd(std::f64::consts::LN_2);
627    let sqrt_half = _mm256_set1_pd(0.7071067811865475244);
628
629    let mut mantissa = [0.0f64; 4];
630    let mut exponent = [0i32; 4];
631    _mm256_storeu_pd(mantissa.as_mut_ptr(), x);
632
633    for j in 0..4 {
634        let bits = mantissa[j].to_bits();
635        let exp_bits = ((bits >> 52) & 0x7FF) as i32;
636        exponent[j] = exp_bits - 1023;
637
638        let mantissa_bits = (bits & !0x7FF0000000000000) | 0x3FF0000000000000;
639        mantissa[j] = f64::from_bits(mantissa_bits);
640    }
641
642    let mut m = _mm256_loadu_pd(mantissa.as_ptr());
643    let e_vals = [
644        exponent[0] as f64,
645        exponent[1] as f64,
646        exponent[2] as f64,
647        exponent[3] as f64,
648    ];
649    let mut e_f64 = _mm256_loadu_pd(e_vals.as_ptr());
650
651    let mask = _mm256_cmp_pd(m, sqrt_half, _CMP_LT_OQ);
652    m = _mm256_blendv_pd(m, _mm256_mul_pd(m, two), mask);
653    e_f64 = _mm256_blendv_pd(e_f64, _mm256_sub_pd(e_f64, one), mask);
654
655    let f = _mm256_sub_pd(m, one);
656
657    let two_plus_f = _mm256_add_pd(two, f);
658    let s = _mm256_div_pd(f, two_plus_f);
659    let z = _mm256_mul_pd(s, s);
660    let w = _mm256_mul_pd(z, z);
661
662    let lg1 = _mm256_set1_pd(6.666666666666735130e-01);
663    let lg2 = _mm256_set1_pd(3.999999999940941908e-01);
664    let lg3 = _mm256_set1_pd(2.857142874366239149e-01);
665    let lg4 = _mm256_set1_pd(2.222219843214978396e-01);
666    let lg5 = _mm256_set1_pd(1.818357216161805012e-01);
667    let lg6 = _mm256_set1_pd(1.531383769920937332e-01);
668    let lg7 = _mm256_set1_pd(1.479819860511658591e-01);
669
670    let mut r1 = lg7;
671    r1 = _mm256_fmadd_pd(r1, z, lg5);
672    r1 = _mm256_fmadd_pd(r1, z, lg3);
673    r1 = _mm256_fmadd_pd(r1, z, lg1);
674    r1 = _mm256_mul_pd(r1, z);
675
676    let mut r2 = lg6;
677    r2 = _mm256_fmadd_pd(r2, z, lg4);
678    r2 = _mm256_fmadd_pd(r2, z, lg2);
679    r2 = _mm256_mul_pd(r2, w);
680
681    let r = _mm256_add_pd(r1, r2);
682
683    let hfsq = _mm256_mul_pd(_mm256_mul_pd(half, f), f);
684    let f_times_hfsq = _mm256_mul_pd(f, hfsq);
685    let ln1pf = _mm256_sub_pd(f, hfsq);
686    let ln1pf = _mm256_fmadd_pd(f_times_hfsq, r, ln1pf);
687
688    _mm256_fmadd_pd(e_f64, ln2, ln1pf)
689}
690
691#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
692#[inline]
693#[target_feature(enable = "avx512f")]
694unsafe fn _mm512_abs_pd(a: __m512d) -> __m512d {
695    let sign_mask = _mm512_set1_pd(-0.0);
696    _mm512_andnot_pd(sign_mask, a)
697}
698
699#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
700#[inline]
701#[target_feature(enable = "avx512f,avx512dq,avx512vl,avx512bw,fma")]
702pub unsafe fn nma_avx512(
703    data: &[f64],
704    period: usize,
705    first: usize,
706    ln_values: &mut [f64],
707    sqrt_diffs: &mut [f64],
708    out: &mut [f64],
709) {
710    let len = data.len();
711
712    let one = _mm512_set1_pd(1.0);
713    let zero = _mm512_setzero_pd();
714
715    for i in 0..len {
716        ln_values[i] = data[i].max(1e-10).ln();
717    }
718
719    for j in (first + period)..len {
720        let mut num_accum = zero;
721        let mut denom_accum = zero;
722
723        let mut idx = 0;
724        while idx + 8 <= period {
725            if j >= idx + 8 {
726                let base_ptr = ln_values.as_ptr().add(j - idx - 8);
727
728                let prev = _mm512_loadu_pd(base_ptr);
729
730                let curr = _mm512_loadu_pd(base_ptr.add(1));
731
732                let diff = _mm512_sub_pd(curr, prev);
733                let abs_diff = _mm512_abs_pd(diff);
734
735                let perm_indices = _mm512_set_epi64(7, 6, 5, 4, 3, 2, 1, 0);
736                let oi_vec = _mm512_permutexvar_pd(perm_indices, abs_diff);
737
738                let weights = _mm512_loadu_pd(sqrt_diffs.as_ptr().add(idx));
739
740                num_accum = _mm512_fmadd_pd(oi_vec, weights, num_accum);
741                denom_accum = _mm512_add_pd(denom_accum, oi_vec);
742            } else {
743                for k in 0..8 {
744                    let i = idx + k;
745                    let oi = (ln_values[j - i] - ln_values[j - i - 1]).abs();
746                    let weight = sqrt_diffs[i];
747                    num_accum = _mm512_mask_add_pd(
748                        num_accum,
749                        1 << k,
750                        num_accum,
751                        _mm512_set1_pd(oi * weight),
752                    );
753                    denom_accum =
754                        _mm512_mask_add_pd(denom_accum, 1 << k, denom_accum, _mm512_set1_pd(oi));
755                }
756            }
757
758            idx += 8;
759        }
760
761        let mut num_scalar = _mm512_reduce_add_pd(num_accum);
762        let mut denom_scalar = _mm512_reduce_add_pd(denom_accum);
763
764        for i in idx..period {
765            let oi = (ln_values[j - i] - ln_values[j - i - 1]).abs();
766            num_scalar += oi * sqrt_diffs[i];
767            denom_scalar += oi;
768        }
769
770        let ratio = if denom_scalar == 0.0 {
771            0.0
772        } else {
773            num_scalar / denom_scalar
774        };
775        let i = period - 1;
776        out[j] = data[j - i] * ratio + data[j - i - 1] * (1.0 - ratio);
777    }
778}
779
780#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
781#[inline]
782#[target_feature(enable = "avx512f,avx512dq,avx512vl,fma")]
783pub unsafe fn nma_avx512_v2(
784    data: &[f64],
785    period: usize,
786    first: usize,
787    ln_values: &mut [f64],
788    sqrt_diffs: &mut [f64],
789    out: &mut [f64],
790) {
791    use aligned_vec::AVec;
792    use core::arch::x86_64::*;
793
794    let len = data.len();
795    debug_assert!(len == ln_values.len());
796    debug_assert!(period >= 1 && period <= len);
797
798    for i in 0..len {
799        ln_values[i] = data[i].max(1e-10).ln();
800    }
801
802    for i in 0..len - 1 {
803        ln_values[i] = (ln_values[i + 1] - ln_values[i]).abs();
804    }
805    ln_values[len - 1] = 0.0;
806    let d = ln_values;
807
808    let mut s = alloc_with_nan_prefix(len + 1, 0);
809    s[0] = 0.0;
810    for k in 0..len {
811        s[k + 1] = s[k] + d[k];
812    }
813
814    let wlen_padded = (period + 7) & !7;
815    let mut w_rev = AVec::<f64>::with_capacity(64, wlen_padded);
816    w_rev.resize(wlen_padded, 0.0);
817    for i in 0..period {
818        w_rev[i] = sqrt_diffs[period - 1 - i];
819    }
820
821    let warm = first + period;
822    let zero = _mm512_setzero_pd();
823
824    for j in warm..len {
825        let base = j - period;
826
827        let denom = s[j] - s[j - period];
828
829        let mut num_acc = zero;
830        let mut t = 0usize;
831
832        while t + 16 <= period {
833            let d0 = _mm512_loadu_pd(d.as_ptr().add(base + t));
834            let w0 = _mm512_loadu_pd(w_rev.as_ptr().add(t));
835            let d1 = _mm512_loadu_pd(d.as_ptr().add(base + t + 8));
836            let w1 = _mm512_loadu_pd(w_rev.as_ptr().add(t + 8));
837            num_acc = _mm512_fmadd_pd(d0, w0, num_acc);
838            num_acc = _mm512_fmadd_pd(d1, w1, num_acc);
839            t += 16;
840        }
841        while t + 8 <= period {
842            let d0 = _mm512_loadu_pd(d.as_ptr().add(base + t));
843            let w0 = _mm512_loadu_pd(w_rev.as_ptr().add(t));
844            num_acc = _mm512_fmadd_pd(d0, w0, num_acc);
845            t += 8;
846        }
847        if t < period {
848            let tail = (period - t) as u32;
849            let mask: __mmask8 = ((1u32 << tail) - 1) as u8;
850            let d0 = _mm512_maskz_loadu_pd(mask, d.as_ptr().add(base + t));
851            let w0 = _mm512_maskz_loadu_pd(mask, w_rev.as_ptr().add(t));
852            num_acc = _mm512_fmadd_pd(d0, w0, num_acc);
853        }
854
855        let num = _mm512_reduce_add_pd(num_acc);
856        let ratio = if denom == 0.0 { 0.0 } else { num / denom };
857
858        let i0 = period - 1;
859        let x2 = data[j - i0 - 1];
860        let dx = data[j - i0] - x2;
861        out[j] = ratio.mul_add(dx, x2);
862    }
863}
864
865#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
866#[target_feature(enable = "avx512f,avx512dq,avx512vl,fma")]
867unsafe fn nma_batch_avx512_optimized(
868    data: &[f64],
869    sweep: &NmaBatchRange,
870    first: usize,
871    parallel: bool,
872) -> Result<NmaBatchOutput, NmaError> {
873    use aligned_vec::AVec;
874    use core::arch::x86_64::*;
875
876    let combos = expand_grid(sweep)?;
877    if combos.is_empty() {
878        return Err(NmaError::InvalidPeriod {
879            period: 0,
880            data_len: 0,
881        });
882    }
883
884    let len = data.len();
885    let rows = combos.len();
886    let cols = len;
887
888    let mut ln_values = alloc_with_nan_prefix(len, 0);
889    for i in 0..len {
890        ln_values[i] = data[i].max(1e-10).ln();
891    }
892
893    for i in 0..len - 1 {
894        ln_values[i] = (ln_values[i + 1] - ln_values[i]).abs();
895    }
896    ln_values[len - 1] = 0.0;
897    let d = &mut ln_values;
898
899    let mut s = alloc_with_nan_prefix(len + 1, 0);
900    s[0] = 0.0;
901    for k in 0..len {
902        s[k + 1] = s[k] + d[k];
903    }
904
905    let warm: Vec<usize> = combos.iter().map(|c| first + c.period.unwrap()).collect();
906    let mut raw = make_uninit_matrix(rows, cols);
907    unsafe { init_matrix_prefixes(&mut raw, cols, &warm) };
908
909    let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
910        let period = combos[row].period.unwrap();
911        let warm = first + period;
912
913        let out_row =
914            core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
915
916        let wlen_padded = (period + 7) & !7;
917        let mut w_rev = AVec::<f64>::with_capacity(64, wlen_padded);
918        w_rev.resize(wlen_padded, 0.0);
919
920        for i in 0..period {
921            let s0 = ((period - 1 - i) as f64).sqrt();
922            let s1 = ((period - i) as f64).sqrt();
923            w_rev[i] = s1 - s0;
924        }
925
926        let zero = _mm512_setzero_pd();
927
928        for j in warm..len {
929            let base = j - period;
930
931            let denom = s[j] - s[j - period];
932
933            let mut num_acc = zero;
934            let mut t = 0usize;
935
936            while t + 16 <= period {
937                let d0 = _mm512_loadu_pd(d.as_ptr().add(base + t));
938                let w0 = _mm512_loadu_pd(w_rev.as_ptr().add(t));
939                let d1 = _mm512_loadu_pd(d.as_ptr().add(base + t + 8));
940                let w1 = _mm512_loadu_pd(w_rev.as_ptr().add(t + 8));
941                num_acc = _mm512_fmadd_pd(d0, w0, num_acc);
942                num_acc = _mm512_fmadd_pd(d1, w1, num_acc);
943                t += 16;
944            }
945            while t + 8 <= period {
946                let d0 = _mm512_loadu_pd(d.as_ptr().add(base + t));
947                let w0 = _mm512_loadu_pd(w_rev.as_ptr().add(t));
948                num_acc = _mm512_fmadd_pd(d0, w0, num_acc);
949                t += 8;
950            }
951            if t < period {
952                let tail = (period - t) as u32;
953                let mask: __mmask8 = ((1u32 << tail) - 1) as u8;
954                let d0 = _mm512_maskz_loadu_pd(mask, d.as_ptr().add(base + t));
955                let w0 = _mm512_maskz_loadu_pd(mask, w_rev.as_ptr().add(t));
956                num_acc = _mm512_fmadd_pd(d0, w0, num_acc);
957            }
958
959            let num = _mm512_reduce_add_pd(num_acc);
960            let ratio = if denom == 0.0 { 0.0 } else { num / denom };
961
962            let i0 = period - 1;
963            let x2 = data[j - i0 - 1];
964            let dx = data[j - i0] - x2;
965            out_row[j] = ratio.mul_add(dx, x2);
966        }
967    };
968
969    if parallel {
970        #[cfg(not(target_arch = "wasm32"))]
971        {
972            use rayon::prelude::*;
973            raw.par_chunks_mut(cols)
974                .enumerate()
975                .for_each(|(row, slice)| do_row(row, slice));
976        }
977        #[cfg(target_arch = "wasm32")]
978        {
979            for (row, slice) in raw.chunks_mut(cols).enumerate() {
980                do_row(row, slice);
981            }
982        }
983    } else {
984        for (row, slice) in raw.chunks_mut(cols).enumerate() {
985            do_row(row, slice);
986        }
987    }
988
989    let values: Vec<f64> = unsafe { std::mem::transmute(raw) };
990
991    Ok(NmaBatchOutput {
992        values,
993        combos,
994        rows,
995        cols,
996    })
997}
998
999#[inline(always)]
1000pub fn nma_batch_with_kernel(
1001    data: &[f64],
1002    sweep: &NmaBatchRange,
1003    k: Kernel,
1004) -> Result<NmaBatchOutput, NmaError> {
1005    let kernel = match k {
1006        Kernel::Auto => detect_best_batch_kernel(),
1007        other if other.is_batch() => other,
1008        _ => return Err(NmaError::InvalidKernelForBatch(k)),
1009    };
1010
1011    let simd = match kernel {
1012        Kernel::Avx512Batch => Kernel::Avx512,
1013        Kernel::Avx2Batch => Kernel::Avx2,
1014        Kernel::ScalarBatch => Kernel::Scalar,
1015        _ => Kernel::Scalar,
1016    };
1017    nma_batch_par_slice(data, sweep, simd)
1018}
1019
1020#[derive(Clone, Debug)]
1021pub struct NmaBatchRange {
1022    pub period: (usize, usize, usize),
1023}
1024
1025impl Default for NmaBatchRange {
1026    fn default() -> Self {
1027        Self {
1028            period: (40, 289, 1),
1029        }
1030    }
1031}
1032
1033#[derive(Clone, Debug, Default)]
1034pub struct NmaBatchBuilder {
1035    range: NmaBatchRange,
1036    kernel: Kernel,
1037}
1038
1039impl NmaBatchBuilder {
1040    pub fn new() -> Self {
1041        Self::default()
1042    }
1043    pub fn kernel(mut self, k: Kernel) -> Self {
1044        self.kernel = k;
1045        self
1046    }
1047
1048    #[inline]
1049    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1050        self.range.period = (start, end, step);
1051        self
1052    }
1053    #[inline]
1054    pub fn period_static(mut self, p: usize) -> Self {
1055        self.range.period = (p, p, 0);
1056        self
1057    }
1058
1059    pub fn apply_slice(self, data: &[f64]) -> Result<NmaBatchOutput, NmaError> {
1060        nma_batch_with_kernel(data, &self.range, self.kernel)
1061    }
1062
1063    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<NmaBatchOutput, NmaError> {
1064        NmaBatchBuilder::new().kernel(k).apply_slice(data)
1065    }
1066
1067    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<NmaBatchOutput, NmaError> {
1068        let slice = source_type(c, src);
1069        self.apply_slice(slice)
1070    }
1071
1072    pub fn with_default_candles(c: &Candles) -> Result<NmaBatchOutput, NmaError> {
1073        NmaBatchBuilder::new()
1074            .kernel(Kernel::Auto)
1075            .apply_candles(c, "close")
1076    }
1077}
1078
1079#[derive(Clone, Debug)]
1080pub struct NmaBatchOutput {
1081    pub values: Vec<f64>,
1082    pub combos: Vec<NmaParams>,
1083    pub rows: usize,
1084    pub cols: usize,
1085}
1086
1087impl NmaBatchOutput {
1088    pub fn row_for_params(&self, p: &NmaParams) -> Option<usize> {
1089        self.combos
1090            .iter()
1091            .position(|c| c.period.unwrap_or(40) == p.period.unwrap_or(40))
1092    }
1093
1094    pub fn values_for(&self, p: &NmaParams) -> Option<&[f64]> {
1095        self.row_for_params(p).map(|row| {
1096            let start = row * self.cols;
1097            &self.values[start..start + self.cols]
1098        })
1099    }
1100}
1101
1102#[inline(always)]
1103fn expand_grid(r: &NmaBatchRange) -> Result<Vec<NmaParams>, NmaError> {
1104    fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, NmaError> {
1105        if step == 0 || start == end {
1106            return Ok(vec![start]);
1107        }
1108        if start < end {
1109            let mut v = Vec::new();
1110            let mut cur = start;
1111            while cur <= end {
1112                v.push(cur);
1113                cur = cur
1114                    .checked_add(step)
1115                    .ok_or_else(|| NmaError::InvalidRange { start, end, step })?;
1116            }
1117            if v.is_empty() {
1118                return Err(NmaError::InvalidRange { start, end, step });
1119            }
1120            Ok(v)
1121        } else {
1122            Err(NmaError::InvalidRange { start, end, step })
1123        }
1124    }
1125    let periods = axis_usize(r.period)?;
1126
1127    let mut out = Vec::with_capacity(periods.len());
1128    for &p in &periods {
1129        out.push(NmaParams { period: Some(p) });
1130    }
1131    Ok(out)
1132}
1133
1134#[inline]
1135fn round_up8(x: usize) -> usize {
1136    (x + 7) & !7
1137}
1138
1139#[inline(always)]
1140fn nma_batch_inner_into_scalar_reuse(
1141    data: &[f64],
1142    sweep: &NmaBatchRange,
1143    parallel: bool,
1144    out: &mut [f64],
1145) -> Result<Vec<NmaParams>, NmaError> {
1146    let combos = expand_grid(sweep)?;
1147    if combos.is_empty() {
1148        return Err(NmaError::InvalidInput("no parameter combinations".into()));
1149    }
1150
1151    let len = data.len();
1152    let first = data
1153        .iter()
1154        .position(|x| !x.is_nan())
1155        .ok_or(NmaError::AllValuesNaN)?;
1156    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1157    if len - first < max_p + 1 {
1158        return Err(NmaError::NotEnoughValidData {
1159            needed: max_p + 1,
1160            valid: len - first,
1161        });
1162    }
1163
1164    let rows = combos.len();
1165    let cols = len;
1166    let warm: Vec<usize> = combos.iter().map(|c| first + c.period.unwrap()).collect();
1167    let out_mu = unsafe {
1168        std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
1169    };
1170    unsafe { init_matrix_prefixes(out_mu, cols, &warm) };
1171
1172    let mut ln = alloc_with_nan_prefix(len, 0);
1173    for i in 0..len {
1174        ln[i] = data[i].max(1e-10).ln();
1175    }
1176    for i in 0..len.saturating_sub(1) {
1177        ln[i] = (ln[i + 1] - ln[i]).abs();
1178    }
1179    ln[len.saturating_sub(1)] = 0.0;
1180    let d = &ln;
1181
1182    let mut s = alloc_with_nan_prefix(len + 1, 0);
1183    s[0] = 0.0;
1184    for i in 0..len {
1185        s[i + 1] = s[i] + d[i];
1186    }
1187
1188    let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| {
1189        let p = combos[row].period.unwrap();
1190        let warm = first + p;
1191
1192        let mut w_rev = Vec::with_capacity(p);
1193        for i in 0..p {
1194            let s0 = ((p - 1 - i) as f64).sqrt();
1195            let s1 = ((p - i) as f64).sqrt();
1196            w_rev.push(s1 - s0);
1197        }
1198        let dst = unsafe {
1199            std::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len())
1200        };
1201
1202        for j in warm..len {
1203            let base = j - p;
1204            let denom = s[j] - s[j - p];
1205
1206            let mut num = 0.0;
1207
1208            for t in 0..p {
1209                num += d[base + t] * w_rev[t];
1210            }
1211
1212            let ratio = if denom == 0.0 { 0.0 } else { num / denom };
1213            let x2 = data[j - p];
1214            let x1 = data[j - p + 1];
1215            dst[j] = ratio.mul_add(x1 - x2, x2);
1216        }
1217    };
1218
1219    if parallel {
1220        #[cfg(not(target_arch = "wasm32"))]
1221        {
1222            use rayon::prelude::*;
1223            out_mu
1224                .par_chunks_mut(cols)
1225                .enumerate()
1226                .for_each(|(r, row)| do_row(r, row));
1227        }
1228        #[cfg(target_arch = "wasm32")]
1229        for (r, row) in out_mu.chunks_mut(cols).enumerate() {
1230            do_row(r, row);
1231        }
1232    } else {
1233        for (r, row) in out_mu.chunks_mut(cols).enumerate() {
1234            do_row(r, row);
1235        }
1236    }
1237
1238    Ok(combos)
1239}
1240
1241#[inline(always)]
1242pub fn nma_batch_slice(
1243    data: &[f64],
1244    sweep: &NmaBatchRange,
1245    kern: Kernel,
1246) -> Result<NmaBatchOutput, NmaError> {
1247    nma_batch_inner(data, sweep, kern, false)
1248}
1249
1250#[inline(always)]
1251pub fn nma_batch_par_slice(
1252    data: &[f64],
1253    sweep: &NmaBatchRange,
1254    kern: Kernel,
1255) -> Result<NmaBatchOutput, NmaError> {
1256    nma_batch_inner(data, sweep, kern, true)
1257}
1258
1259#[inline(always)]
1260fn nma_batch_inner(
1261    data: &[f64],
1262    sweep: &NmaBatchRange,
1263    kern: Kernel,
1264    parallel: bool,
1265) -> Result<NmaBatchOutput, NmaError> {
1266    #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1267    if kern == Kernel::Avx512 {
1268        let first = data
1269            .iter()
1270            .position(|x| !x.is_nan())
1271            .ok_or(NmaError::AllValuesNaN)?;
1272        return unsafe { nma_batch_avx512_optimized(data, sweep, first, parallel) };
1273    }
1274
1275    let combos = expand_grid(sweep)?;
1276    if combos.is_empty() {
1277        return Err(NmaError::InvalidInput("no parameter combinations".into()));
1278    }
1279    let rows = combos.len();
1280    let cols = data.len();
1281    let _ = rows
1282        .checked_mul(cols)
1283        .ok_or_else(|| NmaError::InvalidInput("rows*cols overflow".into()))?;
1284
1285    if kern == Kernel::Scalar {
1286        let first = data
1287            .iter()
1288            .position(|x| !x.is_nan())
1289            .ok_or(NmaError::AllValuesNaN)?;
1290        let warm: Vec<usize> = combos.iter().map(|c| first + c.period.unwrap()).collect();
1291        let mut raw = make_uninit_matrix(rows, cols);
1292        unsafe { init_matrix_prefixes(&mut raw, cols, &warm) };
1293
1294        let out: &mut [f64] =
1295            unsafe { std::slice::from_raw_parts_mut(raw.as_mut_ptr() as *mut f64, raw.len()) };
1296        let combos = nma_batch_inner_into_scalar_reuse(data, sweep, parallel, out)?;
1297
1298        let mut guard = core::mem::ManuallyDrop::new(raw);
1299        let values = unsafe {
1300            Vec::from_raw_parts(
1301                guard.as_mut_ptr() as *mut f64,
1302                guard.len(),
1303                guard.capacity(),
1304            )
1305        };
1306        return Ok(NmaBatchOutput {
1307            values,
1308            combos,
1309            rows,
1310            cols,
1311        });
1312    }
1313
1314    let first = data
1315        .iter()
1316        .position(|x| !x.is_nan())
1317        .ok_or(NmaError::AllValuesNaN)?;
1318    let max_p = combos
1319        .iter()
1320        .map(|c| round_up8(c.period.unwrap()))
1321        .max()
1322        .unwrap();
1323    if data.len() - first < max_p + 1 {
1324        return Err(NmaError::NotEnoughValidData {
1325            needed: max_p + 1,
1326            valid: data.len() - first,
1327        });
1328    }
1329
1330    let warm: Vec<usize> = combos.iter().map(|c| first + c.period.unwrap()).collect();
1331
1332    let mut raw = make_uninit_matrix(rows, cols);
1333    unsafe { init_matrix_prefixes(&mut raw, cols, &warm) };
1334
1335    let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| {
1336        let period = combos[row].period.unwrap();
1337
1338        let out_row = unsafe {
1339            core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len())
1340        };
1341
1342        match kern {
1343            Kernel::Scalar => nma_row_scalar(data, first, period, out_row),
1344            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1345            Kernel::Avx2 => unsafe { nma_row_avx2(data, first, period, out_row) },
1346            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1347            Kernel::Avx512 => unsafe { nma_row_avx512(data, first, period, out_row) },
1348            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1349            Kernel::Avx2 | Kernel::Avx512 => nma_row_scalar(data, first, period, out_row),
1350            _ => nma_row_scalar(data, first, period, out_row),
1351        }
1352    };
1353
1354    if parallel {
1355        #[cfg(not(target_arch = "wasm32"))]
1356        {
1357            use rayon::prelude::*;
1358            raw.par_chunks_mut(cols)
1359                .enumerate()
1360                .for_each(|(row, slice)| do_row(row, slice));
1361        }
1362
1363        #[cfg(target_arch = "wasm32")]
1364        {
1365            for (row, slice) in raw.chunks_mut(cols).enumerate() {
1366                do_row(row, slice);
1367            }
1368        }
1369    } else {
1370        for (row, slice) in raw.chunks_mut(cols).enumerate() {
1371            do_row(row, slice);
1372        }
1373    }
1374
1375    let mut guard = core::mem::ManuallyDrop::new(raw);
1376    let values = unsafe {
1377        Vec::from_raw_parts(
1378            guard.as_mut_ptr() as *mut f64,
1379            guard.len(),
1380            guard.capacity(),
1381        )
1382    };
1383
1384    Ok(NmaBatchOutput {
1385        values,
1386        combos,
1387        rows,
1388        cols,
1389    })
1390}
1391
1392#[inline(always)]
1393fn nma_batch_inner_into(
1394    data: &[f64],
1395    sweep: &NmaBatchRange,
1396    kern: Kernel,
1397    parallel: bool,
1398    out: &mut [f64],
1399) -> Result<Vec<NmaParams>, NmaError> {
1400    let combos = expand_grid(sweep)?;
1401    if combos.is_empty() {
1402        return Err(NmaError::InvalidInput("no parameter combinations".into()));
1403    }
1404
1405    let first = data
1406        .iter()
1407        .position(|x| !x.is_nan())
1408        .ok_or(NmaError::AllValuesNaN)?;
1409    let max_p = combos
1410        .iter()
1411        .map(|c| round_up8(c.period.unwrap()))
1412        .max()
1413        .unwrap();
1414    if data.len() - first < max_p + 1 {
1415        return Err(NmaError::NotEnoughValidData {
1416            needed: max_p + 1,
1417            valid: data.len() - first,
1418        });
1419    }
1420
1421    let rows = combos.len();
1422    let cols = data.len();
1423
1424    let warm: Vec<usize> = combos.iter().map(|c| first + c.period.unwrap()).collect();
1425
1426    let out_uninit = unsafe {
1427        std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
1428    };
1429
1430    unsafe { init_matrix_prefixes(out_uninit, cols, &warm) };
1431
1432    let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| {
1433        let period = combos[row].period.unwrap();
1434
1435        let out_row = unsafe {
1436            core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len())
1437        };
1438
1439        match kern {
1440            Kernel::Scalar => nma_row_scalar(data, first, period, out_row),
1441            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1442            Kernel::Avx2 => unsafe { nma_row_avx2(data, first, period, out_row) },
1443            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1444            Kernel::Avx512 => unsafe { nma_row_avx512(data, first, period, out_row) },
1445            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1446            Kernel::Avx2 | Kernel::Avx512 => nma_row_scalar(data, first, period, out_row),
1447            _ => nma_row_scalar(data, first, period, out_row),
1448        }
1449    };
1450
1451    if parallel {
1452        #[cfg(not(target_arch = "wasm32"))]
1453        {
1454            out_uninit
1455                .par_chunks_mut(cols)
1456                .enumerate()
1457                .for_each(|(row, slice)| do_row(row, slice));
1458        }
1459        #[cfg(target_arch = "wasm32")]
1460        {
1461            for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
1462                do_row(row, slice);
1463            }
1464        }
1465    } else {
1466        for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
1467            do_row(row, slice);
1468        }
1469    }
1470
1471    Ok(combos)
1472}
1473
1474#[inline(always)]
1475fn nma_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1476    nma_scalar(data, period, first, out)
1477}
1478
1479#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1480#[inline(always)]
1481unsafe fn nma_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1482    let len = data.len();
1483    let mut ln_values = alloc_with_nan_prefix(len, 0);
1484
1485    let mut sqrt_diffs = vec![0.0; period];
1486
1487    for i in 0..len {
1488        ln_values[i] = data[i].max(1e-10).ln();
1489    }
1490
1491    for k in 0..period {
1492        let s0 = (k as f64).sqrt();
1493        let s1 = ((k + 1) as f64).sqrt();
1494        sqrt_diffs[k] = s1 - s0;
1495    }
1496
1497    nma_avx2(data, period, first, &mut ln_values, &mut sqrt_diffs, out);
1498}
1499
1500#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1501#[inline(always)]
1502pub unsafe fn nma_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1503    let len = data.len();
1504    let mut ln_values = alloc_with_nan_prefix(len, 0);
1505
1506    let mut sqrt_diffs = vec![0.0; period];
1507
1508    for i in 0..len {
1509        ln_values[i] = data[i].max(1e-10).ln();
1510    }
1511
1512    for k in 0..period {
1513        let s0 = (k as f64).sqrt();
1514        let s1 = ((k + 1) as f64).sqrt();
1515        sqrt_diffs[k] = s1 - s0;
1516    }
1517
1518    nma_avx512_v2(data, period, first, &mut ln_values, &mut sqrt_diffs, out);
1519}
1520
1521#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1522#[inline(always)]
1523pub unsafe fn nma_row_avx512_short(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1524    nma_row_avx512(data, first, period, out)
1525}
1526
1527#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1528#[inline(always)]
1529pub unsafe fn nma_row_avx512_long(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1530    nma_row_avx512(data, first, period, out)
1531}
1532
1533#[derive(Debug, Clone)]
1534pub struct NmaStream {
1535    period: usize,
1536
1537    m: usize,
1538
1539    alpha: Vec<f64>,
1540    beta: Vec<f64>,
1541    beta_pow_p: Vec<f64>,
1542
1543    d_ring: Vec<f64>,
1544    d_head: usize,
1545    d_count: usize,
1546    denom: f64,
1547    x_acc: Vec<f64>,
1548
1549    buffer: Vec<f64>,
1550    ln_buffer: Vec<f64>,
1551    head: usize,
1552    filled: bool,
1553
1554    sqrt_diffs: Vec<f64>,
1555}
1556
1557#[inline(always)]
1558fn ln_pos(x: f64) -> f64 {
1559    debug_assert!(x > 0.0);
1560    x.ln()
1561}
1562
1563impl NmaStream {
1564    pub fn try_new(params: NmaParams) -> Result<Self, NmaError> {
1565        let period = params.period.unwrap_or(40);
1566        if period == 0 {
1567            return Err(NmaError::InvalidPeriod {
1568                period,
1569                data_len: 0,
1570            });
1571        }
1572
1573        let mut sqrt_diffs = Vec::with_capacity(period);
1574        for i in 0..period {
1575            let s0 = (i as f64).sqrt();
1576            let s1 = ((i + 1) as f64).sqrt();
1577            sqrt_diffs.push(s1 - s0);
1578        }
1579
1580        const GAMMAS: [f64; 4] = [0.25, 1.2, 3.0, 8.0];
1581        let m = if period <= 64 { 3 } else { 4 };
1582        let mut beta = Vec::with_capacity(m);
1583        for g in GAMMAS.iter().take(m) {
1584            beta.push((-g / (period as f64)).exp());
1585        }
1586
1587        let alpha = fit_exp_weights_least_squares(&sqrt_diffs, &beta);
1588
1589        let mut beta_pow_p = Vec::with_capacity(m);
1590        for &b in &beta {
1591            beta_pow_p.push(b.powi(period as i32));
1592        }
1593
1594        Ok(Self {
1595            period,
1596            m,
1597            alpha,
1598            beta,
1599            beta_pow_p,
1600            d_ring: vec![0.0; period],
1601            d_head: 0,
1602            d_count: 0,
1603            denom: 0.0,
1604            x_acc: vec![0.0; m],
1605
1606            buffer: vec![f64::NAN; period + 1],
1607            ln_buffer: vec![f64::NAN; period + 1],
1608            head: 0,
1609            filled: false,
1610
1611            sqrt_diffs,
1612        })
1613    }
1614
1615    #[inline(always)]
1616    pub fn update(&mut self, value: f64) -> Option<f64> {
1617        if !value.is_finite() {
1618            self.reset_state();
1619            return None;
1620        }
1621
1622        let ln_val = ln_pos(value.max(1e-10));
1623
1624        let prev_idx = (self.head + self.period) % (self.period + 1);
1625        let prev_ln = self.ln_buffer[prev_idx];
1626
1627        self.buffer[self.head] = value;
1628        self.ln_buffer[self.head] = ln_val;
1629
1630        self.head = (self.head + 1) % (self.period + 1);
1631        if !self.filled && self.head == 0 {
1632            self.filled = true;
1633        }
1634
1635        if prev_ln.is_nan() {
1636            return None;
1637        }
1638
1639        let d_new = (ln_val - prev_ln).abs();
1640
1641        if self.d_count < self.period {
1642            self.d_ring[self.d_head] = d_new;
1643            self.d_head = (self.d_head + 1) % self.period;
1644            self.d_count += 1;
1645            self.denom += d_new;
1646
1647            for m in 0..self.m {
1648                self.x_acc[m] = self.beta[m] * self.x_acc[m] + d_new;
1649            }
1650        } else {
1651            let d_old = self.d_ring[self.d_head];
1652            self.d_ring[self.d_head] = d_new;
1653            self.d_head = (self.d_head + 1) % self.period;
1654
1655            self.denom += d_new - d_old;
1656
1657            for m in 0..self.m {
1658                self.x_acc[m] = self.beta[m] * self.x_acc[m] + d_new - self.beta_pow_p[m] * d_old;
1659            }
1660        }
1661
1662        if !self.filled {
1663            return None;
1664        }
1665
1666        let mut num = 0.0f64;
1667        for m in 0..self.m {
1668            num = (self.alpha[m] * self.x_acc[m]).mul_add(1.0, num);
1669        }
1670        let ratio = if self.denom == 0.0 {
1671            0.0
1672        } else {
1673            num / self.denom
1674        };
1675
1676        let x0 = self.buffer[self.head];
1677        let x1 = self.buffer[(self.head + 1) % (self.period + 1)];
1678
1679        Some((x1 - x0).mul_add(ratio, x0))
1680    }
1681
1682    #[inline(always)]
1683    fn reset_state(&mut self) {
1684        self.d_head = 0;
1685        self.d_count = 0;
1686        self.denom = 0.0;
1687        for v in &mut self.d_ring {
1688            *v = 0.0;
1689        }
1690        for v in &mut self.x_acc {
1691            *v = 0.0;
1692        }
1693        for v in &mut self.buffer {
1694            *v = f64::NAN;
1695        }
1696        for v in &mut self.ln_buffer {
1697            *v = f64::NAN;
1698        }
1699        self.head = 0;
1700        self.filled = false;
1701    }
1702}
1703
1704fn fit_exp_weights_least_squares(w: &[f64], beta: &[f64]) -> Vec<f64> {
1705    let p = w.len();
1706    let m = beta.len();
1707
1708    let mut ata = vec![0.0f64; m * m];
1709    for u in 0..m {
1710        for v in u..m {
1711            let r = beta[u] * beta[v];
1712            let s = if (1.0 - r).abs() < 1e-15 {
1713                p as f64
1714            } else {
1715                (1.0 - r.powi(p as i32)) / (1.0 - r)
1716            };
1717            ata[u * m + v] = s;
1718            ata[v * m + u] = s;
1719        }
1720    }
1721
1722    let mut atw = vec![0.0f64; m];
1723    for u in 0..m {
1724        let mut pow = 1.0f64;
1725        let bu = beta[u];
1726        let mut sum = 0.0f64;
1727        for i in 0..p {
1728            sum += w[i] * pow;
1729            pow *= bu;
1730        }
1731        atw[u] = sum;
1732    }
1733
1734    let lambda = 1e-12;
1735    for i in 0..m {
1736        ata[i * m + i] += lambda;
1737    }
1738
1739    solve_linear_system(&mut ata, &mut atw, m)
1740}
1741
1742fn solve_linear_system(a: &mut [f64], b: &mut [f64], n: usize) -> Vec<f64> {
1743    for k in 0..n {
1744        let mut piv = k;
1745        let mut maxv = a[k * n + k].abs();
1746        for i in (k + 1)..n {
1747            let v = a[i * n + k].abs();
1748            if v > maxv {
1749                maxv = v;
1750                piv = i;
1751            }
1752        }
1753        if piv != k {
1754            for j in k..n {
1755                a.swap(k * n + j, piv * n + j);
1756            }
1757            b.swap(k, piv);
1758        }
1759        let akk = a[k * n + k];
1760        if akk.abs() < 1e-18 {
1761            a[k * n + k] = 1e-18;
1762        }
1763
1764        for i in (k + 1)..n {
1765            let f = a[i * n + k] / a[k * n + k];
1766            if f != 0.0 {
1767                for j in k..n {
1768                    a[i * n + j] -= f * a[k * n + j];
1769                }
1770                b[i] -= f * b[k];
1771            }
1772        }
1773    }
1774
1775    let mut x = vec![0.0f64; n];
1776    for i in (0..n).rev() {
1777        let mut s = b[i];
1778        for j in (i + 1)..n {
1779            s -= a[i * n + j] * x[j];
1780        }
1781        x[i] = s / a[i * n + i];
1782    }
1783    x
1784}
1785
1786#[cfg(feature = "python")]
1787use crate::utilities::kernel_validation::validate_kernel;
1788#[cfg(feature = "python")]
1789use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1790#[cfg(feature = "python")]
1791use pyo3::exceptions::PyValueError;
1792#[cfg(feature = "python")]
1793use pyo3::prelude::*;
1794#[cfg(feature = "python")]
1795use pyo3::types::PyDict;
1796
1797#[cfg(feature = "python")]
1798#[pyfunction(name = "nma")]
1799#[pyo3(signature = (data, period, kernel=None))]
1800pub fn nma_py<'py>(
1801    py: Python<'py>,
1802    data: numpy::PyReadonlyArray1<'py, f64>,
1803    period: usize,
1804    kernel: Option<&str>,
1805) -> PyResult<Bound<'py, PyArray1<f64>>> {
1806    use numpy::{IntoPyArray, PyArrayMethods};
1807
1808    let slice_in = data.as_slice()?;
1809    let kern = validate_kernel(kernel, false)?;
1810    let params = NmaParams {
1811        period: Some(period),
1812    };
1813    let nma_in = NmaInput::from_slice(slice_in, params);
1814
1815    let result_vec: Vec<f64> = py
1816        .allow_threads(|| nma_with_kernel(&nma_in, kern).map(|o| o.values))
1817        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1818
1819    Ok(result_vec.into_pyarray(py))
1820}
1821
1822#[cfg(feature = "python")]
1823#[pyclass(name = "NmaStream")]
1824pub struct NmaStreamPy {
1825    stream: NmaStream,
1826}
1827
1828#[cfg(feature = "python")]
1829#[pymethods]
1830impl NmaStreamPy {
1831    #[new]
1832    fn new(period: usize) -> PyResult<Self> {
1833        let params = NmaParams {
1834            period: Some(period),
1835        };
1836        let stream =
1837            NmaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1838        Ok(NmaStreamPy { stream })
1839    }
1840
1841    fn update(&mut self, value: f64) -> Option<f64> {
1842        self.stream.update(value)
1843    }
1844}
1845
1846#[cfg(feature = "python")]
1847#[pyfunction(name = "nma_batch")]
1848#[pyo3(signature = (data, period_range, kernel=None))]
1849pub fn nma_batch_py<'py>(
1850    py: Python<'py>,
1851    data: numpy::PyReadonlyArray1<'py, f64>,
1852    period_range: (usize, usize, usize),
1853    kernel: Option<&str>,
1854) -> PyResult<Bound<'py, PyDict>> {
1855    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1856    use pyo3::types::PyDict;
1857
1858    let slice_in = data.as_slice()?;
1859    let kern = validate_kernel(kernel, true)?;
1860
1861    let sweep = NmaBatchRange {
1862        period: period_range,
1863    };
1864
1865    let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1866    let rows = combos.len();
1867    let cols = slice_in.len();
1868    let expected = rows
1869        .checked_mul(cols)
1870        .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
1871
1872    let out_arr = unsafe { PyArray1::<f64>::new(py, [expected], false) };
1873    let slice_out = unsafe { out_arr.as_slice_mut()? };
1874
1875    let combos = py
1876        .allow_threads(|| {
1877            let kernel = match kern {
1878                Kernel::Auto => detect_best_batch_kernel(),
1879                k => k,
1880            };
1881            let simd = match kernel {
1882                Kernel::Avx512Batch => Kernel::Avx512,
1883                Kernel::Avx2Batch => Kernel::Avx2,
1884                Kernel::ScalarBatch => Kernel::Scalar,
1885                _ => kernel,
1886            };
1887
1888            nma_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1889        })
1890        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1891
1892    let dict = PyDict::new(py);
1893    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1894    dict.set_item(
1895        "periods",
1896        combos
1897            .iter()
1898            .map(|p| p.period.unwrap() as u64)
1899            .collect::<Vec<_>>()
1900            .into_pyarray(py),
1901    )?;
1902
1903    Ok(dict)
1904}
1905
1906#[cfg(all(feature = "python", feature = "cuda"))]
1907#[pyfunction(name = "nma_cuda_batch_dev")]
1908#[pyo3(signature = (data_f32, period_range, device_id=0))]
1909pub fn nma_cuda_batch_dev_py<'py>(
1910    py: Python<'py>,
1911    data_f32: numpy::PyReadonlyArray1<'py, f32>,
1912    period_range: (usize, usize, usize),
1913    device_id: usize,
1914) -> PyResult<(NmaDeviceArrayF32Py, Bound<'py, PyDict>)> {
1915    use crate::cuda::cuda_available;
1916    use numpy::IntoPyArray;
1917    use pyo3::types::PyDict;
1918
1919    if !cuda_available() {
1920        return Err(PyValueError::new_err("CUDA not available"));
1921    }
1922
1923    let slice_in = data_f32.as_slice()?;
1924    let sweep = NmaBatchRange {
1925        period: period_range,
1926    };
1927
1928    let (inner, combos, ctx_arc, dev_id) = py.allow_threads(|| {
1929        let cuda = CudaNma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1930        let (dev, combos) = cuda
1931            .nma_batch_dev(slice_in, &sweep)
1932            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1933        cuda.synchronize()
1934            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1935        Ok::<_, PyErr>((dev, combos, cuda.context_arc_clone(), cuda.device_id()))
1936    })?;
1937
1938    let dict = PyDict::new(py);
1939    let periods: Vec<u64> = combos.iter().map(|c| c.period.unwrap() as u64).collect();
1940    dict.set_item("periods", periods.into_pyarray(py))?;
1941
1942    Ok((
1943        NmaDeviceArrayF32Py {
1944            inner,
1945            _ctx: ctx_arc,
1946            device_id: dev_id,
1947        },
1948        dict,
1949    ))
1950}
1951
1952#[cfg(all(feature = "python", feature = "cuda"))]
1953#[pyfunction(name = "nma_cuda_many_series_one_param_dev")]
1954#[pyo3(signature = (data_tm_f32, period, device_id=0))]
1955pub fn nma_cuda_many_series_one_param_dev_py(
1956    py: Python<'_>,
1957    data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1958    period: usize,
1959    device_id: usize,
1960) -> PyResult<NmaDeviceArrayF32Py> {
1961    use crate::cuda::cuda_available;
1962    use numpy::PyUntypedArrayMethods;
1963
1964    if !cuda_available() {
1965        return Err(PyValueError::new_err("CUDA not available"));
1966    }
1967
1968    let flat_in = data_tm_f32.as_slice()?;
1969    let rows = data_tm_f32.shape()[0];
1970    let cols = data_tm_f32.shape()[1];
1971    let params = NmaParams {
1972        period: Some(period),
1973    };
1974
1975    let (inner, ctx_arc, dev_id) = py.allow_threads(|| {
1976        let cuda = CudaNma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1977        let dev = cuda
1978            .nma_multi_series_one_param_time_major_dev(flat_in, cols, rows, &params)
1979            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1980        cuda.synchronize()
1981            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1982        Ok::<_, PyErr>((dev, cuda.context_arc_clone(), cuda.device_id()))
1983    })?;
1984
1985    Ok(NmaDeviceArrayF32Py {
1986        inner,
1987        _ctx: ctx_arc,
1988        device_id: dev_id,
1989    })
1990}
1991
1992#[cfg(all(feature = "python", feature = "cuda"))]
1993#[pyclass(module = "ta_indicators.cuda", name = "NmaDeviceArrayF32", unsendable)]
1994pub struct NmaDeviceArrayF32Py {
1995    pub(crate) inner: DeviceArrayF32,
1996    pub(crate) _ctx: Arc<Context>,
1997    pub(crate) device_id: u32,
1998}
1999
2000#[cfg(all(feature = "python", feature = "cuda"))]
2001#[pymethods]
2002impl NmaDeviceArrayF32Py {
2003    #[getter]
2004    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2005        let d = PyDict::new(py);
2006        d.set_item("shape", (self.inner.rows, self.inner.cols))?;
2007        d.set_item("typestr", "<f4")?;
2008        d.set_item(
2009            "strides",
2010            (
2011                self.inner.cols * std::mem::size_of::<f32>(),
2012                std::mem::size_of::<f32>(),
2013            ),
2014        )?;
2015        d.set_item("data", (self.inner.device_ptr() as usize, false))?;
2016
2017        d.set_item("version", 3)?;
2018        Ok(d)
2019    }
2020
2021    fn __dlpack_device__(&self) -> (i32, i32) {
2022        (2, self.device_id as i32)
2023    }
2024
2025    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2026    fn __dlpack__<'py>(
2027        &mut self,
2028        py: Python<'py>,
2029        stream: Option<pyo3::PyObject>,
2030        max_version: Option<pyo3::PyObject>,
2031        dl_device: Option<pyo3::PyObject>,
2032        copy: Option<pyo3::PyObject>,
2033    ) -> PyResult<PyObject> {
2034        use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
2035
2036        let (kdl, alloc_dev) = self.__dlpack_device__();
2037        if let Some(dev_obj) = dl_device.as_ref() {
2038            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2039                if dev_ty != kdl || dev_id != alloc_dev {
2040                    let wants_copy = copy
2041                        .as_ref()
2042                        .and_then(|c| c.extract::<bool>(py).ok())
2043                        .unwrap_or(false);
2044                    if wants_copy {
2045                        return Err(PyValueError::new_err(
2046                            "device copy not implemented for __dlpack__",
2047                        ));
2048                    } else {
2049                        return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
2050                    }
2051                }
2052            }
2053        }
2054        let _ = stream;
2055
2056        let dummy =
2057            DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
2058        let inner = std::mem::replace(
2059            &mut self.inner,
2060            DeviceArrayF32 {
2061                buf: dummy,
2062                rows: 0,
2063                cols: 0,
2064            },
2065        );
2066
2067        let rows = inner.rows;
2068        let cols = inner.cols;
2069        let buf = inner.buf;
2070
2071        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2072
2073        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2074    }
2075}
2076
2077pub fn nma_into_slice(dst: &mut [f64], input: &NmaInput, kern: Kernel) -> Result<(), NmaError> {
2078    let (data, period, first, mut ln_values, mut sqrt_diffs, chosen) = nma_prepare(input, kern)?;
2079
2080    if dst.len() != data.len() {
2081        return Err(NmaError::OutputLengthMismatch {
2082            expected: data.len(),
2083            got: dst.len(),
2084        });
2085    }
2086
2087    nma_compute_into(
2088        data,
2089        period,
2090        first,
2091        &mut ln_values,
2092        &mut sqrt_diffs,
2093        chosen,
2094        dst,
2095    );
2096
2097    let warmup_end = first + period;
2098    let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
2099    for v in &mut dst[..warmup_end] {
2100        *v = qnan;
2101    }
2102
2103    Ok(())
2104}
2105
2106#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2107pub fn nma_into(input: &NmaInput, out: &mut [f64]) -> Result<(), NmaError> {
2108    nma_into_slice(out, input, Kernel::Auto)
2109}
2110
2111#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2112use serde::{Deserialize, Serialize};
2113#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2114use wasm_bindgen::prelude::*;
2115
2116#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2117#[wasm_bindgen]
2118pub fn nma_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
2119    let params = NmaParams {
2120        period: Some(period),
2121    };
2122    let input = NmaInput::from_slice(data, params);
2123
2124    let mut output = vec![0.0; data.len()];
2125
2126    nma_into_slice(&mut output, &input, detect_best_kernel())
2127        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2128
2129    Ok(output)
2130}
2131
2132#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2133#[derive(Serialize, Deserialize)]
2134pub struct NmaBatchConfig {
2135    pub period_range: (usize, usize, usize),
2136}
2137
2138#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2139#[derive(Serialize, Deserialize)]
2140pub struct NmaBatchJsOutput {
2141    pub values: Vec<f64>,
2142    pub combos: Vec<NmaParams>,
2143    pub rows: usize,
2144    pub cols: usize,
2145}
2146
2147#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2148#[wasm_bindgen(js_name = nma_batch)]
2149pub fn nma_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2150    let config: NmaBatchConfig = serde_wasm_bindgen::from_value(config)
2151        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2152
2153    let sweep = NmaBatchRange {
2154        period: config.period_range,
2155    };
2156
2157    let output = nma_batch_inner(data, &sweep, Kernel::ScalarBatch, false)
2158        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2159
2160    let js_output = NmaBatchJsOutput {
2161        values: output.values,
2162        combos: output.combos,
2163        rows: output.rows,
2164        cols: output.cols,
2165    };
2166
2167    serde_wasm_bindgen::to_value(&js_output)
2168        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2169}
2170
2171#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2172#[wasm_bindgen]
2173pub fn nma_batch_js(
2174    data: &[f64],
2175    period_start: usize,
2176    period_end: usize,
2177    period_step: usize,
2178) -> Result<Vec<f64>, JsValue> {
2179    let sweep = NmaBatchRange {
2180        period: (period_start, period_end, period_step),
2181    };
2182
2183    nma_batch_inner(data, &sweep, Kernel::Scalar, false)
2184        .map(|output| output.values)
2185        .map_err(|e| JsValue::from_str(&e.to_string()))
2186}
2187
2188#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2189#[wasm_bindgen]
2190pub fn nma_batch_metadata_js(
2191    period_start: usize,
2192    period_end: usize,
2193    period_step: usize,
2194) -> Result<Vec<f64>, JsValue> {
2195    let sweep = NmaBatchRange {
2196        period: (period_start, period_end, period_step),
2197    };
2198
2199    let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2200    let metadata: Vec<f64> = combos
2201        .iter()
2202        .map(|combo| combo.period.unwrap() as f64)
2203        .collect();
2204
2205    Ok(metadata)
2206}
2207
2208#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2209#[wasm_bindgen]
2210pub fn nma_batch_rows_cols_js(
2211    period_start: usize,
2212    period_end: usize,
2213    period_step: usize,
2214    data_len: usize,
2215) -> Vec<usize> {
2216    let sweep = NmaBatchRange {
2217        period: (period_start, period_end, period_step),
2218    };
2219    let combos = expand_grid(&sweep).unwrap_or_else(|_| Vec::new());
2220    vec![combos.len(), data_len]
2221}
2222
2223#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2224#[wasm_bindgen]
2225pub fn nma_alloc(len: usize) -> *mut f64 {
2226    let mut vec = Vec::<f64>::with_capacity(len);
2227    let ptr = vec.as_mut_ptr();
2228    std::mem::forget(vec);
2229    ptr
2230}
2231
2232#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2233#[wasm_bindgen]
2234pub fn nma_free(ptr: *mut f64, len: usize) {
2235    if !ptr.is_null() {
2236        unsafe {
2237            let _ = Vec::from_raw_parts(ptr, len, len);
2238        }
2239    }
2240}
2241
2242#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2243#[wasm_bindgen]
2244pub fn nma_into(
2245    in_ptr: *const f64,
2246    out_ptr: *mut f64,
2247    len: usize,
2248    period: usize,
2249) -> Result<(), JsValue> {
2250    if in_ptr.is_null() || out_ptr.is_null() {
2251        return Err(JsValue::from_str("null pointer passed to nma_into"));
2252    }
2253
2254    unsafe {
2255        let data = std::slice::from_raw_parts(in_ptr, len);
2256
2257        if period == 0 || period > len {
2258            return Err(JsValue::from_str("Invalid period"));
2259        }
2260
2261        let params = NmaParams {
2262            period: Some(period),
2263        };
2264        let input = NmaInput::from_slice(data, params);
2265
2266        if in_ptr == out_ptr {
2267            let mut temp = alloc_with_nan_prefix(len, 0);
2268            nma_into_slice(&mut temp, &input, Kernel::Scalar)
2269                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2270
2271            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2272            out.copy_from_slice(&temp);
2273        } else {
2274            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2275            nma_into_slice(out, &input, Kernel::Scalar)
2276                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2277        }
2278
2279        Ok(())
2280    }
2281}
2282
2283#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2284#[wasm_bindgen]
2285pub fn nma_batch_into(
2286    in_ptr: *const f64,
2287    out_ptr: *mut f64,
2288    len: usize,
2289    period_start: usize,
2290    period_end: usize,
2291    period_step: usize,
2292) -> Result<usize, JsValue> {
2293    if in_ptr.is_null() || out_ptr.is_null() {
2294        return Err(JsValue::from_str("null pointer passed to nma_batch_into"));
2295    }
2296
2297    unsafe {
2298        let data = std::slice::from_raw_parts(in_ptr, len);
2299
2300        let sweep = NmaBatchRange {
2301            period: (period_start, period_end, period_step),
2302        };
2303
2304        let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2305        let rows = combos.len();
2306        let cols = len;
2307
2308        let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
2309
2310        nma_batch_inner_into(data, &sweep, Kernel::ScalarBatch, false, out)
2311            .map_err(|e| JsValue::from_str(&e.to_string()))?;
2312
2313        Ok(rows)
2314    }
2315}
2316
2317#[cfg(test)]
2318mod tests {
2319    use super::*;
2320    use crate::skip_if_unsupported;
2321    use crate::utilities::data_loader::read_candles_from_csv;
2322
2323    fn check_nma_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2324        skip_if_unsupported!(kernel, test_name);
2325        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2326        let candles = read_candles_from_csv(file_path)?;
2327
2328        let default_params = NmaParams { period: None };
2329        let input = NmaInput::from_candles(&candles, "close", default_params);
2330        let output = nma_with_kernel(&input, kernel)?;
2331        assert_eq!(output.values.len(), candles.close.len());
2332
2333        Ok(())
2334    }
2335
2336    #[test]
2337    fn test_nma_into_matches_api() -> Result<(), Box<dyn Error>> {
2338        let n = 256usize;
2339        let mut data = vec![0.0f64; n];
2340        for i in 0..n {
2341            let t = i as f64;
2342            data[i] = 100.0 + 0.1 * t + (t * 0.07).sin();
2343        }
2344
2345        let params = NmaParams::default();
2346        let input = NmaInput::from_slice(&data, params);
2347
2348        let baseline = nma(&input)?.values;
2349
2350        let mut out = vec![0.0f64; n];
2351        #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2352        {
2353            nma_into(&input, &mut out)?;
2354        }
2355        #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2356        {
2357            nma_into_slice(&mut out, &input, detect_best_kernel())?;
2358        }
2359
2360        assert_eq!(baseline.len(), out.len());
2361
2362        fn eq_or_both_nan(a: f64, b: f64) -> bool {
2363            (a.is_nan() && b.is_nan()) || (a - b).abs() <= 1e-12
2364        }
2365
2366        for (i, (&a, &b)) in baseline.iter().zip(out.iter()).enumerate() {
2367            assert!(
2368                eq_or_both_nan(a, b),
2369                "Mismatch at index {}: baseline={} out={}",
2370                i,
2371                a,
2372                b
2373            );
2374        }
2375
2376        Ok(())
2377    }
2378
2379    fn check_nma_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2380        skip_if_unsupported!(kernel, test_name);
2381        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2382        let candles = read_candles_from_csv(file_path)?;
2383        let input = NmaInput::from_candles(&candles, "close", NmaParams::default());
2384        let nma_result = nma_with_kernel(&input, kernel)?;
2385
2386        let expected_last_five_nma = [
2387            64320.486018271724,
2388            64227.95719984426,
2389            64180.9249333126,
2390            63966.35530620797,
2391            64039.04719192334,
2392        ];
2393        let start_index = nma_result.values.len() - 5;
2394        let result_last_five_nma = &nma_result.values[start_index..];
2395        for (i, &value) in result_last_five_nma.iter().enumerate() {
2396            let expected_value = expected_last_five_nma[i];
2397
2398            let tolerance = if test_name.contains("avx512") {
2399                1.0
2400            } else {
2401                1e-3
2402            };
2403            assert!(
2404                (value - expected_value).abs() < tolerance,
2405                "[{}] NMA value mismatch at last-5 index {}: expected {}, got {}",
2406                test_name,
2407                i,
2408                expected_value,
2409                value
2410            );
2411        }
2412        Ok(())
2413    }
2414
2415    fn check_nma_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2416        skip_if_unsupported!(kernel, test_name);
2417        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2418        let candles = read_candles_from_csv(file_path)?;
2419        let input = NmaInput::with_default_candles(&candles);
2420        match input.data {
2421            NmaData::Candles { source, .. } => assert_eq!(source, "close"),
2422            _ => panic!("Expected NmaData::Candles"),
2423        }
2424        let output = nma_with_kernel(&input, kernel)?;
2425        assert_eq!(output.values.len(), candles.close.len());
2426
2427        Ok(())
2428    }
2429
2430    fn check_nma_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2431        skip_if_unsupported!(kernel, test_name);
2432        let input_data = [10.0, 20.0, 30.0];
2433        let params = NmaParams { period: Some(0) };
2434        let input = NmaInput::from_slice(&input_data, params);
2435        let res = nma_with_kernel(&input, kernel);
2436        assert!(
2437            res.is_err(),
2438            "[{}] NMA should fail with zero period",
2439            test_name
2440        );
2441        Ok(())
2442    }
2443
2444    fn check_nma_period_exceeds_length(
2445        test_name: &str,
2446        kernel: Kernel,
2447    ) -> Result<(), Box<dyn Error>> {
2448        skip_if_unsupported!(kernel, test_name);
2449        let data_small = [10.0, 20.0, 30.0];
2450        let params = NmaParams { period: Some(10) };
2451        let input = NmaInput::from_slice(&data_small, params);
2452        let res = nma_with_kernel(&input, kernel);
2453        assert!(
2454            res.is_err(),
2455            "[{}] NMA should fail with period exceeding length",
2456            test_name
2457        );
2458        Ok(())
2459    }
2460
2461    fn check_nma_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2462        skip_if_unsupported!(kernel, test_name);
2463        let single_point = [42.0];
2464        let params = NmaParams { period: Some(40) };
2465        let input = NmaInput::from_slice(&single_point, params);
2466        let res = nma_with_kernel(&input, kernel);
2467        assert!(
2468            res.is_err(),
2469            "[{}] NMA should fail with insufficient data",
2470            test_name
2471        );
2472        Ok(())
2473    }
2474
2475    fn check_nma_empty_input(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2476        skip_if_unsupported!(kernel, test_name);
2477        let empty: [f64; 0] = [];
2478        let input = NmaInput::from_slice(&empty, NmaParams::default());
2479        let res = nma_with_kernel(&input, kernel);
2480        assert!(
2481            matches!(res, Err(NmaError::EmptyInputData)),
2482            "[{}] NMA should fail with empty input error",
2483            test_name
2484        );
2485        Ok(())
2486    }
2487
2488    fn check_nma_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2489        skip_if_unsupported!(kernel, test_name);
2490        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2491        let candles = read_candles_from_csv(file_path)?;
2492        let first_params = NmaParams { period: Some(40) };
2493        let first_input = NmaInput::from_candles(&candles, "close", first_params);
2494        let first_result = nma_with_kernel(&first_input, kernel)?;
2495        let second_params = NmaParams { period: Some(20) };
2496        let second_input = NmaInput::from_slice(&first_result.values, second_params);
2497        let second_result = nma_with_kernel(&second_input, kernel)?;
2498        assert_eq!(second_result.values.len(), first_result.values.len());
2499        if second_result.values.len() > 240 {
2500            for i in 240..second_result.values.len() {
2501                assert!(second_result.values[i].is_finite());
2502            }
2503        }
2504        Ok(())
2505    }
2506
2507    fn check_nma_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2508        skip_if_unsupported!(kernel, test_name);
2509        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2510        let candles = read_candles_from_csv(file_path)?;
2511        let input = NmaInput::from_candles(&candles, "close", NmaParams { period: Some(40) });
2512        let res = nma_with_kernel(&input, kernel)?;
2513        assert_eq!(res.values.len(), candles.close.len());
2514        if res.values.len() > 240 {
2515            for (i, &val) in res.values[240..].iter().enumerate() {
2516                assert!(
2517                    !val.is_nan(),
2518                    "[{}] Found unexpected NaN at out-index {}",
2519                    test_name,
2520                    240 + i
2521                );
2522            }
2523        }
2524        Ok(())
2525    }
2526
2527    fn check_nma_property(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2528        use proptest::prelude::*;
2529        skip_if_unsupported!(kernel, test_name);
2530
2531        let strat = (2usize..=100).prop_flat_map(|period| {
2532            (
2533                prop::collection::vec(
2534                    (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
2535                    (period + 1)..400,
2536                ),
2537                Just(period),
2538            )
2539        });
2540
2541        proptest::test_runner::TestRunner::default()
2542            .run(&strat, |(data, period)| {
2543                let params = NmaParams {
2544                    period: Some(period),
2545                };
2546                let input = NmaInput::from_slice(&data, params);
2547
2548                let result = nma_with_kernel(&input, kernel);
2549                prop_assert!(result.is_ok(), "NMA computation failed: {:?}", result.err());
2550                let out = result.unwrap().values;
2551
2552                let ref_result = nma_with_kernel(&input, Kernel::Scalar);
2553                prop_assert!(ref_result.is_ok(), "Reference NMA failed");
2554                let ref_out = ref_result.unwrap().values;
2555
2556                prop_assert_eq!(out.len(), data.len(), "Output length mismatch");
2557
2558                let first_valid = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
2559                let warmup_end = first_valid + period;
2560
2561                for i in 0..warmup_end.min(out.len()) {
2562                    prop_assert!(
2563                        out[i].is_nan(),
2564                        "Expected NaN at index {} (warmup period), got {}",
2565                        i,
2566                        out[i]
2567                    );
2568                }
2569
2570                for i in warmup_end..out.len() {
2571                    prop_assert!(
2572                        out[i].is_finite(),
2573                        "Expected finite value at index {} (after warmup), got {}",
2574                        i,
2575                        out[i]
2576                    );
2577                }
2578
2579                for i in warmup_end..out.len() {
2580                    let point1 = data[i - period + 1];
2581                    let point2 = data[i - period];
2582                    let min_bound = point1.min(point2);
2583                    let max_bound = point1.max(point2);
2584
2585                    let tolerance = if test_name.contains("avx512") {
2586                        1e-7
2587                    } else {
2588                        1e-9
2589                    };
2590                    prop_assert!(
2591                        out[i] >= min_bound - tolerance && out[i] <= max_bound + tolerance,
2592                        "NMA at index {} = {} not in bounds [{}, {}]",
2593                        i,
2594                        out[i],
2595                        min_bound,
2596                        max_bound
2597                    );
2598                }
2599
2600                if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-12) && !data.is_empty() {
2601                    for i in warmup_end..out.len() {
2602                        prop_assert!(
2603                            (out[i] - data[0]).abs() < 1e-9,
2604                            "Constant data: NMA[{}] = {} should equal {}",
2605                            i,
2606                            out[i],
2607                            data[0]
2608                        );
2609                    }
2610                }
2611
2612                if period == 1 {
2613                    for i in (first_valid + 1)..out.len() {
2614                        prop_assert!(
2615                            (out[i] - data[i]).abs() < 1e-6,
2616                            "Period=1: NMA[{}] = {} should be close to data[{}] = {}",
2617                            i,
2618                            out[i],
2619                            i,
2620                            data[i]
2621                        );
2622                    }
2623                }
2624
2625                for i in warmup_end..out.len() {
2626                    let point1 = data[i - period + 1];
2627                    let point2 = data[i - period];
2628
2629                    if (point1 - point2).abs() > 1e-10 {
2630                        let implied_ratio = (out[i] - point2) / (point1 - point2);
2631                        prop_assert!(
2632                            implied_ratio >= -1e-9 && implied_ratio <= 1.0 + 1e-9,
2633                            "Invalid interpolation ratio {} at index {} (output={}, p1={}, p2={})",
2634                            implied_ratio,
2635                            i,
2636                            out[i],
2637                            point1,
2638                            point2
2639                        );
2640                    }
2641                }
2642
2643                for i in 0..out.len() {
2644                    if !out[i].is_finite() || !ref_out[i].is_finite() {
2645                        prop_assert_eq!(
2646                            out[i].is_nan(),
2647                            ref_out[i].is_nan(),
2648                            "NaN mismatch at index {}",
2649                            i
2650                        );
2651                        continue;
2652                    }
2653
2654                    let out_bits = out[i].to_bits();
2655                    let ref_bits = ref_out[i].to_bits();
2656                    let ulp_diff = out_bits.abs_diff(ref_bits);
2657
2658                    if test_name.contains("avx512") {
2659                        let rel_error = if ref_out[i].abs() > 1e-10 {
2660                            ((out[i] - ref_out[i]) / ref_out[i]).abs()
2661                        } else {
2662                            (out[i] - ref_out[i]).abs()
2663                        };
2664                        prop_assert!(
2665                            rel_error < 1e-7 || ulp_diff <= 75,
2666                            "Kernel mismatch at index {}: {} vs {} (rel_error: {}, ULP diff: {})",
2667                            i,
2668                            out[i],
2669                            ref_out[i],
2670                            rel_error,
2671                            ulp_diff
2672                        );
2673                    } else {
2674                        prop_assert!(
2675                            (out[i] - ref_out[i]).abs() <= 1e-9 || ulp_diff <= 25,
2676                            "Kernel mismatch at index {}: {} vs {} (ULP diff: {})",
2677                            i,
2678                            out[i],
2679                            ref_out[i],
2680                            ulp_diff
2681                        );
2682                    }
2683                }
2684
2685                let has_small_values = data.iter().any(|&x| x > 0.0 && x < 1e-8);
2686                if has_small_values {
2687                    for i in warmup_end..out.len() {
2688                        prop_assert!(
2689                            out[i].is_finite(),
2690                            "NMA failed to handle small values at index {}: {}",
2691                            i,
2692                            out[i]
2693                        );
2694                    }
2695                }
2696
2697                Ok(())
2698            })
2699            .unwrap();
2700
2701        Ok(())
2702    }
2703
2704    macro_rules! generate_all_nma_tests {
2705        ($($test_fn:ident),*) => {
2706            paste::paste! {
2707                $(#[test]
2708                fn [<$test_fn _scalar_f64>]() {
2709                    let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2710                })*
2711                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2712                $(
2713                    #[test]
2714                    fn [<$test_fn _avx2_f64>]() {
2715                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2716                    }
2717                    #[test]
2718                    fn [<$test_fn _avx512_f64>]() {
2719                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2720                    }
2721                )*
2722                #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
2723                $(
2724                    #[test]
2725                    fn [<$test_fn _simd128_f64>]() {
2726                        let _ = $test_fn(stringify!([<$test_fn _simd128_f64>]), Kernel::Scalar);
2727                    }
2728                )*
2729            }
2730        }
2731    }
2732
2733    #[cfg(debug_assertions)]
2734    fn check_nma_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2735        skip_if_unsupported!(kernel, test_name);
2736
2737        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2738        let candles = read_candles_from_csv(file_path)?;
2739
2740        let test_cases = vec![
2741            NmaParams { period: Some(40) },
2742            NmaParams { period: Some(10) },
2743            NmaParams { period: Some(5) },
2744            NmaParams { period: Some(20) },
2745            NmaParams { period: Some(60) },
2746            NmaParams { period: Some(100) },
2747            NmaParams { period: Some(3) },
2748            NmaParams { period: Some(80) },
2749            NmaParams { period: None },
2750        ];
2751
2752        for params in test_cases {
2753            let input = NmaInput::from_candles(&candles, "close", params);
2754            let output = nma_with_kernel(&input, kernel)?;
2755
2756            for (i, &val) in output.values.iter().enumerate() {
2757                if val.is_nan() {
2758                    continue;
2759                }
2760
2761                let bits = val.to_bits();
2762
2763                if bits == 0x11111111_11111111 {
2764                    panic!(
2765                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
2766                         with params period={:?}",
2767                        test_name, val, bits, i, params.period
2768                    );
2769                }
2770
2771                if bits == 0x22222222_22222222 {
2772                    panic!(
2773                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
2774                         with params period={:?}",
2775                        test_name, val, bits, i, params.period
2776                    );
2777                }
2778
2779                if bits == 0x33333333_33333333 {
2780                    panic!(
2781                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
2782                         with params period={:?}",
2783                        test_name, val, bits, i, params.period
2784                    );
2785                }
2786            }
2787        }
2788
2789        Ok(())
2790    }
2791
2792    #[cfg(not(debug_assertions))]
2793    fn check_nma_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2794        Ok(())
2795    }
2796
2797    generate_all_nma_tests!(
2798        check_nma_partial_params,
2799        check_nma_accuracy,
2800        check_nma_default_candles,
2801        check_nma_zero_period,
2802        check_nma_period_exceeds_length,
2803        check_nma_very_small_dataset,
2804        check_nma_empty_input,
2805        check_nma_reinput,
2806        check_nma_nan_handling,
2807        check_nma_no_poison,
2808        check_nma_property
2809    );
2810
2811    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2812        skip_if_unsupported!(kernel, test);
2813
2814        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2815        let c = read_candles_from_csv(file)?;
2816
2817        let output = NmaBatchBuilder::new()
2818            .kernel(kernel)
2819            .apply_candles(&c, "close")?;
2820
2821        let def = NmaParams::default();
2822        let row = output.values_for(&def).expect("default row missing");
2823
2824        assert_eq!(row.len(), c.close.len());
2825
2826        let expected = [
2827            64320.486018271724,
2828            64227.95719984426,
2829            64180.924933312606,
2830            63966.35530620797,
2831            64039.04719192333,
2832        ];
2833        let start = row.len() - 5;
2834        for (i, &v) in row[start..].iter().enumerate() {
2835            let tolerance = 1e-3;
2836            assert!(
2837                (v - expected[i]).abs() < tolerance,
2838                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
2839            );
2840        }
2841        Ok(())
2842    }
2843
2844    macro_rules! gen_batch_tests {
2845        ($fn_name:ident) => {
2846            paste::paste! {
2847                #[test] fn [<$fn_name _scalar>]()      {
2848                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2849                }
2850                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2851                #[test] fn [<$fn_name _avx2>]()        {
2852                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2853                }
2854                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2855                #[test] fn [<$fn_name _avx512>]()      {
2856                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2857                }
2858                #[test] fn [<$fn_name _auto_detect>]() {
2859                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2860                }
2861            }
2862        };
2863    }
2864
2865    #[cfg(debug_assertions)]
2866    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2867        skip_if_unsupported!(kernel, test);
2868
2869        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2870        let c = read_candles_from_csv(file)?;
2871
2872        let batch_configs = vec![
2873            (10, 30, 10),
2874            (40, 40, 0),
2875            (3, 15, 3),
2876            (50, 100, 25),
2877            (5, 25, 5),
2878            (20, 80, 20),
2879            (8, 24, 8),
2880            (60, 120, 30),
2881        ];
2882
2883        for (p_start, p_end, p_step) in batch_configs {
2884            let output = NmaBatchBuilder::new()
2885                .kernel(kernel)
2886                .period_range(p_start, p_end, p_step)
2887                .apply_candles(&c, "close")?;
2888
2889            for (idx, &val) in output.values.iter().enumerate() {
2890                if val.is_nan() {
2891                    continue;
2892                }
2893
2894                let bits = val.to_bits();
2895                let row = idx / output.cols;
2896                let col = idx % output.cols;
2897                let combo = &output.combos[row];
2898
2899                if bits == 0x11111111_11111111 {
2900                    panic!(
2901						"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} \
2902                         (flat index {}) with params period={:?}",
2903						test, val, bits, row, col, idx, combo.period
2904					);
2905                }
2906
2907                if bits == 0x22222222_22222222 {
2908                    panic!(
2909						"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} \
2910                         (flat index {}) with params period={:?}",
2911						test, val, bits, row, col, idx, combo.period
2912					);
2913                }
2914
2915                if bits == 0x33333333_33333333 {
2916                    panic!(
2917						"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} \
2918                         (flat index {}) with params period={:?}",
2919						test, val, bits, row, col, idx, combo.period
2920					);
2921                }
2922            }
2923        }
2924
2925        Ok(())
2926    }
2927
2928    #[cfg(not(debug_assertions))]
2929    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2930        Ok(())
2931    }
2932
2933    gen_batch_tests!(check_batch_default_row);
2934    gen_batch_tests!(check_batch_no_poison);
2935}