1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::moving_averages::{CudaNma, DeviceArrayF32};
3use crate::utilities::data_loader::{source_type, Candles};
4use crate::utilities::enums::Kernel;
5use crate::utilities::helpers::{
6 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
7 make_uninit_matrix,
8};
9use aligned_vec::{AVec, CACHELINE_ALIGN};
10#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
11use core::arch::wasm32::*;
12#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
13use core::arch::x86_64::*;
14#[cfg(all(feature = "python", feature = "cuda"))]
15use cust::context::Context;
16#[cfg(all(feature = "python", feature = "cuda"))]
17use cust::memory::DeviceBuffer;
18#[cfg(not(target_arch = "wasm32"))]
19use rayon::prelude::*;
20use std::convert::AsRef;
21use std::error::Error;
22use std::mem::MaybeUninit;
23#[cfg(all(feature = "python", feature = "cuda"))]
24use std::sync::Arc;
25use thiserror::Error;
26
27impl<'a> AsRef<[f64]> for NmaInput<'a> {
28 #[inline(always)]
29 fn as_ref(&self) -> &[f64] {
30 match &self.data {
31 NmaData::Slice(slice) => slice,
32 NmaData::Candles { candles, source } => source_type(candles, source),
33 }
34 }
35}
36
37#[derive(Debug, Clone)]
38pub enum NmaData<'a> {
39 Candles {
40 candles: &'a Candles,
41 source: &'a str,
42 },
43 Slice(&'a [f64]),
44}
45
46#[derive(Debug, Clone)]
47pub struct NmaOutput {
48 pub values: Vec<f64>,
49}
50
51#[derive(Debug, Clone, Copy)]
52#[cfg_attr(
53 all(target_arch = "wasm32", feature = "wasm"),
54 derive(serde::Serialize, serde::Deserialize)
55)]
56pub struct NmaParams {
57 pub period: Option<usize>,
58}
59
60impl Default for NmaParams {
61 fn default() -> Self {
62 Self { period: Some(40) }
63 }
64}
65
66#[derive(Debug, Clone)]
67pub struct NmaInput<'a> {
68 pub data: NmaData<'a>,
69 pub params: NmaParams,
70}
71
72impl<'a> NmaInput<'a> {
73 #[inline]
74 pub fn from_candles(c: &'a Candles, s: &'a str, p: NmaParams) -> Self {
75 Self {
76 data: NmaData::Candles {
77 candles: c,
78 source: s,
79 },
80 params: p,
81 }
82 }
83 #[inline]
84 pub fn from_slice(sl: &'a [f64], p: NmaParams) -> Self {
85 Self {
86 data: NmaData::Slice(sl),
87 params: p,
88 }
89 }
90 #[inline]
91 pub fn with_default_candles(c: &'a Candles) -> Self {
92 Self::from_candles(c, "close", NmaParams::default())
93 }
94 #[inline]
95 pub fn get_period(&self) -> usize {
96 self.params.period.unwrap_or(40)
97 }
98}
99
100#[derive(Copy, Clone, Debug)]
101pub struct NmaBuilder {
102 period: Option<usize>,
103 kernel: Kernel,
104}
105
106impl Default for NmaBuilder {
107 fn default() -> Self {
108 Self {
109 period: None,
110 kernel: Kernel::Auto,
111 }
112 }
113}
114
115impl NmaBuilder {
116 #[inline(always)]
117 pub fn new() -> Self {
118 Self::default()
119 }
120 #[inline(always)]
121 pub fn period(mut self, n: usize) -> Self {
122 self.period = Some(n);
123 self
124 }
125 #[inline(always)]
126 pub fn kernel(mut self, k: Kernel) -> Self {
127 self.kernel = k;
128 self
129 }
130 #[inline(always)]
131 pub fn apply(self, c: &Candles) -> Result<NmaOutput, NmaError> {
132 let p = NmaParams {
133 period: self.period,
134 };
135 let i = NmaInput::from_candles(c, "close", p);
136 nma_with_kernel(&i, self.kernel)
137 }
138 #[inline(always)]
139 pub fn apply_slice(self, d: &[f64]) -> Result<NmaOutput, NmaError> {
140 let p = NmaParams {
141 period: self.period,
142 };
143 let i = NmaInput::from_slice(d, p);
144 nma_with_kernel(&i, self.kernel)
145 }
146 #[inline(always)]
147 pub fn into_stream(self) -> Result<NmaStream, NmaError> {
148 let p = NmaParams {
149 period: self.period,
150 };
151 NmaStream::try_new(p)
152 }
153}
154
155#[derive(Debug, Error)]
156pub enum NmaError {
157 #[error("nma: Input data slice is empty.")]
158 EmptyInputData,
159 #[error("nma: All values are NaN.")]
160 AllValuesNaN,
161 #[error("nma: Invalid period: period = {period}, data length = {data_len}")]
162 InvalidPeriod { period: usize, data_len: usize },
163 #[error("nma: Not enough valid data: needed = {needed}, valid = {valid}")]
164 NotEnoughValidData { needed: usize, valid: usize },
165 #[error("nma: Output length mismatch: expected = {expected}, got = {got}")]
166 OutputLengthMismatch { expected: usize, got: usize },
167 #[error("nma: Invalid range: start = {start}, end = {end}, step = {step}")]
168 InvalidRange {
169 start: usize,
170 end: usize,
171 step: usize,
172 },
173 #[error("nma: Invalid kernel for batch path: {0:?}")]
174 InvalidKernelForBatch(Kernel),
175 #[error("nma: invalid input: {0}")]
176 InvalidInput(String),
177}
178
179#[inline]
180pub fn nma(input: &NmaInput) -> Result<NmaOutput, NmaError> {
181 nma_with_kernel(input, Kernel::Auto)
182}
183
184#[inline(always)]
185fn nma_prepare<'a>(
186 input: &'a NmaInput,
187 kernel: Kernel,
188) -> Result<(&'a [f64], usize, usize, Vec<f64>, Vec<f64>, Kernel), NmaError> {
189 let data: &[f64] = input.as_ref();
190 let len = data.len();
191
192 if len == 0 {
193 return Err(NmaError::EmptyInputData);
194 }
195
196 let first = data
197 .iter()
198 .position(|x| !x.is_nan())
199 .ok_or(NmaError::AllValuesNaN)?;
200
201 let period = input.get_period();
202
203 if period == 0 || period > len {
204 return Err(NmaError::InvalidPeriod {
205 period,
206 data_len: len,
207 });
208 }
209 if (len - first) < (period + 1) {
210 return Err(NmaError::NotEnoughValidData {
211 needed: period + 1,
212 valid: len - first,
213 });
214 }
215
216 let chosen = match kernel {
217 Kernel::Auto => detect_best_kernel(),
218 other => other,
219 };
220
221 let mut ln_values = alloc_with_nan_prefix(len, 0);
222 if matches!(chosen, Kernel::Scalar | Kernel::ScalarBatch) {
223 for i in 0..len {
224 ln_values[i] = data[i].max(1e-10).ln();
225 }
226 }
227
228 let mut sqrt_diffs = vec![0.0; period];
229 for i in 0..period {
230 let s0 = (i as f64).sqrt();
231 let s1 = ((i + 1) as f64).sqrt();
232 sqrt_diffs[i] = s1 - s0;
233 }
234
235 Ok((data, period, first, ln_values, sqrt_diffs, chosen))
236}
237
238fn nma_compute_into(
239 data: &[f64],
240 period: usize,
241 first: usize,
242 ln_values: &mut [f64],
243 sqrt_diffs: &mut [f64],
244 kernel: Kernel,
245 out: &mut [f64],
246) {
247 unsafe {
248 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
249 {
250 if matches!(kernel, Kernel::Scalar | Kernel::ScalarBatch) {
251 nma_simd128(data, period, first, ln_values, sqrt_diffs, out);
252 return;
253 }
254 }
255
256 match kernel {
257 Kernel::Scalar | Kernel::ScalarBatch => {
258 nma_scalar_with_precomputed(data, period, first, ln_values, sqrt_diffs, out)
259 }
260
261 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
262 Kernel::Avx2 | Kernel::Avx2Batch => {
263 nma_avx2(data, period, first, ln_values, sqrt_diffs, out)
264 }
265
266 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
267 Kernel::Avx512 | Kernel::Avx512Batch => {
268 nma_avx512_v2(data, period, first, ln_values, sqrt_diffs, out)
269 }
270
271 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
272 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
273 nma_scalar_with_precomputed(data, period, first, ln_values, sqrt_diffs, out)
274 }
275 _ => unreachable!(),
276 }
277 }
278}
279
280pub fn nma_with_kernel(input: &NmaInput, kernel: Kernel) -> Result<NmaOutput, NmaError> {
281 let (data, period, first, mut ln_values, mut sqrt_diffs, chosen) = nma_prepare(input, kernel)?;
282
283 let warm = first + period;
284 let mut out = alloc_with_nan_prefix(data.len(), warm);
285
286 nma_compute_into(
287 data,
288 period,
289 first,
290 &mut ln_values,
291 &mut sqrt_diffs,
292 chosen,
293 &mut out,
294 );
295
296 Ok(NmaOutput { values: out })
297}
298#[inline]
299pub fn nma_scalar(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
300 let len = data.len();
301
302 let mut ln_values = alloc_with_nan_prefix(len, 0);
303 for i in 0..len {
304 ln_values[i] = data[i].max(1e-10).ln();
305 }
306
307 let mut sqrt_diffs = vec![0.0; period];
308 for i in 0..period {
309 let s0 = (i as f64).sqrt();
310 let s1 = ((i + 1) as f64).sqrt();
311 sqrt_diffs[i] = s1 - s0;
312 }
313
314 for j in (first + period)..len {
315 let mut num = 0.0;
316 let mut denom = 0.0;
317
318 for i in 0..period {
319 let oi = (ln_values[j - i] - ln_values[j - i - 1]).abs();
320 num += oi * sqrt_diffs[i];
321 denom += oi;
322 }
323
324 let ratio = if denom == 0.0 { 0.0 } else { num / denom };
325
326 let i = period - 1;
327 out[j] = data[j - i] * ratio + data[j - i - 1] * (1.0 - ratio);
328 }
329}
330
331#[inline]
332pub fn nma_scalar_with_precomputed(
333 data: &[f64],
334 period: usize,
335 first: usize,
336 ln_values: &[f64],
337 sqrt_diffs: &[f64],
338 out: &mut [f64],
339) {
340 let len = data.len();
341
342 for j in (first + period)..len {
343 let base = j - period;
344
345 let mut num = 0.0f64;
346 let mut denom = 0.0f64;
347
348 let mut prev = ln_values[base];
349 for t in 0..period {
350 let cur = ln_values[base + t + 1];
351 let diff = (cur - prev).abs();
352 prev = cur;
353
354 num += diff * sqrt_diffs[period - 1 - t];
355 denom += diff;
356 }
357
358 let ratio = if denom == 0.0 { 0.0 } else { num / denom };
359
360 let x0 = data[j - period];
361 let x1 = data[j - period + 1];
362 out[j] = (x1 - x0).mul_add(ratio, x0);
363 }
364}
365
366#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
367#[inline]
368unsafe fn nma_simd128(
369 data: &[f64],
370 period: usize,
371 first: usize,
372 ln_values: &[f64],
373 sqrt_diffs: &[f64],
374 out: &mut [f64],
375) {
376 use core::arch::wasm32::*;
377
378 const STEP: usize = 2;
379 let len = data.len();
380
381 for j in (first + period)..len {
382 let chunks = period / STEP;
383 let tail = period % STEP;
384
385 let mut num_acc = f64x2_splat(0.0);
386 let mut denom_acc = f64x2_splat(0.0);
387
388 for blk in 0..chunks {
389 let i = blk * STEP;
390
391 let ln_curr_0 = f64x2(ln_values[j - i], ln_values[j - i - 1]);
392 let ln_prev_0 = f64x2(ln_values[j - i - 1], ln_values[j - i - 2]);
393
394 let diff = f64x2_sub(ln_curr_0, ln_prev_0);
395 let abs_diff = f64x2_abs(diff);
396
397 let sqrt_d = v128_load(sqrt_diffs.as_ptr().add(i) as *const v128);
398
399 num_acc = f64x2_add(num_acc, f64x2_mul(abs_diff, sqrt_d));
400 denom_acc = f64x2_add(denom_acc, abs_diff);
401 }
402
403 let mut num = f64x2_extract_lane::<0>(num_acc) + f64x2_extract_lane::<1>(num_acc);
404 let mut denom = f64x2_extract_lane::<0>(denom_acc) + f64x2_extract_lane::<1>(denom_acc);
405
406 for i in (chunks * STEP)..period {
407 let oi = (ln_values[j - i] - ln_values[j - i - 1]).abs();
408 num += oi * sqrt_diffs[i];
409 denom += oi;
410 }
411
412 let ratio = if denom == 0.0 { 0.0 } else { num / denom };
413 let i = period - 1;
414 out[j] = data[j - i] * ratio + data[j - i - 1] * (1.0 - ratio);
415 }
416}
417
418#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
419#[inline]
420#[target_feature(enable = "avx512f,avx512dq,avx512vl,avx512bw,fma")]
421unsafe fn fast_ln_avx512_hi(x: __m512d) -> __m512d {
422 let one = _mm512_set1_pd(1.0);
423 let two = _mm512_set1_pd(2.0);
424 let half = _mm512_set1_pd(0.5);
425 let ln2 = _mm512_set1_pd(std::f64::consts::LN_2);
426 let sqrt_half = _mm512_set1_pd(0.7071067811865475244);
427
428 let threshold = _mm512_set1_pd(0.2);
429 let x_minus_1 = _mm512_sub_pd(x, one);
430 let abs_x_minus_1 = _mm512_abs_pd(x_minus_1);
431 let near_one_mask = _mm512_cmp_pd_mask(abs_x_minus_1, threshold, _CMP_LT_OQ);
432
433 let c2 = _mm512_set1_pd(-0.5);
434 let c3 = _mm512_set1_pd(1.0 / 3.0);
435 let c4 = _mm512_set1_pd(-0.25);
436 let c5 = _mm512_set1_pd(0.2);
437 let c6 = _mm512_set1_pd(-1.0 / 6.0);
438 let c7 = _mm512_set1_pd(1.0 / 7.0);
439 let c8 = _mm512_set1_pd(-0.125);
440
441 let y = x_minus_1;
442 let y2 = _mm512_mul_pd(y, y);
443 let y3 = _mm512_mul_pd(y2, y);
444 let y4 = _mm512_mul_pd(y2, y2);
445
446 let mut taylor = y;
447 taylor = _mm512_fmadd_pd(y2, c2, taylor);
448 taylor = _mm512_fmadd_pd(y3, c3, taylor);
449 taylor = _mm512_fmadd_pd(y4, c4, taylor);
450 let y5 = _mm512_mul_pd(y4, y);
451 let y6 = _mm512_mul_pd(y4, y2);
452 let y7 = _mm512_mul_pd(y4, y3);
453 let y8 = _mm512_mul_pd(y4, y4);
454 taylor = _mm512_fmadd_pd(y5, c5, taylor);
455 taylor = _mm512_fmadd_pd(y6, c6, taylor);
456 taylor = _mm512_fmadd_pd(y7, c7, taylor);
457 taylor = _mm512_fmadd_pd(y8, c8, taylor);
458
459 let ix = _mm512_castpd_si512(x);
460 let exp_mask = _mm512_set1_epi64(0x7FF0000000000000u64 as i64);
461 let mantissa_mask = _mm512_set1_epi64(0x000FFFFFFFFFFFFFu64 as i64);
462 let bias = _mm512_set1_epi64(1023);
463
464 let exp_bits = _mm512_and_si512(ix, exp_mask);
465 let exp_shifted = _mm512_srli_epi64::<52>(exp_bits);
466 let e = _mm512_sub_epi64(exp_shifted, bias);
467 let e_f64 = _mm512_cvtepi64_pd(e);
468
469 let mantissa_bits = _mm512_and_si512(ix, mantissa_mask);
470 let one_bits = _mm512_set1_epi64(0x3FF0000000000000u64 as i64);
471 let m_bits = _mm512_or_si512(mantissa_bits, one_bits);
472 let mut m = _mm512_castsi512_pd(m_bits);
473
474 let needs_fold = _mm512_cmp_pd_mask(m, sqrt_half, _CMP_LT_OQ);
475 m = _mm512_mask_mul_pd(m, needs_fold, m, two);
476 let e_adjust = _mm512_mask_sub_pd(e_f64, needs_fold, e_f64, one);
477
478 let f = _mm512_sub_pd(m, one);
479
480 let two_plus_f = _mm512_add_pd(two, f);
481 let s = _mm512_div_pd(f, two_plus_f);
482 let z = _mm512_mul_pd(s, s);
483 let w = _mm512_mul_pd(z, z);
484
485 let lg1 = _mm512_set1_pd(6.666666666666735130e-01);
486 let lg2 = _mm512_set1_pd(3.999999999940941908e-01);
487 let lg3 = _mm512_set1_pd(2.857142874366239149e-01);
488 let lg4 = _mm512_set1_pd(2.222219843214978396e-01);
489 let lg5 = _mm512_set1_pd(1.818357216161805012e-01);
490 let lg6 = _mm512_set1_pd(1.531383769920937332e-01);
491 let lg7 = _mm512_set1_pd(1.479819860511658591e-01);
492
493 let lg8 = _mm512_set1_pd(1.333355814642869980e-01);
494 let lg9 = _mm512_set1_pd(1.253141636393179328e-01);
495
496 let mut r1 = lg9;
497 r1 = _mm512_fmadd_pd(r1, z, lg7);
498 r1 = _mm512_fmadd_pd(r1, z, lg5);
499 r1 = _mm512_fmadd_pd(r1, z, lg3);
500 r1 = _mm512_fmadd_pd(r1, z, lg1);
501 r1 = _mm512_mul_pd(r1, z);
502
503 let mut r2 = lg8;
504 r2 = _mm512_fmadd_pd(r2, z, lg6);
505 r2 = _mm512_fmadd_pd(r2, z, lg4);
506 r2 = _mm512_fmadd_pd(r2, z, lg2);
507 r2 = _mm512_mul_pd(r2, w);
508
509 let r = _mm512_add_pd(r1, r2);
510
511 let hfsq = _mm512_mul_pd(_mm512_mul_pd(half, f), f);
512
513 let ln1pf = _mm512_sub_pd(f, hfsq);
514 let s_squared_times_f = _mm512_mul_pd(_mm512_mul_pd(s, s), f);
515 let ln1pf = _mm512_fmadd_pd(s_squared_times_f, r, ln1pf);
516
517 let general_result = _mm512_fmadd_pd(e_adjust, ln2, ln1pf);
518
519 _mm512_mask_blend_pd(near_one_mask, general_result, taylor)
520}
521
522#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
523#[inline]
524#[target_feature(enable = "avx2,fma")]
525pub unsafe fn nma_avx2(
526 data: &[f64],
527 period: usize,
528 first: usize,
529 ln_values: &mut [f64],
530 sqrt_diffs: &mut [f64],
531 out: &mut [f64],
532) {
533 let len = data.len();
534
535 let epsilon = _mm256_set1_pd(1e-10);
536
537 let one = _mm256_set1_pd(1.0);
538 let zero = _mm256_setzero_pd();
539
540 let mut i = 0;
541 while i + 4 <= len {
542 let vals = _mm256_loadu_pd(data.as_ptr().add(i));
543 let clamped = _mm256_max_pd(vals, epsilon);
544
545 let mut ln_vals = [0.0f64; 4];
546 _mm256_storeu_pd(ln_vals.as_mut_ptr(), clamped);
547 for j in 0..4 {
548 ln_vals[j] = ln_vals[j].ln();
549 }
550 let ln_result = _mm256_loadu_pd(ln_vals.as_ptr());
551
552 _mm256_storeu_pd(ln_values.as_mut_ptr().add(i), ln_result);
553
554 i += 4;
555 }
556
557 for j in i..len {
558 ln_values[j] = data[j].max(1e-10).ln();
559 }
560
561 for j in (first + period)..len {
562 let mut num_accum = zero;
563 let mut denom_accum = zero;
564
565 let mut idx = 0;
566 while idx + 4 <= period {
567 let mut diffs = [0.0f64; 4];
568 for k in 0..4 {
569 let i = idx + k;
570 let diff = (ln_values[j - i] - ln_values[j - i - 1]).abs();
571 diffs[k] = diff;
572 }
573 let oi_vec = _mm256_loadu_pd(diffs.as_ptr());
574
575 let weights = _mm256_loadu_pd(sqrt_diffs.as_ptr().add(idx));
576
577 num_accum = _mm256_fmadd_pd(oi_vec, weights, num_accum);
578 denom_accum = _mm256_add_pd(denom_accum, oi_vec);
579
580 idx += 4;
581 }
582
583 let num_scalar = horizontal_sum_avx2(num_accum);
584 let denom_scalar = horizontal_sum_avx2(denom_accum);
585
586 let mut num_final = num_scalar;
587 let mut denom_final = denom_scalar;
588
589 for i in idx..period {
590 let oi = (ln_values[j - i] - ln_values[j - i - 1]).abs();
591 num_final += oi * sqrt_diffs[i];
592 denom_final += oi;
593 }
594
595 let ratio = if denom_final == 0.0 {
596 0.0
597 } else {
598 num_final / denom_final
599 };
600 let i = period - 1;
601 out[j] = data[j - i] * ratio + data[j - i - 1] * (1.0 - ratio);
602 }
603}
604
605#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
606#[inline]
607#[target_feature(enable = "avx2")]
608unsafe fn horizontal_sum_avx2(v: __m256d) -> f64 {
609 let vlow = _mm256_castpd256_pd128(v);
610 let vhigh = _mm256_extractf128_pd(v, 1);
611
612 let sum128 = _mm_add_pd(vlow, vhigh);
613
614 let high64 = _mm_unpackhi_pd(sum128, sum128);
615
616 _mm_cvtsd_f64(_mm_add_sd(sum128, high64))
617}
618
619#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
620#[inline]
621#[target_feature(enable = "avx2")]
622unsafe fn fast_ln_avx2_hi(x: __m256d) -> __m256d {
623 let one = _mm256_set1_pd(1.0);
624 let two = _mm256_set1_pd(2.0);
625 let half = _mm256_set1_pd(0.5);
626 let ln2 = _mm256_set1_pd(std::f64::consts::LN_2);
627 let sqrt_half = _mm256_set1_pd(0.7071067811865475244);
628
629 let mut mantissa = [0.0f64; 4];
630 let mut exponent = [0i32; 4];
631 _mm256_storeu_pd(mantissa.as_mut_ptr(), x);
632
633 for j in 0..4 {
634 let bits = mantissa[j].to_bits();
635 let exp_bits = ((bits >> 52) & 0x7FF) as i32;
636 exponent[j] = exp_bits - 1023;
637
638 let mantissa_bits = (bits & !0x7FF0000000000000) | 0x3FF0000000000000;
639 mantissa[j] = f64::from_bits(mantissa_bits);
640 }
641
642 let mut m = _mm256_loadu_pd(mantissa.as_ptr());
643 let e_vals = [
644 exponent[0] as f64,
645 exponent[1] as f64,
646 exponent[2] as f64,
647 exponent[3] as f64,
648 ];
649 let mut e_f64 = _mm256_loadu_pd(e_vals.as_ptr());
650
651 let mask = _mm256_cmp_pd(m, sqrt_half, _CMP_LT_OQ);
652 m = _mm256_blendv_pd(m, _mm256_mul_pd(m, two), mask);
653 e_f64 = _mm256_blendv_pd(e_f64, _mm256_sub_pd(e_f64, one), mask);
654
655 let f = _mm256_sub_pd(m, one);
656
657 let two_plus_f = _mm256_add_pd(two, f);
658 let s = _mm256_div_pd(f, two_plus_f);
659 let z = _mm256_mul_pd(s, s);
660 let w = _mm256_mul_pd(z, z);
661
662 let lg1 = _mm256_set1_pd(6.666666666666735130e-01);
663 let lg2 = _mm256_set1_pd(3.999999999940941908e-01);
664 let lg3 = _mm256_set1_pd(2.857142874366239149e-01);
665 let lg4 = _mm256_set1_pd(2.222219843214978396e-01);
666 let lg5 = _mm256_set1_pd(1.818357216161805012e-01);
667 let lg6 = _mm256_set1_pd(1.531383769920937332e-01);
668 let lg7 = _mm256_set1_pd(1.479819860511658591e-01);
669
670 let mut r1 = lg7;
671 r1 = _mm256_fmadd_pd(r1, z, lg5);
672 r1 = _mm256_fmadd_pd(r1, z, lg3);
673 r1 = _mm256_fmadd_pd(r1, z, lg1);
674 r1 = _mm256_mul_pd(r1, z);
675
676 let mut r2 = lg6;
677 r2 = _mm256_fmadd_pd(r2, z, lg4);
678 r2 = _mm256_fmadd_pd(r2, z, lg2);
679 r2 = _mm256_mul_pd(r2, w);
680
681 let r = _mm256_add_pd(r1, r2);
682
683 let hfsq = _mm256_mul_pd(_mm256_mul_pd(half, f), f);
684 let f_times_hfsq = _mm256_mul_pd(f, hfsq);
685 let ln1pf = _mm256_sub_pd(f, hfsq);
686 let ln1pf = _mm256_fmadd_pd(f_times_hfsq, r, ln1pf);
687
688 _mm256_fmadd_pd(e_f64, ln2, ln1pf)
689}
690
691#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
692#[inline]
693#[target_feature(enable = "avx512f")]
694unsafe fn _mm512_abs_pd(a: __m512d) -> __m512d {
695 let sign_mask = _mm512_set1_pd(-0.0);
696 _mm512_andnot_pd(sign_mask, a)
697}
698
699#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
700#[inline]
701#[target_feature(enable = "avx512f,avx512dq,avx512vl,avx512bw,fma")]
702pub unsafe fn nma_avx512(
703 data: &[f64],
704 period: usize,
705 first: usize,
706 ln_values: &mut [f64],
707 sqrt_diffs: &mut [f64],
708 out: &mut [f64],
709) {
710 let len = data.len();
711
712 let one = _mm512_set1_pd(1.0);
713 let zero = _mm512_setzero_pd();
714
715 for i in 0..len {
716 ln_values[i] = data[i].max(1e-10).ln();
717 }
718
719 for j in (first + period)..len {
720 let mut num_accum = zero;
721 let mut denom_accum = zero;
722
723 let mut idx = 0;
724 while idx + 8 <= period {
725 if j >= idx + 8 {
726 let base_ptr = ln_values.as_ptr().add(j - idx - 8);
727
728 let prev = _mm512_loadu_pd(base_ptr);
729
730 let curr = _mm512_loadu_pd(base_ptr.add(1));
731
732 let diff = _mm512_sub_pd(curr, prev);
733 let abs_diff = _mm512_abs_pd(diff);
734
735 let perm_indices = _mm512_set_epi64(7, 6, 5, 4, 3, 2, 1, 0);
736 let oi_vec = _mm512_permutexvar_pd(perm_indices, abs_diff);
737
738 let weights = _mm512_loadu_pd(sqrt_diffs.as_ptr().add(idx));
739
740 num_accum = _mm512_fmadd_pd(oi_vec, weights, num_accum);
741 denom_accum = _mm512_add_pd(denom_accum, oi_vec);
742 } else {
743 for k in 0..8 {
744 let i = idx + k;
745 let oi = (ln_values[j - i] - ln_values[j - i - 1]).abs();
746 let weight = sqrt_diffs[i];
747 num_accum = _mm512_mask_add_pd(
748 num_accum,
749 1 << k,
750 num_accum,
751 _mm512_set1_pd(oi * weight),
752 );
753 denom_accum =
754 _mm512_mask_add_pd(denom_accum, 1 << k, denom_accum, _mm512_set1_pd(oi));
755 }
756 }
757
758 idx += 8;
759 }
760
761 let mut num_scalar = _mm512_reduce_add_pd(num_accum);
762 let mut denom_scalar = _mm512_reduce_add_pd(denom_accum);
763
764 for i in idx..period {
765 let oi = (ln_values[j - i] - ln_values[j - i - 1]).abs();
766 num_scalar += oi * sqrt_diffs[i];
767 denom_scalar += oi;
768 }
769
770 let ratio = if denom_scalar == 0.0 {
771 0.0
772 } else {
773 num_scalar / denom_scalar
774 };
775 let i = period - 1;
776 out[j] = data[j - i] * ratio + data[j - i - 1] * (1.0 - ratio);
777 }
778}
779
780#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
781#[inline]
782#[target_feature(enable = "avx512f,avx512dq,avx512vl,fma")]
783pub unsafe fn nma_avx512_v2(
784 data: &[f64],
785 period: usize,
786 first: usize,
787 ln_values: &mut [f64],
788 sqrt_diffs: &mut [f64],
789 out: &mut [f64],
790) {
791 use aligned_vec::AVec;
792 use core::arch::x86_64::*;
793
794 let len = data.len();
795 debug_assert!(len == ln_values.len());
796 debug_assert!(period >= 1 && period <= len);
797
798 for i in 0..len {
799 ln_values[i] = data[i].max(1e-10).ln();
800 }
801
802 for i in 0..len - 1 {
803 ln_values[i] = (ln_values[i + 1] - ln_values[i]).abs();
804 }
805 ln_values[len - 1] = 0.0;
806 let d = ln_values;
807
808 let mut s = alloc_with_nan_prefix(len + 1, 0);
809 s[0] = 0.0;
810 for k in 0..len {
811 s[k + 1] = s[k] + d[k];
812 }
813
814 let wlen_padded = (period + 7) & !7;
815 let mut w_rev = AVec::<f64>::with_capacity(64, wlen_padded);
816 w_rev.resize(wlen_padded, 0.0);
817 for i in 0..period {
818 w_rev[i] = sqrt_diffs[period - 1 - i];
819 }
820
821 let warm = first + period;
822 let zero = _mm512_setzero_pd();
823
824 for j in warm..len {
825 let base = j - period;
826
827 let denom = s[j] - s[j - period];
828
829 let mut num_acc = zero;
830 let mut t = 0usize;
831
832 while t + 16 <= period {
833 let d0 = _mm512_loadu_pd(d.as_ptr().add(base + t));
834 let w0 = _mm512_loadu_pd(w_rev.as_ptr().add(t));
835 let d1 = _mm512_loadu_pd(d.as_ptr().add(base + t + 8));
836 let w1 = _mm512_loadu_pd(w_rev.as_ptr().add(t + 8));
837 num_acc = _mm512_fmadd_pd(d0, w0, num_acc);
838 num_acc = _mm512_fmadd_pd(d1, w1, num_acc);
839 t += 16;
840 }
841 while t + 8 <= period {
842 let d0 = _mm512_loadu_pd(d.as_ptr().add(base + t));
843 let w0 = _mm512_loadu_pd(w_rev.as_ptr().add(t));
844 num_acc = _mm512_fmadd_pd(d0, w0, num_acc);
845 t += 8;
846 }
847 if t < period {
848 let tail = (period - t) as u32;
849 let mask: __mmask8 = ((1u32 << tail) - 1) as u8;
850 let d0 = _mm512_maskz_loadu_pd(mask, d.as_ptr().add(base + t));
851 let w0 = _mm512_maskz_loadu_pd(mask, w_rev.as_ptr().add(t));
852 num_acc = _mm512_fmadd_pd(d0, w0, num_acc);
853 }
854
855 let num = _mm512_reduce_add_pd(num_acc);
856 let ratio = if denom == 0.0 { 0.0 } else { num / denom };
857
858 let i0 = period - 1;
859 let x2 = data[j - i0 - 1];
860 let dx = data[j - i0] - x2;
861 out[j] = ratio.mul_add(dx, x2);
862 }
863}
864
865#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
866#[target_feature(enable = "avx512f,avx512dq,avx512vl,fma")]
867unsafe fn nma_batch_avx512_optimized(
868 data: &[f64],
869 sweep: &NmaBatchRange,
870 first: usize,
871 parallel: bool,
872) -> Result<NmaBatchOutput, NmaError> {
873 use aligned_vec::AVec;
874 use core::arch::x86_64::*;
875
876 let combos = expand_grid(sweep)?;
877 if combos.is_empty() {
878 return Err(NmaError::InvalidPeriod {
879 period: 0,
880 data_len: 0,
881 });
882 }
883
884 let len = data.len();
885 let rows = combos.len();
886 let cols = len;
887
888 let mut ln_values = alloc_with_nan_prefix(len, 0);
889 for i in 0..len {
890 ln_values[i] = data[i].max(1e-10).ln();
891 }
892
893 for i in 0..len - 1 {
894 ln_values[i] = (ln_values[i + 1] - ln_values[i]).abs();
895 }
896 ln_values[len - 1] = 0.0;
897 let d = &mut ln_values;
898
899 let mut s = alloc_with_nan_prefix(len + 1, 0);
900 s[0] = 0.0;
901 for k in 0..len {
902 s[k + 1] = s[k] + d[k];
903 }
904
905 let warm: Vec<usize> = combos.iter().map(|c| first + c.period.unwrap()).collect();
906 let mut raw = make_uninit_matrix(rows, cols);
907 unsafe { init_matrix_prefixes(&mut raw, cols, &warm) };
908
909 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
910 let period = combos[row].period.unwrap();
911 let warm = first + period;
912
913 let out_row =
914 core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
915
916 let wlen_padded = (period + 7) & !7;
917 let mut w_rev = AVec::<f64>::with_capacity(64, wlen_padded);
918 w_rev.resize(wlen_padded, 0.0);
919
920 for i in 0..period {
921 let s0 = ((period - 1 - i) as f64).sqrt();
922 let s1 = ((period - i) as f64).sqrt();
923 w_rev[i] = s1 - s0;
924 }
925
926 let zero = _mm512_setzero_pd();
927
928 for j in warm..len {
929 let base = j - period;
930
931 let denom = s[j] - s[j - period];
932
933 let mut num_acc = zero;
934 let mut t = 0usize;
935
936 while t + 16 <= period {
937 let d0 = _mm512_loadu_pd(d.as_ptr().add(base + t));
938 let w0 = _mm512_loadu_pd(w_rev.as_ptr().add(t));
939 let d1 = _mm512_loadu_pd(d.as_ptr().add(base + t + 8));
940 let w1 = _mm512_loadu_pd(w_rev.as_ptr().add(t + 8));
941 num_acc = _mm512_fmadd_pd(d0, w0, num_acc);
942 num_acc = _mm512_fmadd_pd(d1, w1, num_acc);
943 t += 16;
944 }
945 while t + 8 <= period {
946 let d0 = _mm512_loadu_pd(d.as_ptr().add(base + t));
947 let w0 = _mm512_loadu_pd(w_rev.as_ptr().add(t));
948 num_acc = _mm512_fmadd_pd(d0, w0, num_acc);
949 t += 8;
950 }
951 if t < period {
952 let tail = (period - t) as u32;
953 let mask: __mmask8 = ((1u32 << tail) - 1) as u8;
954 let d0 = _mm512_maskz_loadu_pd(mask, d.as_ptr().add(base + t));
955 let w0 = _mm512_maskz_loadu_pd(mask, w_rev.as_ptr().add(t));
956 num_acc = _mm512_fmadd_pd(d0, w0, num_acc);
957 }
958
959 let num = _mm512_reduce_add_pd(num_acc);
960 let ratio = if denom == 0.0 { 0.0 } else { num / denom };
961
962 let i0 = period - 1;
963 let x2 = data[j - i0 - 1];
964 let dx = data[j - i0] - x2;
965 out_row[j] = ratio.mul_add(dx, x2);
966 }
967 };
968
969 if parallel {
970 #[cfg(not(target_arch = "wasm32"))]
971 {
972 use rayon::prelude::*;
973 raw.par_chunks_mut(cols)
974 .enumerate()
975 .for_each(|(row, slice)| do_row(row, slice));
976 }
977 #[cfg(target_arch = "wasm32")]
978 {
979 for (row, slice) in raw.chunks_mut(cols).enumerate() {
980 do_row(row, slice);
981 }
982 }
983 } else {
984 for (row, slice) in raw.chunks_mut(cols).enumerate() {
985 do_row(row, slice);
986 }
987 }
988
989 let values: Vec<f64> = unsafe { std::mem::transmute(raw) };
990
991 Ok(NmaBatchOutput {
992 values,
993 combos,
994 rows,
995 cols,
996 })
997}
998
999#[inline(always)]
1000pub fn nma_batch_with_kernel(
1001 data: &[f64],
1002 sweep: &NmaBatchRange,
1003 k: Kernel,
1004) -> Result<NmaBatchOutput, NmaError> {
1005 let kernel = match k {
1006 Kernel::Auto => detect_best_batch_kernel(),
1007 other if other.is_batch() => other,
1008 _ => return Err(NmaError::InvalidKernelForBatch(k)),
1009 };
1010
1011 let simd = match kernel {
1012 Kernel::Avx512Batch => Kernel::Avx512,
1013 Kernel::Avx2Batch => Kernel::Avx2,
1014 Kernel::ScalarBatch => Kernel::Scalar,
1015 _ => Kernel::Scalar,
1016 };
1017 nma_batch_par_slice(data, sweep, simd)
1018}
1019
1020#[derive(Clone, Debug)]
1021pub struct NmaBatchRange {
1022 pub period: (usize, usize, usize),
1023}
1024
1025impl Default for NmaBatchRange {
1026 fn default() -> Self {
1027 Self {
1028 period: (40, 289, 1),
1029 }
1030 }
1031}
1032
1033#[derive(Clone, Debug, Default)]
1034pub struct NmaBatchBuilder {
1035 range: NmaBatchRange,
1036 kernel: Kernel,
1037}
1038
1039impl NmaBatchBuilder {
1040 pub fn new() -> Self {
1041 Self::default()
1042 }
1043 pub fn kernel(mut self, k: Kernel) -> Self {
1044 self.kernel = k;
1045 self
1046 }
1047
1048 #[inline]
1049 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1050 self.range.period = (start, end, step);
1051 self
1052 }
1053 #[inline]
1054 pub fn period_static(mut self, p: usize) -> Self {
1055 self.range.period = (p, p, 0);
1056 self
1057 }
1058
1059 pub fn apply_slice(self, data: &[f64]) -> Result<NmaBatchOutput, NmaError> {
1060 nma_batch_with_kernel(data, &self.range, self.kernel)
1061 }
1062
1063 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<NmaBatchOutput, NmaError> {
1064 NmaBatchBuilder::new().kernel(k).apply_slice(data)
1065 }
1066
1067 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<NmaBatchOutput, NmaError> {
1068 let slice = source_type(c, src);
1069 self.apply_slice(slice)
1070 }
1071
1072 pub fn with_default_candles(c: &Candles) -> Result<NmaBatchOutput, NmaError> {
1073 NmaBatchBuilder::new()
1074 .kernel(Kernel::Auto)
1075 .apply_candles(c, "close")
1076 }
1077}
1078
1079#[derive(Clone, Debug)]
1080pub struct NmaBatchOutput {
1081 pub values: Vec<f64>,
1082 pub combos: Vec<NmaParams>,
1083 pub rows: usize,
1084 pub cols: usize,
1085}
1086
1087impl NmaBatchOutput {
1088 pub fn row_for_params(&self, p: &NmaParams) -> Option<usize> {
1089 self.combos
1090 .iter()
1091 .position(|c| c.period.unwrap_or(40) == p.period.unwrap_or(40))
1092 }
1093
1094 pub fn values_for(&self, p: &NmaParams) -> Option<&[f64]> {
1095 self.row_for_params(p).map(|row| {
1096 let start = row * self.cols;
1097 &self.values[start..start + self.cols]
1098 })
1099 }
1100}
1101
1102#[inline(always)]
1103fn expand_grid(r: &NmaBatchRange) -> Result<Vec<NmaParams>, NmaError> {
1104 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, NmaError> {
1105 if step == 0 || start == end {
1106 return Ok(vec![start]);
1107 }
1108 if start < end {
1109 let mut v = Vec::new();
1110 let mut cur = start;
1111 while cur <= end {
1112 v.push(cur);
1113 cur = cur
1114 .checked_add(step)
1115 .ok_or_else(|| NmaError::InvalidRange { start, end, step })?;
1116 }
1117 if v.is_empty() {
1118 return Err(NmaError::InvalidRange { start, end, step });
1119 }
1120 Ok(v)
1121 } else {
1122 Err(NmaError::InvalidRange { start, end, step })
1123 }
1124 }
1125 let periods = axis_usize(r.period)?;
1126
1127 let mut out = Vec::with_capacity(periods.len());
1128 for &p in &periods {
1129 out.push(NmaParams { period: Some(p) });
1130 }
1131 Ok(out)
1132}
1133
1134#[inline]
1135fn round_up8(x: usize) -> usize {
1136 (x + 7) & !7
1137}
1138
1139#[inline(always)]
1140fn nma_batch_inner_into_scalar_reuse(
1141 data: &[f64],
1142 sweep: &NmaBatchRange,
1143 parallel: bool,
1144 out: &mut [f64],
1145) -> Result<Vec<NmaParams>, NmaError> {
1146 let combos = expand_grid(sweep)?;
1147 if combos.is_empty() {
1148 return Err(NmaError::InvalidInput("no parameter combinations".into()));
1149 }
1150
1151 let len = data.len();
1152 let first = data
1153 .iter()
1154 .position(|x| !x.is_nan())
1155 .ok_or(NmaError::AllValuesNaN)?;
1156 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1157 if len - first < max_p + 1 {
1158 return Err(NmaError::NotEnoughValidData {
1159 needed: max_p + 1,
1160 valid: len - first,
1161 });
1162 }
1163
1164 let rows = combos.len();
1165 let cols = len;
1166 let warm: Vec<usize> = combos.iter().map(|c| first + c.period.unwrap()).collect();
1167 let out_mu = unsafe {
1168 std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
1169 };
1170 unsafe { init_matrix_prefixes(out_mu, cols, &warm) };
1171
1172 let mut ln = alloc_with_nan_prefix(len, 0);
1173 for i in 0..len {
1174 ln[i] = data[i].max(1e-10).ln();
1175 }
1176 for i in 0..len.saturating_sub(1) {
1177 ln[i] = (ln[i + 1] - ln[i]).abs();
1178 }
1179 ln[len.saturating_sub(1)] = 0.0;
1180 let d = &ln;
1181
1182 let mut s = alloc_with_nan_prefix(len + 1, 0);
1183 s[0] = 0.0;
1184 for i in 0..len {
1185 s[i + 1] = s[i] + d[i];
1186 }
1187
1188 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| {
1189 let p = combos[row].period.unwrap();
1190 let warm = first + p;
1191
1192 let mut w_rev = Vec::with_capacity(p);
1193 for i in 0..p {
1194 let s0 = ((p - 1 - i) as f64).sqrt();
1195 let s1 = ((p - i) as f64).sqrt();
1196 w_rev.push(s1 - s0);
1197 }
1198 let dst = unsafe {
1199 std::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len())
1200 };
1201
1202 for j in warm..len {
1203 let base = j - p;
1204 let denom = s[j] - s[j - p];
1205
1206 let mut num = 0.0;
1207
1208 for t in 0..p {
1209 num += d[base + t] * w_rev[t];
1210 }
1211
1212 let ratio = if denom == 0.0 { 0.0 } else { num / denom };
1213 let x2 = data[j - p];
1214 let x1 = data[j - p + 1];
1215 dst[j] = ratio.mul_add(x1 - x2, x2);
1216 }
1217 };
1218
1219 if parallel {
1220 #[cfg(not(target_arch = "wasm32"))]
1221 {
1222 use rayon::prelude::*;
1223 out_mu
1224 .par_chunks_mut(cols)
1225 .enumerate()
1226 .for_each(|(r, row)| do_row(r, row));
1227 }
1228 #[cfg(target_arch = "wasm32")]
1229 for (r, row) in out_mu.chunks_mut(cols).enumerate() {
1230 do_row(r, row);
1231 }
1232 } else {
1233 for (r, row) in out_mu.chunks_mut(cols).enumerate() {
1234 do_row(r, row);
1235 }
1236 }
1237
1238 Ok(combos)
1239}
1240
1241#[inline(always)]
1242pub fn nma_batch_slice(
1243 data: &[f64],
1244 sweep: &NmaBatchRange,
1245 kern: Kernel,
1246) -> Result<NmaBatchOutput, NmaError> {
1247 nma_batch_inner(data, sweep, kern, false)
1248}
1249
1250#[inline(always)]
1251pub fn nma_batch_par_slice(
1252 data: &[f64],
1253 sweep: &NmaBatchRange,
1254 kern: Kernel,
1255) -> Result<NmaBatchOutput, NmaError> {
1256 nma_batch_inner(data, sweep, kern, true)
1257}
1258
1259#[inline(always)]
1260fn nma_batch_inner(
1261 data: &[f64],
1262 sweep: &NmaBatchRange,
1263 kern: Kernel,
1264 parallel: bool,
1265) -> Result<NmaBatchOutput, NmaError> {
1266 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1267 if kern == Kernel::Avx512 {
1268 let first = data
1269 .iter()
1270 .position(|x| !x.is_nan())
1271 .ok_or(NmaError::AllValuesNaN)?;
1272 return unsafe { nma_batch_avx512_optimized(data, sweep, first, parallel) };
1273 }
1274
1275 let combos = expand_grid(sweep)?;
1276 if combos.is_empty() {
1277 return Err(NmaError::InvalidInput("no parameter combinations".into()));
1278 }
1279 let rows = combos.len();
1280 let cols = data.len();
1281 let _ = rows
1282 .checked_mul(cols)
1283 .ok_or_else(|| NmaError::InvalidInput("rows*cols overflow".into()))?;
1284
1285 if kern == Kernel::Scalar {
1286 let first = data
1287 .iter()
1288 .position(|x| !x.is_nan())
1289 .ok_or(NmaError::AllValuesNaN)?;
1290 let warm: Vec<usize> = combos.iter().map(|c| first + c.period.unwrap()).collect();
1291 let mut raw = make_uninit_matrix(rows, cols);
1292 unsafe { init_matrix_prefixes(&mut raw, cols, &warm) };
1293
1294 let out: &mut [f64] =
1295 unsafe { std::slice::from_raw_parts_mut(raw.as_mut_ptr() as *mut f64, raw.len()) };
1296 let combos = nma_batch_inner_into_scalar_reuse(data, sweep, parallel, out)?;
1297
1298 let mut guard = core::mem::ManuallyDrop::new(raw);
1299 let values = unsafe {
1300 Vec::from_raw_parts(
1301 guard.as_mut_ptr() as *mut f64,
1302 guard.len(),
1303 guard.capacity(),
1304 )
1305 };
1306 return Ok(NmaBatchOutput {
1307 values,
1308 combos,
1309 rows,
1310 cols,
1311 });
1312 }
1313
1314 let first = data
1315 .iter()
1316 .position(|x| !x.is_nan())
1317 .ok_or(NmaError::AllValuesNaN)?;
1318 let max_p = combos
1319 .iter()
1320 .map(|c| round_up8(c.period.unwrap()))
1321 .max()
1322 .unwrap();
1323 if data.len() - first < max_p + 1 {
1324 return Err(NmaError::NotEnoughValidData {
1325 needed: max_p + 1,
1326 valid: data.len() - first,
1327 });
1328 }
1329
1330 let warm: Vec<usize> = combos.iter().map(|c| first + c.period.unwrap()).collect();
1331
1332 let mut raw = make_uninit_matrix(rows, cols);
1333 unsafe { init_matrix_prefixes(&mut raw, cols, &warm) };
1334
1335 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| {
1336 let period = combos[row].period.unwrap();
1337
1338 let out_row = unsafe {
1339 core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len())
1340 };
1341
1342 match kern {
1343 Kernel::Scalar => nma_row_scalar(data, first, period, out_row),
1344 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1345 Kernel::Avx2 => unsafe { nma_row_avx2(data, first, period, out_row) },
1346 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1347 Kernel::Avx512 => unsafe { nma_row_avx512(data, first, period, out_row) },
1348 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1349 Kernel::Avx2 | Kernel::Avx512 => nma_row_scalar(data, first, period, out_row),
1350 _ => nma_row_scalar(data, first, period, out_row),
1351 }
1352 };
1353
1354 if parallel {
1355 #[cfg(not(target_arch = "wasm32"))]
1356 {
1357 use rayon::prelude::*;
1358 raw.par_chunks_mut(cols)
1359 .enumerate()
1360 .for_each(|(row, slice)| do_row(row, slice));
1361 }
1362
1363 #[cfg(target_arch = "wasm32")]
1364 {
1365 for (row, slice) in raw.chunks_mut(cols).enumerate() {
1366 do_row(row, slice);
1367 }
1368 }
1369 } else {
1370 for (row, slice) in raw.chunks_mut(cols).enumerate() {
1371 do_row(row, slice);
1372 }
1373 }
1374
1375 let mut guard = core::mem::ManuallyDrop::new(raw);
1376 let values = unsafe {
1377 Vec::from_raw_parts(
1378 guard.as_mut_ptr() as *mut f64,
1379 guard.len(),
1380 guard.capacity(),
1381 )
1382 };
1383
1384 Ok(NmaBatchOutput {
1385 values,
1386 combos,
1387 rows,
1388 cols,
1389 })
1390}
1391
1392#[inline(always)]
1393fn nma_batch_inner_into(
1394 data: &[f64],
1395 sweep: &NmaBatchRange,
1396 kern: Kernel,
1397 parallel: bool,
1398 out: &mut [f64],
1399) -> Result<Vec<NmaParams>, NmaError> {
1400 let combos = expand_grid(sweep)?;
1401 if combos.is_empty() {
1402 return Err(NmaError::InvalidInput("no parameter combinations".into()));
1403 }
1404
1405 let first = data
1406 .iter()
1407 .position(|x| !x.is_nan())
1408 .ok_or(NmaError::AllValuesNaN)?;
1409 let max_p = combos
1410 .iter()
1411 .map(|c| round_up8(c.period.unwrap()))
1412 .max()
1413 .unwrap();
1414 if data.len() - first < max_p + 1 {
1415 return Err(NmaError::NotEnoughValidData {
1416 needed: max_p + 1,
1417 valid: data.len() - first,
1418 });
1419 }
1420
1421 let rows = combos.len();
1422 let cols = data.len();
1423
1424 let warm: Vec<usize> = combos.iter().map(|c| first + c.period.unwrap()).collect();
1425
1426 let out_uninit = unsafe {
1427 std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
1428 };
1429
1430 unsafe { init_matrix_prefixes(out_uninit, cols, &warm) };
1431
1432 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| {
1433 let period = combos[row].period.unwrap();
1434
1435 let out_row = unsafe {
1436 core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len())
1437 };
1438
1439 match kern {
1440 Kernel::Scalar => nma_row_scalar(data, first, period, out_row),
1441 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1442 Kernel::Avx2 => unsafe { nma_row_avx2(data, first, period, out_row) },
1443 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1444 Kernel::Avx512 => unsafe { nma_row_avx512(data, first, period, out_row) },
1445 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1446 Kernel::Avx2 | Kernel::Avx512 => nma_row_scalar(data, first, period, out_row),
1447 _ => nma_row_scalar(data, first, period, out_row),
1448 }
1449 };
1450
1451 if parallel {
1452 #[cfg(not(target_arch = "wasm32"))]
1453 {
1454 out_uninit
1455 .par_chunks_mut(cols)
1456 .enumerate()
1457 .for_each(|(row, slice)| do_row(row, slice));
1458 }
1459 #[cfg(target_arch = "wasm32")]
1460 {
1461 for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
1462 do_row(row, slice);
1463 }
1464 }
1465 } else {
1466 for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
1467 do_row(row, slice);
1468 }
1469 }
1470
1471 Ok(combos)
1472}
1473
1474#[inline(always)]
1475fn nma_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1476 nma_scalar(data, period, first, out)
1477}
1478
1479#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1480#[inline(always)]
1481unsafe fn nma_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1482 let len = data.len();
1483 let mut ln_values = alloc_with_nan_prefix(len, 0);
1484
1485 let mut sqrt_diffs = vec![0.0; period];
1486
1487 for i in 0..len {
1488 ln_values[i] = data[i].max(1e-10).ln();
1489 }
1490
1491 for k in 0..period {
1492 let s0 = (k as f64).sqrt();
1493 let s1 = ((k + 1) as f64).sqrt();
1494 sqrt_diffs[k] = s1 - s0;
1495 }
1496
1497 nma_avx2(data, period, first, &mut ln_values, &mut sqrt_diffs, out);
1498}
1499
1500#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1501#[inline(always)]
1502pub unsafe fn nma_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1503 let len = data.len();
1504 let mut ln_values = alloc_with_nan_prefix(len, 0);
1505
1506 let mut sqrt_diffs = vec![0.0; period];
1507
1508 for i in 0..len {
1509 ln_values[i] = data[i].max(1e-10).ln();
1510 }
1511
1512 for k in 0..period {
1513 let s0 = (k as f64).sqrt();
1514 let s1 = ((k + 1) as f64).sqrt();
1515 sqrt_diffs[k] = s1 - s0;
1516 }
1517
1518 nma_avx512_v2(data, period, first, &mut ln_values, &mut sqrt_diffs, out);
1519}
1520
1521#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1522#[inline(always)]
1523pub unsafe fn nma_row_avx512_short(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1524 nma_row_avx512(data, first, period, out)
1525}
1526
1527#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1528#[inline(always)]
1529pub unsafe fn nma_row_avx512_long(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1530 nma_row_avx512(data, first, period, out)
1531}
1532
1533#[derive(Debug, Clone)]
1534pub struct NmaStream {
1535 period: usize,
1536
1537 m: usize,
1538
1539 alpha: Vec<f64>,
1540 beta: Vec<f64>,
1541 beta_pow_p: Vec<f64>,
1542
1543 d_ring: Vec<f64>,
1544 d_head: usize,
1545 d_count: usize,
1546 denom: f64,
1547 x_acc: Vec<f64>,
1548
1549 buffer: Vec<f64>,
1550 ln_buffer: Vec<f64>,
1551 head: usize,
1552 filled: bool,
1553
1554 sqrt_diffs: Vec<f64>,
1555}
1556
1557#[inline(always)]
1558fn ln_pos(x: f64) -> f64 {
1559 debug_assert!(x > 0.0);
1560 x.ln()
1561}
1562
1563impl NmaStream {
1564 pub fn try_new(params: NmaParams) -> Result<Self, NmaError> {
1565 let period = params.period.unwrap_or(40);
1566 if period == 0 {
1567 return Err(NmaError::InvalidPeriod {
1568 period,
1569 data_len: 0,
1570 });
1571 }
1572
1573 let mut sqrt_diffs = Vec::with_capacity(period);
1574 for i in 0..period {
1575 let s0 = (i as f64).sqrt();
1576 let s1 = ((i + 1) as f64).sqrt();
1577 sqrt_diffs.push(s1 - s0);
1578 }
1579
1580 const GAMMAS: [f64; 4] = [0.25, 1.2, 3.0, 8.0];
1581 let m = if period <= 64 { 3 } else { 4 };
1582 let mut beta = Vec::with_capacity(m);
1583 for g in GAMMAS.iter().take(m) {
1584 beta.push((-g / (period as f64)).exp());
1585 }
1586
1587 let alpha = fit_exp_weights_least_squares(&sqrt_diffs, &beta);
1588
1589 let mut beta_pow_p = Vec::with_capacity(m);
1590 for &b in &beta {
1591 beta_pow_p.push(b.powi(period as i32));
1592 }
1593
1594 Ok(Self {
1595 period,
1596 m,
1597 alpha,
1598 beta,
1599 beta_pow_p,
1600 d_ring: vec![0.0; period],
1601 d_head: 0,
1602 d_count: 0,
1603 denom: 0.0,
1604 x_acc: vec![0.0; m],
1605
1606 buffer: vec![f64::NAN; period + 1],
1607 ln_buffer: vec![f64::NAN; period + 1],
1608 head: 0,
1609 filled: false,
1610
1611 sqrt_diffs,
1612 })
1613 }
1614
1615 #[inline(always)]
1616 pub fn update(&mut self, value: f64) -> Option<f64> {
1617 if !value.is_finite() {
1618 self.reset_state();
1619 return None;
1620 }
1621
1622 let ln_val = ln_pos(value.max(1e-10));
1623
1624 let prev_idx = (self.head + self.period) % (self.period + 1);
1625 let prev_ln = self.ln_buffer[prev_idx];
1626
1627 self.buffer[self.head] = value;
1628 self.ln_buffer[self.head] = ln_val;
1629
1630 self.head = (self.head + 1) % (self.period + 1);
1631 if !self.filled && self.head == 0 {
1632 self.filled = true;
1633 }
1634
1635 if prev_ln.is_nan() {
1636 return None;
1637 }
1638
1639 let d_new = (ln_val - prev_ln).abs();
1640
1641 if self.d_count < self.period {
1642 self.d_ring[self.d_head] = d_new;
1643 self.d_head = (self.d_head + 1) % self.period;
1644 self.d_count += 1;
1645 self.denom += d_new;
1646
1647 for m in 0..self.m {
1648 self.x_acc[m] = self.beta[m] * self.x_acc[m] + d_new;
1649 }
1650 } else {
1651 let d_old = self.d_ring[self.d_head];
1652 self.d_ring[self.d_head] = d_new;
1653 self.d_head = (self.d_head + 1) % self.period;
1654
1655 self.denom += d_new - d_old;
1656
1657 for m in 0..self.m {
1658 self.x_acc[m] = self.beta[m] * self.x_acc[m] + d_new - self.beta_pow_p[m] * d_old;
1659 }
1660 }
1661
1662 if !self.filled {
1663 return None;
1664 }
1665
1666 let mut num = 0.0f64;
1667 for m in 0..self.m {
1668 num = (self.alpha[m] * self.x_acc[m]).mul_add(1.0, num);
1669 }
1670 let ratio = if self.denom == 0.0 {
1671 0.0
1672 } else {
1673 num / self.denom
1674 };
1675
1676 let x0 = self.buffer[self.head];
1677 let x1 = self.buffer[(self.head + 1) % (self.period + 1)];
1678
1679 Some((x1 - x0).mul_add(ratio, x0))
1680 }
1681
1682 #[inline(always)]
1683 fn reset_state(&mut self) {
1684 self.d_head = 0;
1685 self.d_count = 0;
1686 self.denom = 0.0;
1687 for v in &mut self.d_ring {
1688 *v = 0.0;
1689 }
1690 for v in &mut self.x_acc {
1691 *v = 0.0;
1692 }
1693 for v in &mut self.buffer {
1694 *v = f64::NAN;
1695 }
1696 for v in &mut self.ln_buffer {
1697 *v = f64::NAN;
1698 }
1699 self.head = 0;
1700 self.filled = false;
1701 }
1702}
1703
1704fn fit_exp_weights_least_squares(w: &[f64], beta: &[f64]) -> Vec<f64> {
1705 let p = w.len();
1706 let m = beta.len();
1707
1708 let mut ata = vec![0.0f64; m * m];
1709 for u in 0..m {
1710 for v in u..m {
1711 let r = beta[u] * beta[v];
1712 let s = if (1.0 - r).abs() < 1e-15 {
1713 p as f64
1714 } else {
1715 (1.0 - r.powi(p as i32)) / (1.0 - r)
1716 };
1717 ata[u * m + v] = s;
1718 ata[v * m + u] = s;
1719 }
1720 }
1721
1722 let mut atw = vec![0.0f64; m];
1723 for u in 0..m {
1724 let mut pow = 1.0f64;
1725 let bu = beta[u];
1726 let mut sum = 0.0f64;
1727 for i in 0..p {
1728 sum += w[i] * pow;
1729 pow *= bu;
1730 }
1731 atw[u] = sum;
1732 }
1733
1734 let lambda = 1e-12;
1735 for i in 0..m {
1736 ata[i * m + i] += lambda;
1737 }
1738
1739 solve_linear_system(&mut ata, &mut atw, m)
1740}
1741
1742fn solve_linear_system(a: &mut [f64], b: &mut [f64], n: usize) -> Vec<f64> {
1743 for k in 0..n {
1744 let mut piv = k;
1745 let mut maxv = a[k * n + k].abs();
1746 for i in (k + 1)..n {
1747 let v = a[i * n + k].abs();
1748 if v > maxv {
1749 maxv = v;
1750 piv = i;
1751 }
1752 }
1753 if piv != k {
1754 for j in k..n {
1755 a.swap(k * n + j, piv * n + j);
1756 }
1757 b.swap(k, piv);
1758 }
1759 let akk = a[k * n + k];
1760 if akk.abs() < 1e-18 {
1761 a[k * n + k] = 1e-18;
1762 }
1763
1764 for i in (k + 1)..n {
1765 let f = a[i * n + k] / a[k * n + k];
1766 if f != 0.0 {
1767 for j in k..n {
1768 a[i * n + j] -= f * a[k * n + j];
1769 }
1770 b[i] -= f * b[k];
1771 }
1772 }
1773 }
1774
1775 let mut x = vec![0.0f64; n];
1776 for i in (0..n).rev() {
1777 let mut s = b[i];
1778 for j in (i + 1)..n {
1779 s -= a[i * n + j] * x[j];
1780 }
1781 x[i] = s / a[i * n + i];
1782 }
1783 x
1784}
1785
1786#[cfg(feature = "python")]
1787use crate::utilities::kernel_validation::validate_kernel;
1788#[cfg(feature = "python")]
1789use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1790#[cfg(feature = "python")]
1791use pyo3::exceptions::PyValueError;
1792#[cfg(feature = "python")]
1793use pyo3::prelude::*;
1794#[cfg(feature = "python")]
1795use pyo3::types::PyDict;
1796
1797#[cfg(feature = "python")]
1798#[pyfunction(name = "nma")]
1799#[pyo3(signature = (data, period, kernel=None))]
1800pub fn nma_py<'py>(
1801 py: Python<'py>,
1802 data: numpy::PyReadonlyArray1<'py, f64>,
1803 period: usize,
1804 kernel: Option<&str>,
1805) -> PyResult<Bound<'py, PyArray1<f64>>> {
1806 use numpy::{IntoPyArray, PyArrayMethods};
1807
1808 let slice_in = data.as_slice()?;
1809 let kern = validate_kernel(kernel, false)?;
1810 let params = NmaParams {
1811 period: Some(period),
1812 };
1813 let nma_in = NmaInput::from_slice(slice_in, params);
1814
1815 let result_vec: Vec<f64> = py
1816 .allow_threads(|| nma_with_kernel(&nma_in, kern).map(|o| o.values))
1817 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1818
1819 Ok(result_vec.into_pyarray(py))
1820}
1821
1822#[cfg(feature = "python")]
1823#[pyclass(name = "NmaStream")]
1824pub struct NmaStreamPy {
1825 stream: NmaStream,
1826}
1827
1828#[cfg(feature = "python")]
1829#[pymethods]
1830impl NmaStreamPy {
1831 #[new]
1832 fn new(period: usize) -> PyResult<Self> {
1833 let params = NmaParams {
1834 period: Some(period),
1835 };
1836 let stream =
1837 NmaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1838 Ok(NmaStreamPy { stream })
1839 }
1840
1841 fn update(&mut self, value: f64) -> Option<f64> {
1842 self.stream.update(value)
1843 }
1844}
1845
1846#[cfg(feature = "python")]
1847#[pyfunction(name = "nma_batch")]
1848#[pyo3(signature = (data, period_range, kernel=None))]
1849pub fn nma_batch_py<'py>(
1850 py: Python<'py>,
1851 data: numpy::PyReadonlyArray1<'py, f64>,
1852 period_range: (usize, usize, usize),
1853 kernel: Option<&str>,
1854) -> PyResult<Bound<'py, PyDict>> {
1855 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1856 use pyo3::types::PyDict;
1857
1858 let slice_in = data.as_slice()?;
1859 let kern = validate_kernel(kernel, true)?;
1860
1861 let sweep = NmaBatchRange {
1862 period: period_range,
1863 };
1864
1865 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1866 let rows = combos.len();
1867 let cols = slice_in.len();
1868 let expected = rows
1869 .checked_mul(cols)
1870 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
1871
1872 let out_arr = unsafe { PyArray1::<f64>::new(py, [expected], false) };
1873 let slice_out = unsafe { out_arr.as_slice_mut()? };
1874
1875 let combos = py
1876 .allow_threads(|| {
1877 let kernel = match kern {
1878 Kernel::Auto => detect_best_batch_kernel(),
1879 k => k,
1880 };
1881 let simd = match kernel {
1882 Kernel::Avx512Batch => Kernel::Avx512,
1883 Kernel::Avx2Batch => Kernel::Avx2,
1884 Kernel::ScalarBatch => Kernel::Scalar,
1885 _ => kernel,
1886 };
1887
1888 nma_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1889 })
1890 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1891
1892 let dict = PyDict::new(py);
1893 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1894 dict.set_item(
1895 "periods",
1896 combos
1897 .iter()
1898 .map(|p| p.period.unwrap() as u64)
1899 .collect::<Vec<_>>()
1900 .into_pyarray(py),
1901 )?;
1902
1903 Ok(dict)
1904}
1905
1906#[cfg(all(feature = "python", feature = "cuda"))]
1907#[pyfunction(name = "nma_cuda_batch_dev")]
1908#[pyo3(signature = (data_f32, period_range, device_id=0))]
1909pub fn nma_cuda_batch_dev_py<'py>(
1910 py: Python<'py>,
1911 data_f32: numpy::PyReadonlyArray1<'py, f32>,
1912 period_range: (usize, usize, usize),
1913 device_id: usize,
1914) -> PyResult<(NmaDeviceArrayF32Py, Bound<'py, PyDict>)> {
1915 use crate::cuda::cuda_available;
1916 use numpy::IntoPyArray;
1917 use pyo3::types::PyDict;
1918
1919 if !cuda_available() {
1920 return Err(PyValueError::new_err("CUDA not available"));
1921 }
1922
1923 let slice_in = data_f32.as_slice()?;
1924 let sweep = NmaBatchRange {
1925 period: period_range,
1926 };
1927
1928 let (inner, combos, ctx_arc, dev_id) = py.allow_threads(|| {
1929 let cuda = CudaNma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1930 let (dev, combos) = cuda
1931 .nma_batch_dev(slice_in, &sweep)
1932 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1933 cuda.synchronize()
1934 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1935 Ok::<_, PyErr>((dev, combos, cuda.context_arc_clone(), cuda.device_id()))
1936 })?;
1937
1938 let dict = PyDict::new(py);
1939 let periods: Vec<u64> = combos.iter().map(|c| c.period.unwrap() as u64).collect();
1940 dict.set_item("periods", periods.into_pyarray(py))?;
1941
1942 Ok((
1943 NmaDeviceArrayF32Py {
1944 inner,
1945 _ctx: ctx_arc,
1946 device_id: dev_id,
1947 },
1948 dict,
1949 ))
1950}
1951
1952#[cfg(all(feature = "python", feature = "cuda"))]
1953#[pyfunction(name = "nma_cuda_many_series_one_param_dev")]
1954#[pyo3(signature = (data_tm_f32, period, device_id=0))]
1955pub fn nma_cuda_many_series_one_param_dev_py(
1956 py: Python<'_>,
1957 data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1958 period: usize,
1959 device_id: usize,
1960) -> PyResult<NmaDeviceArrayF32Py> {
1961 use crate::cuda::cuda_available;
1962 use numpy::PyUntypedArrayMethods;
1963
1964 if !cuda_available() {
1965 return Err(PyValueError::new_err("CUDA not available"));
1966 }
1967
1968 let flat_in = data_tm_f32.as_slice()?;
1969 let rows = data_tm_f32.shape()[0];
1970 let cols = data_tm_f32.shape()[1];
1971 let params = NmaParams {
1972 period: Some(period),
1973 };
1974
1975 let (inner, ctx_arc, dev_id) = py.allow_threads(|| {
1976 let cuda = CudaNma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1977 let dev = cuda
1978 .nma_multi_series_one_param_time_major_dev(flat_in, cols, rows, ¶ms)
1979 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1980 cuda.synchronize()
1981 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1982 Ok::<_, PyErr>((dev, cuda.context_arc_clone(), cuda.device_id()))
1983 })?;
1984
1985 Ok(NmaDeviceArrayF32Py {
1986 inner,
1987 _ctx: ctx_arc,
1988 device_id: dev_id,
1989 })
1990}
1991
1992#[cfg(all(feature = "python", feature = "cuda"))]
1993#[pyclass(module = "ta_indicators.cuda", name = "NmaDeviceArrayF32", unsendable)]
1994pub struct NmaDeviceArrayF32Py {
1995 pub(crate) inner: DeviceArrayF32,
1996 pub(crate) _ctx: Arc<Context>,
1997 pub(crate) device_id: u32,
1998}
1999
2000#[cfg(all(feature = "python", feature = "cuda"))]
2001#[pymethods]
2002impl NmaDeviceArrayF32Py {
2003 #[getter]
2004 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2005 let d = PyDict::new(py);
2006 d.set_item("shape", (self.inner.rows, self.inner.cols))?;
2007 d.set_item("typestr", "<f4")?;
2008 d.set_item(
2009 "strides",
2010 (
2011 self.inner.cols * std::mem::size_of::<f32>(),
2012 std::mem::size_of::<f32>(),
2013 ),
2014 )?;
2015 d.set_item("data", (self.inner.device_ptr() as usize, false))?;
2016
2017 d.set_item("version", 3)?;
2018 Ok(d)
2019 }
2020
2021 fn __dlpack_device__(&self) -> (i32, i32) {
2022 (2, self.device_id as i32)
2023 }
2024
2025 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2026 fn __dlpack__<'py>(
2027 &mut self,
2028 py: Python<'py>,
2029 stream: Option<pyo3::PyObject>,
2030 max_version: Option<pyo3::PyObject>,
2031 dl_device: Option<pyo3::PyObject>,
2032 copy: Option<pyo3::PyObject>,
2033 ) -> PyResult<PyObject> {
2034 use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
2035
2036 let (kdl, alloc_dev) = self.__dlpack_device__();
2037 if let Some(dev_obj) = dl_device.as_ref() {
2038 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2039 if dev_ty != kdl || dev_id != alloc_dev {
2040 let wants_copy = copy
2041 .as_ref()
2042 .and_then(|c| c.extract::<bool>(py).ok())
2043 .unwrap_or(false);
2044 if wants_copy {
2045 return Err(PyValueError::new_err(
2046 "device copy not implemented for __dlpack__",
2047 ));
2048 } else {
2049 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
2050 }
2051 }
2052 }
2053 }
2054 let _ = stream;
2055
2056 let dummy =
2057 DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
2058 let inner = std::mem::replace(
2059 &mut self.inner,
2060 DeviceArrayF32 {
2061 buf: dummy,
2062 rows: 0,
2063 cols: 0,
2064 },
2065 );
2066
2067 let rows = inner.rows;
2068 let cols = inner.cols;
2069 let buf = inner.buf;
2070
2071 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2072
2073 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2074 }
2075}
2076
2077pub fn nma_into_slice(dst: &mut [f64], input: &NmaInput, kern: Kernel) -> Result<(), NmaError> {
2078 let (data, period, first, mut ln_values, mut sqrt_diffs, chosen) = nma_prepare(input, kern)?;
2079
2080 if dst.len() != data.len() {
2081 return Err(NmaError::OutputLengthMismatch {
2082 expected: data.len(),
2083 got: dst.len(),
2084 });
2085 }
2086
2087 nma_compute_into(
2088 data,
2089 period,
2090 first,
2091 &mut ln_values,
2092 &mut sqrt_diffs,
2093 chosen,
2094 dst,
2095 );
2096
2097 let warmup_end = first + period;
2098 let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
2099 for v in &mut dst[..warmup_end] {
2100 *v = qnan;
2101 }
2102
2103 Ok(())
2104}
2105
2106#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2107pub fn nma_into(input: &NmaInput, out: &mut [f64]) -> Result<(), NmaError> {
2108 nma_into_slice(out, input, Kernel::Auto)
2109}
2110
2111#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2112use serde::{Deserialize, Serialize};
2113#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2114use wasm_bindgen::prelude::*;
2115
2116#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2117#[wasm_bindgen]
2118pub fn nma_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
2119 let params = NmaParams {
2120 period: Some(period),
2121 };
2122 let input = NmaInput::from_slice(data, params);
2123
2124 let mut output = vec![0.0; data.len()];
2125
2126 nma_into_slice(&mut output, &input, detect_best_kernel())
2127 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2128
2129 Ok(output)
2130}
2131
2132#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2133#[derive(Serialize, Deserialize)]
2134pub struct NmaBatchConfig {
2135 pub period_range: (usize, usize, usize),
2136}
2137
2138#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2139#[derive(Serialize, Deserialize)]
2140pub struct NmaBatchJsOutput {
2141 pub values: Vec<f64>,
2142 pub combos: Vec<NmaParams>,
2143 pub rows: usize,
2144 pub cols: usize,
2145}
2146
2147#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2148#[wasm_bindgen(js_name = nma_batch)]
2149pub fn nma_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2150 let config: NmaBatchConfig = serde_wasm_bindgen::from_value(config)
2151 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2152
2153 let sweep = NmaBatchRange {
2154 period: config.period_range,
2155 };
2156
2157 let output = nma_batch_inner(data, &sweep, Kernel::ScalarBatch, false)
2158 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2159
2160 let js_output = NmaBatchJsOutput {
2161 values: output.values,
2162 combos: output.combos,
2163 rows: output.rows,
2164 cols: output.cols,
2165 };
2166
2167 serde_wasm_bindgen::to_value(&js_output)
2168 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2169}
2170
2171#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2172#[wasm_bindgen]
2173pub fn nma_batch_js(
2174 data: &[f64],
2175 period_start: usize,
2176 period_end: usize,
2177 period_step: usize,
2178) -> Result<Vec<f64>, JsValue> {
2179 let sweep = NmaBatchRange {
2180 period: (period_start, period_end, period_step),
2181 };
2182
2183 nma_batch_inner(data, &sweep, Kernel::Scalar, false)
2184 .map(|output| output.values)
2185 .map_err(|e| JsValue::from_str(&e.to_string()))
2186}
2187
2188#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2189#[wasm_bindgen]
2190pub fn nma_batch_metadata_js(
2191 period_start: usize,
2192 period_end: usize,
2193 period_step: usize,
2194) -> Result<Vec<f64>, JsValue> {
2195 let sweep = NmaBatchRange {
2196 period: (period_start, period_end, period_step),
2197 };
2198
2199 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2200 let metadata: Vec<f64> = combos
2201 .iter()
2202 .map(|combo| combo.period.unwrap() as f64)
2203 .collect();
2204
2205 Ok(metadata)
2206}
2207
2208#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2209#[wasm_bindgen]
2210pub fn nma_batch_rows_cols_js(
2211 period_start: usize,
2212 period_end: usize,
2213 period_step: usize,
2214 data_len: usize,
2215) -> Vec<usize> {
2216 let sweep = NmaBatchRange {
2217 period: (period_start, period_end, period_step),
2218 };
2219 let combos = expand_grid(&sweep).unwrap_or_else(|_| Vec::new());
2220 vec![combos.len(), data_len]
2221}
2222
2223#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2224#[wasm_bindgen]
2225pub fn nma_alloc(len: usize) -> *mut f64 {
2226 let mut vec = Vec::<f64>::with_capacity(len);
2227 let ptr = vec.as_mut_ptr();
2228 std::mem::forget(vec);
2229 ptr
2230}
2231
2232#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2233#[wasm_bindgen]
2234pub fn nma_free(ptr: *mut f64, len: usize) {
2235 if !ptr.is_null() {
2236 unsafe {
2237 let _ = Vec::from_raw_parts(ptr, len, len);
2238 }
2239 }
2240}
2241
2242#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2243#[wasm_bindgen]
2244pub fn nma_into(
2245 in_ptr: *const f64,
2246 out_ptr: *mut f64,
2247 len: usize,
2248 period: usize,
2249) -> Result<(), JsValue> {
2250 if in_ptr.is_null() || out_ptr.is_null() {
2251 return Err(JsValue::from_str("null pointer passed to nma_into"));
2252 }
2253
2254 unsafe {
2255 let data = std::slice::from_raw_parts(in_ptr, len);
2256
2257 if period == 0 || period > len {
2258 return Err(JsValue::from_str("Invalid period"));
2259 }
2260
2261 let params = NmaParams {
2262 period: Some(period),
2263 };
2264 let input = NmaInput::from_slice(data, params);
2265
2266 if in_ptr == out_ptr {
2267 let mut temp = alloc_with_nan_prefix(len, 0);
2268 nma_into_slice(&mut temp, &input, Kernel::Scalar)
2269 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2270
2271 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2272 out.copy_from_slice(&temp);
2273 } else {
2274 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2275 nma_into_slice(out, &input, Kernel::Scalar)
2276 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2277 }
2278
2279 Ok(())
2280 }
2281}
2282
2283#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2284#[wasm_bindgen]
2285pub fn nma_batch_into(
2286 in_ptr: *const f64,
2287 out_ptr: *mut f64,
2288 len: usize,
2289 period_start: usize,
2290 period_end: usize,
2291 period_step: usize,
2292) -> Result<usize, JsValue> {
2293 if in_ptr.is_null() || out_ptr.is_null() {
2294 return Err(JsValue::from_str("null pointer passed to nma_batch_into"));
2295 }
2296
2297 unsafe {
2298 let data = std::slice::from_raw_parts(in_ptr, len);
2299
2300 let sweep = NmaBatchRange {
2301 period: (period_start, period_end, period_step),
2302 };
2303
2304 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2305 let rows = combos.len();
2306 let cols = len;
2307
2308 let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
2309
2310 nma_batch_inner_into(data, &sweep, Kernel::ScalarBatch, false, out)
2311 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2312
2313 Ok(rows)
2314 }
2315}
2316
2317#[cfg(test)]
2318mod tests {
2319 use super::*;
2320 use crate::skip_if_unsupported;
2321 use crate::utilities::data_loader::read_candles_from_csv;
2322
2323 fn check_nma_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2324 skip_if_unsupported!(kernel, test_name);
2325 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2326 let candles = read_candles_from_csv(file_path)?;
2327
2328 let default_params = NmaParams { period: None };
2329 let input = NmaInput::from_candles(&candles, "close", default_params);
2330 let output = nma_with_kernel(&input, kernel)?;
2331 assert_eq!(output.values.len(), candles.close.len());
2332
2333 Ok(())
2334 }
2335
2336 #[test]
2337 fn test_nma_into_matches_api() -> Result<(), Box<dyn Error>> {
2338 let n = 256usize;
2339 let mut data = vec![0.0f64; n];
2340 for i in 0..n {
2341 let t = i as f64;
2342 data[i] = 100.0 + 0.1 * t + (t * 0.07).sin();
2343 }
2344
2345 let params = NmaParams::default();
2346 let input = NmaInput::from_slice(&data, params);
2347
2348 let baseline = nma(&input)?.values;
2349
2350 let mut out = vec![0.0f64; n];
2351 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2352 {
2353 nma_into(&input, &mut out)?;
2354 }
2355 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2356 {
2357 nma_into_slice(&mut out, &input, detect_best_kernel())?;
2358 }
2359
2360 assert_eq!(baseline.len(), out.len());
2361
2362 fn eq_or_both_nan(a: f64, b: f64) -> bool {
2363 (a.is_nan() && b.is_nan()) || (a - b).abs() <= 1e-12
2364 }
2365
2366 for (i, (&a, &b)) in baseline.iter().zip(out.iter()).enumerate() {
2367 assert!(
2368 eq_or_both_nan(a, b),
2369 "Mismatch at index {}: baseline={} out={}",
2370 i,
2371 a,
2372 b
2373 );
2374 }
2375
2376 Ok(())
2377 }
2378
2379 fn check_nma_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2380 skip_if_unsupported!(kernel, test_name);
2381 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2382 let candles = read_candles_from_csv(file_path)?;
2383 let input = NmaInput::from_candles(&candles, "close", NmaParams::default());
2384 let nma_result = nma_with_kernel(&input, kernel)?;
2385
2386 let expected_last_five_nma = [
2387 64320.486018271724,
2388 64227.95719984426,
2389 64180.9249333126,
2390 63966.35530620797,
2391 64039.04719192334,
2392 ];
2393 let start_index = nma_result.values.len() - 5;
2394 let result_last_five_nma = &nma_result.values[start_index..];
2395 for (i, &value) in result_last_five_nma.iter().enumerate() {
2396 let expected_value = expected_last_five_nma[i];
2397
2398 let tolerance = if test_name.contains("avx512") {
2399 1.0
2400 } else {
2401 1e-3
2402 };
2403 assert!(
2404 (value - expected_value).abs() < tolerance,
2405 "[{}] NMA value mismatch at last-5 index {}: expected {}, got {}",
2406 test_name,
2407 i,
2408 expected_value,
2409 value
2410 );
2411 }
2412 Ok(())
2413 }
2414
2415 fn check_nma_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2416 skip_if_unsupported!(kernel, test_name);
2417 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2418 let candles = read_candles_from_csv(file_path)?;
2419 let input = NmaInput::with_default_candles(&candles);
2420 match input.data {
2421 NmaData::Candles { source, .. } => assert_eq!(source, "close"),
2422 _ => panic!("Expected NmaData::Candles"),
2423 }
2424 let output = nma_with_kernel(&input, kernel)?;
2425 assert_eq!(output.values.len(), candles.close.len());
2426
2427 Ok(())
2428 }
2429
2430 fn check_nma_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2431 skip_if_unsupported!(kernel, test_name);
2432 let input_data = [10.0, 20.0, 30.0];
2433 let params = NmaParams { period: Some(0) };
2434 let input = NmaInput::from_slice(&input_data, params);
2435 let res = nma_with_kernel(&input, kernel);
2436 assert!(
2437 res.is_err(),
2438 "[{}] NMA should fail with zero period",
2439 test_name
2440 );
2441 Ok(())
2442 }
2443
2444 fn check_nma_period_exceeds_length(
2445 test_name: &str,
2446 kernel: Kernel,
2447 ) -> Result<(), Box<dyn Error>> {
2448 skip_if_unsupported!(kernel, test_name);
2449 let data_small = [10.0, 20.0, 30.0];
2450 let params = NmaParams { period: Some(10) };
2451 let input = NmaInput::from_slice(&data_small, params);
2452 let res = nma_with_kernel(&input, kernel);
2453 assert!(
2454 res.is_err(),
2455 "[{}] NMA should fail with period exceeding length",
2456 test_name
2457 );
2458 Ok(())
2459 }
2460
2461 fn check_nma_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2462 skip_if_unsupported!(kernel, test_name);
2463 let single_point = [42.0];
2464 let params = NmaParams { period: Some(40) };
2465 let input = NmaInput::from_slice(&single_point, params);
2466 let res = nma_with_kernel(&input, kernel);
2467 assert!(
2468 res.is_err(),
2469 "[{}] NMA should fail with insufficient data",
2470 test_name
2471 );
2472 Ok(())
2473 }
2474
2475 fn check_nma_empty_input(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2476 skip_if_unsupported!(kernel, test_name);
2477 let empty: [f64; 0] = [];
2478 let input = NmaInput::from_slice(&empty, NmaParams::default());
2479 let res = nma_with_kernel(&input, kernel);
2480 assert!(
2481 matches!(res, Err(NmaError::EmptyInputData)),
2482 "[{}] NMA should fail with empty input error",
2483 test_name
2484 );
2485 Ok(())
2486 }
2487
2488 fn check_nma_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2489 skip_if_unsupported!(kernel, test_name);
2490 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2491 let candles = read_candles_from_csv(file_path)?;
2492 let first_params = NmaParams { period: Some(40) };
2493 let first_input = NmaInput::from_candles(&candles, "close", first_params);
2494 let first_result = nma_with_kernel(&first_input, kernel)?;
2495 let second_params = NmaParams { period: Some(20) };
2496 let second_input = NmaInput::from_slice(&first_result.values, second_params);
2497 let second_result = nma_with_kernel(&second_input, kernel)?;
2498 assert_eq!(second_result.values.len(), first_result.values.len());
2499 if second_result.values.len() > 240 {
2500 for i in 240..second_result.values.len() {
2501 assert!(second_result.values[i].is_finite());
2502 }
2503 }
2504 Ok(())
2505 }
2506
2507 fn check_nma_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2508 skip_if_unsupported!(kernel, test_name);
2509 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2510 let candles = read_candles_from_csv(file_path)?;
2511 let input = NmaInput::from_candles(&candles, "close", NmaParams { period: Some(40) });
2512 let res = nma_with_kernel(&input, kernel)?;
2513 assert_eq!(res.values.len(), candles.close.len());
2514 if res.values.len() > 240 {
2515 for (i, &val) in res.values[240..].iter().enumerate() {
2516 assert!(
2517 !val.is_nan(),
2518 "[{}] Found unexpected NaN at out-index {}",
2519 test_name,
2520 240 + i
2521 );
2522 }
2523 }
2524 Ok(())
2525 }
2526
2527 fn check_nma_property(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2528 use proptest::prelude::*;
2529 skip_if_unsupported!(kernel, test_name);
2530
2531 let strat = (2usize..=100).prop_flat_map(|period| {
2532 (
2533 prop::collection::vec(
2534 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
2535 (period + 1)..400,
2536 ),
2537 Just(period),
2538 )
2539 });
2540
2541 proptest::test_runner::TestRunner::default()
2542 .run(&strat, |(data, period)| {
2543 let params = NmaParams {
2544 period: Some(period),
2545 };
2546 let input = NmaInput::from_slice(&data, params);
2547
2548 let result = nma_with_kernel(&input, kernel);
2549 prop_assert!(result.is_ok(), "NMA computation failed: {:?}", result.err());
2550 let out = result.unwrap().values;
2551
2552 let ref_result = nma_with_kernel(&input, Kernel::Scalar);
2553 prop_assert!(ref_result.is_ok(), "Reference NMA failed");
2554 let ref_out = ref_result.unwrap().values;
2555
2556 prop_assert_eq!(out.len(), data.len(), "Output length mismatch");
2557
2558 let first_valid = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
2559 let warmup_end = first_valid + period;
2560
2561 for i in 0..warmup_end.min(out.len()) {
2562 prop_assert!(
2563 out[i].is_nan(),
2564 "Expected NaN at index {} (warmup period), got {}",
2565 i,
2566 out[i]
2567 );
2568 }
2569
2570 for i in warmup_end..out.len() {
2571 prop_assert!(
2572 out[i].is_finite(),
2573 "Expected finite value at index {} (after warmup), got {}",
2574 i,
2575 out[i]
2576 );
2577 }
2578
2579 for i in warmup_end..out.len() {
2580 let point1 = data[i - period + 1];
2581 let point2 = data[i - period];
2582 let min_bound = point1.min(point2);
2583 let max_bound = point1.max(point2);
2584
2585 let tolerance = if test_name.contains("avx512") {
2586 1e-7
2587 } else {
2588 1e-9
2589 };
2590 prop_assert!(
2591 out[i] >= min_bound - tolerance && out[i] <= max_bound + tolerance,
2592 "NMA at index {} = {} not in bounds [{}, {}]",
2593 i,
2594 out[i],
2595 min_bound,
2596 max_bound
2597 );
2598 }
2599
2600 if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-12) && !data.is_empty() {
2601 for i in warmup_end..out.len() {
2602 prop_assert!(
2603 (out[i] - data[0]).abs() < 1e-9,
2604 "Constant data: NMA[{}] = {} should equal {}",
2605 i,
2606 out[i],
2607 data[0]
2608 );
2609 }
2610 }
2611
2612 if period == 1 {
2613 for i in (first_valid + 1)..out.len() {
2614 prop_assert!(
2615 (out[i] - data[i]).abs() < 1e-6,
2616 "Period=1: NMA[{}] = {} should be close to data[{}] = {}",
2617 i,
2618 out[i],
2619 i,
2620 data[i]
2621 );
2622 }
2623 }
2624
2625 for i in warmup_end..out.len() {
2626 let point1 = data[i - period + 1];
2627 let point2 = data[i - period];
2628
2629 if (point1 - point2).abs() > 1e-10 {
2630 let implied_ratio = (out[i] - point2) / (point1 - point2);
2631 prop_assert!(
2632 implied_ratio >= -1e-9 && implied_ratio <= 1.0 + 1e-9,
2633 "Invalid interpolation ratio {} at index {} (output={}, p1={}, p2={})",
2634 implied_ratio,
2635 i,
2636 out[i],
2637 point1,
2638 point2
2639 );
2640 }
2641 }
2642
2643 for i in 0..out.len() {
2644 if !out[i].is_finite() || !ref_out[i].is_finite() {
2645 prop_assert_eq!(
2646 out[i].is_nan(),
2647 ref_out[i].is_nan(),
2648 "NaN mismatch at index {}",
2649 i
2650 );
2651 continue;
2652 }
2653
2654 let out_bits = out[i].to_bits();
2655 let ref_bits = ref_out[i].to_bits();
2656 let ulp_diff = out_bits.abs_diff(ref_bits);
2657
2658 if test_name.contains("avx512") {
2659 let rel_error = if ref_out[i].abs() > 1e-10 {
2660 ((out[i] - ref_out[i]) / ref_out[i]).abs()
2661 } else {
2662 (out[i] - ref_out[i]).abs()
2663 };
2664 prop_assert!(
2665 rel_error < 1e-7 || ulp_diff <= 75,
2666 "Kernel mismatch at index {}: {} vs {} (rel_error: {}, ULP diff: {})",
2667 i,
2668 out[i],
2669 ref_out[i],
2670 rel_error,
2671 ulp_diff
2672 );
2673 } else {
2674 prop_assert!(
2675 (out[i] - ref_out[i]).abs() <= 1e-9 || ulp_diff <= 25,
2676 "Kernel mismatch at index {}: {} vs {} (ULP diff: {})",
2677 i,
2678 out[i],
2679 ref_out[i],
2680 ulp_diff
2681 );
2682 }
2683 }
2684
2685 let has_small_values = data.iter().any(|&x| x > 0.0 && x < 1e-8);
2686 if has_small_values {
2687 for i in warmup_end..out.len() {
2688 prop_assert!(
2689 out[i].is_finite(),
2690 "NMA failed to handle small values at index {}: {}",
2691 i,
2692 out[i]
2693 );
2694 }
2695 }
2696
2697 Ok(())
2698 })
2699 .unwrap();
2700
2701 Ok(())
2702 }
2703
2704 macro_rules! generate_all_nma_tests {
2705 ($($test_fn:ident),*) => {
2706 paste::paste! {
2707 $(#[test]
2708 fn [<$test_fn _scalar_f64>]() {
2709 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2710 })*
2711 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2712 $(
2713 #[test]
2714 fn [<$test_fn _avx2_f64>]() {
2715 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2716 }
2717 #[test]
2718 fn [<$test_fn _avx512_f64>]() {
2719 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2720 }
2721 )*
2722 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
2723 $(
2724 #[test]
2725 fn [<$test_fn _simd128_f64>]() {
2726 let _ = $test_fn(stringify!([<$test_fn _simd128_f64>]), Kernel::Scalar);
2727 }
2728 )*
2729 }
2730 }
2731 }
2732
2733 #[cfg(debug_assertions)]
2734 fn check_nma_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2735 skip_if_unsupported!(kernel, test_name);
2736
2737 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2738 let candles = read_candles_from_csv(file_path)?;
2739
2740 let test_cases = vec![
2741 NmaParams { period: Some(40) },
2742 NmaParams { period: Some(10) },
2743 NmaParams { period: Some(5) },
2744 NmaParams { period: Some(20) },
2745 NmaParams { period: Some(60) },
2746 NmaParams { period: Some(100) },
2747 NmaParams { period: Some(3) },
2748 NmaParams { period: Some(80) },
2749 NmaParams { period: None },
2750 ];
2751
2752 for params in test_cases {
2753 let input = NmaInput::from_candles(&candles, "close", params);
2754 let output = nma_with_kernel(&input, kernel)?;
2755
2756 for (i, &val) in output.values.iter().enumerate() {
2757 if val.is_nan() {
2758 continue;
2759 }
2760
2761 let bits = val.to_bits();
2762
2763 if bits == 0x11111111_11111111 {
2764 panic!(
2765 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
2766 with params period={:?}",
2767 test_name, val, bits, i, params.period
2768 );
2769 }
2770
2771 if bits == 0x22222222_22222222 {
2772 panic!(
2773 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
2774 with params period={:?}",
2775 test_name, val, bits, i, params.period
2776 );
2777 }
2778
2779 if bits == 0x33333333_33333333 {
2780 panic!(
2781 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
2782 with params period={:?}",
2783 test_name, val, bits, i, params.period
2784 );
2785 }
2786 }
2787 }
2788
2789 Ok(())
2790 }
2791
2792 #[cfg(not(debug_assertions))]
2793 fn check_nma_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2794 Ok(())
2795 }
2796
2797 generate_all_nma_tests!(
2798 check_nma_partial_params,
2799 check_nma_accuracy,
2800 check_nma_default_candles,
2801 check_nma_zero_period,
2802 check_nma_period_exceeds_length,
2803 check_nma_very_small_dataset,
2804 check_nma_empty_input,
2805 check_nma_reinput,
2806 check_nma_nan_handling,
2807 check_nma_no_poison,
2808 check_nma_property
2809 );
2810
2811 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2812 skip_if_unsupported!(kernel, test);
2813
2814 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2815 let c = read_candles_from_csv(file)?;
2816
2817 let output = NmaBatchBuilder::new()
2818 .kernel(kernel)
2819 .apply_candles(&c, "close")?;
2820
2821 let def = NmaParams::default();
2822 let row = output.values_for(&def).expect("default row missing");
2823
2824 assert_eq!(row.len(), c.close.len());
2825
2826 let expected = [
2827 64320.486018271724,
2828 64227.95719984426,
2829 64180.924933312606,
2830 63966.35530620797,
2831 64039.04719192333,
2832 ];
2833 let start = row.len() - 5;
2834 for (i, &v) in row[start..].iter().enumerate() {
2835 let tolerance = 1e-3;
2836 assert!(
2837 (v - expected[i]).abs() < tolerance,
2838 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
2839 );
2840 }
2841 Ok(())
2842 }
2843
2844 macro_rules! gen_batch_tests {
2845 ($fn_name:ident) => {
2846 paste::paste! {
2847 #[test] fn [<$fn_name _scalar>]() {
2848 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2849 }
2850 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2851 #[test] fn [<$fn_name _avx2>]() {
2852 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2853 }
2854 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2855 #[test] fn [<$fn_name _avx512>]() {
2856 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2857 }
2858 #[test] fn [<$fn_name _auto_detect>]() {
2859 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2860 }
2861 }
2862 };
2863 }
2864
2865 #[cfg(debug_assertions)]
2866 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2867 skip_if_unsupported!(kernel, test);
2868
2869 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2870 let c = read_candles_from_csv(file)?;
2871
2872 let batch_configs = vec![
2873 (10, 30, 10),
2874 (40, 40, 0),
2875 (3, 15, 3),
2876 (50, 100, 25),
2877 (5, 25, 5),
2878 (20, 80, 20),
2879 (8, 24, 8),
2880 (60, 120, 30),
2881 ];
2882
2883 for (p_start, p_end, p_step) in batch_configs {
2884 let output = NmaBatchBuilder::new()
2885 .kernel(kernel)
2886 .period_range(p_start, p_end, p_step)
2887 .apply_candles(&c, "close")?;
2888
2889 for (idx, &val) in output.values.iter().enumerate() {
2890 if val.is_nan() {
2891 continue;
2892 }
2893
2894 let bits = val.to_bits();
2895 let row = idx / output.cols;
2896 let col = idx % output.cols;
2897 let combo = &output.combos[row];
2898
2899 if bits == 0x11111111_11111111 {
2900 panic!(
2901 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} \
2902 (flat index {}) with params period={:?}",
2903 test, val, bits, row, col, idx, combo.period
2904 );
2905 }
2906
2907 if bits == 0x22222222_22222222 {
2908 panic!(
2909 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} \
2910 (flat index {}) with params period={:?}",
2911 test, val, bits, row, col, idx, combo.period
2912 );
2913 }
2914
2915 if bits == 0x33333333_33333333 {
2916 panic!(
2917 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} \
2918 (flat index {}) with params period={:?}",
2919 test, val, bits, row, col, idx, combo.period
2920 );
2921 }
2922 }
2923 }
2924
2925 Ok(())
2926 }
2927
2928 #[cfg(not(debug_assertions))]
2929 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2930 Ok(())
2931 }
2932
2933 gen_batch_tests!(check_batch_default_row);
2934 gen_batch_tests!(check_batch_no_poison);
2935}